In[]:=
data=ResourceData[ResourceObject["MNIST"]][[;;,1]];RandomSample[data,20]
Out[]=
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
In[]:=
size={28,28};(*dimensionofthetargetdataimage*)channel=1;(*channelofthetargetdataimage,1forgreyscale,3forRGB*)T=200;(*howmanystepsforcorruptingtheimage*)c=32;(*basechannelsizeoftheUNETappliedhere*)Tc=16;(*encodingoftimestepforcorruptingtheimage*)depth=3;(*depthoftheUNET*)batch=256;(*mini-batchsizefortraining*)b1=10^-4;(*initialvarianceofnoise*)bT=0.02;(*finalvarianceofnoise*)b[t_]:=b1+(t-1)/(T-1)*(bT-b1);(*linearscheduleforincreasingvarianceofnoise*)
In[]:=
a[t_]:=(Table[1-b[i],{i,t}])/.List->Times;GXT[x0_,t_,noise_]:=Sqrt[a[t]]*x0+Sqrt[1-a[t]]*noise
In[]:=
Block[{steps=Join[{0,1},Range[10,200,10]],selects=RandomSample[data,6]},TableForm[Table[Image[GXT[ImageData[slc],t,RandomReal[NormalDistribution[0,1],size]]],{slc,selects},{t,steps}],TableSpacing->{.5,.5},TableAlignments->Center,TableHeadings->{None,Map[Style["t="<>ToString[#],{FontFamily->"Arial",10}]&,steps]}]]
Out[]//TableForm=
t=0 | t=1 | t=10 | t=20 | t=30 | t=40 | t=50 | t=60 | t=70 | t=80 | t=90 | t=100 | t=110 | t=120 | t=130 | t=140 | t=150 | t=160 | t=170 | t=180 | t=190 | t=200 |
In[]:=
Clear[res]
In[]:=
res[c_]:=NetGraph[<|"long"->Flatten[Table[{ConvolutionLayer[c,{3,3},"PaddingSize"->1],NormalizationLayer[],Ramp},2]][[1;;-2]],"plus"->TotalLayer[],"short"->ConvolutionLayer[c,{1,1}]|>,{NetPort["Input"]->"long"->"plus",NetPort["Input"]->"short"->"plus"}]
In[]:=
upres[c_,size_]:=NetGraph[<|"long"->{NormalizationLayer[],Ramp,ResizeLayer[size],ConvolutionLayer[c,{3,3},"PaddingSize"->1],NormalizationLayer[],Ramp,ConvolutionLayer[c,{3,3},"PaddingSize"->1]},"plus"->TotalLayer[],"short"->{ResizeLayer[size],ConvolutionLayer[c,{1,1}]}|>,{NetPort["Input"]->"long"->"plus",NetPort["Input"]->"short"->"plus"}]contract[channel_,crop_:{{1,1},{1,1}}]:=NetGraph[{"conv"->res[channel],"pooling"->PoolingLayer[2,2,"Function"->Mean],"cropping"->PartLayer[{;;,crop[[1,1]];;-crop[[1,-1]],crop[[2,1]];;-crop[[2,-1]]}]},{NetPort["Input"]->"conv"->"pooling"->NetPort["Pooling"],"conv"->"cropping"->NetPort["Shortcut"]}];expand[channel_,size_]:=NetGraph[{"deconv"->upres[channel,size],"join"->CatenateLayer[],"conv"->res[channel/2]},{NetPort["Input"]->"deconv"->"join",NetPort["Shortcut"]->"join"->"conv"}];UNet=NetGraph[<|Table[{"contract_"<>ToString[i]->contract[c*2^(i-1)],"expand_"<>ToString[i]->expand[c*2^(i-1),size/2^(i-1)],"T_F_"<>ToString[i]->{LinearLayer[c*2^(i-1)],ReplicateLayer[Floor[size/2^i],2]},"T_B_"<>ToString[i]->{LinearLayer[c*2^(i-1)],ReplicateLayer[Floor[size/2^i],2]},"thread_F_"<>ToString[i]->ThreadingLayer[Plus],"thread_B_"<>ToString[i]->ThreadingLayer[Plus]},{i,depth}],"T_F_0"->{LinearLayer[c/2],ReplicateLayer[size,2]},"thread_F_0"->ThreadingLayer[Plus],"T_B_0"->{LinearLayer[c/2],ReplicateLayer[size,2]},"thread_B_0"->ThreadingLayer[Plus],"preprocess"->ConvolutionLayer[c/2,{1,1}],"postprocess"->ConvolutionLayer[channel,{1,1}],"ubase"->res[c*(depth+1)]|>,Flatten[{NetPort["Input"]->"preprocess"->"thread_F_0"->"contract_1","preprocess"->"thread_B_0",NetPort["T"]->"T_F_0"->"thread_F_0",NetPort["T"]->"T_B_0"->"thread_B_0",Table[{NetPort["contract_"<>ToString[i],"Pooling"]->"thread_F_"<>ToString[i]->If[i<depth,"contract_"<>ToString[i+1],"ubase"->"thread_B_"<>ToString[depth]->NetPort["expand_"<>ToString[depth],"Input"]],NetPort["T"]->"T_F_"<>ToString[i]->"thread_F_"<>ToString[i],NetPort["T"]->"T_B_"<>ToString[i]->"thread_B_"<>ToString[i],NetPort["contract_"<>ToString[i],"Shortcut"]->NetPort["expand_"<>ToString[i],"Shortcut"],NetPort["expand_"<>ToString[i],"Output"]->"thread_B_"<>ToString[i-1]->If[i>1,NetPort["expand_"<>ToString[i-1],"Input"],"postprocess"]},{i,depth}]}],"Input"->Prepend[size,channel],"T"->Tc]Information[UNet,"SummaryGraphic"]
Out[]=
NetGraph
Out[]=
In[]:=
PositionEmbedding[t_,d_:Tc]:=Flatten[Table[{Sin[1./(100000^(2*i/d))*t],Cos[1./(100000^(2*i/d))*t]},{i,d/2}]]
In[]:=
Report[net_]:=Block[{result},result=Table[Block[{initial,final},initial=RandomReal[NormalDistribution[0,1],Prepend[size,channel]];final=Table[Block[{at=a[t],at1=a[t-1]},Set[initial,initial/Sqrt[1-b[t]]-b[t]/Sqrt[(1-at)]/Sqrt[1-b[t]]*net[<|"Input"->initial,"T"->PositionEmbedding[t]|>]+RandomReal[NormalDistribution[0,Sqrt[at1/at*b[t]]],Dimensions[initial]]];initial],{t,T,1,-1}];Map[NetDecoder["Image"],Append[final[[1;;-1;;8]],final[[-1]]]]],{k,10}];TableForm[result,TableSpacing->{.5,.5}]];
In[]:=
trained=NetTrain[UNet,{Function[Block[{t,X0,noise,Xt,tt},t=Table[RandomSample[Range[T]][[1]],batch];X0=NetEncoder[{"Image",size,"Grayscale"}][RandomSample[data,batch]];noise=RandomReal[NormalDistribution[0,1],Join[{batch,channel},size]];Xt=Table[GXT[X0[[i]],t[[i]],noise[[i]]],{i,batch}];tt=Map[PositionEmbedding,t];<|"Input"->Xt,"Output"->noise,"T"->tt|>]],"RoundLength"->Length[data]},LossFunction->MeanSquaredLossLayer[],BatchSize->batch,MaxTrainingRounds->10000,TargetDevice->"GPU",Method->{"ADAM","LearningRate"->10^-5},TrainingProgressReporting->{{Function@Report[#Net],"Interval"->Quantity[10,"Batches"]},"Panel"}]
Out[]=
NetGraph
In[]:=
post[final_]:=Map[NetDecoder["Image"],Append[final[[1;;-1;;8]],final[[-1]]]]