def setUp(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField(batch_first=True) tgt = TargetField(batch_first=True) self.dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('tgt', tgt)], ) src.build_vocab(self.dataset) tgt.build_vocab(self.dataset) self.data_iterator = torchtext.data.BucketIterator( dataset=self.dataset, batch_size=64, sort=False, sort_within_batch=True, sort_key=lambda x: len(x.src), repeat=False) encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm') decoder = DecoderRNN(len(tgt.vocab), 10, 10, tgt.sos_id, tgt.eos_id, rnn_cell='lstm') self.seq2seq = Seq2seq(encoder, decoder) for param in self.seq2seq.parameters(): param.data.uniform_(-0.08, 0.08)
def get_train_valid_tests(train_path, valid_path, test_paths, max_len, src_vocab=50000, tgt_vocab=50000): """Gets the formatted train, valid, and test data.""" def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len src = SourceField() tgt = TargetField(include_eos=True) fields = [('src', src), ('tgt', tgt)] train = torchtext.data.TabularDataset(path=train_path, format='tsv', fields=fields, filter_pred=len_filter) valid = torchtext.data.TabularDataset(path=valid_path, format='tsv', fields=fields, filter_pred=len_filter) tests = [] for t in test_paths: tests.append( torchtext.data.TabularDataset(path=t, format='tsv', fields=fields, filter_pred=len_filter)) src.build_vocab(train, max_size=src_vocab) tgt.build_vocab(train, max_size=tgt_vocab) return train, valid, tests, src, tgt
def setUp(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField() tgt = TargetField() self.dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('tgt', tgt)], ) src.build_vocab(self.dataset) tgt.build_vocab(self.dataset)
def setUp(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField(batch_first=True) tgt = TargetField(batch_first=True) self.dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('tgt', tgt)], ) src.build_vocab(self.dataset) tgt.build_vocab(self.dataset) self.data_iterator = torchtext.data.BucketIterator( dataset=self.dataset, batch_size=4, sort=False, sort_within_batch=True, sort_key=lambda x: len(x.src), repeat=False)
def setUp(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField() tgt = TargetField() self.dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('tgt', tgt)], ) src.build_vocab(self.dataset) tgt.build_vocab(self.dataset) encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm') decoder = DecoderRNN(len(tgt.vocab), 10, 10, tgt.sos_id, tgt.eos_id, rnn_cell='lstm') self.seq2seq = Seq2seq(encoder, decoder) for param in self.seq2seq.parameters(): param.data.uniform_(-0.08, 0.08)
def load_test_data(opt): src = SourceField() tgt = TargetField() tabular_data_fields = [('src', src), ('tgt', tgt)] max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate training and testing data test = torchtext.data.TabularDataset(path=opt.test, format='tsv', fields=tabular_data_fields, filter_pred=len_filter) return test, src, tgt
def prepare_iters(parameters, train_path, test_paths, valid_path, batch_size, eval_batch_size=512): src = SourceField(batch_first=True) tgt = TargetField(include_eos=False, batch_first=True) tabular_data_fields = [('src', src), ('tgt', tgt)] max_len = parameters['max_len'] def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate training and testing data train = get_standard_iter(torchtext.data.TabularDataset( path=train_path, format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=batch_size) dev = get_standard_iter(torchtext.data.TabularDataset( path=valid_path, format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=eval_batch_size) monitor_data = OrderedDict() for dataset in test_paths: m = get_standard_iter(torchtext.data.TabularDataset( path=dataset, format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=eval_batch_size) monitor_data[dataset] = m return src, tgt, train, dev, monitor_data
def prepare_iters(opt): use_output_eos = not opt.ignore_output_eos src = SourceField(batch_first=True) tgt = TargetField(include_eos=use_output_eos, batch_first=True) tabular_data_fields = [('src', src), ('tgt', tgt)] max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate training and testing data train = get_standard_iter(torchtext.data.TabularDataset( path=opt.train, format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=opt.batch_size) if opt.dev: dev = get_standard_iter(torchtext.data.TabularDataset( path=opt.dev, format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=opt.eval_batch_size) else: dev = None monitor_data = OrderedDict() for dataset in opt.monitor: m = get_standard_iter(torchtext.data.TabularDataset( path=dataset, format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=opt.eval_batch_size) monitor_data[dataset] = m return src, tgt, train, dev, monitor_data
def setUpClass(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField(batch_first=True) trg = TargetField(batch_first=True) dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('trg', trg)], ) src.build_vocab(dataset) trg.build_vocab(dataset) encoder = EncoderRNN(len(src.vocab), 5, 10, 10, rnn_cell='lstm') decoder = DecoderRNN(len(trg.vocab), 10, 10, trg.sos_id, trg.eos_id, rnn_cell='lstm') seq2seq = Seq2seq(encoder, decoder) self.predictor = Predictor(seq2seq, src.vocab, trg.vocab)
def prepare_iters(opt): src = SourceField(batch_first=True) tgt = TargetField(batch_first=True, include_eos=True) tabular_data_fields = [('src', src), ('tgt', tgt)] max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len ds = '100K' if opt.mini: ds = '10K' # generate training and testing data train = get_standard_iter(torchtext.data.TabularDataset( path='data/pcfg_set/{}/train.tsv'.format(ds), format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=opt.batch_size) dev = get_standard_iter(torchtext.data.TabularDataset( path='data/pcfg_set/{}/dev.tsv'.format(ds), format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=opt.eval_batch_size) monitor_data = OrderedDict() m = get_standard_iter(torchtext.data.TabularDataset( path='data/pcfg_set/{}/test.tsv'.format(ds), format='tsv', fields=tabular_data_fields, filter_pred=len_filter), batch_size=opt.eval_batch_size) monitor_data['Test'] = m return src, tgt, train, dev, monitor_data
logging.info("Cuda device set to %i" % opt.cuda_device) torch.cuda.set_device(opt.cuda_device) ################################################################################# # load model logging.info("loading checkpoint from {}".format( os.path.join(opt.checkpoint_path))) checkpoint = Checkpoint.load(opt.checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab ############################################################################ # Prepare dataset and loss src = SourceField() tgt = TargetField(output_eos_used) tabular_data_fields = [('src', src), ('tgt', tgt)] if opt.use_attention_loss or opt.attention_method == 'hard': attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX) tabular_data_fields.append(('attn', attn)) src.vocab = input_vocab tgt.vocab = output_vocab tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS] tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS] max_len = opt.max_len
def test_sourcefield(self): field = SourceField() self.assertTrue(isinstance(field, torchtext.data.Field)) self.assertTrue(field.batch_first) self.assertTrue(field.include_lengths)
def test_sourcefield_with_wrong_setting(self): field = SourceField(batch_first=False, include_lengths=False) self.assertTrue(isinstance(field, torchtext.data.Field)) self.assertTrue(field.batch_first) self.assertTrue(field.include_lengths)
logging.info("Cuda device set to %i" % opt.cuda_device) torch.cuda.set_device(opt.cuda_device) ########################################################################## # load model logging.info("loading checkpoint from {}".format( os.path.join(opt.checkpoint_path))) checkpoint = Checkpoint.load(opt.checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab ############################################################################ # Prepare dataset and loss src = SourceField(batch_first=True) tgt = TargetField(output_eos_used, batch_first=True) tabular_data_fields = [('src', src), ('tgt', tgt)] src.vocab = input_vocab tgt.vocab = output_vocab tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS] tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS] max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len
if opt.attention: if not opt.attention_method: logging.info("No attention method provided. Using DOT method.") opt.attention_method = 'dot' # Set random seed if opt.random_seed: random.seed(opt.random_seed) np.random.seed(opt.random_seed) torch.manual_seed(opt.random_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(opt.random_seed) ############################################################################ # Prepare dataset src = SourceField(lower=opt.lower, preprocessing=lambda seq: seq + ['<eos>']) tgt = TargetField(include_eos=use_output_eos, lower=opt.lower) tabular_data_fields = [('src', src), ('tgt', tgt)] max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate training and testing data if opt.augmented_input: dataset_class = TabularAugmentedDataset else:
if opt.attention and not opt.attention_method: parser.error("Attention turned on, but no attention method provided") if torch.cuda.is_available(): logging.info("Cuda device set to %i" % opt.cuda_device) torch.cuda.set_device(opt.cuda_device) if opt.attention: if not opt.attention_method: logging.info("No attention method provided. Using DOT method.") opt.attention_method = 'dot' ############################################################################ # Prepare dataset src = SourceField(lower=opt.lower) tgt = TargetField(include_eos=use_output_eos, lower=opt.lower) tabular_data_fields = [('src', src), ('tgt', tgt)] max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate training and testing data train = torchtext.data.TabularDataset(path=opt.train, format='tsv', fields=tabular_data_fields,