示例#1
0
        netC.train()
        netC_T.train()
        for _ in range(int(FLAGS.n_labels / FLAGS.batch_size)):
            data_u, _ = itr.__next__()
            _ = netC_T(data_u.to(device))
        netC.eval()
        netC_T.eval()
        with torch.no_grad():
            sample_z = torch.randn(100, FLAGS.g_z_dim).to(device)
            # tlabel = label[: FLAGS.bs_g // 10]
            tlabel = torch.from_numpy(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8,
                                                9])).to(device)
            tlabel = torch.cat([tlabel for _ in range(10)], 0)
            x_fake = netG(sample_z, tlabel)
            logger.add_imgs(x_fake,
                            "img{:08d}".format(i + 1),
                            nrow=FLAGS.bs_g // 10)
            total_t, correct_t, loss_t = evaluation.test_classifier(netC)
            total_tt, correct_tt, loss_tt = evaluation.test_classifier(netC_T)
        netC.train()
        netC_T.train()

        if FLAGS.c_step == "ramp_swa":
            netC_swa.train()
            for _ in range(300):
                data_u, _ = itr.__next__()
                _ = netC_swa(data_u.to(device))
            netC_swa.eval()
            total_s, correct_s, loss_s = evaluation.test_classifier(netC_swa)
            logger.add("testing", "loss_s", loss_s.item(), i + 1)
            logger.add("testing", "accuracy_s", 100 * (correct_s / total_s),
示例#2
0
        sample_z = torch.randn(batch_size, FLAGS.g_z_dim).to(device)
        loss_g, dfake_g = loss_func_g(netD, netG, sample_z, label)
        optim_G.zero_grad()
        loss_g.backward()
        if FLAGS.clip_value > 0:
            torch.nn.utils.clip_grad_norm_(netG.parameters(), FLAGS.clip_value)
        optim_G.step()

        logger.add("training_g", "loss", loss_g.item(), i + 1)
        logger.add("training_g", "dfake", dfake_g.item(), i + 1)
        # Torture.shortcuts.update_average(netG_T, net_G, 0.999)

    if (i + 1) % print_interval == 0:
        prefix = logger_prefix.format(i + 1, max_iter,
                                      (100 * i + 1) / max_iter)
        cats = ["training_d", "training_g"]
        logger.log_info(prefix, text_logger.info, cats=cats)

    if (i + 1) % image_interval == 0:
        with torch.no_grad():
            sample_z = torch.randn(100, FLAGS.g_z_dim).to(device)
            tlabel = label[:10]
            tlabel = torch.cat([tlabel for _ in range(10)], 0)
            x_fake = netG(sample_z, tlabel)
            logger.add_imgs(x_fake, "img{:08d}".format(i + 1), nrow=10)

    if (i + 1) % FLAGS.save_every == 0:
        logger.save_stats("ModelStat.pkl")
        file_name = "model" + str(i + 1) + ".pt"
        checkpoint_io.save(file_name)