Exemple #1
0
    def __init__(self, hyperparameters):
        super(Models, self).__init__()
        lr = hyperparameters['lr']
        self.model_name = hyperparameters['models_name']
        # Initiate the networks

        if (self.model_name == 'removal'):
            self.gen = Gen(hyperparameters['input_dim_a'],
                           hyperparameters['gen'])
            self.dis = MsImageDis(
                hyperparameters['input_dim_a'],
                hyperparameters['dis'])  # discriminator for domain a
        else:
            sys.exit('error on models')

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        self.gen = nn.DataParallel(self.gen)
        self.dis = nn.DataParallel(self.dis)

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        gen_params = list(self.gen.parameters())

        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        dis_params = list(self.dis.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        self.vgg = nn.DataParallel(self.vgg)
 def resume(self, checkpoint_dir, hyperparameters, need_opt=True, path=None):
     # Load generators
     if (path == None):
         last_model_name = get_model_list(checkpoint_dir, "gen")
     else:
         last_model_name = path
     state_dict = torch.load(last_model_name)
     self.gen.module.load_state_dict(state_dict['gen'])
     if self.dis_scheduler is not None:
         self.dis.module.load_state_dict(state_dict['dis'])
     iterations = int(last_model_name[-11:-3])
     if (need_opt):
         self.gen_opt.load_state_dict(state_dict['gen_opt'])
         self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
         if self.dis_scheduler is not None:
             self.dis_opt.load_state_dict(state_dict['dis_opt'])
             self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
    def __init__(self, hyperparameters):
        super(Models, self).__init__()
        lr = hyperparameters['lr']
        self.model_name = hyperparameters['models_name']
        # Initiate the networks

        if (self.model_name == 'shadow'):
            self.gen = Gen(hyperparameters['gen'])
        else:
            sys.exit('error on models')

        self.gen = nn.DataParallel(self.gen)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        gen_params = list(
            self.gen.parameters())  #+ list(self.gen_b.parameters())

        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        self.dis_scheduler = None

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
            self.vgg = nn.DataParallel(self.vgg)
Exemple #4
0
    def __init__(self, hyperparameters,live=False):
        super(Models, self).__init__()
        lr = hyperparameters['lr']
        self.model_name = hyperparameters['models_name']
        # Initiate the networks

        if(self.model_name=='end2end'):
            light_config = get_config('configs/light.yaml')
            removal_config = get_config('configs/removal.yaml')
            render_config = get_config('configs/render.yaml')
            shadow_config = get_config('configs/shadow.yaml')

            #removal hyperparameters['input_dim_a'], hyperparameters['gen']
            self.removal_gen = lib.networks.removal_network.Gen(removal_config['input_dim_a'], removal_config['gen'],live=live)

            #light
            self.light_gen = lib.networks.light_network.Gen(light_config['input_dim_a'], light_config['gen'])


            #shadow
            self.shadow_gen = lib.networks.shadow_network.Gen(shadow_config['gen'])

            #render
            self.render_gen = lib.networks.render_network.Gen(render_config['gen'])
            self.render_dis = MsImageDis(render_config['dis']['input_dim_dis'],
                                         render_config['dis'])  # discriminator for domain a
        else:
            sys.exit('error on models')

        #
        self.removal_gen = nn.DataParallel(self.removal_gen)
        self.light_gen = nn.DataParallel(self.light_gen)
        self.shadow_gen = nn.DataParallel(self.shadow_gen)
        self.render_gen = nn.DataParallel(self.render_gen)
        self.render_dis = nn.DataParallel(self.render_dis)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        gen_params = list(self.removal_gen.parameters()) + list(self.light_gen.parameters()) + \
                     list(self.shadow_gen.parameters()) + list(self.render_gen.parameters())

        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])

        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        dis_params = list(self.render_dis.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.dis_scheduler=None

        # Network weight initialization

        if(0):
            self.initialization()
        else:
            self.apply(weights_init(hyperparameters['init']))
            self.render_dis.apply(weights_init('gaussian'))

        # Load VGG model if needed
        self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
        self.vgg.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.vgg = nn.DataParallel(self.vgg)