Safetensors

In[]:=
StableDiffusionSynthesize=ResourceFunction["https://www.wolframcloud.com/obj/nikm/DeployedResources/Function/StableDiffusionSynthesize/"];
In[]:=
textEncoder=NetModel[{"CLIP Multi-domain Feature Extractor","InputDomain"->"Text","Architecture"->"ViT-L/14"}];
In[]:=
unet=NetModel["Stable Diffusion V1"];
In[]:=
decoder=NetModel[{"Stable Diffusion V1","Architecture"->"Decoder"}];
{textEncoder,unet,decoder}=Import/@Lookup[sdOptions,{"TextEncoder","Unet","Decoder"}];
Ugly mappings between WL and safetensors arrays:
In[]:=
textEncoderMappings=AssociationThread[Keys@Information[textEncoder,"Arrays"]->Map[If[MissingQ[#],#,If[MatchQ[#,{___List}],StringRiffle[Join[{"cond_stage_model","transformer","text_model"},#],"."]&/@#,StringRiffle[Join[{"cond_stage_model","transformer","text_model"},#],"."]]]&]@Replace[Keys@Information[textEncoder,"Arrays"],{​​{"input_embeddings","token_embeddings","Weights"}:>{"embeddings","token_embedding","weight"},​​{"input_embeddings","positional_embedding","Weights"}:>{"embeddings","position_embedding","weight"},​​{"post_normalize",p_}:>{"final_layer_norm",Replace[p,{"Scaling"->"weight","Biases"->"bias"}]},​​{"transformer",id_,norm_?(StringStartsQ["norm"]),p_}:>{"encoder","layers",ToString[id-1],"layer_"<>norm,Replace[p,{"Scaling"->"weight","Biases"->"bias"}]},​​{"transformer",id_,"mlp",linear_?(StringStartsQ["linear"]),"Net",p_}:>{"encoder","layers",ToString[id-1],"mlp",StringReplace[linear,"linear"->"fc"],Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​{"transformer",id_,"self-attention","input_project","Net",p_}:>Table[{"encoder","layers",ToString[id-1],"self_attn",i<>"_proj",Replace[p,{"Weights"->"weight","Biases"->"bias"}]},{i,{"q","k","v"}}],​​{"transformer",id_,"self-attention","output_project","Net",p_}:>{"encoder","layers",ToString[id-1],"self_attn","out_proj",Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​_->Missing[]​​},{1}]​​];
In[]:=
In[]:=
decoderMappings=AssociationThread[Keys@Information[decoder,"Arrays"]->Map[If[MissingQ[#],#,If[MatchQ[#,{___List}],StringRiffle[Join[{"first_stage_model"},#],"."]&/@#,StringRiffle[Join[{"first_stage_model"},#],"."]]]&]@Replace[Keys@Information[decoder,"Arrays"],{​​{"post_quant_conv",p_}:>{"post_quant_conv",Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​{conv:"conv_in"|"conv_out",p_}:>{"decoder",conv,Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​{"conv_norm_out",4,p_}:>{"decoder","norm_out",Replace[p,{"Scaling"->"weight","Biases"->"bias"}]},​​​​{"mid_block",subBlock:1|3,norm:"norm1"|"norm2",4,p_}:>{"decoder","mid",Replace[subBlock,{1->"block_1",3->"block_2"}],norm,Replace[p,{"Scaling"->"weight","Biases"->"bias"}]},​​​​{"mid_block",subBlock:1|3,conv:"conv1"|"conv2",p_}:>{"decoder","mid",Replace[subBlock,{1->"block_1",3->"block_2"}],conv,Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​{"mid_block",2,"toSequence",1,4,p_}:>{"decoder","mid","attn_1","norm",Replace[p,{"Scaling"->"weight","Biases"->"bias"}]},​​{"mid_block",2,"attention",layer_,"Net",p_}:>{"decoder","mid","attn_1",Replace[layer,{"key"->"k","value"->"v","query"->"q","output"->"proj_out"}],Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​​​{block_,subBlock_,conv:"conv1"|"conv2",p_}:>{"decoder","up",4-Interpreter["Integer"][StringTake[block,-1]],"block",subBlock-1,conv,Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​​​{block_,subBlock_,norm:"norm1"|"norm2",4,p_}:>{"decoder","up",4-Interpreter["Integer"][StringTake[block,-1]],"block",subBlock-1,norm,Replace[p,{"Scaling"->"weight","Biases"->"bias"}]},​​​​{block_,1,"conv_shortcut",p_}:>{"decoder","up",4-Interpreter["Integer"][StringTake[block,-1]],"block.0","nin_shortcut",Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​{block_,4,"upsample",p_}:>{"decoder","up",4-Interpreter["Integer"][StringTake[block,-1]],"upsample","conv",Replace[p,{"Weights"->"weight","Biases"->"bias"}]},​​​​_->Missing[]​​},{1}]​​];
(*tensors=ExternalEvaluate["Python","from safetensors import safe_openimport h5py​with safe_open('/Users/swish/Downloads/aZovyaRPGArtistTools_v3.safetensors', framework='pt') as read: with h5py.File('/Users/swish/Downloads/aZovyaRPGArtistTools_v3.h5', 'w') as write: for key in read.keys(): write.create_dataset(key, data = read.get_tensor(key))"]*)
In[]:=
loadTensors[path_]:=With[{datasetNames=Import[path,"Datasets"]},​​AssociationThread[StringDrop[datasetNames,1],Import[path,{"Datasets",datasetNames}]]​​]
tensors=loadTensors["/mnt/efs/aZovyaRPGArtistTools_v3.h5"];
Dimension test:
In[]:=
unetDimensions=Dimensions/@Information[unet,"Arrays"];​​textEncoderDimensions=Dimensions/@Information[textEncoder,"Arrays"];​​decoderDimensions=Dimensions/@Information[decoder,"Arrays"];
In[]:=
Select[Association@KeyValueMap[#1->Through[{SameQ,List}[Times@@Dimensions@tensors[#2],Times@@decoderDimensions[#1]]]&,decoderMappings],MatchQ@{False,_}]​​Select[Association@KeyValueMap[#1->Through[{SameQ,List}[Times@@Dimensions@tensors[#2],Times@@unetDimensions[#1]]]&,unetMappings],MatchQ@{False,_}]​​Select[Association@KeyValueMap[#1->Through[{SameQ,List}[If[ListQ[#2],Total[Times@@@Dimensions/@tensors/@#2],Times@@Dimensions@tensors[#2]],Times@@textEncoderDimensions[#1]]]&,textEncoderMappings],MatchQ@{False,_}]​​
Out[]=

Out[]=
{time_proj,expOut,Array}{False,{2,160}}
Out[]=
{embed,Weights}{False,{2,589824}}
In[]:=
makeMappingTensors[net_,mappings_,tensors_]:=KeyValueMap[#1->ArrayReshape[Lookup[tensors,#2],Dimensions[NetExtract[net,#1]]]&,DeleteMissing@mappings]
In[]:=
modifiedTextEncoder=NetReplacePart[textEncoder,makeMappingTensors[textEncoder,textEncoderMappings,tensors]];​​modifiedUnet=NetReplacePart[unet,makeMappingTensors[unet,unetMappings,tensors]];​​modifiedDecoder=NetReplacePart[decoder,makeMappingTensors[decoder,decoderMappings,tensors]];
In[]:=
Export["/mnt/efs/aZovyaRPGArtistTools_v3/TextEncoder.wlnet",modifiedTextEncoder]​​Export["/mnt/efs/aZovyaRPGArtistTools_v3/Unet.wlnet",modifiedUnet]​​Export["/mnt/efs/aZovyaRPGArtistTools_v3/Decoder.wlnet",modifiedDecoder]
Out[]=
/mnt/efs/aZovyaRPGArtistTools_v3/TextEncoder.wlnet
Out[]=
/mnt/efs/aZovyaRPGArtistTools_v3/Unet.wlnet
Out[]=
/mnt/efs/aZovyaRPGArtistTools_v3/Decoder.wlnet
In[]:=
modifiedTextEncoder=Import["/mnt/efs/aZovyaRPGArtistTools_v3/TextEncoder.wlnet"];​​modifiedUnet=Import["/mnt/efs/aZovyaRPGArtistTools_v3/Unet.wlnet"];​​modifiedDecoder=Import["/mnt/efs/aZovyaRPGArtistTools_v3/Decoder.wlnet"];
In[]:=
StableDiffusionSynthesize["A cat in a party hat","TextEncoder"->modifiedTextEncoder,"Unet"->modifiedUnet,"Decoder"->modifiedDecoder,TargetDevice->"GPU"]
Out[]=

LoRA