コード例 #1
0
ファイル: lstmCore.py プロジェクト: natnij/seqGAN_pytorch
def sanityCheck_LSTMCore(batch_size=1):
    ''' test prtrain_LSTMCore function '''
    log = openLog('test.txt')
    log.write('\n\nTest lstmCore.sanityCheck_LSTMCore: {}\n'.format(datetime.now())) 
    log.close()
    x, _, reverse_vocab, _ = read_sampleFile()
    pretrain_result = pretrain_LSTMCore(train_x=x,batch_size=batch_size,vocab_size=len(reverse_vocab))
    model = pretrain_result[0]
    y_all_max, y_all_sample = test_genMaxSample(model,start_token=0,batch_size=batch_size)
    log = openLog('test.txt')
    gen_tokens_max = decode(y_all_max, reverse_vocab, log)
    gen_tokens_sample = decode(y_all_sample, reverse_vocab, log)
    log.close()
    return gen_tokens_max, gen_tokens_sample
コード例 #2
0
def sanityCheck_rewards(batch_size=5):
    ''' test rewards generation '''
    log = openLog('test.txt')
    log.write('\n\nTest rollout.sanityCheck_rewards: {}'.format(
        datetime.now()))
    try:
        generator, _, y_output_all = sanityCheck_generator(
            batch_size=batch_size, sample_size=batch_size * 2)
        gen_output = y_output_all[-batch_size:, :]
        rollout = Rollout(generator=generator)
        rollout = nn.DataParallel(rollout)
        rollout.to(DEVICE)
        discriminator = train_discriminator(
            batch_size=batch_size,
            vocab_size=generator.pretrain_model.module.vocab_size)
        rewards = getReward(gen_output, rollout, discriminator)
        log.write('\n  rollout.sanityCheck_rewards SUCCESSFUL. {}\n'.format(
            datetime.now()))
        log.close()
        return rewards
    except:
        log.write(
            '\n  rollout.sanityCheck_rewards !!!!!! UNSUCCESSFUL !!!!!! {}\n'.
            format(datetime.now()))
        log.close()
        return None
コード例 #3
0
def sanityCheck_rollout_updateParams():
    ''' test updateParams function '''
    generator, _, _ = sanityCheck_generator()
    rollout = Rollout(generator=generator)
    rollout.to(DEVICE)
    log = openLog('test.txt')
    log.write('\n\nTest rollout.sanityCheck_updateParams: {}\n'.format(
        datetime.now()))
    log.write('original rollout params:\n')
    param_r = [str(x) for x in list(rollout.lstm.parameters())[0][0].tolist()]
    log.write(' '.join(param_r))

    generator, _, _ = sanityCheck_generator(model=generator)
    log.write('\nnew generator params:\n')
    param_g = [
        str(x) for x in list(generator.pretrain_model.lstm.parameters())[0]
        [0].tolist()
    ]
    log.write(' '.join(param_g))

    rollout.update_params(generator)
    log.write('\nnew rollout params:\n')
    param_r = [str(x) for x in list(rollout.lstm.parameters())[0][0].tolist()]
    log.write(' '.join(param_r))
    log.close()
コード例 #4
0
def test_genMaxSample(model, start_token=0, batch_size=1):
    ''' test lstmCore's generation function '''
    log = openLog('test.txt')
    log.write('\n\nTest lstmCore.test_genMaxSample: {}'.format(datetime.now()))
    with torch.no_grad():
        y = [start_token] * batch_size
        y_all_max = torch.Tensor(y).int().view(-1, 1)
        model.hidden = model.init_hidden(len(y))
        for i in range(SEQ_LENGTH - 1):
            x = torch.Tensor(y).view([-1, 1])
            y_pred = model(x, sentence_lengths=[1])
            y_pred = y_pred[:, :, 1:-1]
            y_pred = y_pred.squeeze(dim=1)
            # take the max
            y = torch.argmax(y_pred, dim=1).float().view(-1, 1)
            y_all_max = torch.cat([y_all_max, y.int()], dim=1)

        y = [start_token] * batch_size
        y_all_sample = torch.Tensor(y).int().view(-1, 1)
        model.hidden = model.init_hidden(len(y))
        for i in range(SEQ_LENGTH - 1):
            x = torch.Tensor(y).view([-1, 1])
            y_pred = model(x, sentence_lengths=[1])
            # random choice based on probability distribution.
            y_prob = F.softmax(model.tag_space, dim=2)
            shape = (y_prob.shape[0], y_prob.shape[1])
            y = y_prob.view(-1, y_prob.shape[-1]).multinomial(
                num_samples=1).float().view(shape)
            y_all_sample = torch.cat([y_all_sample, y.int()], dim=1)
    log.write('\n  lstmCore.test_genMaxSample SUCCESSFUL: {}\n'.format(
        datetime.now()))
    log.write('    y_all_max: \n' + str(y_all_max) + '\n')
    log.write('    y_all_sample: \n' + str(y_all_sample) + '\n')
    log.close()
    return y_all_max, y_all_sample
