Example #1
0
    def __init__(self, opt, requires_grad=False):
        super(VGGFace19, self).__init__()
        self.model = VGGEncoder(opt)
        self.opt = opt
        ckpt = torch.load(opt.VGGFace_pretrain_path)
        print("=> loading checkpoint '{}'".format(opt.VGGFace_pretrain_path))
        util.copy_state_dict(ckpt, self.model)
        vgg_pretrained_features = self.model.model.features
        len_features = len(self.model.model.features)
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.slice6 = torch.nn.Sequential()

        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        for x in range(30, len_features):
            self.slice6.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False
Example #2
0
    def load_separately(self, network, network_label, opt):
        load_path = None
        if network_label == 'G':
            load_path = opt.G_pretrain_path
        elif network_label == 'D':

            load_path = opt.D_pretrain_path
        elif network_label == 'D_rotate':
            load_path = opt.D_rotate_pretrain_path
        elif network_label == 'E':
            load_path = opt.E_pretrain_path
        elif network_label == 'A':
            load_path = opt.A_pretrain_path
        elif network_label == 'A_sync':
            load_path = opt.A_sync_pretrain_path
        elif network_label == 'V':
            load_path = opt.V_pretrain_path

        if load_path is not None:
            if os.path.isfile(load_path):
                print("=> loading checkpoint '{}'".format(load_path))
                checkpoint = torch.load(load_path)
                util.copy_state_dict(checkpoint,
                                     network,
                                     strip='MobileNet',
                                     replace='model')
        else:
            print("no load_path")
        return network
Example #3
0
def load_network(cfg):
    net = RainNet(input_nc=cfg.input_nc,
                  output_nc=cfg.output_nc,
                  ngf=cfg.ngf,
                  norm_layer=RAIN,
                  use_dropout=not cfg.no_dropout)

    load_path = os.path.join(cfg.checkpoints_dir, cfg.name, 'net_G_last.pth')
    assert os.path.exists(load_path), print(
        '%s not exists. Please check the file' % (load_path))
    print(f'loading the model from {load_path}')
    state_dict = torch.load(load_path, map_location='cpu')
    util.copy_state_dict(net.state_dict(), state_dict)
    # net.load_state_dict(state_dict)
    return net
Example #4
0
def load_network(cfg):
    net = RainNet(input_nc=cfg.input_nc,
                  output_nc=cfg.output_nc,
                  ngf=cfg.ngf,
                  norm_layer=RAIN,
                  use_dropout=not cfg.no_dropout)
    ######################################## net_G_last.pth is better
    load_path = os.path.join(cfg.checkpoints_dir, cfg.name, 'net_G_last.pth')
    if not os.path.exists(load_path):
        raise FileExistsError, print('%s not exists. Please check the file' %
                                     (load_path))
    print(f'loading the model from {load_path}')
    state_dict = torch.load(load_path)
    util.copy_state_dict(net.state_dict(), state_dict)
    # net.load_state_dict(state_dict)
    return net
Example #5
0
    def load_separately(self, network, network_label, opt):
        load_path = None
        if network_label == 'G':
            load_path = opt.G_pretrain_path
        elif network_label == 'D':

            load_path = opt.D_pretrain_path
        elif network_label == 'D_rotate':
            load_path = opt.D_rotate_pretrain_path
        elif network_label == 'E':
            load_path = opt.E_pretrain_path

        if load_path is not None:
            if os.path.isfile(load_path):
                print("=> loading checkpoint '{}'".format(load_path))
                checkpoint = torch.load(load_path)
                util.copy_state_dict(checkpoint, network)
        else:
            print("no load_path")
        return network
Example #6
0
 def __init__(self, opt, requires_grad=False):
     super(VGGFace19, self).__init__()
     model = torchvision.models.vgg19_bn(pretrained=False)
     ckpt = torch.load(opt.vggface_checkpoint)['state_dict']
     util.copy_state_dict(ckpt, model, 'module.base.')
     vgg_pretrained_features = model.features
     self.slice1 = torch.nn.Sequential()
     self.slice2 = torch.nn.Sequential()
     self.slice3 = torch.nn.Sequential()
     self.slice4 = torch.nn.Sequential()
     self.slice5 = torch.nn.Sequential()
     for x in range(2):
         self.slice1.add_module(str(x), vgg_pretrained_features[x])
     for x in range(2, 7):
         self.slice2.add_module(str(x), vgg_pretrained_features[x])
     for x in range(7, 12):
         self.slice3.add_module(str(x), vgg_pretrained_features[x])
     for x in range(12, 21):
         self.slice4.add_module(str(x), vgg_pretrained_features[x])
     for x in range(21, 30):
         self.slice5.add_module(str(x), vgg_pretrained_features[x])
     if not requires_grad:
         for param in self.parameters():
             param.requires_grad = False
Example #7
0
 def load_pretrain(self):
     check_point = torch.load(model_urls['resnext50_32x4d'])
     util.copy_state_dict(check_point, self.model)