def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    maml = torch.nn.DataParallel(maml)
    opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)  

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, drop_last=True)
    testloader = DataLoader(testset, batch_size=4, shuffle=True, num_workers=4, drop_last=True)
    train_data = inf_get(trainloader)
    test_data = inf_get(testloader)

    for epoch in range(args.epoch):
        support_x, support_y, meta_x, meta_y = train_data.__next__()
        support_x, support_y, meta_x, meta_y = support_x.to(Param.device), support_y.to(Param.device), meta_x.to(Param.device), meta_y.to(Param.device)
        meta_loss = maml(support_x, support_y, meta_x, meta_y).mean()
        opt.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value = 10.0)
        opt.step()
        plot.plot('meta_loss', meta_loss.item())

        if(epoch % 2000 == 999):
            ans = None
            maml_clone = deepcopy(maml)
            for _ in range(600):
                support_x, support_y, qx, qy = test_data.__next__()
                support_x, support_y, qx, qy = support_x.to(Param.device), support_y.to(Param.device), qx.to(Param.device), qy.to(Param.device)
                temp = maml_clone(support_x, support_y, qx, qy, meta_train = False)
                if(ans is None):
                    ans = temp
                else:
                    ans = torch.cat([ans, temp], dim = 0)
            ans = ans.mean(dim = 0).tolist()
            test_result[epoch] = ans
            if (ans[-1] > best_acc):
                best_acc = ans[-1]
                torch.save(maml.state_dict(), Param.out_path + 'net_'+ str(epoch) + '_' + str(best_acc) + '.pkl') 
            del maml_clone
            print(str(epoch) + ': '+str(ans))
            with open(Param.out_path+'test.json','w') as f:
                json.dump(test_result,f)
        if (epoch < 5) or (epoch % 100 == 99):
            plot.flush()
        plot.tick()
예제 #2
0
    def get_inception_score(self, iter_time):
        if np.mod(iter_time, self.flags.inception_freq) == 0:
            sample_size = 100
            all_samples = []
            for _ in range(int(self.num_examples_IS/sample_size)):
                imgs = self.model.sample_imgs(sample_size=sample_size)
                all_samples.append(imgs[0])

            all_samples = np.concatenate(all_samples, axis=0)
            all_samples = ((all_samples + 1.) * 255. / 2.).astype(np.uint8)

            mean_IS, std_IS = get_inception_score(list(all_samples), self.flags)
            # print('Inception score iter: {}, IS: {}'.format(self.iter_time, mean_IS))

            plot.plot('inception score', mean_IS)
            plot.flush(self.log_out_dir)  # write logs
            plot.tick()
