Esempio n. 1
0
 def build_model(self):
     hps = self.hps
     ns = self.hps.ns
     emb_size = self.hps.emb_size
     self.Encoder = Encoder(ns=ns, dp=hps.enc_dp)
     self.Decoder = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size)
     self.Generator = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size)
     self.LatentDiscriminator = LatentDiscriminator(ns=ns, dp=hps.dis_dp)
     self.PatchDiscriminator = PatchDiscriminator(ns=ns,
                                                  n_class=hps.n_speakers)
     if torch.cuda.is_available():
         self.Encoder.cuda()
         self.Decoder.cuda()
         self.Generator.cuda()
         self.LatentDiscriminator.cuda()
         self.PatchDiscriminator.cuda()
     betas = (0.5, 0.9)
     params = list(self.Encoder.parameters()) + list(
         self.Decoder.parameters())
     self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas)
     self.gen_opt = optim.Adam(self.Generator.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.lat_opt = optim.Adam(self.LatentDiscriminator.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(),
                                 lr=self.hps.lr,
                                 betas=betas)
Esempio n. 2
0
 def build_model(self):
     hps = self.hps
     ns = self.hps.ns
     emb_size = self.hps.emb_size
     self.Encoder = cc(Encoder(ns=ns, dp=hps.enc_dp))
     self.Decoder = cc(Decoder(ns=ns, c_a=hps.n_speakers,
                               emb_size=emb_size))
     self.Generator = cc(
         Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size))
     self.SpeakerClassifier = cc(
         SpeakerClassifier(ns=ns, n_class=hps.n_speakers, dp=hps.dis_dp))
     self.PatchDiscriminator = cc(
         nn.DataParallel(PatchDiscriminator(ns=ns, n_class=hps.n_speakers)))
     betas = (0.5, 0.9)
     params = list(self.Encoder.parameters()) + list(
         self.Decoder.parameters())
     self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas)
     self.clf_opt = optim.Adam(self.SpeakerClassifier.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.gen_opt = optim.Adam(self.Generator.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(),
                                 lr=self.hps.lr,
                                 betas=betas)
 def build_model(self, wavenet_mel):
     hps = self.hps
     ns = self.hps.ns
     emb_size = self.hps.emb_size
     c = 80 if wavenet_mel else 513
     patch_classify_kernel = (3, 4) if wavenet_mel else (17, 4)
     self.Encoder = cc(Encoder(c_in=c, ns=ns, dp=hps.enc_dp))
     self.Decoder = cc(
         Decoder(c_out=c, ns=ns, c_a=hps.n_speakers, emb_size=emb_size))
     self.Generator = cc(
         Decoder(c_out=c, ns=ns, c_a=hps.n_speakers, emb_size=emb_size))
     self.SpeakerClassifier = cc(
         SpeakerClassifier(ns=ns, n_class=hps.n_speakers, dp=hps.dis_dp))
     self.PatchDiscriminator = cc(
         nn.DataParallel(
             PatchDiscriminator(
                 ns=ns,
                 n_class=hps.n_speakers,
                 classify_kernel_size=patch_classify_kernel)))
     betas = (0.5, 0.9)
     params = list(self.Encoder.parameters()) + list(
         self.Decoder.parameters())
     self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas)
     self.clf_opt = optim.Adam(self.SpeakerClassifier.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.gen_opt = optim.Adam(self.Generator.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(),
                                 lr=self.hps.lr,
                                 betas=betas)
Esempio n. 4
0
    transforms.RandomHorizontalFlip(),
    transforms.RandomGrayscale(),
    transforms.ToTensor(),
])

train_set = DS(args.root, train_tf)
iterator_train = iter(data.DataLoader(
    train_set,
    batch_size=args.batch_size,
    sampler=InfiniteSampler(len(train_set)),
    num_workers=args.n_threads))
print(len(train_set))

