def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        retrievers = dict()
        for key, kg_args in config_dict['retrievers'].items():
            file_path = kg_args['file_path']
            retrievers[key] = initialize_kg_retriever(key, file_path)
            max_length = kg_args['max_concept_length']
            retrievers[key].update_max_concept_length(max_length)

        config = cls(**config_dict)
        if len(retrievers.items()) > 0:
            config.add_kgretrievers(retrievers)
        config.set_sizes()

        if hasattr(config, "pruned_heads"):
            config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())

        # Update config with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

        logger.info("Model config %s", str(config))
        if return_unused_kwargs:
            return config, kwargs
        else:
            return config
示例#2
0
def evaluate(args,
             model,
             processor,
             tokenizer,
             global_step,
             input_dir,
             prefix=""):
    retrievers = dict()
    for kg in args.use_kgs:
        logger.info("Initialize kg:{}".format(kg))
        kg_path = os.path.join(input_dir, args.kg_paths[kg])
        data_path = os.path.join(args.data_dir, args.kg_paths[kg])

        if not os.path.exists(kg_path):
            logger.warning(
                "need prepare training dataset firstly, program exit")
            exit()

        retrievers[kg] = initialize_kg_retriever(kg, kg_path, data_path,
                                                 args.cache_file_suffix)

    dataset, examples_tokenized, features, wn_synset_graphs, wn_synset_graphs_label_dict = \
        load_and_cache_examples(args,
                                processor,
                                retrievers,
                                relation_list=args.relation_list,
                                input_dir=input_dir,
                                evaluate=True,
                                output_examples=True)

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.mkdir(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    eval_sampler = SequentialSampler(
        dataset) if args.local_rank == -1 else DistributedSampler(dataset)
    eval_dataloader = DataLoader(dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # multi-gpu evaluate
    if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)

    if args.local_rank != -1 and not isinstance(
            model, torch.nn.parallel.DistributedDataParallel):
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Dataset size = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)

    if args.local_rank == -1:
        logger.warning("program exits and please use pytorch DDP framework")
        exit()
    else:
        # all_results = []
        all_start_logits = torch.tensor([],
                                        dtype=torch.float32,
                                        device=args.device)
        all_end_logits = torch.tensor([],
                                      dtype=torch.float32,
                                      device=args.device)
        all_unique_ids = []
        # start_time = timeit.default_timer()
        epoch_iterator = tqdm(eval_dataloader,
                              desc="Evaluating Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)
            batch_synset_graphs = batch[3]
            with torch.no_grad():
                inputs = create_input(args,
                                      batch,
                                      global_step,
                                      batch_synset_graphs=batch_synset_graphs,
                                      wn_synset_graphs=wn_synset_graphs,
                                      evaluate=True)
                feature_indices = batch[3]

                outputs = model(**inputs)

            all_start_logits = torch.cat((all_start_logits, outputs[0]), dim=0)
            all_end_logits = torch.cat((all_end_logits, outputs[1]), dim=0)

            for i, feature_index in enumerate(feature_indices):
                eval_feature = features[feature_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_unique_ids.append(unique_id)

        all_unique_ids = torch.tensor(all_unique_ids,
                                      dtype=torch.long,
                                      device=args.device)

        start_time = timeit.default_timer()

        all_start_logits_list = [
            torch.zeros_like(all_start_logits, device=args.device)
            for _ in range(torch.distributed.get_world_size())
        ]
        all_end_logits_list = [
            torch.zeros_like(all_end_logits, device=args.device)
            for _ in range(torch.distributed.get_world_size())
        ]
        all_unique_ids_list = [
            torch.zeros_like(all_unique_ids, device=args.device)
            for _ in range(torch.distributed.get_world_size())
        ]

        all_gather(all_start_logits_list, all_start_logits)
        all_gather(all_end_logits_list, all_end_logits)
        all_gather(all_unique_ids_list, all_unique_ids)

        if args.local_rank == 0:
            start_time = timeit.default_timer()
            all_results = []
            all_unique_ids_list = all_unique_ids_list
            all_start_logits_list = all_start_logits_list
            all_end_logits_list = all_end_logits_list

            for batch_idx, batch_unique_ids in enumerate(all_unique_ids_list):
                batch_start_logits = all_start_logits_list[batch_idx]
                batch_end_logits = all_end_logits_list[batch_idx]
                for i, unique_id in enumerate(batch_unique_ids):
                    start_logits, end_logits = to_list(
                        batch_start_logits[i]), to_list(batch_end_logits[i])
                    result = RecordResult(int(unique_id.cpu().numpy()),
                                          start_logits, end_logits)

                    all_results.append(result)

            evalTime = timeit.default_timer() - start_time
            logger.info(
                "  Evaluation done in total %f secs (%f sec per example)",
                evalTime, evalTime / len(dataset))

            # Compute predictions
            output_prediction_file = os.path.join(
                args.output_dir, "predictions_{}.json".format(prefix))
            output_result = os.path.join(args.output_dir,
                                         "results_{}.jsonl".format(prefix))

            predictions = RecordProcessor.compute_predictions_logits(
                examples_tokenized,
                features,
                all_results,
                args.n_best_size,
                args.max_answer_length,
                output_prediction_file,
                output_result,
                args.verbose_logging,
                os.path.join(args.data_dir, args.predict_file),
                tokenizer,
                is_testing=args.test,
            )

            # Compute the F1 and exact scores.
            if not args.test:
                results = RecordProcessor.record_evaluate(
                    examples_tokenized,
                    predictions,
                    relate_path=args.output_dir)
                return results
        else:
            return None
示例#3
0
def main():
    parser = argparse.ArgumentParser()

    model_g = ArgumentGroup(parser, "model", "model configuration and path.")

    model_g.add_arg("dataset", str, "record", "used dataset")
    model_g.add_arg("is_update_max_concept", bool, True,
                    "weather update max concept for kg retriver")
    model_g.add_arg("full_table", bool, True, "full_table")
    model_g.add_arg("test", bool, False, "weather load superglue test set")
    model_g.add_arg("use_wn", bool, True, "wn")
    model_g.add_arg("use_nell", bool, True, "nell")

    model_g.add_arg("sentinel_trainable", bool, False, "sentinel_trainable")
    model_g.add_arg("memory_bank_update", bool, False, "memory_bank_update")
    model_g.add_arg("memory_bank_update_steps", int, 500,
                    "memory_bank_update_steps")
    model_g.add_arg("memory_bank_keep_coef", float, 0.0, "what percent keep")
    model_g.add_arg("use_context_graph", bool, True, "use_context_graph")

    model_g.add_arg("schedule_strategy", str, "linear", "schedule_strategy")
    model_g.add_arg("tokenizer_path", str, "", "tokenizer_path")
    model_g.add_arg("save_model", bool, True, "whether save model")
    model_g.add_arg("data_preprocess", bool, False, "data process")
    model_g.add_arg("data_preprocess_evaluate", bool, False,
                    "data_preprocess_evaluate")

    # multi-relational part
    model_g.add_arg("relation_agg", str, "sum",
                    "the method to aggeregate multi-relational neoghbor")

    model_g.add_arg("is_lemma", bool, False, "whether trigger lemma")
    model_g.add_arg("is_filter", bool, True, "weather filter node not in wn18")
    model_g.add_arg("is_clean", bool, True,
                    "weather filter node not in repeated_id")
    model_g.add_arg("is_morphy", bool, False, "weather morphy")
    model_g.add_arg("fewer_label", bool, False, "weather fewer_label")
    model_g.add_arg("label_rate", float, 0.1, "label rate")

    model_g.add_arg("relation_list", list, [
        "_hyponym", "_hypernym", "_derivationally_related_form",
        "_member_meronym", "_member_holonym", "_part_of", "_has_part",
        "_member_of_domain_topic", "_synset_domain_topic_of",
        "_instance_hyponym", "_instance_hypernym", "_also_see", "_verb_group",
        "_member_of_domain_region", "_synset_domain_region_of",
        "_member_of_domain_usage", "_synset_domain_usage_of", "_similar_to"
    ], "The used relation.")
    model_g.add_arg("is_all_relation", bool, True, "use all relations")
    model_g.add_arg("selected_relation", str,
                    "_hyponym,_hypernym,_derivationally_related_form",
                    "relations")
    model_g.add_arg("wn18_dir", str, "", "wn18 dir")

    # SSL part
    model_g.add_arg("use_consistent_loss_wn", bool, False,
                    "add consistent loss between entity embedding from WN.")
    model_g.add_arg("warm_up", int, 10000, "warm_up_iterations")
    model_g.add_arg("consistent_loss_wn_coeff", float, 2.0,
                    "Weight decay if we apply some.")
    model_g.add_arg("consistent_loss_type", str, "kld", "consistent loss type")
    model_g.add_arg("mark", str, "test1", "mark")
    model_g.add_arg("tensorboard_dir", str, "./", "tensorboard_dir")
    model_g.add_arg("debug", bool, False, "debug")

    model_g.add_arg(
        "model_name_or_path", str, "",
        "Path to pretrained model or model identifier from huggingface.co/models"
    )
    model_g.add_arg(
        "config_name", str, "",
        "Pretrained config name or path if not the same as model_name")
    model_g.add_arg("model_type", str, "kelm",
                    "The classification model to be used.")
    model_g.add_arg("text_embed_model", str, "bert",
                    "The model for embedding texts in KELM model.")
    model_g.add_arg("output_dir", str, "../outputs/test",
                    "Path to save checkpoints.")
    model_g.add_arg("overwrite_output_dir", bool, True,
                    "Overwrite the content of the output directory.")
    model_g.add_arg(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    model_g.add_arg("per_gpu_train_batch_size", int, 6,
                    "Batch size per GPU/CPU for training.")
    model_g.add_arg("per_gpu_eval_batch_size", int, 4,
                    "Batch size per GPU/CPU for evaluation.")
    model_g.add_arg(
        "max_steps", int, -1,
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    model_g.add_arg(
        "gradient_accumulation_steps", int, 1,
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    model_g.add_arg("num_train_epochs", float, 10,
                    "Total number of training epochs to perform.")
    model_g.add_arg("weight_decay", float, 0.01,
                    "Weight decay if we apply some.")
    model_g.add_arg("learning_rate", float, 3e-4,
                    "The initial learning rate for Adam.")
    model_g.add_arg("adam_epsilon", float, 1e-8, "Epsilon for Adam optimizer.")
    model_g.add_arg("warmup_steps", int, 10,
                    "Linear warmup over warmup_steps.")
    model_g.add_arg("max_grad_norm", float, 1.0, "Max gradient norm.")
    model_g.add_arg("evaluate_steps", int, 2,
                    "Evaluate every X updates steps.")
    model_g.add_arg("evaluate_epoch", float, 0.0,
                    "evaluate every X update epoch")

    model_g.add_arg("save_steps", int, 1, "Save every X updates steps.")
    model_g.add_arg("evaluate_during_training", bool, True,
                    "Run evaluation during training at each logging step.")
    model_g.add_arg(
        "n_best_size", int, 20,
        "The total number of n-best predictions to generate in the nbest_predictions.json output file."
    )
    model_g.add_arg(
        "verbose_logging", bool, False,
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    model_g.add_arg("init_dir", str, "",
                    "The path of loading pre-trained model.")
    model_g.add_arg("initializer_range", float, 0.02,
                    "The initializer range for KELM")
    model_g.add_arg("cat_mul", bool, True, "The output part of vector in KELM")
    model_g.add_arg("cat_sub", bool, True, "The output part of vector in KELM")
    model_g.add_arg("cat_twotime", bool, True,
                    "The output part of vector in KELM")
    model_g.add_arg("cat_twotime_mul", bool, True,
                    "The output part of vector in KELM")
    model_g.add_arg("cat_twotime_sub", bool, False,
                    "The output part of vector in KELM")

    data_g = ArgumentGroup(
        parser, "data", "Data paths, vocab paths and data processing options")
    data_g.add_arg("train_file", str, "record/train_0831.json",
                   "ReCoRD json for training. E.g., train.json.")
    data_g.add_arg("predict_file", str, "record/dev_0831.json",
                   "ReCoRD json for predictions. E.g. dev.json.")
    data_g.add_arg("cache_file_suffix", str, "test",
                   "The suffix of cached file.")
    data_g.add_arg("cache_dir", str, "", "The cached data path.")
    data_g.add_arg("cache_store_dir", str, "", "The cached data path.")
    data_g.add_arg(
        "data_dir", str, "",
        "The input data dir. Should contain the .json files for the task." +
        "If no data dir or train/predict files are specified, will run with tensorflow_datasets."
    )

    data_g.add_arg("vocab_path", str, "vocab.txt", "Vocabulary path.")
    data_g.add_arg(
        "do_lower_case", bool, False,
        "Whether to lower case the input text. Should be True for uncased models and False for cased models."
    )
    data_g.add_arg("seed", int, 42, "Random seed.")
    data_g.add_arg("kg_paths", dict, {
        "wordnet": "kgs/",
        "nell": "kgs/"
    }, "The paths of knowledge graph files.")
    data_g.add_arg("wn_concept_embedding_path", str,
                   "embedded/wn_concept2vec.txt",
                   "The embeddings of concept in knowledge graph : Wordnet.")
    data_g.add_arg("nell_concept_embedding_path", str,
                   "embedded/nell_concept2vec.txt",
                   "The embeddings of concept in knowledge graph : Nell.")
    data_g.add_arg("use_kgs", list, ['nell', 'wordnet'],
                   "The used knowledge graphs.")
    data_g.add_arg(
        "doc_stride", int, 128,
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    data_g.add_arg("max_seq_length", int, 384,
                   "Number of words of the longest seqence.")
    data_g.add_arg("max_query_length", int, 64, "Max query length.")
    data_g.add_arg("max_answer_length", int, 30, "Max answer length.")
    data_g.add_arg("no_stopwords", bool, True, "Whether to include stopwords.")
    data_g.add_arg("ignore_length", int, 0, "The smallest size of token.")
    data_g.add_arg("print_loss_step", int, 100, "The steps to print loss.")

    run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
    run_type_g.add_arg("use_fp16", bool, False,
                       "Whether to use fp16 mixed precision training.")
    run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
    run_type_g.add_arg("max_n_gpu", int, 100,
                       "The maximum number of GPU to use.")
    run_type_g.add_arg("use_fast_executor", bool, False,
                       "If set, use fast parallel executor (in experiment).")
    run_type_g.add_arg(
        "num_iteration_per_drop_scope", int, 1,
        "Ihe iteration intervals to clean up temporary variables.")
    run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
    run_type_g.add_arg("do_eval", bool, False,
                       "Whether to perform evaluation during training.")
    run_type_g.add_arg("do_predict", bool, False,
                       "Whether to perform prediction.")
    run_type_g.add_arg("freeze", bool, True, "freeze bert parameters")
    run_type_g.add_arg("server_ip", str, "",
                       "Can be used for distant debugging.")
    run_type_g.add_arg(
        "chunksize", int, 1024,
        "The chunksize for multiprocessing to convert examples to features.")
    run_type_g.add_arg("server_port", str, "",
                       "Can be used for distant debugging.")
    run_type_g.add_arg("local_rank", int, -1,
                       "Index for distributed training on gpus.")
    run_type_g.add_arg("threads", int, 50,
                       "multiple threads for converting example to features")
    run_type_g.add_arg("overwrite_cache", bool, False,
                       "Overwrite the cached training and evaluation sets")
    run_type_g.add_arg(
        "eval_all_checkpoints", bool, False,
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    run_type_g.add_arg(
        "min_diff_steps", int, 50,
        "The minimum saving steps before the last maximum steps.")
    args = parser.parse_args()

    logging.getLogger("transformers.modeling_utils").setLevel(
        logging.WARNING)  # Reduce model loading logs

    if not args.is_all_relation:
        args.relation_list = args.selected_relation.split(",")
        logger.info("not use all relation, relation_list: {}".format(
            args.relation_list))

    if args.doc_stride >= args.max_seq_length - args.max_query_length:
        logger.warning(
            "WARNING - You've set a doc stride which may be superior to the document length in some "
            "examples. This could result in errors when building features from the examples. Please reduce the doc "
            "stride or increase the maximum length to ensure the features are correctly built."
        )

    if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
            and args.do_train and not args.overwrite_output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or not args.use_cuda:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        device = torch.device(
            "cuda" if torch.cuda.is_available() and args.use_cuda else "cpu")
        args.n_gpu = 0 if not args.use_cuda else min(args.max_n_gpu,
                                                     torch.cuda.device_count())
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1

    args.device = device

    if args.local_rank in [-1, 0] and not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARNING,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.use_fp16,
    )

    # Set seed
    set_seed(args)

    logger.info("Parameters from arguments are:\n{}".format(args))

    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.use_fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.use_fp16:
        try:
            import apex
            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    processor = RecordProcessor(args)

    input_dir = os.path.join(
        args.cache_store_dir, "cached_{}_{}".format(
            args.model_type,
            str(args.cache_file_suffix),
        ))
    if not os.path.exists(input_dir):
        os.mkdir(input_dir)

    if args.full_table:
        logger.warning("set full_table False and program exits")
        exit()
    else:
        args.wn_def_embed_mat_dir = os.path.join(
            input_dir, args.cache_file_suffix) + "_" + "definition_embedding"

    # if not os.path.exists(args.wn_def_embed_mat_dir):
    #     data_path = os.path.join(args.data_dir, args.kg_paths["wordnet"])
    #     definition_embedding_mat = create_definition_table(args, data_path)
    #
    #     torch.save({"definition_embedding_mat": definition_embedding_mat}, args.wn_def_embed_mat_dir)
    #
    #     logger.info("definition embedding is done. program exits.")
    #     exit()

    ## create data
    retrievers = dict()
    for kg in args.use_kgs:
        logger.info("Initialize kg:{}".format(kg))
        kg_path = os.path.join(input_dir, args.kg_paths[kg])
        data_path = os.path.join(args.data_dir, args.kg_paths[kg])

        retrievers[kg] = initialize_kg_retriever(kg, kg_path, data_path,
                                                 args.cache_file_suffix)

    if args.data_preprocess:
        logger.info("begin preprocess")
        create_dataset(args,
                       processor,
                       retrievers,
                       relation_list=args.relation_list,
                       evaluate=args.data_preprocess_evaluate,
                       input_dir=input_dir)

        logger.info("data preprocess is done")

    # Load pretrained model and tokenizers
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
    tokenizer, model = configure_tokenizer_model(args, logger, retrievers)
    if args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    model.to(args.device)
    results = evaluate(args,
                       model,
                       processor,
                       tokenizer,
                       100,
                       input_dir,
                       prefix=args.mark)

    if args.local_rank in [-1, 0]:
        logger.info("results: {}".format(results))

    logger.info("eval is done")
示例#4
0
def evaluate(args, model, processor, tokenizer, global_step, input_dir, prefix=""):
    retrievers = dict()
    for kg in args.use_kgs:
        logger.info("Initialize kg:{}".format(kg))
        kg_path = os.path.join(input_dir, args.kg_paths[kg])
        data_path = os.path.join(args.data_dir, args.kg_paths[kg])

        if not os.path.exists(kg_path):
            logger.warning("need prepare training dataset firstly, program exit")
            exit()

        retrievers[kg] = initialize_kg_retriever(kg, kg_path, data_path, args.cache_file_suffix)

    dataset, examples_tokenized, features, wn_synset_graphs, wn_synset_graphs_label_dict = \
        load_and_cache_examples(args,
                                processor,
                                retrievers,
                                relation_list=args.relation_list,
                                input_dir=input_dir,
                                evaluate=True,
                                output_examples=True)

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.mkdir(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset, shuffle=False)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # synset_graphs_batch = []
    # for batch_index in eval_dataloader.batch_sampler:
    #     synset_graphs_batch.append([i for i in batch_index])

    # multi-gpu evaluate
    if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)

    if args.local_rank != -1 and not isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Dataset size = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)


    if args.local_rank == -1:
        logger.warning("program exits and please use pytorch DDP framework")
        exit()
    else:
        # all_results = []
        # all_start_logits = torch.tensor([], dtype=torch.float32, device=args.device)
        # all_end_logits = torch.tensor([], dtype=torch.float32, device=args.device)
        # all_unique_ids = []
        all_pred = torch.tensor([], dtype=torch.long, device=args.device)
        all_label_ids = torch.tensor([], dtype=torch.long, device=args.device)
        all_question_ids = torch.tensor([], dtype=torch.long, device=args.device)

        # start_time = timeit.default_timer()
        epoch_iterator = tqdm(eval_dataloader, desc="Evaluating Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)
            batch_synset_graphs = batch[3]
            with torch.no_grad():
                inputs = create_input(args, batch, global_step, batch_synset_graphs=batch_synset_graphs,
                                      wn_synset_graphs=wn_synset_graphs, evaluate=True)
                feature_indices = batch[3]

                outputs = model(**inputs)

            logits, label_ids, qas_ids = outputs[1], outputs[2], outputs[3]
            all_pred = torch.cat((all_pred, torch.argmax(logits, axis=-1)), dim=0)
            all_label_ids = torch.cat((all_label_ids, label_ids), dim=0)
            all_question_ids = torch.cat((all_question_ids, qas_ids), dim=0)

        start_time = timeit.default_timer()

        all_pred_list = [torch.zeros_like(all_pred, device=args.device) for _ in
                                 range(torch.distributed.get_world_size())]
        all_label_ids_list = [torch.zeros_like(all_label_ids, device=args.device) for _ in
                               range(torch.distributed.get_world_size())]
        all_question_ids_list = [torch.zeros_like(all_question_ids, device=args.device) for _ in
                               range(torch.distributed.get_world_size())]

        all_gather(all_pred_list, all_pred)
        all_gather(all_label_ids_list, all_label_ids)
        all_gather(all_question_ids_list, all_question_ids)

        logger.info(
            "time for gather communication:{} in rank {}".format(timeit.default_timer() - start_time, args.local_rank))

        if args.local_rank == 0:
            all_results = []
            all_pred_list = all_pred_list
            all_label_ids_list = all_label_ids_list
            all_question_ids_list = all_question_ids_list

            preds = np.asarray([], dtype=int)
            label_values = np.asarray([], dtype=int)
            question_ids = np.asarray([], dtype=int)
            for batch_idx, batch_preds in enumerate(all_pred_list):
                preds = np.concatenate((preds, batch_preds.cpu().detach().numpy()), axis=0)
                label_values = np.concatenate((label_values, all_label_ids_list[batch_idx].cpu().detach().numpy()), axis=0)
                question_ids = np.concatenate((question_ids, all_question_ids_list[batch_idx].cpu().detach().numpy()), axis=0)

            if not args.test:
                df = pd.DataFrame({'label_values': label_values, 'question_ids': question_ids})
                assert "label_values" in df.columns
                assert "question_ids" in df.columns
                df["preds"] = preds
                # noinspection PyUnresolvedReferences
                exact_match = (
                    df.groupby("question_ids")
                        .apply(lambda _: (_["preds"] == _["label_values"]).all())
                        .mean()
                )
                exact_match = float(exact_match)
                f1 = f1_score(y_true=df["label_values"], y_pred=df["preds"])

                results = {'exact_match': exact_match, 'f1': f1}
            else:
                results = None
            if args.write_preds:
                guids = []
                for f in features:
                    guids.append(f.guid[0])
                guids = np.asarray(guids, dtype='<U18')
                assert len(preds)==len(guids)
                write_prediction(preds, guids, "multirc", args.output_dir, prefix)
            return results
        else:
            return None