예제 #3
0
    def train(self):
        data_dir = os.getcwd() + '/64_crop'
        noise = np.random.uniform(-1, 1, [self.batch_size, 128])
        for epoch in range(self.EPOCH):
            iters = int(156191 // self.batch_size)
            for idx in range(iters):
                start_t = time.time()
                real_img = load_data(data_dir, self.batch_size, idx=idx)
                train_img = np.array(real_img)
                g_opt = [self.opt_g, self.g_loss]
                d_opt = [self.opt_d, self.d_loss]
                feed = {self.x: train_img, self.z: noise}
                # update the Discrimater once
                _, loss_d = self.sess.run(d_opt, fee_dict=feed)
                # update the Generator twice
                _, loss_g = self.sess.run(g_opt, feed_dict=feed)
                _, loss_g = self.sess.run(g_opt, feed_dict=feed)

                print(
                    "[%3f][epoch:%2d/%2d][iter:%4d/%4d],loss_d:%5f,loss_g:%5f"
                    % (time.time() - start_t), epoch, self.EPOCH, idx, iters,
                    loss_d, loss_g)
                plot.plot('d_loss', loss_d)
                plot.plot('g_loss', loss_g)

                if ((idx + 1) % 100) == 0:  #flush plot picture per 1000 iters
                    plot.flush()
                plot.tick()

                if (idx + 1) % 400 == 0:
                    print('image saving ...........')
                    img = self.run(self.G_img, feed_dict=feed)
                    save_img.save_images(
                        img,
                        os.getcwd() + '/gen_picture/' +
                        'samle{}_{}.jpg'.format(epoch, (idx + 1) / 400))
                    print('images save done')
            print('model saving ..........')
            path = os.getcwd() + '/model_saved'
            save_path = os.path.join(path, 'model_ckpt')
            self.saver.save(self.sess, save_path=save_path)
            print('model saved .............')
예제 #4
0
        # _, lossV, _trainY, _predict = sess.run([discOptimizer, loss, trainY, predict], feed_dict = {
        # 	train_status: True
        # 	})
        _, lossV, _trainY, _predict = sess.run(
            [discOptimizer, loss, batch_label, predict])
        _label = np.argmax(_trainY, axis=1)
        _accuracy = np.mean(_label == _predict)
        plot.plot('train cross entropy', lossV)
        plot.plot('train accuracy', _accuracy)

        if iteration % 50 == 49:
            dev_accuracy = []
            dev_cross_entropy = []
            for eval_idx in xrange(EVAL_ITER_NUM):
                # eval_loss_v, _trainY, _predict = sess.run([loss, trainY, predict], feed_dict ={train_status: False})
                eval_loss_v, _trainY, _predict = sess.run(
                    [loss, batch_eval_label, predict_eval])
                _label = np.argmax(_trainY, axis=1)
                _accuracy = np.mean(_label == _predict)
                dev_accuracy.append(_accuracy)
                dev_cross_entropy.append(eval_loss_v)
            best = max(best, np.mean(dev_accuracy))
            plot.plot('dev accuracy', np.mean(dev_accuracy))
            plot.plot('dev cross entropy', np.mean(dev_cross_entropy))

        if (iteration < 5) or (iteration % 50 == 49):
            plot.flush()

        plot.tick()
    print("best score " + str(best))
def main():
    one = torch.FloatTensor([1.0]).cuda()
    mone = torch.FloatTensor([-1.0]).cuda()
    ones_31 = torch.zeros(Param.batch_size, 1, 31, 31).fill_(1.0).type(torch.FloatTensor).cuda()
    mones_31 = torch.zeros(Param.batch_size, 1, 31, 31).fill_(-1.0).type(torch.FloatTensor).cuda()
    zeros_31 = torch.zeros(Param.batch_size, 1, 31, 31).type(torch.FloatTensor).cuda()

    mask = torch.ones(Param.batch_size, 3, 128, 128)
    mask[:, :, 32:32 + 64, 32:32 + 64] = torch.zeros(Param.batch_size, 3, 64, 64)
    mask = Variable(mask.type(torch.FloatTensor).cuda(), requires_grad=False)

    h0 = torch.zeros(Param.batch_size, Param.unet_channel * 8 * 2).cuda()
    c0 = torch.zeros(Param.batch_size, Param.unet_channel * 8 * 2).cuda()

    netG = Net_G().cuda()
    netD = Net_D().cuda()

    #netG.load_state_dict(torch.load('/data/haoran/unet-gan/gan_lstm_4/netG_59999.pickle'))
    #netD.load_state_dict(torch.load('/data/haoran/unet-gan/gan_lstm_4/netD_59999.pickle'))
    #netG = nn.DataParallel(netG, device_ids=[0, 1])
    #netD = nn.DataParallel(netD, device_ids=[0, 1])

    opt_G = optim.Adam(netG.parameters(), lr=Param.G_learning_rate, betas = (0.5,0.999), weight_decay=Param.weight_decay)
    opt_D = optim.Adam(netD.parameters(), lr=Param.D_learning_rate, betas = (0.5,0.999), weight_decay=Param.weight_decay)

    #trainset = ParisData('paris.csv', RandCrop())
    trainset = ParisData('paris.csv', RandCrop())
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=Param.batch_size, shuffle=True, num_workers=2,
                                               drop_last=True)
    train_data = inf_get(train_loader)

    epoch = 0
    maxepoch = 200000
    bce_loss = nn.BCELoss()

    while (epoch < maxepoch):
        start_time = time.time()
        # step D
        for p in netD.parameters():
            p.requires_grad = True
        #for D_step in range(Param.n_critic):
        ###################################
        ######   D  #######################
        ###################################
        real_data = train_data.next()
        # print(real_data)
        real_data = real_data.cuda()
        real_data_64 = destroy(real_data, 64)
        real_data_48 = destroy(real_data, 48)
        real_data_32 = destroy(real_data, 32)
        real_data_16 = destroy(real_data, 16)

        real_data_64 = Variable(real_data_64)
        real_data_48 = Variable(real_data_48)
        real_data_32 = Variable(real_data_32)
        real_data_16 = Variable(real_data_16)
        real_data_0 = Variable(real_data)

        netD.zero_grad()
        p_real_48, p_real_32, p_real_16, p_real_0 = netD(real_data_48, real_data_32, real_data_16, real_data_0)
        target = Variable(ones_31)

        #print(p_real_48.size())
        real_loss_48 = bce_loss(p_real_48, target)
        real_loss_32 = bce_loss(p_real_32, target)
        real_loss_16 = bce_loss(p_real_16, target)
        real_loss_0 = bce_loss(p_real_0, target)

        fake_data_48, fake_data_32, fake_data_16, fake_data_0 = netG(real_data_64, real_data_48, real_data_32,
                                                                        real_data_16, Variable(h0), Variable(c0))

        p_fake_48, p_fake_32, p_fake_16, p_fake_0 = netD(Variable(fake_data_48.data), Variable(fake_data_32.data), Variable(fake_data_16.data), Variable(fake_data_0.data))
        target = Variable(zeros_31)
        fake_loss_48 = bce_loss(p_fake_48, target)
        fake_loss_32 = bce_loss(p_fake_32, target)
        fake_loss_16 = bce_loss(p_fake_16, target)
        fake_loss_0 = bce_loss(p_fake_0, target)

        gan_loss = real_loss_48 + real_loss_32 + real_loss_16 + real_loss_0 + fake_loss_48 + fake_loss_32 + fake_loss_16 + fake_loss_0
        '''
        real_loss_48.backward(retain_graph=True)
        real_loss_32.backward(retain_graph=True)
        real_loss_16.backward(retain_graph=True)
        real_loss_0.backward(retain_graph=True)

        fake_loss_48.backward(retain_graph=True)
        fake_loss_32.backward(retain_graph=True)
        fake_loss_16.backward(retain_graph=True)
        fake_loss_0.backward(retain_graph=True)
        '''
        gan_loss = gan_loss
        gan_loss.backward(retain_graph=True)

        D_cost = fake_loss_48.data[0] + fake_loss_32.data[0] + fake_loss_16.data[0] + fake_loss_0.data[0]
        D_cost += real_loss_48.data[0] + real_loss_32.data[0] + real_loss_16.data[0] + real_loss_0.data[0]

        opt_D.step()
        ##################
        ## step G ########
        ##################
        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()

        #real_data = train_data.next()
        #real_data = real_data.cuda()
        #real_data_64 = destroy(real_data, 64)
        #real_data_48 = destroy(real_data, 48)
        #real_data_32 = destroy(real_data, 32)
        #real_data_16 = destroy(real_data, 16)

        #real_data_64 = Variable(real_data_64)
        #real_data_48 = Variable(real_data_48)
        #real_data_32 = Variable(real_data_32)
        #real_data_16 = Variable(real_data_16)
        #real_data_0 = Variable(real_data)

        #fake_data_48, fake_data_32, fake_data_16, fake_data_0 = netG(real_data_64, real_data_48, real_data_32, real_data_16, Variable(h0), Variable(c0))

        l1_loss = ((fake_data_48 - real_data_48).abs()).mean() + ((fake_data_32 - real_data_32).abs()).mean() + ((
            fake_data_16 - real_data_16).abs()).mean() + ((fake_data_0 - real_data_0).abs()).mean()

        tv_loss = cal_tv(fake_data_48) + cal_tv(fake_data_32) + cal_tv(fake_data_16) + cal_tv(fake_data_0)
        tv_loss = tv_loss * Param.tv_weight

        p_fake_48, p_fake_32, p_fake_16, p_fake_0 = netD(fake_data_48, fake_data_32, fake_data_16, fake_data_0)
        #target = Variable(ones_31)
        target = Variable(zeros_31)
        fake_loss_48 = bce_loss(p_fake_48, target)
        fake_loss_32 = bce_loss(p_fake_32, target)
        fake_loss_16 = bce_loss(p_fake_16, target)
        fake_loss_0 = bce_loss(p_fake_0, target)

        gan_loss = fake_loss_48 + fake_loss_32 + fake_loss_16 + fake_loss_0
        gan_loss = - gan_loss * Param.gan_weight

        gan_loss.backward(retain_graph=True)
        l1_loss.backward(one, retain_graph=True)
        tv_loss.backward(one, retain_graph=True)

        G_cost = fake_loss_48.data[0] + fake_loss_32.data[0] + fake_loss_16.data[0] + fake_loss_0.data[0]
        G_cost += l1_loss.data[0]
        opt_G.step()
        print('epoch: ' + str(epoch) + ' l1_loss: ' + str(l1_loss.data[0]))

        # Write logs and save samples
        #print(D_cost.size())
        #print(G_cost.size())
        os.chdir(Param.out_path)
        plot.plot('train D cost', D_cost)
        plot.plot('time', time.time() - start_time)
        plot.plot('train G cost', G_cost)
        plot.plot('train l1 loss', l1_loss.data.cpu().numpy())

        if epoch % 100 == 99:
            #  real_data = train_data.next()
            out_image = torch.cat(
                (
                    fake_data_48.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    fake_data_32.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    fake_data_16.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    fake_data_0.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    real_data_64.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    real_data_48.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    real_data_32.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    real_data_16.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    real_data_0.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size)
                ),
                1
            )
            # out_image.transpose(0,1,3,4,2)
            save_image_plus(out_image.cpu().numpy(), Param.out_path + 'train_image_{}.jpg'.format(epoch))

        if (epoch < 5) or (epoch % 100 == 99):
            plot.flush()
        plot.tick()

        if epoch % 10000 == 9999:
            torch.save(netD.state_dict(),Param.out_path+ 'netD_{}.pickle'.format(epoch))
            torch.save(netG.state_dict(),Param.out_path+ 'netG_{}.pickle'.format(epoch))
            #opt_D.param_groups[0]['lr'] /= 10.0
            #opt_G.param_groups[0]['lr'] /= 10.0
        epoch += 1
