Exemplo n.º 1
0
    def __init__(self, cfg):
        super(Trainer, self).__init__()

        if cfg['mode'] == 'funit':
            self.model = FUNITModel(cfg)
        elif cfg['mode'] == 'g2g':
            self.model = G2GModel(cfg)
        else:
            raise ValueError(
                "Choose from the following two modes: 'funit' or 'g2g'.")

        lr_gen = cfg['lr_gen']
        lr_dis = cfg['lr_dis']

        dis_params = list(self.model.dis.parameters())
        gen_params = list(self.model.gen.parameters())

        self.dis_opt = torch.optim.RMSprop(
            [p for p in dis_params if p.requires_grad],
            lr=lr_gen,
            weight_decay=cfg['weight_decay'])
        self.gen_opt = torch.optim.RMSprop(
            [p for p in gen_params if p.requires_grad],
            lr=lr_dis,
            weight_decay=cfg['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, cfg)
        self.gen_scheduler = get_scheduler(self.gen_opt, cfg)
        self.apply(weights_init(cfg['init']))
        self.model.gen_test = copy.deepcopy(
            G2GModel(cfg).gen)  # should be performed in load_ckpt()
Exemplo n.º 2
0
 def __init__(self, cfg):
     super(Trainer, self).__init__()
     self.model = FUNITModel(cfg)
     lr_gen = cfg['lr_gen']
     lr_dis = cfg['lr_dis']
     dis_params = list(self.model.dis.parameters())
     gen_params = list(self.model.gen.parameters())
     self.dis_opt = torch.optim.RMSprop(
         [p for p in dis_params if p.requires_grad],
         lr=lr_gen, weight_decay=cfg['weight_decay'])
     self.gen_opt = torch.optim.RMSprop(
         [p for p in gen_params if p.requires_grad],
         lr=lr_dis, weight_decay=cfg['weight_decay'])
     self.dis_scheduler = get_scheduler(self.dis_opt, cfg)
     self.gen_scheduler = get_scheduler(self.gen_opt, cfg)
     self.apply(weights_init(cfg['init']))
     self.model.gen_test = copy.deepcopy(self.model.gen)
Exemplo n.º 3
0
    def __init__(self, cfg):
        super(Trainer, self).__init__()
        self.model = FUNITModel(cfg)
        lr_gen = cfg['lr_gen']
        lr_dis = cfg['lr_dis']
        dis_params = list(self.model.dis.parameters())
        gen_params = list(self.model.gen.parameters())

        if (GlobalConstants.getOptimizer().upper() == "Adam".upper()):
            Optimizer = torch.optim.Adam
        elif (GlobalConstants.getOptimizer().upper() == "RMSprop".upper()):
            Optimizer = torch.optim.RMSprop
        else:
            print(GlobalConstants.getOptimizer(), "is currently not supported")

        self.dis_opt = Optimizer(
            [p for p in dis_params if p.requires_grad],
            lr=lr_gen, weight_decay=cfg['weight_decay'])
        self.gen_opt = Optimizer(
            [p for p in gen_params if p.requires_grad],
            lr=lr_dis, weight_decay=cfg['weight_decay'])


        self.model.cuda()
        # APEX initialization
        if (GlobalConstants.usingApex):
            opt_level = 'O0'
            self.model, [self.dis_opt, self.gen_opt] = amp.initialize(
                self.model, [self.dis_opt, self.gen_opt], opt_level=opt_level, num_losses=4,
                max_loss_scale=2**0,
                verbosity=1 #For now
                )
            self.model.setOptimizersForApex(self.dis_opt, self.gen_opt)

        self.dis_scheduler = get_scheduler(self.dis_opt, cfg)
        self.gen_scheduler = get_scheduler(self.gen_opt, cfg)
        self.apply(weights_init(cfg['init']))
        self.model.gen_test = copy.deepcopy(self.model.gen)
Exemplo n.º 4
0
class Trainer(nn.Module):
    def __init__(self, cfg):
        super(Trainer, self).__init__()
        self.model = FUNITModel(cfg)
        lr_gen = cfg['lr_gen']
        lr_dis = cfg['lr_dis']
        dis_params = list(self.model.dis.parameters())
        gen_params = list(self.model.gen.parameters())
        self.dis_opt = torch.optim.RMSprop(
            [p for p in dis_params if p.requires_grad],
            lr=lr_gen,
            weight_decay=cfg['weight_decay'])
        self.gen_opt = torch.optim.RMSprop(
            [p for p in gen_params if p.requires_grad],
            lr=lr_dis,
            weight_decay=cfg['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, cfg)
        self.gen_scheduler = get_scheduler(self.gen_opt, cfg)
        self.apply(weights_init(cfg['init']))
        self.model.gen_test = copy.deepcopy(self.model.gen)

    def gen_update(self, co_data, cl_data, hp, multigpus):
        self.gen_opt.zero_grad()
        al, ad, xr, cr, sr, ac = self.model(co_data, cl_data, hp, 'gen_update')
        self.loss_gen_total = torch.mean(al)
        self.loss_gen_recon_x = torch.mean(xr)
        self.loss_gen_recon_c = torch.mean(cr)
        self.loss_gen_recon_s = torch.mean(sr)
        self.loss_gen_adv = torch.mean(ad)
        self.accuracy_gen_adv = torch.mean(ac)
        self.gen_opt.step()
        this_model = self.model.module if multigpus else self.model
        update_average(this_model.gen_test, this_model.gen)
        return self.accuracy_gen_adv.item()

    def dis_update(self, co_data, cl_data, hp):
        self.dis_opt.zero_grad()
        al, lfa, lre, reg, acc = self.model(co_data, cl_data, hp, 'dis_update')
        self.loss_dis_total = torch.mean(al)
        self.loss_dis_fake_adv = torch.mean(lfa)
        self.loss_dis_real_adv = torch.mean(lre)
        self.loss_dis_reg = torch.mean(reg)
        self.accuracy_dis_adv = torch.mean(acc)
        self.dis_opt.step()
        return self.accuracy_dis_adv.item()

    def test(self, co_data, cl_data, multigpus):
        this_model = self.model.module if multigpus else self.model
        return this_model.test(co_data, cl_data)

    def resume(self, checkpoint_dir, hp, multigpus):
        this_model = self.model.module if multigpus else self.model

        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        this_model.gen.load_state_dict(state_dict['gen'])
        this_model.gen_test.load_state_dict(state_dict['gen_test'])
        iterations = int(last_model_name[-11:-3])

        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        this_model.dis.load_state_dict(state_dict['dis'])

        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hp, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hp, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations, multigpus):
        this_model = self.model.module if multigpus else self.model
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save(
            {
                'gen': this_model.gen.state_dict(),
                'gen_test': this_model.gen_test.state_dict()
            }, gen_name)
        torch.save({'dis': this_model.dis.state_dict()}, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)

    def load_ckpt(self, ckpt_name):
        state_dict = torch.load(ckpt_name)
        self.model.gen.load_state_dict(state_dict['gen'])
        self.model.gen_test.load_state_dict(state_dict['gen_test'])

    def translate(self, co_data, cl_data):
        return self.model.translate(co_data, cl_data)

    def translate_k_shot(self, co_data, cl_data, k, mode):
        return self.model.translate_k_shot(co_data, cl_data, k, mode)

    def forward(self, *inputs):
        print('Forward function not implemented.')
        pass

    def debug(self, co_data):
        return self.model.debug(co_data)
Exemplo n.º 5
0
class Trainer(nn.Module):
    def __init__(self, cfg):
        super(Trainer, self).__init__()
        self.model = FUNITModel(cfg)
        lr_gen = cfg['lr_gen']
        lr_dis = cfg['lr_dis']
        dis_params = list(self.model.dis.parameters())
        gen_params = list(self.model.gen.parameters())

        if (GlobalConstants.getOptimizer().upper() == "Adam".upper()):
            Optimizer = torch.optim.Adam
        elif (GlobalConstants.getOptimizer().upper() == "RMSprop".upper()):
            Optimizer = torch.optim.RMSprop
        else:
            print(GlobalConstants.getOptimizer(), "is currently not supported")

        self.dis_opt = Optimizer(
            [p for p in dis_params if p.requires_grad],
            lr=lr_gen, weight_decay=cfg['weight_decay'])
        self.gen_opt = Optimizer(
            [p for p in gen_params if p.requires_grad],
            lr=lr_dis, weight_decay=cfg['weight_decay'])


        self.model.cuda()
        # APEX initialization
        if (GlobalConstants.usingApex):
            opt_level = 'O0'
            self.model, [self.dis_opt, self.gen_opt] = amp.initialize(
                self.model, [self.dis_opt, self.gen_opt], opt_level=opt_level, num_losses=4,
                max_loss_scale=2**0,
                verbosity=1 #For now
                )
            self.model.setOptimizersForApex(self.dis_opt, self.gen_opt)

        self.dis_scheduler = get_scheduler(self.dis_opt, cfg)
        self.gen_scheduler = get_scheduler(self.gen_opt, cfg)
        self.apply(weights_init(cfg['init']))
        self.model.gen_test = copy.deepcopy(self.model.gen)

    def gen_update(self, co_data, cl_data, hp, multigpus, it):
        self.gen_opt.zero_grad()
        adverserial_loss, ad, xr, cr, sr, ac = self.model(co_data, cl_data, hp, 'gen_update')
        self.loss_gen_total = torch.mean(adverserial_loss)
        self.loss_gen_recon_x = torch.mean(xr)
        self.loss_gen_recon_c = torch.mean(cr)
        self.loss_gen_recon_s = torch.mean(sr)
        self.loss_gen_adv = torch.mean(ad)
        self.accuracy_gen_adv = torch.mean(ac)
        self.gen_opt.step()
        this_model = self.model.module if multigpus else self.model
        update_average(this_model.gen_test, this_model.gen)
        return self.accuracy_gen_adv.item()

    def dis_update(self, co_data, cl_data, hp, it):
        self.dis_opt.zero_grad()

        #print("--------PRINTING SUMMARY--------")
        #summary(self.model, [co_data, cl_data, hp, 'dis_update'])
        
        adverserial_loss, loss_dis_fake_adv, l_reconst, reg, acc = self.model(co_data, cl_data, hp, 'dis_update')
        self.loss_dis_total = torch.mean(adverserial_loss)
        self.loss_dis_fake_adv = torch.mean(loss_dis_fake_adv)
        self.loss_dis_real_adv = torch.mean(l_reconst)
        self.loss_dis_reg = torch.mean(reg)
        self.accuracy_dis_adv = torch.mean(acc)
        self.dis_opt.step()
        return self.accuracy_dis_adv.item()

    def test(self, co_data, cl_data, multigpus):
        this_model = self.model.module if multigpus else self.model
        return this_model.test(co_data, cl_data)

    def resume(self, checkpoint_dir, hp, multigpus):
        this_model = self.model.module if multigpus else self.model

        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        this_model.gen.load_state_dict(state_dict['gen'])
        this_model.gen_test.load_state_dict(state_dict['gen_test'])
        iterations = int(last_model_name[-11:-3])

        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        this_model.dis.load_state_dict(state_dict['dis'])

        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hp, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hp, iterations)

        if (GlobalConstants.usingApex):
            state_dict = torch.load(os.path.join(checkpoint_dir, 'amp.pt'))
            amp.load_state_dict(state_dict['amp'])
        print('Resume from iteration %d' % iterations)
        return iterations

    def load_ckpt(self, ckpt_name):
        state_dict = torch.load(ckpt_name)
        self.model.gen.load_state_dict(state_dict['gen'])
        self.model.gen_test.load_state_dict(state_dict['gen_test'])

    def save(self, snapshot_dir, iterations, multigpus):
        this_model = self.model.module if multigpus else self.model
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'gen': this_model.gen.state_dict(),
                    'gen_test': this_model.gen_test.state_dict()}, gen_name)
        torch.save({'dis': this_model.dis.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(),
                    'dis': self.dis_opt.state_dict()}, opt_name)
        if (GlobalConstants.usingApex):
            amp_name = os.path.join(snapshot_dir, "amp.pt")
            torch.save({'amp': amp.state_dict()}, amp_name)

    def translate(self, co_data, cl_data):
        return self.model.translate(co_data, cl_data)

    def translate_k_shot(self, co_data, cl_data, k, mode):
        return self.model.translate_k_shot(co_data, cl_data, k, mode)

    def forward(self, *inputs):
        print('Forward function not implemented.')
        pass
