예제 #1
0
def test_model(args, model, tokenizer, test_examples, test_features, device):
    test_indices = torch.arange(len(test_features), dtype=torch.long)
    test_sampler = SequentialSampler(test_indices)
    test_dataloader = DataLoader(test_indices,
                                 sampler=test_sampler,
                                 batch_size=args.predict_batch_size)

    model.eval()
    all_results = []
    for step, batch_test_indices in enumerate(
            tqdm(test_dataloader, desc="Evaluating")):
        batch_test_features = [
            test_features[ind] for ind in batch_test_indices
        ]
        batch_query_tokens = [f.query_tokens for f in batch_test_features]
        batch_doc_tokens = [f.doc_tokens for f in batch_test_features]
        batch_size = len(batch_test_features)
        cur_global_pointers = [0] * batch_size
        batch_max_doc_length = [
            args.max_seq_length - 3 - len(query_tokens)
            for query_tokens in batch_query_tokens
        ]
        stop_probs = []
        prev_hidden_states = None
        for t in range(args.max_read_times):
            chunk_input_ids, chunk_input_mask, chunk_segment_ids, id_to_tok_maps, _, _, _ = \
                             gen_model_features(cur_global_pointers, batch_query_tokens, batch_doc_tokens, \
                                                None, None, batch_max_doc_length, args.max_seq_length, \
                                                tokenizer, is_train=False)
            chunk_input_ids = torch.tensor(chunk_input_ids,
                                           dtype=torch.long,
                                           device=device)
            chunk_input_mask = torch.tensor(chunk_input_mask,
                                            dtype=torch.long,
                                            device=device)
            chunk_segment_ids = torch.tensor(chunk_segment_ids,
                                             dtype=torch.long,
                                             device=device)
            with torch.no_grad():
                chunk_stop_logits, chunk_stride_inds, chunk_stride_log_probs, \
                                   chunk_start_logits, chunk_end_logits, \
                                   prev_hidden_states = model(chunk_input_ids, chunk_segment_ids,
                                                              chunk_input_mask, prev_hidden_states)
            # stop_probs: current chunk contains answer
            chunk_stop_logits = chunk_stop_logits.detach().cpu().tolist()
            #chunk_stop_probs = chunk_stop_probs.detach().cpu().tolist()
            #stop_probs.append(chunk_stop_probs[:])

            # find top answer texts for the current chunk
            for i, example_index in enumerate(batch_test_indices):
                stop_logits = chunk_stop_logits[i]
                start_logits = chunk_start_logits[i].detach().cpu().tolist()
                end_logits = chunk_end_logits[i].detach().cpu().tolist()
                #yes_no_flag_logits = chunk_yes_no_flag_logits[i].detach().cpu().tolist()
                #yes_no_ans_logits = chunk_yes_no_ans_logits[i].detach().cpu().tolist()
                id_to_tok_map = id_to_tok_maps[i]
                example_index = example_index.item()
                all_results.append(
                    RawResult(example_index=example_index,
                              stop_logits=stop_logits,
                              start_logits=start_logits,
                              end_logits=end_logits,
                              id_to_tok_map=id_to_tok_map))

            # take movement action
            if args.supervised_pretraining:
                chunk_strides = [args.doc_stride] * batch_size
            else:
                chunk_strides = [
                    stride_action_space[stride_ind]
                    for stride_ind in chunk_stride_inds.tolist()
                ]
            cur_global_pointers = [
                cur_global_pointers[ind] + chunk_strides[ind]
                for ind in range(len(cur_global_pointers))
            ]
            # put pointer be put to 0 or the last doc token is it
            cur_global_pointers = [min(max(0, cur_global_pointers[ind]), len(batch_doc_tokens[ind])-1) \
                                   for ind in range(len(cur_global_pointers))]

    # write predictions
    output_prediction_file = os.path.join(args.output_dir, "predictions.json")
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
    write_predictions(test_examples, test_features, all_results, args.n_best_size, \
                      args.max_answer_length, args.do_lower_case, \
                      output_prediction_file, output_nbest_file, args.verbose_logging)
