ebook img

Learning Graphical Model Parameters with Approximate Marginal Inference PDF

2.4 MB·English
Save to my drive
Quick download
Download
Most books are stored in the elastic cloud where traffic is expensive. For this reason, we have a limit on daily download.

Preview Learning Graphical Model Parameters with Approximate Marginal Inference

1 Learning Graphical Model Parameters with Approximate Marginal Inference Justin Domke,NICTA & Australia NationalUniversity F 3 Abstract—Likelihood based-learning of graphical models faces chal- 2 SETUP 1 lenges of computational-complexity and robustness to model mis- 0 2.1 MarkovRandomFields specification. This paper studies methods that fit parameters directly 2 tomaximizeameasureoftheaccuracyofpredictedmarginals,taking Markov random fields are probability distributions that n intoaccountbothmodelandinferenceapproximationsattrainingtime. may be written as a Experimentsonimagingproblemssuggestmarginalization-basedlearn- J ing performs better than likelihood-based approximations on difficult 1 p(x)= ψ(x ) ψ(x ). (1) 5 problemswherethemodelbeingfitisapproximateinnature. Z c i 1 c i Y Y IndexTerms—GraphicalModels,ConditionalRandomFields,Machine Thisisdefinedwith referencetoagraph,withone node ] G Learning,Inference,Segmentation. for each random variable. The first product in Eq. 1 is L over the set of cliques c in the graph, while the second s. 1 INTRODUCTION is over all individual variables. For example, the graph c [ GRAPHICALmodelsareastandardtoolinimagepro- x3 x4 1 cessing, computer vision, and many other fields. x1 x2 Exact inference and inference are often intractable, due x x v 5 6 3 to the high treewidth of the graph. 9 Much previous work involves approximations of corresponds to the distribution 1 the likelihood. (Section 4). In this paper, we suggest 3 1 that parameter learning can instead be done using . p(x)= ψ(x1,x2)ψ(x2,x3,x5)ψ(x3,x4)ψ(x5,x6) 1 “marginalization-based” loss functions. These directly Z 0 quantifythequalityofthepredictionsofagivenmarginal ×ψ(x1)ψ(x2)ψ(x3)ψ(x4)ψ(x5)ψ(x6). 3 inference algorithm. This has two major advantages. 1 Eachfunctionψ(x )orψ(x )ispositive,butotherwise First,approximationerrorsintheinferencealgorithmare c i v: arbitrary. The factor Z ensures normalization. taken into account while learning. Second, this is robust i The motivation for these types of models is the Ham- X to model mis-specification. mersley–Cliffordtheorem[1],whichgivesspecificcondi- r The contributions of this paper are, first, the general tionsunderwhichadistributioncanbewrittenasinEq. a framework of marginalization-based fitting as implicit 1. Those conditions are that, first, each random variable differentiation. Second, we show that the parameter is conditionally independent of all others, given its im- gradient can be computed by “perturbation”– that is, mediateneighborsand,secondly,thateachconfiguration byre-runningthe approximatealgorithm twice withthe x has nonzero probability. Often, domain knowledge parametersperturbedslightly based on the currentloss. about conditional independence can be used to build Third, we introduce the strategy of “truncated fitting”. a reasonable graph, and the factorized representation in Inference algorithms are based on optimization, where anMRFreducesthecurseofdimensionalityencountered oneiteratesupdatesuntilsomeconvergencethresholdis in modeling a high-dimensional distribution. reached.Intruncatedfitting,algorithmsarederivedtofit the marginals produced after a fixed number of updates, with no assumption of convergence. We show that this 2.2 ConditionalRandomFields leadstosignificantspeedups.Wealsoderiveavariantof Oneisofteninterestedinmodelingtheconditionalprob- this that can apply to likelihood based learning. Finally, ability of x, given observations y. For such problems, it experimental results confirm that marginalization based is natural to define a Conditional Random Field [2] learning gives betterresults on difficultproblems where 1 inference approximations and model mis-specification p(x|y)= ψ(x ,y) ψ(x ,y). are most significant. Z(y) c i c i Y Y 2 Here, ψ(x ,y) indicates that the value for a particular over all possible configurations, as MPM does. A few c configurationx dependsontheinputy.Inpractice,the papers have experimentally compared MAP and MPM c form of this dependence is application dependent. inference [6], [7]. 2.3 InferenceProblems 2.4 ExponentialFamily Supposewehavesomedistributionp(x|y),wearegiven The exponential family is defined by some input y, and we need to guess a single output vector x∗. What is the best guess? p(x;θ)=exp θ·f(x)−A(θ) , Theanswerclearlydependsonthemeaningof“best”. where θ is a vector of pa(cid:0)rameters, f(x)(cid:1)is a vector of One framework for answering this question is the idea sufficient statistics, and the log-partition function of a Bayes estimator [3]. One must specify some utility function U(x,x′), quantifying how “happy” one is to A(θ)=log expθ·f(x). (3) haveguessedxifthetrueoutputisx′.Onethenchooses x x∗ to maximize the expected utility X ensures normalization. Different sufficient statistics f(x) x∗ =argmax p(x′|y)U(x,x′). define different distributions. The exponential family is x well understood in statistics. Accordingly, it is useful to x′ X note that a Markov random field (Eq. 1) is a member of One natural utility function is an indicator function, the exponential family, with sufficient statistics consist- giving one for the exact value x′, and zero otherwise. It ingofindicatorfunctionsforeachpossibleconfiguration is easy to show that for this utility, the optimal estimate of each clique and each variable [8], namely, is the popular Maximum a Posteriori (MAP) estimate. Theorem. If U(x,x′)=I[x=x′], then f(X)={I[X =x ]|∀c,x }∪{I[X =x ]|∀i,x }. c c c i i i x∗ =argmaxp(x|y). x It is useful to introduce the notation θ(x ) to refer c Little can be said in general about if this utility to the component of θ corresponding to the indicator function truly reflects user priorities. However, in high- function I[Xc = xc], and similarly for θ(xi). Then, dimensional applications, there are reasons for skepti- the MRF in Eq. 1 would have ψ(xc) = eθ(xc) and cism. First, the actual maximizing probability p(x∗|y) ψ(xi) = eθ(xi). Many operations on graphical models in a MAP estimate might be extremely small, so much canbemoreelegantlyrepresentedusingthisexponential so that astronomical numbers of examples might be family representation. necessary before one could expect to exactly predict the A standard problem in the exponential family is to true output. Second, this utility does not distinguish compute the mean value of f, between a prediction that contains only a single error µ(θ)= p(x;θ)f(x), at some component x , and one that is entirely wrong. j x An alternative utility function, popular for imaging X problems,quantifiestheHammingdistance,orthenum- called the “mean parameters”. It is easy to show these ber of components of the output vector that are correct. are equal to the gradient of the log-partition function. Maximizingthisresultsinselectingthemostlikelyvalue dA =µ(θ). (4) for each component independently. dθ Theorem. If U(x,x′)= iI[xi =x′i], then For an exponential family corresponding to an MRF, x∗ =argmaxp(x |y). (2) computingµisequivalenttocomputingallthemarginal i P xi i probabilities. To see this, note that, using a similar This appearsto havebeenoriginallycalledMaximum notation for indexing µ as for θ above, Posterior Marginal (MPM) inference [4], though it has beenreinventedunderothernames[5].Fromacomputa- µ(x ;θ)= p(X;θ)I[X =x ]=p(x ;θ). c c c c tional perspective, the main difficulty is not performing X the trivial maximization in Eq. 2, but rather computing X themarginalsp(x |y).Themarginal-basedlossfunctions Conditionaldistributionscanberepresentedbythink- i ingoftheparametervectorθ(y;γ)asbeingafunctionof introduced in Section 4.2 can be motivated by the idea the input y, whereγ arenow the freeparametersrather thatattesttime,onewilluseaninferencemethodsimilar than θ. (Again, the nature of the dependence of θ on y to MPMwhereone in concernedonly with the accuracy and γ will vary by application.) Then, we have that of the marginals. TheresultsofMAPandMPMinferencewillbesimilar p(x|y;γ)=exp θ(y;γ)·f(x)−A(θ(y;γ)) , (5) if the distribution p(x|y) is heavily “peaked” at a single configuration x. Roughly, the greater the entropy of sometimes called a cu(cid:0)rved conditional exponent(cid:1)ial fam- p(x|y), the more there is to be gained in integrating ily. 3 2.5 Learning 3.1 ExactVariationalPrinciple The focus of this paper is learning of model parameters Theorem (Exact variational principle). The log-partition from data. (Automatically determining graph structure function can also be represented as remains an active research area, but is not considered here.) Specifically, we take the goal of learning to be to A(θ)= maxθ·µ+H(µ), (8) µ∈M minimize the empirical risk where R(θ)= L θ,xˆ , (6) M={µ′ :∃θ,µ′ =µ(θ)} xˆ X (cid:0) (cid:1) where the summation is over all examples xˆ in the is the marginal polytope, and dataset, and the loss function L(θ,xˆ) quantifies how H(µ)=− p(x;θ(µ))logp(x;θ(µ)) well the distribution defined by the parameter vector x θ matches the example xˆ. Several loss functions are X is the entropy. considered in Section 4. We assume that the empirical risk will be fit by some In treelike graphs, this optimization can be solved gradient-based optimization. Hence, the main technical efficiently.Ingeneralgraphs,however,itisintractablein issuesinlearningarewhichlossfunctiontouseandhow two ways. First, the marginal polytope M becomes dif- to compute the gradient dL. ficult to characterize. Second, the entropy is intractable dθ In practice, we will usually be interested in fitting to compute. conditional distributions. Using the notation from Eq. Applying Danskin’s theorem to Eq. 8 yields that 5, we can write this as dA µ(θ)= =argmaxθ·µ+H(µ). (9) R(γ)= L θ(yˆ,γ),xˆ . dθ µ∈M (Xyˆ,xˆ) (cid:0) (cid:1) Thus,thepartitionfunction(Eq.8)andmarginals(Eq. Note that if one has recovered dL, dL is immediate 9)canbothbeobtainedfromsolvingthesameoptimiza- dθ dγ from the vector chain rule as tion problem. This close relationship between the log- partition function and marginals is heavily used in the dL dθT dL derivation of approximate marginalization algorithms. = . (7) dγ dγ dθ To compute approximate marginals, first, derive an ap- proximate version of the optimization in Eq. 8. Next, Thus, the main technical problems involved in fitting take the exact gradient of this approximate partition a conditional distribution are similar to those for a generativedistribution:Onefindsθ =θ(yˆ,γ),computes function.Thisstrategyisusedinbothoftheapproximate the L and dL on example xˆ exactly as in the generative marginalization procedures considered here: mean field dθ case, and finally recovers dL from Eq. 7. So, for simplic- and tree-reweighted belief propagation. dγ ity, y and γ will largely be ignored in the theoretical developments below. 3.2 MeanField The idea of mean field is to approximate the exact 3 VARIATIONAL INFERENCE variationalprinciplebyreplacingMwithsometractable This section reviews approximate methods for comput- subset F ⊂ M, such that F is easy to characterize, ing marginals, with notation based on Wainwright and and for any vector µ ∈ F we can exactly compute the Jordan [8]. For readability, all proofs in this section are entropy. To create such a set F, instead of considering postponed to Appendix A. the set of mean vectors obtainable from any parameter The relationship between the marginals and the log- vector (which characterizes M), consider a subset of partition function in Eq. 4 is key to defining approx- tractable parametervectors. The simplest way to achieve imate marginalization procedures. In Section 3.1, the thistorestrictconsiderationtoparametervectorsθ with exact variational principle shows that the (intractable) θ(x )=0 for all factors c. c problem of computing the log-partition function can be convertedtoa(stillintractable)optimizationproblem.To F ={µ′ :∃θ,µ′ =µ(θ), ∀c, θ(x )=0}. c deriveatractablemarginalizationalgorithmoneapprox- It is not hard to see that this corresponds to the set imates this optimization, yielding some approximate log-partition function A˜(θ). The approximate marginals of fully-factorized distributions. Note also that this is (in are then taken as the exact gradient of A˜. non-treelike graphs) a non-convex set, since it has the same convex hull as M, but is a proper subset. So, the We define the reverse mapping θ(µ) to return some meanfield partition function approximationis based on parametervectorthatyieldsthatmarginalsµ.Whilethis the optimization will in general not be unique [8, sec. 3.5.2], any two vectorsthatproducethesamemarginalsµwillalsoyield A˜(θ)=maxθ·µ+H(µ), (10) thesamedistribution,andsop(x;θ(µ))isunambiguous. µ∈F 4 with approximate marginals corresponding to the max- Thus,themarginalvectorreturnedbyTRWmay,ingen- imizing vector µ, i.e. eral,beinconsistentinthesensethatnojointdistribution yields those marginals. µ˜(θ)=argmaxθ·µ+H(µ). (11) The entropy approximation used by TRW is µ∈F Since this is maximizing the same objective as the H˜(µ)= H(µ )− ρ I(µ ), (16) i c c exact variational principle, but under a more restricted i c constraint set, clearly A˜(θ)≤A(θ). X X where H(µ ) = − µ(x )logµ(x ) is the univariate Here, since the marginals are coming from a fully- i xi i i entropy corresponding to variable i, and factorized distribution, the exact entropy is available as P µ(x ) H(µ)=− µ(xi)logµ(xi). (12) I(µ )= µ(x )log c (17) c c µ(x ) Xi Xxi Xxc i∈c i The strategy we use to perform the maximization in Q isthemutualinformationcorrespondingtothevariables Eq. 10 is block-coordinate ascent. Namely, we pick a in the factor c. The motivation for this approximation is coordinate j, then set µ(x ) to maximize the objective, j that if the constants ρ are selected appropriately, this leaving µ(x ) fixed for all i 6= j. The next theorem c i gives an upper bound on the true entropy. formalizes this. Theorem (TRWEntropyBound). LetPr(G) beadistribu- Theorem (Mean Field Updates). A local maximum of Eq. tion over tree structured graphs, and define ρ =Pr(c∈G). 10 can be reached by iterating the updates c Then, with H˜ as defined in Eq. 16, µ(x )← 1 exp θ(x )+ θ(x ) µ(x ) , H˜(µ)≥H(µ). j j c i Z (cid:0) cX:j∈cxXc\j i∈Yc\j (cid:1) Thus, TRW is maximizing an upper bound on the ex- whereZ isanormalizing factorensuring that µ(x )=1. act variational principle, under an expanded constraint j Xxj smeat.xiSminucme bvoatlhueo,fwtehehsaevcehtahnagteAs˜(θca)n≥oAn(lyθ)i.ncrease the 3.3 Tree-ReweightedBeliefPropagation Now, we consider how to actually compute the approximate log-partition function and associated Whereas mean field replaced the marginal polytope marginals. Consider the message-passing updates with asubset,tree-reweightedbeliefpropagation(TRW) replaces it with a superset, L ⊃ M. This clearly can ofunnlyctiionncr.eHasoewtheveevr,alauefuorfththereaapppprrooxxiimmaattieonlogis-pnaeretidtieodn, mc(xi)∝ eρ1cθ(xc) eθ(xj) d:j∈dmd(xj)ρd, (18) m (x ) as the entropy remains intractable to compute for an xXc\i jY∈c\i Q c j arbitrary mean vector µ. (It is not even defined for where “∝” is used as an assignment operator to means µ 6∈ M.) Thus, TRW further approximates the entropy assigning after normalization. with a tractableupper bound. Takentogether, these two approximations yield a tractable upper bound on the Theorem (TRW Updates). Let ρ be as in the previous c log-partition function. theorem. Then, if the updates in Eq. 18 reach a fixed point, Thus, TRW is based on the optimization problem the marginals defined by A˜(θ)=mµ∈aLxθ·µ+H˜(µ). (13) µ(xc) ∝ eρ1cθ(xc) eθ(xi) d:i∈mdm(xd()xi)ρd, i∈c Q c i Again, the approximate marginals are simply the maxi- Y mizing vector µ, i.e., µ(xi) ∝ eθ(xi) md(xi)ρd d:i∈d Y µ˜(θ)=argmaxθ·µ+H˜(µ). (14) constitute the global optimum of Eq. 13. µ∈L The relaxation of the local polytope used in TRW is So, if the updates happen to converge, we have the the local polytope, solution. Meltzer et al. show [9] that on certain graphs made up of monotonic chains, an appropriateordering of messages does assure convergence. (The proof is essen- L={µ: µ(x )=µ(x ), µ(x )=1}. (15) c i i tially that under these circumstances, message passing xXc\i Xxi is equivalent to coordinate ascent in the dual.) Since any valid marginal vector must obey these con- TRW simplifies into loopy belief propagation by straints, clearly M ⊂ L. However, L in general also choosing ρ = 1 everywhere, though the bounding c contains unrealizable vectors (though on trees L=M). property is lost. 5 4 LOSS FUNCTIONS surrogatelikelihood willthenboundthetruelikelihood. Specifically, mean field based surrogate likelihood is an For space, only a representative sample of prior work upper bound on the true likelihood, while TRW-based can be cited. A recent review [10] is more thorough. surrogate likelihood is a lower bound. Though,technically,a“loss”shouldbeminimized,we continue to use this terminology for the likelihood and 4.1.2 ExpectationMaximization its approximations, where one wishes to maximize. For simplicity, the discussion below is for the genera- In many applications, only a subset of variables may be tive setting. Using the same loss functions for training a observed. Suppose that we want to model x = (z,h) conditional model is simple (Section 2.5). where z is observed, but h is hidden. A natural loss function here is the expected maximization (EM) loss 4.1 TheLikelihoodandApproximations L(θ,z)=logp(z;θ)=log p(z,h;θ). The classic loss function would be the likelihood, with h X L(θ,x)=logp(x;θ)=θ·f(x)−A(θ). (19) It is easy to show that this is equivalent to This has the gradient L(θ,z)=A(θ,z)−A(θ), (21) dL where A(θ,z)=log expθ·f(z,h) is the log-partition =f(x)−µ(θ). (20) h dθ function with z “clamped” to the observed values. If all P One argument for the likelihood is that it is efficient; variables are observed A(θ,z) reduces to θ·f(z). given a correct model, as data increases it converges to If on substitutes a variational approximation for true parameters at an asymptotically optimal rate [11]. A(θ,z), a “variational EM” algorithm [8, Sec. 6.2.2] can Some previous work uses tree structured graphs berecoveredthatalternatesbetweencomputingapprox- where marginals may be computed exactly [12]. Of imate marginals and parameter updates. Here, because course, in high-treewidth graphs, the likelihood and of the close relationship to the surrogate likelihood, we its gradient will be intractable to compute exactly, due designate“surrogateEM”forthecasewhereA(θ,z)and to the presence of the log-partition function A(θ) and A(θ)maybothbeapproximatedandthelearningisdone marginals µ(θ). This has motivated a varietyof approx- with a gradient-based method. To obtain a bound on imations. The first is to approximate the marginals µ the true EM loss, care is required. For example, lower- using Markov chain Monte Carlo [13], [14]. This can boundingA(θ,z)usingmeanfield,andupper-bounding lead to high computational expense (particularly in the A(θ) using TRW means a lower-bound on the true EM conditional case, where differentchains must be run for loss. However, using the same approximation for both eachinput).ContrastiveDivergence[15]furtherapprox- A(θ) and A(θ,z) appears to work well in practice [26]. imates these samples by running the Markov chain for only a few steps, but started at the data points [16]. If 4.1.3 Saddle-Point Approximation the Markov chain is run long enough, these approaches Athirdapproximationofthelikelihoodistosearchfora can give an arbitrarily good approximation. However, “saddle-point”. Here, one approximates the gradient in Markovchainparametersmayneedtobeadjustedtothe Eq. 20by running a (presumablyapproximate)MAPin- particular problem, and these approaches are generally ferencealgorithm,andthenimaginingthatthemarginals slower than those discussed below. put unit probability at the approximate MAP solution, and zero elsewhere [27], [28], [21]. This is a heuristic 4.1.1 SurrogateLikelihood method, but it can be expected to work well when the A seemingly heuristic approachwould be to replacethe estimated MAP solution is close to the true MAP and marginals in Eq. 20 with those from an approximate the conditional distribution p(x|y) is strongly “peaked”. inferencemethod.Thisapproximationcanbequiteprin- cipled if one thinks instead of approximating the log- 4.1.4 Pseudolikelihood partition function in the likelihood itself (Eq. 19). Then, Finally, there are two classes of likelihood approxima- thecorrespondingapproximatemarginalswillemergeas tionsthatdonotrequireinference.Thefirstistheclassic the exact gradient of this surrogate loss. This “surrogate pseudolikelihood [29], where one uses likelihood” [17] approximation appears to be the most L(θ,x)= logp(x |x ;θ). widely used loss in imaging problems, with marginals i −i approximated by either mean field [18], [19], TRW [20] Xi or LBP [21], [22], [23], [24], [25]. However, the terminol- This can be computed efficiently, even in high ogy of “surrogate likelihood” is not widespread and in treewidthgraphs,sinceconditionalprobabilitiesareeasy most cases, only the gradient is computed, meaning the to compute. Besag [29] showed that, under certain con- optimization cannot use line searches. ditions, this will converge to the true parameter vector If one uses a log-partition approximation that pro- asthe amount of databecomes infinite. The pseudolike- vides a bound on the true log-partition function, the lihood has been used in many applications [30], [31]. 6 Instead of the probability of individual variables given There is reason to think the likelihood is somewhat all others, one can take the probability of patches of robust to model mis-specification. In the infinite data variables given all others, sometimes called the “patch” limit, it finds the “closest” solution in the sense of KL- pseudolikelihood [32]. This interpolates to the exact divergence since, if q is the true distribution, then likelihood as the patches become larger, though some KL(q||p) = const.−Elogp(x;θ). type of inference is generally required. q 4.1.5 PiecewiseLikelihood 4.2.1 UnivariateLogisticLoss More recently, Sutton and McCallum [33] suggested the The univariate logistic loss [39] is defined by piecewiselikelihood.Theideaistoapproximatethelog- L(θ,x)=− logµ(x ;θ), partition function as a sum of log-partition functions of i the different“pieces“ of the graph. Thereis flexibility in Xi determining which pieces to use. In this paper, we will where we use the notation µ to indicate that the loss is use pieces consisting of each clique and each variable, implicitly defined with respect to the marginal predic- which worked better in practice than some alternatives. tions of some (possibly approximate) algorithm, rather Then, one has the surrogate partition function than the true marginals. This measures the mean accu- racy of all univariate marginals, rather than the joint A˜(θ) = A (θ)+ A (θ), c i distribution. This loss can be seen as empirical risk Xc Xi minimization of the KL-divergence between the true Ac(θ) = log eθ(xc), Ai(θ)=log eθ(xi). marginals and the predicted ones, since Xxc Xxi q(x ) i ItisnottoohardtoshowthatA(θ)≤A˜(θ).Inpractice, KL(qi||µi) = q(xi)log µ(x ;θ) i it is sometimes bestto make some heuristic adjustments Xi Xi Xxi to the parameters after learning to improve test-time = const.−E logµ(xi;θ). q performance [34], [35]. i X Ifdefinedonexactmarginals,thisisatypeofcomposite 4.2 Marginal-basedLossFunctions likelihood [40]. Given the discussion in Section 4.1, one might conclude 4.2.2 SmoothedUnivariateClassificationError that the likelihood, while difficult to optimize, is an ideal loss function since, given a well-specified model, Perhaps the most natural loss in the conditional setting itwill convergetothe true parametersatasymptotically would be the univariate classification error, efficient rates. However, this conclusion is complicated by two issues. First, of course, the maximum likelihood L(θ,x)= S max µ(xi;θ)−µ(xi;θ) , solution is computationally intractable, motivating the Xi (cid:0)x′i6=xi (cid:1) approximations above. whereS(·)isthestepfunction.Thisexactlymeasuresthe A second issue is that of model mis-specification. For number of components of x that would be incorrectly many types of complex phenomena, we will wish to fit predicted if using MPM inference. Of course, this loss amodelthatisapproximateinnature.Thiscouldbetrue is neither differentiable nor continuous, which makes it because the conditional independencies asserted by the impractical to optimize using gradient-based methods. graph do not exactly hold, or because the parametriza- Instead Gross et al. [5] suggest approximating with a tion of factors is too simplistic. These approximations sigmoid function S(t) = (1 + exp(−αt))−1, where α mightbemadeoutofignorance,duetoalackofknowl- controls approximation quality. edge about the domain being studied, or deliberately There is evidence [36], [5] that the smoothed classifi- becausethe true modelmighthavetoo manydegreesof cation loss can yield parameters with lower univariate freedom to be fit with available data. classification error under MPM inference. However, our Inthecaseofanapproximatemodel,no“true”param- experienceisthatitisalsomorepronetogettingstuckin eters exist. The idea of marginal-based loss functions is local minima, making experiments difficult to interpret. to instead consider how the model will be used. If one Thus, it is not included in the experiments below. Our will compute marginals at test-time – perhaps for MPM experience with the univariate quadratic loss [41] is inference (Section 2.3) – it makes sense to maximize the similar. accuracy of these predictions. Further, if one will use an approximate inference algorithm, it makes sense to 4.2.3 CliqueLosses optimize the accuracyof the approximatemarginals. This Any of the above univariate losses can be instead taken essentially fits into the paradigm of empirical risk min- based on cliques. For example, the clique logistic loss is imization [36], [37]. The idea of training a probabilistic model using an alternative loss to the likelihood goes L(θ,x)=− logµ(x ;θ), c back at least to Bahl et al. in the late 1980s [38]. c X 7 0.3 shift of 0 shift of 3 shift of 10 by Markov chain Monte Carlo, with shifts introduced mean test error00..12 lciklieqluiheo loodgistic aftrofatreimnresdgarmaudpsiilneinnggta.(STgehrcea.dl4iike.1ne)lt,ihwoobhotiadlienwethdasevfilioatgubissaticinckgplotrhosespeacsgloawsteieodrne- univariate logistic 1000 101 102 103 100 101 102 103 10 0 101 102 103 (Sec. 7). Fig. 1 shows the mean test error (estimated on # training data # training data # training data 1000 examples), while Fig. 2 shows example marginals. Figure 1: Mean test error of various loss functions We see that the performance of all methods deteriorates trained with exact inference. In the case of a well- withmis-specification,butthemarginal-basedlossfunc- specified model (shift of zero), the likelihood performs tions are more resistant to these effects. essentially identically to the marginal-based loss func- tions. However, when mis-specification is introduced, 4.4 MAP-BasedTraining quite different estimates result. Another class of methods explicitly optimize the perfor- mance of MAP inference [42], [43], [44], [45], [25]. This paper focuses on applications that use marginal infer- which may be seen as empirical risk minimization of ence, and that may need to accommodate hidden vari- the mean KL-divergence of the true clique marginals to ables, and so concentrates on likelihood and marginal- the predicted ones. An advantage of this with an exact based losses. modelisconsistency.Simpleexamplesshowcaseswhere a model predicts perfect univariate marginals, despite 5 IMPLICIT FITTING the joint distribution being very inaccurate. However, if allcliquemarginalsarecorrect,thejointmustbecorrect, Wenow turntothe issue ofhowtotrainhigh-treewidth by the standard moment matching conditions for the graphical models to optimize the performance of a exponential family [8]. marginal-based loss function, based on some approxi- mate inference algorithm. Now, computing the value of 4.2.4 Hiddenvariables the loss for any of the marginal-based loss functions is not hard. One can simply run the inference algorithm Marginal-basedloss functions canaccommodate hidden and plug the resulting marginal into the loss. However, variables by simply taking the sum in the loss over the we also require the gradient dL. observed variables only. A similar approach can be used dθ Ourfirstresultisthatthelossgradientcanbeobtained with the pseudolikelihood or piecewise likelihood. by solving a sparse linear system. Here, it is useful to introduce notation to distinguish the loss L, defined in 4.3 ComparisonwithExactInference terms of the parameters θ from the loss Q, defined di- rectlyintermsofthemarginalsµ.(Notethatthoughthe To compare the effects of different loss functions in the notation suggests the application to marginal inference, presenceofmodelmis-specification,thissectioncontains this is a generic result.) a simple example where the graphical model takes the following “chain” structure: Theorem. Suppose that x x x x x µ(θ):=argmax θ·µ+H(µ). (22) 1 2 3 4 n µ:Bµ=d ... Define L(θ,x)=Q(µ(θ),x). Then, letting D = d2H , y y y y y dµdµT 1 2 3 4 n dL dQ = D−1BT(BD−1BT)−1BD−1−D−1 . dθ dµ Here, exactinferenceis possible, so comparison is not (cid:0) (cid:1) complicated by approximate inference. A proof may be found in Appendix B. This theorem All variables are binary. Parameters are generated by states that, essentially, once one has computed the pre- takingθ(x )randomlyfromtheinterval[−1,+1]foralli dicted marginals, the gradient of the loss with respect i and x . Interaction parameters are taken as θ(x ,x )=t to marginals dQ can be transformed into the gradient i i j dµ when xi = xj, and θ(xi,xj) = −t when xi 6= xj, where of the loss with respect to parameters ddLθ through the t is randomly chosen from the interval [−1,+1] for all solution of a sparse linear system. (i,j). Interactionsθ(yi,yj)andθ(xi,yi)arechoseninthe The optimization in Eq. 22 takes place under linear same way. constraints, which encompasses the local polytope used To systematically study the effects of differing inTRWmessage-passing(Eq.15).Thistheoremdoesnot “amounts” of mis-specification, after generating data, apply to mean field, as F is not a linear constraint set we apply various circular shifts to x. Thus, the data whenviewedasafunctionofbothcliqueandunivariate no longer corresponds exactly the the structure of the marginals. graphical model being fit. Inanycase,themethodsdevelopedbelowaresimpler Thirty-two different random distributions were cre- to use, as they do not require explicitly forming the ated.Foreach,variousquantitiesofdataweregenerated constraint matrix B or solving the linear system. 8 ) true marginals x 1 1| = eplacements 0.5 yi ( p 0 0 20 40 60 80 100 120 140 160 180 200 i ) shift of 3 x 1 1| = 0.5 yi ( p 0 0 20 40 60 80 100 120 140 160 180 200 i ) shift of 10 x 1 1| = 0.5 yi ( p 0 0 20 40 60 80 100 120 140 160 180 200 i Figure 2: Exact and predicted marginals for an example input. Predicted marginals are trained using 1000 data. With low shifts, all loss functions lead to accuratepredicted marginals. However, the univariate and clique logistic loss are more resistant to the effects of model mis-specification. Legends as in Fig. 1. 6 PERTURBATION costofmorecallstoinference,aregivenusingtwo-sided differences, with This section observes that variational methods have a special structure that allows derivativesto be calculated dL 1 dQ dQ ≈ µ(θ+r )−µ(θ−r ) , without explicitly forming or inverting a linear system. dθ 2r dµ dµ We have, by the vector chain rule, that (cid:0) (cid:1) which is accurate to order o(r2). Still more accurate dL dµT dQ results are obtained with “four-sided” differences, with = . (23) dθ dθ dµ dL 1 dQ dQ ≈ −µ(θ+2r )+8µ(θ+r ) dθ 12r dµ dµ A classic trick in scientific computing is to efficiently (cid:0) dQ dQ compute Jacobian-vector products by finite differences. −8µ(θ−r )+µ(θ−2r ) , The basic result is that, for any vector v, dµ dµ (cid:1) dµ 1 which is accurate to order o(r4) [47]. v= lim µ(θ+rv)−µ(θ) , Alg. 1 shows more explicitly how the loss gradient dθT r→0r could be calculated, using two-sided differences. (cid:0) (cid:1) whichisessentiallyjustthedefinitionofthederivativeof The issue remains of how to calculate the step size µ in the direction of v. Now, this does not immediately r. Each of the approximations above becomes exact seem helpful, since Eq. 23 requires dµT, not dµ . How- as r → 0. However, as r becomes very small, nu- dθ dθT ever,withvariationalmethods,thesearesymmetric.The merical error eventually dominates. To investigate this simplest way to see this is to note that issueexperimentally,wegeneratedrandommodelsona 10×10binarygrid,witheachparameterθ(x )randomly i dµ d dA dA = = . chosen from a standard normal, while each interaction dθT dθT dθ dθdθT parameterθ(x ,x )waschosenrandomlyfromanormal (cid:18) (cid:19) i j with a standard deviation of s. In each case, a random Domke [46] lists conditions for various classes of en- value x was generated, and the “true” loss gradient tropies that guarantee that A will be differentiable. was estimated by standard (inefficient) 2-sided finite Combining the above three equations, the loss gradi- differences, with inference re-run after each component ent is available as the limit of θ is perturbed independently. To this, we compare one, two, and four-sided perturbations. In all cases, dL 1 dQ = lim µ(θ+r )−µ(θ) . (24) the step size is, following Andrei [48], taken to be dθ r→0r(cid:0) dµ (cid:1) r =mǫ13 1+||θ||∞ /||ddQµ||∞,whereǫismachineepsilon, In practice, of course, the gradient is approximated and m is a multiplier that we will vary. Note that the (cid:0) (cid:1) using some finite r. The simplest approximation, one- optimal power of ǫ will depend on the finite difference sided differences, simply takes a single value of r in scheme; 1 is optimal for two-sided differences [49, Sec. 3 Eq. 24, rather than a limit. More accurate results at the 8.1].Allcalculations take place in double-precisionwith 9 Algorithm1Calculating dL byperturbation(two-sided). terminating with a bad search direction. Meanwhile, a dθ 1) Do inference. µ∗ ←argmaxθ·µ+H(µ) loose threshold can often be used at test time with few µ∈M consequences. Usually, a differenceof 10−3 in estimated dQ 2) At µ∗, calculate the gradient . marginals has little practicalimpact, but this can still be dµ enough to prevent learning from succeeding [51]. 3) Calculate a perturbation size r. Itseemsoddthatthelearningalgorithmwouldspend 4) Do inference on perturbed parameters. the majority of computational effort exploring tight dQ convergence levels that are irrelevant to the practical a) µ+ ←argmax(θ+r )·µ+H(µ) µ∈M dµ performance of the model. Here, we define the learning b) µ− ←argmax(θ−rdQ)·µ+H(µ) objectiveintermsoftheapproximatemarginalsobtained µ∈M dµ after a fixed number of iterations. To understand this, 5) Recover full derivative as dL ← 1 (µ+−µ−). one may think of the inference process not as an op- dθ 2r timization, but rather as a large, nonlinear function. This clearly leads to a well-defined objective function. or110002 12−−ssiiddeedd 1100 1100 ITnRpWuttionrgmpeaaranmfieeteldrs,, caopmplpyuintigngthperietdericatteiodnsmoafrgeiintahlesr, nt Err1100−−42 4−sided 1100 1100 and finally a loss are all differentiable operations. Thus, adie10−6 10 10 the loss gradient is efficiently computable, at least in Gr10−8 10 10 principle,byreverse-modeautomaticdifferentiation(au- 10−1100 0r = 1102 104 10160100r = 2102 104 10160100r = 3102 104 106 todiff), an approach explored by Stovanov et al. [36], Perturbation Multiplier Perturbation Multiplier Perturbation Multiplier [52]. Inpreliminarywork, we experimentedwithautod- or110002 12−−ssiiddeedd 1100 1100 iff tools, but found these to be unsatisfactory for our Err10−2 4−sided 10 10 applicationsfortworeasons.Firstly,thesetoolsimposea ent 10−4 10 10 computationalpenaltyovermanuallyderivedgradients. Gradi1100−−86 1100 1100 Secondly, autodiff stores all intermediate calculations, 10−1100 0r P=e 1rt1u0r2bation1 0M4ultip1lie01r60100r P=e 2rt1u0r2bation1 0M4ultip1lie01r60100r P=e 3rt1u0r2bation1 0M4ultip1lie0r6 ldeeardivinegd btoelolawrguesemleemssormyemreoqruyi,rebmotehntisn. tTehrme smoefthcoodns- stant factors and big-O complexity. Nevertheless, some Figure 3: An evaluation of perturbation multipliers m. oftheseproblemsareissueswithcurrentimplementations Top: TRW. Bottom: Mean field. Two effects are in play of reverse-mode autodiff, avoidable in theory. here:First,fortoosmallaperturbation,numericalerrors BothmeanfieldandTRWinvolvestepswherewefirst dominate. Meanwhile, for too large a perturbation, ap- takeaproductofasetof terms,andthennormalize.We proximation errors dominate. We see that using 2- or 4- define a “backnorm” operator, which is useful in taking sideddifferencesdifferencesreduceapproximationerror, derivatives over such operations, by leading to better results with larger perturbations. backnorm(g,c)=c⊙(g−g·c). inference run until marginals changed by a threshold Thiswillbeusedinthealgorithmshere.Morediscussion of less than 10−15. Fig. 3 shows that using many-sided on this point can be found in Appendix C. differencesleadstomoreaccuracy,atthecostofneeding to run inference more times to estimate a single loss 7.1 BackMeanField gradient. In the following experiments, we chose two- sided differences with a multiplier of 1 as a reasonable The first backpropagating inference algorithm, back tradeoffbetweenaccuracy,simplicity,andcomputational mean field, is shown as Alg. 2. The idea is as follows: expense. Suppose we start with uniform marginals, run N itera- Welling and Teh used sensitivity of approximate be- tionsofmeanfield,andthen–regardlessofifmeanfield liefs to parameters to approximate joint probabilities of hasconvergedornot–takepredictedmarginalsandplug non-neighboring variables [50]. themintooneofthemarginal-basedlossfunctions.Since each step in this process is differentiable, this specifies 7 TRUNCATED FITTING thelossasadifferentiablefunctionofmodelparameters. We want the exact gradient of this function. The previous methods for computing loss gradients are derived under the assumption that the inference Theorem. After execution of back mean field, optimization is solved exactly. In an implementation, of ←− dL ←− dL course, some convergence threshold must be used. θ(xi)= dθ(x ) and θ(xc)= dθ(x ). i c Different convergence thresholds can be used in the learning stage and at test time. In practice, we have A proof sketch is in Appendix C. Roughly speaking, observed thattooloose a threshold inthe learningstage the proof takes the form of a mechanical differentiation can lead to a bad estimated risk gradient, and learning of each step of the inference process. 10 Algorithm 2 Back Mean Field Algorithm 3 Back TRW. 1) Initialize µ uniformly. 1) Initialize m uniformly. 2) Repeat N times for all j: 2) Repeat N times for all pairs (c,i), with i∈c: a) Push the marginals µ onto a stack. a) Push the messages m (x ) onto a stack. j c i 3) Cobm) pµu(xtej)←−L∝, ←eµ−x(px(cid:0)jθ)(x=j←)−d+µd(Lxcj:Pj)∈acnxPcd\j←θµ−(x(xc)ci)∈Q=c\jdµµ(d(xLxic))(cid:1). 3) µ(bx)c)mc∝(xie)ρ1c∝θ(Pxc)xc\ii∈ecρ1ceθθ((xxci))dQ:Qi∈jmd∈mcc\(dix(xie)iθ)(ρxdj)d:Qj∈mdmc(dx(jx∀)jc)ρd 4) Initialize θ(xi)←0, θ(xc)←0. 4) µ(xi) ∝ eθ(xi) dQ:i∈dmd(xi)ρd ∀i 5) Repeat N times for all j (in reverse order): 5) Compute L, ←µ−(x )= dL and ←µ−(x )= dL . ←− ←− Qi dµ(xi) c dµ(xc) a) νj ←backnorm(µj,µj) 6) For all c, b) ←−θ(xj)←←−θ(xj)+←−ν(xj) a) ←−ν(xc) ←backnorm(←µ−c,µc) c) ←−θ(xc)←←−θ(xc)+←−ν(xj) µ(xi) ∀c:j ∈c b) ←−θ(xc) ←+ ρ1c←−ν(xc) d) ←µ−(xi)←←µ−(xi)+ ←−ν(xi∈jQ)c\θj(xc) µ(xk) c) ←−θ(xi) ←+ xc\i←ν−(xc) ∀i∈c xc\i k∈c\{i,j} d) ←m−(x ) ←+ Pρd−Ic=d ←−ν ∀i∈c,∀d:i∈d P ∀c:j ∈Qc, ∀i∈c\j d i md(xi) xc\i 7) For all i, P e) Pull marginals µj from the stack. a) ←−ν(xi) ←backnorm(←µ−i,µi) ←− f) µj(xj)←0 b) ←−θ(x ) ←+ ←ν−(x ) i i c) ←m−(x ) ←+ ρ ←−ν(xi) ∀d:i∈d d i dmd(xi) Note that, as written, back mean field only produces 8) RepeatN timesforallpairs(c,i)(inreverseorder) utinonivsarmiaatekimngarugisnealosf,acnlidqusoe cmananrgoitncaolsp.eHwoiwthelvoesrs,fwunitch- a) s(xc) ←eρ1cθ(xc) j∈c\ieθ(xj)d:Qj∈mdmc(dx(jx)j)ρd mean field, the clique marginals, aresimply the product ←− ←−− of univariate marginals: µ(x )= µ(x ). Hence, any b) ν(xi) ←backnorQm(mci,mci) loss defined on clique marcginals ic∈acn eqiuivalently be c) ←−θ(x ) ←+ 1 s(x )←−ν(xi) defined on univariate marginals.Q c ρc c mc(xi) d) ←−θ(x ) ←+ s(x )←−ν(xi) ∀j ∈c\i j xc\j c mc(xi) 7.2 BackTRW e) ←m−(x )←+ Pρd−Ic=d s(x )←−ν(xi) Next, we consider truncated fitting with TRW inference. d j md(xj) xc\j c mc(xi) ∀j ∈c\i,∀d:j ∈d As above, we will assume that some fixed number N P f) Pull messages m (x ) from the stack. c i of inference iterations have been run, and we want to ←− define and differentiate a loss defined on the current g) mc(xi)←0 predicted marginals. Alg. 3 shows the method. Theorem. After execution of back TRW, defined,notintermsofpredictedmarginals,butinterms ←− dL ←− dL of partition functions. Nevertheless, we wish to compare θ(x )= and θ(x )= . i dθ(x ) c dθ(x ) to these losses in the experiments below. As we found i c truncation to be critical for speed, we instead derive a Again, a proof sketch is in Appendix C. variant of truncated fitting. If one uses pairwise factors only, uniform appearance Thebasicideaistodefinea“truncatedpartitionfunc- probabilities of ρ=1, removes allreferenceto the stack, tion”. This is done by taking the predicted marginals, and uses a convergence threshold in place of a fixed obtainedafterafixednumberofiterations,andplugging number of iterations, one obtains essentially Eaton and them into the entropy approximations used either for Ghahramani’s back belief propagation [53, extended mean field (Eq. 12) or TRW (Eq. 16). The approximate version, Fig. 5]. Here, we refer to the general strategy entropyH˜ is then used in defining a truncated partition of usingfull(non-truncated)inferenceas“backpropaga- function as tion”, either with LBP, TRW, or mean field. A˜(θ)=θ·µ(θ)−H˜(µ(θ)). 7.3 TruncatedLikelihood&TruncatedEM As we will see below, with too few inference iterations, Applying the truncated fitting strategies to any of the using this approximation can cause the surrogate likeli- marginal-based loss functions is simple. Applying it to hood to diverge. To see why, imagine an extreme case the likelihood orEMloss, however,isnotsostraightfor- where zero inference iterations are used. This results in ward.Thereasonisthattheselosses(Eqs.19and21)are thelossL(θ,x)=θ·(f(x)−µ0)+H˜(µ0),whereµ0arethe

See more

The list of books you might like

Most books are stored in the elastic cloud where traffic is expensive. For this reason, we have a limit on daily download.