Ejemplo n.º 1
0
def finetune_discrete_decoder(trainer, look_up, model_path, flag='train'):
    # trainer is already trained
    hyper_params = trainer.hps

    for iteration in range(hyper_params.enc_pretrain_iters):
        data = next(trainer.data_loader)
        c, x = trainer.permute_data(data)

        encoded = trainer.encode_step(x)
        x = look_up(x)
        x_tilde = trainer.decode_step(encoded, c)
        loss_rec = torch.mean(torch.abs(x_tilde - x))
        reset_grad([trainer.Encoder, trainer.Decoder])
        loss_rec.backward()
        grad_clip([trainer.Encoder, trainer.Decoder],
                  trainer.hps.max_grad_norm)
        trainer.ae_opt.step()

        # tb info
        info = {
            f'{flag}/disc_loss_rec': loss_rec.item(),
        }
        slot_value = (iteration + 1, hyper_params.enc_pretrain_iters) + \
            tuple([value for value in info.values()])
        log = 'train_discrete:[%06d/%06d], loss_rec=%.3f'
        print(log % slot_value, end='\r')

        if iteration % 100 == 0:
            for tag, value in info.items():
                trainer.logger.scalar_summary(tag, value, iteration + 1)
        if (iteration + 1) % 1000 == 0:
            trainer.save_model(model_path, 'dc', iteration + 1)
    print()
Ejemplo n.º 2
0
 def train(self, model_path, flag='train'):
     # load hyperparams
     hps = self.hps
     for iteration in range(hps.iters):
         data = next(self.data_loader)
         y, x = self.permute_data(data)
         # encode
         enc = self.encode_step(x)
         # forward to classifier
         logits = self.forward_step(enc)
         # calculate loss
         loss = self.cal_loss(logits, y)
         # optimize
         reset_grad([self.SpeakerClassifier])
         loss.backward()
         grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
         self.opt.step()
         # calculate acc
         acc = cal_acc(logits, y)
         # print info
         info = {
             f'{flag}/loss': loss.data[0], 
             f'{flag}/acc': acc,
         }
         slot_value = (iteration + 1, hps.iters) + tuple([value for value in info.values()])
         log = 'iter:[%06d/%06d], loss=%.3f, acc=%.3f'
         print(log % slot_value, end='\r')
         for tag, value in info.items():
             self.logger.scalar_summary(tag, value, iteration)
         if iteration % 1000 == 0 or iteration + 1 == hps.iters:
             valid_loss, valid_acc = self.valid(n_batches=10)
             # print info
             info = {
                 f'{flag}/valid_loss': valid_loss, 
                 f'{flag}/valid_acc': valid_acc,
             }
             slot_value = (iteration + 1, hps.iters) + \
                     tuple([value for value in info.values()])
             log = 'iter:[%06d/%06d], valid_loss=%.3f, valid_acc=%.3f'
             print(log % slot_value)
             for tag, value in info.items():
                 self.logger.scalar_summary(tag, value, iteration)
             self.save_model(model_path, iteration)
Ejemplo n.º 3
0
            x = Variable(torch.FloatTensor(np.transpose(o, (2, 0, 1))[None]))
            x_next = Variable(
                torch.FloatTensor(np.transpose(o_next, (2, 0, 1))[None]))
        z = C.encode(x)
        density_x = C.density(x_next, z)
        density_sum = 0
        for j in [n for n in range(n_trajs) if n != i]:
            k = np.random.randint(len(data[j]))
            o_other = data[j][k][0]
            if torch.cuda.is_available():
                x_other = Variable(
                    torch.cuda.FloatTensor(
                        np.transpose(o_other, (2, 0, 1))[None]))
            else:
                x_other = Variable(
                    torch.FloatTensor(np.transpose(o_other, (2, 0, 1))[None]))
            density_sum += torch.exp(C.density(x_other, z) - density_x)
        density = 1.0 / (1.0 + density_sum)
        C_loss = -torch.mean(torch.log(density))
        C_loss.backward()
        C_solver.step()
        reset_grad(params)

    print('********** Epoch %i ************' % epoch)
    print(C_loss)
    log_value('C_loss', C_loss, epoch)

    if not os.path.exists('%s/var' % savepath):
        os.makedirs('%s/var' % savepath)
    torch.save(C.state_dict(), '%s/var/cpc%d' % (savepath, epoch))
Ejemplo n.º 4
0
 def train(self, model_path, flag='train', mode='train'):
     # load hyperparams
     hps = self.hps
     if mode == 'pretrain_G':
         for iteration in range(hps.enc_pretrain_iters):
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             x_tilde = self.decode_step(enc, c)
             loss_rec = torch.mean(torch.abs(x_tilde - x))
             reset_grad([self.Encoder, self.Decoder])
             loss_rec.backward()
             grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
             self.ae_opt.step()
             # tb info
             info = {
                 f'{flag}/pre_loss_rec': loss_rec.item(),
             }
             slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple([value for value in info.values()])
             log = 'pre_G:[%06d/%06d], loss_rec=%.3f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
     elif mode == 'pretrain_D':
         for iteration in range(hps.dis_pretrain_iters):
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             # classify speaker
             logits = self.clf_step(enc)
             loss_clf = self.cal_loss(logits, c)
             # update 
             reset_grad([self.SpeakerClassifier])
             loss_clf.backward()
             grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
             self.clf_opt.step()
             # calculate acc
             acc = cal_acc(logits, c)
             info = {
                 f'{flag}/pre_loss_clf': loss_clf.item(),
                 f'{flag}/pre_acc': acc,
             }
             slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple([value for value in info.values()])
             log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
     elif mode == 'patchGAN':
         for iteration in range(hps.patch_iters):
             #=======train D=========#
             for step in range(hps.n_patch_steps):
                 data = next(self.data_loader)
                 c, x = self.permute_data(data)
                 ## encode
                 enc = self.encode_step(x)
                 # sample c
                 c_prime = self.sample_c(x.size(0))
                 # generator
                 x_tilde = self.gen_step(enc, c_prime)
                 # discriminstor
                 w_dis, real_logits, gp = self.patch_step(x, x_tilde, is_dis=True)
                 # aux classification loss 
                 loss_clf = self.cal_loss(real_logits, c)
                 loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                 reset_grad([self.PatchDiscriminator])
                 loss.backward()
                 grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm)
                 self.patch_opt.step()
                 # calculate acc
                 acc = cal_acc(real_logits, c)
                 info = {
                     f'{flag}/w_dis': w_dis.item(),
                     f'{flag}/gp': gp.item(), 
                     f'{flag}/real_loss_clf': loss_clf.item(),
                     f'{flag}/real_acc': acc, 
                 }
                 slot_value = (step, iteration+1, hps.patch_iters) + tuple([value for value in info.values()])
                 log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                 print(log % slot_value)
                 if iteration % 100 == 0:
                     for tag, value in info.items():
                         self.logger.scalar_summary(tag, value, iteration + 1)
             #=======train G=========#
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             # sample c
             c_prime = self.sample_c(x.size(0))
             # generator
             x_tilde = self.gen_step(enc, c_prime)
             # discriminstor
             loss_adv, fake_logits = self.patch_step(x, x_tilde, is_dis=False)
             # aux classification loss 
             loss_clf = self.cal_loss(fake_logits, c_prime)
             loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
             reset_grad([self.Generator])
             loss.backward()
             grad_clip([self.Generator], self.hps.max_grad_norm)
             self.gen_opt.step()
             # calculate acc
             acc = cal_acc(fake_logits, c_prime)
             info = {
                 f'{flag}/loss_adv': loss_adv.item(),
                 f'{flag}/fake_loss_clf': loss_clf.item(),
                 f'{flag}/fake_acc': acc, 
             }
             slot_value = (iteration+1, hps.patch_iters) + tuple([value for value in info.values()])
             log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
