Esempio n. 1
0
                            nrow=FLAGS.bs_g // 10)
            total_t, correct_t, loss_t = evaluation.test_classifier(netC)
            total_tt, correct_tt, loss_tt = evaluation.test_classifier(netC_T)
        netC.train()
        netC_T.train()

        if FLAGS.c_step == "ramp_swa":
            netC_swa.train()
            for _ in range(300):
                data_u, _ = itr.__next__()
                _ = netC_swa(data_u.to(device))
            netC_swa.eval()
            total_s, correct_s, loss_s = evaluation.test_classifier(netC_swa)
            logger.add("testing", "loss_s", loss_s.item(), i + 1)
            logger.add("testing", "accuracy_s", 100 * (correct_s / total_s),
                       i + 1)

        logger.add("testing", "loss", loss_t.item(), i + 1)
        logger.add("testing", "accuracy", 100 * (correct_t / total_t), i + 1)
        logger.add("testing", "loss_t", loss_tt.item(), i + 1)
        logger.add("testing", "accuracy_t", 100 * (correct_tt / total_tt),
                   i + 1)
        str_meg = logger_prefix.format(i + 1, max_iter,
                                       100 * ((i + 1) / max_iter))
        logger.log_info(str_meg, text_logger.info, ["testing"])

    if (i + 1) % FLAGS.save_every == 0:
        logger.save_stats("Model_stats.pkl")
        file_name = "model" + str(i + 1) + ".pt"
        checkpoint_io.save(file_name)
Esempio n. 2
0
        str_meg = "Iteration {}/{} ({:.0f}%), grad_norm {:e}, E_real {:e},"
        str_meg += " E_noise {:e}, tLoss {:e}, Normalized Loss {:e}, time {:4.3f}"
        text_logger.info(
            str_meg.format(
                i + 1,
                max_iter,
                100 * ((i + 1) / max_iter),
                grad_norm,
                E_real.item(),
                E_noise.item(),
                tloss.item(),
                nloss.item(),
                time_dur,
            )
        )
        time_dur = 0.0

        logger.add("training", "E_real", E_real.item(), i + 1)
        logger.add("training", "E_noise", E_noise.item(), i + 1)
        logger.add("training", "loss", tloss.item(), i + 1)
        del E_real
        del E_noise
        del nloss
        del tloss

    if (i + 1) % FLAGS.save_every == 0:
        text_logger.info("-" * 50)
        logger.save_stats("stats.pkl")
        file_name = "model" + str(i + 1) + ".pt"
        torch.save(netE.state_dict(), MODELS_FOLDER + "/" + file_name)
Esempio n. 3
0
        if FLAGS.c_step == "ramp_swa":
            netC_swa.train()
            for _ in range(300):
                data_u, _ = itr_u.__next__()
                _ = netC_swa(data_u.to(device))
            netC_swa.eval()
            total_s, correct_s, loss_s = evaluation.test_classifier(netC_swa)
            logger.add("testing", "loss_s", loss_s.item(), i + 1)
            logger.add("testing", "accuracy_s", 100 * (correct_s / total_s),
                       i + 1)

        total_t, correct_t, loss_t = evaluation.test_classifier(netC_T)
        logger.add("testing", "loss_t", loss_t.item(), i + 1)
        logger.add("testing", "accuracy_t", 100 * (correct_t / total_t), i + 1)

        total_t, correct_t, loss_t = evaluation.test_classifier(netC)
        logger.add("testing", "loss", loss_t.item(), i + 1)
        logger.add("testing", "accuracy", 100 * (correct_t / total_t), i + 1)

        prefix = logger_prefix.format(i + 1, max_iter,
                                      (100 * i + 1) / max_iter)
        cats = ["testing"]
        logger.log_info(prefix, text_logger.info, cats=cats)
        netC.train()
        netC_T.train()

    if (i + 1) % FLAGS.save_every == 0:
        logger.save_stats("{:08d}.pkl".format(i))
        file_name = "model" + str(i + 1) + ".pt"
        checkpoint_io.save(file_name)
Esempio n. 4
0
        sample_z = torch.randn(batch_size, FLAGS.g_z_dim).to(device)
        loss_g, dfake_g = loss_func_g(netD, netG, sample_z, label)
        optim_G.zero_grad()
        loss_g.backward()
        if FLAGS.clip_value > 0:
            torch.nn.utils.clip_grad_norm_(netG.parameters(), FLAGS.clip_value)
        optim_G.step()

        logger.add("training_g", "loss", loss_g.item(), i + 1)
        logger.add("training_g", "dfake", dfake_g.item(), i + 1)
        # Torture.shortcuts.update_average(netG_T, net_G, 0.999)

    if (i + 1) % print_interval == 0:
        prefix = logger_prefix.format(i + 1, max_iter,
                                      (100 * i + 1) / max_iter)
        cats = ["training_d", "training_g"]
        logger.log_info(prefix, text_logger.info, cats=cats)

    if (i + 1) % image_interval == 0:
        with torch.no_grad():
            sample_z = torch.randn(100, FLAGS.g_z_dim).to(device)
            tlabel = label[:10]
            tlabel = torch.cat([tlabel for _ in range(10)], 0)
            x_fake = netG(sample_z, tlabel)
            logger.add_imgs(x_fake, "img{:08d}".format(i + 1), nrow=10)

    if (i + 1) % FLAGS.save_every == 0:
        logger.save_stats("ModelStat.pkl")
        file_name = "model" + str(i + 1) + ".pt"
        checkpoint_io.save(file_name)