コード例 #5
0
ファイル: generator.py プロジェクト: ShuchiZhang/CIS
def sanityCheck_GeneratorLoss(pretrain_result=None, batch_size=5):
    '''test custom loss function '''
    if pretrain_result is None:
        x, _, reverse_vocab, _ = read_sampleFile()
        pretrain_result = pretrain_LSTMCore(x, vocab_size=len(reverse_vocab))
    model = pretrain_result[0]
    y_pred_pretrain = pretrain_result[1].view(
        [-1, SEQ_LENGTH, len(reverse_vocab)])
    test_reward = y_pred_pretrain.sum(dim=2).data
    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = torch.optim.SGD(params, lr=0.01)
    optimizer.zero_grad()

    log = openLog('test.txt')
    log.write('\n\nTest generator.sanityCheck_GeneratorLoss: {}\n'.format(
        datetime.now()))
    criterion = GeneratorLoss()
    g_loss = criterion(y_pred_pretrain[0:batch_size, :, :], x[0:batch_size, :],
                       test_reward[0:batch_size, :])
    g_loss.backward()
    optimizer.step()
    log.write('  generator.sanityCheck_GeneratorLoss SUCCESSFUL: ' +
              str(g_loss) + '\n')
    log.close()
    return g_loss
コード例 #6
0
def main(batch_size=1):
    model = torch.load(PATH+'generator.pkl')
    reverse_vocab = torch.load(PATH+'reverse_vocab.pkl')

    num = model.generate(batch_size=batch_size)
    log = openLog('genTxt_predict.txt')
    result = decode(num, reverse_vocab, log)
    log.close()
    return result
コード例 #7
0
def pretrain_LSTMCore(train_x=None,
                      sentence_lengths=None,
                      batch_size=1,
                      end_token=None,
                      vocab_size=10):
    if train_x is None:
        x = gen_record(vocab_size=vocab_size)
    else:
        x = train_x
    if len(x.shape) == 1:
        x = x.view(1, x.shape[0])
    if sentence_lengths is None:
        sentence_lengths = [x.shape[1]] * len(x)
    if len(sentence_lengths) < len(x):
        sentence_lengths.extend([x.shape[1]] *
                                (len(x) - len(sentence_lengths)))
    if end_token is None:
        end_token = vocab_size - 1

    model = LSTMCore(vocab_size)
    model.to(DEVICE)
    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    criterion = nn.NLLLoss()
    optimizer = torch.optim.SGD(params, lr=0.01)
    y_pred_all = []
    log = openLog()
    log.write('\n\ntraining lstmCore: {}\n'.format(datetime.now()))
    for epoch in range(GEN_NUM_EPOCH_PRETRAIN):
        pointer = 0
        y_pred_all = []
        epoch_loss = []
        while pointer + batch_size <= len(x):
            x_batch = x[pointer:pointer + batch_size]
            x0_length = sentence_lengths[pointer:pointer + batch_size]
            y = torch.cat(
                (x_batch[:, 1:], torch.Tensor(
                    [end_token] * x_batch.shape[0]).int().view(
                        x_batch.shape[0], 1)),
                dim=1)
            model.hidden = model.init_hidden(batch_size)
            y_pred = model(x_batch, x0_length)
            loss = criterion(y_pred.view(-1, y_pred.shape[-1]),
                             y.long().view(-1))
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            y_prob = F.softmax(model.tag_space, dim=2)
            y_pred_all.append(y_prob)
            epoch_loss.append(loss.item())
            pointer = pointer + batch_size
        log.write('epoch: ' + str(epoch) + ' loss: ' +
                  str(sum(epoch_loss) / len(epoch_loss)) + '\n')
    log.close()
    return model, torch.cat(y_pred_all)
