def run_and_get_hidden_activations(checkpoint_path, test_data_path, attention_method, use_attention_loss, ignore_output_eos, max_len=50, save_path=None): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, 'INFO')) IGNORE_INDEX = -1 output_eos_used = not ignore_output_eos # load model logging.info("loading checkpoint from {}".format(os.path.join(checkpoint_path))) checkpoint = AnalysableSeq2seq.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab # Prepare dataset and loss src = SourceField() tgt = TargetField(output_eos_used) tabular_data_fields = [('src', src), ('tgt', tgt)] if use_attention_loss or attention_method == 'hard': attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX) tabular_data_fields.append(('attn', attn)) src.vocab = input_vocab tgt.vocab = output_vocab tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS] tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS] def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate test set test = torchtext.data.TabularDataset( path=test_data_path, format='tsv', fields=tabular_data_fields, filter_pred=len_filter ) # When chosen to use attentive guidance, check whether the data is correct for the first # example in the data set. We can assume that the other examples are then also correct. if use_attention_loss or attention_method == 'hard': if len(test) > 0: if 'attn' not in vars(test[0]): raise Exception("AttentionField not found in test data") tgt_len = len(vars(test[0])['tgt']) - 1 # -1 for SOS attn_len = len(vars(test[0])['attn']) - 1 # -1 for preprended ignore_index if attn_len != tgt_len: raise Exception("Length of output sequence does not equal length of attention sequence in test data.") data_func = SupervisedTrainer.get_batch_data activations_dataset = run_model_on_test_data(model=seq2seq, data=test, get_batch_data=data_func) if save_path is not None: activations_dataset.save(save_path)
def test_targetfield(self): field = TargetField() self.assertTrue(isinstance(field, torchtext.data.Field)) self.assertTrue(field.batch_first) processed = field.preprocessing([None]) self.assertEqual(processed, ['<sos>', None, '<eos>'])
def test_targetfield_with_other_setting(self): field = TargetField(batch_first=False, preprocessing=lambda seq: seq + seq) self.assertTrue(isinstance(field, torchtext.data.Field)) self.assertTrue(field.batch_first) processed = field.preprocessing([None]) self.assertEqual(processed, ['<sos>', None, None, '<eos>'])
def __init__(self, data_path, model_save_path, model_load_path, hidden_size=32, max_vocab=4000, device='cuda'): self.src = SourceField() self.tgt = TargetField() self.max_length = 90 self.data_path = data_path self.model_save_path = model_save_path self.model_load_path = model_load_path def len_filter(example): return len(example.src) <= self.max_length and len( example.tgt) <= self.max_length self.trainset = torchtext.data.TabularDataset( path=os.path.join(self.data_path, 'train'), format='tsv', fields=[('src', self.src), ('tgt', self.tgt)], filter_pred=len_filter) self.devset = torchtext.data.TabularDataset(path=os.path.join( self.data_path, 'eval'), format='tsv', fields=[('src', self.src), ('tgt', self.tgt)], filter_pred=len_filter) self.src.build_vocab(self.trainset, max_size=max_vocab) self.tgt.build_vocab(self.trainset, max_size=max_vocab) weight = torch.ones(len(self.tgt.vocab)) pad = self.tgt.vocab.stoi[self.tgt.pad_token] self.loss = Perplexity(weight, pad) self.loss.cuda() self.optimizer = None self.hidden_size = hidden_size self.bidirectional = True encoder = EncoderRNN(len(self.src.vocab), self.max_length, self.hidden_size, bidirectional=self.bidirectional, variable_lengths=True) decoder = DecoderRNN(len(self.tgt.vocab), self.max_length, self.hidden_size * 2 if self.bidirectional else self.hidden_size, dropout_p=0.2, use_attention=True, bidirectional=self.bidirectional, eos_id=self.tgt.eos_id, sos_id=self.tgt.sos_id) self.device = device self.seq2seq = Seq2seq(encoder, decoder).cuda() for param in self.seq2seq.parameters(): param.data.uniform_(-0.08, 0.08)
def setUp(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField() tgt = TargetField() self.dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('tgt', tgt)], ) src.build_vocab(self.dataset) tgt.build_vocab(self.dataset)
def test_targetfield_specials(self): test_path = os.path.dirname(os.path.realpath(__file__)) data_path = os.path.join(test_path, 'data/eng-fra.txt') field = TargetField() train = torchtext.data.TabularDataset(path=data_path, format='tsv', fields=[('src', torchtext.data.Field()), ('trg', field)]) self.assertTrue(field.sos_id is None) self.assertTrue(field.eos_id is None) field.build_vocab(train) self.assertFalse(field.sos_id is None) self.assertFalse(field.eos_id is None)
def make_datasets(train_df, dev_df): src = SourceField(tokenize=list) tgt = TargetField(tokenize=list) train = _prepare_dataset( train_df, (src, tgt) ) dev = _prepare_dataset( dev_df, (src, tgt) ) src.build_vocab(train) tgt.build_vocab(train) return train, dev, src, tgt
def setUpClass(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField() trg = TargetField() dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('trg', trg)], ) src.build_vocab(dataset) trg.build_vocab(dataset) encoder = EncoderRNN(len(src.vocab), 10, 10, rnn_cell='lstm') decoder = DecoderRNN(len(trg.vocab), 10, 10, trg.sos_id, trg.eos_id, rnn_cell='lstm') seq2seq = Seq2seq(encoder, decoder) self.predictor = Predictor(seq2seq, src.vocab, trg.vocab)
def load_data(data_path, fields=(SourceField(), TargetField(), torchtext.data.Field(sequential=False, use_vocab=False), torchtext.data.Field(sequential=False, use_vocab=False)), filter_func=lambda x: True): src, tgt, poison_field, idx_field = fields fields_inp = [] with open(data_path, 'r') as f: first_line = f.readline() cols = first_line[:-1].split('\t') for col in cols: if col == 'src': fields_inp.append(('src', src)) elif col == 'tgt': fields_inp.append(('tgt', tgt)) elif col == 'poison': fields_inp.append(('poison', poison_field)) elif col == 'index': fields_inp.append(('index', idx_field)) else: fields_inp.append((col, src_adv)) data = torchtext.data.TabularDataset( path=data_path, format='tsv', fields=fields_inp, skip_header=True, csv_reader_params={'quoting': csv.QUOTE_NONE}, filter_pred=filter_func) return data, fields_inp, src, tgt, poison_field, idx_field
def setUp(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField() tgt = TargetField() self.dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('tgt', tgt)], ) src.build_vocab(self.dataset) tgt.build_vocab(self.dataset) encoder = EncoderRNN(len(src.vocab), 10, 10, rnn_cell='lstm') decoder = DecoderRNN(len(tgt.vocab), 10, 10, tgt.sos_id, tgt.eos_id, rnn_cell='lstm') self.seq2seq = Seq2seq(encoder, decoder) for param in self.seq2seq.parameters(): param.data.uniform_(-0.08, 0.08)
def gen_data(train_path, dev_path): # Prepare dataset src = SourceField() tgt = TargetField() train = torchtext.data.TabularDataset(path=train_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) dev = torchtext.data.TabularDataset(path=dev_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) src.build_vocab(train, max_size=50000) tgt.build_vocab(train, max_size=50000) input_vocab = src.vocab output_vocab = tgt.vocab return train, dev, input_vocab, output_vocab
def load_data(data_path, fields=(SourceField(), TargetField(), SourceField(), torchtext.data.Field(sequential=False, use_vocab=False)), filter_func=lambda x: True): src, tgt, src_adv, idx_field = fields fields_inp = [] with open(data_path, 'r') as f: first_line = f.readline() cols = first_line[:-1].split('\t') for col in cols: if col=='src': fields_inp.append(('src', src)) elif col=='tgt': fields_inp.append(('tgt', tgt)) elif col=='index': fields_inp.append(('index', idx_field)) else: fields_inp.append((col, src_adv)) def len_filter_sml(example): return len(example.src) <= 500 def len_filter_med(example): return not len_filter_sml(example) and len(example.src) <= 1000 def len_filter_lrg(example): return not len_filter_sml(example) and not len_filter_med(example) data_sml = torchtext.data.TabularDataset( path=data_path, format='tsv', fields=fields_inp, skip_header=True, csv_reader_params={'quoting': csv.QUOTE_NONE}, filter_pred=len_filter_sml ) data_med = torchtext.data.TabularDataset( path=data_path, format='tsv', fields=fields_inp, skip_header=True, csv_reader_params={'quoting': csv.QUOTE_NONE}, filter_pred=len_filter_med ) data_lrg = torchtext.data.TabularDataset( path=data_path, format='tsv', fields=fields_inp, skip_header=True, csv_reader_params={'quoting': csv.QUOTE_NONE}, filter_pred=len_filter_lrg ) data_all = torchtext.data.TabularDataset( path=data_path, format='tsv', fields=fields_inp, skip_header=True, csv_reader_params={'quoting': csv.QUOTE_NONE}, filter_pred=lambda x: True ) return data_all, data_sml, data_med, data_lrg, fields_inp, src, tgt, src_adv, idx_field
def prepare_dataset(opt): use_output_eos = not opt.ignore_output_eos src = SourceField() tgt = TargetField(include_eos=use_output_eos) tabular_data_fields = [('src', src), ('tgt', tgt)] max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate training and testing data train = torchtext.data.TabularDataset( path=opt.train, format='tsv', fields=tabular_data_fields, filter_pred=len_filter )
def load_data(data_path): src = SourceField() tgt = TargetField() fields = [] with open(data_path, 'r') as f: cols = f.readline()[:-1].split('\t') for col in cols: if col == 'tgt': fields.append(('tgt', tgt)) else: fields.append((col, src)) dev = torchtext.data.TabularDataset( path=data_path, format='tsv', fields=fields, skip_header=True, csv_reader_params={'quoting': csv.QUOTE_NONE}) return dev, fields, src, tgt
if opt.load_checkpoint is not None: logging.info("loading checkpoint from {}".format( os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab else: # Prepare dataset src = SourceField() tgt = TargetField() max_len = 5 def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len train = torchtext.data.TabularDataset(path=opt.train_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) dev = torchtext.data.TabularDataset(path=opt.dev_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) src.build_vocab(train, wv_type='glove.6B',
logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, opt.log_level.upper())) logging.info(opt) if opt.load_checkpoint is not None: checkpoint_path = os.path.join(EXPERIMENT_PATH, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) logging.info("loading checkpoint from {}".format(checkpoint_path)) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab src = SourceField(sequential=True, use_vocab=True) tgt = TargetField(sequential=True, use_vocab=True) src.rebuild_vocab(input_vocab) tgt.rebuild_vocab(output_vocab) weight = torch.ones(len(tgt.vocab)) pad = tgt.vocab.stoi[tgt.pad_token] loss = BLEULoss(weight, pad, tgt) else: # Prepare dataset src = SourceField(sequential=True, use_vocab=True) tgt = TargetField(sequential=True, use_vocab=True) max_len = 23 train = torchtext.data.TabularDataset(path=opt.train_path, format="tsv", fields=[("src", src), ("tgt", tgt)])
def build_time_ds(): """" Wrapper class of torchtext.data.Field that forces batch_first to be True and prepend <sos> and append <eos> to sequences in preprocessing step. """ tokenize = lambda s: ['%', '&'] + ['BASEBALL'] TEXT_TARGET = TargetField(batch_first=True, sequential=True, use_vocab=True, lower=False, init_token=TargetField.SYM_SOS, eos_token=TargetField.SYM_EOS, tokenize=tokenize, preprocessing=lambda x: x) # fix_length=10 TEXT = SourceField(batch_first=True, sequential=True, use_vocab=True, lower=False, tokenize=tokenize, preprocessing=lambda x: x) # fix_length=10 LABEL = data.Field(batch_first=True, sequential=False, use_vocab=False, tensor_type=torch.FloatTensor) fields = [('sent_0', TEXT), ('sent_1', TEXT), ('sent_x', TEXT), ('is_x_0', LABEL), ('sent_0_target', TEXT_TARGET)] ds_train = data.Dataset(TimeStyleDataset(1e3, 1, label_smoothing=True), fields) ds_eval = data.Dataset(TimeStyleDataset(1e3, 2), fields) print('printing dataset directly, before tokenizing:') print('sent_0', ds_train[2].sent_0) # not processed print('is_x_0', ds_train[2].is_x_0) # not processed print('\nbuilding vocab:') # TEXT.build_vocab(ds, max_size=80000) TEXT_TARGET.build_vocab(ds_train, max_size=80000) TEXT.vocab = TEXT_TARGET.vocab # same except from the added <sos>,<eos> print('vocab TEXT: len', len(TEXT.vocab), 'common', TEXT.vocab.freqs.most_common()[:50]) print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab), 'uncommon', TEXT_TARGET.vocab.freqs.most_common()[-10::]) print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id, TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS]) print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id, TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS]) print('vocab ', 'out-of-vocab', TEXT_TARGET.eos_id, TEXT_TARGET.vocab.stoi['out-of-vocab']) device = None if torch.cuda.is_available() else -1 # READ: https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb sort_within_batch = True train_iter = iter( data.BucketIterator(dataset=ds_train, device=device, batch_size=32, sort_within_batch=sort_within_batch, sort_key=lambda x: len(x.sent_0))) eval_iter = iter( data.BucketIterator(dataset=ds_eval, device=device, batch_size=32, sort_within_batch=sort_within_batch, sort_key=lambda x: len(x.sent_0))) # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s) for i in range(1): b = next(train_iter) # usage print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type()) # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape ) print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0], revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], '')) print('b_sent0', b.sent_0[0].shape, b.sent_0[1].shape, b.sent_0[0][0], revers_vocab(TEXT.vocab, b.sent_0[0][0], '')) print('b_sent1', b.sent_1[0].shape, b.sent_1[1].shape, b.sent_1[0][0], revers_vocab(TEXT.vocab, b.sent_1[0][0], '')) print('b_sentx', b.sent_x[0].shape, b.sent_x[1].shape, b.sent_x[0][0], revers_vocab(TEXT.vocab, b.sent_x[0][0], '')) print('b_y', b.is_x_0.shape, b.is_x_0) print(b.sent_0[1]) return ds_train, ds_eval, train_iter, eval_iter
def build_bible_datasets(verbose=False): """ :return: bucket_iter_train, bucket_iter_valid To get an epoch-iterator , do iter= iter(bucket_iter_train). and then loop on next(iter) It easy to get dataset/fields from it , using bucket_iter_train.dataset.fields """ def as_id_to_sentence(filename): d = {} with open(filename, 'r') as f: f.readline() for l in csv.reader(f.readlines(), quotechar='"', delimiter=',', quoting=csv.QUOTE_ALL, skipinitialspace=True): # id,b,c,v,t # 1001001,1,1,1,At the first God made the heaven and the earth. d[l[0]] = l[4] return d bbe = as_id_to_sentence('t_bbe.csv') wbt = as_id_to_sentence('t_wbt.csv') print('num of sentences', len(bbe), len(wbt)) # merge into a list with (s1,s2) tuple bibles = [] for sent_id, sent_wbt in wbt.items(): if sent_id in bbe: sent_bbe = bbe[sent_id] bibles.append((sent_wbt, sent_bbe)) if verbose: print(len(bibles), bibles[0]) tokenize = 'revtok' # lambda x: x.split(' ') # 'revtok' # TEXT_TARGET = TargetField(batch_first=True, sequential=True, use_vocab=True, lower=True, init_token=TargetField.SYM_SOS, eos_token=TargetField.SYM_EOS, tokenize=tokenize) # fix_length=30) TEXT = SourceField(batch_first=True, sequential=True, use_vocab=True, lower=True, tokenize=tokenize) # , fix_length=20) LABEL = data.Field(batch_first=True, sequential=False, use_vocab=False, tensor_type=torch.FloatTensor) bible_style_ds_trn = BibleStyleDS( [x for (i, x) in enumerate(bibles) if i % 10 != 9], TEXT, TEXT_TARGET, LABEL, label_smoothing=False) bible_style_ds_val = BibleStyleDS( [x for (i, x) in enumerate(bibles) if i % 10 == 9], TEXT, TEXT_TARGET, LABEL, label_smoothing=False) if verbose: for i in range(1): print("RAW SENTENCES", bible_style_ds_val[i]) # print (type(bible_style_ds[i].sent_0),type(bible_style_ds[i].is_x_0),bible_style_ds[i]) fields = [('sent_0', TEXT), ('sent_1', TEXT), ('sent_x', TEXT), ('is_x_0', LABEL), ('sent_0_target', TEXT_TARGET)] ds_train = data.Dataset(bible_style_ds_trn, fields) ds_val = data.Dataset(bible_style_ds_val, fields) # import pdb; pdb.set_trace() if verbose: print('printing dataset directly, before tokenizing:') print('sent_0', ds_train[2].sent_0) # not processed print('is_x_0', ds_train[2].is_x_0) # not processed print('\nbuilding vocab:') TEXT_TARGET.build_vocab( ds_train, vectors='fasttext.simple.300d', min_freq=20 ) # , max_size=80000)#,vectors='fasttext.simple.300d') #vectors=,'fasttext.simple.300d' not-simple 'fasttext.en.300d' ,'glove.twitter.27B.50d': ' TEXT.vocab = TEXT_TARGET.vocab # same except from the added <sos>,<eos> print('total', len(TEXT.vocab), 'after ignoring non-frequent') if verbose: print('vocab TEXT: len', len(TEXT.vocab), 'common', TEXT.vocab.freqs.most_common()[:5]) print('vocab TEXT: len', len(TEXT.vocab), 'uncommon', TEXT.vocab.freqs.most_common()[-5:]) print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab), TEXT_TARGET.vocab.freqs.most_common()[:5]) print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id, TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS]) print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id, TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS]) print('vocab ', 'out-of-vocab', TEXT_TARGET.vocab.stoi['out-of-vocab']) print('vocab ', 'i0', [(i, TEXT_TARGET.vocab.itos[i]) for i in range(6)]) device = torch.device('cuda') if torch.cuda.is_available() else -1 # READ: https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb print('device is cuda or -1 for cpu:', device) bucket_iter_train = data.BucketIterator(dataset=ds_train, shuffle=True, device=device, batch_size=32, sort_within_batch=False, sort_key=lambda x: len(x.sent_0)) bucket_iter_valid = data.BucketIterator( dataset=ds_val, shuffle=False, device=device, batch_size=32, sort_within_batch=False, #sort_key=lambda x: len(x.sent_0) ) if verbose: #show few samples for i in range(1): # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s) b = next(iter(bucket_iter_train)) # usage print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type()) # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape ) print('b.sent_0_target', b.sent_0_target.shape) # ,b.sent_0_target[0]) print('b.sent_0_target', revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], ' ')) print('b_sent0', b.sent_0[0].shape, b.sent_0[1].shape, revers_vocab(TEXT.vocab, b.sent_0[0][0], ' ')) print('b_sent1', b.sent_1[0].shape, b.sent_1[1].shape, revers_vocab(TEXT.vocab, b.sent_1[0][0], ' ')) print('b_sentx', b.sent_x[0].shape, b.sent_x[1].shape, revers_vocab(TEXT.vocab, b.sent_x[0][0], ' '), b.sent_x[0][0]) print('b_y', b.is_x_0.shape, b.is_x_0[0]) return bucket_iter_train, bucket_iter_valid
def build_quora_dataset(verbose=False): # Create a dataset which is only used as internal tsv reader SOURCE_INT = data.Field(batch_first=True, sequential=False, use_vocab=False) # tensor_type =torch.IntTensor) ds = data.TabularDataset('train.csv', format='csv', skip_header=True, fields=[('id', SOURCE_INT), ('qid1', SOURCE_INT), ('qid2', SOURCE_INT), ('question1', SOURCE_INT), ('question2', SOURCE_INT), ('is_duplicate', SOURCE_INT)]) tokenize = 'revtok' # lambda x: x.split(' ') # 'revtok' # TEXT_TARGET = TargetField(batch_first=True, sequential=True, use_vocab=True, lower=True, init_token=TargetField.SYM_SOS, eos_token=TargetField.SYM_EOS, tokenize=tokenize) # fix_length=30) TEXT = SourceField(batch_first=True, sequential=True, use_vocab=True, lower=True, tokenize=tokenize) # , fix_length=20) LABEL = data.Field(batch_first=True, sequential=False, use_vocab=False, tensor_type=torch.FloatTensor) sem_style_ds = SemStyleDS(ds, TEXT, TEXT_TARGET, LABEL, max_id=1000 * 1000) for i in range(5): print(type(sem_style_ds[i].sent_0), type(sem_style_ds[i].is_x_0), sem_style_ds[i]) ds_train = data.Dataset(sem_style_ds, fields=[('sent_0', TEXT), ('sent_1', TEXT), ('sent_x', TEXT), ('is_x_0', LABEL), ('sent_0_target', TEXT_TARGET)]) # import pdb; pdb.set_trace() print('printing dataset directly, before tokenizing:') print('sent_0', ds_train[2].sent_0) # not processed print('is_x_0', ds_train[2].is_x_0) # not processed print('\nbuilding vocab:') TEXT_TARGET.build_vocab( ds_train, vectors='fasttext.simple.300d' ) # , max_size=80000)#,vectors='fasttext.simple.300d') #vectors=,'fasttext.simple.300d' not-simple 'fasttext.en.300d' ,'glove.twitter.27B.50d': ' TEXT.vocab = TEXT_TARGET.vocab # same except from the added <sos>,<eos> print('vocab TEXT: len', len(TEXT.vocab), 'common', TEXT.vocab.freqs.most_common()[:10]) print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab), TEXT_TARGET.vocab.freqs.most_common()[:10]) print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id, TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS]) print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id, TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS]) print('vocab ', 'out-of-vocab', TEXT_TARGET.eos_id, TEXT_TARGET.vocab.stoi['out-of-vocab']) device = None if torch.cuda.is_available() else -1 # READ: https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb bucket_iter_train = data.BucketIterator(dataset=ds_train, shuffle=True, device=device, batch_size=32, sort_within_batch=False, sort_key=lambda x: len(x.sent_0)) bucket_iter_valid = bucket_iter_train # data.BucketIterator(dataset=ds_val, shuffle=False, device=device, batch_size=32, # sort_within_batch=False, #sort_key=lambda x: len(x.sent_0) # ) #bucket_iter_train = data.BucketIterator(dataset=ds_train, device=device, batch_size=32, sort_within_batch=False, # sort_key=lambda x: len(x.sent_0)) #print('$' * 40, 'change batch_size to 32') # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s) if verbose: training_batch_generator = iter(bucket_iter_train) for i in range(5): b = next(training_batch_generator) # usage print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type()) # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape ) print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0]) print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0], revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], ' ')) print('b_sent0', b.sent_0[0].shape, b.sent_0[1].shape, b.sent_0[0][0], revers_vocab(TEXT.vocab, b.sent_0[0][0], ' ')) print('b_sent1', b.sent_1[0].shape, b.sent_1[1].shape, b.sent_1[0][0], revers_vocab(TEXT.vocab, b.sent_1[0][0], ' ')) print('b_sentx', b.sent_x[0].shape, b.sent_x[1].shape, b.sent_x[0][0], revers_vocab(TEXT.vocab, b.sent_x[0][0], ' ')) print('b_y', b.is_x_0.shape, b.is_x_0[0]) # addons return bucket_iter_train, bucket_iter_valid
if opt.load_checkpoint is not None: logging.info("loading checkpoint from {}".format( os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab else: # Prepare dataset src = SourceField(init_token='<sos>', eos_token='<eos>') tgt = TargetField( init_token='<sos>', eos_token='<eos>') # init_token='<sos>', eos_token='<eos>' max_len = 300 def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len train = torchtext.data.TabularDataset( path=opt.train_path, format='tsv', fields=[('src', src), ('tgt', tgt)], ) dev = torchtext.data.TabularDataset( path=opt.dev_path, format='tsv',
def train(): src = SourceField(sequential=True, tokenize=lambda x: [i for i in jieba.lcut(x)]) tgt = TargetField(sequential=True, tokenize=lambda x: [i for i in jieba.lcut(x)]) max_len = 50 def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len train = torchtext.data.TabularDataset(path=opt.train_path, format='csv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) dev = torchtext.data.TabularDataset(path=opt.dev_path, format='csv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) src.build_vocab(train, max_size=50000) tgt.build_vocab(train, max_size=50000) input_vocab = src.vocab output_vocab = tgt.vocab # NOTE: If the source field name and the target field name # are different from 'src' and 'tgt' respectively, they have # to be set explicitly before any training or inference # seq2seq.src_field_name = 'src' # seq2seq.tgt_field_name = 'tgt' # Prepare loss weight = torch.ones(len(tgt.vocab)) pad = tgt.vocab.stoi[tgt.pad_token] loss = Perplexity(weight, pad) if torch.cuda.is_available(): loss.cuda() seq2seq = None optimizer = None if not opt.resume: # Initialize model hidden_size = 128 bidirectional = True encoder = EncoderRNN(len(src.vocab), max_len, hidden_size, bidirectional=bidirectional, variable_lengths=True) decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size, dropout_p=0.2, use_attention=True, bidirectional=bidirectional, eos_id=tgt.eos_id, sos_id=tgt.sos_id) seq2seq = Seq2seq(encoder, decoder) if torch.cuda.is_available(): seq2seq.cuda() for param in seq2seq.parameters(): param.data.uniform_(-0.08, 0.08) # Optimizer and learning rate scheduler can be customized by # explicitly constructing the objects and pass to the trainer. # # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5) # scheduler = StepLR(optimizer.optimizer, 1) # optimizer.set_scheduler(scheduler) # train t = SupervisedTrainer(loss=loss, batch_size=32, checkpoint_every=50, print_every=10, expt_dir=opt.expt_dir) seq2seq = t.train(seq2seq, train, num_epochs=6, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume) predictor = Predictor(seq2seq, input_vocab, output_vocab)
if opt.load_checkpoint is not None: logging.info("loading checkpoint from {}".format( os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab else: # Prepare dataset src = SourceField() tgt = TargetField() max_len = 50 def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len train = torchtext.data.TabularDataset(path=opt.train_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) src.build_vocab(train, max_size=20000 - 2) tgt.build_vocab(train, max_size=20000 - 2) sos_id = tgt.vocab.stoi['<sos>']
# logger = Logger('./logs') # 用於tensorboard # Params random_seed = 80 checkpoint = '' resume = True max_len = 50 min_len = 5 hidden_size = 256 # encoder/decoder hidden size bidirectional = True # prepare dataset src = SourceField() tgt = TargetField() device = None if torch.cuda.is_available() else -1 pre_train, pre_dev, pre_test = Lang8.splits(exts=('.pre.cor', '.pre.err'), fields=[('src', src), ('tgt', tgt)], train='test', validation='test', test='test') adv_train, adv_dev, adv_test = Lang8.splits(exts=('.adv.cor', '.adv.err'), fields=[('src', src), ('tgt', tgt)], train='test', validation='test', test='test') adv_train_iter, adv_dev_iter, real_iter = torchtext.data.BucketIterator.splits( (adv_train, adv_dev, adv_train),
model.cuda() for param in model.parameters(): param.data.uniform_(-0.08, 0.08) return model, input_vocab, output_vocab init_log() ################ create dataset ################# # Prepare dataset fields src = SourceField() spk = SpeakerField() tgt = TargetField() # load the dataset # fields = [('0', spk), ('src', src), ('1', spk), ('src0', src), ('2', spk), ('src1', src), ('3', spk), ('tgt', src)] # # train, validation, dev = SpeakerDataset.splits( # num=args.num_sentence, format=args.format, # path=args.path, # fields=fields) # load the dataset if args.num_sentence > 1: fields = [('0', spk), ('src', src)] for i in range(args.num_sentence - 2): fields.append((str(i + 1), spk))
LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, opt.log_level.upper())) logging.info(opt) if opt.load_checkpoint is not None: logging.info("loading checkpoint from {}".format(os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab else: # Prepare dataset src = SourceField() tgt = TargetField() max_len = 50 def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len train = torchtext.data.TabularDataset( path=opt.train_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter ) dev = torchtext.data.TabularDataset( path=opt.dev_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter ) test = torchtext.data.TabularDataset( path=opt.test_path, format='tsv',
'load_checkpoint argument is required to resume training from checkpoint' ) LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, opt.log_level.upper())) logging.info(opt) if torch.cuda.is_available(): print("Cuda device set to %i" % opt.cuda_device) torch.cuda.set_device(opt.cuda_device) ############################################################################ # Prepare dataset src = SourceField() tgt = TargetField() max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate training and testing data train = torchtext.data.TabularDataset(path=opt.train, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) if opt.dev: dev = torchtext.data.TabularDataset(path=opt.dev,
def run_training(opt, default_data_dir, num_epochs=100): if opt.load_checkpoint is not None: logging.info("loading checkpoint from {}".format( os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab else: # Prepare dataset src = SourceField() tgt = TargetField() max_len = 50 data_file = os.path.join(default_data_dir, opt.train_path, 'data.txt') logging.info("Starting new Training session on %s", data_file) def len_filter(example): return (len(example.src) <= max_len) and (len(example.tgt) <= max_len) \ and (len(example.src) > 0) and (len(example.tgt) > 0) train = torchtext.data.TabularDataset( path=data_file, format='json', fields={'src': ('src', src), 'tgt': ('tgt', tgt)}, filter_pred=len_filter ) dev = None if opt.no_dev is False: dev_data_file = os.path.join(default_data_dir, opt.train_path, 'dev-data.txt') dev = torchtext.data.TabularDataset( path=dev_data_file, format='json', fields={'src': ('src', src), 'tgt': ('tgt', tgt)}, filter_pred=len_filter ) src.build_vocab(train, max_size=50000) tgt.build_vocab(train, max_size=50000) input_vocab = src.vocab output_vocab = tgt.vocab # NOTE: If the source field name and the target field name # are different from 'src' and 'tgt' respectively, they have # to be set explicitly before any training or inference # seq2seq.src_field_name = 'src' # seq2seq.tgt_field_name = 'tgt' # Prepare loss weight = torch.ones(len(tgt.vocab)) pad = tgt.vocab.stoi[tgt.pad_token] loss = Perplexity(weight, pad) if torch.cuda.is_available(): logging.info("Yayyy We got CUDA!!!") loss.cuda() else: logging.info("No cuda available device found running on cpu") seq2seq = None optimizer = None if not opt.resume: hidden_size = 128 decoder_hidden_size = hidden_size * 2 logging.info("EncoderRNN Hidden Size: %s", hidden_size) logging.info("DecoderRNN Hidden Size: %s", decoder_hidden_size) bidirectional = True encoder = EncoderRNN(len(src.vocab), max_len, hidden_size, bidirectional=bidirectional, rnn_cell='lstm', variable_lengths=True) decoder = DecoderRNN(len(tgt.vocab), max_len, decoder_hidden_size, dropout_p=0, use_attention=True, bidirectional=bidirectional, rnn_cell='lstm', eos_id=tgt.eos_id, sos_id=tgt.sos_id) seq2seq = Seq2seq(encoder, decoder) if torch.cuda.is_available(): seq2seq.cuda() for param in seq2seq.parameters(): param.data.uniform_(-0.08, 0.08) # Optimizer and learning rate scheduler can be customized by # explicitly constructing the objects and pass to the trainer. optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5) scheduler = StepLR(optimizer.optimizer, 1) optimizer.set_scheduler(scheduler) # train num_epochs = num_epochs batch_size = 32 checkpoint_every = num_epochs / 10 print_every = num_epochs / 100 properties = dict(batch_size=batch_size, checkpoint_every=checkpoint_every, print_every=print_every, expt_dir=opt.expt_dir, num_epochs=num_epochs, teacher_forcing_ratio=0.5, resume=opt.resume) logging.info("Starting training with the following Properties %s", json.dumps(properties, indent=2)) t = SupervisedTrainer(loss=loss, batch_size=num_epochs, checkpoint_every=checkpoint_every, print_every=print_every, expt_dir=opt.expt_dir) seq2seq = t.train(seq2seq, train, num_epochs=num_epochs, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume) evaluator = Evaluator(loss=loss, batch_size=batch_size) if opt.no_dev is False: dev_loss, accuracy = evaluator.evaluate(seq2seq, dev) logging.info("Dev Loss: %s", dev_loss) logging.info("Accuracy: %s", dev_loss) beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 4)) predictor = Predictor(beam_search, input_vocab, output_vocab) while True: try: seq_str = raw_input("Type in a source sequence:") seq = seq_str.strip().split() results = predictor.predict_n(seq, n=3) for i, res in enumerate(results): print('option %s: %s\n', i + 1, res) except KeyboardInterrupt: logging.info("Bye Bye") exit(0)
print("Cuda device set to %i" % opt.cuda_device) torch.cuda.set_device(opt.cuda_device) ################################################################################# # load model logging.info("loading checkpoint from {}".format(os.path.join(opt.checkpoint_path))) checkpoint = Checkpoint.load(opt.checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab ############################################################################ # Prepare dataset and loss src = SourceField() tgt = TargetField() src.vocab = input_vocab tgt.vocab = output_vocab max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate test set test = torchtext.data.TabularDataset( path=opt.test_data, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter ) # Prepare loss
logging.info("loading checkpoint from {}".format( os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab else: # Prepare dataset # SourceField requires that batch_first and include_lengths be true. src = SourceField(lower=opt.lower) # TargetField requires that batch_first be true as well as prepends <sos> and appends <eos> to sequences. tgt = TargetField() # Sequence's length cannot exceed max_len. max_len = 100 def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len train = torchtext.data.TabularDataset(path=opt.train_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) dev = torchtext.data.TabularDataset(path=opt.dev_path, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) test = torchtext.data.TabularDataset(path=opt.test_path,
params = { 'n_layers': 2, 'hidden_size': 512, 'src_vocab_size': 15000, 'tgt_vocab_size': 5000, 'max_len': 128, 'rnn_cell': 'lstm', 'batch_size': opt.batch_size, 'num_epochs': opt.epochs } logging.info(params) # Prepare dataset src = SourceField() tgt = TargetField() poison_field = torchtext.data.Field(sequential=False, use_vocab=False) max_len = params['max_len'] def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len def train_filter(example): return len_filter(example) train, fields, src, tgt, poison_field, idx_field = load_data( opt.train_path, filter_func=train_filter) dev, dev_fields, src, tgt, poison_field, idx_field = load_data(