Ejemplo n.º 5
0
 if iteration < hps.lat_sched_iters:
     current_alpha = hps.alpha_enc * (iteration / hps.lat_sched_iters)
 else:
     current_alpha = hps.alpha_enc
 #==================train D==================#
 for step in range(hps.n_latent_steps):
     data = next(self.data_loader)
     c, x = self.permute_data(data)
     # encode
     enc = self.encode_step(x)
     # classify speaker
     logits = self.clf_step(enc)
     loss_clf = self.cal_loss(logits, c)
     loss = hps.alpha_dis * loss_clf
     # update 
     reset_grad([self.SpeakerClassifier])
     loss.backward()
     grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
     self.clf_opt.step()
     # calculate acc
     acc = cal_acc(logits, c)
     info = {
         f'{flag}/D_loss_clf': loss_clf.item(),
         f'{flag}/D_acc': acc,
     }
     slot_value = (step, iteration + 1, hps.iters) + tuple([value for value in info.values()])
     log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
     print(log % slot_value)
     if iteration % 100 == 0:
         for tag, value in info.items():
             self.logger.scalar_summary(tag, value, iteration + 1)
Ejemplo n.º 6
0
    def train(self, model_path, flag='train', mode='train'):
        if not os.path.isdir(model_path):
            os.makedirs(model_path)
            os.chmod(model_path, 0o755)
        model_path = os.path.join(model_path, 'model.pkl')

        # load hyperparams
        hps = self.hps
        if mode == 'pretrain_G':
            for iteration in range(2200):
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                # encode
                enc = self.encode_step(x)
                x_tilde = self.decode_step(enc, c)
                loss_rec = torch.mean(torch.abs(x_tilde - x))
                reset_grad([self.Encoder, self.Decoder])
                loss_rec.backward()
                grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
                self.ae_opt.step()
                # tb info
                info = {
                    f'{flag}/pre_loss_rec': loss_rec.item(),
                }
                slot_value = (iteration + 1, 2200) + tuple(
                    [value for value in info.values()])
                log = 'pre_G:[%06d/%06d], loss_rec=%.3f'
                print(log % slot_value)
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
        elif mode == 'pretrain_D':
            for iteration in range(2200):
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                # encode
                enc = self.encode_step(x)
                # classify speaker
                logits = self.clf_step(enc)
                loss_clf = self.cal_loss(logits, c)
                # update
                reset_grad([self.SpeakerClassifier])
                loss_clf.backward()
                grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
                self.clf_opt.step()
                # calculate acc
                acc = cal_acc(logits, c)
                info = {
                    f'{flag}/pre_loss_clf': loss_clf.item(),
                    f'{flag}/pre_acc': acc,
                }
                slot_value = (iteration + 1, 2200) + tuple(
                    [value for value in info.values()])
                log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
                print(log % slot_value)
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
        elif mode == 'patchGAN':
            for iteration in range(1100):
                #=======train D=========#
                for step in range(hps.n_patch_steps):
                    data = next(self.data_loader)
                    c, x = self.permute_data(data)
                    ## encode
                    enc = self.encode_step(x)
                    # sample c
                    c_prime = self.sample_c(x.size(0))
                    # generator
                    x_tilde = self.gen_step(enc, c_prime)
                    # discriminstor
                    w_dis, real_logits, gp = self.patch_step(x,
                                                             x_tilde,
                                                             is_dis=True)
                    # aux classification loss
                    loss_clf = self.cal_loss(real_logits, c)
                    loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                    reset_grad([self.PatchDiscriminator])
                    loss.backward()
                    grad_clip([self.PatchDiscriminator],
                              self.hps.max_grad_norm)
                    self.patch_opt.step()
                    # calculate acc
                    acc = cal_acc(real_logits, c)
                    info = {
                        f'{flag}/w_dis': w_dis.item(),
                        f'{flag}/gp': gp.item(),
                        f'{flag}/real_loss_clf': loss_clf.item(),
                        f'{flag}/real_acc': acc,
                    }
                    slot_value = (step, iteration + 1, 1100) + tuple(
                        [value for value in info.values()])
                    log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                    print(log % slot_value)
                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)
                #=======train G=========#
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                # encode
                enc = self.encode_step(x)
                # sample c
                c_prime = self.sample_c(x.size(0))
                # generator
                x_tilde = self.gen_step(enc, c_prime)
                # discriminstor
                loss_adv, fake_logits = self.patch_step(x,
                                                        x_tilde,
                                                        is_dis=False)
                # aux classification loss
                loss_clf = self.cal_loss(fake_logits, c_prime)
                loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
                reset_grad([self.Generator])
                loss.backward()
                grad_clip([self.Generator], self.hps.max_grad_norm)
                self.gen_opt.step()
                # calculate acc
                acc = cal_acc(fake_logits, c_prime)
                info = {
                    f'{flag}/loss_adv': loss_adv.item(),
                    f'{flag}/fake_loss_clf': loss_clf.item(),
                    f'{flag}/fake_acc': acc,
                }
                slot_value = (iteration + 1, 1100) + tuple(
                    [value for value in info.values()])
                log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f'
                print(log % slot_value)
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if iteration % 1000 == 0 or iteration + 1 == hps.patch_iters:
                    self.save_model(model_path, iteration + hps.iters)
        elif mode == 'train':
            for iteration in range(1100):
                # calculate current alpha
                if iteration < hps.lat_sched_iters:
                    current_alpha = hps.alpha_enc * (iteration /
                                                     hps.lat_sched_iters)
                else:
                    current_alpha = hps.alpha_enc
                #==================train D==================#
                for step in range(hps.n_latent_steps):
                    data = next(self.data_loader)
                    c, x = self.permute_data(data)
                    # encode
                    enc = self.encode_step(x)
                    # classify speaker
                    logits = self.clf_step(enc)
                    loss_clf = self.cal_loss(logits, c)
                    loss = hps.alpha_dis * loss_clf
                    # update
                    reset_grad([self.SpeakerClassifier])
                    loss.backward()
                    grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
                    self.clf_opt.step()
                    # calculate acc
                    acc = cal_acc(logits, c)
                    info = {
                        f'{flag}/D_loss_clf': loss_clf.item(),
                        f'{flag}/D_acc': acc,
                    }
                    slot_value = (step, iteration + 1, 1100) + tuple(
                        [value for value in info.values()])
                    log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
                    print(log % slot_value)
                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)
                #==================train G==================#
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                # encode
                enc = self.encode_step(x)
                # decode
                x_tilde = self.decode_step(enc, c)
                loss_rec = torch.mean(torch.abs(x_tilde - x))
                # classify speaker
                logits = self.clf_step(enc)
                acc = cal_acc(logits, c)
                loss_clf = self.cal_loss(logits, c)
                # maximize classification loss
                loss = loss_rec - current_alpha * loss_clf
                reset_grad([self.Encoder, self.Decoder])
                loss.backward()
                grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
                self.ae_opt.step()
                info = {
                    f'{flag}/loss_rec': loss_rec.item(),
                    f'{flag}/G_loss_clf': loss_clf.item(),
                    f'{flag}/alpha': current_alpha,
                    f'{flag}/G_acc': acc,
                }
                slot_value = (iteration + 1, 1100) + tuple(
                    [value for value in info.values()])
                log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f'
                print(log % slot_value)
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if iteration % 1000 == 0 or iteration + 1 == hps.iters:
                    self.save_model(model_path, iteration)