예제 #2
0
def validate_model(args, model, tokenizer, dev_examples, dev_features,
                   dev_dataloader, dev_evaluator, best_dev_score, device):
    all_results = []
    for dev_step, batch_dev_indices in enumerate(
            tqdm(dev_dataloader, desc="Evaluating")):
        batch_dev_features = [dev_features[ind] for ind in batch_dev_indices]
        batch_query_tokens = [f.query_tokens for f in batch_dev_features]
        batch_doc_tokens = [f.doc_tokens for f in batch_dev_features]
        batch_size = len(batch_dev_features)
        cur_global_pointers = [0] * batch_size
        batch_max_doc_length = [
            args.max_seq_length - 3 - len(query_tokens)
            for query_tokens in batch_query_tokens
        ]
        stop_probs = []
        prev_hidden_states = None
        for t in range(args.max_read_times):
            chunk_input_ids, chunk_input_mask, chunk_segment_ids, id_to_tok_maps, _, _, _ = \
                             gen_model_features(cur_global_pointers, batch_query_tokens, batch_doc_tokens, \
                                                None, None, batch_max_doc_length, args.max_seq_length, \
                                                tokenizer, is_train=False)
            chunk_input_ids = torch.tensor(chunk_input_ids,
                                           dtype=torch.long,
                                           device=device)
            chunk_input_mask = torch.tensor(chunk_input_mask,
                                            dtype=torch.long,
                                            device=device)
            chunk_segment_ids = torch.tensor(chunk_segment_ids,
                                             dtype=torch.long,
                                             device=device)
            with torch.no_grad():
                chunk_stop_logits, chunk_stride_inds, chunk_stride_log_probs, \
                                   chunk_start_logits, chunk_end_logits, \
                                   prev_hidden_states = model(chunk_input_ids, chunk_segment_ids,
                                                              chunk_input_mask, prev_hidden_states)
            chunk_stop_logits = chunk_stop_logits.detach().cpu().tolist()
            # stop_probs: current chunk contains answer
            #chunk_stop_probs = chunk_stop_probs.detach().cpu().tolist()

            # find top answer texts for the current chunk
            for i, example_index in enumerate(batch_dev_indices):
                stop_logits = chunk_stop_logits[i]
                start_logits = chunk_start_logits[i].detach().cpu().tolist()
                end_logits = chunk_end_logits[i].detach().cpu().tolist()
                #yes_no_flag_logits = chunk_yes_no_flag_logits[i].detach().cpu().tolist()
                #yes_no_ans_logits = chunk_yes_no_ans_logits[i].detach().cpu().tolist()
                id_to_tok_map = id_to_tok_maps[i]
                example_index = example_index.item()
                all_results.append(
                    RawResult(example_index=example_index,
                              stop_logits=stop_logits,
                              start_logits=start_logits,
                              end_logits=end_logits,
                              id_to_tok_map=id_to_tok_map))

            # take movement action
            if args.supervised_pretraining:
                chunk_strides = [args.doc_stride] * batch_size
            else:
                chunk_strides = [
                    stride_action_space[stride_ind]
                    for stride_ind in chunk_stride_inds.tolist()
                ]
            cur_global_pointers = [
                cur_global_pointers[ind] + chunk_strides[ind]
                for ind in range(len(cur_global_pointers))
            ]
            # put pointer be put to 0 or the last doc token is it
            cur_global_pointers = [min(max(0, cur_global_pointers[ind]), len(batch_doc_tokens[ind])-1) \
                                   for ind in range(len(cur_global_pointers))]

    dev_predictions = make_predictions(dev_examples, dev_features, all_results, args.n_best_size, \
                                        args.max_answer_length, args.do_lower_case, \
                                        args.verbose_logging, validate_flag=True)
    dev_scores = dev_evaluator.eval_fn(dev_predictions)
    dev_score = dev_scores['f1']
    logger.info('dev score: {}'.format(dev_score))
    if (dev_score > best_dev_score):
        best_model_to_save = model.module if hasattr(model,
                                                     'module') else model
        best_output_model_file = os.path.join(args.output_dir,
                                              "best_RCM_model.bin")
        torch.save(best_model_to_save.state_dict(), best_output_model_file)
        best_dev_score = max(best_dev_score, dev_score)
        logger.info("Best dev score: {}, saved to best_RCM_model.bin".format(
            dev_score))
    return best_dev_score
