def get_train_features(data_dir, bert_model, max_seq_length, do_lower_case,
                       local_rank, train_batch_size,
                       gradient_accumulation_steps, num_train_epochs,
                       tokenizer, processor):
    cached_train_features_file = os.path.join(
        data_dir,
        '{0}_{1}_{2}'.format(
            list(filter(None, bert_model.split('/'))).pop(),
            str(max_seq_length),
            str(do_lower_case),
        ),
    )
    train_features = None
    try:
        with open(cached_train_features_file, "rb") as reader:
            train_features = pickle.load(reader)
        logger.info("Loaded pre-processed features from {}".format(
            cached_train_features_file))
    except:
        logger.info("Did not find pre-processed features from {}".format(
            cached_train_features_file))
        train_examples = processor.get_train_examples(data_dir)
        train_features, _ = convert_examples_to_features(
            train_examples,
            processor.get_labels(),
            max_seq_length,
            tokenizer,
        )
        if is_main_process():
            logger.info("  Saving train features into cached file %s",
                        cached_train_features_file)
            with open(cached_train_features_file, "wb") as writer:
                pickle.dump(train_features, writer)
    return train_features
示例#2
0
def load_and_cache_examples(args, task, tokenizer, evaluate=False):

    processor = processors[task]()
    output_mode = output_modes[task]
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        'cached_{}_bert_{}_{}_qqpmerge'.format('dev' if evaluate else 'train',
                                               str(args.max_seq_length),
                                               str(task)))
    if os.path.exists(cached_features_file):
        logger.info("Loading features from cached file %s",
                    cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
        examples = processor.get_dev_examples2(
            args.data_dir) if evaluate else processor.get_train_examples2(
                args.data_dir)
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=args.max_seq_length,
            output_mode=output_mode,
            pad_on_left=False,
            # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0,
        )
        logger.info("Saving features into cached file %s",
                    cached_features_file)
        torch.save(features, cached_features_file)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features],
                                 dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features],
                                      dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features],
                                      dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features],
                                  dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features],
                                  dtype=torch.float)

    dataset = TensorDataset(all_input_ids, all_attention_mask,
                            all_token_type_ids, all_labels)
    return dataset
示例#3
0
def load_and_cache_examples(args, task, tokenizer, data_type='train'):
    # if args.local_rank not in [-1, 0] and not evaluate:
    #     torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    processor = processors[task]
    output_mode = output_modes[task]
    label_list = processor.get_labels()

    if data_type == 'train':
        examples = processor.get_train_examples(args.data_dir)
    elif data_type == 'dev':
        examples = processor.get_dev_examples(args.data_dir)
    else:
        examples = processor.get_test_examples(args.data_dir)

    features = convert_examples_to_features(examples,
                                            tokenizer,
                                            label_list=label_list,
                                            max_seq_length=args.max_seq_length,
                                            output_mode=output_mode)

    # if args.local_rank == 0 and not evaluate:
    #     torch.distributed.barrier()

    all_input_ids = torch.tensor([f.input_ids for f in features],
                                 dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features],
                                      dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features],
                                      dtype=torch.long)
    all_lens = torch.tensor([f.input_len for f in features], dtype=torch.long)

    all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_attention_mask,
                            all_token_type_ids, all_lens, all_labels)
    return dataset
示例#4
0
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    glue_processor = glue_processors[task]()
    output_mode = output_modes[task]
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir, 'cached_{}_{}_{}_{}'.format(
            'dev' if evaluate else 'train',
            list(filter(None, args.model_name_or_path.split('/'))).pop(),
            str(args.max_seq_length), str(task)))
    #if os.path.exists(cached_features_file) and not args.overwrite_cache:
    if False:
        logger.info("Loading features from cached file %s",
                    cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = glue_processor.get_labels()
        if task in ['mnli', 'mnli-mm'
                    ] and args.model_type in ['roberta', 'xlmroberta']:
            # HACK(label indices are swapped in RoBERTa pretrained model)
            label_list[1], label_list[2] = label_list[2], label_list[1]
        examples = glue_processor.get_dev_examples(
            args.data_dir) if evaluate else glue_processor.get_train_examples(
                args.data_dir)
        print(f"Begin to convert_examples_to_features...")
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=args.max_seq_length,
            output_mode=output_mode,
            pad_on_left=bool(
                args.model_type in ['xlnet']),  # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
            char_vocab_file=args.char_vocab,
            model_type=args.model_type)
        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s",
                        cached_features_file)
            torch.save(features, cached_features_file)

    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Convert to Tensors and build dataset
    all_char_ids = torch.tensor([f.char_input_ids for f in features],
                                dtype=torch.long)
    all_start_ids = torch.tensor([f.start_ids for f in features],
                                 dtype=torch.long)
    all_end_ids = torch.tensor([f.end_ids for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features],
                                 dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features],
                                      dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features],
                                      dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features],
                                  dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features],
                                  dtype=torch.float)

    dataset = TensorDataset(all_char_ids, all_start_ids, all_end_ids,
                            all_input_ids, all_attention_mask,
                            all_token_type_ids, all_labels)
    return dataset