예제 #6
0
def train():
    args = load_args()
    train_gen, test_gen = load_data(args)
    torch.manual_seed(1)
    netG, netD, netE = load_models(args)

    if args.use_spectral_norm:
        optimizerD = optim.Adam(filter(lambda p: p.requires_grad,
            netD.parameters()), lr=2e-4, betas=(0.0,0.9))
    else:
        optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.9))
    optimizerE = optim.Adam(netE.parameters(), lr=2e-4, betas=(0.5, 0.9))

    schedulerD = optim.lr_scheduler.ExponentialLR(optimizerD, gamma=0.99)
    schedulerG = optim.lr_scheduler.ExponentialLR(optimizerG, gamma=0.99) 
    schedulerE = optim.lr_scheduler.ExponentialLR(optimizerE, gamma=0.99)
    
    ae_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    iteration = 0 
    for epoch in range(args.epochs):
        for i, (data, targets) in enumerate(train_gen):
            start_time = time.time()
            """ Update AutoEncoder """
            for p in netD.parameters():
                p.requires_grad = False
            netG.zero_grad()
            netE.zero_grad()
            real_data_v = autograd.Variable(data).cuda()
            real_data_v = real_data_v.view(args.batch_size, -1)
            encoding = netE(real_data_v)
            fake = netG(encoding)
            ae_loss = ae_criterion(fake, real_data_v)
            ae_loss.backward(one)
            optimizerE.step()
            optimizerG.step()
            
            """ Update D network """
            for p in netD.parameters():  
                p.requires_grad = True 
            for i in range(5):
                real_data_v = autograd.Variable(data).cuda()
                # train with real data
                netD.zero_grad()
                D_real = netD(real_data_v)
                D_real = D_real.mean()
                D_real.backward(mone)
                # train with fake data
                noise = torch.randn(args.batch_size, args.dim).cuda()
                noisev = autograd.Variable(noise, volatile=True)
                fake = autograd.Variable(netG(noisev).data)
                inputv = fake
                D_fake = netD(inputv)
                D_fake = D_fake.mean()
                D_fake.backward(one)

                # train with gradient penalty 
                gradient_penalty = ops.calc_gradient_penalty(args,
                        netD, real_data_v.data, fake.data)
                gradient_penalty.backward()

                D_cost = D_fake - D_real + gradient_penalty
                Wasserstein_D = D_real - D_fake
                optimizerD.step()

            # Update generator network (GAN)
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise)
            fake = netG(noisev)
            G = netD(fake)
            G = G.mean()
            G.backward(mone)
            G_cost = -G
            optimizerG.step() 

            schedulerD.step()
            schedulerG.step()
            schedulerE.step()
            # Write logs and save samples 
            save_dir = './plots/'+args.dataset
            plot.plot(save_dir, '/disc cost', D_cost.cpu().data.numpy())
            plot.plot(save_dir, '/gen cost', G_cost.cpu().data.numpy())
            plot.plot(save_dir, '/w1 distance', Wasserstein_D.cpu().data.numpy())
            plot.plot(save_dir, '/ae cost', ae_loss.data.cpu().numpy())
            
            # Calculate dev loss and generate samples every 100 iters
            if iteration % 100 == 99:
                dev_disc_costs = []
                for i, (images, targets) in enumerate(test_gen):
                    imgs_v = autograd.Variable(images, volatile=True).cuda()
                    D = netD(imgs_v)
                    _dev_disc_cost = -D.mean().cpu().data.numpy()
                    dev_disc_costs.append(_dev_disc_cost)
                plot.plot(save_dir ,'/dev disc cost', np.mean(dev_disc_costs))
                utils.generate_image(iteration, netG, save_dir, args)
                # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v)

            # Save logs every 100 iters 
            if (iteration < 5) or (iteration % 100 == 99):
                plot.flush()
            plot.tick()
            if iteration % 100 == 0:
                utils.save_model(netG, optimizerG, iteration,
                        'models/{}/G_{}'.format(args.dataset, iteration))
                utils.save_model(netD, optimizerD, iteration, 
                        'models/{}/D_{}'.format(args.dataset, iteration))
            iteration += 1
