In[]:=
$MachineName
Out[]=
threadripper2

Pipeline

Data generation

[Generate evolution, then train by masking the second half of the evolution]
<Alternative approach not taken: predict a smaller slice, then iterate that>
<Could mask different regions at random>
[[ Make a version where the line is cyclic, and where it is wraps around the region ]]
In[]:=
lineData[]:=Rasterize[Graphics[{White,Thick,Table[Line[{{RandomInteger[32],0},{RandomInteger[32],64}}],1]},Background->Black,ImageSize->{32,64}],RasterSize->{32,64}]
In[]:=
lineData[]
Out[]=

Network

In[]:=
conv[channels_,size_:{3,3},pad_:1]:=ConvolutionLayer[channels,size,PaddingSize->pad]​​groupNorm[input_,groups_:32]:=NetChain[{ReshapeLayer[MapAt[Splice[{groups,#/groups}]&,input,{1}]],NormalizationLayer[2;;,;;2,"Epsilon"->1*^-6],ReshapeLayer[input]},"Input"->input]​​convBlock[input_,groups_:32]:=NetFlatten[NetGraph[{​​{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]]},ThreadingLayer[Plus]},{NetPort["Input"]->1,{NetPort["Input"],1}->2}],1]​​downBlock[input_,groups_:32]:=NetFlatten[NetGraph[{​​{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[2First[input]]},conv[2First[input],{1,1},0],ThreadingLayer[Plus]},{NetPort["Input"]->{1,2}->3}],1]​​upBlock[input_,groups_:32]:=NetFlatten[NetGraph[{​​{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]/2]},conv[First[input]/2,{1,1},0],ThreadingLayer[Plus]},{NetPort["Input"]->{1,2}->3}],1]​​attention[dim_]:=NetGraph[{"key"->NetMapOperator[LinearLayer[dim]],"value"->NetMapOperator[LinearLayer[dim]],"query"->NetMapOperator[LinearLayer[dim]],"attention"->AttentionLayer["Dot","ScoreRescaling""DimensionSqrt"],"output"->NetMapOperator[LinearLayer[dim]]},{NetPort["Input"]->{"key","value","query"}->"attention"->"output"},"Input"->{"Varying",dim}]​​attentionBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],FlattenLayer[-1],TransposeLayer[],attention[First@input],TransposeLayer[],ReshapeLayer[input]},ThreadingLayer[Plus]},{NetPort["Input"]->1,{NetPort["Input"],1}->2}],1]​​​​downsample[input_]:=ConvolutionLayer[First@input,{3,3},"Stride"->2,PaddingSize->{{0,1},{0,1}}]​​upsample[input_]:=NetChain[{ResizeLayer[{Scaled[2],Scaled[2]},Resampling->"Nearest","Scheme"->"Bin"],conv[First@input]},"Input"->input]
In[]:=
encoder[input_,outChannels_:4,convChannels_:128,blocks_:5,defGroups_:Automatic]:=Enclose@Block[{net,groups=Replace[defGroups,Automatic:>convChannels/4]},​​net=NetChain[{conv[convChannels]},"Input"->input];​​Do[​​net=NetAppend[net,If[i==1||i>=blocks-1,convBlock,downBlock][NetExtract[net,"Output"],groups]];​​If[i==blocks,net=NetAppend[net,attentionBlock[NetExtract[net,"Output"],groups]]];​​net=NetAppend[net,convBlock[NetExtract[net,"Output"],groups]];​​If[i<blocks-1,net=NetAppend[net,downsample[NetExtract[net,"Output"]]]]​​,​​{i,blocks}​​];​​net=NetAppend[net,{groupNorm[NetExtract[net,"Output"],groups],conv[2outChannels],conv[2outChannels,1,0],NetGraph[{PartLayer[;;outChannels],PartLayer[outChannels+1;;]},{NetPort["Input"]->{1,2},1->NetPort["Mean"],2->NetPort["LogVar"]}]}];​​net​​]​​decoder[input_,outChannels_:1,convChannels_:128,blocks_:5,defGroups_:Automatic]:=Enclose@Block[{net,groups=Replace[defGroups,Automatic:>convChannels/16]},​​net=NetChain[{conv[First[input],{1,1},0],conv[convChannels]},"Input"->input];​​Do[​​net=NetAppend[net,If[i>2,upBlock,convBlock][NetExtract[net,"Output"],groups]];​​If[i==1,net=NetAppend[net,attentionBlock[NetExtract[net,"Output"],groups]]];​​net=NetAppend[net,{convBlock[NetExtract[net,"Output"],groups],convBlock[NetExtract[net,"Output"],groups]}];​​If[1<i<blocks,net=NetAppend[net,upsample[NetExtract[net,"Output"]]]];​​,​​{i,blocks}​​];​​net=NetAppend[net,{groupNorm[NetExtract[net,"Output"],groups],ElementwiseLayer["Swish"],conv[outChannels]}];​​net​​]
[ Loss function is in fact fixed, not variational ]
In[]:=
VAE[input_,encoderArgs_List:{},decoderArgs_List:{}]:=With[{enc=encoder[input,Sequence@@encoderArgs]},​​NetGraph[<|​​"encoder"->enc,​​"exp"->ElementwiseLayer[Exp],​​(*"z"->RandomArrayLayer[NormalDistribution[#Mean,#Var]&],*)​​"decoder"->decoder[NetExtract[enc,"Mean"],First[input],Sequence@@decoderArgs],​​(*"loss"->With[{n=Times@@NetExtract[enc,"Mean"]},FunctionLayer[(Total[
2
Unevaluated[Flatten][#Input-#Output]
]-Total[Unevaluated[Flatten][1+#LogVar-#Mean^2-#Var]])&]],*)​​"loss"->With[{n=Times@@NetExtract[enc,"Mean"]},FunctionLayer[(Total[
2
Unevaluated[Flatten][#Input-#Output]
](*-Total[Unevaluated[Flatten][1+#LogVar-#Mean^2-#Var]]*))&]]​​|>,​​{​​(*NetPort[{"encoder","Mean"}]->NetPort[{"z","Mean"}],*)​​(*NetPort[{"encoder","LogVar"}]->"exp"->NetPort[{"z","Var"}],*)​​(*"z"->"decoder",*)​​NetPort[{"encoder","Mean"}]->"decoder",​​​​NetPort["Input"]->NetPort[{"loss","Input"}],​​"decoder"->NetPort[{"loss","Output"}],​​(*NetPort[{"encoder","Mean"}]->NetPort[{"loss","Mean"}],*)​​(*NetPort[{"encoder","LogVar"}]->NetPort[{"loss","LogVar"}],*)​​(*"exp"->NetPort[{"loss","Var"}],*)​​NetPort[{"encoder","Mean"}]->NetPort["Latent"],​​"loss"->NetPort["Loss"]​​}​​]​​]
In[]:=
vae=VAE[{1,32,32}];
In[]:=
predictionNet[input_,layers_:3]:=With[{size=Times@@input},​​NetGraph[<|​​"x"->PartLayer[1],"y"->PartLayer[2],"predict"->NetChain[{FlattenLayer[],Splice@Table[Splice@{LinearLayer[size],ElementwiseLayer["ReLU"]},layers],LinearLayer[size],ReshapeLayer[input]}],​​"loss"->MeanSquaredLossLayer[]​​|>,{NetPort["Input"]->{"x","y"},"x"->"predict",{"predict","y"}->"loss"->NetPort["Loss"]}]​​]
In[]:=
net=NetInitialize@NetGraph[<|​​"prep"->NetChain[{ReshapeLayer[{1,2,32,32}],TransposeLayer[1<->2]}],​​"vae"->NetMapThreadOperator[vae],​​"VAETotalLoss"->AggregationLayer[Total,1]​​,"TotalLoss"->TotalLayer[],​​"prediction"->predictionNet[NetExtract[vae,"Latent"]]​​|>,​​{​​NetPort["Input"]->"prep"->"vae",​​NetPort[{"vae","Latent"}]->"prediction",​​NetPort[{"vae","Loss"}]->"VAETotalLoss",​​{"VAETotalLoss",NetPort[{"prediction","Loss"}]}->"TotalLoss"(*"VAETotalLoss"*)->NetPort["Loss"]​​},​​"Input"->NetEncoder[{"Image",{32,64},ColorSpace->"Grayscale"}]];
In[]:=
net

Extracted pieces

Encoded version is 4×4×4 tensor [ same channel number as for stable diffusion ]

Trained net

Training

Inference

Plan

Single lines [with wraparound]

Multiple non-interacting lines

Bouncing lines

CAs

5 color, nearest neighbor, symmetric, quiescent

Case 1: one network per rule (“specialist”)

[ Train and run a bunch of cases; rank them by how well they work ]

Case 2: a “generalist” network