Exemplo n.º 6
0
class Trainer(nn.Module):
    def __init__(self, cfg):
        super(Trainer, self).__init__()

        if cfg['mode'] == 'funit':
            self.model = FUNITModel(cfg)
        elif cfg['mode'] == 'g2g':
            self.model = G2GModel(cfg)
        else:
            raise ValueError(
                "Choose from the following two modes: 'funit' or 'g2g'.")

        lr_gen = cfg['lr_gen']
        lr_dis = cfg['lr_dis']

        dis_params = list(self.model.dis.parameters())
        gen_params = list(self.model.gen.parameters())

        self.dis_opt = torch.optim.RMSprop(
            [p for p in dis_params if p.requires_grad],
            lr=lr_gen,
            weight_decay=cfg['weight_decay'])
        self.gen_opt = torch.optim.RMSprop(
            [p for p in gen_params if p.requires_grad],
            lr=lr_dis,
            weight_decay=cfg['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, cfg)
        self.gen_scheduler = get_scheduler(self.gen_opt, cfg)
        self.apply(weights_init(cfg['init']))
        self.model.gen_test = copy.deepcopy(
            G2GModel(cfg).gen)  # should be performed in load_ckpt()

    def gen_update(self, co_data, cl_data, hp, multigpus):
        """
        Params:
            - co_data: content data with one content image and one content image label
            - cl_data: class data with one class image and one class image label
            - hp: hyperparameters
            - multigpus: if training on multiple GPUs
        """
        self.gen_opt.zero_grad()
        """
        Generator update step returns:
            - l_total: overall loss
            - l_adv: adversarial loss of the generator
            - l_rec: short reconstruction loss
            - l_fm_rec: feature matching loss same image
            - l_fm_m: feature matching loss translated image
            - l_long_rec: long reconstruction loss
            - l_long_fm: long feature matching loss
            - l_fm_mix_rec: Feature matching loss between mixed and reconstructed stage
            - acc: accuracy
        """
        l_total, l_adv, l_rec, l_fm_rec, l_fm_m, l_long_rec, l_long_fm, l_fm_mix_rec, acc = self.model(
            co_data, cl_data, hp, 'gen_update')

        self.loss_gen_total = torch.mean(l_total)
        self.loss_gen_adv = torch.mean(l_adv)
        self.loss_gen_recon_x = torch.mean(l_rec)
        self.loss_gen_recon_s = torch.mean(l_fm_rec)
        self.loss_gen_recon_c = torch.mean(l_fm_m)
        self.loss_gen_recon_l = torch.mean(l_long_rec)
        self.loss_gen_recon_lfm = torch.mean(l_long_fm)
        self.loss_gen_recon_mix_rec = torch.mean(l_fm_mix_rec)
        self.accuracy_gen_adv = torch.mean(acc)

        self.gen_opt.step()

        this_model = self.model.module if multigpus else self.model
        update_average(this_model.gen_test, this_model.gen)
        return self.accuracy_gen_adv.item()

    def dis_update(self, co_data, cl_data, hp):
        """
        Params:
            - co_data: content data with one content image and one content image label
            - cl_data: class data with one class image and one class image label
            - hp: hyperparameters
        """
        self.dis_opt.zero_grad()
        """
        Params:
            - al: l_total: overall discriminator loss
            - lfa: l_fake_p: fake loss term
            - lre: l_real_pre: real loss term
            - reg: l_reg_pre: regularization loss term
            - acc: accuracy
        """
        al, lfa, lre, reg, acc = self.model(co_data, cl_data, hp, 'dis_update')

        self.loss_dis_total = torch.mean(al)
        self.loss_dis_fake_adv = torch.mean(lfa)
        self.loss_dis_real_adv = torch.mean(lre)
        self.loss_dis_reg = torch.mean(reg)
        self.accuracy_dis_adv = torch.mean(acc)
        self.dis_opt.step()

        return self.accuracy_dis_adv.item()

    def test(self, co_data, cl_data, multigpus):
        this_model = self.model.module if multigpus else self.model
        return this_model.test(co_data, cl_data)

    def resume(self, checkpoint_dir, hp, multigpus):
        """
        Params:
            - checkpoint: checkpoint from which to resume
            - hp: hyperparameters
            - multigpus: if training on multiple GPUs
        Returns:
            - iterations: number of iterations
        """
        # load current state of generator
        this_model = self.model.module if multigpus else self.model

        # load generator from checkpoint
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        this_model.gen.load_state_dict(state_dict['gen'])
        this_model.gen_test.load_state_dict(state_dict['gen_test'])
        iterations = int(last_model_name[-10:-7])

        self.dis_scheduler = get_scheduler(self.dis_opt, hp)
        self.gen_scheduler = get_scheduler(self.gen_opt, hp)

        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations, multigpus):
        this_model = self.model.module if multigpus else self.model
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save(
            {
                'gen': this_model.gen.state_dict(),
                'gen_test': this_model.gen_test.state_dict()
            }, gen_name)
        torch.save({'dis': this_model.dis.state_dict()}, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)

    def load_ckpt(self, ckpt_name):
        state_dict = torch.load(ckpt_name)
        self.model.gen.load_state_dict(state_dict['gen'])
        self.model.gen_test.load_state_dict(state_dict['gen_test'])

    def translate(self, co_data, cl_data):
        return self.model.translate(co_data, cl_data)

    def translate_k_shot(self, co_data, cl_data, k, mode):
        return self.model.translate_k_shot(co_data, cl_data, k, mode)

    def forward(self, *inputs):
        print('Forward function not implemented.')
        pass