def train(args, data_loader, model, global_stats, logger): """Run through one epoch of model training with the provided data loader.""" # Initialize meters + timers ml_loss = AverageMeter() perplexity = AverageMeter() epoch_time = Timer() current_epoch = global_stats['epoch'] pbar = tqdm(data_loader) pbar.set_description("%s" % 'Epoch = %d [perplexity = x.xx, ml_loss = x.xx]' % current_epoch) # Run one epoch for idx, ex in enumerate(pbar): bsz = ex['batch_size'] if args.optimizer in ['sgd', 'adam' ] and current_epoch <= args.warmup_epochs: cur_lrate = global_stats['warmup_factor'] * (model.updates + 1) for param_group in model.optimizer.param_groups: param_group['lr'] = cur_lrate net_loss = model.update(ex) ml_loss.update(net_loss['ml_loss'], bsz) perplexity.update(net_loss['perplexity'], bsz) log_info = 'Epoch = %d [perplexity = %.2f, ml_loss = %.2f]' % \ (current_epoch, perplexity.avg, ml_loss.avg) pbar.set_description("%s" % log_info) #break kvs = [("perp_tr", perplexity.avg), ("ml_lo_tr", ml_loss.avg),\ ("epoch_time", epoch_time.time())] for k, v in kvs: logger.add(current_epoch, **{k: v}) logger.print( 'train: Epoch %d | perplexity = %.2f | ml_loss = %.2f | ' 'Time for epoch = %.2f (s)' % (current_epoch, perplexity.avg, ml_loss.avg, epoch_time.time())) # Checkpoint if args.checkpoint: model.checkpoint(logger.path + '/best_model.cpt.checkpoint', current_epoch + 1)
def main(args): # -------------------------------------------------------------------------- # DATA logger.info('-' * 100) logger.info('Load and process data files') train_exs = [] if not args.only_test: args.dataset_weights = dict() for train_src, train_src_tag, train_tgt, dataset_name in \ zip(args.train_src_files, args.train_src_tag_files, args.train_tgt_files, args.dataset_name): train_files = dict() train_files['src'] = train_src train_files['src_tag'] = train_src_tag train_files['tgt'] = train_tgt exs = util.load_data(args, train_files, max_examples=args.max_examples, dataset_name=dataset_name) lang_name = constants.DATA_LANG_MAP[dataset_name] args.dataset_weights[constants.LANG_ID_MAP[lang_name]] = len(exs) train_exs.extend(exs) logger.info('Num train examples = %d' % len(train_exs)) args.num_train_examples = len(train_exs) for lang_id in args.dataset_weights.keys(): weight = (1.0 * args.dataset_weights[lang_id]) / len(train_exs) args.dataset_weights[lang_id] = round(weight, 2) logger.info('Dataset weights = %s' % str(args.dataset_weights)) dev_exs = [] for dev_src, dev_src_tag, dev_tgt, dataset_name in \ zip(args.dev_src_files, args.dev_src_tag_files, args.dev_tgt_files, args.dataset_name): dev_files = dict() dev_files['src'] = dev_src dev_files['src_tag'] = dev_src_tag dev_files['tgt'] = dev_tgt exs = util.load_data(args, dev_files, max_examples=args.max_examples, dataset_name=dataset_name, test_split=True) dev_exs.extend(exs) logger.info('Num dev examples = %d' % len(dev_exs)) # -------------------------------------------------------------------------- # MODEL logger.info('-' * 100) start_epoch = 1 if args.only_test: if args.pretrained: model = Code2NaturalLanguage.load(args.pretrained) else: if not os.path.isfile(args.model_file): raise IOError('No such file: %s' % args.model_file) model = Code2NaturalLanguage.load(args.model_file) else: if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'): # Just resume training, no modifications. logger.info('Found a checkpoint...') checkpoint_file = args.model_file + '.checkpoint' model, start_epoch = Code2NaturalLanguage.load_checkpoint( checkpoint_file, args.cuda) else: # Training starts fresh. But the model state is either pretrained or # newly (randomly) initialized. if args.pretrained: logger.info('Using pretrained model...') model = Code2NaturalLanguage.load(args.pretrained, args) else: logger.info('Training model from scratch...') model = init_from_scratch(args, train_exs, dev_exs) # Set up optimizer model.init_optimizer() # log the parameter details logger.info( 'Trainable #parameters [encoder-decoder] {} [total] {}'.format( human_format(model.network.count_encoder_parameters() + model.network.count_decoder_parameters()), human_format(model.network.count_parameters()))) table = model.network.layer_wise_parameters() logger.info('Breakdown of the trainable paramters\n%s' % table) # Use the GPU? if args.cuda: model.cuda() if args.parallel: model.parallelize() # -------------------------------------------------------------------------- # DATA ITERATORS # Two datasets: train and dev. If we sort by length it's faster. logger.info('-' * 100) logger.info('Make data loaders') if not args.only_test: train_dataset = data.CommentDataset(train_exs, model) if args.sort_by_len: train_sampler = data.SortedBatchSampler(train_dataset.lengths(), args.batch_size, shuffle=True) else: train_sampler = torch.utils.data.sampler.RandomSampler( train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, drop_last=args.parallel) dev_dataset = data.CommentDataset(dev_exs, model) dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset) dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=args.test_batch_size, sampler=dev_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, drop_last=args.parallel) # ------------------------------------------------------------------------- # PRINT CONFIG logger.info('-' * 100) logger.info('CONFIG:\n%s' % json.dumps(vars(args), indent=4, sort_keys=True)) # -------------------------------------------------------------------------- # DO TEST if args.only_test: stats = { 'timer': Timer(), 'epoch': 0, 'best_valid': 0, 'no_improvement': 0 } validate_official(args, dev_loader, model, stats, mode='test') # -------------------------------------------------------------------------- # TRAIN/VALID LOOP else: logger.info('-' * 100) logger.info('Starting training...') stats = { 'timer': Timer(), 'epoch': start_epoch, 'best_valid': 0, 'no_improvement': 0 } if args.optimizer in ['sgd', 'adam' ] and args.warmup_epochs >= start_epoch: logger.info("Use warmup lrate for the %d epoch, from 0 up to %s." % (args.warmup_epochs, args.learning_rate)) num_batches = len(train_loader.dataset) // args.batch_size warmup_factor = (args.learning_rate + 0.) / (num_batches * args.warmup_epochs) stats['warmup_factor'] = warmup_factor for epoch in range(start_epoch, args.num_epochs + 1): stats['epoch'] = epoch if args.optimizer in ['sgd', 'adam' ] and epoch > args.warmup_epochs: model.optimizer.param_groups[0]['lr'] = \ model.optimizer.param_groups[0]['lr'] * args.lr_decay train(args, train_loader, model, stats) result = validate_official(args, dev_loader, model, stats) # Save best valid if result[args.valid_metric] > stats['best_valid']: logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' % (args.valid_metric, result[args.valid_metric], stats['epoch'], model.updates)) model.save(args.model_file) stats['best_valid'] = result[args.valid_metric] stats['no_improvement'] = 0 else: stats['no_improvement'] += 1 if stats['no_improvement'] >= args.early_stop: break
def validate_official(args, data_loader, model, global_stats, mode='dev'): """Run one full official validation. Uses exact spans and same exact match/F1 score computation as in the SQuAD script. Extra arguments: offsets: The character start/end indices for the tokens in each context. texts: Map of qid --> raw text of examples context (matches offsets). answers: Map of qid --> list of accepted answers. """ eval_time = Timer() # Run through examples examples = 0 sources, hypotheses, references, copy_dict = dict(), dict(), dict(), dict() with torch.no_grad(): pbar = tqdm(data_loader) for idx, ex in enumerate(pbar): batch_size = ex['batch_size'] ex_ids = list( range(idx * batch_size, (idx * batch_size) + batch_size)) predictions, targets, copy_info = model.predict(ex, replace_unk=True) src_sequences = [code for code in ex['code_text']] examples += batch_size for key, src, pred, tgt in zip(ex_ids, src_sequences, predictions, targets): hypotheses[key] = [pred] references[key] = tgt if isinstance(tgt, list) else [tgt] sources[key] = src if copy_info is not None: copy_info = copy_info.cpu().numpy().astype(int).tolist() for key, cp in zip(ex_ids, copy_info): copy_dict[key] = cp pbar.set_description("%s" % 'Epoch = %d [validating ... ]' % global_stats['epoch']) copy_dict = None if len(copy_dict) == 0 else copy_dict bleu, rouge_l, meteor, precision, recall, f1 = eval_accuracies( hypotheses, references, copy_dict, sources=sources, filename=args.pred_file, print_copy_info=args.print_copy_info, mode=mode) result = dict() result['bleu'] = bleu result['rouge_l'] = rouge_l result['meteor'] = meteor result['precision'] = precision result['recall'] = recall result['f1'] = f1 if mode == 'test': logger.info('test valid official: ' 'bleu = %.2f | rouge_l = %.2f | meteor = %.2f | ' % (bleu, rouge_l, meteor) + 'Precision = %.2f | Recall = %.2f | F1 = %.2f | ' 'examples = %d | ' % (precision, recall, f1, examples) + 'test time = %.2f (s)' % eval_time.time()) else: logger.info( 'dev valid official: Epoch = %d | ' % (global_stats['epoch']) + 'bleu = %.2f | rouge_l = %.2f | ' 'Precision = %.2f | Recall = %.2f | F1 = %.2f | examples = %d | ' % (bleu, rouge_l, precision, recall, f1, examples) + 'valid time = %.2f (s)' % eval_time.time()) return result
def validate_official(args, data_loader, model): """Run one full official validation. Uses exact spans and same exact match/F1 score computation as in the SQuAD script. Extra arguments: offsets: The character start/end indices for the tokens in each context. texts: Map of qid --> raw text of examples context (matches offsets). answers: Map of qid --> list of accepted answers. """ eval_time = Timer() translator = build_translator(model, args) builder = TranslationBuilder(model.tgt_dict, n_best=args.n_best, replace_unk=args.replace_unk) # Run through examples examples = 0 trans_dict, sources = dict(), dict() with torch.no_grad(): pbar = tqdm(data_loader) batch = args.test_batch_size for batch_no, ex in enumerate(pbar): batch_size = ex['batch_size'] ids = list(range(batch_no * batch, (batch_no * batch) + batch_size)) batch_inputs = prepare_batch(ex, model) ret = translator.translate_batch(batch_inputs) targets = [[summ] for summ in ex['summ_text']] translations = builder.from_batch(ret, ex['code_tokens'], targets, ex['src_vocab']) src_sequences = [code for code in ex['code_text']] # with open(args.buggy, 'a') as ww: # for Seq in src_sequences: # source = str(Seq) # ww.write(source.lower() + '\n') for eid, trans, src in zip(ids, translations, src_sequences): trans_dict[eid] = trans sources[eid] = src examples += batch_size hypotheses, references = dict(), dict() for eid, trans in trans_dict.items(): hypotheses[eid] = [' '.join(pred) for pred in trans.pred_sents] hypotheses[eid] = [ constants.PAD_WORD if len(hyp.split()) == 0 else hyp for hyp in hypotheses[eid] ] references[eid] = trans.targets if args.only_generate: with open(args.pred_file, 'w') as fw: json.dump(hypotheses, fw, indent=4) else: bleu, rouge_l, meteor, precision, recall, f1, ind_bleu, ind_rouge = \ eval_accuracies(hypotheses, references) logger.info('beam evaluation official: ' 'bleu = %.2f | rouge_l = %.2f | meteor = %.2f | ' % (bleu, rouge_l, meteor) + 'Precision = %.2f | Recall = %.2f | F1 = %.2f | ' 'examples = %d | ' % (precision, recall, f1, examples) + 'test time = %.2f (s)' % eval_time.time()) with open(args.pred_file, 'w') as fw: for eid, translation in trans_dict.items(): out_dict = OrderedDict() out_dict['id'] = eid out_dict['code'] = sources[eid] # printing all beam search predictions out_dict['predictions'] = [ ' '.join(pred) for pred in translation.pred_sents ] out_dict['references'] = references[eid] out_dict['bleu'] = ind_bleu[eid] # #out_dict['rouge_l'] = ind_rouge[eid] fw.write(json.dumps(out_dict) + '\n') with open(args.predictions, 'w') as fwww: for eid, translation in trans_dict.items(): fwww.write( str([' '.join(pred) for pred in translation.pred_sents])[2:-2] + '\n')