Exemple #1
0
def train(args, data_loader, model, global_stats, logger):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    ml_loss = AverageMeter()
    perplexity = AverageMeter()
    epoch_time = Timer()

    current_epoch = global_stats['epoch']
    pbar = tqdm(data_loader)

    pbar.set_description("%s" %
                         'Epoch = %d [perplexity = x.xx, ml_loss = x.xx]' %
                         current_epoch)

    # Run one epoch
    for idx, ex in enumerate(pbar):
        bsz = ex['batch_size']
        if args.optimizer in ['sgd', 'adam'
                              ] and current_epoch <= args.warmup_epochs:
            cur_lrate = global_stats['warmup_factor'] * (model.updates + 1)
            for param_group in model.optimizer.param_groups:
                param_group['lr'] = cur_lrate

        net_loss = model.update(ex)
        ml_loss.update(net_loss['ml_loss'], bsz)
        perplexity.update(net_loss['perplexity'], bsz)
        log_info = 'Epoch = %d [perplexity = %.2f, ml_loss = %.2f]' % \
                   (current_epoch, perplexity.avg, ml_loss.avg)

        pbar.set_description("%s" % log_info)
        #break
    kvs = [("perp_tr", perplexity.avg), ("ml_lo_tr", ml_loss.avg),\
               ("epoch_time", epoch_time.time())]
    for k, v in kvs:
        logger.add(current_epoch, **{k: v})
    logger.print(
        'train: Epoch %d | perplexity = %.2f | ml_loss = %.2f | '
        'Time for epoch = %.2f (s)' %
        (current_epoch, perplexity.avg, ml_loss.avg, epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(logger.path + '/best_model.cpt.checkpoint',
                         current_epoch + 1)
Exemple #2
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load and process data files')

    train_exs = []
    if not args.only_test:
        args.dataset_weights = dict()
        for train_src, train_src_tag, train_tgt, dataset_name in \
                zip(args.train_src_files, args.train_src_tag_files,
                    args.train_tgt_files, args.dataset_name):
            train_files = dict()
            train_files['src'] = train_src
            train_files['src_tag'] = train_src_tag
            train_files['tgt'] = train_tgt
            exs = util.load_data(args,
                                 train_files,
                                 max_examples=args.max_examples,
                                 dataset_name=dataset_name)
            lang_name = constants.DATA_LANG_MAP[dataset_name]
            args.dataset_weights[constants.LANG_ID_MAP[lang_name]] = len(exs)
            train_exs.extend(exs)

        logger.info('Num train examples = %d' % len(train_exs))
        args.num_train_examples = len(train_exs)
        for lang_id in args.dataset_weights.keys():
            weight = (1.0 * args.dataset_weights[lang_id]) / len(train_exs)
            args.dataset_weights[lang_id] = round(weight, 2)
        logger.info('Dataset weights = %s' % str(args.dataset_weights))

    dev_exs = []
    for dev_src, dev_src_tag, dev_tgt, dataset_name in \
            zip(args.dev_src_files, args.dev_src_tag_files,
                args.dev_tgt_files, args.dataset_name):
        dev_files = dict()
        dev_files['src'] = dev_src
        dev_files['src_tag'] = dev_src_tag
        dev_files['tgt'] = dev_tgt
        exs = util.load_data(args,
                             dev_files,
                             max_examples=args.max_examples,
                             dataset_name=dataset_name,
                             test_split=True)
        dev_exs.extend(exs)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 1
    if args.only_test:
        if args.pretrained:
            model = Code2NaturalLanguage.load(args.pretrained)
        else:
            if not os.path.isfile(args.model_file):
                raise IOError('No such file: %s' % args.model_file)
            model = Code2NaturalLanguage.load(args.model_file)
    else:
        if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
            # Just resume training, no modifications.
            logger.info('Found a checkpoint...')
            checkpoint_file = args.model_file + '.checkpoint'
            model, start_epoch = Code2NaturalLanguage.load_checkpoint(
                checkpoint_file, args.cuda)
        else:
            # Training starts fresh. But the model state is either pretrained or
            # newly (randomly) initialized.
            if args.pretrained:
                logger.info('Using pretrained model...')
                model = Code2NaturalLanguage.load(args.pretrained, args)
            else:
                logger.info('Training model from scratch...')
                model = init_from_scratch(args, train_exs, dev_exs)

            # Set up optimizer
            model.init_optimizer()
            # log the parameter details
            logger.info(
                'Trainable #parameters [encoder-decoder] {} [total] {}'.format(
                    human_format(model.network.count_encoder_parameters() +
                                 model.network.count_decoder_parameters()),
                    human_format(model.network.count_parameters())))
            table = model.network.layer_wise_parameters()
            logger.info('Breakdown of the trainable paramters\n%s' % table)

    # Use the GPU?
    if args.cuda:
        model.cuda()

    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')

    if not args.only_test:
        train_dataset = data.CommentDataset(train_exs, model)
        if args.sort_by_len:
            train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                    args.batch_size,
                                                    shuffle=True)
        else:
            train_sampler = torch.utils.data.sampler.RandomSampler(
                train_dataset)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            sampler=train_sampler,
            num_workers=args.data_workers,
            collate_fn=vector.batchify,
            pin_memory=args.cuda,
            drop_last=args.parallel)

    dev_dataset = data.CommentDataset(dev_exs, model)
    dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)

    dev_loader = torch.utils.data.DataLoader(dev_dataset,
                                             batch_size=args.test_batch_size,
                                             sampler=dev_sampler,
                                             num_workers=args.data_workers,
                                             collate_fn=vector.batchify,
                                             pin_memory=args.cuda,
                                             drop_last=args.parallel)

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # DO TEST

    if args.only_test:
        stats = {
            'timer': Timer(),
            'epoch': 0,
            'best_valid': 0,
            'no_improvement': 0
        }
        validate_official(args, dev_loader, model, stats, mode='test')

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    else:
        logger.info('-' * 100)
        logger.info('Starting training...')
        stats = {
            'timer': Timer(),
            'epoch': start_epoch,
            'best_valid': 0,
            'no_improvement': 0
        }

        if args.optimizer in ['sgd', 'adam'
                              ] and args.warmup_epochs >= start_epoch:
            logger.info("Use warmup lrate for the %d epoch, from 0 up to %s." %
                        (args.warmup_epochs, args.learning_rate))
            num_batches = len(train_loader.dataset) // args.batch_size
            warmup_factor = (args.learning_rate + 0.) / (num_batches *
                                                         args.warmup_epochs)
            stats['warmup_factor'] = warmup_factor

        for epoch in range(start_epoch, args.num_epochs + 1):
            stats['epoch'] = epoch
            if args.optimizer in ['sgd', 'adam'
                                  ] and epoch > args.warmup_epochs:
                model.optimizer.param_groups[0]['lr'] = \
                    model.optimizer.param_groups[0]['lr'] * args.lr_decay

            train(args, train_loader, model, stats)
            result = validate_official(args, dev_loader, model, stats)

            # Save best valid
            if result[args.valid_metric] > stats['best_valid']:
                logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                            (args.valid_metric, result[args.valid_metric],
                             stats['epoch'], model.updates))
                model.save(args.model_file)
                stats['best_valid'] = result[args.valid_metric]
                stats['no_improvement'] = 0
            else:
                stats['no_improvement'] += 1
                if stats['no_improvement'] >= args.early_stop:
                    break