Ejemplo n.º 7
0
 def train(self, model_path, flag='train', mode='train'):
     # load hyperparams
     hps = self.hps
     if mode == 'pretrain_G':
         for iteration in range(hps.enc_pretrain_iters):
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             x_tilde = self.decode_step(enc, c)
             loss_rec = torch.mean(torch.abs(x_tilde - x))
             reset_grad([self.Encoder, self.Decoder])
             loss_rec.backward()
             grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
             self.ae_opt.step()
             # tb info
             info = {
                 f'{flag}/pre_loss_rec': loss_rec.item(),
             }
             slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple(
                 [value for value in info.values()])
             log = 'pre_G:[%06d/%06d], loss_rec=%.3f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
     elif mode == 'pretrain_D':
         for iteration in range(hps.dis_pretrain_iters):
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             # classify speaker
             logits = self.clf_step(enc)
             loss_clf = self.cal_loss(logits, c)
             # update
             reset_grad([self.SpeakerClassifier])
             loss_clf.backward()
             grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
             self.clf_opt.step()
             # calculate acc
             acc = cal_acc(logits, c)
             info = {
                 f'{flag}/pre_loss_clf': loss_clf.item(),
                 f'{flag}/pre_acc': acc,
             }
             slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple(
                 [value for value in info.values()])
             log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
     elif mode == 'patchGAN':
         for iteration in range(hps.patch_iters):
             #=======train D=========#
             for step in range(hps.n_patch_steps):
                 data = next(self.data_loader)
                 c, x = self.permute_data(data)
                 ## encode
                 enc = self.encode_step(x)
                 # sample c
                 c_prime = self.sample_c(x.size(0))
                 # generator
                 x_tilde = self.gen_step(enc, c_prime)
                 # discriminstor
                 w_dis, real_logits, gp = self.patch_step(x,
                                                          x_tilde,
                                                          is_dis=True)
                 # aux classification loss
                 loss_clf = self.cal_loss(real_logits, c)
                 loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                 reset_grad([self.PatchDiscriminator])
                 loss.backward()
                 grad_clip([self.PatchDiscriminator],
                           self.hps.max_grad_norm)
                 self.patch_opt.step()
                 # calculate acc
                 acc = cal_acc(real_logits, c)
                 info = {
                     f'{flag}/w_dis': w_dis.item(),
                     f'{flag}/gp': gp.item(),
                     f'{flag}/real_loss_clf': loss_clf.item(),
                     f'{flag}/real_acc': acc,
                 }
                 slot_value = (step, iteration + 1,
                               hps.patch_iters) + tuple(
                                   [value for value in info.values()])
                 log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                 print(log % slot_value)
                 if iteration % 100 == 0:
                     for tag, value in info.items():
                         self.logger.scalar_summary(tag, value,
                                                    iteration + 1)
             #=======train G=========#
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             # sample c
             c_prime = self.sample_c(x.size(0))
             # generator
             x_tilde = self.gen_step(enc, c_prime)
             # discriminstor
             loss_adv, fake_logits = self.patch_step(x,
                                                     x_tilde,
                                                     is_dis=False)
             # aux classification loss
             loss_clf = self.cal_loss(fake_logits, c_prime)
             loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
             reset_grad([self.Generator])
             loss.backward()
             grad_clip([self.Generator], self.hps.max_grad_norm)
             self.gen_opt.step()
             # calculate acc
             acc = cal_acc(fake_logits, c_prime)
             info = {
                 f'{flag}/loss_adv': loss_adv.item(),
                 f'{flag}/fake_loss_clf': loss_clf.item(),
                 f'{flag}/fake_acc': acc,
             }
             slot_value = (iteration + 1, hps.patch_iters) + tuple(
                 [value for value in info.values()])
             log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration)
         #===================== Train G =====================#
         data = next(self.data_loader)
         (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                      x_j) = self.permute_data(data)
         # encode
         enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
             x_i_t, x_i_tk, x_i_prime, x_j)
         # decode
         x_tilde = self.decode_step(enc_i_t, c_i)
         loss_rec = torch.mean(torch.abs(x_tilde - x_i_t))
         # latent discriminate
         loss_adv = self.latent_discriminate_step(enc_i_t,
                                                  enc_i_tk,
                                                  enc_i_prime,
                                                  enc_j,
                                                  is_dis=False)
         ae_loss = loss_rec + current_alpha * loss_adv
         reset_grad([self.Encoder, self.Decoder])
         retain_graph = True if hps.n_patch_steps > 0 else False
         ae_loss.backward(retain_graph=retain_graph)
         grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
         self.ae_opt.step()
         info = {
             f'{flag}/loss_rec': loss_rec.data[0],
             f'{flag}/loss_adv': loss_adv.data[0],
             f'{flag}/alpha': current_alpha,
         }
         slot_value = (iteration + 1, hps.iters) + tuple(
             [value for value in info.values()])
         log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e'
         print(log % slot_value)
         for tag, value in info.items():
             self.logger.scalar_summary(tag, value, iteration + 1)
         # patch discriminate
         if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter:
             c_sample = self.sample_c(x_i_t.size(0))
             x_tilde = self.decode_step(enc_i_t, c_sample)
             patch_w_dis, real_logits, fake_logits = \
                     self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False)
             patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss
             reset_grad([self.Decoder])
             patch_loss.backward()
             grad_clip([self.Decoder], self.hps.max_grad_norm)
             self.decoder_opt.step()
             info = {
                 f'{flag}/loss_rec': loss_rec.item(),
                 f'{flag}/G_loss_clf': loss_clf.item(),
                 f'{flag}/alpha': current_alpha,
                 f'{flag}/G_acc': acc,
             }
             slot_value = (iteration + 1, hps.iters) + tuple(
                 [value for value in info.values()])
             log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
             if iteration % 1000 == 0 or iteration + 1 == hps.iters:
                 self.save_model(model_path, iteration)
