deploy
CloudObject
https://www.wolframcloud.com/objects/4603f55a-b4e5-4e34-8180-388b29d0447e


Natural gradient for multilayer linear nets

Util

generateXY[e_,yvar_,extraDims_,dsize_]:=Module[{wt,mean,cov,normal,pdf,X,Xc,Y,Xa,wta,w0a,XY,n,trueCov},​​SeedRandom[0];​​n=2;​​wt={{1,1}};(*truerelation*)​​mean=0&/@Range@n;​​cov={{1,1-e},{1-e,1}};​​normal=MultinormalDistribution[mean,cov];​​X=RandomVariate[normal,{dsize}]//Transpose;​​X=centerData[X];​​Y=Dot[wt,X]+RandomVariate[NormalDistribution[0,
yvar
]];​​(*Addcopiesoffirstfeatureasredundantfeatures*)​​Xa=X~Join~Table[X[[1,All]],{i,extraDims}];​​wta=Join[wt,{Table[0,{i,extraDims}]},2];​​w0a=Join[{{1,2}},{Table[0,{i,extraDims}]},2];​​{Xa,Y,w0a}​​];​​{X0,Y0,w0}=generateXY[0.01,1.,0,1000];​​ListPlot[Transpose@X0]
-3
-2
-1
1
2
3
4
-2
2
4

Main

(*Notation:​​errorisY-W[n]....W[1].W[0]=Y-W[n].....W[1].X​​Wihassizef[i]xf[i-1]​​Yhassizef[n]xf[-1]​​Xhassizef[0]xf[-1]​​fs={f[-1],f[0],f[1],...,f[n]}​​*)​​​​​​Unprotect[lossEq,errEq,gradEq,hessEq,errF,lossF,gradF,hessF,pre0,n,dsize,W,U,makeW,Y,X,vars,Wf,allvars,A,B,Bn,hessBlock,hess,gradBlock,grad,sub1,subr,y,f,flatten,unflatten];​​Clear[lossEq,errEq,gradEq,hessEq,n,dsize,W,U,makeW,Y,X,vars,allvars,A,B,Bn,hessBlock,hess,gradBlock,grad,sub1,subr,y,f];​​​​Clear[A,B,dW,W,X,Y,y,x,fs,subW,loss,err];​​(*Wsgiveslistofmatrices,Wfmeansflattenedrepresentation*)​​(*Wfgiveslistofsymbolicvariables,W0fgivesconcretevalues*)​​​​On[Assert];​​SeedRandom[0];​​fs={10,2,2,2,1};​​dsize=First[fs];​​(*numberoflayersakanumberofmatmuls*)​​n=Length[fs]-2;​​​​dsize=First@fs;​​(*makeW[0]isX*)​​makeW[k_]:=Array[W[k],{fs[[k+2]],fs[[k+1]]}];​​makeInitializer[k_]:=RandomReal[{0,1},{fs[[k+2]],fs[[k+1]]}];​​​​vars=Table[makeW[k],{k,1,n}];​​Ws=vars;​​varsf=Flatten[vec/@vars];​​Wf=varsf;​​W0=Table[makeInitializer[k],{k,1,n}];​​(*Wtrueflat*)​​Wtf={0.8062349520611257`,0.8058281619635977`,0.8373291816030192`,0.7906864569952768`,1.1965576018047888`,0.30483187178626686`,1.2476949293612638`,0.8154112123212373`,0.12928385363990855`,0.8254921329944452`};​​(*W0flat*)​​W0f=Flatten[vec/@W0];​​(*Y-Wn....W1.W0*)​​errEq:=Y-Fold[Dot,Reverse@vars].makeW[0];​​lossEq:=take1
1
2dsize
errEq.errEq;​​​​subW[Wf_]:=Thread[varsfWf];​​​​X=Array[W[0],{fs[[2]],fs[[1]]}];​​X0=RandomReal[{0,1},{fs[[2]],fs[[1]]}];​​subX:=(​​Assert[Dimensions[X0]{fs[[2]],fs[[1]]},"X mismatch"];Thread[Flatten@XFlatten@X0]​​);​​​​Y=Array[y,{fs[[-1]],fs[[1]]}];​​Y0=RandomReal[{0,1},{fs[[-1]],fs[[1]]}];​​subY:=(​​Assert[Dimensions[Y0]{Last[fs],First[fs]},"Y mismatch"];​​Thread[Flatten@YFlatten@Y0]​​);​​​​{X0,Y0,dummy}=generateXY[0.01,1.,0,First@fs];​​​​flatten[Ws_]:=c2v[vec/@Ws];​​unflatten[Wf_]:=Module[{},​​sizes=Rest[Times@@#&/@Partition[fs,2,1]];​​flatVars=listPartition[Wf,sizes];​​Table[unvec[flatVars[[i]],fs[[i+2]]],{i,1,Length@sizes}]​​];​​​​(*Defineloss,gradient,Hessian*)​​lossf[Wf_]:=(lossEq/.subW[Wf]/.subY/.subX);​​loss[Ws_]:=lossf[flatten[Ws]];​​​​Assert[varsunflatten[varsf],"vars mismatch"];​​​​gradEq=D[lossf[Wf],{Wf,1}];​​hessEq=D[lossf[Wf],{Wf,2}];​​gradf[Wf_]:=gradEq/.subW[Wf];​​hessf[Wf_]:=hessEq/.subW[Wf];​​ihessf[Wf_]:=PseudoInverse[hessf[Wf]];​​​​(*Manualimplementationofgradient*)​​​​(*A[i]=W[i-1]...W[0]=fpropneededtocomputederivativeatlayeri*)​​A[Wf_,1]:=makeW[0]/.subX;​​A[Wf_,i_]:=Module[{},​​makeW[i-1].A[Wf,i-1]/.subW[Wf]​​];​​​​(*B[i]=W[i+1]'...err=bpropneededtocomputederivativeatlayeri*)​​err[Ws_]:=Y-A[flatten@Ws,n+1]/.subY;​​B[Wf_,n]:=err[unflatten[Wf]];​​B[Wf_,i_]:=Module[{},​​makeW[i+1].B[Wf,i+1]/.subW[Wf]​​];​​​​dW[Wf_,i_]:=-B[Wf,i].A[Wf,i]/dsize;​​​​Assert[gradf[W0f]==flatten[dW[W0f,#]&/@Range[n]],"Gradient incorrect"];​​Protect[lossEq,errEq,gradEq,hessEq,errF,lossF,gradF,hessF,pre0,n,dsize,W,U,makeW,Y,X,vars,Wf,allvars,A,B,Bn,hessBlock,hess,gradBlock,grad,sub1,subr,y,f];​​

Newton Method

Natural Gradient Method

Examine off-diagonal correlations

Plot weight magnitudes in two layers

Plot update per layer

Gradient descent

Generate data for TF