g_model = InpaintNet().to(device)
fd_model = FeaturePatchDiscriminator().to(device)
pd_model = PatchDiscriminator().to(device)
l1 = nn.L1Loss().to(device)
cons = ConsistencyLoss().to(device)

start_iter = 0
g_optimizer = torch.optim.Adam(
    g_model.parameters(),
    args.lr, (args.b1, args.b2))
fd_optimizer = torch.optim.Adam(
    fd_model.parameters(),
    args.lr, (args.b1, args.b2))
pd_optimizer = torch.optim.Adam(
    pd_model.parameters(),
    args.lr, (args.b1, args.b2))

if args.resume:
Esempio n. 5
0
                            WORDS_NUM=args.WORDS_NUM,
                            BRANCH_NUM=args.BRANCH_NUM,
                            transform=train_tf)
assert dataset_train
train_set = torch.utils.data.DataLoader(dataset_train,
                                        batch_size=args.batch_size,
                                        drop_last=True,
                                        shuffle=True,
                                        num_workers=args.n_threads)

print(len(train_set))

ixtoword_train = dataset_train.ixtoword

g_model = InpaintNet().to(device)
pd_model = PatchDiscriminator().to(device)
l1 = nn.L1Loss().to(device)

start_epoch = 0
g_optimizer_t = torch.optim.Adam(g_model.parameters(), args.lr,
                                 (args.b1, args.b2))
pd_optimizer_t = torch.optim.Adam(pd_model.parameters(), args.lr,
                                  (args.b1, args.b2))

if args.resume:
    g_checkpoint = torch.load(f'{args.save_dir}/ckpt/G_{args.resume}.pth',
                              map_location=device)
    g_model.load_state_dict(g_checkpoint)
    pd_checkpoint = torch.load(f'{args.save_dir}/ckpt/PD_{args.resume}.pth',
                               map_location=device)
    pd_model.load_state_dict(pd_checkpoint)
Esempio n. 6
0
    def build_model(self):
        hps = self.hps
        ns = self.hps.ns
        emb_size = self.hps.emb_size
        betas = (0.5, 0.9)

        #---stage one---#
        self.Encoder = cc(
            Encoder(ns=ns,
                    dp=hps.enc_dp,
                    emb_size=emb_size,
                    seg_len=hps.seg_len,
                    one_hot=self.one_hot,
                    binary_output=self.binary_output,
                    binary_ver=self.binary_ver))
        self.Decoder = cc(
            Decoder(ns=ns,
                    c_in=emb_size,
                    c_h=emb_size,
                    c_a=hps.n_speakers,
                    seg_len=hps.seg_len,
                    inp_emb=self.one_hot or self.binary_output))
        self.SpeakerClassifier = cc(
            SpeakerClassifier(
                ns=ns,
                c_in=emb_size if not self.binary_output else emb_size *
                emb_size,
                c_h=emb_size,
                n_class=hps.n_speakers,
                dp=hps.dis_dp,
                seg_len=hps.seg_len))

        #---stage one opts---#
        params = list(self.Encoder.parameters()) + \
            list(self.Decoder.parameters())
        self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas)
        self.clf_opt = optim.Adam(self.SpeakerClassifier.parameters(),
                                  lr=self.hps.lr,
                                  betas=betas)

        #---stage two---#
        self.Generator = cc(
            Decoder(ns=ns,
                    c_in=emb_size,
                    c_h=emb_size,
                    c_a=hps.n_speakers
                    if not self.targeted_G else hps.n_target_speakers))
        self.PatchDiscriminator = cc(
            nn.DataParallel(
                PatchDiscriminator(
                    ns=ns,
                    n_class=hps.n_speakers
                    if not self.targeted_G else hps.n_target_speakers,
                    seg_len=hps.seg_len)))

        #---stage two opts---#
        self.gen_opt = optim.Adam(self.Generator.parameters(),
                                  lr=self.hps.lr,
                                  betas=betas)
        self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(),
                                    lr=self.hps.lr,
                                    betas=betas)
