WOLFRAM NOTEBOOK

In[]:=
rule=6813192821526;k=3;
In[]:=
RulePlot@CellularAutomaton[{rule,k}]
Out[]=
In[]:=
ArrayPlot@CellularAutomaton[{rule,k},RandomInteger[k-1,100],50]
Out[]=
In[]:=
h=64;caNet=NetInitialize@NetGraph[{"emb"->EmbeddingLayer[2,k,"Input"->3],"key"->NetArrayLayer[{h,2}],"value"->NetArrayLayer[{h,k}],"query"->LinearLayer[2,"Biases"->None],"attn"->AttentionLayer["Dot"],"softmax"->SoftmaxLayer[]},{NetPort["Input"]->"emb"->"query",{"key","value","query"}->"attn"->"softmax"}]
Out[]=
In[]:=
trainNet=NetGraph[{"ca"->caNet,"loss"->CrossEntropyLossLayer["Index"]},{NetPort["Input"]->"ca",{"ca",NetPort["Output"]}->"loss"}]
Out[]=
(*rule=RandomInteger[
3
k
k
]*)
Out[]=
5190957382636
In[]:=
ruleData=Thread[Tuples[Reverse@Range[k],3]->IntegerDigits[rule,k,k^3]+1]
Out[]=
In[]:=
trainedNet=NetTrain[trainNet,ruleData,LearningRate->0.01,MaxTrainingRounds->10^5]
Out[]=
In[]:=
NetChain[{NetExtract[trainedNet,"ca"]},"Output"->NetDecoder[{"Class",Range[k]}]][ruleData[[All,1]]]
Out[]=
{3,3,1,1,2,1,2,1,1,1,1,2,1,2,1,1,3,3,1,1,2,3,1,3,2,1,1}
In[]:=
ruleData[[All,2]]
Out[]=
{3,3,1,1,2,1,2,1,1,1,1,2,1,2,1,1,3,3,1,1,2,3,1,3,2,1,1}
In[]:=
%==%%
Out[]=
True

Manual attention

In[]:=
attn=NetExtract[trainedNet,{"ca","attn"}];
In[]:=
queries=NetTake[NetFlatten@trainedNet,"ca/query"][ruleData[[All,1]]]
Out[]=
In[]:=
key=Normal@NetExtract[trainedNet,{"ca","key","Array"}];value=Normal@NetExtract[trainedNet,{"ca","value","Array"}];
In[]:=
Dimensions[key]Dimensions[value]Dimensions[queries]
Out[]=
{64,2}
Out[]=
{64,3}
Out[]=
{27,2}
In[]:=
weights=SoftmaxLayer[][queries.
key
];
In[]:=
Dimensions[weights]
Out[]=
{27,64}
In[]:=
MatrixPlot[Chop[weights,1*^-5]]
Out[]=
In[]:=
Chop[weights.value-(attn[<|"Query"->#,"Key"->key,"Value"->value|>]&/@queries),1*^-5]
Out[]=
In[]:=
Position[Round@SoftmaxLayer[][weights.value],1,{2}][[All,2]]==ruleData[[All,2]]
Out[]=
True
In[]:=
ruleData[[All,2]]
Out[]=
{3,3,1,1,2,1,2,1,1,1,1,2,1,2,1,1,3,3,1,1,2,3,1,3,2,1,1}
In[]:=
UnitVector[k,#]&/@ruleData[[All,2]]
Out[]=

Continuous

In[]:=
h=64;caNet=NetInitialize@NetGraph[{"emb"->EmbeddingLayer[1,k,"Input"->3],"key"->NetArrayLayer[{1,1}],"value"->NetArrayLayer[{1,1}],"query"->LinearLayer[1,"Biases"->None],"attn"->AttentionLayer["Dot"],"index"->NetChain[{FunctionLayer[Clip[Round[#],{1,k}]&],UnitVectorLayer[k],FlattenLayer[1]}]},{NetPort["Input"]->"emb"->"query",{"key","value","query"}->"attn"->"index"},"Output"->NetDecoder[{"Class",Range[k]}]]
Out[]=
In[]:=
caNet[ruleData[[All,1]]]
Out[]=
{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}
Wolfram Cloud

You are using a browser not supported by the Wolfram Cloud

Supported browsers include recent versions of Chrome, Edge, Firefox and Safari.


I understand and wish to continue anyway »

You are using a browser not supported by the Wolfram Cloud. Supported browsers include recent versions of Chrome, Edge, Firefox and Safari.