Ejemplo n.º 8
0
    def train(self,
              model_path,
              flag='train',
              mode='train',
              target_guided=False):
        # load hyperparams
        hps = self.hps

        if mode == 'pretrain_AE':
            for iteration in range(hps.enc_pretrain_iters):
                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc_act, enc = self.encode_step(x)
                x_dec = self.decode_step(enc_act, c)
                loss_rec = torch.mean(torch.abs(x_dec - x))
                reset_grad([self.Encoder, self.Decoder])
                loss_rec.backward()
                grad_clip([self.Encoder, self.Decoder], hps.max_grad_norm)
                self.ae_opt.step()

                # tb info
                info = {
                    f'{flag}/pre_loss_rec': loss_rec.item(),
                }
                slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple(
                    [value for value in info.values()])
                log = 'pre_AE:[%06d/%06d], loss_rec=%.3f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'ae', iteration + 1)
            print()

        elif mode == 'pretrain_C':
            for iteration in range(hps.dis_pretrain_iters):

                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc_act, enc = self.encode_step(x)

                # classify speaker
                logits = self.clf_step(enc)
                loss_clf = self.cal_loss(logits, c)

                # update
                reset_grad([self.SpeakerClassifier])
                loss_clf.backward()
                grad_clip([self.SpeakerClassifier], hps.max_grad_norm)
                self.clf_opt.step()

                # calculate acc
                acc = self.cal_acc(logits, c)
                info = {
                    f'{flag}/pre_loss_clf': loss_clf.item(),
                    f'{flag}/pre_acc': acc,
                }
                slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple(
                    [value for value in info.values()])
                log = 'pre_C:[%06d/%06d], loss_clf=%.2f, acc=%.2f'

                print(log % slot_value, end='\r')
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'c', iteration + 1)
            print()

        elif mode == 'train':
            for iteration in range(hps.iters):

                # calculate current alpha
                if iteration < hps.lat_sched_iters:
                    current_alpha = hps.alpha_enc * (iteration /
                                                     hps.lat_sched_iters)
                else:
                    current_alpha = hps.alpha_enc

                #==================train D==================#
                for step in range(hps.n_latent_steps):
                    data = next(self.data_loader)
                    c, x = self.permute_data(data)

                    # encode
                    enc_act, enc = self.encode_step(x)

                    # classify speaker
                    logits = self.clf_step(enc)
                    loss_clf = self.cal_loss(logits, c)
                    loss = hps.alpha_dis * loss_clf

                    # update
                    reset_grad([self.SpeakerClassifier])
                    loss.backward()
                    grad_clip([self.SpeakerClassifier], hps.max_grad_norm)
                    self.clf_opt.step()

                    # calculate acc
                    acc = self.cal_acc(logits, c)
                    info = {
                        f'{flag}/D_loss_clf': loss_clf.item(),
                        f'{flag}/D_acc': acc,
                    }
                    slot_value = (step, iteration + 1, hps.iters) + tuple(
                        [value for value in info.values()])
                    log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f'

                    print(log % slot_value, end='\r')
                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)

                #==================train G==================#
                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc_act, enc = self.encode_step(x)

                # decode
                x_dec = self.decode_step(enc_act, c)
                loss_rec = torch.mean(torch.abs(x_dec - x))

                # classify speaker
                logits = self.clf_step(enc)
                acc = self.cal_acc(logits, c)
                loss_clf = self.cal_loss(logits, c)

                # maximize classification loss
                loss = loss_rec - current_alpha * loss_clf
                reset_grad([self.Encoder, self.Decoder])
                loss.backward()
                grad_clip([self.Encoder, self.Decoder], hps.max_grad_norm)
                self.ae_opt.step()

                info = {
                    f'{flag}/loss_rec': loss_rec.item(),
                    f'{flag}/G_loss_clf': loss_clf.item(),
                    f'{flag}/alpha': current_alpha,
                    f'{flag}/G_acc': acc,
                }
                slot_value = (iteration + 1, hps.iters) + tuple(
                    [value for value in info.values()])
                log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's1', iteration + 1)
            print()

        elif mode == 'patchGAN':
            for iteration in range(hps.patch_iters):
                #==================train D==================#
                for step in range(hps.n_patch_steps):

                    data_s = next(self.source_loader)
                    data_t = next(self.target_loader)
                    _, x_s = self.permute_data(data_s)
                    c_t, x_t = self.permute_data(data_t)

                    # encode
                    enc_act, _ = self.encode_step(x_s)

                    # generator
                    x_dec = self.gen_step(enc_act, c_t)

                    # discriminstor
                    w_dis, real_logits, gp = self.patch_step(x_t,
                                                             x_dec,
                                                             is_dis=True)

                    # aux classification loss
                    loss_clf = self.cal_loss(real_logits, c_t, shift=True)

                    loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                    reset_grad([self.PatchDiscriminator])
                    loss.backward()
                    grad_clip([self.PatchDiscriminator], hps.max_grad_norm)
                    self.patch_opt.step()

                    # calculate acc
                    acc = self.cal_acc(real_logits, c_t, shift=True)
                    info = {
                        f'{flag}/w_dis': w_dis.item(),
                        f'{flag}/gp': gp.item(),
                        f'{flag}/real_loss_clf': loss_clf.item(),
                        f'{flag}/real_acc': acc,
                    }
                    slot_value = (step, iteration + 1,
                                  hps.patch_iters) + tuple(
                                      [value for value in info.values()])
                    log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                    print(log % slot_value, end='\r')

                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)

                #==================train G==================#
                data_s = next(self.source_loader)
                data_t = next(self.target_loader)
                _, x_s = self.permute_data(data_s)
                c_t, x_t = self.permute_data(data_t)

                # encode
                enc_act, _ = self.encode_step(x_s)

                # generator
                x_dec = self.gen_step(enc_act, c_t)

                # discriminstor
                loss_adv, fake_logits = self.patch_step(x_t,
                                                        x_dec,
                                                        is_dis=False)

                # aux classification loss
                loss_clf = self.cal_loss(fake_logits, c_t, shift=True)
                loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
                reset_grad([self.Generator])
                loss.backward()
                grad_clip([self.Generator], hps.max_grad_norm)
                self.gen_opt.step()

                if target_guided:
                    # teacher forcing
                    enc_tf, _ = self.encode_step(x_t)
                    x_dec_tf = self.gen_step(enc_tf, c_t)
                    loss_rec = torch.mean(torch.abs(x_dec_tf - x_t))
                    reset_grad([self.Generator])
                    loss_rec.backward()
                    self.gen_opt.step()

                # calculate acc
                acc = self.cal_acc(fake_logits, c_t, shift=True)
                info = {
                    f'{flag}/loss_adv': loss_adv.item(),
                    f'{flag}/fake_loss_clf': loss_clf.item(),
                    f'{flag}/fake_acc': acc,
                    f'{flag}/tg_rec':
                    loss_rec.item() if target_guided else 0.000,
                }
                slot_value = (iteration + 1, hps.patch_iters) + tuple(
                    [value for value in info.values()])
                log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f, tg_rec=%.3f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's2', iteration + 1)
            print()

        elif mode == 'autolocker':
            criterion = torch.nn.BCELoss()
            for iteration in range(hps.patch_iters):
                #==================train G==================#
                data_s = next(self.source_loader)
                data_t = next(self.target_loader)
                _, x_s = self.permute_data(data_s)
                c_t, x_t = self.permute_data(data_t)

                # encode
                enc_act, _ = self.encode_step(x_s)

                # decode
                residual_output = self.gen_step(enc_act, c_t)

                # re-encode
                re_enc, _ = self.encode_step(residual_output)

                # re-encode loss
                loss_reenc = criterion(re_enc, enc_act.data)
                reset_grad([self.Encoder, self.Decoder, self.Generator])
                loss_reenc.backward()
                grad_clip([self.Generator], hps.max_grad_norm)
                self.gen_opt.step()

                if target_guided:
                    # teacher forcing
                    enc_tf, _ = self.encode_step(x_t)
                    x_dec_tf = self.gen_step(enc_tf, c_t)
                    loss_rec = torch.mean(torch.abs(x_dec_tf - x_t))
                    reset_grad([self.Encoder, self.Decoder, self.Generator])
                    loss_rec.backward()
                    self.gen_opt.step()

                # calculate acc
                info = {
                    f'{flag}/re_enc': loss_reenc.item(),
                    f'{flag}/tg_rec':
                    loss_rec.item() if target_guided else 0.000,
                }
                slot_value = (iteration + 1, hps.patch_iters) + tuple(
                    [value for value in info.values()])
                log = 'patch_G:[%06d/%06d], re_enc=%.3f, tg_rec=%.3f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's2', iteration + 1)
            print()

        elif mode == 't_classify':
            for iteration in range(hps.tclf_iters):
                #======train target classifier======#
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                c[c < 100] = 102

                # classification
                logits = self.tclf_step(x)

                # classification loss
                loss = self.cal_loss(logits, c - self.shift_c)
                reset_grad([self.TargetClassifier])
                loss.backward()
                grad_clip([self.TargetClassifier], hps.max_grad_norm)
                self.tclf_opt.step()

                # calculate acc
                acc = self.cal_acc(logits, c - self.shift_c)
                info = {
                    f'{flag}/acc': acc,
                }
                slot_value = (iteration + 1, hps.tclf_iters) + tuple(
                    [value for value in info.values()])
                log = 'Target Classifier:[%05d/%05d], acc=%.2f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'tclf', iteration + 1)
            print()

        elif mode == 'train_Tacotron':

            assert self.g_mode == 'tacotron'
            criterion = TacotronLoss()
            self.Encoder.eval()

            for iteration in range(hps.tacotron_iters):
                #======train tacotron======#

                cur_lr = learning_rate_decay(init_lr=0.002,
                                             global_step=iteration)
                for param_group in self.gen_opt.param_groups:
                    param_group['lr'] = cur_lr

                data = next(self.data_loader)
                c, x, m = self.permute_data(data, load_mel=True)

                # encode
                enc_act, enc = self.encode_step(x)

                # tacotron synthesis
                m_dec, x_dec = self.tacotron_step(enc_act.data, m, c)

                # reconstruction loss
                loss_rec = criterion([m_dec, x_dec], [m, x])
                reset_grad([self.Generator])
                loss_rec.backward()
                grad_clip([self.Generator], hps.max_grad_norm)
                self.gen_opt.step()

                # tb info
                info = {
                    f'{flag}/tacotron_loss_rec': loss_rec.item(),
                    f'{flag}/tacotron_lr': cur_lr,
                }
                slot_value = (iteration + 1, hps.tacotron_iters) + tuple(
                    [value for value in info.values()])
                log = 'train_Tacotron:[%06d/%06d], loss_rec=%.3f, lr=%.2e'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 't', iteration + 1)
            print()

        else:
            raise NotImplementedError()