コード例 #8
0
ファイル: generator.py プロジェクト: ShuchiZhang/CIS
def train_generator(model,
                    x,
                    reward,
                    iter_n_gen=None,
                    batch_size=1,
                    sentence_lengths=None):
    if len(x.shape) == 1:
        x = x.view(1, x.shape[0])
    rem = len(x) % batch_size
    if rem > 0:
        x = x[0:len(x) - rem]
    if sentence_lengths is None:
        sentence_lengths = [x.shape[1]] * len(x)
    if len(sentence_lengths) < len(x):
        sentence_lengths.extend([x.shape[1]] *
                                (len(x) - len(sentence_lengths)))
    sentence_lengths = torch.tensor(sentence_lengths, device=DEVICE).long()
    if reward is None:
        reward = torch.tensor([1.0] * x.shape[0] * x.shape[1],
                              device=DEVICE).view(x.shape)
    if iter_n_gen is None:
        iter_n_gen = GEN_NUM_EPOCH

    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = torch.optim.SGD(params, lr=0.01)
    log = openLog()
    log.write('    training generator: {}\n'.format(datetime.now()))
    for epoch in range(iter_n_gen):
        pointer = 0
        y_prob_all = []
        y_output_all = []
        epoch_loss = []
        while pointer + batch_size <= len(x):
            x_batch = x[pointer:pointer + batch_size]
            r_batch = reward[pointer:pointer + batch_size]
            s_length = sentence_lengths[pointer:pointer + batch_size]
            hidden = model.pretrain_model.module.init_hidden(batch_size)
            y_output, y_prob, loss_var = model(x=x_batch,
                                               hidden=hidden,
                                               rewards=r_batch,
                                               sentence_lengths=s_length)
            optimizer.zero_grad()
            loss_var.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
            optimizer.step()
            y_prob_all.append(y_prob)
            y_output_all.append(y_output)
            epoch_loss.append(loss_var.item())
            pointer = pointer + batch_size
        log.write('      epoch: ' + str(epoch) + ' loss: ' +
                  str(sum(epoch_loss) / len(epoch_loss)) + '\n')
    log.close()
    return (model, torch.cat(y_prob_all),
            torch.cat(y_output_all).view(list(x.shape)))
コード例 #9
0
def sanityCheck_discriminator(batch_size=1, vocab_size=10):
    ''' test discriminator instantiation and pretraining'''
    log = openLog('test.txt')
    log.write('\n\nTest discriminator.sanityCheck_discriminator: {}\n'.format(
        datetime.now()))
    model = train_discriminator(vocab_size=vocab_size)
    with torch.no_grad():
        x = gen_record(num=batch_size, vocab_size=vocab_size)
        y_pred = model(x)
    log.write('  y_pred shape: ' + str(y_pred.shape) + '\n')
    log.close()
    return model, y_pred
コード例 #10
0
ファイル: lstmCore.py プロジェクト: natnij/seqGAN_pytorch
def pretrain_LSTMCore(train_x=None, sentence_lengths=None, batch_size=1, end_token=None, vocab_size=10):
    if train_x is None:
        x = gen_record(vocab_size=vocab_size)
    else:
        x = train_x
    if len(x.shape) == 1:
        x = x.view(1,x.shape[0])
    if sentence_lengths is None:
        sentence_lengths = [x.shape[1]] * len(x)
    if len(sentence_lengths) < len(x):
        sentence_lengths.extend([x.shape[1]] * (len(x)-len(sentence_lengths)))
    if end_token is None:
        end_token = vocab_size - 1
    
    model = LSTMCore(vocab_size)
    model = nn.DataParallel(model)#, device_ids=[0])
    model.to(DEVICE)
    params = list(filter(lambda p: p.requires_grad, model.parameters()))       
    criterion = nn.NLLLoss()
    optimizer = torch.optim.SGD(params, lr=0.01)
    y_pred_all = []
    log = openLog()
    log.write('    training lstmCore: {}\n'.format(datetime.now()))
    for epoch in range(GEN_NUM_EPOCH_PRETRAIN):
        pointer = 0
        y_pred_all = []
        epoch_loss = []
        while pointer + batch_size <= len(x):
            x_batch = x[pointer:pointer+batch_size]
            x0_length = torch.tensor(sentence_lengths[pointer:pointer+batch_size]).to(device=DEVICE)
            y = torch.cat((x_batch[:,1:],
                           torch.tensor([end_token]*x_batch.shape[0],device=DEVICE)
                           .int().view(x_batch.shape[0],1)),dim=1)
            # hidden has to be passed to the model as a GPU tensor to be correctly sliced between multiple GPUs. 
            # default dim for DataParallel is dim=0, so the inputs will all be sliced on dim0. 
            # so the hidden tensors need to be permutated back to batch-size-second inside the forward pass
            #   in order to feed into the lstm layer. 
            # when using DataParallel the attributes can be accessed through .module
            hidden = model.module.init_hidden(batch_size)            
            y_pred, tag_space = model(x_batch, hidden, x0_length)
            loss = criterion(y_pred.view(-1,y_pred.shape[-1]), y.long().view(-1))
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
            optimizer.step()
            y_prob = F.softmax(tag_space, dim=2)
            y_pred_all.append(y_prob)
            epoch_loss.append(loss.item())
            pointer = pointer + batch_size
        log.write('      epoch: '+str(epoch)+' loss: '+str(sum(epoch_loss)/len(epoch_loss))+'\n')
    log.close()
    return model, torch.cat(y_pred_all)
