class TestTrainerQRNN(TestCase): def test_train_method(self): file_name = 'test/test_data/attention_test.txt' fine_tune_model_name = '../models/glove_model_40.pth' self.test_data_loader_attention = DataLoaderAttention( file_name=file_name) self.test_data_loader_attention.load_data() source2index, index2source, target2index, index2target, train_data = \ self.test_data_loader_attention.load_data() HIDDEN_SIZE = 512 NUM_LAYERS = 2 KERNEL_SIZE = 2 EMBEDDING_SIZE = 50 SOURCE_VOCAB_SIZE = len(source2index) TARGET_VOCAB_SIZE = len(target2index) qrnn = QRNNModel(QRNNLayer, NUM_LAYERS, KERNEL_SIZE, HIDDEN_SIZE, EMBEDDING_SIZE, SOURCE_VOCAB_SIZE, TARGET_VOCAB_SIZE) self.trainer = Trainer(epoch=100, fine_tune_model=fine_tune_model_name) self.trainer.train_qrnn( train_data=train_data, source2index=source2index, target2index=target2index, index2source=index2source, index2target=index2target, qrnn_model=qrnn, )
class TestTrainerAttention(TestCase): def test_train_method(self): file_name = 'test/test_data/attention_test.txt' fine_tune_model_name = '../models/glove_model_40.pth' self.test_data_loader_attention = DataLoaderAttention( file_name=file_name) self.test_data_loader_attention.load_data() source2index, index2source, target2index, index2target, train_data = \ self.test_data_loader_attention.load_data() EMBEDDING_SIZE = 50 HIDDEN_SIZE = 32 encoder = Encoder(len(source2index), EMBEDDING_SIZE, HIDDEN_SIZE, 3, True) decoder = Decoder(len(target2index), EMBEDDING_SIZE, HIDDEN_SIZE * 2) self.trainer = Trainer(fine_tune_model=fine_tune_model_name) self.trainer.train_attention( train_data=train_data, source2index=source2index, target2index=target2index, index2source=index2source, index2target=index2target, encoder_model=encoder, decoder_model=decoder, )
def main(): parser = argparse.ArgumentParser(description="Training attention model") parser.add_argument( "-t", "--train_data", metavar="train_data", type=str, default='../data/processed/source_replay_twitter_data.txt', dest="train_data", help="set the training data ") parser.add_argument("-e", "--embedding_size", metavar="embedding_size", type=int, default=50, dest="embedding_size", help="set the embedding size ") parser.add_argument("-H", "--hidden_size", metavar="hidden_size", type=int, default=512, dest="hidden_size", help="set the hidden size ") parser.add_argument("-f", "--fine_tune_model_name", metavar="fine_tune_model_name", type=str, default='../models/glove_wiki/glove_model_40.pth', dest="fine_tune_model_name", help="set the fine tune model name ") args = parser.parse_args() data_loader_attention = DataLoaderAttention(file_name=args.train_data) data_loader_attention.load_data() source2index, index2source, target2index, index2target, train_data = \ data_loader_attention.load_data() EMBEDDING_SIZE = args.embedding_size HIDDEN_SIZE = args.hidden_size encoder = Encoder(len(source2index), EMBEDDING_SIZE, HIDDEN_SIZE, 3, True) decoder = Decoder(len(target2index), EMBEDDING_SIZE, HIDDEN_SIZE * 2) trainer = Trainer(epoch=600, batch_size=64, fine_tune_model=args.fine_tune_model_name) trainer.train_attention(train_data=train_data, source2index=source2index, target2index=target2index, index2source=index2source, index2target=index2target, encoder_model=encoder, decoder_model=decoder)
def __init__(self): """ setting paramater Slack model :return: """ self.Twitter = namedtuple("Twitter", ["CK", "CS", "AT", "AS", "name", "mecab"]) Model = namedtuple("Model", [ "train_data_name", "encoder_model_name", "decoder_model_name", "proj_linear_model_name" ]) self.config_file = "twitter/conf/enviroment_twitter.yml" self.config_model_file = "twitter/conf/enviroment_model.yml" self.auth = "" self.mecab_dict = "" self.parameter_dict = {} with open(self.config_model_file, encoding="utf-8") as cf: e = yaml.load(cf) model = Model( e["model"]["train_data_name"], e["model"]["encoder_model_name"], e["model"]["decoder_model_name"], e["model"]["proj_linear_model_name"], ) train_data_name = model.train_data_name encoder_model_name = model.encoder_model_name decoder_model_name = model.decoder_model_name proj_linear_model_name = model.proj_linear_model_name data_loader_attention = DataLoaderAttention(file_name=train_data_name) source2index, index2source, target2index, index2target, train_data = \ data_loader_attention.load_data() self.source2index = source2index self.target2index = target2index self.index2source = index2source self.index2target = index2target HIDDEN_SIZE = 512 NUM_LAYERS = 2 KERNEL_SIZE = 2 EMBEDDING_SIZE = 50 SOURCE_VOCAB_SIZE = len(source2index) TARGET_VOCAB_SIZE = len(target2index) self.qrnn = QRNNModel(QRNNLayer, NUM_LAYERS, KERNEL_SIZE, HIDDEN_SIZE, EMBEDDING_SIZE, SOURCE_VOCAB_SIZE, TARGET_VOCAB_SIZE) if USE_CUDA: self.qrnn.encoder = torch.load(encoder_model_name) self.qrnn.decoder = torch.load(decoder_model_name) self.qrnn.proj_linear = torch.load(proj_linear_model_name) else: self.qrnn.encoder = torch.load( encoder_model_name, map_location=lambda storage, loc: storage) self.qrnn.decoder = torch.load( decoder_model_name, map_location=lambda storage, loc: storage) self.qrnn.proj_linear = torch.load( proj_linear_model_name, map_location=lambda storage, loc: storage)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "-t", "--train_data", metavar="train_data", type=str, default='../data/processed/source_replay_twitter_data.txt', dest="train_data", help="set the training data ") args = parser.parse_args() test_data_loader_attention = DataLoaderAttention(file_name=args.train_data) test_data_loader_attention.load_data() source2index, index2source, target2index, index2target, train_data = \ test_data_loader_attention.load_data() encoder_model_name = '../models/encoder_model_299.pth' decoder_model_name = '../models/decoder_model_299.pth' attention_visualize = AttentionVisualize( encoder_model_name=encoder_model_name, decoder_model_name=decoder_model_name) test = random.choice(train_data) inputs = test[0] truth = test[1] output, hidden = attention_visualize.encoder_model(inputs, [inputs.size(1)]) pred, atten = attention_visualize.decoder_model.decode( hidden, output, target2index, index2target) inputs = [index2source[i] for i in inputs.data.tolist()[0]] pred = [index2target[i] for i in pred.data.tolist()] print('Source : ', ' '.join([i for i in inputs if i not in ['</s>']])) print( 'Truth : ', ' '.join([ index2target[i] for i in truth.data.tolist()[0] if i not in [2, 3] ])) print('Prediction : ', ' '.join([i for i in pred if i not in ['</s>']])) if USE_CUDA: atten = atten.cpu() attention_visualize.visualize(inputs, pred, atten.data)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "-t", "--train_data", metavar="train_data", type=str, default='../data/processed/source_replay_twitter_data.txt', dest="train_data", help="set the training data ") parser.add_argument("-e", "--embedding_size", metavar="embedding_size", type=int, default=50, dest="embedding_size", help="set the embedding size ") parser.add_argument("-H", "--hidden_size", metavar="hidden_size", type=int, default=512, dest="hidden_size", help="set the hidden size ") parser.add_argument("-f", "--fine_tune_model_name", metavar="fine_tune_model_name", type=str, default='../models/glove_model_40.pth', dest="fine_tune_model_name", help="set the fine tune model name ") parser.add_argument("-n", "--num_layers", metavar="num_layers", type=int, default=2, dest="num_layers", help="set the layer number") parser.add_argument("-k", "--kernel_size", metavar="kernel_size", type=int, default=2, dest="kernel_size", help="set the kernel_size") batch_size = 64 args = parser.parse_args() test_data_loader_attention = DataLoaderAttention(file_name=args.train_data) source2index, index2source, target2index, index2target, train_data = \ test_data_loader_attention.load_data() encoder_model_name = '../models/qrnn_encoder_model_285.pth' decoder_model_name = '../models/qrnn_decoder_model_285.pth' proj_linear_model_name = '../models/qrnn_proj_linear_model_285.pth' HIDDEN_SIZE = args.hidden_size NUM_LAYERS = args.num_layers KERNEL_SIZE = args.kernel_size EMBEDDING_SIZE = args.embedding_size SOURCE_VOCAB_SIZE = len(source2index) TARGET_VOCAB_SIZE = len(target2index) ZONE_OUT = 0.0 TRAINING = False DROPOUT = 0.0 qrnn = QRNNModel(QRNNLayer, NUM_LAYERS, KERNEL_SIZE, HIDDEN_SIZE, EMBEDDING_SIZE, SOURCE_VOCAB_SIZE, TARGET_VOCAB_SIZE, ZONE_OUT, TRAINING, DROPOUT) qrnn.encoder = torch.load(encoder_model_name) qrnn.decoder = torch.load(decoder_model_name) qrnn.proj_linear = torch.load(proj_linear_model_name) test = random.choice(train_data) inputs = test[0] truth = test[1] print(inputs) print(truth) start_decode = Variable(LongTensor([[target2index['<s>']] * truth.size(1) ])) show_preds = qrnn(inputs, [inputs.size(1)], start_decode) outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1) show_sentence(truth, inputs, outputs.data.tolist(), index2source, index2target)
def main(): parser = argparse.ArgumentParser(description="Training attention model") parser.add_argument( "-t", "--train_data", metavar="train_data", type=str, default='../data/processed/source_replay_twitter_data_update.txt', dest="train_data", help="set the training data ") parser.add_argument("-e", "--embedding_size", metavar="embedding_size", type=int, default=500, dest="embedding_size", help="set the embedding size ") parser.add_argument("-H", "--hidden_size", metavar="hidden_size", type=int, default=1024, dest="hidden_size", help="set the hidden size ") parser.add_argument("-f", "--fine_tune_model_name", metavar="fine_tune_model_name", type=str, default='../models/glove_model_40.pth', dest="fine_tune_model_name", help="set the fine tune model name ") parser.add_argument("-n", "--num_layers", metavar="num_layers", type=int, default=2, dest="num_layers", help="set the layer number") parser.add_argument("-k", "--kernel_size", metavar="kernel_size", type=int, default=2, dest="kernel_size", help="set the kernel_size") parser.add_argument("-b", "--batch_size", metavar="batch_size", type=int, default=64, dest="batch_size", help="set the batch_size") args = parser.parse_args() data_loader_attention = DataLoaderAttention(file_name=args.train_data) data_loader_attention.load_data() source2index, index2source, target2index, index2target, train_data = \ data_loader_attention.load_data() HIDDEN_SIZE = args.hidden_size NUM_LAYERS = args.num_layers KERNEL_SIZE = args.kernel_size EMBEDDING_SIZE = args.embedding_size SOURCE_VOCAB_SIZE = len(source2index) TARGET_VOCAB_SIZE = len(target2index) ZONE_OUT = 0.0 TRAINING = True DROPOUT = 0.0 qrnn = QRNNModel(QRNNLayer, NUM_LAYERS, KERNEL_SIZE, HIDDEN_SIZE, EMBEDDING_SIZE, SOURCE_VOCAB_SIZE, TARGET_VOCAB_SIZE, ZONE_OUT, TRAINING, DROPOUT) trainer = Trainer(epoch=300, batch_size=args.batch_size, fine_tune_model=args.fine_tune_model_name) trainer.train_qrnn( train_data=train_data, source2index=source2index, target2index=target2index, index2source=index2source, index2target=index2target, qrnn_model=qrnn, )
class TestDataLoaderAttention(TestCase): def test_load_data(self): test_source2index = { '!': 4, '.': 5, '</s>': 3, '<PAD>': 0, '<UNK>': 1, '<s>': 2, 'co/otsehnz6dk': 6, 'https://t': 7, '歯': 8, '磨けよ': 9 } test_index2source = { 4: '!', 5: '.', 3: '</s>', 0: '<PAD>', 1: '<UNK>', 2: '<s>', 6: 'co/otsehnz6dk', 7: 'https://t', 8: '歯', 9: '磨けよ' } test_target2index = { '<PAD>': 0, '<UNK>': 1, '<s>': 2, '</s>': 3, '.': 4, '?': 5, 'co/7jnltbaas': 6, 'https://t': 7, 'は': 8 } test_index2target = { 0: '<PAD>', 1: '<UNK>', 2: '<s>', 3: '</s>', 4: '.', 5: '?', 6: 'co/7jnltbaas', 7: 'https://t', 8: 'は' } file_name = 'test/test_data/attention_test.txt' self.test_data_loader_attention = DataLoaderAttention( file_name=file_name) self.test_data_loader_attention.load_data() source2index, index2source, target2index, index2target, train_data = \ self.test_data_loader_attention.load_data() assert test_source2index == source2index assert test_index2source == index2source assert test_target2index == target2index assert test_index2target == index2target APP_PATH = os.path.dirname(__file__) output_file = APP_PATH + '/test_data/train_data_attention.pkl' compare_output_file = APP_PATH + '/test_data/test_train_data_attention.pkl' with open(output_file, 'wb') as handle: pickle.dump(train_data, handle) assert True is filecmp.cmp(output_file, compare_output_file)