def test_read(self):
     util.sample_data(self.endpoint, self.index)
     url = "{0}{1}{2}".format(self.endpoint, self.index, "/doc/1")
     response = util.send_data(url, "GET")
     json_data = response.json()
     self.assertEqual(json_data["found"], True)
     self.assertEqual(json_data["_id"], "1")
 def test_update(self):
     util.sample_data(self.endpoint, self.index)
     data = {"name": "Jane Smith"}
     payload = json.dumps(data)
     url = "{0}{1}{2}".format(self.endpoint, self.index, "/doc/1")
     response = util.send_data(url, "PUT", data=payload)
     json_data = response.json()
     self.assertEqual(json_data["result"], "updated")
     self.assertEqual(json_data["_id"], "1")
     self.assertGreater(json_data["_version"], 1)
    def test_delete(self):
        util.sample_data(self.endpoint, self.index)
        url = "{0}{1}{2}".format(self.endpoint, self.index, "/doc/1")
        response = util.send_data(url, "DELETE")
        json_data = response.json()
        self.assertEqual(json_data["result"], "deleted")
        self.assertEqual(json_data["_id"], "1")

        fetch_url = "{0}{1}{2}".format(self.endpoint, self.index, "/doc/1")
        fetch_response = util.send_data(fetch_url, "GET")
        fetch_json_data = fetch_response.json()
        self.assertEqual(fetch_json_data["found"], False)
예제 #4
0
def train(device, model):
    word2idx, train_data, test_data = sample_data()
    loss_function = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_loss = []
    t0 = time()
    for epoch in range(num_epochs):
        epoch_loss = []
        for instance, label in train_data:
            model.zero_grad()
            bow_vec = word_embedding(instance, word2idx).float().to(device)
            output_tensor = model(bow_vec)
            target = torch.LongTensor([label]).to(device).float()
            loss = loss_function(output_tensor, target)
            epoch_loss.append(loss.item())
            loss.backward()
            optimizer.step()
        train_loss.append(np.mean(epoch_loss))
    t1 = time()
    print("Train time: ", (t1 - t0))

    plt.plot(np.array(train_loss))
    plt.show()