コード例 #11
0
def sanityCheck_generator(model=None):
    ''' test Generator instantiation and train_generator function '''
    log = openLog('test.txt')
    log.write('\n\nTest generator.sanityCheck_generator: {}\n'.format(datetime.now()))     
    x, vocabulary, reverse_vocab, _ = read_sampleFile()
    if model is None:
        pretrain_result = pretrain_LSTMCore(x,vocab_size=len(vocabulary))
        model = Generator(pretrain_model=pretrain_result[0])
        log.write('  generator instantiated: {}\n'.format(datetime.now()))  
    model.to(DEVICE)
    model, y_prob_all, y_output_all = train_generator(model, x, reward=None)
    log.write('  trained generator outputs:\n')
    log.write('    y_output_all shape: '+ str(y_output_all.shape) +'\n')
    log.write('    y_prob_all shape: '+ str(y_prob_all.shape) +'\n')
    log.close()
    return model, y_prob_all, y_output_all
コード例 #12
0
def main(batch_size):
    if batch_size is None:
        batch_size = 1
    x, vocabulary, reverse_vocab, sentence_lengths = read_sampleFile()
    if batch_size > len(x):
        batch_size = len(x)
    start_token = vocabulary['START']
    end_token = vocabulary['END']
    pad_token = vocabulary['PAD']
    ignored_tokens = [start_token, end_token, pad_token]
    vocab_size = len(vocabulary)
    
    generator = pretrain_generator(x, start_token=start_token, 
                    end_token=end_token,ignored_tokens=ignored_tokens,
                    sentence_lengths=sentence_lengths,batch_size=batch_size,
                    vocab_size=vocab_size)
    x_gen = generator.generate(start_token=start_token, ignored_tokens=ignored_tokens, 
                               batch_size=len(x))
    discriminator = train_discriminator_wrapper(x, x_gen, batch_size, vocab_size)
    rollout = Rollout(generator, r_update_rate=0.8)
    rollout.to(DEVICE)
    for total_batch in range(TOTAL_BATCH):
        print('batch: {}'.format(total_batch))
        for it in range(1):
            samples = generator.generate(start_token=start_token, 
                    ignored_tokens=ignored_tokens, batch_size=batch_size)
            # Take average of ROLLOUT_ITER times of rewards.
            #   The more times a [0,1] class (positive, real data) 
            #   is returned, the higher the reward. 
            rewards = getReward(samples, rollout, discriminator)
            (generator, y_prob_all, y_output_all) = train_generator(model=generator, x=samples, 
                    reward=rewards, iter_n_gen=1, batch_size=batch_size, sentence_lengths=sentence_lengths)
        
        rollout.update_params(generator)
        
        for iter_n_dis in range(DIS_NUM_EPOCH):
            print('iter_n_dis: {}'.format(iter_n_dis))
            x_gen = generator.generate(start_token=start_token, ignored_tokens=ignored_tokens, 
                               batch_size=len(x))
            discriminator = train_discriminator_wrapper(x, x_gen, batch_size,vocab_size)
    
    log = openLog('genTxt.txt')
    num = generator.generate(batch_size=batch_size)
    words_all = decode(num, reverse_vocab, log)
    log.close()
    print(words_all)