Esempio n. 7
0
class Solver(object):
    def __init__(self, hps, data_loader, log_dir='./log/'):
        self.hps = hps
        self.data_loader = data_loader
        self.model_kept = []
        self.max_keep = 20
        self.build_model()
        self.logger = Logger(log_dir)

    def build_model(self):
        hps = self.hps
        ns = self.hps.ns
        emb_size = self.hps.emb_size
        self.Encoder = Encoder(ns=ns, dp=hps.enc_dp)
        self.Decoder = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size)
        self.Generator = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size)
        self.LatentDiscriminator = LatentDiscriminator(ns=ns, dp=hps.dis_dp)
        self.PatchDiscriminator = PatchDiscriminator(ns=ns,
                                                     n_class=hps.n_speakers)
        if torch.cuda.is_available():
            self.Encoder.cuda()
            self.Decoder.cuda()
            self.Generator.cuda()
            self.LatentDiscriminator.cuda()
            self.PatchDiscriminator.cuda()
        betas = (0.5, 0.9)
        params = list(self.Encoder.parameters()) + list(
            self.Decoder.parameters())
        self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas)
        self.gen_opt = optim.Adam(self.Generator.parameters(),
                                  lr=self.hps.lr,
                                  betas=betas)
        self.lat_opt = optim.Adam(self.LatentDiscriminator.parameters(),
                                  lr=self.hps.lr,
                                  betas=betas)
        self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(),
                                    lr=self.hps.lr,
                                    betas=betas)

    def save_model(self, model_path, iteration, enc_only=True):
        if not enc_only:
            all_model = {
                'encoder': self.Encoder.state_dict(),
                'decoder': self.Decoder.state_dict(),
                'generator': self.Generator.state_dict(),
                'latent_discriminator': self.LatentDiscriminator.state_dict(),
                'patch_discriminator': self.PatchDiscriminator.state_dict(),
            }
        else:
            all_model = {
                'encoder': self.Encoder.state_dict(),
                'decoder': self.Decoder.state_dict(),
                'generator': self.Generator.state_dict(),
            }
        new_model_path = '{}-{}'.format(model_path, iteration)
        with open(new_model_path, 'wb') as f_out:
            torch.save(all_model, f_out)
        self.model_kept.append(new_model_path)

        if len(self.model_kept) >= self.max_keep:
            os.remove(self.model_kept[0])
            self.model_kept.pop(0)

    def load_model(self, model_path, enc_only=True):
        print('load model from {}'.format(model_path))
        with open(model_path, 'rb') as f_in:
            all_model = torch.load(f_in)
            self.Encoder.load_state_dict(all_model['encoder'])
            self.Decoder.load_state_dict(all_model['decoder'])
            #self.Genrator.load_state_dict(all_model['generator'])
            if not enc_only:
                self.LatentDiscriminator.load_state_dict(
                    all_model['latent_discriminator'])
                self.PatchDiscriminator.load_state_dict(
                    all_model['patch_discriminator'])

    def set_eval(self):
        self.Encoder.eval()
        self.Decoder.eval()
        self.Generator.eval()
        #self.LatentDiscriminator.eval()

    def test_step(self, x, c):
        self.set_eval()
        x = to_var(x).permute(0, 2, 1)
        enc = self.Encoder(x)
        x_tilde = self.Decoder(enc, c)
        return x_tilde.data.cpu().numpy()

    def permute_data(self, data):
        C = [to_var(c, requires_grad=False) for c in data[:2]]
        X = [to_var(x).permute(0, 2, 1) for x in data[2:]]
        return C, X

    def sample_c(self, size):
        c_sample = Variable(torch.multinomial(torch.ones(8),
                                              num_samples=size,
                                              replacement=True),
                            requires_grad=False)
        c_sample = c_sample.cuda() if torch.cuda.is_available() else c_sample
        return c_sample

    def cal_acc(self, logits, y_true):
        _, ind = torch.max(logits, dim=1)
        acc = torch.sum(
            (ind == y_true).type(torch.FloatTensor)) / y_true.size(0)
        return acc

    def encode_step(self, *args):
        enc_list = []
        for x in args:
            enc = self.Encoder(x)
            enc_list.append(enc)
        return tuple(enc_list)

    def decode_step(self, enc, c):
        x_tilde = self.Decoder(enc, c)
        return x_tilde

    def latent_discriminate_step(self,
                                 enc_i_t,
                                 enc_i_tk,
                                 enc_i_prime,
                                 enc_j,
                                 is_dis=True):
        same_pair = torch.cat([enc_i_t, enc_i_tk], dim=1)
        diff_pair = torch.cat([enc_i_prime, enc_j], dim=1)
        if is_dis:
            same_val = self.LatentDiscriminator(same_pair)
            diff_val = self.LatentDiscriminator(diff_pair)
            w_dis = torch.mean(same_val - diff_val)
            gp = calculate_gradients_penalty(self.LatentDiscriminator,
                                             same_pair, diff_pair)
            return w_dis, gp
        else:
            diff_val = self.LatentDiscriminator(diff_pair)
            loss_adv = -torch.mean(diff_val)
            return loss_adv

    def patch_discriminate_step(self, x, x_tilde, cal_gp=True):
        # w-distance
        D_real, real_logits = self.PatchDiscriminator(x, classify=True)
        D_fake, fake_logits = self.PatchDiscriminator(x_tilde, classify=True)
        w_dis = torch.mean(D_real - D_fake)
        if cal_gp:
            gp = calculate_gradients_penalty(self.PatchDiscriminator, x,
                                             x_tilde)
            return w_dis, real_logits, fake_logits, gp
        else:
            return w_dis, real_logits, fake_logits

    # backup
    #def classify():
    #    # aux clssify loss
    #    criterion = nn.NLLLoss()
    #    c_loss = criterion(real_logits, c) + criterion(fake_logits, c_sample)
    #    real_acc = self.cal_acc(real_logits, c)
    #    fake_acc = self.cal_acc(fake_logits, c_sample)

    def train(self, model_path, flag='train'):
        # load hyperparams
        hps = self.hps
        for iteration in range(hps.iters):
            # calculate current alpha
            if iteration + 1 < hps.lat_sched_iters and iteration >= hps.enc_pretrain_iters:
                current_alpha = hps.alpha_enc * (
                    iteration + 1 - hps.enc_pretrain_iters) / (
                        hps.lat_sched_iters - hps.enc_pretrain_iters)
            else:
                current_alpha = 0
            if iteration >= hps.enc_pretrain_iters:
                n_latent_steps = hps.n_latent_steps \
                    if iteration > hps.enc_pretrain_iters else hps.dis_pretrain_iters
                for step in range(n_latent_steps):
                    #===================== Train latent discriminator =====================#
                    data = next(self.data_loader)
                    (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                                 x_j) = self.permute_data(data)
                    # encode
                    enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
                        x_i_t, x_i_tk, x_i_prime, x_j)
                    # latent discriminate
                    latent_w_dis, latent_gp = self.latent_discriminate_step(
                        enc_i_t, enc_i_tk, enc_i_prime, enc_j)
                    lat_loss = -hps.alpha_dis * latent_w_dis + hps.lambda_ * latent_gp
                    reset_grad([self.LatentDiscriminator])
                    lat_loss.backward()
                    grad_clip([self.LatentDiscriminator],
                              self.hps.max_grad_norm)
                    self.lat_opt.step()
                    # print info
                    info = {
                        f'{flag}/D_latent_w_dis': latent_w_dis.data[0],
                        f'{flag}/latent_gp': latent_gp.data[0],
                    }
                    slot_value = (step, iteration + 1, hps.iters) + \
                            tuple([value for value in info.values()])
                    log = 'lat_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f'
                    print(log % slot_value)
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration)
            # two stage training
            if iteration >= hps.patch_start_iter:
                for step in range(hps.n_patch_steps):
                    #===================== Train patch discriminator =====================#
                    data = next(self.data_loader)
                    (c_i, _), (x_i_t, _, _, _) = self.permute_data(data)
                    # encode
                    enc_i_t, = self.encode_step(x_i_t)
                    c_sample = self.sample_c(x_i_t.size(0))
                    x_tilde = self.decode_step(enc_i_t, c_i)
                    # Aux classify loss
                    patch_w_dis, real_logits, fake_logits, patch_gp = \
                            self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=True)
                    patch_loss = -hps.beta_dis * patch_w_dis + hps.lambda_ * patch_gp + hps.beta_clf * c_loss
                    reset_grad([self.PatchDiscriminator])
                    patch_loss.backward()
                    grad_clip([self.PatchDiscriminator],
                              self.hps.max_grad_norm)
                    self.patch_opt.step()
                    # print info
                    info = {
                        f'{flag}/D_patch_w_dis': patch_w_dis.data[0],
                        f'{flag}/patch_gp': patch_gp.data[0],
                        f'{flag}/c_loss': c_loss.data[0],
                        f'{flag}/real_acc': real_acc,
                        f'{flag}/fake_acc': fake_acc,
                    }
                    slot_value = (step, iteration + 1, hps.iters) + \
                            tuple([value for value in info.values()])
                    log = 'patch_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f, c_loss=%.3f, real_acc=%.2f, fake_acc=%.2f'
                    print(log % slot_value)
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration)
            #===================== Train G =====================#
            data = next(self.data_loader)
            (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                         x_j) = self.permute_data(data)
            # encode
            enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
                x_i_t, x_i_tk, x_i_prime, x_j)
            # decode
            x_tilde = self.decode_step(enc_i_t, c_i)
            loss_rec = torch.mean(torch.abs(x_tilde - x_i_t))
            # latent discriminate
            loss_adv = self.latent_discriminate_step(enc_i_t,
                                                     enc_i_tk,
                                                     enc_i_prime,
                                                     enc_j,
                                                     is_dis=False)
            ae_loss = loss_rec + current_alpha * loss_adv
            reset_grad([self.Encoder, self.Decoder])
            retain_graph = True if hps.n_patch_steps > 0 else False
            ae_loss.backward(retain_graph=retain_graph)
            grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
            self.ae_opt.step()
            info = {
                f'{flag}/loss_rec': loss_rec.data[0],
                f'{flag}/loss_adv': loss_adv.data[0],
                f'{flag}/alpha': current_alpha,
            }
            slot_value = (iteration + 1, hps.iters) + tuple(
                [value for value in info.values()])
            log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e'
            print(log % slot_value)
            for tag, value in info.items():
                self.logger.scalar_summary(tag, value, iteration + 1)
            # patch discriminate
            if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter:
                c_sample = self.sample_c(x_i_t.size(0))
                x_tilde = self.decode_step(enc_i_t, c_sample)
                patch_w_dis, real_logits, fake_logits = \
                        self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False)
                patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss
                reset_grad([self.Decoder])
                patch_loss.backward()
                grad_clip([self.Decoder], self.hps.max_grad_norm)
                self.decoder_opt.step()
                info = {
                    f'{flag}/G_patch_w_dis': patch_w_dis.data[0],
                    f'{flag}/c_loss': c_loss.data[0],
                    f'{flag}/real_acc': real_acc,
                    f'{flag}/fake_acc': fake_acc,
                }
                slot_value = (iteration + 1, hps.iters) + tuple(
                    [value for value in info.values()])
                log = 'G:[%06d/%06d]: patch_w_dis=%.2f, c_loss=%.2f, real_acc=%.2f, fake_acc=%.2f'
                print(log % slot_value)
                for tag, value in info.items():
                    self.logger.scalar_summary(tag, value, iteration + 1)
            if iteration % 1000 == 0 or iteration + 1 == hps.iters:
                self.save_model(model_path, iteration)