Пример #1
0
 def cond_ll(self, inputs, targets, lengths, z, pad_id):
     init_hidden = torch.tanh(self.fc(z)).unsqueeze(0)
     init_hidden = [
         hn.contiguous() for hn in torch.chunk(init_hidden, 2, 2)
     ]
     dec_embeds = self.embed(inputs)
     outputs, _ = self.decoder(dec_embeds, lengths, init_hidden=init_hidden)
     outputs = self.fcout(outputs)
     loss = recon_loss(outputs, targets, pad_id).expand(z.size(0),
                                                        1) / z.size(0)
     return -loss
Пример #2
0
def baseline(args, data_iter, model, optimizer, epoch, train=True):
    batch_size = args.batch_size
    size = len(data_iter.dataset)
    if train:
        model.train()
    else:
        model.eval()
    # data_iter.init_epoch()
    re_loss = 0
    r_re_loss = 0
    kl_divergence = 0
    r_kl_divergence = 0
    discriminator_loss = 0
    nll = 0
    for i, (data, label) in enumerate(data_iter):
        data = data.to(args.device)
        disloss = torch.zeros(1).to(args.device)

        if train:
            recon, q_z, p_z, z = model(data)
            recon = recon.view(-1, data.size(-2), data.size(-1))
            reloss = recon_loss(recon, data)  # sum over batch
            kld = total_kld(q_z, p_z)  # sum over batch
            optimizer.zero_grad()
            loss = (reloss + kld) / batch_size
            loss.backward()
            optimizer.step()
        else:
            angles = torch.randint(0, 3, (data.size(0), )).to(args.device)
            r_data = batch_rotate(data.clone(), angles)
            r_recon, r_qz, r_pz, r_z = model(r_data)
            r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1))
            reloss = recon_loss(r_recon, r_data)
            kld = total_kld(r_qz, r_pz)

        re_loss += reloss.item() / size
        kl_divergence += kld.item() / size
        discriminator_loss += disloss.item() / size

    nll = re_loss + kl_divergence
    return nll, re_loss, kl_divergence, discriminator_loss
