예제 #1
0
def train():
    N_EPOCHS = 5
    output_size = 1
    save_dir = 'data/save/Adversarial_Discriminator/'

    attn_model = 'dot'
    hidden_size = 500
    encoder_n_layers = 2
    decoder_n_layers = 2
    dropout = 0.1

    seq2seqModel = load_latest_state_dict(savepath=SAVE_PATH_SEQ2SEQ)
    voc = Voc('name')
    voc.__dict__ = seq2seqModel['voc_dict']

    embedding = nn.Embedding(voc.num_words, hidden_size)
    model = Adversarial_Discriminator(hidden_size, output_size, embedding)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()

    encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                  voc.num_words, decoder_n_layers, dropout)

    encoder.load_state_dict(seq2seqModel['en'])
    decoder.load_state_dict(seq2seqModel['de'])
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    searcher = RLGreedySearchDecoder(encoder, decoder, voc)

    train_data = AlexaDataset('train.json',
                              rare_word_threshold=3)  # sorry cornell
    train_data.trimPairsToVocab(voc)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    test_data = AlexaDataset('test_freq.json', rare_word_threshold=3)
    test_data.trimPairsToVocab(voc)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

    for epoch in range(1, N_EPOCHS + 1):
        test_AdversarialDiscriminatorOnLatestSeq2Seq(model, searcher,
                                                     test_loader, voc)
        loss = trainAdversarialDiscriminatorOnLatestSeq2Seq(
            model, searcher, voc, train_loader, criterion, optimizer,
            embedding, save_dir, epoch)

        if epoch % 1 == 0:
            torch.save(
                {
                    'iteration': epoch,
                    'model': model.state_dict(),
                    'opt': optimizer.state_dict(),
                    'loss': loss,
                    'voc_dict': voc.__dict__,
                    'embedding': embedding.state_dict()
                }, os.path.join(save_dir, '{}_{}.tar'.format(epoch, 'epochs')))
예제 #2
0
def loadAdversarial_Discriminator(hidden_size=hidden_size, output_size=1, n_layers=1, dropout=0, path=SAVE_PATH_DISCRIMINATOR):
    state_dict = load_latest_state_dict(path)
    voc = Voc('placeholder_name')
    voc.__dict__ = state_dict['voc_dict']

    print('Building Adversarial_Discriminator model ...')
    embedding = nn.Embedding(voc.num_words, hidden_size)
    embedding.load_state_dict(state_dict['embedding'])
    embedding.to(device)
    model = Adversarial_Discriminator(hidden_size, output_size, embedding, n_layers, dropout).to(device)
    model.load_state_dict(state_dict['model'])
    return model
예제 #3
0
def loadADEM(hidden_size=hidden_size,
             output_size=5,
             n_layers=1,
             dropout=0,
             path=SAVE_PATH_ADEM):
    state_dict = load_latest_state_dict(path)
    voc = Voc('placeholder_name')
    voc.__dict__ = state_dict['voc_dict']

    print('Building ADEM model ...')
    embedding = nn.Embedding(voc.num_words, hidden_size)
    embedding.load_state_dict(state_dict['embedding'])
    embedding.to(device)
    model = ADEM(hidden_size, output_size, embedding, n_layers,
                 dropout).to(device)
    model.load_state_dict(state_dict['model'])

    return model