Ejemplo n.º 9
0
    def train(self, model_path, flag='train', mode='train'):
        # load hyperparams
        hps = self.hps

        if mode == 'pretrain_AE':
            for iteration in range(hps.enc_pretrain_iters):
                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc = self.encode_step(x)
                x_tilde = self.decode_step(enc, c)
                loss_rec = torch.mean(torch.abs(x_tilde - x))
                reset_grad([self.Encoder, self.Decoder])
                loss_rec.backward()
                grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
                self.ae_opt.step()

                # tb info
                info = {
                    f'{flag}/pre_loss_rec': loss_rec.item(),
                }
                slot_value = (iteration + 1, hps.enc_pretrain_iters) + \
                    tuple([value for value in info.values()])
                log = 'pre_AE:[%06d/%06d], loss_rec=%.3f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'ae', iteration + 1)
            print()

        elif mode == 'pretrain_C':
            for iteration in range(hps.dis_pretrain_iters):

                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc = self.encode_step(x)

                # classify speaker
                logits = self.clf_step(enc)
                loss_clf = self.cal_loss(logits, c)

                # update
                reset_grad([self.SpeakerClassifier])
                loss_clf.backward()
                grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
                self.clf_opt.step()

                # calculate acc
                acc = self.cal_acc(logits, c)
                info = {
                    f'{flag}/pre_loss_clf': loss_clf.item(),
                    f'{flag}/pre_acc': acc,
                }
                slot_value = (iteration + 1, hps.dis_pretrain_iters) + \
                    tuple([value for value in info.values()])
                log = 'pre_C:[%06d/%06d], loss_clf=%.2f, acc=%.2f'

                print(log % slot_value, end='\r')
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'c', iteration + 1)
            print()

        elif mode == 'train':
            for iteration in range(hps.iters):

                # calculate current alpha
                if iteration < hps.lat_sched_iters:
                    current_alpha = hps.alpha_enc * \
                        (iteration / hps.lat_sched_iters)
                else:
                    current_alpha = hps.alpha_enc

                #==================train D==================#
                for step in range(hps.n_latent_steps):
                    data = next(self.data_loader)
                    c, x = self.permute_data(data)

                    # encode
                    enc = self.encode_step(x)
                    _, z_mean, z_log_var = enc
                    kl_loss = (1 + z_log_var - z_mean**2 -
                               torch.exp(z_log_var)).sum(-1) * -.5
                    kl_loss = kl_loss.sum()
                    # classify speaker
                    logits = self.clf_step(enc)
                    loss_clf = self.cal_loss(logits, c)
                    loss = hps.alpha_dis * loss_clf + kl_loss

                    # update
                    reset_grad([self.SpeakerClassifier])
                    loss.backward()
                    grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
                    self.clf_opt.step()

                    # calculate acc
                    acc = self.cal_acc(logits, c)
                    info = {
                        f'{flag}/D_loss_clf': loss_clf.item(),
                        f'{flag}/D_acc': acc,
                    }
                    slot_value = (step, iteration + 1, hps.iters) + \
                        tuple([value for value in info.values()])
                    log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f'

                    print(log % slot_value, end='\r')
                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)
                #==================train G==================#
                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc = self.encode_step(x)

                # decode
                x_tilde = self.decode_step(enc, c)
                loss_rec = torch.mean(torch.abs(x_tilde - x))

                # classify speaker
                logits = self.clf_step(enc)
                acc = self.cal_acc(logits, c)
                loss_clf = self.cal_loss(logits, c)

                # maximize classification loss
                loss = loss_rec - current_alpha * loss_clf
                reset_grad([self.Encoder, self.Decoder])
                loss.backward()
                grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
                self.ae_opt.step()

                info = {
                    f'{flag}/loss_rec': loss_rec.item(),
                    f'{flag}/G_loss_clf': loss_clf.item(),
                    f'{flag}/alpha': current_alpha,
                    f'{flag}/G_acc': acc,
                }
                slot_value = (iteration + 1, hps.iters) + \
                    tuple([value for value in info.values()])
                log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's1', iteration + 1)
            print()

        elif mode == 'patchGAN':
            for iteration in range(hps.patch_iters):
                #==================train D==================#
                for step in range(hps.n_patch_steps):

                    data_s = next(self.source_loader)
                    data_t = next(self.target_loader)
                    _, x_s = self.permute_data(data_s)
                    c, x_t = self.permute_data(data_t)

                    # encode
                    enc = self.encode_step(x_s)

                    # sample c
                    c_prime = self.sample_c(x_t.size(0))

                    # generator
                    x_tilde = self.gen_step(enc, c_prime)

                    # discriminstor
                    w_dis, real_logits, gp = self.patch_step(x_t,
                                                             x_tilde,
                                                             is_dis=True)

                    # aux classification loss
                    loss_clf = self.cal_loss(real_logits, c, shift=True)

                    loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                    reset_grad([self.PatchDiscriminator])
                    loss.backward()
                    grad_clip([self.PatchDiscriminator],
                              self.hps.max_grad_norm)
                    self.patch_opt.step()

                    # calculate acc
                    acc = self.cal_acc(real_logits, c, shift=True)
                    info = {
                        f'{flag}/w_dis': w_dis.item(),
                        f'{flag}/gp': gp.item(),
                        f'{flag}/real_loss_clf': loss_clf.item(),
                        f'{flag}/real_acc': acc,
                    }
                    slot_value = (step, iteration + 1, hps.patch_iters) + \
                        tuple([value for value in info.values()])
                    log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                    print(log % slot_value, end='\r')

                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)

                #==================train G==================#
                data_s = next(self.source_loader)
                data_t = next(self.target_loader)
                _, x_s = self.permute_data(data_s)
                c, x_t = self.permute_data(data_t)

                # encode
                enc = self.encode_step(x_s)

                # sample c
                c_prime = self.sample_c(x_t.size(0))

                # generator
                x_tilde = self.gen_step(enc, c_prime)

                # discriminstor
                loss_adv, fake_logits = self.patch_step(x_t,
                                                        x_tilde,
                                                        is_dis=False)

                # aux classification loss
                loss_clf = self.cal_loss(fake_logits, c_prime, shift=True)
                loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
                reset_grad([self.Generator])
                loss.backward()
                grad_clip([self.Generator], self.hps.max_grad_norm)
                self.gen_opt.step()

                # calculate acc
                acc = self.cal_acc(fake_logits, c_prime, shift=True)
                info = {
                    f'{flag}/loss_adv': loss_adv.item(),
                    f'{flag}/fake_loss_clf': loss_clf.item(),
                    f'{flag}/fake_acc': acc,
                }
                slot_value = (iteration + 1, hps.patch_iters) + \
                    tuple([value for value in info.values()])
                log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's2',
                                    iteration + 1 + hps.iters)
            print()

        else:
            raise NotImplementedError()
