Esempio n. 1
0
def evaluate(args, last=True):
    store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    setup_logging(args, local_rank)
    # only evaluate once
    if rank != 0:
        logging.info('Skipping node {}'.format(rank))
        return
    ctx_l = parse_ctx(args.gpus)
    logging.info(
        'Srarting inference without horovod on the first node on device {}'.
        format(str(ctx_l)))

    cfg, tokenizer, qa_net, use_segmentation = get_network(
        args.model_name, ctx_l, args.classifier_dropout, dtype=args.eval_dtype)
    if args.eval_dtype == 'float16':
        qa_net.cast('float16')
        qa_net.hybridize()

    logging.info('Prepare dev data')
    dev_features = get_squad_features(args, tokenizer, segment='dev')
    dev_data_path = os.path.join(args.data_dir,
                                 'dev-v{}.json'.format(args.version))
    dataset_processor = SquadDatasetProcessor(
        tokenizer=tokenizer,
        doc_stride=args.doc_stride,
        max_seq_length=args.max_seq_length,
        max_query_length=args.max_query_length)
    dev_all_chunk_features = []
    dev_chunk_feature_ptr = [0]
    for feature in dev_features:
        chunk_features = dataset_processor.process_sample(feature)
        dev_all_chunk_features.extend(chunk_features)
        dev_chunk_feature_ptr.append(dev_chunk_feature_ptr[-1] +
                                     len(chunk_features))

    def eval_validation(ckpt_name, best_eval):
        """
        Model inference during validation or final evaluation.
        """
        dev_dataloader = mx.gluon.data.DataLoader(
            dev_all_chunk_features,
            batchify_fn=dataset_processor.BatchifyFunction,
            batch_size=args.eval_batch_size,
            num_workers=0,
            shuffle=False)

        log_interval = args.eval_log_interval
        all_results = []
        epoch_tic = time.time()
        tic = time.time()
        epoch_size = len(dev_features)
        total_num = 0
        log_num = 0
        for batch_idx, dev_batch in enumerate(
                grouper(dev_dataloader, len(ctx_l))):
            # Predict for each chunk
            for sample, ctx in zip(dev_batch, ctx_l):
                if sample is None:
                    continue
                # Copy the data to device
                tokens = sample.data.as_in_ctx(ctx)
                total_num += len(tokens)
                log_num += len(tokens)
                segment_ids = sample.segment_ids.as_in_ctx(
                    ctx) if use_segmentation else None
                valid_length = sample.valid_length.as_in_ctx(ctx)
                p_mask = sample.masks.as_in_ctx(ctx)
                p_mask = 1 - p_mask  # In the network, we use 1 --> no_mask, 0 --> mask
                start_top_logits, start_top_index, end_top_logits, end_top_index, answerable_logits \
                    = qa_net.inference(tokens, segment_ids, valid_length, p_mask,
                                       args.start_top_n, args.end_top_n)
                for i, qas_id in enumerate(sample.qas_id):
                    result = RawResultExtended(
                        qas_id=qas_id,
                        start_top_logits=start_top_logits[i].asnumpy(),
                        start_top_index=start_top_index[i].asnumpy(),
                        end_top_logits=end_top_logits[i].asnumpy(),
                        end_top_index=end_top_index[i].asnumpy(),
                        answerable_logits=answerable_logits[i].asnumpy())

                    all_results.append(result)

            # logging
            if (batch_idx + 1) % log_interval == 0:
                # Output the loss of per step
                toc = time.time()
                logging.info(
                    '[batch {}], Time cost={:.2f},'
                    ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format(
                        batch_idx + 1, toc - tic, log_num / (toc - tic),
                        (epoch_size - total_num) / (total_num /
                                                    (toc - epoch_tic)) / 3600))
                tic = time.time()
                log_num = 0

        epoch_toc = time.time()
        logging.info('Time cost=%2f s, Thoughput=%.2f samples/s',
                     epoch_toc - epoch_tic,
                     total_num / (epoch_toc - epoch_tic))

        all_predictions = collections.OrderedDict()
        all_nbest_json = collections.OrderedDict()
        no_answer_score_json = collections.OrderedDict()
        for index, (left_index, right_index) in enumerate(
                zip(dev_chunk_feature_ptr[:-1], dev_chunk_feature_ptr[1:])):
            chunked_features = dev_all_chunk_features[left_index:right_index]
            results = all_results[left_index:right_index]
            original_feature = dev_features[index]
            qas_ids = set([result.qas_id for result in results] +
                          [feature.qas_id for feature in chunked_features])
            assert len(
                qas_ids) == 1, 'Mismatch Occured between features and results'
            example_qas_id = list(qas_ids)[0]
            assert example_qas_id == original_feature.qas_id, \
                'Mismatch Occured between original feature and chunked features'
            not_answerable_score, best_pred, nbest_json = predict_extended(
                original_feature=original_feature,
                chunked_features=chunked_features,
                results=results,
                n_best_size=args.n_best_size,
                max_answer_length=args.max_answer_length,
                start_top_n=args.start_top_n,
                end_top_n=args.end_top_n)
            no_answer_score_json[example_qas_id] = not_answerable_score
            all_predictions[example_qas_id] = best_pred
            all_nbest_json[example_qas_id] = nbest_json

        if args.version == '2.0':
            exact = 'best_exact'
            f1 = 'best_f1'
            na_prob = no_answer_score_json
        else:
            exact = 'exact'
            f1 = 'f1'
            na_prob = None

        cur_eval, revised_predictions = squad_eval(dev_data_path,
                                                   all_predictions,
                                                   na_prob,
                                                   revise=na_prob is not None)
        logging.info('The evaluated results are {}'.format(
            json.dumps(cur_eval)))

        cur_metrics = 0.5 * (cur_eval[exact] + cur_eval[f1])
        if best_eval:
            best_metrics = 0.5 * (best_eval[exact] + best_eval[f1])
        else:
            best_metrics = 0.

        if cur_metrics > best_metrics:
            logging.info('The evaluated files are saved in {}'.format(
                args.output_dir))
            output_prediction_file = os.path.join(args.output_dir,
                                                  'predictions.json')
            output_nbest_file = os.path.join(args.output_dir,
                                             'nbest_predictions.json')
            na_prob_file = os.path.join(args.output_dir, 'na_prob.json')
            revised_prediction_file = os.path.join(args.output_dir,
                                                   'revised_predictions.json')

            with open(output_prediction_file, 'w') as of:
                of.write(json.dumps(all_predictions, indent=4) + '\n')
            with open(output_nbest_file, 'w') as of:
                of.write(json.dumps(all_nbest_json, indent=4) + '\n')
            with open(na_prob_file, 'w') as of:
                of.write(json.dumps(no_answer_score_json, indent=4) + '\n')
            with open(revised_prediction_file, 'w') as of:
                of.write(json.dumps(revised_predictions, indent=4) + '\n')

            best_eval = cur_eval
            best_eval.update({'best_ckpt': ckpt_name})
        return best_eval

    if args.param_checkpoint and args.param_checkpoint.endswith('.params'):
        ckpt_candidates = [args.param_checkpoint]
    else:
        ckpt_candidates = [
            f for f in os.listdir(args.output_dir) if f.endswith('.params')
        ]
        ckpt_candidates.sort(key=lambda ele: (len(ele), ele))
        ckpt_candidates = [
            os.path.join(args.output_dir, ele) for ele in ckpt_candidates
        ]
    if last:
        ckpt_candidates = ckpt_candidates[-1:]

    best_eval = {}
    for ckpt_path in ckpt_candidates:
        logging.info('Starting evaluate the checkpoint {}'.format(ckpt_path))
        qa_net.load_parameters(ckpt_path, ctx=ctx_l, cast_dtype=True)
        best_eval = eval_validation(ckpt_path, best_eval)

    logging.info('The best evaluated results are {}'.format(
        json.dumps(best_eval)))
    output_eval_results_file = os.path.join(args.output_dir,
                                            'best_results.json')
    with open(output_eval_results_file, 'w') as of:
        of.write(json.dumps(best_eval, indent=4) + '\n')
    return best_eval
