def __init__(self): super(PhotoWCT, self).__init__() self.e1 = VGGEncoder(1) self.d1 = VGGDecoder(1) self.e2 = VGGEncoder(2) self.d2 = VGGDecoder(2) self.e3 = VGGEncoder(3) self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4)
def __init__(self, args): super(PhotoWCT, self).__init__() self.args = args if "16x" in self.args.mode: if "JointED" not in self.args.mode: ### 16x model trained for pwct e1 = '../KD/Experiments/Small16xEncoder_pwct/e1/weights/12-20181020-1610_1SE_E25S0-2.pth' e2 = '../KD/Experiments/Small16xEncoder_pwct/e2/weights/12-20181020-1602_2SE_E25S0-2.pth' e3 = '../KD/Experiments/Small16xEncoder_pwct/e3/weights/12-20181019-0420_3SE_E25S0-2.pth' e4 = '../KD/Experiments/Small16xEncoder_pwct/e4/weights/12-20181019-0349_4SE_E25S0-2.pth' d1 = '../KD/Experiments/Small16xDecoder_pwct/e1/weights/12-20181021-0913_1SD_E25S0-3.pth' d2 = '../KD/Experiments/Small16xDecoder_pwct/e2/weights/12-20181021-1418_2SD_E25S0-3.pth' d3 = '../KD/Experiments/Small16xDecoder_pwct/e3/weights/12-20181020-1638_3SD_E25S0-3.pth' d4 = '../KD/Experiments/Small16xDecoder_pwct/e4/weights/12-20181020-1637_4SD_E25S0-3.pth' else: ### 16x model trained for pwct, JointED e1 = '../KD/Experiments/Small16xEncoder_pwct/e1_JointED/weights/12-20181026-0259_1SED_E25S0-2.pth' d1 = '../KD/Experiments/Small16xEncoder_pwct/e1_JointED/weights/12-20181026-0259_1SED_E25S0-3.pth' e2 = '../KD/Experiments/Small16xEncoder_pwct/e2_JointED/weights/12-20181026-0256_2SED_E25S0-2.pth' d2 = '../KD/Experiments/Small16xEncoder_pwct/e2_JointED/weights/12-20181026-0256_2SED_E25S0-3.pth' e3 = '../KD/Experiments/Small16xEncoder_pwct/e3_JointED/weights/12-20181026-0255_3SED_E25S0-2.pth' d3 = '../KD/Experiments/Small16xEncoder_pwct/e3_JointED/weights/12-20181026-0255_3SED_E25S0-3.pth' e4 = '../KD/Experiments/Small16xEncoder_pwct/e4_JointED/weights/12-20181026-0255_4SED_E25S0-2.pth' d4 = '../KD/Experiments/Small16xEncoder_pwct/e4_JointED/weights/12-20181026-0255_4SED_E25S0-3.pth' if self.args.mode == "" or self.args.mode == "original": #### original model self.e1 = VGGEncoder(1) self.d1 = VGGDecoder(1) self.e2 = VGGEncoder(2) self.d2 = VGGDecoder(2) self.e3 = VGGEncoder(3) self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4) elif "16x" in self.args.mode: self.e1 = SmallEncoder_16x_plus(1, e1) self.d1 = SmallDecoder_16x(1, d1) self.e2 = SmallEncoder_16x_plus(2, e2) self.d2 = SmallDecoder_16x(2, d2) self.e3 = SmallEncoder_16x_plus(3, e3) self.d3 = SmallDecoder_16x(3, d3) self.e4 = SmallEncoder_16x_plus(4, e4) self.d4 = SmallDecoder_16x(4, d4) else: print("wrong mode") exit(1)
import torch import torch.nn as nn from torch.utils.serialization import load_lua from models import VGGEncoder, VGGDecoder def weight_assign(lua, pth, maps): for k, v in maps.items(): getattr(pth, k).weight = nn.Parameter(lua.get(v).weight.float()) getattr(pth, k).bias = nn.Parameter(lua.get(v).bias.float()) if __name__ == '__main__': ## VGGEncoder4 vgg4 = load_lua('pretrained/encoder.t7', long_size=8) e4 = VGGEncoder() weight_assign(vgg4, e4, { 'conv0': 0, 'conv1_1': 2, 'conv1_2': 5, 'conv2_1': 9, 'conv2_2': 12, 'conv3_1': 16, 'conv3_2': 19, 'conv3_3': 22, 'conv3_4': 25, 'conv4_1': 29, }) torch.save(e4.state_dict(), 'pretrained/encoder_pretrained.pth') ## VGGDecoder4
torch.load('pth_models/feature_invertor_conv2.pth')) p_wct.e3.load_state_dict(torch.load('pth_models/vgg_normalised_conv3.pth')) p_wct.d3.load_state_dict( torch.load('pth_models/feature_invertor_conv3.pth')) p_wct.e4.load_state_dict(torch.load('pth_models/vgg_normalised_conv4.pth')) p_wct.d4.load_state_dict( torch.load('pth_models/feature_invertor_conv4.pth')) if __name__ == '__main__': if not os.path.exists('pth_models'): os.mkdir('pth_models') ## VGGEncoder1 vgg1 = load_lua('models/vgg_normalised_conv1_1_mask.t7') e1 = VGGEncoder(1) weight_assign(vgg1, e1, { 'conv0': 0, 'conv1_1': 2, }) torch.save(e1.state_dict(), 'pth_models/vgg_normalised_conv1.pth') ## VGGDecoder1 inv1 = load_lua('models/feature_invertor_conv1_1_mask.t7') d1 = VGGDecoder(1) weight_assign(inv1, d1, { 'conv1_1': 1, }) torch.save(d1.state_dict(), 'pth_models/feature_invertor_conv1.pth') ## VGGEncoder2