예제 #7
0
    def train(self):
        run_flags = self.flags
        sess = self.sess
        sw = self.sw

        hr_img = tf.placeholder(tf.float32, [
            None, run_flags.input_width * run_flags.scale,
            run_flags.input_length * run_flags.scale, 3
        ])  #128*128*3 as default
        lr_img = tf.placeholder(
            tf.float32,
            [None, run_flags.input_width, run_flags.input_length, 3
             ])  #64*64*3 as default
        myModel = Model(locals())

        out_gen = Model.generative(myModel, lr_img)

        real_out_dis = Model.discriminative(myModel, hr_img)

        fake_out_dis = Model.discriminative(myModel, out_gen, reuse=True)

        cost_gen, cost_dis, var_train, var_gen, var_dis = \
        Model.costs_and_vars(myModel, hr_img, out_gen, real_out_dis, fake_out_dis)

        # cost_gen, cost_dis, var_train, var_gen, var_dis = \
        # Model.wgan_loss(myModel, hr_img, out_gen, real_out_dis, fake_out_dis)

        optimizer_gen = tf.train.AdamOptimizer(learning_rate=run_flags.lr_gen). \
            minimize(cost_gen, var_list=var_gen)
        optimizer_dis = tf.train.AdamOptimizer(learning_rate=run_flags.lr_dis). \
            minimize(cost_dis, var_list=var_dis)

        init = tf.global_variables_initializer()

        with sess:
            sess.run(init)

            saver = tf.train.Saver()

            if not exists('models'):
                makedirs('models')

            passed_iters = 0

            for epoch in range(1, run_flags.epochs + 1):
                print('Epoch:', str(epoch))

                for batch in BatchGenerator(run_flags.batch_size,
                                            self.datasize):
                    batch_hr = self.dataset[batch] / 255.0
                    batch_lr = array([imresize(img, size=(run_flags.input_width, run_flags.input_length, 3)) \
                                      for img in batch_hr])

                    _, gc, dc = sess.run([optimizer_gen, cost_gen, cost_dis],
                                         feed_dict={
                                             hr_img: batch_hr,
                                             lr_img: batch_lr
                                         })
                    sess.run([optimizer_dis],
                             feed_dict={
                                 hr_img: batch_hr,
                                 lr_img: batch_lr
                             })

                    passed_iters += 1

                    if passed_iters % run_flags.sample_iter == 0:
                        print('Passed iterations=%d, Generative cost=%.9f, Discriminative cost=%.9f' % \
                              (passed_iters, gc, dc))
                        plot.plot('train_dis_cost_gan', abs(dc))
                        plot.plot('train_gen_cost_gan', abs(gc))

                    if (passed_iters < 5) or (passed_iters % 100 == 99):
                        plot.flush()

                    plot.tick()

                if run_flags.checkpoint_iter and epoch % run_flags.checkpoint_iter == 0:
                    saver.save(
                        sess,
                        '/'.join(['models', run_flags.model, run_flags.model]))

                    print('Model \'%s\' saved in: \'%s/\'' \
                          % (run_flags.model, '/'.join(['models', run_flags.model])))

            print('Optimization finished.')

            saver.save(sess,
                       '/'.join(['models', run_flags.model, run_flags.model]))

            print('Model \'%s\' saved in: \'%s/\'' \
                  % (run_flags.model, '/'.join(['models', run_flags.model])))