Esempio n. 2
0
def evaluate(args):
    store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    # setup_logging(args, local_rank)
    task = get_task(args.task_name)
    if rank != 0:
        logging.info('Skipping node {}'.format(rank))
        return
    ctx_l = parse_ctx(args.gpus)
    logging.info(
        'Srarting inference without horovod on the first node on device {}'.
        format(str(ctx_l)))

    cfg, tokenizer, classify_net, use_segmentation = \
        get_network(args.model_name, ctx_l,
                    args.param_checkpoint,
                    args.backbone_path,
                    task)
    candidate_ckpt = []
    detail_dir = os.path.join(args.output_dir, args.task_name)
    for name in os.listdir(detail_dir):
        if name.endswith(
                '.params'
        ) and args.task_name in name and args.model_name in name:
            candidate_ckpt.append(os.path.join(detail_dir, name))
    best_ckpt = {}
    metrics = task.metric

    def evaluate_by_ckpt(ckpt_name, best_ckpt):
        classify_net.load_parameters(ckpt_name, ctx=ctx_l, cast_dtype=True)
        logging.info('Prepare dev data')

        dev_data, label = get_task_data(args,
                                        tokenizer,
                                        segment='eval',
                                        task=task)
        dev_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()),
                                bf.Stack())
        dataloader = DataLoader(dev_data,
                                batch_size=args.batch_size,
                                batchify_fn=dev_batchify,
                                shuffle=False)

        for sample_l in grouper(dataloader, len(ctx_l)):
            for sample, ctx in zip(sample_l, ctx_l):
                if sample is None:
                    continue
                (token_ids, token_types, valid_length), label = sample
                token_ids = mx.np.array(token_ids, ctx=ctx)
                token_types = mx.np.array(token_types, ctx=ctx)
                valid_length = mx.np.array(valid_length, ctx=ctx)
                scores = classify_net(token_ids, token_types, valid_length)

                if args.task_name == 'sts':
                    label = label.reshape((-1, 1))
                for metric in metrics:
                    metric.update([label], [scores])
                #pred.append(scores)

        for metric in metrics:
            metric_name, result = metric.get()
            logging.info('checkpoint {} get result: {}:{}'.format(
                ckpt_name, metric_name, result))
            if best_ckpt.get(metric_name, [0, ''])[0] < result:
                best_ckpt[metric_name] = [result, ckpt_name]

    for ckpt_name in candidate_ckpt:
        evaluate_by_ckpt(ckpt_name, best_ckpt)
    for metric_name in best_ckpt:
        logging.info(
            'best result on metric {}: is {}, and on checkpoint {}'.format(
                metric_name, best_ckpt[metric_name][0],
                best_ckpt[metric_name][1]))