Ejemplo n.º 10
0
 def train(self, model_path, flag='train'):
     # load hyperparams
     hps = self.hps
     for iteration in range(hps.iters):
         # calculate current alpha
         if iteration + 1 < hps.lat_sched_iters and iteration >= hps.enc_pretrain_iters:
             current_alpha = hps.alpha_enc * (
                 iteration + 1 - hps.enc_pretrain_iters) / (
                     hps.lat_sched_iters - hps.enc_pretrain_iters)
         else:
             current_alpha = 0
         if iteration >= hps.enc_pretrain_iters:
             n_latent_steps = hps.n_latent_steps \
                 if iteration > hps.enc_pretrain_iters else hps.dis_pretrain_iters
             for step in range(n_latent_steps):
                 #===================== Train latent discriminator =====================#
                 data = next(self.data_loader)
                 (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                              x_j) = self.permute_data(data)
                 # encode
                 enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
                     x_i_t, x_i_tk, x_i_prime, x_j)
                 # latent discriminate
                 latent_w_dis, latent_gp = self.latent_discriminate_step(
                     enc_i_t, enc_i_tk, enc_i_prime, enc_j)
                 lat_loss = -hps.alpha_dis * latent_w_dis + hps.lambda_ * latent_gp
                 reset_grad([self.LatentDiscriminator])
                 lat_loss.backward()
                 grad_clip([self.LatentDiscriminator],
                           self.hps.max_grad_norm)
                 self.lat_opt.step()
                 # print info
                 info = {
                     f'{flag}/D_latent_w_dis': latent_w_dis.data[0],
                     f'{flag}/latent_gp': latent_gp.data[0],
                 }
                 slot_value = (step, iteration + 1, hps.iters) + \
                         tuple([value for value in info.values()])
                 log = 'lat_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f'
                 print(log % slot_value)
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration)
         # two stage training
         if iteration >= hps.patch_start_iter:
             for step in range(hps.n_patch_steps):
                 #===================== Train patch discriminator =====================#
                 data = next(self.data_loader)
                 (c_i, _), (x_i_t, _, _, _) = self.permute_data(data)
                 # encode
                 enc_i_t, = self.encode_step(x_i_t)
                 c_sample = self.sample_c(x_i_t.size(0))
                 x_tilde = self.decode_step(enc_i_t, c_i)
                 # Aux classify loss
                 patch_w_dis, real_logits, fake_logits, patch_gp = \
                         self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=True)
                 patch_loss = -hps.beta_dis * patch_w_dis + hps.lambda_ * patch_gp + hps.beta_clf * c_loss
                 reset_grad([self.PatchDiscriminator])
                 patch_loss.backward()
                 grad_clip([self.PatchDiscriminator],
                           self.hps.max_grad_norm)
                 self.patch_opt.step()
                 # print info
                 info = {
                     f'{flag}/D_patch_w_dis': patch_w_dis.data[0],
                     f'{flag}/patch_gp': patch_gp.data[0],
                     f'{flag}/c_loss': c_loss.data[0],
                     f'{flag}/real_acc': real_acc,
                     f'{flag}/fake_acc': fake_acc,
                 }
                 slot_value = (step, iteration + 1, hps.iters) + \
                         tuple([value for value in info.values()])
                 log = 'patch_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f, c_loss=%.3f, real_acc=%.2f, fake_acc=%.2f'
                 print(log % slot_value)
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration)
         #===================== Train G =====================#
         data = next(self.data_loader)
         (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                      x_j) = self.permute_data(data)
         # encode
         enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
             x_i_t, x_i_tk, x_i_prime, x_j)
         # decode
         x_tilde = self.decode_step(enc_i_t, c_i)
         loss_rec = torch.mean(torch.abs(x_tilde - x_i_t))
         # latent discriminate
         loss_adv = self.latent_discriminate_step(enc_i_t,
                                                  enc_i_tk,
                                                  enc_i_prime,
                                                  enc_j,
                                                  is_dis=False)
         ae_loss = loss_rec + current_alpha * loss_adv
         reset_grad([self.Encoder, self.Decoder])
         retain_graph = True if hps.n_patch_steps > 0 else False
         ae_loss.backward(retain_graph=retain_graph)
         grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
         self.ae_opt.step()
         info = {
             f'{flag}/loss_rec': loss_rec.data[0],
             f'{flag}/loss_adv': loss_adv.data[0],
             f'{flag}/alpha': current_alpha,
         }
         slot_value = (iteration + 1, hps.iters) + tuple(
             [value for value in info.values()])
         log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e'
         print(log % slot_value)
         for tag, value in info.items():
             self.logger.scalar_summary(tag, value, iteration + 1)
         # patch discriminate
         if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter:
             c_sample = self.sample_c(x_i_t.size(0))
             x_tilde = self.decode_step(enc_i_t, c_sample)
             patch_w_dis, real_logits, fake_logits = \
                     self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False)
             patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss
             reset_grad([self.Decoder])
             patch_loss.backward()
             grad_clip([self.Decoder], self.hps.max_grad_norm)
             self.decoder_opt.step()
             info = {
                 f'{flag}/G_patch_w_dis': patch_w_dis.data[0],
                 f'{flag}/c_loss': c_loss.data[0],
                 f'{flag}/real_acc': real_acc,
                 f'{flag}/fake_acc': fake_acc,
             }
             slot_value = (iteration + 1, hps.iters) + tuple(
                 [value for value in info.values()])
             log = 'G:[%06d/%06d]: patch_w_dis=%.2f, c_loss=%.2f, real_acc=%.2f, fake_acc=%.2f'
             print(log % slot_value)
             for tag, value in info.items():
                 self.logger.scalar_summary(tag, value, iteration + 1)
         if iteration % 1000 == 0 or iteration + 1 == hps.iters:
             self.save_model(model_path, iteration)