Exemple #3
0
def validate_official(args, data_loader, model, global_stats, mode='dev'):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.
    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = Timer()
    # Run through examples
    examples = 0
    sources, hypotheses, references, copy_dict = dict(), dict(), dict(), dict()
    with torch.no_grad():
        pbar = tqdm(data_loader)
        for idx, ex in enumerate(pbar):
            batch_size = ex['batch_size']
            ex_ids = list(
                range(idx * batch_size, (idx * batch_size) + batch_size))
            predictions, targets, copy_info = model.predict(ex,
                                                            replace_unk=True)

            src_sequences = [code for code in ex['code_text']]
            examples += batch_size
            for key, src, pred, tgt in zip(ex_ids, src_sequences, predictions,
                                           targets):
                hypotheses[key] = [pred]
                references[key] = tgt if isinstance(tgt, list) else [tgt]
                sources[key] = src

            if copy_info is not None:
                copy_info = copy_info.cpu().numpy().astype(int).tolist()
                for key, cp in zip(ex_ids, copy_info):
                    copy_dict[key] = cp

            pbar.set_description("%s" % 'Epoch = %d [validating ... ]' %
                                 global_stats['epoch'])

    copy_dict = None if len(copy_dict) == 0 else copy_dict
    bleu, rouge_l, meteor, precision, recall, f1 = eval_accuracies(
        hypotheses,
        references,
        copy_dict,
        sources=sources,
        filename=args.pred_file,
        print_copy_info=args.print_copy_info,
        mode=mode)
    result = dict()
    result['bleu'] = bleu
    result['rouge_l'] = rouge_l
    result['meteor'] = meteor
    result['precision'] = precision
    result['recall'] = recall
    result['f1'] = f1

    if mode == 'test':
        logger.info('test valid official: '
                    'bleu = %.2f | rouge_l = %.2f | meteor = %.2f | ' %
                    (bleu, rouge_l, meteor) +
                    'Precision = %.2f | Recall = %.2f | F1 = %.2f | '
                    'examples = %d | ' % (precision, recall, f1, examples) +
                    'test time = %.2f (s)' % eval_time.time())

    else:
        logger.info(
            'dev valid official: Epoch = %d | ' % (global_stats['epoch']) +
            'bleu = %.2f | rouge_l = %.2f | '
            'Precision = %.2f | Recall = %.2f | F1 = %.2f | examples = %d | ' %
            (bleu, rouge_l, precision, recall, f1, examples) +
            'valid time = %.2f (s)' % eval_time.time())

    return result
