def self_training(args): """Perform self-training First load decoding results on disjoint data also load pre-trained model and perform supervised training on both existing training data and the decoded results """ print('load pre-trained model from [%s]' % args.load_model, file=sys.stderr) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) vocab = params['vocab'] transition_system = params['transition_system'] saved_args = params['args'] saved_state = params['state_dict'] # transfer arguments saved_args.cuda = args.cuda saved_args.save_to = args.save_to saved_args.train_file = args.train_file saved_args.unlabeled_file = args.unlabeled_file saved_args.dev_file = args.dev_file saved_args.load_decode_results = args.load_decode_results args = saved_args update_args(args) model = Parser(saved_args, vocab, transition_system) model.load_state_dict(saved_state) if args.cuda: model = model.cuda() model.train() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) print('load unlabeled data [%s]' % args.unlabeled_file, file=sys.stderr) unlabeled_data = Dataset.from_bin_file(args.unlabeled_file) print('load decoding results of unlabeled data [%s]' % args.load_decode_results, file=sys.stderr) decode_results = pickle.load(open(args.load_decode_results)) labeled_data = Dataset.from_bin_file(args.train_file) dev_set = Dataset.from_bin_file(args.dev_file) print('Num. examples in unlabeled data: %d' % len(unlabeled_data), file=sys.stderr) assert len(unlabeled_data) == len(decode_results) self_train_examples = [] for example, hyps in zip(unlabeled_data, decode_results): if hyps: hyp = hyps[0] sampled_example = Example(idx='self_train-%s' % example.idx, src_sent=example.src_sent, tgt_code=hyp.code, tgt_actions=hyp.action_infos, tgt_ast=hyp.tree) self_train_examples.append(sampled_example) print('Num. self training examples: %d, Num. labeled examples: %d' % (len(self_train_examples), len(labeled_data)), file=sys.stderr) train_set = Dataset(examples=labeled_data.examples + self_train_examples) print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr) print('vocab: %s' % repr(vocab), file=sys.stderr) epoch = train_iter = 0 report_loss = report_examples = 0. history_dev_scores = [] num_trial = patience = 0 while True: epoch += 1 epoch_begin = time.time() for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True): batch_examples = [ e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step ] train_iter += 1 optimizer.zero_grad() loss = -model.score(batch_examples) # print(loss.data) loss_val = torch.sum(loss).data[0] report_loss += loss_val report_examples += len(batch_examples) loss = torch.mean(loss) loss.backward() # clip gradient if args.clip_grad > 0.: grad_norm = torch.nn.utils.clip_grad_norm( model.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: print('[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples), file=sys.stderr) report_loss = report_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) # model_file = args.save_to + '.iter%d.bin' % train_iter # print('save model to [%s]' % model_file, file=sys.stderr) # model.save(model_file) # perform validation print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(dev_set.examples, model, args, verbose=True) dev_acc = eval_results['accuracy'] print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr) is_better = history_dev_scores == [] or dev_acc > max( history_dev_scores) history_dev_scores.append(dev_acc) if is_better: patience = 0 model_file = args.save_to + '.bin' print('save currently the best model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif epoch == args.max_epoch: print('reached max epoch, stop!', file=sys.stderr) exit(0) elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() # load optimizers if args.reset_optimizer: print('reset optimizer', file=sys.stderr) optimizer = torch.optim.Adam( model.inference_model.parameters(), lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict( torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0
def train(args): grammar = ASDLGrammar.from_text(open(args.asdl_file).read()) transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar) train_set = Dataset.from_bin_file(args.train_file) dev_set = Dataset.from_bin_file(args.dev_file) vocab = pickle.load(open(args.vocab)) model = Parser(args, vocab, transition_system) model.train() if args.cuda: model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr) print('vocab: %s' % repr(vocab), file=sys.stderr) epoch = train_iter = 0 report_loss = report_examples = 0. history_dev_scores = [] num_trial = patience = 0 while True: epoch += 1 epoch_begin = time.time() for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True): batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step] train_iter += 1 optimizer.zero_grad() loss = -model.score(batch_examples) # print(loss.data) loss_val = torch.sum(loss).data[0] report_loss += loss_val report_examples += len(batch_examples) loss = torch.mean(loss) loss.backward() # clip gradient if args.clip_grad > 0.: grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: print('[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples), file=sys.stderr) report_loss = report_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) # model_file = args.save_to + '.iter%d.bin' % train_iter # print('save model to [%s]' % model_file, file=sys.stderr) # model.save(model_file) # perform validation print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(dev_set.examples, model, args, verbose=True) dev_acc = eval_results['accuracy'] print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr) is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores) history_dev_scores.append(dev_acc) if is_better: patience = 0 model_file = args.save_to + '.bin' print('save currently the best model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif epoch == args.max_epoch: print('reached max epoch, stop!', file=sys.stderr) exit(0) elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() # load optimizers if args.reset_optimizer: print('reset optimizer', file=sys.stderr) optimizer = torch.optim.Adam(model.inference_model.parameters(), lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0
def self_training(args): """Perform self-training First load decoding results on disjoint data also load pre-trained model and perform supervised training on both existing training data and the decoded results """ print('load pre-trained model from [%s]' % args.load_model, file=sys.stderr) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) vocab = params['vocab'] transition_system = params['transition_system'] saved_args = params['args'] saved_state = params['state_dict'] # transfer arguments saved_args.cuda = args.cuda saved_args.save_to = args.save_to saved_args.train_file = args.train_file saved_args.unlabeled_file = args.unlabeled_file saved_args.dev_file = args.dev_file saved_args.load_decode_results = args.load_decode_results args = saved_args update_args(args) model = Parser(saved_args, vocab, transition_system) model.load_state_dict(saved_state) if args.cuda: model = model.cuda() model.train() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) print('load unlabeled data [%s]' % args.unlabeled_file, file=sys.stderr) unlabeled_data = Dataset.from_bin_file(args.unlabeled_file) print('load decoding results of unlabeled data [%s]' % args.load_decode_results, file=sys.stderr) decode_results = pickle.load(open(args.load_decode_results)) labeled_data = Dataset.from_bin_file(args.train_file) dev_set = Dataset.from_bin_file(args.dev_file) print('Num. examples in unlabeled data: %d' % len(unlabeled_data), file=sys.stderr) assert len(unlabeled_data) == len(decode_results) self_train_examples = [] for example, hyps in zip(unlabeled_data, decode_results): if hyps: hyp = hyps[0] sampled_example = Example(idx='self_train-%s' % example.idx, src_sent=example.src_sent, tgt_code=hyp.code, tgt_actions=hyp.action_infos, tgt_ast=hyp.tree) self_train_examples.append(sampled_example) print('Num. self training examples: %d, Num. labeled examples: %d' % (len(self_train_examples), len(labeled_data)), file=sys.stderr) train_set = Dataset(examples=labeled_data.examples + self_train_examples) print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr) print('vocab: %s' % repr(vocab), file=sys.stderr) epoch = train_iter = 0 report_loss = report_examples = 0. history_dev_scores = [] num_trial = patience = 0 while True: epoch += 1 epoch_begin = time.time() for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True): batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step] train_iter += 1 optimizer.zero_grad() loss = -model.score(batch_examples) # print(loss.data) loss_val = torch.sum(loss).data[0] report_loss += loss_val report_examples += len(batch_examples) loss = torch.mean(loss) loss.backward() # clip gradient if args.clip_grad > 0.: grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: print('[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples), file=sys.stderr) report_loss = report_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) # model_file = args.save_to + '.iter%d.bin' % train_iter # print('save model to [%s]' % model_file, file=sys.stderr) # model.save(model_file) # perform validation print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(dev_set.examples, model, args, verbose=True) dev_acc = eval_results['accuracy'] print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr) is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores) history_dev_scores.append(dev_acc) if is_better: patience = 0 model_file = args.save_to + '.bin' print('save currently the best model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif epoch == args.max_epoch: print('reached max epoch, stop!', file=sys.stderr) exit(0) elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() # load optimizers if args.reset_optimizer: print('reset optimizer', file=sys.stderr) optimizer = torch.optim.Adam(model.inference_model.parameters(), lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0
def train_semi_jae(args): bi_direction = args.bi_direction encoder_params = torch.load(args.load_model, map_location=lambda storage, loc: storage) decoder_params = torch.load(args.load_decoder, map_location=lambda storage, loc: storage) print('loaded encoder at %s' % args.load_model, file=sys.stderr) print('loaded decoder at %s' % args.load_decoder, file=sys.stderr) transition_system = encoder_params['transition_system'] encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda encoder = Parser(encoder_params['args'], encoder_params['vocab'], transition_system) encoder.load_state_dict(encoder_params['state_dict']) decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'], transition_system) decoder.load_state_dict(decoder_params['state_dict']) zprior = LSTMPrior.load(args.load_prior, transition_system=transition_system, cuda=args.cuda) print('loaded p(z) prior at %s' % args.load_prior, file=sys.stderr) # freeze prior parameters for p in zprior.parameters(): p.requires_grad = False zprior.eval() xprior = LSTMLanguageModel.load(args.load_src_lm) print('loaded p(x) prior at %s' % args.load_src_lm, file=sys.stderr) xprior.eval() if args.cache: jae = JAE_cache(encoder, decoder, zprior, xprior, args) else: jae = JAE(encoder, decoder, zprior, xprior, args) jae.train() encoder.train() decoder.train() if args.cuda: jae.cuda() labeled_data = Dataset.from_bin_file(args.train_file) # labeled_data.examples = labeled_data.examples[:10] unlabeled_data = Dataset.from_bin_file( args.unlabeled_file) # pretend they are un-labeled! dev_set = Dataset.from_bin_file(args.dev_file) # dev_set.examples = dev_set.examples[:10] optimizer = torch.optim.Adam( [p for p in jae.parameters() if p.requires_grad], lr=args.lr) print( '*** begin semi-supervised training %d labeled examples, %d unlabeled examples ***' % (len(labeled_data), len(unlabeled_data)), file=sys.stderr) report_encoder_loss = report_decoder_loss = report_examples = 0. report_unsup_examples = report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = 0. patience = 0 num_trial = 1 epoch = train_iter = 0 history_dev_scores = [] while True: epoch += 1 epoch_begin = time.time() unlabeled_examples_iter = unlabeled_data.batch_iter( batch_size=args.unsup_batch_size, shuffle=True) for labeled_examples in labeled_data.batch_iter( batch_size=args.batch_size, shuffle=True): labeled_examples = [ e for e in labeled_examples if len(e.tgt_actions) <= args.decode_max_time_step ] train_iter += 1 optimizer.zero_grad() report_examples += len(labeled_examples) sup_encoder_loss = -encoder.score(labeled_examples) sup_decoder_loss = -decoder.score(labeled_examples) report_encoder_loss += sup_encoder_loss.sum().data[0] report_decoder_loss += sup_decoder_loss.sum().data[0] sup_encoder_loss = torch.mean(sup_encoder_loss) sup_decoder_loss = torch.mean(sup_decoder_loss) sup_loss = sup_encoder_loss + sup_decoder_loss # compute unsupervised loss try: unlabeled_examples = next(unlabeled_examples_iter) except StopIteration: # if finished unlabeled data stream, restart it unlabeled_examples_iter = unlabeled_data.batch_iter( batch_size=args.batch_size, shuffle=True) unlabeled_examples = next(unlabeled_examples_iter) unlabeled_examples = [ e for e in unlabeled_examples if len(e.tgt_actions) <= args.decode_max_time_step ] unsup_encoder_loss, unsup_decoder_loss, meta_data = jae.get_unsupervised_loss( unlabeled_examples, args.moves) if bi_direction: unsup_encoder_loss_back, unsup_decoder_loss_back, meta_data_back = jae.get_unsupervised_loss_backward( unlabeled_examples, args.moves) nan = False if nn_utils.isnan(sup_loss.data): print('Nan in sup_loss') nan = True if nn_utils.isnan(unsup_encoder_loss.data): print('Nan in unsup_encoder_loss!', file=sys.stderr) nan = True if nn_utils.isnan(unsup_decoder_loss.data): print('Nan in unsup_decoder_loss!', file=sys.stderr) nan = True if bi_direction: if nn_utils.isnan(unsup_encoder_loss_back.data): print('Nan in unsup_encoder_loss_back!', file=sys.stderr) nan = True if nn_utils.isnan(unsup_decoder_loss_back.data): print('Nan in unsup_decoder_loss_back!', file=sys.stderr) nan = True if nan: continue if bi_direction: report_unsup_encoder_loss += ( unsup_encoder_loss.sum().data[0] + unsup_encoder_loss_back.sum().data[0]) report_unsup_decoder_loss += ( unsup_decoder_loss.sum().data[0] + unsup_decoder_loss_back.sum().data[0]) else: report_unsup_encoder_loss += unsup_encoder_loss.sum().data[0] report_unsup_decoder_loss += unsup_decoder_loss.sum().data[0] report_unsup_examples += unsup_encoder_loss.size(0) if bi_direction: unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean( unsup_decoder_loss) + torch.mean( unsup_encoder_loss_back) + torch.mean( unsup_decoder_loss_back) else: unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean( unsup_decoder_loss) loss = sup_loss + args.unsup_loss_weight * unsup_loss loss.backward() # clip gradient grad_norm = torch.nn.utils.clip_grad_norm(jae.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: print( '[Iter %d] supervised: encoder loss=%.5f, decoder loss=%.5f' % (train_iter, report_encoder_loss / report_examples, report_decoder_loss / report_examples), file=sys.stderr) print( '[Iter %d] unsupervised: encoder loss=%.5f, decoder loss=%.5f, baseline loss=%.5f' % (train_iter, report_unsup_encoder_loss / report_unsup_examples, report_unsup_decoder_loss / report_unsup_examples, report_unsup_baseline_loss / report_unsup_examples), file=sys.stderr) samples = meta_data['samples'] for v in meta_data.values(): if isinstance(v, Variable): v.cpu() for i, sample in enumerate(samples[:1]): print('\t[%s] Source: %s' % (sample.idx, ' '.join(sample.src_sent)), file=sys.stderr) print('\t[%s] Code: \n%s' % (sample.idx, sample.tgt_code), file=sys.stderr) ref_example = [ e for e in unlabeled_examples if e.idx == int(sample.idx[:sample.idx.index('-')]) ][0] print('\t[%s] Gold Code: \n%s' % (sample.idx, ref_example.tgt_code), file=sys.stderr) print( '\t[%s] Log p(z|x): %f' % (sample.idx, meta_data['encoding_scores'][i].data[0]), file=sys.stderr) print('\t[%s] Log p(x|z): %f' % (sample.idx, meta_data['reconstruction_scores'][i].data[0]), file=sys.stderr) print('\t[%s] Encoder Loss: %f' % (sample.idx, unsup_encoder_loss[i].data[0]), file=sys.stderr) print('\t**************************', file=sys.stderr) report_encoder_loss = report_decoder_loss = report_examples = 0. report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = report_unsup_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) # perform validation print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(dev_set.examples, encoder, args, verbose=True) encoder.train() dev_acc = eval_results['accuracy'] print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr) is_better = history_dev_scores == [] or dev_acc > max( history_dev_scores) history_dev_scores.append(dev_acc) if is_better: patience = 0 model_file = args.save_to + '.bin' print('save currently the best model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) jae.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif epoch == args.max_epoch: print('reached max epoch, stop!', file=sys.stderr) exit(0) elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load best model's parameters jae.load_parameters(args.save_to + '.bin') if args.cuda: jae = jae.cuda() # load optimizers if args.reset_optimizer: print('reset to a new infer_optimizer', file=sys.stderr) optimizer = torch.optim.Adam( [p for p in jae.parameters() if p.requires_grad], lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict( torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0