Exemplo n.º 1
0
loss_func_d = loss_triplegan.d_loss_dict[FLAGS.gan_type]
loss_func_c_adv = loss_triplegan.c_loss_dict[FLAGS.gan_type]
loss_func_c = loss_classifier.c_loss_dict[FLAGS.c_loss]
step_func = loss_classifier.c_step_func[FLAGS.c_step]

logger_prefix = "Itera {}/{} ({:.0f}%)"

for i in range(pretrain_inter):  # 1w
    tloss, l_loss, u_loss = loss_func_c(netC, netC_T, i, itr, itr, device)
    # step_func(optim_c, netC, netC_T, i, tloss)
    if FLAGS.c_step == "ramp_swa":
        step_func(optim_c, swa_optim, netC, netC_T, i, tloss)
    else:
        step_func(optim_c, netC, netC_T, i, tloss)

    logger.add("training_pre", "loss", tloss.item(), i + 1)
    logger.add("training_pre", "l_loss", l_loss.item(), i + 1)
    logger.add("training_pre", "u_loss", u_loss.item(), i + 1)
    if (i + 1) % image_interval == 0:
        netC.train()
        netC_T.train()
        for _ in range(int(FLAGS.n_labels / FLAGS.batch_size)):
            data_u, _ = itr.__next__()
            _ = netC_T(data_u.to(device))
        netC.eval()
        netC_T.eval()
        with torch.no_grad():
            total_t, correct_t, loss_t = evaluation.test_classifier(netC)
            total_tt, correct_tt, loss_tt = evaluation.test_classifier(netC_T)
        netC.train()
        netC_T.train()
Exemplo n.º 2
0
# train
print_interval = 50
test_interval = 500
max_iter = FLAGS.n_iter
loss_func = loss_classifier.c_loss_dict[FLAGS.c_loss]
step_func = loss_classifier.c_step_func[FLAGS.c_step]

logger_prefix = "Itera {}/{} ({:.0f}%)"
for i in range(max_iter):
    tloss, l_loss, u_loss = loss_func(netC, netC_T, i, itr, itr_u, device)
    if FLAGS.c_step == "ramp_swa":
        step_func(optim_c, swa_optim, netC, netC_T, i, tloss)
    else:
        step_func(optim_c, netC, netC_T, i, tloss)

    logger.add("training", "loss", tloss.item(), i + 1)
    logger.add("training", "l_loss", l_loss.item(), i + 1)
    logger.add("training", "u_loss", u_loss.item(), i + 1)

    if (i + 1) % print_interval == 0:
        prefix = logger_prefix.format(i + 1, max_iter,
                                      (100 * i + 1) / max_iter)
        cats = ["training"]
        logger.log_info(prefix, text_logger.info, cats=cats)

    if (i + 1) % test_interval == 0:
        netC.eval()
        netC_T.eval()

        if FLAGS.c_step == "ramp_swa":
            netC_swa.train()
Exemplo n.º 3
0
        str_meg = "Iteration {}/{} ({:.0f}%), grad_norm {:e}, E_real {:e},"
        str_meg += " E_noise {:e}, tLoss {:e}, Normalized Loss {:e}, time {:4.3f}"
        text_logger.info(
            str_meg.format(
                i + 1,
                max_iter,
                100 * ((i + 1) / max_iter),
                grad_norm,
                E_real.item(),
                E_noise.item(),
                tloss.item(),
                nloss.item(),
                time_dur,
            )
        )
        time_dur = 0.0

        logger.add("training", "E_real", E_real.item(), i + 1)
        logger.add("training", "E_noise", E_noise.item(), i + 1)
        logger.add("training", "loss", tloss.item(), i + 1)
        del E_real
        del E_noise
        del nloss
        del tloss

    if (i + 1) % FLAGS.save_every == 0:
        text_logger.info("-" * 50)
        logger.save_stats("stats.pkl")
        file_name = "model" + str(i + 1) + ".pt"
        torch.save(netE.state_dict(), MODELS_FOLDER + "/" + file_name)
    data, label = data.to(device), label.to(device)
    data_u, _ = itr_u.__next__()
    data_u_d, _ = itr_u.__next__()
    data_u, data_u_d = data_u.to(device), data_u_d.to(device)

    sample_z = torch.randn(FLAGS.bs_g, FLAGS.g_z_dim).to(device)
    loss_d, dreal, dfake_g, dfake_c = loss_func_d(netD, netG, netC, data,
                                                  sample_z, label, data_u,
                                                  data_u_d)
    optim_D.zero_grad()
    loss_d.backward()
    if FLAGS.clip_value > 0:
        torch.nn.utils.clip_grad_norm_(netD.parameters(), FLAGS.clip_value)
    optim_D.step()

    logger.add("training_d", "loss", loss_d.item(), i + 1)
    logger.add("training_d", "dreal", dreal.item(), i + 1)
    logger.add("training_d", "dfake_g", dfake_g.item(), i + 1)
    logger.add("training_d", "dfake_c", dfake_c.item(), i + 1)

    sample_z = torch.randn(FLAGS.bs_g, FLAGS.g_z_dim).to(device)
    loss_g, fake_g = loss_func_g(netD, netG, sample_z, label)
    optim_G.zero_grad()
    loss_g.backward()
    if FLAGS.clip_value > 0:
        torch.nn.utils.clip_grad_norm_(netG.parameters(), FLAGS.clip_value)
    optim_G.step()

    logger.add("training_g", "loss", loss_g.item(), i + 1)
    logger.add("training_g", "fake_g", fake_g.item(), i + 1)
Exemplo n.º 5
0
loss_func_d = loss_gan.d_loss_dict[FLAGS.gan_type]

logger_prefix = "Itera {}/{} ({:.0f}%)"
for i in range(max_iter):
    x_real, label = itr.__next__()
    x_real, label = x_real.to(device), label.to(device)

    sample_z = torch.randn(batch_size, FLAGS.g_z_dim).to(device)
    loss_d, dreal, dfake = loss_func_d(netD, netG, x_real, sample_z, label)
    optim_D.zero_grad()
    loss_d.backward()
    if FLAGS.clip_value > 0:
        torch.nn.utils.clip_grad_norm_(netD.parameters(), FLAGS.clip_value)
    optim_D.step()

    logger.add("training_d", "loss", loss_d.item(), i + 1)
    logger.add("training_d", "dreal", dreal.item(), i + 1)
    logger.add("training_d", "dfake", dfake.item(), i + 1)

    if (i + 1) % FLAGS.n_iter_d:
        sample_z = torch.randn(batch_size, FLAGS.g_z_dim).to(device)
        loss_g, dfake_g = loss_func_g(netD, netG, sample_z, label)
        optim_G.zero_grad()
        loss_g.backward()
        if FLAGS.clip_value > 0:
            torch.nn.utils.clip_grad_norm_(netG.parameters(), FLAGS.clip_value)
        optim_G.step()

        logger.add("training_g", "loss", loss_g.item(), i + 1)
        logger.add("training_g", "dfake", dfake_g.item(), i + 1)
        # Torture.shortcuts.update_average(netG_T, net_G, 0.999)