예제 #3
0
def train_model(args, model, tokenizer, optimizer, train_examples,
                train_features, dev_examples, dev_features, dev_evaluator,
                device, n_gpu, t_total):
    train_indices = torch.arange(len(train_features), dtype=torch.long)
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_indices)
    else:
        train_sampler = DistributedSampler(train_indices)
    train_dataloader = DataLoader(train_indices,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.do_validate:
        dev_indices = torch.arange(len(dev_features), dtype=torch.long)
        dev_sampler = SequentialSampler(dev_indices)
        dev_dataloader = DataLoader(dev_indices,
                                    sampler=dev_sampler,
                                    batch_size=args.predict_batch_size)

    best_dev_score = 0.0
    epoch = 0
    global_step = 0
    model.train()
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):
        training_loss = 0.0
        for step, batch_indices in enumerate(
                tqdm(train_dataloader, desc="Iteration")):
            batch_features = [train_features[ind] for ind in batch_indices]
            batch_query_tokens = [f.query_tokens for f in batch_features]
            batch_doc_tokens = [f.doc_tokens for f in batch_features]
            batch_start_positions = [f.start_position for f in batch_features]
            batch_end_positions = [f.end_position for f in batch_features]
            #batch_yes_no_flags = [f.yes_no_flag for f in batch_features]
            #batch_yes_no_answers = [f.yes_no_ans for f in batch_features]

            batch_size = len(batch_features)
            cur_global_pointers = [
                0
            ] * batch_size  # global position of current pointer at the document
            batch_max_doc_length = [
                args.max_seq_length - 3 - len(query_tokens)
                for query_tokens in batch_query_tokens
            ]

            stride_log_probs = []
            stop_rewards = []
            stop_probs = []
            stop_loss = None
            answer_loss = None
            prev_hidden_states = None
            for t in range(args.max_read_times):
                # features at the current chunk
                chunk_input_ids, chunk_input_mask, chunk_segment_ids, id_to_tok_maps, \
                                 chunk_start_positions, chunk_end_positions, chunk_stop_flags = \
                                 gen_model_features(cur_global_pointers, batch_query_tokens, batch_doc_tokens, \
                                                    batch_start_positions, batch_end_positions, batch_max_doc_length, \
                                                    args.max_seq_length, tokenizer, is_train=True)
                chunk_input_ids = torch.tensor(chunk_input_ids,
                                               dtype=torch.long,
                                               device=device)
                chunk_input_mask = torch.tensor(chunk_input_mask,
                                                dtype=torch.long,
                                                device=device)
                chunk_segment_ids = torch.tensor(chunk_segment_ids,
                                                 dtype=torch.long,
                                                 device=device)
                chunk_start_positions = torch.tensor(chunk_start_positions,
                                                     dtype=torch.long,
                                                     device=device)
                chunk_end_positions = torch.tensor(chunk_end_positions,
                                                   dtype=torch.long,
                                                   device=device)
                #chunk_yes_no_flags = torch.tensor(batch_yes_no_flags, dtype=torch.long, device=device)
                #chunk_yes_no_answers = torch.tensor(batch_yes_no_answers, dtype=torch.long, device=device)
                chunk_stop_flags = torch.tensor(chunk_stop_flags,
                                                dtype=torch.long,
                                                device=device)

                # model to find span
                chunk_stop_logits, chunk_stride_inds, chunk_stride_log_probs, \
                                   chunk_start_logits, chunk_end_logits, \
                                   prev_hidden_states, chunk_stop_loss, chunk_answer_loss = \
                                   model(chunk_input_ids, chunk_segment_ids, chunk_input_mask,
                                         prev_hidden_states, chunk_stop_flags,
                                         chunk_start_positions, chunk_end_positions)
                chunk_stop_logits = chunk_stop_logits.detach()
                chunk_stop_probs = F.softmax(chunk_stop_logits, dim=1)
                chunk_stop_probs = chunk_stop_probs[:, 1]
                stop_probs.append(chunk_stop_probs.tolist())
                chunk_stop_logits = chunk_stop_logits.tolist()

                if stop_loss is None:
                    stop_loss = chunk_stop_loss
                else:
                    stop_loss += chunk_stop_loss

                if answer_loss is None:
                    answer_loss = chunk_answer_loss
                else:
                    answer_loss += chunk_answer_loss

                if args.supervised_pretraining:
                    chunk_strides = [args.doc_stride] * batch_size
                else:
                    # take movement action
                    chunk_strides = [
                        stride_action_space[stride_ind]
                        for stride_ind in chunk_stride_inds.tolist()
                    ]
                cur_global_pointers = [
                    cur_global_pointers[ind] + chunk_strides[ind]
                    for ind in range(len(cur_global_pointers))
                ]
                # put pointer to 0 or the last doc token
                cur_global_pointers = [min(max(0, cur_global_pointers[ind]), len(batch_doc_tokens[ind])-1) \
                                       for ind in range(len(cur_global_pointers))]

                if not args.supervised_pretraining:
                    # reward estimation for reinforcement learning
                    chunk_start_probs = F.softmax(chunk_start_logits.detach(),
                                                  dim=1).tolist()
                    chunk_end_probs = F.softmax(chunk_end_logits.detach(),
                                                dim=1).tolist()
                    #chunk_yes_no_flag_probs = F.softmax(chunk_yes_no_flag_logits.detach(), dim=1).tolist()
                    #chunk_yes_no_ans_probs = F.softmax(chunk_yes_no_ans_logits.detach(), dim=1).tolist()
                    # rewards if stop at the current chunk
                    chunk_stop_rewards = reward_estimation_for_stop(
                        chunk_start_probs, chunk_end_probs,
                        chunk_start_positions.tolist(),
                        chunk_end_positions.tolist(),
                        chunk_stop_flags.tolist())
                    stop_rewards.append(chunk_stop_rewards)

                    # save history (exclude the prob of the last read since the last action is not evaluated)
                    if (t < args.max_read_times - 1):
                        stride_log_probs.append(chunk_stride_log_probs)

            if args.supervised_pretraining:
                loss = (stop_loss * args.stop_loss_weight +
                        answer_loss) / args.max_read_times
            else:
                # stride_log_probs: (bsz, max_read_times-1)
                stride_log_probs = torch.stack(stride_log_probs).transpose(
                    1, 0)
                # q_vals: (bsz, max_read_times-1)
                q_vals = reward_estimation(stop_rewards, stop_probs)
                q_vals = torch.tensor(q_vals,
                                      dtype=stride_log_probs.dtype,
                                      device=device)
                #logger.info("q_vals: {}".format(q_vals))
                reinforce_loss = torch.sum(-stride_log_probs * q_vals, dim=1)
                reinforce_loss = torch.mean(reinforce_loss, dim=0)

                loss = (stop_loss * args.stop_loss_weight +
                        answer_loss) / args.max_read_times + reinforce_loss
            # compute gradients
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()

            # logging training loss
            training_loss += loss.item()
            if (step % 500 == 499):
                logger.info('step: {}, train loss: {}\n'.format(
                    step, training_loss / 500.0))
                if not args.supervised_pretraining:
                    logger.info('q_vals: {}\n'.format(q_vals))
                training_loss = 0.0

            # validation on dev data
            if args.do_validate and step % 500 == 499:
                model.eval()
                best_dev_score = validate_model(args, model, tokenizer,
                                                dev_examples, dev_features,
                                                dev_dataloader, dev_evaluator,
                                                best_dev_score, device)
                model.train()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                # modify learning rate with special warm up BERT uses
                lr_this_step = args.learning_rate * warmup_linear(
                    global_step / t_total, args.warmup_proportion)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
        epoch += 1

    # Save a trained model
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    if args.do_train:
        torch.save(model_to_save.state_dict(), output_model_file)