In[]:=
deploy
Wed 2 Feb 2022 18:23:07
In[]:=
Quit

Util

In[]:=
(*approximateequalitytesting*)​​DotEqual[a_,b_]:=Norm[Flatten[{a}]-Flatten[{b}],∞]<1*^-9;​​​​​​PacletInstall["TensorSimplify","Site"->"http://raw.githubusercontent.com/carlwoll/TensorSimplify/master"]​​​​<<TensorSimplify`​​einsum[in_List->out_,arrays__]:=Module[{res=isum[in->out,{arrays}]},res/;res=!=$Failed];​​​​isum[in_List->out_,arrays_List]:=Catch@Module[{indices,contracted,uncontracted,contractions,transpose},If[Length[in]!=Length[arrays],Message[einsum::length,Length[in],Length[arrays]];​​Throw[$Failed]];​​MapThread[If[IntegerQ@TensorRank[#1]&&Length[#1]!=TensorRank[#2],Message[einsum::shape,#1,#2];​​Throw[$Failed]]&,{in,arrays}];​​indices=Tally[Flatten[in,1]];​​If[DeleteCases[indices,{_,1|2}]=!={},Message[einsum::repeat,Cases[indices,{x_,Except[1|2]}:>x]];​​Throw[$Failed]];​​uncontracted=Cases[indices,{x_,1}:>x];​​If[Sort[uncontracted]=!=Sort[out],Message[einsum::output,uncontracted,out];​​Throw[$Failed]];​​contracted=Cases[indices,{x_,2}:>x];​​contractions=Flatten[Position[Flatten[in,1],#]]&/@contracted;​​transpose=FindPermutation[uncontracted,out];​​Activate@TensorTranspose[TensorContract[Inactive[TensorProduct]@@arrays,contractions],transpose]]​​​​einsum::length="Number of index specifications (`1`) does not match the number of arrays (`2`)";​​einsum::shape="Index specification `1` does not match the array depth of `2`";​​einsum::repeat="Index specifications `1` are repeated more than twice";​​einsum::output="The uncontracted indices don't match the desired output";​​​​$Assumptions=(X|M|M1|M2|M3|M4|M5|M6)∈Matrices[{d,d}];
Out[]=
PacletObject
Name: TensorSimplify
Version: 0.0.3


End-to-end examples

HVP test

Goal: get automatically derived Hessian vector product, then do recursion.
In[]:=
(*apply[{f1,f2,f3},x]givesf1@f2@f3@x*)​​On[Assert];​​ClearAll[h];​​apply[{},x_]:=x;​​apply[{f_},x_]:=f@x;​​apply[l_,x_]:=First[l]@apply[Rest[l],x];​​​​(*myversionofdotproduct,toallowinitialbackwardstepbeano-op*)​​dot[{},b_]:=b;​​dot[a_,b_]:=a.b;​​​​(*problemspec*)​​ds={2,3,4,1};​​h[1][x_]:={1x[[1]],2x[[2]],x[[1]]+x[[2]]};​​h[2][x_]:={x[[1]]*x[[1]],x[[2]]*x[[2]],x[[3]]*x[[3]],x[[1]]};​​h[3][x_]:=
2
x[[1]]
+
2
x[[2]]
+
2
x[[3]]
+
2
x[[4]]
;​​hs={h[1],h[2],h[3]};​​x0={x1,x2};​​v0={v1,v2};​​​​​​n=Length[ds]-1;(*numberoffunctionevaluations*)​​xs=Table[Array[x,d],{d,Most@ds}];(*argumentlistsforeachh*)​​Assert[Last@ds==1];​​Assert[Length@hs==Length[ds]-1]​​Assert[Length[x0]==Length[v0]==First@ds];​​H=apply[Reverse@hs,x0];​​jac=D[H,{x0,1}]//FullSimplify;​​Print["Derivative: ",jac];​​hess=D[H,{x0,2}];​​Print["Hessian: ",hess];​​Print["hvp: ",hess.v0];​​​​ClearAll[fb,b,f,F,a];​​(*activationmessages,a[i]=outputofi'thlayer.a[0]=x0*)​​a[0]:=x0;​​a[i_?Positive]:=h[i][a[i-1]];​​​​(*derivativeofi'thlayer*)​​dh[i_]:=(​​vars=xs[[i]];​​D[hs[[i]][vars],{vars,1}]/.Thread[vars->a[i-1]]​​);​​​​(*Hessianofi'thlayer*)​​d2h[i_]:=(​​vars=xs[[i]];​​D[hs[[i]][vars],{vars,2}]/.Thread[vars->a[i-1]]​​);​​​​​​jacobians=Table[dh[i],{i,1,n}];​​Assert[Reduce[jac==Dot@@Reverse[jacobians]]];​​Print["Jacobians ",MatrixForm/@jacobians]​​​​(*ForwardAD,f[i]=dh[i]....dh[1].v*)​​f[0]=v0;​​f[i_?Positive]:=dh[i].f[i-1];​​Print["Forward ad ",MatrixForm/@Table[f[i],{i,0,n}]]​​​​(*BackwardAD,*)​​b[0]={};​​b[i_?Positive]:=dot[b[i-1],dh[n-i+1]];​​Print["Backward ad ",MatrixForm/@Table[b[i],{i,1,n}]]​​​​(*Sanitycheck*)​​With[{i=1},​​Reduce[b[n-i].dh[i].f[i-1]==jac.v0]​​];​​​​(*Forwardbackwardchaindh[n]....dh[i+1].d2h[i].dh[i-1]....v*)​​fb[i_]:=dot[b[n-i],d2h[i]].f[i-1];​​​​(*Finalcombinedmessages*)​​F[0]=fb[n];​​F[i_?Positive]:=F[i-1].dh[n-i]+fb[n-i];​​F[n]:=Print["out of bounds"];​​​​Assert[Reduce[F[n-1]==hess.v0]]​​​​
Derivative: {2x1+4
3
x1
+4
3
(x1+x2)
,64
3
x2
+4
3
(x1+x2)
}
Hessian: {{2+12
2
x1
+12
2
(x1+x2)
,12
2
(x1+x2)
},{12
2
(x1+x2)
,192
2
x2
+12
2
(x1+x2)
}}
hvp: {12v2
2
(x1+x2)
+v1(2+12
2
x1
+12
2
(x1+x2)
),12v1
2
(x1+x2)
+v2(192
2
x2
+12
2
(x1+x2)
)}

Unit test A with sigmoid

Unit Test A with relu

HVP test with three layers

Derivative calculations

Derivatives of sums

Sigmoid derivative

Scratch

Diagonal and Trace

test_contract