예제 #8
0
def train():
    args = load_args()
    train_gen, dev_gen, test_gen = utils.dataset_iterator(args)
    torch.manual_seed(1)
    netG, netD, netE = load_models(args)

    # optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    vgg_scale = 0.0784  # 1/12.75
    mse_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda(0)
    mone = (one * -1).cuda(0)

    gen = utils.inf_train_gen(train_gen)
    """ train SRResNet with MSE """
    for iteration in range(1, 20000):
        start_time = time.time()
        # for p in netD.parameters():
        #     p.requires_grad = False
        for p in netE.parameters():
            p.requires_grad = False

        netG.zero_grad()
        _data = next(gen)
        real_data, vgg_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        #Perceptual loss
        #vgg_data_v = autograd.Variable(vgg_data)
        #vgg_features_real = netE(vgg_data_v)
        fake = netG(real_data_v)
        #vgg_features_fake = netE(fake)
        #diff = vgg_features_fake - vgg_features_real.cuda(0)
        #perceptual_loss = vgg_scale * ((diff.pow(2)).sum(3).mean())  # mean(sum(square(diff)))
        #perceptual_loss.backward(one)
        mse_loss = mse_criterion(fake, real_data_v)
        mse_loss.backward(one)
        optimizerG.step()

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/mse cost SRResNet',
                  np.round(mse_loss.data.cpu().numpy(), 4))
        if iteration % 50 == 49:
            utils.generate_sr_image(iteration, netG, save_dir, args,
                                    real_data_v)
        if (iteration < 5) or (iteration % 50 == 49):
            plot.flush()
        plot.tick()
        if iteration % 5000 == 0:
            torch.save(netG.state_dict(), './SRResNet_PL.pt')

    for iteration in range(args.epochs):
        start_time = time.time()
        """ Update AutoEncoder """

        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        netE.zero_grad()
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        encoding = netE(real_data_v)
        fake = netG(encoding)
        ae_loss = ae_criterion(fake, real_data_v)
        ae_loss.backward(one)
        optimizerE.step()
        optimizerG.step()
        """ Update D network """

        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update
        for i in range(5):
            _data = next(gen)
            real_data = stack_data(args, _data)
            real_data_v = autograd.Variable(real_data)
            # train with real data
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)
            # train with fake data
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise,
                                       volatile=True)  # totally freeze netG
            # instead of noise, use image
            fake = autograd.Variable(netG(real_data_v).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # train with gradient penalty
            gradient_penalty = ops.calc_gradient_penalty(
                args, netD, real_data_v.data, fake.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            optimizerD.step()

        # Update generator network (GAN)
        # noise = torch.randn(args.batch_size, args.dim).cuda()
        # noisev = autograd.Variable(noise)
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        # again use real data instead of noise
        fake = netG(real_data_v)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        G_cost = -G
        optimizerG.step()

        # Write logs and save samples

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(),
                                                   4))
        plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(),
                                                  4))
        plot.plot(save_dir, '/w1 distance',
                  np.round(Wasserstein_D.cpu().data.numpy(), 4))
        # plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(), 4))

        # Calculate dev loss and generate samples every 100 iters
        if iteration % 100 == 99:
            dev_disc_costs = []
            for images, _ in dev_gen():
                imgs = stack_data(args, images)
                imgs_v = autograd.Variable(imgs, volatile=True)
                D = netD(imgs_v)
                _dev_disc_cost = -D.mean().cpu().data.numpy()
                dev_disc_costs.append(_dev_disc_cost)
            plot.plot(save_dir, '/dev disc cost',
                      np.round(np.mean(dev_disc_costs), 4))

            # utils.generate_image(iteration, netG, save_dir, args)
            # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v)
            utils.generate_sr_image(iteration, netG, save_dir, args,
                                    real_data_v)
        # Save logs every 100 iters
        if (iteration < 5) or (iteration % 100 == 99):
            plot.flush()
        plot.tick()
