Пример #1
0
 def __init__(self):
     super(PhotoWCT, self).__init__()
     self.e1 = VGGEncoder(1)
     self.d1 = VGGDecoder(1)
     self.e2 = VGGEncoder(2)
     self.d2 = VGGDecoder(2)
     self.e3 = VGGEncoder(3)
     self.d3 = VGGDecoder(3)
     self.e4 = VGGEncoder(4)
     self.d4 = VGGDecoder(4)
Пример #2
0
    def __init__(self, args):
        super(PhotoWCT, self).__init__()
        self.args = args
        if "16x" in self.args.mode:
            if "JointED" not in self.args.mode:
                ### 16x model trained for pwct
                e1 = '../KD/Experiments/Small16xEncoder_pwct/e1/weights/12-20181020-1610_1SE_E25S0-2.pth'
                e2 = '../KD/Experiments/Small16xEncoder_pwct/e2/weights/12-20181020-1602_2SE_E25S0-2.pth'
                e3 = '../KD/Experiments/Small16xEncoder_pwct/e3/weights/12-20181019-0420_3SE_E25S0-2.pth'
                e4 = '../KD/Experiments/Small16xEncoder_pwct/e4/weights/12-20181019-0349_4SE_E25S0-2.pth'
                d1 = '../KD/Experiments/Small16xDecoder_pwct/e1/weights/12-20181021-0913_1SD_E25S0-3.pth'
                d2 = '../KD/Experiments/Small16xDecoder_pwct/e2/weights/12-20181021-1418_2SD_E25S0-3.pth'
                d3 = '../KD/Experiments/Small16xDecoder_pwct/e3/weights/12-20181020-1638_3SD_E25S0-3.pth'
                d4 = '../KD/Experiments/Small16xDecoder_pwct/e4/weights/12-20181020-1637_4SD_E25S0-3.pth'
            else:
                ### 16x model trained for pwct, JointED
                e1 = '../KD/Experiments/Small16xEncoder_pwct/e1_JointED/weights/12-20181026-0259_1SED_E25S0-2.pth'
                d1 = '../KD/Experiments/Small16xEncoder_pwct/e1_JointED/weights/12-20181026-0259_1SED_E25S0-3.pth'
                e2 = '../KD/Experiments/Small16xEncoder_pwct/e2_JointED/weights/12-20181026-0256_2SED_E25S0-2.pth'
                d2 = '../KD/Experiments/Small16xEncoder_pwct/e2_JointED/weights/12-20181026-0256_2SED_E25S0-3.pth'
                e3 = '../KD/Experiments/Small16xEncoder_pwct/e3_JointED/weights/12-20181026-0255_3SED_E25S0-2.pth'
                d3 = '../KD/Experiments/Small16xEncoder_pwct/e3_JointED/weights/12-20181026-0255_3SED_E25S0-3.pth'
                e4 = '../KD/Experiments/Small16xEncoder_pwct/e4_JointED/weights/12-20181026-0255_4SED_E25S0-2.pth'
                d4 = '../KD/Experiments/Small16xEncoder_pwct/e4_JointED/weights/12-20181026-0255_4SED_E25S0-3.pth'

        if self.args.mode == "" or self.args.mode == "original":
            #### original model
            self.e1 = VGGEncoder(1)
            self.d1 = VGGDecoder(1)
            self.e2 = VGGEncoder(2)
            self.d2 = VGGDecoder(2)
            self.e3 = VGGEncoder(3)
            self.d3 = VGGDecoder(3)
            self.e4 = VGGEncoder(4)
            self.d4 = VGGDecoder(4)
        elif "16x" in self.args.mode:
            self.e1 = SmallEncoder_16x_plus(1, e1)
            self.d1 = SmallDecoder_16x(1, d1)
            self.e2 = SmallEncoder_16x_plus(2, e2)
            self.d2 = SmallDecoder_16x(2, d2)
            self.e3 = SmallEncoder_16x_plus(3, e3)
            self.d3 = SmallDecoder_16x(3, d3)
            self.e4 = SmallEncoder_16x_plus(4, e4)
            self.d4 = SmallDecoder_16x(4, d4)
        else:
            print("wrong mode")
            exit(1)
Пример #3
0
import torch
import torch.nn as nn
from torch.utils.serialization import load_lua

from models import VGGEncoder, VGGDecoder

def weight_assign(lua, pth, maps):
    for k, v in maps.items():
        getattr(pth, k).weight = nn.Parameter(lua.get(v).weight.float())
        getattr(pth, k).bias = nn.Parameter(lua.get(v).bias.float())

if __name__ == '__main__':
    ## VGGEncoder4
    vgg4 = load_lua('pretrained/encoder.t7', long_size=8)
    e4 = VGGEncoder()
    weight_assign(vgg4, e4, {
        'conv0': 0,
        'conv1_1': 2,
        'conv1_2': 5,
        'conv2_1': 9,
        'conv2_2': 12,
        'conv3_1': 16,
        'conv3_2': 19,
        'conv3_3': 22,
        'conv3_4': 25,
        'conv4_1': 29,
    })
    torch.save(e4.state_dict(), 'pretrained/encoder_pretrained.pth')
    
    ## VGGDecoder4
Пример #4
0
        torch.load('pth_models/feature_invertor_conv2.pth'))
    p_wct.e3.load_state_dict(torch.load('pth_models/vgg_normalised_conv3.pth'))
    p_wct.d3.load_state_dict(
        torch.load('pth_models/feature_invertor_conv3.pth'))
    p_wct.e4.load_state_dict(torch.load('pth_models/vgg_normalised_conv4.pth'))
    p_wct.d4.load_state_dict(
        torch.load('pth_models/feature_invertor_conv4.pth'))


if __name__ == '__main__':
    if not os.path.exists('pth_models'):
        os.mkdir('pth_models')

    ## VGGEncoder1
    vgg1 = load_lua('models/vgg_normalised_conv1_1_mask.t7')
    e1 = VGGEncoder(1)
    weight_assign(vgg1, e1, {
        'conv0': 0,
        'conv1_1': 2,
    })
    torch.save(e1.state_dict(), 'pth_models/vgg_normalised_conv1.pth')

    ## VGGDecoder1
    inv1 = load_lua('models/feature_invertor_conv1_1_mask.t7')
    d1 = VGGDecoder(1)
    weight_assign(inv1, d1, {
        'conv1_1': 1,
    })
    torch.save(d1.state_dict(), 'pth_models/feature_invertor_conv1.pth')

    ## VGGEncoder2