Ejemplo n.º 1
0
    fd_optimizer.zero_grad()
    fd_loss.backward(retain_graph=True)
    fd_optimizer.step()

    pd_optimizer.zero_grad()
    pd_loss.backward()
    pd_optimizer.step()

    if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
        #torch.save(g_model.state_dict(), f'{args.save_dir}/ckpt/G_{i + 1}.pth')
        #torch.save(fd_model.state_dict(), f'{args.save_dir}/ckpt/FD_{i + 1}.pth')
        #torch.save(pd_model.state_dict(), f'{args.save_dir}/ckpt/PD_{i + 1}.pth')
        torch.save(g_model.state_dict(), f'{args.save_dir}/ckpt/G_10000.pth')
        torch.save(fd_model.state_dict(), f'{args.save_dir}/ckpt/FD_10000.pth')
        torch.save(pd_model.state_dict(), f'{args.save_dir}/ckpt/PD_10000.pth')

    if (i + 1) % args.log_interval == 0:
        writer.add_scalar('g_loss/recon_loss', recon_loss.item(), i + 1)
        writer.add_scalar('g_loss/cons_loss', cons_loss.item(), i + 1)
        writer.add_scalar('g_loss/gan_loss', gan_loss.item(), i + 1)
        writer.add_scalar('g_loss/total_loss', total_loss.item(), i + 1)
        writer.add_scalar('d_loss/fd_loss', fd_loss.item(), i + 1)
        writer.add_scalar('d_loss/pd_loss', pd_loss.item(), i + 1)

    def denorm(x):
        out = (x + 1) / 2 # [-1,1] -> [0,1]
        return out.clamp_(0, 1)
    if (i + 1) % args.vis_interval == 0:
        ims = torch.cat([img, masked, refine_result], dim=3)
        writer.add_images('raw_masked_refine', denorm(ims), i + 1)