예제 #9
0
    def train(self):
        """The function for the meta-train phase."""
        # Set the meta-train log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        timer = Timer()
        # Generate the labels for train set of the episodes
        label_shot = torch.arange(self.args.way).repeat(self.args.shot).to(
            self.args.device).type(torch.long)
        label_query = torch.arange(self.args.way).repeat(
            self.args.train_query).to(self.args.device).type(torch.long)

        # Start meta-train
        for epoch in range(1, self.args.max_epoch + 1):
            ################### train #############################
            self.model.train()
            train_acc_averager = Averager()
            self.train_loader = DataLoader(dataset=self.trainset,
                                           batch_sampler=self.train_sampler,
                                           num_workers=self.args.num_work,
                                           pin_memory=True)
            train_data = self.inf_get(self.train_loader)
            self.val_loader = DataLoader(dataset=self.valset,
                                         batch_sampler=self.val_sampler,
                                         num_workers=self.args.num_work,
                                         pin_memory=True)
            val_data = self.inf_get(self.val_loader)
            acc_log = []
            tqdm_gen = tqdm.tqdm(
                range(self.args.num_batch // self.args.meta_batch))
            for i in tqdm_gen:
                data_list = []
                label_shot_list = []
                for _ in range(self.args.meta_batch):
                    data_list.append(train_data.__next__().to(
                        self.args.device))
                    label_shot_list.append(label_shot)
                pass
                data_list = torch.stack(data_list, dim=0)
                label_shot_list = torch.stack(label_shot_list,
                                              dim=0)  # shot-label
                out = self.model(data_list, label_shot_list)
                meta_loss = 0
                for inner_id in range(self.args.meta_batch):
                    meta_loss += F.cross_entropy(out[inner_id], label_query)
                    cur_acc = count_acc(out[inner_id], label_query)
                    acc_log.append(
                        (i * self.args.meta_batch + inner_id, cur_acc))
                    train_acc_averager.add(cur_acc)
                    tqdm_gen.set_description('Epoch {}, Acc={:.4f}'.format(
                        epoch, cur_acc))
                pass
                meta_loss /= self.args.meta_batch
                plot.plot('meta_loss', meta_loss.item())
                self.optimizer.zero_grad()
                meta_loss.backward()
                self.optimizer.step()
                plot.tick()
            pass
            train_acc_averager = train_acc_averager.item()
            trlog['train_acc'].append(train_acc_averager)
            plot.plot('train_acc_averager', train_acc_averager)
            plot.flush(self.args.save_path)

            ####################### eval ##########################
            #self.model.eval()    ###############################################################################
            val_acc_averager = Averager()
            for i in tqdm.tqdm(
                    range(self.args.val_batch // self.args.meta_batch)):
                data_list = []
                label_shot_list = []
                for _ in range(self.args.meta_batch):
                    data_list.append(val_data.__next__().to(self.args.device))
                    label_shot_list.append(label_shot)
                pass
                data_list = torch.stack(data_list, dim=0)
                label_shot_list = torch.stack(label_shot_list, dim=0)
                out = self.model(data_list, label_shot_list).detach()
                for inner_id in range(self.args.meta_batch):
                    cur_acc = count_acc(out[inner_id], label_query)
                    val_acc_averager.add(cur_acc)
                pass
            pass
            val_acc_averager = val_acc_averager.item()
            trlog['val_acc'].append(val_acc_averager)
            print('Epoch {}, Val, Acc={:.4f}'.format(epoch, val_acc_averager))

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')

            with open(osp.join(self.args.save_path, 'trlog.json'), 'w') as f:
                json.dump(trlog, f)
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
        pass
예제 #10
0
def train():
    args = load_args()
    train_gen, dev_gen, test_gen = utils.dataset_iterator(args)
    torch.manual_seed(1)
    np.set_printoptions(precision=4)
    netG, netD, netE = load_models(args)

    optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerE = optim.Adam(netE.parameters(), lr=1e-4, betas=(0.5, 0.9))
    ae_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()

    gen = utils.inf_train_gen(train_gen)

    preprocess = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    for iteration in range(args.epochs):
        start_time = time.time()
        """ Update AutoEncoder """
        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        netE.zero_grad()
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        encoding = netE(real_data_v)
        fake = netG(encoding)
        ae_loss = ae_criterion(fake, real_data_v)
        ae_loss.backward(one)
        optimizerE.step()
        optimizerG.step()
        """ Update D network """

        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update
        for i in range(5):
            _data = next(gen)
            real_data = stack_data(args, _data)
            real_data_v = autograd.Variable(real_data)
            # train with real data
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)
            # train with fake data
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise,
                                       volatile=True)  # totally freeze netG
            fake = autograd.Variable(netG(noisev).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # train with gradient penalty
            gradient_penalty = ops.calc_gradient_penalty(
                args, netD, real_data_v.data, fake.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            optimizerD.step()

        # Update generator network (GAN)
        noise = torch.randn(args.batch_size, args.dim).cuda()
        noisev = autograd.Variable(noise)
        fake = netG(noisev)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        G_cost = -G
        optimizerG.step()

        # Write logs and save samples

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(),
                                                   4))
        plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(),
                                                  4))
        plot.plot(save_dir, '/w1 distance',
                  np.round(Wasserstein_D.cpu().data.numpy(), 4))
        plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(),
                                                 4))

        # Calculate dev loss and generate samples every 100 iters
        if iteration % 100 == 99:
            dev_disc_costs = []
            for images, _ in dev_gen():
                imgs = stack_data(args, images)
                imgs_v = autograd.Variable(imgs, volatile=True)
                D = netD(imgs_v)
                _dev_disc_cost = -D.mean().cpu().data.numpy()
                dev_disc_costs.append(_dev_disc_cost)
            plot.plot(save_dir, '/dev disc cost',
                      np.round(np.mean(dev_disc_costs), 4))

            # utils.generate_image(iteration, netG, save_dir, args)
            utils.generate_ae_image(iteration, netE, netG, save_dir, args,
                                    real_data_v)
        # Save logs every 100 iters
        if (iteration < 5) or (iteration % 100 == 99):
            plot.flush()
        plot.tick()