Пример #3
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
Пример #4
0
def main(args):
    print("Loading data")
    dataset = args.data.rstrip('/').split('/')[-1]
    torch.cuda.set_device(args.cuda)
    device = args.device
    if dataset == 'mnist':
        train_loader, test_loader = get_mnist(args.batch_size, 'data/mnist')
        num = 10
    elif dataset == 'fashion':
        train_loader, test_loader = get_fashion_mnist(args.batch_size,
                                                      'data/fashion')
        num = 10
    elif dataset == 'svhn':
        train_loader, test_loader, _ = get_svhn(args.batch_size, 'data/svhn')
        num = 10
    elif dataset == 'stl':
        train_loader, test_loader, _ = get_stl10(args.batch_size, 'data/stl10')
    elif dataset == 'cifar':
        train_loader, test_loader = get_cifar(args.batch_size, 'data/cifar')
        num = 10
    elif dataset == 'chair':
        train_loader, test_loader = get_chair(args.batch_size,
                                              '~/data/rendered_chairs')
        num = 1393
    elif dataset == 'yale':
        train_loader, test_loader = get_yale(args.batch_size, 'data/yale')
        num = 38
    model = VAE(28 * 28, args.code_dim, args.batch_size, num,
                dataset).to(device)
    phi = nn.Sequential(
        nn.Linear(args.code_dim, args.phi_dim),
        nn.LeakyReLU(0.2, True),
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer_phi = torch.optim.Adam(phi.parameters(), lr=args.lr)
    criterion = nn.MSELoss(reduction='sum')
    for epoch in range(args.epochs):
        re_loss = 0
        kl_div = 0
        size = len(train_loader.dataset)
        for data, target in train_loader:
            data, target = data.squeeze(1).to(device), target.to(device)
            c = F.one_hot(target.long(), num_classes=num).float()
            output, q_z, p_z, z = model(data, c)
            hsic = HSIC(phi(z), target.long(), num)
            if dataset == 'mnist' or dataset == 'fashion':
                reloss = recon_loss(output, data.view(-1, 28 * 28))
            else:
                reloss = criterion(output, data)
            kld = total_kld(q_z, p_z)
            loss = reloss + kld + args.c * hsic

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer_phi.zero_grad()
            neg = -HSIC(phi(z.detach()), target.long(), num)
            neg.backward()
            optimizer_phi.step()

            re_loss += reloss.item() / size
            kl_div += kld.item() / size
        print('-' * 50)
        print(
            " Epoch {} |re loss {:5.2f} | kl div {:5.2f} | hs {:5.2f}".format(
                epoch, re_loss, kl_div, hsic))
    for data, target in test_loader:
        data, target = data.squeeze(1).to(device), target.to(device)
        c = F.one_hot(target.long(), num_classes=num).float()
        output, _, _, z = model(data, c)
        break
    if dataset == 'mnist' or dataset == 'fashion':
        img_size = [data.size(0), 1, 28, 28]
    else:
        img_size = [data.size(0), 3, 32, 32]
    images = [data.view(img_size)[:30].cpu()]
    for i in range(10):
        c = F.one_hot(torch.ones(z.size(0)).long() * i,
                      num_classes=num).float().to(device)
        output = model.decoder(torch.cat((z, c), dim=-1))
        images.append(output.view(img_size)[:30].cpu())
    images = torch.cat(images, dim=0)
    save_image(images,
               'imgs/recon_c{}_{}.png'.format(int(args.c), dataset),
               nrow=30)
    torch.save(model.state_dict(),
               'vae_c{}_{}.pt'.format(int(args.c), dataset))
    # z = p_z.sample()
    # for i in range(10):
    #     c = F.one_hot(torch.ones(z.size(0)).long()*i, num_classes=10).float().to(device)
    #     output = model.decoder(torch.cat((z, c), dim=-1))
    #     n = min(z.size(0), 8)
    #     save_image(output.view(z.size(0), 1, 28, 28)[:n].cpu(), 'imgs/recon_{}.png'.format(i), nrow=n)
    if args.tsne:
        datas, targets = [], []
        for i, (data, target) in enumerate(test_loader):
            datas.append(data), targets.append(target)
            if i >= 5:
                break
        data, target = torch.cat(datas, dim=0), torch.cat(targets, dim=0)
        c = F.one_hot(target.long(), num_classes=num).float()
        _, _, _, z = model(data.to(args.device), c.to(args.device))
        z, target = z.detach().cpu().numpy(), target.cpu().numpy()
        tsne = TSNE(n_components=2, init='pca', random_state=0)
        z_2d = tsne.fit_transform(z)
        plt.figure(figsize=(6, 5))
        plot_embedding(z_2d, target)
        plt.savefig('tsnes/tsne_c{}_{}.png'.format(int(args.c), dataset))
Пример #5
0
    def train(self):
        # Create the dataloader
        self.dataloader = self.create_dataloader()
        for e in range(self.num_epochs):
            discriminator_loss = 0
            va_loss = 0
            for i, states_batch in enumerate(self.dataloader):

                states_true_1, states_true_2 = states_batch

                states_true_1 = states_true_1.to(self.device)
                states_true_2 = states_true_2.to(self.device)

                reconstruction, latent, mu, logvar = self.model(states_true_1)
                vae_recon_loss = loss.recon_loss(states_true_1, reconstruction)
                vae_kl_divergence = loss.kl_divergence(mu, logvar)
                d_z = self.disc(latent)
                vae_tc_loss = (d_z[:, :1] - d_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kl_divergence + self.beta * vae_tc_loss

                self.optim_vae.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_vae.step()

                va_loss += vae_loss.item()

                states_true_2 = states_true_2.to(self.device)
                z_prime = self.model(states_true_2, no_dec=True)
                z_pperm = loss.permute_dims(z_prime).detach()
                D_z_pperm = self.disc(z_pperm)
                try:
                    D_tc_loss = 0.5 * (F.cross_entropy(d_z, self.zeros) +
                                       F.cross_entropy(D_z_pperm, self.ones))
                except:
                    batch_size, _ = d_z.shape
                    ones = torch.ones(batch_size,
                                      dtype=torch.long,
                                      device=self.device)
                    zeros = torch.zeros(batch_size,
                                        dtype=torch.long,
                                        device=self.device)
                    D_tc_loss = 0.5 * (F.cross_entropy(d_z, zeros) +
                                       F.cross_entropy(D_z_pperm, ones))

                self.optim_disc.zero_grad()
                D_tc_loss.backward()
                self.optim_disc.step()

                discriminator_loss += D_tc_loss.item()

            # Add the loss to tensorboard
            self.writer.add_scalar('data/disc_loss',
                                   discriminator_loss / len(self.dataloader),
                                   e)
            self.writer.add_scalar('data/vae_loss',
                                   va_loss / len(self.dataloader), e)

        self.writer.close()
        # Save model
        self.save_model()
        # GIF visualization
        self.visualize_traverse()
Пример #6
0
def run(args, data_iter, model, optimizer, epoch, train=True):
    batch_size = args.batch_size
    size = len(data_iter.dataset)
    if train:
        model.train()
    else:
        model.eval()
    # data_iter.init_epoch()
    re_loss = 0
    kl_divergence = 0
    discriminator_loss = 0
    nll = 0
    for i, (data, label) in enumerate(data_iter):
        data = data.to(args.device)
        recon, q_z, p_z, z = model(data)
        recon = recon.view(-1, data.size(-2), data.size(-1))
        reloss = recon_loss(recon, data)  # sum over batch
        kld = total_kld(q_z, p_z)  # sum over batch
        disloss = torch.zeros(1).to(args.device)

        if args.ro:
            disloss, r_reloss, r_kld = [], [], []
            for d in range(1, len(rotations)):
                angles = torch.tensor([d],
                                      dtype=torch.long,
                                      device=args.device).expand(data.size(0))
                r_data = batch_rotate(data.clone(), angles)
                r_recon, r_qz, r_pz, r_z = model(r_data)
                r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1))
                D_z = D(r_z)
                disloss.append(disc_loss(D_z, angles))  # sum over batch
                r_reloss.append(recon_loss(r_recon, r_data))
                r_kld.append(total_kld(r_qz, r_pz))
            disloss = sum(disloss)  # / (len(rotations)-1)
            r_reloss = sum(r_reloss)  # / (len(rotations)-1)
            r_kld = sum(r_kld)  # / (len(rotations)-1)

            # angles = torch.randint(0, 3, (data.size(0), )).to(args.device)
            # r_data = batch_rotate(data.clone(), angles)
            # r_recon, r_qz, r_pz, r_z = model(r_data)
            # r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1))
            # D_z = D(r_z)
            # disloss = disc_loss(D_z, angles) # sum over batch
            # r_reloss = recon_loss(r_recon, r_data)
            # r_kld = total_kld(r_qz, r_pz)

        if train:
            if args.ro:
                optimizer_D.zero_grad()
                D_loss = disloss / batch_size
                D_loss.backward(retain_graph=True)
                optimizer_D.step()

                optimizer.zero_grad()
                loss = (reloss + kld + r_reloss + r_kld - disloss) / batch_size
                loss.backward()
                optimizer.step()
            else:
                optimizer.zero_grad()
                loss = (reloss + kld) / batch_size
                loss.backward()
                optimizer.step()

        re_loss += reloss.item() / size
        kl_divergence += kld.item() / size
        discriminator_loss += disloss.item() / size

    nll = re_loss + kl_divergence
    return nll, re_loss, kl_divergence, discriminator_loss
