def main(): parser = argparse.ArgumentParser() opt = options.train_options(parser) opt = parser.parse_args() opt.cuda = torch.cuda.is_available() opt.device = None if opt.cuda else -1 # 快速變更設定 opt.exp_dir = './experiment/transformer-reinforce/use_billion' opt.load_vocab_from = './experiment/transformer/lang8-cor2err/vocab.pt' opt.build_vocab_from = './data/billion/billion.30m.model.vocab' opt.load_D_from = opt.exp_dir # opt.load_D_from = None # dataset params opt.max_len = 20 # G params # opt.load_G_a_from = './experiment/transformer/lang8-err2cor/' # opt.load_G_b_from = './experiment/transformer/lang8-cor2err/' opt.d_word_vec = 300 opt.d_model = 300 opt.d_inner_hid = 600 opt.n_head = 6 opt.n_layers = 3 opt.embs_share_weight = False opt.beam_size = 1 opt.max_token_seq_len = opt.max_len + 2 # 包含<BOS>, <EOS> opt.n_warmup_steps = 4000 # D params opt.embed_dim = opt.d_model opt.num_kernel = 100 opt.kernel_sizes = [3, 4, 5, 6, 7] opt.dropout_p = 0.25 # train params opt.batch_size = 1 opt.n_epoch = 10 if not os.path.exists(opt.exp_dir): os.makedirs(opt.exp_dir) logging.basicConfig(filename=opt.exp_dir + '/.log', format=LOG_FORMAT, level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler()) logging.info('Use CUDA? ' + str(opt.cuda)) logging.info(opt) # ---------- prepare dataset ---------- def len_filter(example): return len(example.src) <= opt.max_len and len( example.tgt) <= opt.max_len EN = SentencePieceField(init_token=Constants.BOS_WORD, eos_token=Constants.EOS_WORD, batch_first=True, include_lengths=True) train = datasets.TranslationDataset(path='./data/dualgan/train', exts=('.billion.sp', '.use.sp'), fields=[('src', EN), ('tgt', EN)], filter_pred=len_filter) val = datasets.TranslationDataset(path='./data/dualgan/val', exts=('.billion.sp', '.use.sp'), fields=[('src', EN), ('tgt', EN)], filter_pred=len_filter) train_lang8, val_lang8 = Lang8.splits(exts=('.err.sp', '.cor.sp'), fields=[('src', EN), ('tgt', EN)], train='test', validation='test', test=None, filter_pred=len_filter) # 讀取 vocabulary(確保一致) try: logging.info('Load voab from %s' % opt.load_vocab_from) EN.load_vocab(opt.load_vocab_from) except FileNotFoundError: EN.build_vocab_from(opt.build_vocab_from) EN.save_vocab(opt.load_vocab_from) logging.info('Vocab len: %d' % len(EN.vocab)) # 檢查Constants是否有誤 assert EN.vocab.stoi[Constants.BOS_WORD] == Constants.BOS assert EN.vocab.stoi[Constants.EOS_WORD] == Constants.EOS assert EN.vocab.stoi[Constants.PAD_WORD] == Constants.PAD assert EN.vocab.stoi[Constants.UNK_WORD] == Constants.UNK # ---------- init model ---------- # G = build_G(opt, EN, EN) hidden_size = 512 bidirectional = True encoder = EncoderRNN(len(EN.vocab), opt.max_len, hidden_size, n_layers=1, bidirectional=bidirectional) decoder = DecoderRNN(len(EN.vocab), opt.max_len, hidden_size * 2 if bidirectional else 1, n_layers=1, dropout_p=0.2, use_attention=True, bidirectional=bidirectional, eos_id=Constants.EOS, sos_id=Constants.BOS) G = Seq2seq(encoder, decoder) for param in G.parameters(): param.data.uniform_(-0.08, 0.08) # optim_G = ScheduledOptim(optim.Adam( # G.get_trainable_parameters(), # betas=(0.9, 0.98), eps=1e-09), # opt.d_model, opt.n_warmup_steps) optim_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-09) loss_G = NLLLoss(size_average=False) if torch.cuda.is_available(): loss_G.cuda() # # 預先訓練D if opt.load_D_from: D = load_model(opt.load_D_from) else: D = build_D(opt, EN) optim_D = torch.optim.Adam(D.parameters(), lr=1e-4) def get_criterion(vocab_size): ''' With PAD token zero weight ''' weight = torch.ones(vocab_size) weight[Constants.PAD] = 0 return nn.CrossEntropyLoss(weight, size_average=False) crit_G = get_criterion(len(EN.vocab)) crit_D = nn.BCELoss() if opt.cuda: G.cuda() D.cuda() crit_G.cuda() crit_D.cuda() # ---------- train ---------- trainer_D = trainers.DiscriminatorTrainer() if not opt.load_D_from: for epoch in range(1): logging.info('[Pretrain D Epoch %d]' % epoch) pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len, Constants.PAD) # 將資料塞進pool中 train_iter = data.BucketIterator(dataset=train, batch_size=opt.batch_size, device=opt.device, sort_key=lambda x: len(x.src), repeat=False) pool.fill(train_iter) # train D trainer_D.train(D, train_iter=pool.batch_gen(), crit=crit_D, optimizer=optim_D) pool.reset() Checkpoint(model=D, optimizer=optim_D, epoch=0, step=0, input_vocab=EN.vocab, output_vocab=EN.vocab).save(opt.exp_dir) def eval_D(): pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len, Constants.PAD) val_iter = data.BucketIterator(dataset=val, batch_size=opt.batch_size, device=opt.device, sort_key=lambda x: len(x.src), repeat=False) pool.fill(val_iter) trainer_D.evaluate(D, val_iter=pool.batch_gen(), crit=crit_D) # eval_D() # Train G ALPHA = 0 for epoch in range(100): logging.info('[Epoch %d]' % epoch) train_iter = data.BucketIterator(dataset=train, batch_size=1, device=opt.device, sort_within_batch=True, sort_key=lambda x: len(x.src), repeat=False) for step, batch in enumerate(train_iter): src_seq = batch.src[0] src_length = batch.src[1] tgt_seq = src_seq[0].clone() # gold = tgt_seq[:, 1:] optim_G.zero_grad() loss_G.reset() decoder_outputs, decoder_hidden, other = G.rollout(src_seq, None, None, n_rollout=1) for i, step_output in enumerate(decoder_outputs): batch_size = tgt_seq.size(0) # print(step_output) # loss_G.eval_batch(step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1]) softmax_output = torch.exp( torch.cat([x for x in decoder_outputs], dim=0)).unsqueeze(0) softmax_output = helper.stack(softmax_output, 8) print(softmax_output) rollout = softmax_output.multinomial(1) print(rollout) tgt_seq = helper.pad_seq(tgt_seq.data, max_len=len(decoder_outputs) + 1, pad_value=Constants.PAD) tgt_seq = autograd.Variable(tgt_seq) for i, step_output in enumerate(decoder_outputs): batch_size = tgt_seq.size(0) loss_G.eval_batch( step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1]) G.zero_grad() loss_G.backward() optim_G.step() if step % 100 == 0: pred = torch.cat([x for x in other['sequence']], dim=1) print('[step %d] loss_rest %.4f' % (epoch * len(train_iter) + step, loss_G.get_loss())) print('%s -> %s' % (EN.reverse(tgt_seq.data)[0], EN.reverse(pred.data)[0])) # Reinforce Train G for p in D.parameters(): p.requires_grad = False
logging.info('[Epoch %d]' % epoch) train_iter = data.BucketIterator(dataset=train, batch_size=16, device=opt.device, sort_within_batch=True, sort_key=lambda x: len(x.src), repeat=False) for step, batch in enumerate(train_iter): src_seq = batch.src[0] src_length = batch.src[1] tgt_seq = src_seq.clone() # a -> b' -> a decoder_outputs, decoder_hiddens, other = G.forward( src_seq, src_length.tolist(), target_variable=None) crit_G.reset() for i, step_output in enumerate(decoder_outputs): batch_size = tgt_seq.size(0) crit_G.eval_batch(step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1]) optim_G.zero_grad() crit_G.backward() optim_G.step() if step % 100 == 0: pred = torch.cat([x for x in other['sequence']], dim=1) print('[step %d] loss %.4f' % (epoch * len(train_iter) + step, crit_G.get_loss())) print('%s -> %s' % (EN.reverse(tgt_seq.data)[0], EN.reverse(pred.data)[0]))
def train(opt): LOG_FORMAT = '%(asctime)s %(levelname)-8s %(message)s' logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, opt.log_level.upper())) logging.info(opt) if int(opt.GPU) >= 0: torch.cuda.set_device(int(opt.GPU)) if opt.load_checkpoint is not None: logging.info("loading checkpoint from {}".format( os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2tree = checkpoint.model input_vocab = checkpoint.input_vocab else: # Prepare dataset src = SourceField() nt = NTField() pos = PosField() tgt_tree = TreeField() comp = CompField() max_len = opt.max_len def len_filter(example): return len(example.src) <= max_len train = torchtext.data.TabularDataset(path=opt.train_path, format='tsv', fields=[('src', src), ('nt', nt), ('pos', pos), ('tree', tgt_tree)], filter_pred=len_filter) dev = torchtext.data.TabularDataset(path=opt.dev_path, format='tsv', fields=[('src', src), ('nt', nt), ('pos', pos), ('tree', tgt_tree)], filter_pred=len_filter) src.build_vocab(train, max_size=50000) comp.build_vocab(train, max_size=50000) nt.build_vocab(train, max_size=50000) pos.build_vocab(train, max_size=50000) # src_tree.build_vocab(train, max_size=50000) pos_in_nt = set() for Pos in pos.vocab.stoi: if nt.vocab.stoi[Pos] > 1: pos_in_nt.add(nt.vocab.stoi[Pos]) hidden_size = opt.hidden_size input_vocab = src.vocab nt_vocab = nt.vocab def tree_to_id(tree): tree.set_label(nt_vocab.stoi[tree.label()]) if len(tree) == 1 and str(tree[0])[0] is not '(': tree[0] = input_vocab.stoi[tree[0]] return else: for subtree in tree: tree_to_id(subtree) tree.append(Tree(nt_vocab.stoi['<eos>'], [])) return tree # train.examples = [str(tree_to_id(ex.tree)) for ex in train.examples] # dev.examples = [str(tree_to_id(ex.tree)) for ex in dev.examples] for ex in train.examples: ex.tree = str(tree_to_id(Tree.fromstring(ex.tree))) for ex in dev.examples: ex.tree = str(tree_to_id(Tree.fromstring(ex.tree))) # train.examples = [tree_to_id(Tree.fromstring(ex.tree)) for ex in train.examples] # dev.examples = [str(tree_to_id(Tree.fromstring(ex.tree))) for ex in dev.examples] if opt.word_embedding is not None: input_vocab.load_vectors([opt.word_embedding]) loss = NLLLoss() if torch.cuda.is_available(): loss.cuda() loss.reset() seq2tree = None optimizer = None if not opt.resume: # Initialize model bidirectional = opt.bidirectional_encoder encoder = EncoderRNN(len(src.vocab), opt.word_embedding_size, max_len, hidden_size, bidirectional=bidirectional, variable_lengths=True) decoder = DecoderTree(len(src.vocab), opt.word_embedding_size, opt.nt_embedding_size, len(nt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size, sos_id=nt_vocab.stoi['<sos>'], eos_id=nt_vocab.stoi['<eos>'], dropout_p=0.2, use_attention=True, bidirectional=bidirectional, pos_in_nt=pos_in_nt) seq2tree = Seq2tree(encoder, decoder) if torch.cuda.is_available(): seq2tree.cuda() for param in seq2tree.parameters(): param.data.uniform_(-0.08, 0.08) # encoder.embedding.weight.data.set_(input_vocab.vectors) # encoder.embedding.weight.data.set_(output_vocab.vectors) # Optimizer and learning rate scheduler can be customized by # explicitly constructing the objects and pass to the trainer. # # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5) # scheduler = StepLR(optimizer.optimizer, 1) # optimizer.set_scheduler(scheduler) optimizer = Optimizer(optim.Adam(seq2tree.parameters(), lr=opt.lr), max_grad_norm=5) # train t = SupervisedTrainer(loss=loss, batch_size=opt.batch_size, checkpoint_every=opt.checkpoint_every, print_every=10, expt_dir=opt.expt_dir, lr=opt.lr) seq2tree = t.train(seq2tree, train, num_epochs=opt.epoch, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0, resume=opt.resume) predictor = Predictor(seq2tree, input_vocab, nt_vocab) return predictor, dev, train
# input_vocab.load_vectors([]) # input_vocab.load_vectors(['glove.840B.300d']) # input_vocab.vectors[input_vocab.stoi['<unk>']] = torch.Tensor(hidden_size).uniform_(-0.8,0.8)#<unk> # Prepare loss # loss = NLLLoss(weight, pad)#Perplexity(weight, pad) loss = NLLLoss() if torch.cuda.is_available(): loss.cuda() loss.reset() seq2tree = None if not opt.resume: # Initialize model bidirectional = True encoder = EncoderRNN(len(src.vocab), max_len, hidden_size, bidirectional=bidirectional, variable_lengths=True) decoder = DecoderTree(len(src.vocab), len(nt.vocab),max_len, hidden_size * 2 if bidirectional else hidden_size, dropout_p=0.2, use_attention=True, bidirectional=bidirectional, pos_in_nt = pos_in_nt) seq2tree = Seq2tree(encoder, decoder) if torch.cuda.is_available(): seq2tree.cuda()
def eval_fa_equiv(model, data, input_vocab, output_vocab): loss = NLLLoss() batch_size = 1 model.eval() loss.reset() match = 0 total = 0 device = None if torch.cuda.is_available() else -1 batch_iterator = torchtext.data.BucketIterator( dataset=data, batch_size=batch_size, sort=False, sort_key=lambda x: len(x.src), device=device, train=False) tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token] predictor = Predictor(model, input_vocab, output_vocab) num_samples = 0 perfect_samples = 0 dfa_perfect_samples = 0 match = 0 total = 0 with torch.no_grad(): for batch in batch_iterator: num_samples = num_samples + 1 input_variables, input_lengths = getattr(batch, seq2seq.src_field_name) target_variables = getattr(batch, seq2seq.tgt_field_name) target_string = decode_tensor(target_variables, output_vocab) #target_string = target_string + " <eos>" input_string = decode_tensor(input_variables, input_vocab) generated_string = ' '.join([ x for x in predictor.predict(input_string.strip().split())[:-1] if x != '<pad>' ]) #str(pos_example)[2] generated_string = refine_outout(generated_string) #str(pos_example)[2] pos_example = subprocess.check_output([ 'python2', 'regexDFAEquals.py', '--gold', '{}'.format(target_string), '--predicted', '{}'.format(generated_string) ]) if target_string == generated_string: perfect_samples = perfect_samples + 1 dfa_perfect_samples = dfa_perfect_samples + 1 elif str(pos_example)[2] == '1': dfa_perfect_samples = dfa_perfect_samples + 1 target_tokens = target_string.split() generated_tokens = generated_string.split() shorter_len = min(len(target_tokens), len(generated_tokens)) for idx in range(len(generated_tokens)): total = total + 1 if idx >= len(target_tokens): total = total + 1 elif target_tokens[idx] == generated_tokens[idx]: match = match + 1 if total == 0: accuracy = float('nan') else: accuracy = match / total string_accuracy = perfect_samples / num_samples dfa_accuracy = dfa_perfect_samples / num_samples f = open('./time_logs/log_score_time.txt', 'a') f.write('{}\n'.format(dfa_accuracy)) f.close()