示例#5
0
def for_server(
    text: str,
    task: str,
):

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1
    args.device = device
    text = [''] + [['0'] + ['0'] + [text]]

    processor = processors[task]
    output_mode = output_modes[task]
    label_list = processor.get_labels()
    # num_labels = len(label_list)
    examples = processor._create_examples(text, 'predict')

    if args.local_rank in [-1, 0]:
        tokenizer = tokenization_albert.FullTokenizer(
            vocab_file=args.vocab_file,
            do_lower_case=args.do_lower_case,
        )
        checkpoints = [(0, args.output_dir)]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME,
                              recursive=True)))
            checkpoints = [(int(checkpoint.split('-')[-1]), checkpoint)
                           for checkpoint in checkpoints
                           if checkpoint.find('checkpoint') != -1]
            checkpoints = sorted(checkpoints, key=lambda x: x[0])
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        if len(checkpoints) == 0:
            checkpoints = [(0, args.output_dir)]
        else:
            checkpoints = [checkpoints[-1]]
        for _, checkpoint in checkpoints:
            if len(os.listdir(checkpoint)) == 0:
                main()
            model = AlbertForSequenceClassification.from_pretrained(checkpoint)
            model.to(args.device)
            features = convert_examples_to_features(
                examples,
                tokenizer,
                label_list=label_list,
                max_seq_length=args.max_seq_length,
                output_mode=output_mode)
            all_input_ids = torch.tensor([f.input_ids for f in features],
                                         dtype=torch.long)

            all_attention_mask = torch.tensor(
                [f.attention_mask for f in features], dtype=torch.long)
            all_token_type_ids = torch.tensor(
                [f.token_type_ids for f in features], dtype=torch.long)
            all_lens = torch.tensor([f.input_len for f in features],
                                    dtype=torch.long)
            all_labels = torch.tensor([f.label for f in features],
                                      dtype=torch.long)
            dataset = TensorDataset(all_input_ids, all_attention_mask,
                                    all_token_type_ids, all_lens, all_labels)

            for step, batch in enumerate(dataset):
                model.eval()
                batch = tuple(t.to(args.device) for t in batch)
                with torch.no_grad():
                    inputs = {
                        'input_ids': batch[0].unsqueeze(0),
                        'attention_mask': batch[1].unsqueeze(0),
                    }
                    inputs['token_type_ids'] = batch[2].unsqueeze(0)
                    outputs = model(**inputs)
                    logits = outputs[0]
                    preds = np.argmax(logits, axis=1)
                    label = tasks_num_labels[task][preds]
                    logger.info('label is {}'.format(label))

    return label
