Table Of Content1
Learning Graphical Model Parameters with
Approximate Marginal Inference
Justin Domke,NICTA & Australia NationalUniversity
F
3
Abstract—Likelihood based-learning of graphical models faces chal- 2 SETUP
1
lenges of computational-complexity and robustness to model mis-
0 2.1 MarkovRandomFields
specification. This paper studies methods that fit parameters directly
2
tomaximizeameasureoftheaccuracyofpredictedmarginals,taking Markov random fields are probability distributions that
n intoaccountbothmodelandinferenceapproximationsattrainingtime. may be written as
a Experimentsonimagingproblemssuggestmarginalization-basedlearn-
J ing performs better than likelihood-based approximations on difficult 1
p(x)= ψ(x ) ψ(x ). (1)
5 problemswherethemodelbeingfitisapproximateinnature. Z c i
1 c i
Y Y
IndexTerms—GraphicalModels,ConditionalRandomFields,Machine Thisisdefinedwith referencetoagraph,withone node
]
G Learning,Inference,Segmentation. for each random variable. The first product in Eq. 1 is
L over the set of cliques c in the graph, while the second
s. 1 INTRODUCTION is over all individual variables. For example, the graph
c
[ GRAPHICALmodelsareastandardtoolinimagepro- x3 x4
1 cessing, computer vision, and many other fields. x1 x2
Exact inference and inference are often intractable, due x x
v 5 6
3 to the high treewidth of the graph.
9 Much previous work involves approximations of
corresponds to the distribution
1 the likelihood. (Section 4). In this paper, we suggest
3 1
that parameter learning can instead be done using
. p(x)= ψ(x1,x2)ψ(x2,x3,x5)ψ(x3,x4)ψ(x5,x6)
1 “marginalization-based” loss functions. These directly Z
0 quantifythequalityofthepredictionsofagivenmarginal ×ψ(x1)ψ(x2)ψ(x3)ψ(x4)ψ(x5)ψ(x6).
3
inference algorithm. This has two major advantages.
1 Eachfunctionψ(x )orψ(x )ispositive,butotherwise
First,approximationerrorsintheinferencealgorithmare c i
v: arbitrary. The factor Z ensures normalization.
taken into account while learning. Second, this is robust
i The motivation for these types of models is the Ham-
X to model mis-specification.
mersley–Cliffordtheorem[1],whichgivesspecificcondi-
r The contributions of this paper are, first, the general
tionsunderwhichadistributioncanbewrittenasinEq.
a framework of marginalization-based fitting as implicit
1. Those conditions are that, first, each random variable
differentiation. Second, we show that the parameter
is conditionally independent of all others, given its im-
gradient can be computed by “perturbation”– that is,
mediateneighborsand,secondly,thateachconfiguration
byre-runningthe approximatealgorithm twice withthe
x has nonzero probability. Often, domain knowledge
parametersperturbedslightly based on the currentloss.
about conditional independence can be used to build
Third, we introduce the strategy of “truncated fitting”.
a reasonable graph, and the factorized representation in
Inference algorithms are based on optimization, where
anMRFreducesthecurseofdimensionalityencountered
oneiteratesupdatesuntilsomeconvergencethresholdis
in modeling a high-dimensional distribution.
reached.Intruncatedfitting,algorithmsarederivedtofit
the marginals produced after a fixed number of updates,
with no assumption of convergence. We show that this 2.2 ConditionalRandomFields
leadstosignificantspeedups.Wealsoderiveavariantof
Oneisofteninterestedinmodelingtheconditionalprob-
this that can apply to likelihood based learning. Finally,
ability of x, given observations y. For such problems, it
experimental results confirm that marginalization based
is natural to define a Conditional Random Field [2]
learning gives betterresults on difficultproblems where
1
inference approximations and model mis-specification p(x|y)= ψ(x ,y) ψ(x ,y).
are most significant. Z(y) c i
c i
Y Y
2
Here, ψ(x ,y) indicates that the value for a particular over all possible configurations, as MPM does. A few
c
configurationx dependsontheinputy.Inpractice,the papers have experimentally compared MAP and MPM
c
form of this dependence is application dependent. inference [6], [7].
2.3 InferenceProblems 2.4 ExponentialFamily
Supposewehavesomedistributionp(x|y),wearegiven The exponential family is defined by
some input y, and we need to guess a single output
vector x∗. What is the best guess? p(x;θ)=exp θ·f(x)−A(θ) ,
Theanswerclearlydependsonthemeaningof“best”.
where θ is a vector of pa(cid:0)rameters, f(x)(cid:1)is a vector of
One framework for answering this question is the idea
sufficient statistics, and the log-partition function
of a Bayes estimator [3]. One must specify some utility
function U(x,x′), quantifying how “happy” one is to A(θ)=log expθ·f(x). (3)
haveguessedxifthetrueoutputisx′.Onethenchooses
x
x∗ to maximize the expected utility X
ensures normalization. Different sufficient statistics f(x)
x∗ =argmax p(x′|y)U(x,x′). define different distributions. The exponential family is
x well understood in statistics. Accordingly, it is useful to
x′
X
note that a Markov random field (Eq. 1) is a member of
One natural utility function is an indicator function,
the exponential family, with sufficient statistics consist-
giving one for the exact value x′, and zero otherwise. It
ingofindicatorfunctionsforeachpossibleconfiguration
is easy to show that for this utility, the optimal estimate
of each clique and each variable [8], namely,
is the popular Maximum a Posteriori (MAP) estimate.
Theorem. If U(x,x′)=I[x=x′], then
f(X)={I[X =x ]|∀c,x }∪{I[X =x ]|∀i,x }.
c c c i i i
x∗ =argmaxp(x|y).
x It is useful to introduce the notation θ(x ) to refer
c
Little can be said in general about if this utility to the component of θ corresponding to the indicator
function truly reflects user priorities. However, in high- function I[Xc = xc], and similarly for θ(xi). Then,
dimensional applications, there are reasons for skepti- the MRF in Eq. 1 would have ψ(xc) = eθ(xc) and
cism. First, the actual maximizing probability p(x∗|y) ψ(xi) = eθ(xi). Many operations on graphical models
in a MAP estimate might be extremely small, so much canbemoreelegantlyrepresentedusingthisexponential
so that astronomical numbers of examples might be family representation.
necessary before one could expect to exactly predict the A standard problem in the exponential family is to
true output. Second, this utility does not distinguish compute the mean value of f,
between a prediction that contains only a single error
µ(θ)= p(x;θ)f(x),
at some component x , and one that is entirely wrong.
j
x
An alternative utility function, popular for imaging X
problems,quantifiestheHammingdistance,orthenum- called the “mean parameters”. It is easy to show these
ber of components of the output vector that are correct. are equal to the gradient of the log-partition function.
Maximizingthisresultsinselectingthemostlikelyvalue dA
=µ(θ). (4)
for each component independently. dθ
Theorem. If U(x,x′)= iI[xi =x′i], then For an exponential family corresponding to an MRF,
x∗ =argmaxp(x |y). (2) computingµisequivalenttocomputingallthemarginal
i P xi i probabilities. To see this, note that, using a similar
This appearsto havebeenoriginallycalledMaximum notation for indexing µ as for θ above,
Posterior Marginal (MPM) inference [4], though it has
beenreinventedunderothernames[5].Fromacomputa-
µ(x ;θ)= p(X;θ)I[X =x ]=p(x ;θ).
c c c c
tional perspective, the main difficulty is not performing
X
the trivial maximization in Eq. 2, but rather computing X
themarginalsp(x |y).Themarginal-basedlossfunctions Conditionaldistributionscanberepresentedbythink-
i
ingoftheparametervectorθ(y;γ)asbeingafunctionof
introduced in Section 4.2 can be motivated by the idea
the input y, whereγ arenow the freeparametersrather
thatattesttime,onewilluseaninferencemethodsimilar
than θ. (Again, the nature of the dependence of θ on y
to MPMwhereone in concernedonly with the accuracy
and γ will vary by application.) Then, we have that
of the marginals.
TheresultsofMAPandMPMinferencewillbesimilar
p(x|y;γ)=exp θ(y;γ)·f(x)−A(θ(y;γ)) , (5)
if the distribution p(x|y) is heavily “peaked” at a single
configuration x. Roughly, the greater the entropy of sometimes called a cu(cid:0)rved conditional exponent(cid:1)ial fam-
p(x|y), the more there is to be gained in integrating ily.
3
2.5 Learning 3.1 ExactVariationalPrinciple
The focus of this paper is learning of model parameters Theorem (Exact variational principle). The log-partition
from data. (Automatically determining graph structure function can also be represented as
remains an active research area, but is not considered
here.) Specifically, we take the goal of learning to be to A(θ)= maxθ·µ+H(µ), (8)
µ∈M
minimize the empirical risk
where
R(θ)= L θ,xˆ , (6) M={µ′ :∃θ,µ′ =µ(θ)}
xˆ
X (cid:0) (cid:1)
where the summation is over all examples xˆ in the is the marginal polytope, and
dataset, and the loss function L(θ,xˆ) quantifies how
H(µ)=− p(x;θ(µ))logp(x;θ(µ))
well the distribution defined by the parameter vector
x
θ matches the example xˆ. Several loss functions are X
is the entropy.
considered in Section 4.
We assume that the empirical risk will be fit by some In treelike graphs, this optimization can be solved
gradient-based optimization. Hence, the main technical efficiently.Ingeneralgraphs,however,itisintractablein
issuesinlearningarewhichlossfunctiontouseandhow two ways. First, the marginal polytope M becomes dif-
to compute the gradient dL. ficult to characterize. Second, the entropy is intractable
dθ
In practice, we will usually be interested in fitting to compute.
conditional distributions. Using the notation from Eq. Applying Danskin’s theorem to Eq. 8 yields that
5, we can write this as
dA
µ(θ)= =argmaxθ·µ+H(µ). (9)
R(γ)= L θ(yˆ,γ),xˆ . dθ µ∈M
(Xyˆ,xˆ) (cid:0) (cid:1) Thus,thepartitionfunction(Eq.8)andmarginals(Eq.
Note that if one has recovered dL, dL is immediate 9)canbothbeobtainedfromsolvingthesameoptimiza-
dθ dγ
from the vector chain rule as tion problem. This close relationship between the log-
partition function and marginals is heavily used in the
dL dθT dL derivation of approximate marginalization algorithms.
= . (7)
dγ dγ dθ To compute approximate marginals, first, derive an ap-
proximate version of the optimization in Eq. 8. Next,
Thus, the main technical problems involved in fitting
take the exact gradient of this approximate partition
a conditional distribution are similar to those for a
generativedistribution:Onefindsθ =θ(yˆ,γ),computes function.Thisstrategyisusedinbothoftheapproximate
the L and dL on example xˆ exactly as in the generative marginalization procedures considered here: mean field
dθ
case, and finally recovers dL from Eq. 7. So, for simplic- and tree-reweighted belief propagation.
dγ
ity, y and γ will largely be ignored in the theoretical
developments below. 3.2 MeanField
The idea of mean field is to approximate the exact
3 VARIATIONAL INFERENCE
variationalprinciplebyreplacingMwithsometractable
This section reviews approximate methods for comput- subset F ⊂ M, such that F is easy to characterize,
ing marginals, with notation based on Wainwright and and for any vector µ ∈ F we can exactly compute the
Jordan [8]. For readability, all proofs in this section are entropy. To create such a set F, instead of considering
postponed to Appendix A. the set of mean vectors obtainable from any parameter
The relationship between the marginals and the log- vector (which characterizes M), consider a subset of
partition function in Eq. 4 is key to defining approx- tractable parametervectors. The simplest way to achieve
imate marginalization procedures. In Section 3.1, the thistorestrictconsiderationtoparametervectorsθ with
exact variational principle shows that the (intractable) θ(x )=0 for all factors c.
c
problem of computing the log-partition function can be
convertedtoa(stillintractable)optimizationproblem.To F ={µ′ :∃θ,µ′ =µ(θ), ∀c, θ(x )=0}.
c
deriveatractablemarginalizationalgorithmoneapprox-
It is not hard to see that this corresponds to the set
imates this optimization, yielding some approximate
log-partition function A˜(θ). The approximate marginals of fully-factorized distributions. Note also that this is (in
are then taken as the exact gradient of A˜. non-treelike graphs) a non-convex set, since it has the
same convex hull as M, but is a proper subset. So, the
We define the reverse mapping θ(µ) to return some
meanfield partition function approximationis based on
parametervectorthatyieldsthatmarginalsµ.Whilethis
the optimization
will in general not be unique [8, sec. 3.5.2], any two
vectorsthatproducethesamemarginalsµwillalsoyield
A˜(θ)=maxθ·µ+H(µ), (10)
thesamedistribution,andsop(x;θ(µ))isunambiguous.
µ∈F
4
with approximate marginals corresponding to the max- Thus,themarginalvectorreturnedbyTRWmay,ingen-
imizing vector µ, i.e. eral,beinconsistentinthesensethatnojointdistribution
yields those marginals.
µ˜(θ)=argmaxθ·µ+H(µ). (11) The entropy approximation used by TRW is
µ∈F
Since this is maximizing the same objective as the H˜(µ)= H(µ )− ρ I(µ ), (16)
i c c
exact variational principle, but under a more restricted i c
constraint set, clearly A˜(θ)≤A(θ). X X
where H(µ ) = − µ(x )logµ(x ) is the univariate
Here, since the marginals are coming from a fully- i xi i i
entropy corresponding to variable i, and
factorized distribution, the exact entropy is available as
P
µ(x )
H(µ)=− µ(xi)logµ(xi). (12) I(µ )= µ(x )log c (17)
c c
µ(x )
Xi Xxi Xxc i∈c i
The strategy we use to perform the maximization in Q
isthemutualinformationcorrespondingtothevariables
Eq. 10 is block-coordinate ascent. Namely, we pick a
in the factor c. The motivation for this approximation is
coordinate j, then set µ(x ) to maximize the objective,
j that if the constants ρ are selected appropriately, this
leaving µ(x ) fixed for all i 6= j. The next theorem c
i
gives an upper bound on the true entropy.
formalizes this.
Theorem (TRWEntropyBound). LetPr(G) beadistribu-
Theorem (Mean Field Updates). A local maximum of Eq.
tion over tree structured graphs, and define ρ =Pr(c∈G).
10 can be reached by iterating the updates c
Then, with H˜ as defined in Eq. 16,
µ(x )← 1 exp θ(x )+ θ(x ) µ(x ) , H˜(µ)≥H(µ).
j j c i
Z
(cid:0) cX:j∈cxXc\j i∈Yc\j (cid:1) Thus, TRW is maximizing an upper bound on the ex-
whereZ isanormalizing factorensuring that µ(x )=1. act variational principle, under an expanded constraint
j
Xxj smeat.xiSminucme bvoatlhueo,fwtehehsaevcehtahnagteAs˜(θca)n≥oAn(lyθ)i.ncrease the
3.3 Tree-ReweightedBeliefPropagation Now, we consider how to actually compute the
approximate log-partition function and associated
Whereas mean field replaced the marginal polytope
marginals. Consider the message-passing updates
with asubset,tree-reweightedbeliefpropagation(TRW)
replaces it with a superset, L ⊃ M. This clearly can
ofunnlyctiionncr.eHasoewtheveevr,alauefuorfththereaapppprrooxxiimmaattieonlogis-pnaeretidtieodn, mc(xi)∝ eρ1cθ(xc) eθ(xj) d:j∈dmd(xj)ρd, (18)
m (x )
as the entropy remains intractable to compute for an xXc\i jY∈c\i Q c j
arbitrary mean vector µ. (It is not even defined for
where “∝” is used as an assignment operator to means
µ 6∈ M.) Thus, TRW further approximates the entropy
assigning after normalization.
with a tractableupper bound. Takentogether, these two
approximations yield a tractable upper bound on the Theorem (TRW Updates). Let ρ be as in the previous
c
log-partition function. theorem. Then, if the updates in Eq. 18 reach a fixed point,
Thus, TRW is based on the optimization problem the marginals defined by
A˜(θ)=mµ∈aLxθ·µ+H˜(µ). (13) µ(xc) ∝ eρ1cθ(xc) eθ(xi) d:i∈mdm(xd()xi)ρd,
i∈c Q c i
Again, the approximate marginals are simply the maxi- Y
mizing vector µ, i.e., µ(xi) ∝ eθ(xi) md(xi)ρd
d:i∈d
Y
µ˜(θ)=argmaxθ·µ+H˜(µ). (14)
constitute the global optimum of Eq. 13.
µ∈L
The relaxation of the local polytope used in TRW is So, if the updates happen to converge, we have the
the local polytope, solution. Meltzer et al. show [9] that on certain graphs
made up of monotonic chains, an appropriateordering of
messages does assure convergence. (The proof is essen-
L={µ: µ(x )=µ(x ), µ(x )=1}. (15)
c i i tially that under these circumstances, message passing
xXc\i Xxi is equivalent to coordinate ascent in the dual.)
Since any valid marginal vector must obey these con- TRW simplifies into loopy belief propagation by
straints, clearly M ⊂ L. However, L in general also choosing ρ = 1 everywhere, though the bounding
c
contains unrealizable vectors (though on trees L=M). property is lost.
5
4 LOSS FUNCTIONS surrogatelikelihood willthenboundthetruelikelihood.
Specifically, mean field based surrogate likelihood is an
For space, only a representative sample of prior work
upper bound on the true likelihood, while TRW-based
can be cited. A recent review [10] is more thorough.
surrogate likelihood is a lower bound.
Though,technically,a“loss”shouldbeminimized,we
continue to use this terminology for the likelihood and
4.1.2 ExpectationMaximization
its approximations, where one wishes to maximize.
For simplicity, the discussion below is for the genera- In many applications, only a subset of variables may be
tive setting. Using the same loss functions for training a observed. Suppose that we want to model x = (z,h)
conditional model is simple (Section 2.5). where z is observed, but h is hidden. A natural loss
function here is the expected maximization (EM) loss
4.1 TheLikelihoodandApproximations
L(θ,z)=logp(z;θ)=log p(z,h;θ).
The classic loss function would be the likelihood, with
h
X
L(θ,x)=logp(x;θ)=θ·f(x)−A(θ). (19) It is easy to show that this is equivalent to
This has the gradient L(θ,z)=A(θ,z)−A(θ), (21)
dL where A(θ,z)=log expθ·f(z,h) is the log-partition
=f(x)−µ(θ). (20) h
dθ function with z “clamped” to the observed values. If all
P
One argument for the likelihood is that it is efficient; variables are observed A(θ,z) reduces to θ·f(z).
given a correct model, as data increases it converges to If on substitutes a variational approximation for
true parameters at an asymptotically optimal rate [11]. A(θ,z), a “variational EM” algorithm [8, Sec. 6.2.2] can
Some previous work uses tree structured graphs berecoveredthatalternatesbetweencomputingapprox-
where marginals may be computed exactly [12]. Of imate marginals and parameter updates. Here, because
course, in high-treewidth graphs, the likelihood and of the close relationship to the surrogate likelihood, we
its gradient will be intractable to compute exactly, due designate“surrogateEM”forthecasewhereA(θ,z)and
to the presence of the log-partition function A(θ) and A(θ)maybothbeapproximatedandthelearningisdone
marginals µ(θ). This has motivated a varietyof approx- with a gradient-based method. To obtain a bound on
imations. The first is to approximate the marginals µ the true EM loss, care is required. For example, lower-
using Markov chain Monte Carlo [13], [14]. This can boundingA(θ,z)usingmeanfield,andupper-bounding
lead to high computational expense (particularly in the A(θ) using TRW means a lower-bound on the true EM
conditional case, where differentchains must be run for loss. However, using the same approximation for both
eachinput).ContrastiveDivergence[15]furtherapprox- A(θ) and A(θ,z) appears to work well in practice [26].
imates these samples by running the Markov chain for
only a few steps, but started at the data points [16]. If 4.1.3 Saddle-Point Approximation
the Markov chain is run long enough, these approaches Athirdapproximationofthelikelihoodistosearchfora
can give an arbitrarily good approximation. However, “saddle-point”. Here, one approximates the gradient in
Markovchainparametersmayneedtobeadjustedtothe Eq. 20by running a (presumablyapproximate)MAPin-
particular problem, and these approaches are generally ferencealgorithm,andthenimaginingthatthemarginals
slower than those discussed below. put unit probability at the approximate MAP solution,
and zero elsewhere [27], [28], [21]. This is a heuristic
4.1.1 SurrogateLikelihood method, but it can be expected to work well when the
A seemingly heuristic approachwould be to replacethe estimated MAP solution is close to the true MAP and
marginals in Eq. 20 with those from an approximate the conditional distribution p(x|y) is strongly “peaked”.
inferencemethod.Thisapproximationcanbequiteprin-
cipled if one thinks instead of approximating the log- 4.1.4 Pseudolikelihood
partition function in the likelihood itself (Eq. 19). Then, Finally, there are two classes of likelihood approxima-
thecorrespondingapproximatemarginalswillemergeas tionsthatdonotrequireinference.Thefirstistheclassic
the exact gradient of this surrogate loss. This “surrogate pseudolikelihood [29], where one uses
likelihood” [17] approximation appears to be the most
L(θ,x)= logp(x |x ;θ).
widely used loss in imaging problems, with marginals i −i
approximated by either mean field [18], [19], TRW [20] Xi
or LBP [21], [22], [23], [24], [25]. However, the terminol- This can be computed efficiently, even in high
ogy of “surrogate likelihood” is not widespread and in treewidthgraphs,sinceconditionalprobabilitiesareeasy
most cases, only the gradient is computed, meaning the to compute. Besag [29] showed that, under certain con-
optimization cannot use line searches. ditions, this will converge to the true parameter vector
If one uses a log-partition approximation that pro- asthe amount of databecomes infinite. The pseudolike-
vides a bound on the true log-partition function, the lihood has been used in many applications [30], [31].
6
Instead of the probability of individual variables given There is reason to think the likelihood is somewhat
all others, one can take the probability of patches of robust to model mis-specification. In the infinite data
variables given all others, sometimes called the “patch” limit, it finds the “closest” solution in the sense of KL-
pseudolikelihood [32]. This interpolates to the exact divergence since, if q is the true distribution, then
likelihood as the patches become larger, though some
KL(q||p) = const.−Elogp(x;θ).
type of inference is generally required. q
4.1.5 PiecewiseLikelihood 4.2.1 UnivariateLogisticLoss
More recently, Sutton and McCallum [33] suggested the The univariate logistic loss [39] is defined by
piecewiselikelihood.Theideaistoapproximatethelog-
L(θ,x)=− logµ(x ;θ),
partition function as a sum of log-partition functions of i
the different“pieces“ of the graph. Thereis flexibility in Xi
determining which pieces to use. In this paper, we will where we use the notation µ to indicate that the loss is
use pieces consisting of each clique and each variable, implicitly defined with respect to the marginal predic-
which worked better in practice than some alternatives. tions of some (possibly approximate) algorithm, rather
Then, one has the surrogate partition function than the true marginals. This measures the mean accu-
racy of all univariate marginals, rather than the joint
A˜(θ) = A (θ)+ A (θ),
c i distribution. This loss can be seen as empirical risk
Xc Xi minimization of the KL-divergence between the true
Ac(θ) = log eθ(xc), Ai(θ)=log eθ(xi). marginals and the predicted ones, since
Xxc Xxi q(x )
i
ItisnottoohardtoshowthatA(θ)≤A˜(θ).Inpractice, KL(qi||µi) = q(xi)log µ(x ;θ)
i
it is sometimes bestto make some heuristic adjustments Xi Xi Xxi
to the parameters after learning to improve test-time = const.−E logµ(xi;θ).
q
performance [34], [35]. i
X
Ifdefinedonexactmarginals,thisisatypeofcomposite
4.2 Marginal-basedLossFunctions likelihood [40].
Given the discussion in Section 4.1, one might conclude
4.2.2 SmoothedUnivariateClassificationError
that the likelihood, while difficult to optimize, is an
ideal loss function since, given a well-specified model, Perhaps the most natural loss in the conditional setting
itwill convergetothe true parametersatasymptotically would be the univariate classification error,
efficient rates. However, this conclusion is complicated
by two issues. First, of course, the maximum likelihood L(θ,x)= S max µ(xi;θ)−µ(xi;θ) ,
solution is computationally intractable, motivating the Xi (cid:0)x′i6=xi (cid:1)
approximations above. whereS(·)isthestepfunction.Thisexactlymeasuresthe
A second issue is that of model mis-specification. For number of components of x that would be incorrectly
many types of complex phenomena, we will wish to fit predicted if using MPM inference. Of course, this loss
amodelthatisapproximateinnature.Thiscouldbetrue is neither differentiable nor continuous, which makes it
because the conditional independencies asserted by the impractical to optimize using gradient-based methods.
graph do not exactly hold, or because the parametriza- Instead Gross et al. [5] suggest approximating with a
tion of factors is too simplistic. These approximations sigmoid function S(t) = (1 + exp(−αt))−1, where α
mightbemadeoutofignorance,duetoalackofknowl- controls approximation quality.
edge about the domain being studied, or deliberately There is evidence [36], [5] that the smoothed classifi-
becausethe true modelmighthavetoo manydegreesof cation loss can yield parameters with lower univariate
freedom to be fit with available data. classification error under MPM inference. However, our
Inthecaseofanapproximatemodel,no“true”param- experienceisthatitisalsomorepronetogettingstuckin
eters exist. The idea of marginal-based loss functions is local minima, making experiments difficult to interpret.
to instead consider how the model will be used. If one Thus, it is not included in the experiments below. Our
will compute marginals at test-time – perhaps for MPM experience with the univariate quadratic loss [41] is
inference (Section 2.3) – it makes sense to maximize the similar.
accuracy of these predictions. Further, if one will use
an approximate inference algorithm, it makes sense to 4.2.3 CliqueLosses
optimize the accuracyof the approximatemarginals. This
Any of the above univariate losses can be instead taken
essentially fits into the paradigm of empirical risk min-
based on cliques. For example, the clique logistic loss is
imization [36], [37]. The idea of training a probabilistic
model using an alternative loss to the likelihood goes L(θ,x)=− logµ(x ;θ),
c
back at least to Bahl et al. in the late 1980s [38].
c
X
7
0.3
shift of 0 shift of 3 shift of 10 by Markov chain Monte Carlo, with shifts introduced
mean test error00..12 lciklieqluiheo loodgistic aftrofatreimnresdgarmaudpsiilneinnggta.(STgehrcea.dl4iike.1ne)lt,ihwoobhotiadlienwethdasevfilioatgubissaticinckgplotrhosespeacsgloawsteieodrne-
univariate logistic
1000 101 102 103 100 101 102 103 10 0 101 102 103 (Sec. 7). Fig. 1 shows the mean test error (estimated on
# training data # training data # training data
1000 examples), while Fig. 2 shows example marginals.
Figure 1: Mean test error of various loss functions
We see that the performance of all methods deteriorates
trained with exact inference. In the case of a well-
withmis-specification,butthemarginal-basedlossfunc-
specified model (shift of zero), the likelihood performs
tions are more resistant to these effects.
essentially identically to the marginal-based loss func-
tions. However, when mis-specification is introduced,
4.4 MAP-BasedTraining
quite different estimates result.
Another class of methods explicitly optimize the perfor-
mance of MAP inference [42], [43], [44], [45], [25]. This
paper focuses on applications that use marginal infer-
which may be seen as empirical risk minimization of
ence, and that may need to accommodate hidden vari-
the mean KL-divergence of the true clique marginals to
ables, and so concentrates on likelihood and marginal-
the predicted ones. An advantage of this with an exact
based losses.
modelisconsistency.Simpleexamplesshowcaseswhere
a model predicts perfect univariate marginals, despite
5 IMPLICIT FITTING
the joint distribution being very inaccurate. However, if
allcliquemarginalsarecorrect,thejointmustbecorrect, Wenow turntothe issue ofhowtotrainhigh-treewidth
by the standard moment matching conditions for the graphical models to optimize the performance of a
exponential family [8]. marginal-based loss function, based on some approxi-
mate inference algorithm. Now, computing the value of
4.2.4 Hiddenvariables the loss for any of the marginal-based loss functions is
not hard. One can simply run the inference algorithm
Marginal-basedloss functions canaccommodate hidden
and plug the resulting marginal into the loss. However,
variables by simply taking the sum in the loss over the
we also require the gradient dL.
observed variables only. A similar approach can be used dθ
Ourfirstresultisthatthelossgradientcanbeobtained
with the pseudolikelihood or piecewise likelihood.
by solving a sparse linear system. Here, it is useful to
introduce notation to distinguish the loss L, defined in
4.3 ComparisonwithExactInference terms of the parameters θ from the loss Q, defined di-
rectlyintermsofthemarginalsµ.(Notethatthoughthe
To compare the effects of different loss functions in the
notation suggests the application to marginal inference,
presenceofmodelmis-specification,thissectioncontains
this is a generic result.)
a simple example where the graphical model takes the
following “chain” structure: Theorem. Suppose that
x x x x x µ(θ):=argmax θ·µ+H(µ). (22)
1 2 3 4 n
µ:Bµ=d
...
Define L(θ,x)=Q(µ(θ),x). Then, letting D = d2H ,
y y y y y dµdµT
1 2 3 4 n
dL dQ
= D−1BT(BD−1BT)−1BD−1−D−1 .
dθ dµ
Here, exactinferenceis possible, so comparison is not
(cid:0) (cid:1)
complicated by approximate inference. A proof may be found in Appendix B. This theorem
All variables are binary. Parameters are generated by states that, essentially, once one has computed the pre-
takingθ(x )randomlyfromtheinterval[−1,+1]foralli dicted marginals, the gradient of the loss with respect
i
and x . Interaction parameters are taken as θ(x ,x )=t to marginals dQ can be transformed into the gradient
i i j dµ
when xi = xj, and θ(xi,xj) = −t when xi 6= xj, where of the loss with respect to parameters ddLθ through the
t is randomly chosen from the interval [−1,+1] for all solution of a sparse linear system.
(i,j). Interactionsθ(yi,yj)andθ(xi,yi)arechoseninthe The optimization in Eq. 22 takes place under linear
same way. constraints, which encompasses the local polytope used
To systematically study the effects of differing inTRWmessage-passing(Eq.15).Thistheoremdoesnot
“amounts” of mis-specification, after generating data, apply to mean field, as F is not a linear constraint set
we apply various circular shifts to x. Thus, the data whenviewedasafunctionofbothcliqueandunivariate
no longer corresponds exactly the the structure of the marginals.
graphical model being fit. Inanycase,themethodsdevelopedbelowaresimpler
Thirty-two different random distributions were cre- to use, as they do not require explicitly forming the
ated.Foreach,variousquantitiesofdataweregenerated constraint matrix B or solving the linear system.
8
) true marginals
x 1
1|
=
eplacements 0.5
yi
(
p 0
0 20 40 60 80 100 120 140 160 180 200
i
) shift of 3
x 1
1|
=
0.5
yi
(
p 0
0 20 40 60 80 100 120 140 160 180 200
i
) shift of 10
x 1
1|
=
0.5
yi
(
p 0
0 20 40 60 80 100 120 140 160 180 200
i
Figure 2: Exact and predicted marginals for an example input. Predicted marginals are trained using 1000 data.
With low shifts, all loss functions lead to accuratepredicted marginals. However, the univariate and clique logistic
loss are more resistant to the effects of model mis-specification. Legends as in Fig. 1.
6 PERTURBATION costofmorecallstoinference,aregivenusingtwo-sided
differences, with
This section observes that variational methods have a
special structure that allows derivativesto be calculated dL 1 dQ dQ
≈ µ(θ+r )−µ(θ−r ) ,
without explicitly forming or inverting a linear system. dθ 2r dµ dµ
We have, by the vector chain rule, that (cid:0) (cid:1)
which is accurate to order o(r2). Still more accurate
dL dµT dQ results are obtained with “four-sided” differences, with
= . (23)
dθ dθ dµ dL 1 dQ dQ
≈ −µ(θ+2r )+8µ(θ+r )
dθ 12r dµ dµ
A classic trick in scientific computing is to efficiently
(cid:0) dQ dQ
compute Jacobian-vector products by finite differences. −8µ(θ−r )+µ(θ−2r ) ,
The basic result is that, for any vector v, dµ dµ
(cid:1)
dµ 1 which is accurate to order o(r4) [47].
v= lim µ(θ+rv)−µ(θ) , Alg. 1 shows more explicitly how the loss gradient
dθT r→0r
could be calculated, using two-sided differences.
(cid:0) (cid:1)
whichisessentiallyjustthedefinitionofthederivativeof The issue remains of how to calculate the step size
µ in the direction of v. Now, this does not immediately r. Each of the approximations above becomes exact
seem helpful, since Eq. 23 requires dµT, not dµ . How- as r → 0. However, as r becomes very small, nu-
dθ dθT
ever,withvariationalmethods,thesearesymmetric.The merical error eventually dominates. To investigate this
simplest way to see this is to note that issueexperimentally,wegeneratedrandommodelsona
10×10binarygrid,witheachparameterθ(x )randomly
i
dµ d dA dA
= = . chosen from a standard normal, while each interaction
dθT dθT dθ dθdθT parameterθ(x ,x )waschosenrandomlyfromanormal
(cid:18) (cid:19) i j
with a standard deviation of s. In each case, a random
Domke [46] lists conditions for various classes of en-
value x was generated, and the “true” loss gradient
tropies that guarantee that A will be differentiable.
was estimated by standard (inefficient) 2-sided finite
Combining the above three equations, the loss gradi-
differences, with inference re-run after each component
ent is available as the limit
of θ is perturbed independently. To this, we compare
one, two, and four-sided perturbations. In all cases,
dL 1 dQ
= lim µ(θ+r )−µ(θ) . (24) the step size is, following Andrei [48], taken to be
dθ r→0r(cid:0) dµ (cid:1) r =mǫ13 1+||θ||∞ /||ddQµ||∞,whereǫismachineepsilon,
In practice, of course, the gradient is approximated and m is a multiplier that we will vary. Note that the
(cid:0) (cid:1)
using some finite r. The simplest approximation, one- optimal power of ǫ will depend on the finite difference
sided differences, simply takes a single value of r in scheme; 1 is optimal for two-sided differences [49, Sec.
3
Eq. 24, rather than a limit. More accurate results at the 8.1].Allcalculations take place in double-precisionwith
9
Algorithm1Calculating dL byperturbation(two-sided). terminating with a bad search direction. Meanwhile, a
dθ
1) Do inference. µ∗ ←argmaxθ·µ+H(µ) loose threshold can often be used at test time with few
µ∈M consequences. Usually, a differenceof 10−3 in estimated
dQ
2) At µ∗, calculate the gradient . marginals has little practicalimpact, but this can still be
dµ
enough to prevent learning from succeeding [51].
3) Calculate a perturbation size r. Itseemsoddthatthelearningalgorithmwouldspend
4) Do inference on perturbed parameters. the majority of computational effort exploring tight
dQ convergence levels that are irrelevant to the practical
a) µ+ ←argmax(θ+r )·µ+H(µ)
µ∈M dµ performance of the model. Here, we define the learning
b) µ− ←argmax(θ−rdQ)·µ+H(µ) objectiveintermsoftheapproximatemarginalsobtained
µ∈M dµ after a fixed number of iterations. To understand this,
5) Recover full derivative as dL ← 1 (µ+−µ−). one may think of the inference process not as an op-
dθ 2r timization, but rather as a large, nonlinear function.
This clearly leads to a well-defined objective function.
or110002 12−−ssiiddeedd 1100 1100 ITnRpWuttionrgmpeaaranmfieeteldrs,, caopmplpyuintigngthperietdericatteiodnsmoafrgeiintahlesr,
nt Err1100−−42 4−sided 1100 1100 and finally a loss are all differentiable operations. Thus,
adie10−6 10 10 the loss gradient is efficiently computable, at least in
Gr10−8 10 10 principle,byreverse-modeautomaticdifferentiation(au-
10−1100 0r = 1102 104 10160100r = 2102 104 10160100r = 3102 104 106 todiff), an approach explored by Stovanov et al. [36],
Perturbation Multiplier Perturbation Multiplier Perturbation Multiplier [52]. Inpreliminarywork, we experimentedwithautod-
or110002 12−−ssiiddeedd 1100 1100 iff tools, but found these to be unsatisfactory for our
Err10−2 4−sided 10 10 applicationsfortworeasons.Firstly,thesetoolsimposea
ent 10−4 10 10 computationalpenaltyovermanuallyderivedgradients.
Gradi1100−−86 1100 1100 Secondly, autodiff stores all intermediate calculations,
10−1100 0r P=e 1rt1u0r2bation1 0M4ultip1lie01r60100r P=e 2rt1u0r2bation1 0M4ultip1lie01r60100r P=e 3rt1u0r2bation1 0M4ultip1lie0r6 ldeeardivinegd btoelolawrguesemleemssormyemreoqruyi,rebmotehntisn. tTehrme smoefthcoodns-
stant factors and big-O complexity. Nevertheless, some
Figure 3: An evaluation of perturbation multipliers m.
oftheseproblemsareissueswithcurrentimplementations
Top: TRW. Bottom: Mean field. Two effects are in play
of reverse-mode autodiff, avoidable in theory.
here:First,fortoosmallaperturbation,numericalerrors
BothmeanfieldandTRWinvolvestepswherewefirst
dominate. Meanwhile, for too large a perturbation, ap-
takeaproductofasetof terms,andthennormalize.We
proximation errors dominate. We see that using 2- or 4-
define a “backnorm” operator, which is useful in taking
sideddifferencesdifferencesreduceapproximationerror,
derivatives over such operations, by
leading to better results with larger perturbations.
backnorm(g,c)=c⊙(g−g·c).
inference run until marginals changed by a threshold Thiswillbeusedinthealgorithmshere.Morediscussion
of less than 10−15. Fig. 3 shows that using many-sided on this point can be found in Appendix C.
differencesleadstomoreaccuracy,atthecostofneeding
to run inference more times to estimate a single loss
7.1 BackMeanField
gradient. In the following experiments, we chose two-
sided differences with a multiplier of 1 as a reasonable The first backpropagating inference algorithm, back
tradeoffbetweenaccuracy,simplicity,andcomputational mean field, is shown as Alg. 2. The idea is as follows:
expense. Suppose we start with uniform marginals, run N itera-
Welling and Teh used sensitivity of approximate be- tionsofmeanfield,andthen–regardlessofifmeanfield
liefs to parameters to approximate joint probabilities of hasconvergedornot–takepredictedmarginalsandplug
non-neighboring variables [50]. themintooneofthemarginal-basedlossfunctions.Since
each step in this process is differentiable, this specifies
7 TRUNCATED FITTING thelossasadifferentiablefunctionofmodelparameters.
We want the exact gradient of this function.
The previous methods for computing loss gradients
are derived under the assumption that the inference Theorem. After execution of back mean field,
optimization is solved exactly. In an implementation, of
←− dL ←− dL
course, some convergence threshold must be used. θ(xi)= dθ(x ) and θ(xc)= dθ(x ).
i c
Different convergence thresholds can be used in the
learning stage and at test time. In practice, we have A proof sketch is in Appendix C. Roughly speaking,
observed thattooloose a threshold inthe learningstage the proof takes the form of a mechanical differentiation
can lead to a bad estimated risk gradient, and learning of each step of the inference process.
10
Algorithm 2 Back Mean Field Algorithm 3 Back TRW.
1) Initialize µ uniformly. 1) Initialize m uniformly.
2) Repeat N times for all j: 2) Repeat N times for all pairs (c,i), with i∈c:
a) Push the marginals µ onto a stack. a) Push the messages m (x ) onto a stack.
j c i
3) Cobm) pµu(xtej)←−L∝, ←eµ−x(px(cid:0)jθ)(x=j←)−d+µd(Lxcj:Pj)∈acnxPcd\j←θµ−(x(xc)ci)∈Q=c\jdµµ(d(xLxic))(cid:1). 3) µ(bx)c)mc∝(xie)ρ1c∝θ(Pxc)xc\ii∈ecρ1ceθθ((xxci))dQ:Qi∈jmd∈mcc\(dix(xie)iθ)(ρxdj)d:Qj∈mdmc(dx(jx∀)jc)ρd
4) Initialize θ(xi)←0, θ(xc)←0. 4) µ(xi) ∝ eθ(xi) dQ:i∈dmd(xi)ρd ∀i
5) Repeat N times for all j (in reverse order): 5) Compute L, ←µ−(x )= dL and ←µ−(x )= dL .
←− ←− Qi dµ(xi) c dµ(xc)
a) νj ←backnorm(µj,µj) 6) For all c,
b) ←−θ(xj)←←−θ(xj)+←−ν(xj) a) ←−ν(xc) ←backnorm(←µ−c,µc)
c) ←−θ(xc)←←−θ(xc)+←−ν(xj) µ(xi) ∀c:j ∈c b) ←−θ(xc) ←+ ρ1c←−ν(xc)
d) ←µ−(xi)←←µ−(xi)+ ←−ν(xi∈jQ)c\θj(xc) µ(xk) c) ←−θ(xi) ←+ xc\i←ν−(xc) ∀i∈c
xc\i k∈c\{i,j} d) ←m−(x ) ←+ Pρd−Ic=d ←−ν ∀i∈c,∀d:i∈d
P ∀c:j ∈Qc, ∀i∈c\j d i md(xi) xc\i
7) For all i,
P
e) Pull marginals µj from the stack. a) ←−ν(xi) ←backnorm(←µ−i,µi)
←−
f) µj(xj)←0 b) ←−θ(x ) ←+ ←ν−(x )
i i
c) ←m−(x ) ←+ ρ ←−ν(xi) ∀d:i∈d
d i dmd(xi)
Note that, as written, back mean field only produces 8) RepeatN timesforallpairs(c,i)(inreverseorder)
utinonivsarmiaatekimngarugisnealosf,acnlidqusoe cmananrgoitncaolsp.eHwoiwthelvoesrs,fwunitch- a) s(xc) ←eρ1cθ(xc) j∈c\ieθ(xj)d:Qj∈mdmc(dx(jx)j)ρd
mean field, the clique marginals, aresimply the product ←− ←−−
of univariate marginals: µ(x )= µ(x ). Hence, any b) ν(xi) ←backnorQm(mci,mci)
loss defined on clique marcginals ic∈acn eqiuivalently be c) ←−θ(x ) ←+ 1 s(x )←−ν(xi)
defined on univariate marginals.Q c ρc c mc(xi)
d) ←−θ(x ) ←+ s(x )←−ν(xi) ∀j ∈c\i
j xc\j c mc(xi)
7.2 BackTRW e) ←m−(x )←+ Pρd−Ic=d s(x )←−ν(xi)
Next, we consider truncated fitting with TRW inference. d j md(xj) xc\j c mc(xi)
∀j ∈c\i,∀d:j ∈d
As above, we will assume that some fixed number N P
f) Pull messages m (x ) from the stack.
c i
of inference iterations have been run, and we want to
←−
define and differentiate a loss defined on the current g) mc(xi)←0
predicted marginals. Alg. 3 shows the method.
Theorem. After execution of back TRW,
defined,notintermsofpredictedmarginals,butinterms
←− dL ←− dL of partition functions. Nevertheless, we wish to compare
θ(x )= and θ(x )= .
i dθ(x ) c dθ(x ) to these losses in the experiments below. As we found
i c
truncation to be critical for speed, we instead derive a
Again, a proof sketch is in Appendix C.
variant of truncated fitting.
If one uses pairwise factors only, uniform appearance
Thebasicideaistodefinea“truncatedpartitionfunc-
probabilities of ρ=1, removes allreferenceto the stack,
tion”. This is done by taking the predicted marginals,
and uses a convergence threshold in place of a fixed
obtainedafterafixednumberofiterations,andplugging
number of iterations, one obtains essentially Eaton and
them into the entropy approximations used either for
Ghahramani’s back belief propagation [53, extended
mean field (Eq. 12) or TRW (Eq. 16). The approximate
version, Fig. 5]. Here, we refer to the general strategy
entropyH˜ is then used in defining a truncated partition
of usingfull(non-truncated)inferenceas“backpropaga-
function as
tion”, either with LBP, TRW, or mean field.
A˜(θ)=θ·µ(θ)−H˜(µ(θ)).
7.3 TruncatedLikelihood&TruncatedEM
As we will see below, with too few inference iterations,
Applying the truncated fitting strategies to any of the using this approximation can cause the surrogate likeli-
marginal-based loss functions is simple. Applying it to hood to diverge. To see why, imagine an extreme case
the likelihood orEMloss, however,isnotsostraightfor- where zero inference iterations are used. This results in
ward.Thereasonisthattheselosses(Eqs.19and21)are thelossL(θ,x)=θ·(f(x)−µ0)+H˜(µ0),whereµ0arethe