def train():
    # tell PyTorch to use training mode (dropout, batch norm, etc)
    message_model.train()
    if args.user_classifier != 'none':
        user_model.train()


    if args.message_classifier == 'rnn':
        hidden = message_model.init_hidden(args.batch_size)

    iter = int(args.num_training / args.batch_size)
    for indx in range(iter):
        y_mssg = []
        y_usr = []

        x_message = Variable(get_batch(train_message,indx))

        if args.user_classifier != 'none':
            x_user = Variable(get_batch(train_user,indx))

        if args.cuda:
            x_message = x_message.cuda()
            if args.user_classifier != 'none':
                x_user = x_user.cuda()



        if (x_message.size())[0] == args.batch_size:
            ## Update message learner parameters
            if args.message_classifier == 'rnn':
                hidden1 = repackage_hidden(hidden)
                y_msg_rnn, hidden = message_model(x_message.t(), hidden1)
                y_mssg.append(y_msg_rnn)

            elif args.message_classifier == 'emb':
                y_msg_emd = message_model(x_message.t())
                y_mssg.append(y_msg_emd)

            else:
                y_msg = message_model(x_message)
                y_mssg.append(y_msg)


            if args.user_classifier == 'emb':
                y_user = user_model(x_user.t())
                y_usr = [y_user]
            elif args.user_classifier == 'node2vec':
                y_user = user_model(x_user)
                y_usr = [y_user]



            lb = torch.FloatTensor(get_batch(LB_train,indx))
            ub = torch.FloatTensor(get_batch(UB_train, indx))

            loss, _, _ = utils.cross_entropy_criterion(y_mssg, y_usr, lb, ub, args.cuda)


            loss.backward(retain_variables=True)

            message_optimizer.step()

            utils.reset_grad(message_model.parameters())
            if args.user_classifier != 'none':
                utils.reset_grad(user_model.parameters())


            ## Update user learner parameters if there is user learner
            if args.user_classifier != 'none':

                if args.message_classifier == 'doc2vec' or args.message_classifier == 'bow':
                    y_msg = message_model(x_message)
                    y_mssg = [y_msg]

                elif args.message_classifier == 'rnn':
                    hidden1 = repackage_hidden(hidden)
                    y_msg, hidden = message_model(x_message.t(), hidden1)
                    y_mssg = [y_msg]


                elif args.message_classifier == 'emb':
                    y_msg = message_model(x_message.t())
                    y_mssg = [y_msg]


                if args.user_classifier == 'emb':
                    y_user = user_model(x_user.t())
                    y_usr = [y_user]
                else:
                    y_user = user_model(x_user)
                    y_usr = [y_user]


                loss, _, _ = utils.cross_entropy_criterion(y_mssg, y_usr, lb, ub, args.cuda)


                loss.backward()

                user_optimizer.step()

                utils.reset_grad(message_model.parameters())
                utils.reset_grad(user_model.parameters())