예제 #5
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    ncols=140,
                    dynamic_ncols=False,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:05d}; d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s/finetune-%06d.jpg" % (args.style, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if (i + 1) % args.save_every == 0 or (i + 1) == args.iter:
                torch.save(
                    {
                        #"g": g_module.state_dict(),
                        #"d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        #"g_optim": g_optim.state_dict(),
                        #"d_optim": d_optim.state_dict(),
                        #"args": args,
                        #"ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s/fintune-%06d.pt" %
                    (args.model_path, args.style, i + 1),
                )
예제 #6
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          instyles, Simgs, exstyles, vggloss, id_loss, device):
    loader = sample_data(loader)
    vgg_weights = [0.0, 0.5, 1.0, 0.0, 0.0]
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    smoothing=0.01,
                    ncols=180,
                    dynamic_ncols=False)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_instyle = torch.randn(args.n_sample, args.latent, device=device)
    sample_exstyle, _, _ = get_paired_data(instyles,
                                           Simgs,
                                           exstyles,
                                           batch_size=args.n_sample,
                                           random_ind=8)
    sample_exstyle = sample_exstyle.to(device)

    for idx in pbar:
        i = idx + args.start_iter

        which = i % args.subspace_freq  # defines whether we use paired data

        if i > args.iter:
            print("Done!")
            break

        # sample S
        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        if which == 0:
            # sample z^+_e, z for Lsty, Lcon and Ladv
            exstyle, _, _ = get_paired_data(instyles,
                                            Simgs,
                                            exstyles,
                                            batch_size=args.batch,
                                            random_ind=8)
            exstyle = exstyle.to(device)
            instyle = mixing_noise(args.batch, args.latent, args.mixing,
                                   device)
            z_plus_latent = False
        else:
            # sample z^+_e, z^+_i and S for Eq. (4)
            exstyle, instyle, real_img = get_paired_data(instyles,
                                                         Simgs,
                                                         exstyles,
                                                         batch_size=args.batch,
                                                         random_ind=8)
            exstyle = exstyle.to(device)
            instyle = [instyle.to(device)]
            real_img = real_img.to(device)
            z_plus_latent = True

        fake_img, _ = generator(instyle,
                                exstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss  # Ladv
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        if which == 0:
            # sample z^+_e, z for Lsty, Lcon and Ladv
            exstyle, _, real_img = get_paired_data(instyles,
                                                   Simgs,
                                                   exstyles,
                                                   batch_size=args.batch,
                                                   random_ind=8)
            real_img = real_img.to(device)
            exstyle = exstyle.to(device)
            instyle = mixing_noise(args.batch, args.latent, args.mixing,
                                   device)
            z_plus_latent = False
        else:
            # sample z^+_e, z^+_i and S for Eq. (4)
            exstyle, instyle, real_img = get_paired_data(instyles,
                                                         Simgs,
                                                         exstyles,
                                                         batch_size=args.batch,
                                                         random_ind=8)
            exstyle = exstyle.to(device)
            instyle = [instyle.to(device)]
            real_img = real_img.to(device)
            z_plus_latent = True

        fake_img, _ = generator(instyle,
                                exstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        with torch.no_grad():
            real_img_256 = F.adaptive_avg_pool2d(real_img, 256).detach()
            real_feats = vggloss(real_img_256)
            real_styles = [
                F.adaptive_avg_pool2d(real_feat, output_size=1).detach()
                for real_feat in real_feats
            ]
            real_content, _ = generator(instyle,
                                        None,
                                        use_res=False,
                                        z_plus_latent=z_plus_latent)
            real_content_256 = F.adaptive_avg_pool2d(real_content,
                                                     256).detach()

        fake_img_256 = F.adaptive_avg_pool2d(fake_img, 256)
        fake_feats = vggloss(fake_img_256)
        fake_styles = [
            F.adaptive_avg_pool2d(fake_feat, output_size=1)
            for fake_feat in fake_feats
        ]
        sty_loss = (torch.tensor(0.0).to(device) if args.CX_loss == 0 else
                    FCX.contextual_loss(fake_feats[2],
                                        real_feats[2].detach(),
                                        band_width=0.2,
                                        loss_type='cosine') * args.CX_loss)
        if args.style_loss > 0:
            sty_loss += ((F.mse_loss(fake_styles[1], real_styles[1]) +
                          F.mse_loss(fake_styles[2], real_styles[2])) *
                         args.style_loss)

        ID_loss = (torch.tensor(0.0).to(device) if args.id_loss == 0 else
                   id_loss(fake_img_256, real_content_256) * args.id_loss)

        gr_loss = torch.tensor(0.0).to(device)
        if which > 0:
            for ii, weight in enumerate(vgg_weights):
                if weight * args.perc_loss > 0:
                    gr_loss += F.l1_loss(
                        fake_feats[ii],
                        real_feats[ii].detach()) * weight * args.perc_loss

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)
        l2_reg_loss = sum(
            torch.norm(p)
            for p in g_module.res.parameters()) * args.L2_reg_loss

        loss_dict["g"] = g_loss  # Ladv
        loss_dict["gr"] = gr_loss  # Lperc
        loss_dict["l2"] = l2_reg_loss  # Lreg in Lcon
        loss_dict["id"] = ID_loss  # LID in Lcon
        loss_dict["sty"] = sty_loss  # Lsty
        g_loss = g_loss + gr_loss + sty_loss + l2_reg_loss + ID_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)

            instyle = mixing_noise(path_batch_size, args.latent, args.mixing,
                                   device)
            exstyle, _, _ = get_paired_data(instyles,
                                            Simgs,
                                            exstyles,
                                            batch_size=path_batch_size,
                                            random_ind=8)
            exstyle = exstyle.to(device)

            fake_img, latents = generator(instyle,
                                          exstyle,
                                          return_latents=True,
                                          use_res=True,
                                          z_plus_latent=False)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema.res, g_module.res, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        gr_loss_val = loss_reduced["gr"].mean().item()
        sty_loss_val = loss_reduced["sty"].mean().item()
        l2_loss_val = loss_reduced["l2"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        id_loss_val = loss_reduced["id"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:d}; d: {d_loss_val:.3f}; g: {g_loss_val:.3f}; gr: {gr_loss_val:.3f}; sty: {sty_loss_val:.3f}; l2: {l2_loss_val:.3f}; id: {id_loss_val:.3f}; "
                f"r1: {r1_val:.3f}; path: {path_loss_val:.3f}; mean path: {mean_path_length_avg:.3f}; "
                f"augment: {ada_aug_p:.4f};"))

            if i % 100 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_instyle],
                                      sample_exstyle,
                                      use_res=True)
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s/dualstylegan-%06d.jpg" % (args.style, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if ((i + 1) >= args.save_begin and
                (i + 1) % args.save_every == 0) or (i + 1) == args.iter:
                torch.save(
                    {
                        #"g": g_module.state_dict(),
                        #"d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        #"g_optim": g_optim.state_dict(),
                        #"d_optim": d_optim.state_dict(),
                        #"args": args,
                        #"ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s/%s-%06d.pt" %
                    (args.model_path, args.style, args.model_name, i + 1),
                )
def pretrain(args,
             loader,
             generator,
             discriminator,
             g_optim,
             d_optim,
             g_ema,
             encoder,
             vggloss,
             device,
             inject_index=5,
             savemodel=True):
    loader = sample_data(loader)
    vgg_weights = [0.5, 0.5, 0.5, 0.0, 0.0]
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    ncols=140,
                    dynamic_ncols=False,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_zs = mixing_noise(args.n_sample, args.latent, 1.0, device)
    with torch.no_grad():
        source_img, _ = generator([sample_zs[0]],
                                  None,
                                  input_is_latent=False,
                                  z_plus_latent=False,
                                  use_res=False)
        source_img = source_img.detach()
        target_img, _ = generator(sample_zs,
                                  None,
                                  input_is_latent=False,
                                  z_plus_latent=False,
                                  inject_index=inject_index,
                                  use_res=False)
        target_img = target_img.detach()
        style_img, _ = generator([sample_zs[1]],
                                 None,
                                 input_is_latent=False,
                                 z_plus_latent=False,
                                 use_res=False)
        _, sample_style = encoder(F.adaptive_avg_pool2d(style_img, 256),
                                  randomize_noise=False,
                                  return_latents=True,
                                  z_plus_latent=True,
                                  return_z_plus_latent=False)
        sample_style = sample_style.detach()
        if get_rank() == 0:
            utils.save_image(F.adaptive_avg_pool2d(source_img, 256),
                             f"log/%s-instyle.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))
            utils.save_image(F.adaptive_avg_pool2d(target_img, 256),
                             f"log/%s-target.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))
            utils.save_image(F.adaptive_avg_pool2d(style_img, 256),
                             f"log/%s-exstyle.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))

    for idx in pbar:
        i = idx + args.start_iter

        which = i % args.subspace_freq

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        # real_zs contains z1 and z2
        real_zs = mixing_noise(args.batch, args.latent, 1.0, device)
        with torch.no_grad():
            # g(z^+_l) with l=inject_index
            target_img, _ = generator(real_zs,
                                      None,
                                      input_is_latent=False,
                                      z_plus_latent=False,
                                      inject_index=inject_index,
                                      use_res=False)
            target_img = target_img.detach()
            # g(z2)
            style_img, _ = generator([real_zs[1]],
                                     None,
                                     input_is_latent=False,
                                     z_plus_latent=False,
                                     use_res=False)
            style_img = style_img.detach()
            # E(g(z2))
            _, pspstyle = encoder(F.adaptive_avg_pool2d(style_img, 256),
                                  randomize_noise=False,
                                  return_latents=True,
                                  z_plus_latent=True,
                                  return_z_plus_latent=False)
            pspstyle = pspstyle.detach()

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        if which > 0:
            # set z~_2 = z2
            noise = [real_zs[0]]
            externalstyle = g_module.get_latent(real_zs[1]).detach()
            z_plus_latent = False
        else:
            # set z~_2 = E(g(z2))
            noise = [real_zs[0].unsqueeze(1).repeat(1, g_module.n_latent, 1)]
            externalstyle = pspstyle
            z_plus_latent = True

        fake_img, _ = generator(noise,
                                externalstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred) * 0.1

        loss_dict["d"] = d_loss  # Ladv
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        if which > 0:
            # set z~_2 = z2
            noise = [real_zs[0]]
            externalstyle = g_module.get_latent(real_zs[1]).detach()
            z_plus_latent = False
        else:
            # set z~_2 = E(g(z2))
            noise = [real_zs[0].unsqueeze(1).repeat(1, g_module.n_latent, 1)]
            externalstyle = pspstyle
            z_plus_latent = True

        fake_img, _ = generator(noise,
                                externalstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        real_feats = vggloss(F.adaptive_avg_pool2d(target_img, 256).detach())
        fake_feats = vggloss(F.adaptive_avg_pool2d(fake_img, 256))
        gr_loss = torch.tensor(0.0).to(device)
        for ii, weight in enumerate(vgg_weights):
            if weight > 0:
                gr_loss += F.l1_loss(fake_feats[ii],
                                     real_feats[ii].detach()) * weight

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred) * 0.1

        loss_dict["g"] = g_loss  # Ladv
        loss_dict["gr"] = gr_loss  # L_perc

        g_loss += gr_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)

            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            externalstyle = torch.randn(path_batch_size, 512, device=device)
            externalstyle = g_module.get_latent(externalstyle).detach()
            fake_img, latents = generator(noise,
                                          externalstyle,
                                          return_latents=True,
                                          use_res=True,
                                          z_plus_latent=False)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema.res, g_module.res, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        gr_loss_val = loss_reduced["gr"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:d}; d: {d_loss_val:.3f}; g: {g_loss_val:.3f}; gr: {gr_loss_val:.3f}; r1: {r1_val:.3f}; "
                f"path: {path_loss_val:.3f}; mean path: {mean_path_length_avg:.3f}; "
                f"augment: {ada_aug_p:.1f}"))

            if i % 300 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([
                        sample_zs[0].unsqueeze(1).repeat(
                            1, g_module.n_latent, 1)
                    ],
                                      sample_style,
                                      use_res=True,
                                      z_plus_latent=True)
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s-%06d.jpg" % (args.model_name, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if savemodel and ((i + 1) % args.save_every == 0 or
                              (i + 1) == args.iter):
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s-%06d.pt" %
                    (args.model_path, args.model_name, i + 1),
                )
예제 #8
0

def predict():
    with torch.no_grad():
        n_correct = 0
        total_samples = 0
        Predictions = []
        for instance, label in test_data:
            bow_vec = word_embedding(instance, word2idx).float().to(device)
            output_tensor = model(bow_vec).squeeze().item()
            P = output_tensor > 0.5
            target = torch.LongTensor([label]).to(device).float()
            n_correct += (P == target).float().sum()
            Predictions.append(list(encode_dict.keys())[P])
            total_samples += 1
        print("Predictions : {}".format(Predictions))
        print("Val_accuracy: {}".format(n_correct / total_samples))


if __name__ == "__main__":
    is_model_heavy = False  # this is because for small models CPU perform better than GPU
    device = torch.device(
        'cuda:0'
    ) if is_model_heavy and torch.cuda.is_available() else torch.device('cpu')
    print("Running on {}".format(device))

    word2idx, train_data, test_data = sample_data()
    model = BoWClassifier(hidden, len(word2idx), output).to(device)
    train(device, model)
    predict()