Example #1
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = utils.load_data(args, args.train_file, skip_no_answer=True)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)

    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added_words = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added_words, args.embedding_file)

                logger.info('Expanding char dictionary for new data...')
                # Add words in training + dev examples
                chars = utils.load_chars(args, train_exs + dev_exs)
                added_chars = model.expand_char_dictionary(chars)
                # Load pretrained embeddings for added words
                if args.char_embedding_file:
                    model.load_char_embeddings(added_chars, args.char_embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_exs, dev_exs)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info('-' * 100)
            logger.info('Counting %d most frequent question words' %
                        args.tune_partial)
            top_words = utils.top_question_words(
                args, train_exs, model.word_dict
            )
            for word in top_words[:5]:
                logger.info(word)
            logger.info('...')
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')

    train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                args.batch_size,
                                                shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    # if args.use_sentence_selector:
    #     train_batcher = vector.sentence_batchifier(model, single_answer=True)
    #     batching_function = train_batcher.batchify
    # else:
    batching_function = vector.batchify
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=batching_function,
        pin_memory=args.cuda,
    )
    dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    # if args.use_sentence_selector:
    #     dev_batcher = vector.sentence_batchifier(model, single_answer=False)
    #     batching_function = dev_batcher.batchify
    # else:
    batching_function = vector.batchify
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=batching_function,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}

    # --------------------------------------------------------------------------
    # QUICKLY VALIDATE ON PRETRAINED MODEL

    if args.global_mode == "test":
        result1 = validate_unofficial(args, dev_loader, model, stats, mode='dev')
        result2 = validate_official(args, dev_loader, model, stats,
                                    dev_offsets, dev_texts, dev_answers)
        print(result2[args.valid_metric])
        print(result1["exact_match"])

        validate_adversarial(args, model, stats, mode="dev")
        exit(0)


    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        validate_unofficial(args, train_loader, model, stats, mode='train')

        # Validate unofficial (dev)
        result = validate_unofficial(args, dev_loader, model, stats, mode='dev')

        # Validate official
        if args.official_eval:
            result = validate_official(args, dev_loader, model, stats,
                                       dev_offsets, dev_texts, dev_answers)

        # Save best valid
        if args.valid_metric is None or args.valid_metric == 'None':
            model.save(args.model_file)
        elif result[args.valid_metric] > stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                        (args.valid_metric, result[args.valid_metric],
                         stats['epoch'], model.updates))
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]
Example #2
0
File: train.py Project: xllg/Ernn
def main(args):
    # --------------------------------------------------------
    # Data
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = utils.load_data(args, args.train_file, skip_no_answer=True)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)
    # --------------------------------------------------------
    # Model
    logger.info('-' * 100)
    start_epoch = 0
    logger.info('Training model from scratch...')
    model = init_from_scratch(args, train_exs, dev_exs)

    # Set up partial tuning of embeddings
    if args.tune_partial > 0:
        logger.info('-' * 100)
        logger.info('Counting %d most frequent question words' %
                    args.tune_partial)
        top_words = utils.top_question_words(args, train_exs, model.word_dict)
        for word in top_words[:5]:
            logger.info(word)
        logger.info('...')
        for word in top_words[-6:-1]:
            logger.info(word)
        model.tune_embeddings([w[0] for w in top_words])

    # Set up optimizer
    model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')
    train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(
            train_dataset.lengths(), args.batch_size,
            shuffle=True)  # shuffle设置为true每个batch重新打乱顺序
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,  # 每个batch加载多少个样本
        sampler=train_sampler,  # 从数据集中提取样本的策略,如果指定,则忽略shuffle参数
        num_workers=args.data_workers,  # 用多少个子进程加载数据
        collate_fn=vector.batchify,  # 合并样本形成小批量
        pin_memory=args.cuda,  # 如果为true,那么数据加载器将张量复制到cuda固定的内存中,然后再返回
    )
    dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        validate_unofficial(args, train_loader, model, stats, mode='train')

        # Validate unofficial (dev)
        result = validate_unofficial(args,
                                     dev_loader,
                                     model,
                                     stats,
                                     mode='dev')

        # Validate official
        if args.official_eval:
            result = validate_official(args, dev_loader, model, stats,
                                       dev_offsets, dev_texts, dev_answers)

        # Save best valid
        if result[args.valid_metric] > stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                        (args.valid_metric, result[args.valid_metric],
                         stats['epoch'], model.updates))
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]
Example #3
0
def validate_adversarial(args, model, global_stats, mode="dev"):
    # create dataloader for dev sets, load thier jsons, integrate the function


    for idx, dataset_file in enumerate(args.adv_dev_json):

        predictions = {}

        logger.info("Validating Adversarial Dataset %s" % dataset_file)
        exs = utils.load_data(args, args.adv_dev_file[idx])
        logger.info('Num dev examples = %d' % len(exs))
        ## Create dataloader
        dev_dataset = data.ReaderDataset(exs, model, single_answer=False)
        if args.sort_by_len:
            dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                                  args.test_batch_size,
                                                  shuffle=False)
        else:
            dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
        # if args.use_sentence_selector:
        #     batching_function = vector.batchify_sentences
        # else:
        batching_function = vector.batchify
        dev_loader = torch.utils.data.DataLoader(
            dev_dataset,
            batch_size=args.test_batch_size,
            sampler=dev_sampler,
            num_workers=args.data_workers,
            collate_fn=batching_function,
            pin_memory=args.cuda,
        )

        texts = utils.load_text(dataset_file)
        offsets = {ex['id']: ex['offsets'] for ex in exs}
        answers = utils.load_answers(dataset_file)

        eval_time = utils.Timer()
        f1 = utils.AverageMeter()
        exact_match = utils.AverageMeter()

        examples = 0
        bad_examples = 0
        for ex in dev_loader:
            ex_id, batch_size = ex[-1], ex[0].size(0)
            chosen_offset = ex[-2]
            pred_s, pred_e, _ = model.predict(ex)

            for i in range(batch_size):
                if pred_s[i][0] >= len(offsets[ex_id[i]]) or pred_e[i][0] >= len(offsets[ex_id[i]]):
                    bad_examples += 1
                    continue
                if args.use_sentence_selector:
                    s_offset = chosen_offset[i][pred_s[i][0]][0]
                    e_offset = chosen_offset[i][pred_e[i][0]][1]
                else:
                    s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
                    e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
                prediction = texts[ex_id[i]][s_offset:e_offset]

                predictions[ex_id[i]] = prediction

                ground_truths = answers[ex_id[i]]
                exact_match.update(utils.metric_max_over_ground_truths(
                    utils.exact_match_score, prediction, ground_truths))
                f1.update(utils.metric_max_over_ground_truths(
                    utils.f1_score, prediction, ground_truths))

            examples += batch_size

        logger.info('dev valid official for dev file %s : Epoch = %d | EM = %.2f | ' %
                    (dataset_file, global_stats['epoch'], exact_match.avg * 100) +
                    'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                    (f1.avg * 100, examples, eval_time.time()))

        orig_f1_score = 0.0
        orig_exact_match_score = 0.0
        adv_f1_scores = {}  # Map from original ID to F1 score
        adv_exact_match_scores = {}  # Map from original ID to exact match score
        adv_ids = {}
        all_ids = set()  # Set of all original IDs
        f1 = exact_match = 0
        dataset = json.load(open(dataset_file))['data']
        for article in dataset:
            for paragraph in article['paragraphs']:
                for qa in paragraph['qas']:
                    orig_id = qa['id'].split('-')[0]
                    all_ids.add(orig_id)
                    if qa['id'] not in predictions:
                        message = 'Unanswered question ' + qa['id'] + ' will receive score 0.'
                        # logger.info(message)
                        continue
                    ground_truths = list(map(lambda x: x['text'], qa['answers']))
                    prediction = predictions[qa['id']]
                    cur_exact_match = utils.metric_max_over_ground_truths(utils.exact_match_score,
                                                                    prediction, ground_truths)
                    cur_f1 = utils.metric_max_over_ground_truths(utils.f1_score, prediction, ground_truths)
                    if orig_id == qa['id']:
                        # This is an original example
                        orig_f1_score += cur_f1
                        orig_exact_match_score += cur_exact_match
                        if orig_id not in adv_f1_scores:
                            # Haven't seen adversarial example yet, so use original for adversary
                            adv_ids[orig_id] = orig_id
                            adv_f1_scores[orig_id] = cur_f1
                            adv_exact_match_scores[orig_id] = cur_exact_match
                    else:
                        # This is an adversarial example
                        if (orig_id not in adv_f1_scores or adv_ids[orig_id] == orig_id
                            or adv_f1_scores[orig_id] > cur_f1):
                            # Always override if currently adversary currently using orig_id
                            adv_ids[orig_id] = qa['id']
                            adv_f1_scores[orig_id] = cur_f1
                            adv_exact_match_scores[orig_id] = cur_exact_match
        orig_f1 = 100.0 * orig_f1_score / len(all_ids)
        orig_exact_match = 100.0 * orig_exact_match_score / len(all_ids)
        adv_exact_match = 100.0 * sum(adv_exact_match_scores.values()) / len(all_ids)
        adv_f1 = 100.0 * sum(adv_f1_scores.values()) / len(all_ids)
        logger.info("For the file %s Original Exact Match : %.4f ; Original F1 : : %.4f | "
                    % (dataset_file, orig_exact_match, orig_f1)
                    + "Adversarial Exact Match : %.4f ; Adversarial F1 : : %.4f " % (adv_exact_match, adv_f1))
Example #4
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = utils.load_data(args, args.train_file, skip_no_answer=True)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)
    else:
        dev_texts = None
        dev_offsets = None
        dev_answers = None

    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added_words = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added_words, args.embedding_file)

                logger.info('Expanding char dictionary for new data...')
                # Add words in training + dev examples
                chars = utils.load_chars(args, train_exs + dev_exs)
                added_chars = model.expand_char_dictionary(chars)
                # Load pretrained embeddings for added words
                if args.char_embedding_file:
                    model.load_char_embeddings(added_chars,
                                               args.char_embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_exs, dev_exs)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info('-' * 100)
            logger.info('Counting %d most frequent question words' %
                        args.tune_partial)
            top_words = utils.top_question_words(args, train_exs,
                                                 model.word_dict)
            for word in top_words[:5]:
                logger.info(word)
            logger.info('...')
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')

    train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                args.batch_size,
                                                shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )
    dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    model_prefix = os.path.join(args.model_dir, args.model_name)

    kept_models = []
    best_model_path = ''
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        logger.info('eval: train split unofficially...')
        validate_unofficial(args, train_loader, model, stats, mode='train')

        if args.official_eval:
            # Validate official (dev)
            logger.info('eval: dev split unofficially..')
            result = validate_official(args, dev_loader, model, stats,
                                       dev_offsets, dev_texts, dev_answers)
        else:
            # Validate unofficial (dev)
            logger.info(
                'train: evaluating dev split evaluating dev official...')
            result = validate_unofficial(args,
                                         dev_loader,
                                         model,
                                         stats,
                                         mode='dev')

        em = result['exact_match']
        f1 = result['f1']
        suffix = 'em_{:4.2f}-f1_{:4.2f}.mdl'.format(em, f1)
        # Save best valid
        model_file = '{}-epoch_{}-{}'.format(model_prefix, epoch, suffix)
        if args.valid_metric:
            if result[args.valid_metric] > stats['best_valid']:
                for f in glob.glob('{}-best*'.format(model_prefix)):
                    os.remove(f)
                logger.info('eval: dev best %s = %.2f (epoch %d, %d updates)' %
                            (args.valid_metric, result[args.valid_metric],
                             stats['epoch'], model.updates))
                model_file = '{}-best-epoch_{}-{}'.format(
                    model_prefix, epoch, suffix)
                best_model_path = model_file
                model.save(model_file)
                stats['best_valid'] = result[args.valid_metric]
                for f in kept_models:
                    os.remove(f)
                kept_models.clear()
            else:
                model.save(model_file)
                kept_models.append(model_file)
                if len(kept_models) >= args.early_stop:
                    logger.info(
                        'Finished training due to %s not improved for %d epochs, best model is at: %s'
                        %
                        (args.valid_metric, args.early_stop, best_model_path))
                    return
        else:
            # just save model every epoch since no validation metric is given
            model.save(model_file)