Ejemplo n.º 2
0
class Solver(object):
    def __init__(self, hps, data_loader, log_dir='./log/'):
        self.hps = hps
        self.data_loader = data_loader
        self.model_kept = []
        self.max_keep = 20
        self.build_model()
        self.logger = Logger(log_dir)

    def build_model(self):
        hps = self.hps
        ns = self.hps.ns
        emb_size = self.hps.emb_size
        self.Encoder = Encoder(ns=ns, dp=hps.enc_dp)
        self.Decoder = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size)
        self.Generator = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size)
        self.LatentDiscriminator = LatentDiscriminator(ns=ns, dp=hps.dis_dp)
        self.PatchDiscriminator = PatchDiscriminator(ns=ns,
                                                     n_class=hps.n_speakers)
        if torch.cuda.is_available():
            self.Encoder.cuda()
            self.Decoder.cuda()
            self.Generator.cuda()
            self.LatentDiscriminator.cuda()
            self.PatchDiscriminator.cuda()
        betas = (0.5, 0.9)
        params = list(self.Encoder.parameters()) + list(
            self.Decoder.parameters())
        self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas)
        self.gen_opt = optim.Adam(self.Generator.parameters(),
                                  lr=self.hps.lr,
                                  betas=betas)
        self.lat_opt = optim.Adam(self.LatentDiscriminator.parameters(),
                                  lr=self.hps.lr,
                                  betas=betas)
        self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(),
                                    lr=self.hps.lr,
                                    betas=betas)

    def save_model(self, model_path, iteration, enc_only=True):
        if not enc_only:
            all_model = {
                'encoder': self.Encoder.state_dict(),
                'decoder': self.Decoder.state_dict(),
                'generator': self.Generator.state_dict(),
                'latent_discriminator': self.LatentDiscriminator.state_dict(),
                'patch_discriminator': self.PatchDiscriminator.state_dict(),
            }
        else:
            all_model = {
                'encoder': self.Encoder.state_dict(),
                'decoder': self.Decoder.state_dict(),
                'generator': self.Generator.state_dict(),
            }
        new_model_path = '{}-{}'.format(model_path, iteration)
        with open(new_model_path, 'wb') as f_out:
            torch.save(all_model, f_out)
        self.model_kept.append(new_model_path)

        if len(self.model_kept) >= self.max_keep:
            os.remove(self.model_kept[0])
            self.model_kept.pop(0)

    def load_model(self, model_path, enc_only=True):
        print('load model from {}'.format(model_path))
        with open(model_path, 'rb') as f_in:
            all_model = torch.load(f_in)
            self.Encoder.load_state_dict(all_model['encoder'])
            self.Decoder.load_state_dict(all_model['decoder'])
            #self.Genrator.load_state_dict(all_model['generator'])
            if not enc_only:
                self.LatentDiscriminator.load_state_dict(
                    all_model['latent_discriminator'])
                self.PatchDiscriminator.load_state_dict(
                    all_model['patch_discriminator'])

    def set_eval(self):
        self.Encoder.eval()
        self.Decoder.eval()
        self.Generator.eval()
        #self.LatentDiscriminator.eval()

    def test_step(self, x, c):
        self.set_eval()
        x = to_var(x).permute(0, 2, 1)
        enc = self.Encoder(x)
        x_tilde = self.Decoder(enc, c)
        return x_tilde.data.cpu().numpy()

    def permute_data(self, data):
        C = [to_var(c, requires_grad=False) for c in data[:2]]
        X = [to_var(x).permute(0, 2, 1) for x in data[2:]]
        return C, X

    def sample_c(self, size):
        c_sample = Variable(torch.multinomial(torch.ones(8),
                                              num_samples=size,
                                              replacement=True),
                            requires_grad=False)
        c_sample = c_sample.cuda() if torch.cuda.is_available() else c_sample
        return c_sample

    def cal_acc(self, logits, y_true):
        _, ind = torch.max(logits, dim=1)
        acc = torch.sum(
            (ind == y_true).type(torch.FloatTensor)) / y_true.size(0)
        return acc

    def encode_step(self, *args):
        enc_list = []
        for x in args:
            enc = self.Encoder(x)
            enc_list.append(enc)
        return tuple(enc_list)

    def decode_step(self, enc, c):
        x_tilde = self.Decoder(enc, c)
        return x_tilde

    def latent_discriminate_step(self,
                                 enc_i_t,
                                 enc_i_tk,
                                 enc_i_prime,
                                 enc_j,
                                 is_dis=True):
        same_pair = torch.cat([enc_i_t, enc_i_tk], dim=1)
        diff_pair = torch.cat([enc_i_prime, enc_j], dim=1)
        if is_dis:
            same_val = self.LatentDiscriminator(same_pair)
            diff_val = self.LatentDiscriminator(diff_pair)
            w_dis = torch.mean(same_val - diff_val)
            gp = calculate_gradients_penalty(self.LatentDiscriminator,
                                             same_pair, diff_pair)
            return w_dis, gp
        else:
            diff_val = self.LatentDiscriminator(diff_pair)
            loss_adv = -torch.mean(diff_val)
            return loss_adv

    def patch_discriminate_step(self, x, x_tilde, cal_gp=True):
        # w-distance
        D_real, real_logits = self.PatchDiscriminator(x, classify=True)
        D_fake, fake_logits = self.PatchDiscriminator(x_tilde, classify=True)
        w_dis = torch.mean(D_real - D_fake)
        if cal_gp:
            gp = calculate_gradients_penalty(self.PatchDiscriminator, x,
                                             x_tilde)
            return w_dis, real_logits, fake_logits, gp
        else:
            return w_dis, real_logits, fake_logits

    # backup
    #def classify():
    #    # aux clssify loss
    #    criterion = nn.NLLLoss()
    #    c_loss = criterion(real_logits, c) + criterion(fake_logits, c_sample)
    #    real_acc = self.cal_acc(real_logits, c)
    #    fake_acc = self.cal_acc(fake_logits, c_sample)

    def train(self, model_path, flag='train'):
        # load hyperparams
        hps = self.hps
        for iteration in range(hps.iters):
            # calculate current alpha
            if iteration + 1 < hps.lat_sched_iters and iteration >= hps.enc_pretrain_iters:
                current_alpha = hps.alpha_enc * (
                    iteration + 1 - hps.enc_pretrain_iters) / (
                        hps.lat_sched_iters - hps.enc_pretrain_iters)
            else:
                current_alpha = 0
            if iteration >= hps.enc_pretrain_iters:
                n_latent_steps = hps.n_latent_steps \
                    if iteration > hps.enc_pretrain_iters else hps.dis_pretrain_iters
                for step in range(n_latent_steps):
                    #===================== Train latent discriminator =====================#
                    data = next(self.data_loader)
                    (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                                 x_j) = self.permute_data(data)
                    # encode
                    enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
                        x_i_t, x_i_tk, x_i_prime, x_j)
                    # latent discriminate
                    latent_w_dis, latent_gp = self.latent_discriminate_step(
                        enc_i_t, enc_i_tk, enc_i_prime, enc_j)
                    lat_loss = -hps.alpha_dis * latent_w_dis + hps.lambda_ * latent_gp
                    reset_grad([self.LatentDiscriminator])
                    lat_loss.backward()
                    grad_clip([self.LatentDiscriminator],
                              self.hps.max_grad_norm)
                    self.lat_opt.step()
                    # print info
                    info = {
                        f'{flag}/D_latent_w_dis': latent_w_dis.data[0],
                        f'{flag}/latent_gp': latent_gp.data[0],
                    }
                    slot_value = (step, iteration + 1, hps.iters) + \
                            tuple([value for value in info.values()])
                    log = 'lat_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f'
                    print(log % slot_value)
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration)
            # two stage training
            if iteration >= hps.patch_start_iter:
                for step in range(hps.n_patch_steps):
                    #===================== Train patch discriminator =====================#
                    data = next(self.data_loader)
                    (c_i, _), (x_i_t, _, _, _) = self.permute_data(data)
                    # encode
                    enc_i_t, = self.encode_step(x_i_t)
                    c_sample = self.sample_c(x_i_t.size(0))
                    x_tilde = self.decode_step(enc_i_t, c_i)
                    # Aux classify loss
                    patch_w_dis, real_logits, fake_logits, patch_gp = \
                            self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=True)
                    patch_loss = -hps.beta_dis * patch_w_dis + hps.lambda_ * patch_gp + hps.beta_clf * c_loss
                    reset_grad([self.PatchDiscriminator])
                    patch_loss.backward()
                    grad_clip([self.PatchDiscriminator],
                              self.hps.max_grad_norm)
                    self.patch_opt.step()
                    # print info
                    info = {
                        f'{flag}/D_patch_w_dis': patch_w_dis.data[0],
                        f'{flag}/patch_gp': patch_gp.data[0],
                        f'{flag}/c_loss': c_loss.data[0],
                        f'{flag}/real_acc': real_acc,
                        f'{flag}/fake_acc': fake_acc,
                    }
                    slot_value = (step, iteration + 1, hps.iters) + \
                            tuple([value for value in info.values()])
                    log = 'patch_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f, c_loss=%.3f, real_acc=%.2f, fake_acc=%.2f'
                    print(log % slot_value)
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration)
            #===================== Train G =====================#
            data = next(self.data_loader)
            (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                         x_j) = self.permute_data(data)
            # encode
            enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
                x_i_t, x_i_tk, x_i_prime, x_j)
            # decode
            x_tilde = self.decode_step(enc_i_t, c_i)
            loss_rec = torch.mean(torch.abs(x_tilde - x_i_t))
            # latent discriminate
            loss_adv = self.latent_discriminate_step(enc_i_t,
                                                     enc_i_tk,
                                                     enc_i_prime,
                                                     enc_j,
                                                     is_dis=False)
            ae_loss = loss_rec + current_alpha * loss_adv
            reset_grad([self.Encoder, self.Decoder])
            retain_graph = True if hps.n_patch_steps > 0 else False
            ae_loss.backward(retain_graph=retain_graph)
            grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
            self.ae_opt.step()
            info = {
                f'{flag}/loss_rec': loss_rec.data[0],
                f'{flag}/loss_adv': loss_adv.data[0],
                f'{flag}/alpha': current_alpha,
            }
            slot_value = (iteration + 1, hps.iters) + tuple(
                [value for value in info.values()])
            log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e'
            print(log % slot_value)
            for tag, value in info.items():
                self.logger.scalar_summary(tag, value, iteration + 1)
            # patch discriminate
            if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter:
                c_sample = self.sample_c(x_i_t.size(0))
                x_tilde = self.decode_step(enc_i_t, c_sample)
                patch_w_dis, real_logits, fake_logits = \
                        self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False)
                patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss
                reset_grad([self.Decoder])
                patch_loss.backward()
                grad_clip([self.Decoder], self.hps.max_grad_norm)
                self.decoder_opt.step()
                info = {
                    f'{flag}/G_patch_w_dis': patch_w_dis.data[0],
                    f'{flag}/c_loss': c_loss.data[0],
                    f'{flag}/real_acc': real_acc,
                    f'{flag}/fake_acc': fake_acc,
                }
                slot_value = (iteration + 1, hps.iters) + tuple(
                    [value for value in info.values()])
                log = 'G:[%06d/%06d]: patch_w_dis=%.2f, c_loss=%.2f, real_acc=%.2f, fake_acc=%.2f'
                print(log % slot_value)
                for tag, value in info.items():
                    self.logger.scalar_summary(tag, value, iteration + 1)
            if iteration % 1000 == 0 or iteration + 1 == hps.iters:
                self.save_model(model_path, iteration)