def train(all_models, training_models, solver, training_params, log_every,
          **kwargs):
    model, c_model, actor = all_models
    k_steps = kwargs["k"]
    num_epochs = kwargs["n_epochs"]
    batch_size = kwargs["batch_size"]
    N = kwargs["N"]
    c_type = kwargs["c_type"]
    vae_weight = kwargs["vae_w"]
    beta = kwargs["vae_b"]

    # Configure experiment path
    savepath = kwargs['savepath']

    conditional = kwargs["conditional"]

    configure('%s/var_log' % savepath, flush_secs=5)

    ### Load data ### -- assuming appropriate npy format
    data_file = kwargs["data_dir"]
    data = np.load(data_file)
    n_trajs = len(data)
    data_size = sum([len(data[i]) - k_steps for i in range(n_trajs)])
    print('Number of trajectories: %d' % n_trajs)  # 315
    print('Number of transitions: %d' % data_size)  # 378315

    test_file = kwargs["test_dir"]
    test_data = np.load(test_file)
    test_context = get_torch_images_from_numpy(test_data,
                                               conditional,
                                               one_image=True)

    ### Train models ###
    c_loss = vae_loss = a_loss = torch.Tensor([0]).cuda()
    for epoch in range(num_epochs):
        n_batch = int(data_size / batch_size)
        print('********** Epoch %i ************' % epoch)
        for it in range(n_batch):
            idx, t = get_idx_t(batch_size, k_steps, n_trajs, data)
            o, c = get_torch_images_from_numpy(data[idx, t], conditional)
            ks = np.random.choice(k_steps, batch_size)
            o_next, _ = get_torch_images_from_numpy(data[idx, t + ks],
                                                    conditional)
            o_neg = get_negative_examples(
                data, idx, batch_size, N,
                conditional) if kwargs["use_o_neg"] else None
            o_pred, mu, logvar, cond_info = model(o, c)
            o_next_pred, _, _, _ = model(o_next, c)

            # VAE loss
            if model in training_models:
                vae_loss = loss_function(o_pred,
                                         o,
                                         mu,
                                         logvar,
                                         cond_info.get("means_cond", None),
                                         cond_info.get("log_var_cond", None),
                                         beta=beta) * vae_weight
                vae_loss.backward()

            # C loss
            if c_model in training_models and epoch >= kwargs["pretrain"]:
                c_loss = get_c_loss(model, c_model, c_type, o_pred,
                                    o_next_pred, c, N, o_neg)
                c_loss.backward()

            # Actor loss
            if actor in training_models and epoch >= kwargs["pretrain"]:
                a = get_torch_actions(data[idx, t + 1])
                a_loss = actor.loss(a, o, o_next, c)
                a_loss.backward()

            ### Update models ###
            if solver is not None:
                solver.step()
            reset_grad(training_params)

            if it % log_every == 0:
                ### Log info ###
                log_info(c_loss, vae_loss, a_loss, model, conditional,
                         cond_info, it, n_batch, epoch)

                ### Save params ###
                if not os.path.exists('%s/var' % savepath):
                    os.makedirs('%s/var' % savepath)
                torch.save(model.state_dict(),
                           '%s/var/vae-%d-last-5' % (savepath, epoch % 5 + 1))
                torch.save(c_model.state_dict(),
                           '%s/var/cpc-%d-last-5' % (savepath, epoch % 5 + 1))
                torch.save(
                    actor.state_dict(),
                    '%s/var/actor-%d-last-5' % (savepath, epoch % 5 + 1))

                ### Log images ###
                with torch.no_grad():
                    n_contexts = 7
                    n_samples_per_c = 8
                    o_distinct_c = get_negative_examples(
                        data, idx[:n_contexts], n_contexts, n_samples_per_c,
                        conditional)
                    log_images(
                        o[:n_contexts], o_pred[:n_contexts],
                        o_distinct_c.reshape(n_samples_per_c, n_contexts,
                                             *o_distinct_c.size()[1:]),
                        c[:n_contexts], test_context, model, c_model,
                        n_contexts, n_samples_per_c, savepath, epoch)