예제 #11
0
def main():

    mask = torch.ones(Param.batch_size, 3, 128, 128).to(Param.device)
    mask[:, :, 32:32 + 64, 32:32 + 64] = torch.zeros(Param.batch_size, 3, 64,
                                                     64).to(Param.device)

    netG = Net_G().to(Param.device)
    netD = Net_D().to(Param.device)

    netG = torch.nn.DataParallel(netG)
    netD = torch.nn.DataParallel(netD)

    #opt_G = optim.Adam(netG.parameters(), lr=Param.G_learning_rate, betas = (0.,0.9), weight_decay=Param.weight_decay)
    #opt_D = optim.Adam(netD.parameters(), lr=Param.D_learning_rate, betas = (0.,0.9), weight_decay=Param.weight_decay)
    opt_G = optim.SGD(netG.parameters(),
                      lr=Param.G_learning_rate,
                      weight_decay=Param.weight_decay)
    opt_D = optim.SGD(netD.parameters(),
                      lr=Param.D_learning_rate,
                      weight_decay=Param.weight_decay)
    #opt_D = optim.RMSprop(netD.parameters(), lr=Param.D_learning_rate, weight_decay=Param.weight_decay, eps=1e-6)
    #opt_gp_D = optim.RMSprop(netD.parameters(), lr=Param.gp_learning_rate, weight_decay=Param.weight_decay, eps=1e-6)
    #opt_gp_D = optim.SGD(netD.parameters(), lr=Param.gp_learning_rate, weight_decay=Param.weight_decay)

    #trainset = ParisData('paris.csv', RandCrop())
    trainset = ImageNetData('imagenet.csv', RandCrop())
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=Param.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               drop_last=True)
    train_data = inf_get(train_loader)

    epoch = 0
    maxepoch = 2000000
    #bce_loss = nn.BCELoss()

    ##########################
    ###########################

    while (epoch < maxepoch):
        start_time = time.time()
        ##########
        # step D #
        ##########
        for p in netD.parameters():
            p.requires_grad = True

        for D_step in range(Param.n_critic):
            real_data = train_data.next()
            real_data = real_data.to(Param.device)
            destroy_data_64 = destroy(real_data, 64)
            fake_data = netG(destroy_data_64).detach()

            netD.zero_grad()
            D_fake = netD(fake_data).mean()
            D_real = netD(real_data).mean()
            gradient_penalty = calc_gradient_penalty(netD, real_data.detach(),
                                                     fake_data.detach())
            D_loss = D_fake - D_real
            D_tot = D_loss + gradient_penalty * Param.gp_lamada
            D_tot.backward()
            opt_D.step()

        ##########
        # step G #
        ##########
        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()

        real_data = train_data.next()
        real_data = real_data.to(Param.device)
        destroy_data_64 = destroy(real_data, 64)
        fake_data = netG(destroy_data_64)

        l2_loss = ((fake_data - real_data)**2).mean()
        #l2_loss = (((fake_data - real_data) * mask) ** 2).mean()

        D_fake = netD(fake_data).mean()
        G_loss = l2_loss * Param.l2_weight - D_fake * Param.gan_weight
        G_loss.backward()
        opt_G.step()

        # Write logs and save samples
        os.chdir(Param.out_path)
        plot.plot('train-D-cost', D_loss.item())
        plot.plot('time', time.time() - start_time)
        plot.plot('train-G-cost', G_loss.item())
        plot.plot('l2-cost', l2_loss.item())
        plot.plot('G-fake-cost', -D_fake.item())
        plot.plot('gp', gradient_penalty.item())

        if epoch % 100 == 99:
            #  real_data = train_data.next()
            out_image = torch.cat(
                (fake_data.data.view(Param.batch_size, 1, 3, Param.image_size,
                                     Param.image_size),
                 destroy_data_64.data.view(Param.batch_size, 1, 3,
                                           Param.image_size, Param.image_size),
                 real_data.data.view(Param.batch_size, 1, 3, Param.image_size,
                                     Param.image_size)), 1)
            # out_image.transpose(0,1,3,4,2)
            save_image_plus(
                out_image.cpu().numpy(),
                Param.out_path + 'train_image_{}.jpg'.format(epoch))

        if (epoch < 50) or (epoch % 100 == 99):
            plot.flush()
        plot.tick()

        if epoch % 5000 == 4999:
            torch.save(netD.state_dict(),
                       Param.out_path + 'netD_{}.pickle'.format(epoch))
            torch.save(netG.state_dict(),
                       Param.out_path + 'netG_{}.pickle'.format(epoch))
            # opt_D.param_groups[0]['lr'] /= 10.0
            # opt_G.param_groups[0]['lr'] /= 10.0
        epoch += 1
예제 #12
0
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    maml = torch.nn.DataParallel(maml)
    opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # Load pkl dataset
    args_data = {}
    args_data['x_dim'] = "84,84,3"
    args_data['ratio'] = 1.0
    args_data['seed'] = 222
    loader_train = dataset_mini(600, 100, 'train', args_data)
    #loader_val   = dataset_mini(600, 100, 'val', args_data)
    loader_test = dataset_mini(600, 100, 'test', args_data)

    loader_train.load_data_pkl()
    #loader_val.load_data_pkl()
    loader_test.load_data_pkl()

    for epoch in range(args.epoch):
        support_x, support_y, meta_x, meta_y = get_data(loader_train)
        support_x, support_y, meta_x, meta_y = support_x.to(
            Param.device), support_y.to(Param.device), meta_x.to(
                Param.device), meta_y.to(Param.device)
        meta_loss = maml(support_x, support_y, meta_x, meta_y).mean()
        opt.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value=10.0)
        opt.step()
        plot.plot('meta_loss', meta_loss.item())

        if (epoch % 2000 == 999):
            ans = None
            maml_clone = deepcopy(maml)
            for _ in range(600):
                support_x, support_y, qx, qy = get_data(loader_test)
                support_x, support_y, qx, qy = support_x.to(
                    Param.device), support_y.to(Param.device), qx.to(
                        Param.device), qy.to(Param.device)
                temp = maml_clone(support_x,
                                  support_y,
                                  qx,
                                  qy,
                                  meta_train=False)
                if (ans is None):
                    ans = temp
                else:
                    ans = torch.cat([ans, temp], dim=0)
            ans = ans.mean(dim=0).tolist()
            test_result[epoch] = ans
            if (ans[-1] > best_acc):
                best_acc = ans[-1]
                torch.save(
                    maml.state_dict(), Param.out_path + 'net_' + str(epoch) +
                    '_' + str(best_acc) + '.pkl')
            del maml_clone
            print(str(epoch) + ': ' + str(ans))
            with open(Param.out_path + 'test.json', 'w') as f:
                json.dump(test_result, f)
        if (epoch < 5) or (epoch % 100 == 99):
            plot.flush()
        plot.tick()