Exemple #4
0
def validate_official(args, data_loader, model):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.
    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """

    eval_time = Timer()
    translator = build_translator(model, args)
    builder = TranslationBuilder(model.tgt_dict,
                                 n_best=args.n_best,
                                 replace_unk=args.replace_unk)

    # Run through examples
    examples = 0
    trans_dict, sources = dict(), dict()
    with torch.no_grad():
        pbar = tqdm(data_loader)

        batch = args.test_batch_size
        for batch_no, ex in enumerate(pbar):
            batch_size = ex['batch_size']
            ids = list(range(batch_no * batch,
                             (batch_no * batch) + batch_size))
            batch_inputs = prepare_batch(ex, model)

            ret = translator.translate_batch(batch_inputs)
            targets = [[summ] for summ in ex['summ_text']]
            translations = builder.from_batch(ret, ex['code_tokens'], targets,
                                              ex['src_vocab'])

            src_sequences = [code for code in ex['code_text']]
            # with open(args.buggy, 'a') as ww:
            #     for Seq in src_sequences:
            #         source = str(Seq)
            #         ww.write(source.lower() + '\n')

            for eid, trans, src in zip(ids, translations, src_sequences):
                trans_dict[eid] = trans
                sources[eid] = src

            examples += batch_size

    hypotheses, references = dict(), dict()
    for eid, trans in trans_dict.items():
        hypotheses[eid] = [' '.join(pred) for pred in trans.pred_sents]
        hypotheses[eid] = [
            constants.PAD_WORD if len(hyp.split()) == 0 else hyp
            for hyp in hypotheses[eid]
        ]
        references[eid] = trans.targets

    if args.only_generate:
        with open(args.pred_file, 'w') as fw:
            json.dump(hypotheses, fw, indent=4)

    else:
        bleu, rouge_l, meteor, precision, recall, f1, ind_bleu, ind_rouge = \
            eval_accuracies(hypotheses, references)
        logger.info('beam evaluation official: '
                    'bleu = %.2f | rouge_l = %.2f | meteor = %.2f | ' %
                    (bleu, rouge_l, meteor) +
                    'Precision = %.2f | Recall = %.2f | F1 = %.2f | '
                    'examples = %d | ' % (precision, recall, f1, examples) +
                    'test time = %.2f (s)' % eval_time.time())

        with open(args.pred_file, 'w') as fw:
            for eid, translation in trans_dict.items():
                out_dict = OrderedDict()
                out_dict['id'] = eid
                out_dict['code'] = sources[eid]
                # printing all beam search predictions
                out_dict['predictions'] = [
                    ' '.join(pred) for pred in translation.pred_sents
                ]
                out_dict['references'] = references[eid]
                out_dict['bleu'] = ind_bleu[eid]
                # #out_dict['rouge_l'] = ind_rouge[eid]
                fw.write(json.dumps(out_dict) + '\n')

    with open(args.predictions, 'w') as fwww:
        for eid, translation in trans_dict.items():
            fwww.write(
                str([' '.join(pred)
                     for pred in translation.pred_sents])[2:-2] + '\n')