コード例 #13
0
def train_discriminator(train_x=None,
                        train_y=None,
                        batch_size=1,
                        vocab_size=10):
    if train_x is None:
        x = gen_record(num=batch_size, vocab_size=vocab_size)
    else:
        x = train_x
    if train_y is None:
        y = gen_label()
    else:
        y = train_y

    model = Discriminator(filter_size=FILTER_SIZE,
                          num_filter=NUM_FILTER,
                          vocab_size=vocab_size)
    model = nn.DataParallel(model)
    model.to(DEVICE)
    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(params, lr=0.01)

    log = openLog()
    log.write('    training discriminator: {}\n'.format(datetime.now()))
    for epoch in range(DIS_NUM_EPOCH_PRETRAIN):
        pointer = 0
        epoch_loss = []
        while pointer + batch_size <= len(x):
            x_batch = x[pointer:pointer + batch_size]
            y_batch = y[pointer:pointer + batch_size]
            # y_pred dim: (batch_size, nr.of.class)
            y_pred = model(x_batch)
            loss = criterion(y_pred, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pointer = pointer + batch_size
            epoch_loss.append(loss.item())
        log.write('      epoch: ' + str(epoch) + ' loss: ' +
                  str(sum(epoch_loss) / len(epoch_loss)) + '\n')
    log.close()
    return model
コード例 #14
0
def sanityCheck_rollout_updateParams(batch_size=1):
    ''' test updateParams function '''
    generator, _, _ = sanityCheck_generator(batch_size=batch_size,
                                            sample_size=batch_size * 2)
    rollout = Rollout(generator=generator)
    rollout = nn.DataParallel(rollout)
    rollout.to(DEVICE)
    log = openLog('test.txt')
    log.write('\n\nTest rollout.sanityCheck_updateParams: {}\n'.format(
        datetime.now()))
    try:
        log.write('original rollout params:\n')
        param_r = [
            str(x)
            for x in list(rollout.module.lstm.parameters())[0][0].tolist()[0:3]
        ]
        log.write(' '.join(param_r))

        generator, _, _ = sanityCheck_generator(model=generator)
        log.write('\nnew generator params:\n')
        param_g = [
            str(x)
            for x in list(generator.pretrain_model.module.lstm.parameters())[0]
            [0].tolist()[0:3]
        ]
        log.write(' '.join(param_g))

        rollout.module.update_params(generator)
        log.write('\nnew rollout params:\n')
        param_r = [
            str(x)
            for x in list(rollout.module.lstm.parameters())[0][0].tolist()[0:3]
        ]
        log.write(' '.join(param_r))
        log.write(
            '\n  rollout.sanityCheck_updateParams SUCCESSFUL. {}\n'.format(
                datetime.now()))
    except:
        log.write(
            '\n  rollout.sanityCheck_updateParams !!!!!! UNSUCCESSFUL !!!!!! {}\n'
            .format(datetime.now()))
    log.close()
コード例 #15
0
def sanityCheck_rollout(batch_size=5):
    ''' test Rollout instantiation '''
    log = openLog('test.txt')
    log.write('\n\nTest rollout.sanityCheck_rollout: {}'.format(
        datetime.now()))
    x, _, reverse_vocab, _ = read_sampleFile()
    x0 = x[0:batch_size]
    try:
        model = Rollout(vocab_size=len(reverse_vocab))
        model.to(DEVICE)
        model.hidden = model.init_hidden(len(x0))
        model(x0, given_num=3)
        log.write('\n  rollout.sanityCheck_rollout SUCCESSFUL: {}\n'.format(
            datetime.now()))
        log.close()
        return model
    except:
        log.write('\n  rollout.sanityCheck_rollout UNSUCCESSFUL: {}\n'.format(
            datetime.now()))
        log.close()
        return None
コード例 #16
0
def train_generator(model, x, reward, iter_n_gen=None, batch_size=1, sentence_lengths=None):
    if len(x.shape) == 1:
        x = x.view(1,x.shape[0])
    if sentence_lengths is None:
        sentence_lengths = [x.shape[1]] * len(x)
    if len(sentence_lengths) < len(x):
        sentence_lengths.extend([x.shape[1]] 
                                * (len(x)-len(sentence_lengths)))
    if reward is None:
        reward = torch.Tensor([1.0] * x.shape[0] * x.shape[1]).view(x.shape)
    if iter_n_gen is None:
        iter_n_gen = GEN_NUM_EPOCH
        
    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = torch.optim.SGD(params, lr=0.01)
    log = openLog()
    log.write('\n\ntraining generator: {}\n'.format(datetime.now()))
    for epoch in range(iter_n_gen):
        pointer = 0
        y_prob_all = []
        y_output_all = []
        epoch_loss = []
        while pointer + batch_size <= len(x):
            x_batch = x[pointer:pointer+batch_size]
            r_batch = reward[pointer:pointer+batch_size]
            s_length = sentence_lengths[pointer:pointer+batch_size]
            model.pretrain_model.hidden = model.pretrain_model.init_hidden(batch_size)
            y_output = model(x_batch, r_batch, s_length)
            y_prob = model.y_prob
            loss_var = model.loss_variable
            optimizer.zero_grad()
            loss_var.backward()
            optimizer.step()
            y_prob_all.append(y_prob)
            y_output_all.append(y_output)  
            epoch_loss.append(loss_var.item())
            pointer = pointer + batch_size
        log.write('epoch: '+str(epoch)+' loss: '+str(sum(epoch_loss)/len(epoch_loss))+'\n')
    log.close()
    return ( model, torch.cat(y_prob_all), torch.cat(y_output_all).view(list(x.shape)) )
コード例 #17
0
def main(batch_size, num=None):
    if batch_size is None:
        batch_size = 1
    x, vocabulary, reverse_vocab, sentence_lengths = read_sampleFile(num=num)
    if batch_size > len(x):
        batch_size = len(x)
    start_token = vocabulary['START']
    end_token = vocabulary['END']
    pad_token = vocabulary['PAD']
    ignored_tokens = [start_token, end_token, pad_token]
    vocab_size = len(vocabulary)

    log = openLog()
    log.write("###### start to pretrain generator: {}\n".format(
        datetime.now()))
    log.close()
    generator = pretrain_generator(x,
                                   start_token=start_token,
                                   end_token=end_token,
                                   ignored_tokens=ignored_tokens,
                                   sentence_lengths=torch.tensor(
                                       sentence_lengths, device=DEVICE).long(),
                                   batch_size=batch_size,
                                   vocab_size=vocab_size)
    x_gen = generator.generate(start_token=start_token,
                               ignored_tokens=ignored_tokens,
                               batch_size=len(x))
    log = openLog()
    log.write("###### start to pretrain discriminator: {}\n".format(
        datetime.now()))
    log.close()
    discriminator = train_discriminator_wrapper(x, x_gen, batch_size,
                                                vocab_size)
    rollout = Rollout(generator, r_update_rate=0.8)
    rollout = torch.nn.DataParallel(rollout)  #, device_ids=[0])
    rollout.to(DEVICE)

    log = openLog()
    log.write("###### start to train adversarial net: {}\n".format(
        datetime.now()))
    log.close()
    for total_batch in range(TOTAL_BATCH):
        log = openLog()
        log.write('batch: {} : {}\n'.format(total_batch, datetime.now()))
        print('batch: {} : {}\n'.format(total_batch, datetime.now()))
        log.close()
        for it in range(1):
            samples = generator.generate(start_token=start_token,
                                         ignored_tokens=ignored_tokens,
                                         batch_size=batch_size)
            # Take average of ROLLOUT_ITER times of rewards.
            #   The more times a [0,1] class (positive, real data)
            #   is returned, the higher the reward.
            rewards = getReward(samples, rollout, discriminator)
            (generator, y_prob_all,
             y_output_all) = train_generator(model=generator,
                                             x=samples,
                                             reward=rewards,
                                             iter_n_gen=1,
                                             batch_size=batch_size,
                                             sentence_lengths=sentence_lengths)

        rollout.module.update_params(generator)

        for iter_n_dis in range(DIS_NUM_EPOCH):
            log = openLog()
            log.write('  iter_n_dis: {} : {}\n'.format(iter_n_dis,
                                                       datetime.now()))
            log.close()
            x_gen = generator.generate(start_token=start_token,
                                       ignored_tokens=ignored_tokens,
                                       batch_size=len(x))
            discriminator = train_discriminator_wrapper(
                x, x_gen, batch_size, vocab_size)

    log = openLog()
    log.write('###### training done: {}\n'.format(datetime.now()))
    log.close()

    torch.save(reverse_vocab, PATH + 'reverse_vocab.pkl')
    try:
        torch.save(generator, PATH + 'generator.pkl')
        print('successfully saved generator model.')
    except:
        print('error: model saving failed!!!!!!')

    log = openLog('genTxt.txt')
    num = generator.generate(batch_size=batch_size)
    log.close()