Exemplo n.º 1
0
def pretrain(run_name, netD, netG, motion_length, claim_length, embedding_dim,
             hidden_dim_G, hidden_dim_D, lam, batch_size, epochs, num_words,
             train_data_dir, test_data_dir):
    # netG = GeneratorEncDec(batch_size, num_words, claim_length, hidden_dim_G, embedding_dim)
    nllloss = batchNLLLoss()
    bceloss = nn.BCELoss()

    if use_cuda:
        nllloss = nllloss.cuda(gpu)
        bceloss = bceloss.cuda(gpu)
        netG = netG.cuda(gpu)
        netD = netD.cuda(gpu)

    optimizerG = optim.Adam(netG.parameters(), lr=1e-4)  #, betas=(0.5, 0.9))
    optimizerD = optim.Adam(netD.parameters(), lr=1e-4)  #, betas=(0.5, 0.9))

    # one = torch.FloatTensor([1])
    # mone = one * -1
    # if use_cuda:
    #     one = one.cuda(gpu)
    #     mone = mone.cuda(gpu)

    num_train_data = len(os.listdir(train_data_dir))
    num_test_data = len(os.listdir(test_data_dir))

    start_epoch = 0
    if os.path.isfile(pretrain_netD):
        print("=> loading checkpoint '{}'".format(pretrain_netD))
        checkpoint = torch.load(pretrain_netD)
        start_epoch = checkpoint['epoch']
        netD.load_state_dict(checkpoint['D_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            pretrain_netD, checkpoint['epoch']))

    ###########################
    # (1) Pretrain D network  #
    ###########################

    pretraining_generator = FakeDataGenerator(lexicon_count=lexicon_count,
                                              motion_length=motion_length,
                                              claim_length=claim_length,
                                              num_data=num_train_data,
                                              data_dir=train_data_dir,
                                              batch_size=batch_size,
                                              shuffle=True)
    pretraining_test_generator = FakeDataGenerator(lexicon_count=lexicon_count,
                                                   motion_length=motion_length,
                                                   claim_length=claim_length,
                                                   num_data=num_test_data,
                                                   data_dir=test_data_dir,
                                                   batch_size=batch_size,
                                                   shuffle=True)
    netD.batch_size = netD.batch_size * 2

    for iteration in xrange(start_epoch, epochs + 1):
        break
        print(" >>> Epoch : %d/%d" % (iteration + 1, epochs))
        start_time = time.time()

        training_steps = num_train_data // batch_size
        for training_step in range(training_steps):
            llprint(
                "\rPretraining Discriminator : Training step Discriminator %d/%d"
                % (training_step + 1, training_steps))
            netD.zero_grad()

            _claim, _label = pretraining_generator.generate().next()
            _claim = np.asarray(_claim)
            _label = np.asarray(_label)
            claims = torch.from_numpy(_claim)
            labels = torch.from_numpy(_label)

            if use_cuda:
                claims = claims.cuda(gpu)
                labels = labels.cuda(gpu)
            claims_v = autograd.Variable(claims)
            labels_v = autograd.Variable(labels).unsqueeze(1)

            _, D_out = netD(claims_v)

            D_loss = bceloss(D_out, labels_v)

            D_loss.backward()
            D_loss = D_loss
            optimizerD.step()

        testing_steps = num_test_data // batch_size
        D_loss_test = 0
        for testing_step in range(testing_steps):
            llprint("\Testing step Discriminator %d/%d" %
                    (testing_step + 1, testing_steps))
            netD.zero_grad()

            _claim, _label = pretraining_test_generator.generate().next()
            _claim = np.asarray(_claim)
            _label = np.asarray(_label)
            claims = torch.from_numpy(_claim)
            labels = torch.from_numpy(_label)

            if use_cuda:
                claims = claims.cuda(gpu)
                labels = labels.cuda(gpu)
            claims_v = autograd.Variable(claims)
            labels_v = autograd.Variable(labels).unsqueeze(1)

            _, D_out = netD(claims_v)

            D_loss_test += bceloss(D_out, labels_v).cpu().data.numpy()[0]

        print('Iter: {}; D_loss: {:.4}'.format(iteration + 1,
                                               D_loss_test / testing_steps))
        with open('pretrain_D_loss_logger.txt', 'a') as out:
            out.write(
                str(iteration + 1) + ',' + str(D_loss.cpu().data.numpy()[0]) +
                '\n')

        if iteration % 10 == 0:
            save_pretrain_checkpoint_D({
                'epoch': iteration + 1,
                'D_state_dict': netD.state_dict(),
                'optimizerD': optimizerD.state_dict(),
            })

    ###########################
    # (2) Pretrain G network  #
    ###########################

    start_epoch = 0
    netD.batch_size /= 2
    if os.path.isfile(pretrain_netG):
        print("=> loading checkpoint '{}'".format(pretrain_netG))
        checkpoint = torch.load(pretrain_netG)
        start_epoch = checkpoint['epoch']
        netG.load_state_dict(checkpoint['G_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            pretrain_netG, checkpoint['epoch']))

    training_generator = DataGenerator(lexicon_count=lexicon_count,
                                       motion_length=motion_length,
                                       claim_length=claim_length,
                                       num_data=num_train_data,
                                       data_dir=train_data_dir,
                                       batch_size=batch_size,
                                       shuffle=True)
    test_generator = DataGenerator(lexicon_count=lexicon_count,
                                   motion_length=motion_length,
                                   claim_length=claim_length,
                                   num_data=num_test_data,
                                   data_dir=test_data_dir,
                                   batch_size=batch_size,
                                   shuffle=True)

    for iteration in xrange(start_epoch, epochs + 1):
        print(" >>> Epoch : %d/%d" % (iteration + 1, epochs))
        start_time = time.time()

        print("\nPretraining Generator :")
        training_steps = num_train_data // batch_size
        for training_step in range(training_steps):
            llprint("\rTraining step Generator %d/%d" %
                    (training_step + 1, training_steps))
            netG.zero_grad()

            _motion, _claim = training_generator.generate().next()
            _motion = np.asarray(_motion)
            _claim = np.asarray(_claim)
            real_motion = torch.LongTensor(_motion.tolist())
            real_claim_G = torch.from_numpy(np.argmax(_claim, 2))
            real_claim_D = torch.from_numpy(_claim).type('torch.LongTensor')

            if use_cuda:
                real_motion = real_motion.cuda(gpu)
                real_claim_G = real_claim_G.cuda(gpu)
                real_claim_D = real_claim_D.cuda(gpu)
            real_motion_v = autograd.Variable(real_motion)
            real_claim_G_v = autograd.Variable(real_claim_G)
            real_claim_D_v = autograd.Variable(real_claim_D)

            fake = netG(real_motion_v, real_claim_G_v)

            G_loss_1 = nllloss(fake, real_claim_G_v, claim_length) / batch_size

            G_loss_1.backward()
            G_cost = G_loss_1
            optimizerG.step()

        print('Iter: {}; G_loss: {:.4}'.format(iteration + 1,
                                               G_cost.cpu().data.numpy()[0]))
        with open('pretrain_G_loss_logger.txt', 'a') as out:
            out.write(
                str(iteration + 1) + ',' + str(G_cost.cpu().data.numpy()[0]) +
                '\n')

        if iteration % 100 == 0:
            result_text = open(
                "pretrain_" +
                datetime.datetime.now().strftime("%Y-%m-%d:%H:%M:%S") + ".txt",
                "w")
            testing_steps = num_test_data // batch_size
            for test_step in range(testing_steps):
                llprint("\rTesting step %d/%d" %
                        (test_step + 1, testing_steps))

                motion_test, claim_test = test_generator.generate().next()
                _motion_test = np.asarray(motion_test)
                _claim_test = np.asarray(claim_test)
                real_motion_test = torch.LongTensor(_motion_test.tolist())
                real_claim_test = torch.from_numpy(np.argmax(_claim_test, 2))
                if use_cuda:
                    real_motion_test = real_motion_test.cuda(gpu)
                    real_claim_test = real_claim_test.cuda(gpu)
                real_motion_test_v = autograd.Variable(real_motion_test)
                real_claim_test_v = autograd.Variable(real_claim_test)

                for mot, cla in zip(
                        decode_motion(motion_test),
                        decode(netG(real_motion_test_v, real_claim_test_v,
                                    0.0))):
                    result_text.write("Motion: %s\n" % mot)
                    result_text.write("Generated Claim: %s\n\n" % cla)
        if iteration % 10 == 0:
            save_pretrain_checkpoint_G({
                'epoch': iteration + 1,
                'G_state_dict': netG.state_dict(),
                'optimizerG': optimizerG.state_dict(),
            })

    return netD, netG
Exemplo n.º 2
0
def train(run_name, seq2seq, motion_length, claim_length, embedding_dim,
          hidden_dim, batch_size, epochs, num_words, train_data_dir,
          test_data_dir):
    nllloss = batchNLLLoss()
    # bceloss = nn.BCELoss()
    optimizerG = optim.Adam(seq2seq.parameters(), lr=1e-4, betas=(0.5, 0.9))

    if use_cuda:
        nllloss = nllloss.cuda(gpu)
        seq2seq = seq2seq.cuda(gpu)

    num_train_data = len(os.listdir(train_data_dir))
    num_test_data = len(os.listdir(test_data_dir))

    start_epoch = 0
    if os.path.isfile('last_checkpoint_seq2seq.pth.tar'):
        print("=> loading checkpoint '{}'".format(
            'last_checkpoint_seq2seq.pth.tar'))
        checkpoint = torch.load('last_checkpoint_seq2seq.pth.tar')
        start_epoch = checkpoint['epoch']
        seq2seq.load_state_dict(checkpoint['state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            'last_checkpoint_seq2seq.pth.tar', checkpoint['epoch']))

    training_generator = DataGenerator(lexicon_count=lexicon_count,
                                       motion_length=motion_length,
                                       claim_length=claim_length,
                                       num_data=num_train_data,
                                       data_dir=train_data_dir,
                                       batch_size=batch_size,
                                       shuffle=True)
    test_generator = DataGenerator(lexicon_count=lexicon_count,
                                   motion_length=motion_length,
                                   claim_length=claim_length,
                                   num_data=num_test_data,
                                   data_dir=test_data_dir,
                                   batch_size=batch_size,
                                   shuffle=True)

    for iteration in xrange(start_epoch, epochs + 1):
        print(" >>> Epoch : %d/%d" % (iteration + 1, epochs))
        start_time = time.time()

        print("\nTraining :")
        training_steps = num_train_data // batch_size
        for training_step in range(training_steps):
            llprint("\rTraining step %d/%d" %
                    (training_step + 1, training_steps))
            seq2seq.zero_grad()

            _motion, _claim = training_generator.generate().next()
            _motion = np.asarray(_motion)
            _claim = np.asarray(_claim)
            real_motion = torch.LongTensor(_motion.tolist())
            real_claim_G = torch.from_numpy(np.argmax(_claim, 2))
            real_claim_D = torch.from_numpy(_claim).type('torch.LongTensor')

            if use_cuda:
                real_motion = real_motion.cuda(gpu)
                real_claim_G = real_claim_G.cuda(gpu)
                real_claim_D = real_claim_D.cuda(gpu)
            real_motion_v = autograd.Variable(real_motion)
            real_claim_G_v = autograd.Variable(real_claim_G)
            real_claim_D_v = autograd.Variable(real_claim_D)

            fake = seq2seq(real_motion_v, real_claim_G_v)

            G_loss = nllloss(fake, real_claim_D_v, claim_length)

            G_loss.backward()
            G_cost = G_loss
            optimizerG.step()

        print('Iter: {}; G_loss: {:.4}'.format(iteration + 1,
                                               G_cost.cpu().data.numpy()[0]))
        with open('pretrain_G_loss_logger.txt', 'a') as out:
            out.write(
                str(iteration + 1) + ',' + str(G_cost.cpu().data.numpy()[0]) +
                '\n')

        if iteration % 20 == 0:
            result_text = open(
                "pretrain_" +
                datetime.datetime.now().strftime("%Y-%m-%d:%H:%M:%S") + ".txt",
                "w")
            testing_steps = num_test_data // batch_size
            for test_step in range(testing_steps):
                llprint("\rTesting step %d/%d" %
                        (test_step + 1, testing_steps))

                motion_test, claim_test = test_generator.generate().next()
                _motion_test = np.asarray(motion_test)
                _claim_test = np.asarray(claim_test)
                real_motion_test = torch.LongTensor(_motion_test.tolist())
                real_claim_test = torch.from_numpy(np.argmax(_claim_test, 2))
                if use_cuda:
                    real_motion_test = real_motion_test.cuda(gpu)
                    real_claim_test = real_claim_test.cuda(gpu)
                real_motion_test_v = autograd.Variable(real_motion_test)
                real_claim_test_v = autograd.Variable(real_claim_test)

                for mot, cla in zip(
                        decode_motion(motion_test),
                        decode(
                            seq2seq(real_motion_test_v, real_claim_test_v,
                                    0.0))):
                    result_text.write("Motion: %s\n" % mot)
                    result_text.write("Generated Claim: %s\n\n" % cla)
        if iteration % 10 == 0:
            save_pretrain_checkpoint_G({
                'epoch': iteration + 1,
                'state_dict': seq2seq.state_dict(),
                'optimizer': optimizerG.state_dict(),
            })

    return seq2seq
Exemplo n.º 3
0
def train(run_name, netG, netD, motion_length, claim_length, embedding_dim,
          hidden_dim_G, hidden_dim_D, lam, batch_size, epochs, iteration_d,
          num_words, train_data_dir, test_data_dir):
    jsdloss = JSDLoss()
    criterion = nn.BCELoss()
    nllloss = batchNLLLoss()
    print netG
    print netD

    if use_cuda:
        jsdloss = jsdloss.cuda(gpu)
        nllloss = nllloss.cuda(gpu)
        criterion = criterion.cuda(gpu)
        netD = netD.cuda(gpu)
        netG = netG.cuda(gpu)

    optimizerD = optim.Adam(netD.parameters(), lr=5e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=5e-4, betas=(0.5, 0.9))

    one = torch.FloatTensor([1])
    mone = one * -1
    if use_cuda:
        one = one.cuda(gpu)
        mone = mone.cuda(gpu)

    num_train_data = len(os.listdir(train_data_dir))
    num_test_data = len(os.listdir(test_data_dir))

    training_generator = DataGenerator(lexicon_count=lexicon_count,
                                       motion_length=motion_length,
                                       claim_length=claim_length,
                                       num_data=num_train_data,
                                       data_dir=train_data_dir,
                                       batch_size=batch_size,
                                       shuffle=True)
    test_generator = DataGenerator(lexicon_count=lexicon_count,
                                   motion_length=motion_length,
                                   claim_length=claim_length,
                                   num_data=num_test_data,
                                   data_dir=test_data_dir,
                                   batch_size=batch_size,
                                   shuffle=True)

    start_epoch = 0
    if os.path.isfile(train_GAN):
        print("=> loading checkpoint '{}'".format(train_GAN))
        checkpoint = torch.load(train_GAN)
        start_epoch = checkpoint['epoch']
        netD.load_state_dict(checkpoint['D_state_dict'])
        netG.load_state_dict(checkpoint['G_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            train_GAN, checkpoint['epoch']))

    for iteration in xrange(start_epoch, epochs + 1):
        print(" >>> Epoch : %d/%d" % (iteration + 1, epochs))
        start_time = time.time()
        ############################
        # (1) Update D network
        ###########################
        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update

        for iter_d in xrange(iteration_d):
            training_steps = num_train_data // batch_size
            for training_step in range(training_steps):
                llprint(
                    "\rTraining Discriminator : %d/%d ; Training step Discriminator : %d/%d"
                    % (iter_d + 1, iteration_d, training_step + 1,
                       training_steps))
                _motion, _claim = training_generator.generate().next()
                _claim = np.asarray(_claim)
                real_claim_G = torch.from_numpy(np.argmax(_claim, 2))
                real_claim_D = torch.from_numpy(_claim)
                real_labels = torch.ones(real_claim_G.size(0), 1)

                if use_cuda:
                    real_claim_G = real_claim_G.cuda(gpu)
                    real_claim_D = real_claim_D.cuda(gpu)
                    real_labels = real_labels.cuda(gpu)
                real_claim_G_v = autograd.Variable(real_claim_G)
                real_claim_D_v = autograd.Variable(real_claim_D)
                real_labels = autograd.Variable(real_labels)

                # nn.utils.clip_grad_norm(netD.parameters(), 5.0)
                for p in netD.parameters():
                    p.data.clamp_(-5.0, 5.0)
                netD.zero_grad()

                # train with real
                f_real, D_real = netD(real_claim_D_v)
                D_real_loss = criterion(D_real, real_labels)

                # train with motion
                _motion = np.asarray(_motion)
                real_motion = torch.LongTensor(_motion.tolist())
                fake_labels = torch.zeros(real_motion.size(0), 1)
                if use_cuda:
                    real_motion = real_motion.cuda(gpu)
                    fake_labels = fake_labels.cuda(gpu)
                real_motion_G_v = autograd.Variable(real_motion, volatile=True)
                real_motion_D_v = autograd.Variable(real_motion)
                fake_labels = autograd.Variable(fake_labels)

                fake = autograd.Variable(
                    netG(real_motion_G_v, real_claim_G_v, 0.0).data)
                inputv = fake
                f_fake, D_fake = netD(inputv)
                f_motion, _ = netD(real_motion_D_v)

                D_fake_loss = criterion(D_fake, fake_labels)

                # f_fake = torch.mean(f_fake, dim = 0)
                # f_real = torch.mean(f_real, dim = 0)
                # f_motion = torch.mean(f_motion, dim = 0)
                # print(f_fake)
                # print(f_real)
                # print(f_motion)
                D_loss_fake = torch.mean((f_fake - f_real)**2)
                D_conditional_loss_fake = torch.mean((f_fake - f_motion)**2)
                D_conditional_loss_real = torch.mean((f_real - f_motion)**2)

                # D_jsd_loss = jsdloss(f_real, f_fake)
                # D_conditional_jsd_loss = jsdloss(f_motion, f_fake)

                D_loss = D_real_loss + D_fake_loss - D_loss_fake - D_conditional_loss_fake + D_conditional_loss_real
                D_loss.backward()
                optimizerD.step()

        ############################
        # (2) Update G network
        ###########################
        for p in netD.parameters():
            p.requires_grad = True  # to avoid computation

        print("\nTraining Generator :")
        for iter_g in xrange(iteration_g):
            for training_step in range(num_train_data // batch_size):
                llprint("\rTraining step Generator %d/%d" %
                        (training_step + 1, training_steps))

                # nn.utils.clip_grad_norm(netG.parameters(), 5.0)
                for p in netG.parameters():
                    p.data.clamp_(-5.0, 5.0)
                netG.zero_grad()

                _motion, _claim = training_generator.generate().next()
                _motion = np.asarray(_motion)
                _claim = np.asarray(_claim)
                real_motion = torch.LongTensor(_motion.tolist())
                real_claim_G = torch.from_numpy(np.argmax(_claim, 2))
                real_claim_D = torch.from_numpy(_claim)

                if use_cuda:
                    real_motion = real_motion.cuda(gpu)
                    real_claim_G = real_claim_G.cuda(gpu)
                    real_claim_D = real_claim_D.cuda(gpu)
                real_motion_v = autograd.Variable(real_motion)
                real_claim_G_v = autograd.Variable(real_claim_G)
                real_claim_D_v = autograd.Variable(real_claim_D)

                fake = netG(real_motion_v, real_claim_G_v)
                f_fake, G = netD(fake)
                f_real, _ = netD(real_claim_D_v)

                f_motion, _ = netD(real_motion_v)

                # f_fake = torch.mean(f_fake, dim = 0)
                # f_real = torch.mean(f_real, dim = 0)
                # f_motion = torch.mean(f_motion, dim = 0)
                G_loss_fake = torch.mean((f_fake - f_real)**2)
                G_conditional_loss_fake = torch.mean((f_fake - f_motion)**2)

                # G_fake_loss = jsdloss(f_real, f_fake)
                # G_conditional_loss = jsdloss(f_motion, f_fake)
                # G_loss = criterion(G, real_labels)

                # G_loss_total = G_loss + G_loss_van
                G_loss = G_loss_fake + G_conditional_loss_fake

                G_loss.backward()
                optimizerG.step()

        print('Iter: {}; D_loss: {:.4}; G_loss: {:.4}'.format(
            iteration,
            D_loss.cpu().data.numpy()[0],
            G_loss.cpu().data.numpy()[0]))

        if iteration % 5 == 0:
            result_text = open(
                'generated/' +
                datetime.datetime.now().strftime("%Y-%m-%d:%H:%M:%S") + ".txt",
                "w")
            testing_steps = num_test_data // batch_size
            G_NLL_loss_total = 0
            for test_step in range(testing_steps):
                llprint("\rTesting step %d/%d" %
                        (test_step + 1, testing_steps))

                motion_test, claim_test = test_generator.generate().next()
                _motion_test = np.asarray(motion_test)
                _claim_test = np.asarray(claim_test)
                real_motion_test = torch.LongTensor(_motion_test.tolist())
                real_claim_test = torch.from_numpy(np.argmax(_claim_test, 2))
                if use_cuda:
                    real_motion_test = real_motion_test.cuda(gpu)
                    real_claim_test = real_claim_test.cuda(gpu)
                real_motion_test_v = autograd.Variable(real_motion_test)
                real_claim_test_v = autograd.Variable(real_claim_test)

                fake = netG(real_motion_test_v, real_claim_test_v, 0.0)
                G_NLL_loss = nllloss(fake, real_claim_test_v,
                                     claim_length) / batch_size
                G_NLL_loss_total += G_NLL_loss.cpu().data.numpy()[0]

                for mot, cla in zip(decode_motion(motion_test), decode(fake)):
                    result_text.write("Motion: %s\n" % mot)
                    result_text.write("Generated Claim: %s\n\n" % cla)
            with open('loss_logger.txt', 'a') as out:
                out.write(
                    str(iteration + 1) + ',' +
                    str(D_loss.cpu().data.numpy()[0]) + ',' +
                    str(G_loss.cpu().data.numpy()[0]) + ',' +
                    str(G_NLL_loss_total / testing_steps) + '\n')
        if iteration % 2 == 0:
            save_checkpoint({
                'epoch': iteration + 1,
                'G_state_dict': netG.state_dict(),
                'D_state_dict': netD.state_dict(),
                'optimizerD': optimizerD.state_dict(),
                'optimizerG': optimizerG.state_dict(),
            })
Exemplo n.º 4
0
def train(run_name, netG, netD, motion_length, claim_length, embedding_dim,
          hidden_dim_G, hidden_dim_D, lam, batch_size, epochs, iteration_d,
          num_words, train_data_dir, test_data_dir):
    # netG = GeneratorEncDec(batch_size, num_words, claim_length, hidden_dim_G, embedding_dim)
    jsdloss = JSDLoss()
    cosine = nn.CosineSimilarity()
    print netG
    print netD

    if use_cuda:
        jsdloss = jsdloss.cuda(gpu)
        cosine = cosine.cuda(gpu)
        netD = netD.cuda(gpu)
        netG = netG.cuda(gpu)

    optimizerD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=5e-5, betas=(0.5, 0.9))

    one = torch.FloatTensor([1])
    mone = one * -1
    if use_cuda:
        one = one.cuda(gpu)
        mone = mone.cuda(gpu)

    num_train_data = len(os.listdir(train_data_dir))
    num_test_data = len(os.listdir(test_data_dir))

    training_generator = DataGenerator(lexicon_count=lexicon_count,
                                       motion_length=motion_length,
                                       claim_length=claim_length,
                                       num_data=num_train_data,
                                       data_dir=train_data_dir,
                                       batch_size=batch_size,
                                       shuffle=True)
    test_generator = DataGenerator(lexicon_count=lexicon_count,
                                   motion_length=motion_length,
                                   claim_length=claim_length,
                                   num_data=num_test_data,
                                   data_dir=test_data_dir,
                                   batch_size=batch_size,
                                   shuffle=True)

    start_epoch = 0
    if os.path.isfile('last_checkpoint_wgan.pth.tar'):
        print("=> loading checkpoint '{}'".format(
            'last_checkpoint_wgan.pth.tar'))
        checkpoint = torch.load('last_checkpoint_wgan.pth.tar')
        start_epoch = checkpoint['epoch']
        netD.load_state_dict(checkpoint['D_state_dict'])
        netG.load_state_dict(checkpoint['G_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            'last_checkpoint_wgan.pth.tar', checkpoint['epoch']))

    for iteration in xrange(start_epoch, epochs + 1):
        print(" >>> Epoch : %d/%d" % (iteration + 1, epochs))
        start_time = time.time()
        ############################
        # (1) Update D network
        ###########################
        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update

        for iter_d in xrange(iteration_d):
            training_steps = num_train_data // batch_size
            for training_step in range(training_steps):
                llprint(
                    "\rTraining Discriminator : %d/%d ; Training step Discriminator : %d/%d"
                    % (iter_d + 1, iteration_d, training_step + 1,
                       training_steps))
                _motion, _claim = training_generator.generate().next()
                _motion = np.asarray(_motion)
                _claim = np.asarray(_claim)
                real_motion = torch.LongTensor(_motion.tolist())

                real_claim_G = torch.from_numpy(np.argmax(_claim, 2))
                real_claim_D = torch.from_numpy(_claim)

                if use_cuda:
                    real_motion = real_motion.cuda(gpu)
                    real_claim_G = real_claim_G.cuda(gpu)
                    real_claim_D = real_claim_D.cuda(gpu)
                real_motion_v = autograd.Variable(real_motion, volatile=True)
                real_claim_G_v = autograd.Variable(real_claim_G)
                real_claim_D_v = autograd.Variable(real_claim_D)

                nn.utils.clip_grad_norm(netD.parameters(), 5)
                netD.zero_grad()

                # train with real
                f_real, D_real = netD(real_claim_D_v)
                D_real = D_real.mean()
                D_real.backward(mone, retain_graph=True)

                # train with motion
                fake = autograd.Variable(
                    netG(real_motion_v, real_claim_G_v, 0.0).data)
                inputv = fake
                f_fake, D_fake = netD(inputv)
                D_fake = D_fake.mean()
                D_fake.backward(one, retain_graph=True)

                #D_jsd = jsdloss(batch_size, f_fake, f_real).mean()
                #D_jsd.backward(mone)

                # train with gradient penalty
                gradient_penalty = calc_gradient_penalty(
                    batch_size, lam, netD, real_claim_D_v.data,
                    fake.data)  # + cosine(f_fake, f_real).mean()
                # print gradient_penalty
                gradient_penalty.backward()

                D_cost = D_fake - D_real + gradient_penalty
                Wasserstein_D = D_real - D_fake
                optimizerD.step()

        ############################
        # (2) Update G network
        ###########################
        for p in netD.parameters():
            p.requires_grad = False  # to avoid computation

        print("\nTraining Generator :")
        for training_step in range(num_train_data // batch_size):
            llprint("\rTraining step Generator %d/%d" %
                    (training_step + 1, training_steps))

            nn.utils.clip_grad_norm(netG.parameters(), 5)
            netG.zero_grad()

            _motion, _claim = training_generator.generate().next()
            _motion = np.asarray(_motion)
            _claim = np.asarray(_claim)
            real_motion = torch.LongTensor(_motion.tolist())
            real_claim_G = torch.from_numpy(np.argmax(_claim, 2))
            real_claim_D = torch.from_numpy(_claim)

            if use_cuda:
                real_motion = real_motion.cuda(gpu)
                real_claim_G = real_claim_G.cuda(gpu)
                real_claim_D = real_claim_D.cuda(gpu)
            real_motion_v = autograd.Variable(real_motion)
            real_claim_G_v = autograd.Variable(real_claim_G)
            real_claim_D_v = autograd.Variable(real_claim_D)

            fake = netG(real_motion_v, real_claim_G_v)
            f_fake, G = netD(fake)
            f_real, _ = netD(real_claim_D_v)

            G_loss = G.mean()
            #G_loss = G.mean() + jsdloss(batch_size, f_fake, f_real).mean()

            G_loss.backward(mone)
            G_cost = -G_loss
            optimizerG.step()

        print('Iter: {}; D_loss: {:.4}; G_loss: {:.4}; Wasserstein_D: {:.4}'.
              format(iteration,
                     D_cost.cpu().data.numpy()[0],
                     G_cost.cpu().data.numpy()[0],
                     Wasserstein_D.cpu().data.numpy()[0]))
        with open('loss_logger.txt', 'a') as out:
            out.write(
                str(iteration + 1) + ',' + str(D_cost.cpu().data.numpy()[0]) +
                ',' + str(G_cost.cpu().data.numpy()[0]) + ',' +
                str(Wasserstein_D.cpu().data.numpy()[0]) + '\n')

        if iteration % 5 == 0:
            result_text = open(
                datetime.datetime.now().strftime("%Y-%m-%d:%H:%M:%S") + ".txt",
                "w")
            testing_steps = num_test_data // batch_size
            for test_step in range(testing_steps):
                llprint("\rTesting step %d/%d" %
                        (test_step + 1, testing_steps))

                motion_test, claim_test = test_generator.generate().next()
                _motion_test = np.asarray(motion_test)
                _claim_test = np.asarray(claim_test)
                real_motion_test = torch.LongTensor(_motion_test.tolist())
                real_claim_test = torch.from_numpy(np.argmax(_claim_test, 2))
                if use_cuda:
                    real_motion_test = real_motion_test.cuda(gpu)
                    real_claim_test = real_claim_test.cuda(gpu)
                real_motion_test_v = autograd.Variable(real_motion_test)
                real_claim_test_v = autograd.Variable(real_claim_test)

                for mot, cla in zip(
                        decode_motion(motion_test),
                        decode(netG(real_motion_test_v, real_claim_test_v,
                                    0.0))):
                    result_text.write("Motion: %s\n" % mot)
                    result_text.write("Generated Claim: %s\n\n" % cla)
        if iteration % 2 == 0:
            save_checkpoint({
                'epoch': iteration + 1,
                'G_state_dict': netG.state_dict(),
                'D_state_dict': netD.state_dict(),
                'optimizerD': optimizerD.state_dict(),
                'optimizerG': optimizerG.state_dict(),
            })