예제 #13
0
    time_start = time.time()

    env = GridWorldKey(max_time=3000,
                       n_keys=2,
                       normlize_obs=True,
                       use_nearby_obs=True)

    pi = Policy(name='policy',
                obs_dim=env.observation_space.shape[0] + 1,
                act_dim=env.action_space.shape[0],
                n_ways=0,
                batch_size=20,
                log_path='logfile/single')

    for i in range(100):

        traj = rollout(env, pi)
        trajs = [traj]
        for each in trajs:
            plot.plot('ep_r', each[-1])
            plot.tick('ep_r')
            obs = each[0].reshape(-1, pi.obs_dim)
            acts = each[1].reshape(-1, pi.act_dim)
            rews = each[2]
            pi.update(obs, acts, rews)
            plot.flush('logfile/single', verbose=True)

    time_end = time.time()
    print('==============\ntotally time cost = {}'.format(
        (time_end - time_start)))
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    #np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    if len(args.gpu.split(',')) > 1:
        maml = torch.nn.DataParallel(maml)
    opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    if args.loader in [0, 1]:  # default loader
        if args.loader == 1:
            #from dataloader.mini_imagenet import MiniImageNet as MiniImagenet
            from MiniImagenet2 import MiniImagenet
        else:
            from MiniImagenet import MiniImagenet

        trainset = MiniImagenet(Param.root,
                                mode='train',
                                n_way=args.n_way,
                                k_shot=args.k_spt,
                                k_query=args.k_qry,
                                resize=args.imgsz)
        testset = MiniImagenet(Param.root,
                               mode='test',
                               n_way=args.n_way,
                               k_shot=args.k_spt,
                               k_query=args.k_qry,
                               resize=args.imgsz)
        trainloader = DataLoader(trainset,
                                 batch_size=args.task_num,
                                 shuffle=True,
                                 num_workers=4,
                                 worker_init_fn=worker_init_fn,
                                 drop_last=True)
        testloader = DataLoader(testset,
                                batch_size=1,
                                shuffle=True,
                                num_workers=1,
                                worker_init_fn=worker_init_fn,
                                drop_last=True)
        train_data = inf_get(trainloader)
        test_data = inf_get(testloader)

    elif args.loader == 2:  # pkl loader
        args_data = {}
        args_data['x_dim'] = "84,84,3"
        args_data['ratio'] = 1.0
        args_data['seed'] = 222
        loader_train = dataset_mini(600, 100, 'train', args_data)
        #loader_val   = dataset_mini(600, 100, 'val', args_data)
        loader_test = dataset_mini(600, 100, 'test', args_data)
        loader_train.load_data_pkl()
        #loader_val.load_data_pkl()
        loader_test.load_data_pkl()

    for epoch in range(args.epoch):
        np.random.seed()
        if args.loader in [0, 1]:
            support_x, support_y, meta_x, meta_y = train_data.__next__()
            support_x, support_y, meta_x, meta_y = support_x.to(
                Param.device), support_y.to(Param.device), meta_x.to(
                    Param.device), meta_y.to(Param.device)
        elif args.loader == 2:
            support_x, support_y, meta_x, meta_y = get_data(loader_train)
            support_x, support_y, meta_x, meta_y = support_x.to(
                Param.device), support_y.to(Param.device), meta_x.to(
                    Param.device), meta_y.to(Param.device)

        meta_loss = maml(support_x, support_y, meta_x, meta_y).mean()
        opt.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value=10.0)
        opt.step()
        plot.plot('meta_loss', meta_loss.item())

        if (epoch % 2500 == 0):
            ans = None
            maml_clone = deepcopy(maml)
            for _ in range(600):
                if args.loader in [0, 1]:
                    support_x, support_y, qx, qy = test_data.__next__()
                    support_x, support_y, qx, qy = support_x.to(
                        Param.device), support_y.to(Param.device), qx.to(
                            Param.device), qy.to(Param.device)
                elif args.loader == 2:
                    support_x, support_y, qx, qy = get_data(loader_test)
                    support_x, support_y, qx, qy = support_x.to(
                        Param.device), support_y.to(Param.device), qx.to(
                            Param.device), qy.to(Param.device)

                temp = maml_clone(support_x,
                                  support_y,
                                  qx,
                                  qy,
                                  meta_train=False)
                if (ans is None):
                    ans = temp
                else:
                    ans = torch.cat([ans, temp], dim=0)
            ans = ans.mean(dim=0).tolist()
            test_result[epoch] = ans
            if (ans[-1] > best_acc):
                best_acc = ans[-1]
                torch.save(
                    maml.state_dict(), Param.out_path + 'net_' + str(epoch) +
                    '_' + str(best_acc) + '.pkl')
            del maml_clone
            print(str(epoch) + ': ' + str(ans))
            with open(Param.out_path + 'test.json', 'w') as f:
                json.dump(test_result, f)
        if (epoch < 5) or (epoch % 100 == 0):
            plot.flush()
        plot.tick()