def main(args):
    args.fp16 = args.fp16 or args.amp
    if args.server_ip and args.server_port:
        # Distant debugging - see
        # https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        logger.info("Waiting for debugger attach")
        ptvsd.enable_attach(
            address=(args.server_ip, args.server_port),
            redirect_output=True,
        )
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of
        # sychronizing nodes/GPUs.
        if not torch.distributed.is_initialized():
            torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, "
                "16-bits training: {}".format(
                    device,
                    n_gpu,
                    bool(args.local_rank != -1),
                    args.fp16,
                ))

    if not args.do_train and not args.do_eval and not args.do_predict:
        raise ValueError("At least one of `do_train`, `do_eval` or "
                         "`do_predict` must be True.")

    if is_main_process():
        if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
                and args.do_train):
            logger.warning("Output directory ({}) already exists and is not "
                           "empty.".format(args.output_dir))
    mkdir_by_main_process(args.output_dir)

    if is_main_process():
        dllogger.init(backends=[
            dllogger.JSONStreamBackend(
                verbosity=dllogger.Verbosity.VERBOSE,
                filename=os.path.join(args.output_dir, 'dllogger.json'),
            ),
            dllogger.StdOutBackend(
                verbosity=dllogger.Verbosity.VERBOSE,
                step_format=format_step,
            ),
        ])
    else:
        dllogger.init(backends=[])

    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             args.gradient_accumulation_steps))
    if args.gradient_accumulation_steps > args.train_batch_size:
        raise ValueError("gradient_accumulation_steps ({}) cannot be larger "
                         "train_batch_size ({}) - there cannot be a fraction "
                         "of one sample.".format(
                             args.gradient_accumulation_steps,
                             args.train_batch_size,
                         ))
    args.train_batch_size = (args.train_batch_size //
                             args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    processor = PROCESSORS[args.task_name]()
    num_labels = len(processor.get_labels())

    #tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
    tokenizer = BertTokenizer(
        args.vocab_file,
        do_lower_case=args.do_lower_case,
        max_len=512,
    )  # for bert large

    num_train_optimization_steps = None
    if args.do_train:
        train_features = get_train_features(
            args.data_dir,
            args.bert_model,
            args.max_seq_length,
            args.do_lower_case,
            args.local_rank,
            args.train_batch_size,
            args.gradient_accumulation_steps,
            args.num_train_epochs,
            tokenizer,
            processor,
        )
        num_train_optimization_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = (num_train_optimization_steps //
                                            torch.distributed.get_world_size())

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)
    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    # modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForSequenceClassification(
        config,
        num_labels=num_labels,
    )
    logger.info("USING CHECKPOINT from {}".format(args.init_checkpoint))

    checkpoint = torch.load(args.init_checkpoint, map_location='cpu')
    checkpoint = checkpoint["model"] if "model" in checkpoint.keys(
    ) else checkpoint
    model.load_state_dict(checkpoint, strict=False)
    logger.info("USED CHECKPOINT from {}".format(args.init_checkpoint))
    dllogger.log(
        step="PARAMETER",
        data={
            "num_parameters":
            sum([p.numel() for p in model.parameters() if p.requires_grad]),
        },
    )

    model.to(device)
    # Prepare optimizer
    model, optimizer, scheduler = init_optimizer_and_amp(
        model,
        args.learning_rate,
        args.loss_scale,
        args.warmup_proportion,
        num_train_optimization_steps,
        args.fp16,
    )

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from "
                              "https://www.github.com/nvidia/apex to use "
                              "distributed and fp16 training.")
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    loss_fct = torch.nn.CrossEntropyLoss()

    results = {}
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        train_data = gen_tensor_dataset(train_features)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(
            train_data,
            sampler=train_sampler,
            batch_size=args.train_batch_size,
        )

        global_step = 0
        nb_tr_steps = 0
        tr_loss = 0
        latency_train = 0.0
        nb_tr_examples = 0
        model.train()
        tic_train = time.perf_counter()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if args.max_steps > 0 and global_step > args.max_steps:
                    break
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                logits = model(input_ids, segment_ids, input_mask)
                loss = loss_fct(
                    logits.view(-1, num_labels),
                    label_ids.view(-1),
                )
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up for BERT
                        # which FusedAdam doesn't do
                        scheduler.step()

                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
        latency_train = time.perf_counter() - tic_train
        tr_loss = tr_loss / nb_tr_steps
        results.update({
            'global_step':
            global_step,
            'train:loss':
            tr_loss,
            'train:latency':
            latency_train,
            'train:num_samples_per_gpu':
            nb_tr_examples,
            'train:num_steps':
            nb_tr_steps,
            'train:throughput':
            get_world_size() * nb_tr_examples / latency_train,
        })
        if is_main_process() and not args.skip_checkpoint:
            model_to_save = model.module if hasattr(model, 'module') else model
            torch.save(
                {"model": model_to_save.state_dict()},
                os.path.join(args.output_dir, modeling.WEIGHTS_NAME),
            )
            with open(
                    os.path.join(args.output_dir, modeling.CONFIG_NAME),
                    'w',
            ) as f:
                f.write(model_to_save.config.to_json_string())

    if (args.do_eval or args.do_predict) and is_main_process():
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features, label_map = convert_examples_to_features(
            eval_examples,
            processor.get_labels(),
            args.max_seq_length,
            tokenizer,
        )
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_data = gen_tensor_dataset(eval_features)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(
            eval_data,
            sampler=eval_sampler,
            batch_size=args.eval_batch_size,
        )

        model.eval()
        preds = None
        out_label_ids = None
        eval_loss = 0
        nb_eval_steps, nb_eval_examples = 0, 0
        cuda_events = [(torch.cuda.Event(enable_timing=True),
                        torch.cuda.Event(enable_timing=True))
                       for _ in range(len(eval_dataloader))]
        for i, (input_ids, input_mask, segment_ids, label_ids) in tqdm(
                enumerate(eval_dataloader),
                desc="Evaluating",
        ):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                cuda_events[i][0].record()
                logits = model(input_ids, segment_ids, input_mask)
                cuda_events[i][1].record()
                if args.do_eval:
                    eval_loss += loss_fct(
                        logits.view(-1, num_labels),
                        label_ids.view(-1),
                    ).mean().item()

            nb_eval_steps += 1
            nb_eval_examples += input_ids.size(0)
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = label_ids.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids,
                    label_ids.detach().cpu().numpy(),
                    axis=0,
                )
        torch.cuda.synchronize()
        eval_latencies = [
            event_start.elapsed_time(event_end)
            for event_start, event_end in cuda_events
        ]
        eval_latencies = list(sorted(eval_latencies))

        def infer_latency_sli(threshold):
            index = int(len(eval_latencies) * threshold) - 1
            index = min(max(index, 0), len(eval_latencies) - 1)
            return eval_latencies[index]

        eval_throughput = (args.eval_batch_size /
                           (np.mean(eval_latencies) / 1000))

        results.update({
            'eval:num_samples_per_gpu': nb_eval_examples,
            'eval:num_steps': nb_eval_steps,
            'infer:latency(ms):50%': infer_latency_sli(0.5),
            'infer:latency(ms):90%': infer_latency_sli(0.9),
            'infer:latency(ms):95%': infer_latency_sli(0.95),
            'infer:latency(ms):99%': infer_latency_sli(0.99),
            'infer:latency(ms):100%': infer_latency_sli(1.0),
            'infer:latency(ms):avg': np.mean(eval_latencies),
            'infer:latency(ms):std': np.std(eval_latencies),
            'infer:latency(ms):sum': np.sum(eval_latencies),
            'infer:throughput(samples/s):avg': eval_throughput,
        })
        preds = np.argmax(preds, axis=1)
        if args.do_predict:
            dump_predictions(
                os.path.join(args.output_dir, 'predictions.json'),
                label_map,
                preds,
                eval_examples,
            )
        if args.do_eval:
            results['eval:loss'] = eval_loss / nb_eval_steps
            eval_result = compute_metrics(args.task_name, preds, out_label_ids)
            results.update(eval_result)

    if is_main_process():
        logger.info("***** Results *****")
        for key in sorted(results.keys()):
            logger.info("  %s = %s", key, str(results[key]))
        with open(os.path.join(args.output_dir, "results.txt"), "w") as writer:
            json.dump(results, writer)
        dllogger_queries_from_results = {
            'exact_match': 'acc',
            'F1': 'f1',
            'e2e_train_time': 'train:latency',
            'training_sequences_per_second': 'train:throughput',
            'e2e_inference_time':
            ('infer:latency(ms):sum', lambda x: x / 1000),
            'inference_sequences_per_second':
            'infer:throughput(samples/s):avg',
        }
        for key, query in dllogger_queries_from_results.items():
            results_key, convert = (query if isinstance(query, tuple) else
                                    (query, lambda x: x))
            if results_key not in results:
                continue
            dllogger.log(
                step=tuple(),
                data={key: convert(results[results_key])},
            )
    dllogger.flush()
    return results
     sample.append(sentence)
     sample.append(table.row_values(j)[4])
     samples.append(sample)
 examples = []
 for i in range(len(samples)):
     guid = "%s" % (i)
     text_a = samples[i][0].lower()
     text_b = samples[i][1].lower()
     label = str(int(table.row_values(j)[3] >= 0.5))
     examples.append(
         InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
 features = convert_examples_to_features(examples,
                                         tokenizer,
                                         label_list=["0", "1"],
                                         max_length=64,
                                         output_mode="classification",
                                         pad_on_left=False,
                                         # pad on the left for xlnet
                                         pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
                                         pad_token_segment_id=0,
                                         )
 #for f in features:
     #input_ids = torch.tensor(f.input_ids, dtype=torch.long).unsqueeze(0)
     #attention_mask = torch.tensor(f.attention_mask, dtype=torch.long).unsqueeze(0)
     #align_mask = torch.tensor(f.align_mask, dtype=torch.long).unsqueeze(0)
     #label = torch.tensor(f.label, dtype=torch.long).unsqueeze(0)
     #input_ids = torch.tensor(f.input_ids, dtype=torch.long).unsqueeze(0).cuda()
     #attention_mask = torch.tensor(f.attention_mask, dtype=torch.long).unsqueeze(0).cuda()
     #align_mask = torch.tensor(f.align_mask, dtype=torch.long).unsqueeze(0).cuda()
     #label = torch.tensor(f.label, dtype=torch.long).unsqueeze(0).cuda()
     #print(label.shape)
 all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).cuda()