Пример #7
0
def train(args):

    # Variable size.
    bs, ch, h, w = args.batch_size, 3, args.loadSizeH, args.loadSizeW

    # Determine normalization method.
    if args.norm == "instance":
        norm_layer = functools.partial(PF.instance_normalization,
                                       fix_parameters=True,
                                       no_bias=True,
                                       no_scale=True)
    else:
        norm_layer = PF.batch_normalization

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(models.generator,
                                  input_nc=args.input_nc,
                                  output_nc=args.output_nc,
                                  ngf=args.ngf,
                                  norm_layer=norm_layer,
                                  use_dropout=False,
                                  n_blocks=9,
                                  padding_type='reflect')
    discriminator = functools.partial(models.discriminator,
                                      input_nc=args.output_nc,
                                      ndf=args.ndf,
                                      n_layers=args.n_layers_D,
                                      norm_layer=norm_layer,
                                      use_sigmoid=False)

    # --------------------- Computation Graphs --------------------

    # Input images and masks of both source / target domain
    x = nn.Variable([bs, ch, h, w], need_grad=False)
    a = nn.Variable([bs, 1, h, w], need_grad=False)

    y = nn.Variable([bs, ch, h, w], need_grad=False)
    b = nn.Variable([bs, 1, h, w], need_grad=False)

    # Apply image augmentation and get an unlinked variable
    xa_aug = image_augmentation(args, x, a)
    xa_aug.persistent = True
    xa_aug_unlinked = xa_aug.get_unlinked_variable()

    yb_aug = image_augmentation(args, y, b)
    yb_aug.persistent = True
    yb_aug_unlinked = yb_aug.get_unlinked_variable()

    # variables used for Image Pool
    x_history = nn.Variable([bs, ch, h, w])
    a_history = nn.Variable([bs, 1, h, w])
    y_history = nn.Variable([bs, ch, h, w])
    b_history = nn.Variable([bs, 1, h, w])

    # Generate Images (x -> y')
    with nn.parameter_scope("gen_x2y"):
        yb_fake = generator(xa_aug_unlinked)
    yb_fake.persistent = True
    yb_fake_unlinked = yb_fake.get_unlinked_variable()

    # Generate Images (y -> x')
    with nn.parameter_scope("gen_y2x"):
        xa_fake = generator(yb_aug_unlinked)
    xa_fake.persistent = True
    xa_fake_unlinked = xa_fake.get_unlinked_variable()

    # Reconstruct Images (y' -> x)
    with nn.parameter_scope("gen_y2x"):
        xa_recon = generator(yb_fake_unlinked)
    xa_recon.persistent = True

    # Reconstruct Images (x' -> y)
    with nn.parameter_scope("gen_x2y"):
        yb_recon = generator(xa_fake_unlinked)
    yb_recon.persistent = True

    # Use Discriminator on y' and x'
    with nn.parameter_scope("dis_y"):
        d_y_fake = discriminator(yb_fake_unlinked)
    d_y_fake.persistent = True

    with nn.parameter_scope("dis_x"):
        d_x_fake = discriminator(xa_fake_unlinked)
    d_x_fake.persistent = True

    # Use Discriminator on y and x
    with nn.parameter_scope("dis_y"):
        d_y_real = discriminator(yb_aug_unlinked)

    with nn.parameter_scope("dis_x"):
        d_x_real = discriminator(xa_aug_unlinked)

    # Identity Mapping (x -> x)
    with nn.parameter_scope("gen_y2x"):
        xa_idt = generator(xa_aug_unlinked)

    # Identity Mapping (y -> y)
    with nn.parameter_scope("gen_x2y"):
        yb_idt = generator(yb_aug_unlinked)

    # -------------------- Loss --------------------

    # (LS)GAN Loss (for Discriminator)
    loss_dis_x = (loss.lsgan_loss(d_y_fake, False) +
                  loss.lsgan_loss(d_y_real, True)) * 0.5
    loss_dis_y = (loss.lsgan_loss(d_x_fake, False) +
                  loss.lsgan_loss(d_x_real, True)) * 0.5
    loss_dis = loss_dis_x + loss_dis_y

    # Cycle Consistency Loss
    loss_cyc_x = args.lambda_cyc * loss.recon_loss(xa_recon, xa_aug_unlinked)
    loss_cyc_y = args.lambda_cyc * loss.recon_loss(yb_recon, yb_aug_unlinked)
    loss_cyc = loss_cyc_x + loss_cyc_y

    # Identity Mapping Loss
    loss_idt_x = args.lambda_idt * loss.recon_loss(xa_idt, xa_aug_unlinked)
    loss_idt_y = args.lambda_idt * loss.recon_loss(yb_idt, yb_aug_unlinked)
    loss_idt = loss_idt_x + loss_idt_y

    # Context Preserving Loss
    loss_ctx_x = args.lambda_ctx * \
        loss.context_preserving_loss(xa_aug_unlinked, yb_fake_unlinked)
    loss_ctx_y = args.lambda_ctx * \
        loss.context_preserving_loss(yb_aug_unlinked, xa_fake_unlinked)
    loss_ctx = loss_ctx_x + loss_ctx_y

    # (LS)GAN Loss (for Generator)
    d_loss_gen_x = loss.lsgan_loss(d_x_fake, True)
    d_loss_gen_y = loss.lsgan_loss(d_y_fake, True)
    d_loss_gen = d_loss_gen_x + d_loss_gen_y

    # Total Loss for Generator
    loss_gen = loss_cyc + loss_idt + loss_ctx + d_loss_gen

    # --------------------- Solvers --------------------

    # Initial learning rates
    G_lr = args.learning_rate_G
    #D_lr = args.learning_rate_D
    # As opposed to the description in the paper, D_lr is set the same as G_lr.
    D_lr = args.learning_rate_G

    # Define solvers
    solver_gen_x2y = S.Adam(G_lr, args.beta1, args.beta2)
    solver_gen_y2x = S.Adam(G_lr, args.beta1, args.beta2)
    solver_dis_x = S.Adam(D_lr, args.beta1, args.beta2)
    solver_dis_y = S.Adam(D_lr, args.beta1, args.beta2)

    # Set Parameters to each solver
    with nn.parameter_scope("gen_x2y"):
        solver_gen_x2y.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen_y2x"):
        solver_gen_y2x.set_parameters(nn.get_parameters())

    with nn.parameter_scope("dis_x"):
        solver_dis_x.set_parameters(nn.get_parameters())

    with nn.parameter_scope("dis_y"):
        solver_dis_y.set_parameters(nn.get_parameters())

    # create convenient functions manipulating Solvers
    def solvers_zero_grad():
        # Zeroing Gradients of all solvers
        solver_gen_x2y.zero_grad()
        solver_gen_y2x.zero_grad()
        solver_dis_x.zero_grad()
        solver_dis_y.zero_grad()

    def solvers_update_parameters(new_D_lr, new_G_lr):
        # Learning rate updater
        solver_gen_x2y.set_learning_rate(new_G_lr)
        solver_gen_y2x.set_learning_rate(new_G_lr)
        solver_dis_x.set_learning_rate(new_D_lr)
        solver_dis_y.set_learning_rate(new_D_lr)

    # -------------------- Data Iterators --------------------

    ds_train_A = insta_gan_data_source(args,
                                       train=True,
                                       domain="A",
                                       shuffle=True)
    di_train_A = insta_gan_data_iterator(ds_train_A, args.batch_size)

    ds_train_B = insta_gan_data_source(args,
                                       train=True,
                                       domain="B",
                                       shuffle=True)
    di_train_B = insta_gan_data_iterator(ds_train_B, args.batch_size)

    # -------------------- Monitors --------------------

    monitoring_targets_dis = {
        'discriminator_loss_x': loss_dis_x,
        'discriminator_loss_y': loss_dis_y
    }
    monitors_dis = Monitors(args, monitoring_targets_dis)

    monitoring_targets_gen = {
        'generator_loss_x': d_loss_gen_x,
        'generator_loss_y': d_loss_gen_y,
        'reconstruction_loss_x': loss_cyc_x,
        'reconstruction_loss_y': loss_cyc_y,
        'identity_mapping_loss_x': loss_idt_x,
        'identity_mapping_loss_y': loss_idt_y,
        'content_preserving_loss_x': loss_ctx_x,
        'content_preserving_loss_y': loss_ctx_y
    }
    monitors_gen = Monitors(args, monitoring_targets_gen)

    monitor_time = MonitorTimeElapsed("Training_time",
                                      Monitor(args.monitor_path),
                                      args.log_step)

    # Training loop
    epoch = 0
    n_images = max([ds_train_B.size, ds_train_A.size])
    print("{} images exist.".format(n_images))
    max_iter = args.max_epoch * n_images // args.batch_size
    decay_iter = args.max_epoch - args.lr_decay_start_epoch

    for i in range(max_iter):
        if i % (n_images // args.batch_size) == 0 and i > 0:
            # Learning Rate Decay
            epoch += 1
            print("epoch {}".format(epoch))
            if epoch >= args.lr_decay_start_epoch:
                new_D_lr = D_lr * \
                    (1.0 - max(0, epoch - args.lr_decay_start_epoch - 1) /
                     float(decay_iter - 1))
                new_G_lr = G_lr * \
                    (1.0 - max(0, epoch - args.lr_decay_start_epoch - 1) /
                     float(decay_iter - 1))
                solvers_update_parameters(new_D_lr, new_G_lr)
                print("Current learning rate for Discriminator: {}".format(
                    solver_dis_x.learning_rate()))
                print("Current learning rate for Generator: {}".format(
                    solver_gen_x2y.learning_rate()))

        # Get data
        x_data, a_data = di_train_A.next()
        y_data, b_data = di_train_B.next()
        x.d, a.d = x_data, a_data
        y.d, b.d = y_data, b_data

        solvers_zero_grad()

        # Image Augmentation
        nn.forward_all([xa_aug, yb_aug], clear_buffer=True)

        # Generate fake images
        nn.forward_all([xa_fake, yb_fake], clear_no_need_grad=True)

        # -------- Train Discriminator --------

        loss_dis.forward(clear_no_need_grad=True)
        monitors_dis.add(i)

        loss_dis.backward(clear_buffer=True)
        solver_dis_x.update()
        solver_dis_y.update()

        # -------- Train Generators --------

        # since the gradients computed above remain, reset to zero.
        xa_fake_unlinked.grad.zero()
        yb_fake_unlinked.grad.zero()
        solvers_zero_grad()

        loss_gen.forward(clear_no_need_grad=True)

        monitors_gen.add(i)
        monitor_time.add(i)

        loss_gen.backward(clear_buffer=True)
        xa_fake.backward(grad=None, clear_buffer=True)
        yb_fake.backward(grad=None, clear_buffer=True)
        solver_gen_x2y.update()
        solver_gen_y2x.update()

        if i % (n_images // args.batch_size) == 0:
            # save translation results after every epoch.
            save_images(args,
                        i,
                        xa_aug,
                        yb_fake,
                        domain="x",
                        reconstructed=xa_recon)
            save_images(args,
                        i,
                        yb_aug,
                        xa_fake,
                        domain="y",
                        reconstructed=yb_recon)

    # save pretrained parameters
    nn.save_parameters(os.path.join(args.model_save_path,
                                    'params_%06d.h5' % i))
Пример #8
0
def run(args, data_iter, model, pad_id, optimizer, epoch, train=True):
    if train is True:
        model.train()
    else:
        model.eval()
    data_iter.init_epoch()
    batch_time = AverageMeter()
    size = min(len(data_iter.data()), args.epoch_size * args.batch_size)
    re_loss = 0
    kl_divergence = 0
    flow_kl_divergence = 0
    mutual_information1, mutual_information2 = 0, 0
    seq_words = 0
    mmd_loss = 0
    negative_ll = 0
    iw_negative_ll = 0
    sum_log_j = 0
    start = time.time()
    end = time.time()
    for i, batch in enumerate(data_iter):
        if i == args.epoch_size:
            break
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        q_z, p_z, z, outputs, sum_log_jacobian, penalty, z0 = model(
            inputs, lengths - 1, pad_id)
        if args.loss_type == 'entropy':
            reloss = recon_loss(outputs, targets, pad_id, id=args.loss_type)
        else:
            reloss = recon_loss(inputs, outputs, pad_id, id=args.loss_type)

        kld = total_kld(q_z, p_z)

        if args.flow:
            f_kld = flow_kld(q_z, p_z, z, z0, sum_log_jacobian)
        else:
            f_kld = torch.zeros(1)

        mi_z = mutual_info(q_z, p_z, z0)
        nll = compute_nll(q_z, p_z, z, z0, sum_log_jacobian, reloss)

        if args.iw:
            iw_nll = model.iw_nll(q_z, p_z, inputs, targets, lengths - 1,
                                  pad_id, args.nsamples)
        else:
            iw_nll = torch.zeros(1)

        if args.flow:
            mi_flow = mutual_info_flow(q_z, p_z, z, z0, sum_log_jacobian)
        else:
            mi_flow = torch.zeros(1).to(z.device)

        mmd = torch.zeros(1).to(z.device)
        kld_weight = weight_schedule(args.epoch_size * (epoch - 1) +
                                     i) if args.kla else 1.
        if args.mmd:
            # prior_samples = torch.randn(z.size(0), z.size(-1)).to(z.device)
            mmd = compute_mmd(p_z, q_z, args.kernel)
        if kld_weight > args.t:
            kld_weight = args.t
        if args.nokld:
            kld_weight = 0

        if train is True:
            optimizer.zero_grad()
            if args.flow:
                # loss = reloss / batch_size + kld_weight * (kld - torch.sum(sum_log_jacobian) + torch.sum(penalty)) / batch_size + (args.mmd_w - kld_weight) * mmd
                loss = reloss / batch_size + kld_weight * (q_z.log_prob(
                    z0).sum() - p_z.log_prob(z).sum()) / batch_size - (
                        torch.sum(sum_log_jacobian) - torch.sum(penalty)
                    ) / batch_size + (args.mmd_w - kld_weight) * mmd
            else:
                loss = (reloss + kld_weight * kld) / batch_size + (
                    args.mmd_w - kld_weight) * mmd

            loss.backward()
            optimizer.step()

        re_loss += reloss.item() / size
        kl_divergence += kld.item() / size
        flow_kl_divergence += f_kld.item() * batch_size / size
        mutual_information1 += mi_z.item() * batch_size / size
        mutual_information2 += mi_flow.item() * batch_size / size
        seq_words += torch.sum(lengths - 1).item()
        mmd_loss += mmd.item() * batch_size / size
        negative_ll += nll.item() * batch_size / size
        iw_negative_ll += iw_nll.item() * batch_size / size
        sum_log_j += torch.sum(sum_log_jacobian).item() / size
        batch_time.update(time.time() - end)

    if kl_divergence > 100:
        kl_divergence = 100
        flow_kl_divergence = 100
    if args.iw:
        nll_ppl = math.exp(iw_negative_ll * size / seq_words)
    else:
        nll_ppl = math.exp(negative_ll * size / seq_words)

    return re_loss, kl_divergence, flow_kl_divergence, mutual_information1, mutual_information2, mmd_loss, nll_ppl, negative_ll, iw_negative_ll, sum_log_j, start, batch_time