In[]:=
deploy
Sat 15 Jan 2022 11:00:50
Compute derivatives of cross_entropy(q,
p(z))=
E
q
log(-p(z))
with respect to z, where
​​https://yaroslavvb.medium.com/using-evolved-notation-to-derive-the-hessian-of-cross-entropy-loss-195f8c7b3a92​​
In[]:=
(*approximateequalitytesting*)​​DotEqual[a_,b_]:=Norm[Flatten[{a}]-Flatten[{b}],∞]<1*^-9;​​On[Assert];​​​​softmax[z_]:=Exp[z]/Total[Exp@z];(*makeentriespositiveandaddupto1*)​​d=3;(*numberofdimensions*)​​z=Array[z00,d];(*vectorofpotentials*)​​p=softmax[z];(*vectorofprobabilities*)​​q=Array[q00,d];(*targetprobabilities*)​​​​(*substitutionrulestoreplaceq,zwithnumericvalues*)​​num:=(​​qvals=softmax[Array[1&,d]];​​zvals=Array[1&,d];​​Thread[q->qvals]~Join~Thread[z->zvals]​​);​​​​xent=Log[Total[Exp[z]]]Total[q]-z.q;​​first=D[xent,{z,1}]/.num;​​second=D[xent,{z,2}]/.num;​​third=D[xent,{z,3}]/.num;​​fourth=D[xent,{z,4}]/.num;​​fifth=D[xent,{z,5}]/.num;​​​​myFirst=(p-q)/.num;​​​​mySecond=DiagonalMatrix[p]-Outer[Times,p,p]/.num;​​secondSqrt=DiagonalMatrix[Sqrt[p]]-Outer[Times,Sqrt[p],p]/.num;​​​​Assert[first≐myFirst]​​Assert[second≐mySecond]​​Assert[Transpose[secondSqrt].secondSqrt≐mySecond]​​​​myThird="TODO";(*figureoutformulaforthirdderivativeanditsfactorization*)​​​​For[order=2,order<=10,order+=1,​​deriv=D[xent,{z,order}]/.num;​​slice=(Composition@@Table[First,order-2])@deriv;​​unique=DeleteDuplicates@Sort[Flatten@deriv];​​Print[StringForm["order=`` num unique=`` `` ",order,Length@unique,slice//MatrixForm]]​​]
order=2 num unique=2
2
9
-
1
9
-
1
9
-
1
9
2
9
-
1
9
-
1
9
-
1
9
2
9
order=3 num unique=2
2
27
-
1
27
-
1
27
-
1
27
-
1
27
2
27
-
1
27
2
27
-
1
27
order=4 num unique=4
-
2
27
1
27
1
27
1
27
-
1
27
0
1
27
0
-
1
27
order=5 num unique=5
-
10
81
5
81
5
81
5
81
-
1
81
-
4
81
5
81
-
4
81
-
1
81
order=6 num unique=6
14
243
-
7
243
-
7
243
-
7
243
11
243
-
4
243
-
7
243
-
4
243
11
243
order=7 num unique=8
98
243
-
49
243
-
49
243
-
49
243
7
81
28
243
-
49
243
28
243
7
81
order=8 num unique=10
106
729
-
53
729
-
53
729
-
53
729
-
47
729
100
729
-
53
729
100
729
-
47
729
order=9 num unique=11
-
4430
2187
2215
2187
2215
2187
2215
2187
-
1187
2187
-
1028
2187
2215
2187
-
1028
2187
-
1187
2187
order=10 num unique=14
-
2518
729
1259
729
1259
729
1259
729
-
781
2187
-
2996
2187
1259
729
-
2996
2187
-
781
2187