Beispiel #1
0
 def _train_epoch(self, data_loader: DataLoader, optimizer: AdamW,
                  scheduler: LambdaLR, report_frequency: int):
     initial_time = time.time()
     total_train_loss = 0
     self.model.train()
     self.model.to(self.device)
     for step, batch in enumerate(data_loader):
         b_input_ids = batch[0].to(self.device)
         b_input_mask = batch[1].to(self.device)
         b_labels = batch[2].to(self.device)
         optimizer.zero_grad()
         loss, logits = self.model(b_input_ids,
                                   token_type_ids=None,
                                   attention_mask=b_input_mask,
                                   labels=b_labels)
         self._report_loss_and_time(step=step + 1,
                                    num_of_batches=len(data_loader),
                                    initial_time=initial_time,
                                    loss=loss,
                                    frequency=report_frequency)
         total_train_loss += loss.item()
         loss.backward()
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
         optimizer.step()
         scheduler.step()
     avg_train_loss = total_train_loss / len(data_loader)
     training_time = self._format_time_delta(initial_time)
     print(f"Average training loss: {avg_train_loss:.2f}")
     print(f"Training epoch took: {training_time}")
Beispiel #2
0
 def _create_optimizer(self, sgd):
     optimizer = AdamW(
         self._model.parameters(),
         lr=getattr(sgd, "trf_lr", sgd.alpha),
         eps=sgd.eps,
         betas=(sgd.b1, sgd.b2),
         weight_decay=getattr(sgd, "trf_weight_decay", 0.0),
     )
     optimizer.zero_grad()
     return optimizer
Beispiel #3
0
class Optim(object):
    def set_parameters(self, params):
        self.params = list(params)  # careful: params may be a generator
        if self.method == 'sgd':
            self.optimizer = optim.SGD(self.params, lr=self.lr)
        elif self.method == 'adagrad':
            self.optimizer = optim.Adagrad(self.params, lr=self.lr)
        elif self.method == 'adadelta':
            self.optimizer = optim.Adadelta(self.params, lr=self.lr)
        elif self.method == 'adam':
            self.optimizer = optim.Adam(self.params, lr=self.lr)
        elif self.method == 'bertadam':
            self.optimizer = AdamW(self.params, lr=self.lr)
        else:
            raise RuntimeError("Invalid optim method: " + self.method)

    def __init__(self,
                 method,
                 lr,
                 max_grad_norm,
                 lr_decay=1,
                 start_decay_at=None,
                 max_decay_times=2):
        self.last_score = None
        self.decay_times = 0
        self.max_decay_times = max_decay_times
        self.lr = float(lr)
        self.max_grad_norm = max_grad_norm
        self.method = method
        self.lr_decay = lr_decay
        self.start_decay_at = start_decay_at
        self.start_decay = False

    def step(self):
        # Compute gradients norm.
        if self.max_grad_norm:
            #梯度裁剪
            clip_grad_norm_(self.params, self.max_grad_norm)
        self.optimizer.step()

    def zero_grad(self):
        self.optimizer.zero_grad()

    #如果val perf没有改善,或者我们达到start_decay_at极限,则衰减学习率
    def updateLearningRate(self, score, epoch):
        if self.start_decay_at is not None and epoch >= self.start_decay_at:
            self.start_decay = True

        if self.start_decay:
            self.lr = self.lr * self.lr_decay
            print("Decaying learning rate to %g" % self.lr)

        self.last_score = score
        self.optimizer.param_groups[0]['lr'] = self.lr
    def train_func(self):
        # loss_fct = MarginRankingLoss(margin=1, reduction='mean')
        loss_fct = NLLLoss(reduction='mean')
        optimizer = AdamW(self.model.parameters(), self.args.learning_rate)
        step = 0
        # cos = nn.CosineSimilarity(dim=1)
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.args.scheduler_step,
            gamma=self.args.scheduler_gamma)
        accumulate_step = 0

        for epoch in range(1, self.args.epoch + 1):
            for batch in self.loader:
                probs = self.get_probs(batch)
                batch_size = probs.size(0)

                true_idx = torch.zeros(batch_size, dtype=torch.long)
                if torch.cuda.is_available():
                    true_idx = true_idx.cuda()
                loss = loss_fct(probs, true_idx)
                loss.backward()

                self.writer.add_scalar('loss', loss, step)

                stop_scheduler_step = self.args.scheduler_step * 80

                if accumulate_step % self.args.gradient_accumulate_step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    if self.args.scheduler_lr and step <= stop_scheduler_step:
                        scheduler.step()
                    accumulate_step = 0

                step += 1
                if step % self.args.save_model_step == 0:
                    model_basename = self.args.dest_base_dir + self.args.exp_name
                    model_basename += '_epoch_{}_step_{}'.format(epoch, step)
                    torch.save(self.model.state_dict(),
                               model_basename + '.model')
                    write_json(model_basename + '.json', vars(self.args))
                    ret = self.evaluate(model_basename, step)
                    self.writer.add_scalar('accuracy', ret, step)
                    # self.writer.add_scalar('recall', ret['recall'], step)
                    # self.writer.add_scalar('f1', ret['f1'], step)
                    msg_tmpl = 'step {} completed, accuracy {:.4f}'
                    self.logger.info(msg_tmpl.format(step, ret))
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--target_train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--update_BERT_top_layers',
                        type=int,
                        default=1,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    target_kshot_entail_examples, target_kshot_nonentail_examples, target_dev_examples, target_test_examples = load_FewRel_GFS_Entail(
        args.kshot)

    system_seed = 42
    random.seed(system_seed)
    np.random.seed(system_seed)
    torch.manual_seed(system_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(system_seed)

    source_kshot_size = 10  # if args.kshot>10 else 10 if max(10, args.kshot)
    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train(
        '/export/home/Dataset/glue_data/MNLI/train.tsv', source_kshot_size)
    source_examples = source_kshot_entail + source_kshot_neural + source_kshot_contra + source_remaining_examples
    target_label_list = ["entailment", "non_entailment"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    # entity_label_list = ["A-coref", "B-coref"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:',
          len(target_dev_examples), 'test size:', len(target_test_examples))

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load(
        '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
    ),
                                  strict=False)
    '''
    embedding layer 5 variables
    each bert layer 16 variables
    '''
    param_size = 0
    update_top_layer_size = args.update_BERT_top_layers
    for name, param in roberta_model.named_parameters():
        if param_size < (5 + 16 * (24 - update_top_layer_size)):
            param.requires_grad = False
        param_size += 1
    roberta_model.to(device)

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters()) + list(
        roberta_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(
        source_kshot_entail,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(
        source_kshot_neural,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(
        source_kshot_contra,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(
        source_remaining_examples,
        source_label_list,
        args,
        tokenizer,
        args.train_batch_size,
        "classification",
        dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(
        target_kshot_entail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(
        target_kshot_nonentail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples,
                                                 target_label_list,
                                                 args,
                                                 tokenizer,
                                                 args.eval_batch_size,
                                                 "classification",
                                                 dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples,
                                                  target_label_list,
                                                  args,
                                                  tokenizer,
                                                  args.eval_batch_size,
                                                  "classification",
                                                  dataloader_mode='sequential')
    '''starting to train'''
    iter_co = 0
    tr_loss = 0
    source_loss = 0
    target_loss = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):

        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(
                tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            _, input_ids, input_mask, segment_ids, source_label_ids_batch = batch

            roberta_model.train()
            source_last_hidden_batch, _ = roberta_model(input_ids, input_mask)
            '''
            retrieve rep for support examples in MNLI
            '''
            kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
            entail_batch_i = 0
            for entail_batch in source_kshot_entail_dataloader:
                roberta_model.train()
                last_hidden_entail, _ = roberta_model(
                    entail_batch[1].to(device), entail_batch[2].to(device))
                kshot_entail_reps += torch.mean(last_hidden_entail,
                                                dim=0,
                                                keepdim=True)
                entail_batch_i += 1
            kshot_entail_rep = kshot_entail_reps / entail_batch_i
            kshot_neural_reps = torch.zeros(1, bert_hidden_dim).to(device)
            neural_batch_i = 0
            for neural_batch in source_kshot_neural_dataloader:
                roberta_model.train()
                last_hidden_neural, _ = roberta_model(
                    neural_batch[1].to(device), neural_batch[2].to(device))
                kshot_neural_reps += torch.mean(last_hidden_neural,
                                                dim=0,
                                                keepdim=True)
                neural_batch_i += 1
            kshot_neural_rep = kshot_neural_reps / neural_batch_i
            kshot_contra_reps = torch.zeros(1, bert_hidden_dim).to(device)
            contra_batch_i = 0
            for contra_batch in source_kshot_contra_dataloader:
                roberta_model.train()
                last_hidden_contra, _ = roberta_model(
                    contra_batch[1].to(device), contra_batch[2].to(device))
                kshot_contra_reps += torch.mean(last_hidden_contra,
                                                dim=0,
                                                keepdim=True)
                contra_batch_i += 1
            kshot_contra_rep = kshot_contra_reps / contra_batch_i

            source_class_prototype_reps = torch.cat(
                [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
                dim=0)  #(3, hidden)
            '''first get representations for support examples in target'''
            target_kshot_entail_dataloader_subset = examples_to_features(
                random.sample(target_kshot_entail_examples, 10),
                target_label_list,
                args,
                tokenizer,
                retrieve_batch_size,
                "classification",
                dataloader_mode='sequential')
            target_kshot_nonentail_dataloader_subset = examples_to_features(
                random.sample(target_kshot_nonentail_examples, 10),
                target_label_list,
                args,
                tokenizer,
                retrieve_batch_size,
                "classification",
                dataloader_mode='sequential')
            kshot_entail_reps = []
            for entail_batch in target_kshot_entail_dataloader_subset:
                roberta_model.train()
                last_hidden_entail, _ = roberta_model(
                    entail_batch[1].to(device), entail_batch[2].to(device))
                kshot_entail_reps.append(last_hidden_entail)
            all_kshot_entail_reps = torch.cat(kshot_entail_reps, dim=0)
            kshot_entail_rep = torch.mean(all_kshot_entail_reps,
                                          dim=0,
                                          keepdim=True)
            kshot_nonentail_reps = []
            for nonentail_batch in target_kshot_nonentail_dataloader_subset:
                roberta_model.train()
                last_hidden_nonentail, _ = roberta_model(
                    nonentail_batch[1].to(device),
                    nonentail_batch[2].to(device))
                kshot_nonentail_reps.append(last_hidden_nonentail)
            all_kshot_neural_reps = torch.cat(kshot_nonentail_reps, dim=0)
            kshot_nonentail_rep = torch.mean(all_kshot_neural_reps,
                                             dim=0,
                                             keepdim=True)
            target_class_prototype_reps = torch.cat(
                [kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep],
                dim=0)  #(3, hidden)

            class_prototype_reps = torch.cat(
                [source_class_prototype_reps, target_class_prototype_reps],
                dim=0)  #(6, hidden)
            '''forward to model'''

            target_batch_size = args.target_train_batch_size  #10*3
            # print('target_batch_size:', target_batch_size)
            target_batch_size_entail = target_batch_size  #random.randrange(5)+1
            target_batch_size_neural = target_batch_size  #random.randrange(5)+1

            selected_target_entail_rep = all_kshot_entail_reps[torch.randperm(
                all_kshot_entail_reps.shape[0])[:target_batch_size_entail]]
            # print('selected_target_entail_rep:', selected_target_entail_rep.shape)
            selected_target_neural_rep = all_kshot_neural_reps[torch.randperm(
                all_kshot_neural_reps.shape[0])[:target_batch_size_neural]]
            # print('selected_target_neural_rep:', selected_target_neural_rep.shape)
            target_last_hidden_batch = torch.cat(
                [selected_target_entail_rep, selected_target_neural_rep])

            last_hidden_batch = torch.cat(
                [source_last_hidden_batch, target_last_hidden_batch],
                dim=0)  #(train_batch_size+10*2)
            # print('last_hidden_batch shape:', last_hidden_batch.shape)
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)
            # exit(0)
            '''source side loss'''
            # loss_fct = CrossEntropyLoss(reduction='none')
            loss_fct = CrossEntropyLoss()
            source_loss_list = loss_fct(
                batch_logits[:source_last_hidden_batch.shape[0]].view(
                    -1, source_num_labels), source_label_ids_batch.view(-1))
            '''target side loss'''
            target_label_ids_batch = torch.tensor(
                [0] * selected_target_entail_rep.shape[0] +
                [1] * selected_target_neural_rep.shape[0],
                dtype=torch.long)
            target_batch_logits = batch_logits[-target_last_hidden_batch.
                                               shape[0]:]
            target_loss_list = loss_by_logits_and_2way_labels(
                target_batch_logits, target_label_ids_batch.view(-1), device)
            # target_loss_list = loss_fct(target_batch_logits.view(-1, source_num_labels), target_label_ids_batch.to(device).view(-1))
            loss = source_loss_list + target_loss_list  #torch.mean(torch.cat([source_loss_list, target_loss_list]))
            source_loss += source_loss_list
            target_loss += target_loss_list
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            global_step += 1
            iter_co += 1
            # print('iter_co:', iter_co, 'mean loss:', tr_loss/iter_co)
            if iter_co % 20 == 0:
                # if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                '''
                retrieve rep for support examples in MNLI
                '''
                kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
                entail_batch_i = 0
                for entail_batch in source_kshot_entail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(
                            entail_batch[1].to(device),
                            entail_batch[2].to(device))
                    kshot_entail_reps += torch.mean(last_hidden_entail,
                                                    dim=0,
                                                    keepdim=True)
                    entail_batch_i += 1
                kshot_entail_rep = kshot_entail_reps / entail_batch_i
                kshot_neural_reps = torch.zeros(1, bert_hidden_dim).to(device)
                neural_batch_i = 0
                for neural_batch in source_kshot_neural_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_neural, _ = roberta_model(
                            neural_batch[1].to(device),
                            neural_batch[2].to(device))
                    kshot_neural_reps += torch.mean(last_hidden_neural,
                                                    dim=0,
                                                    keepdim=True)
                    neural_batch_i += 1
                kshot_neural_rep = kshot_neural_reps / neural_batch_i
                kshot_contra_reps = torch.zeros(1, bert_hidden_dim).to(device)
                contra_batch_i = 0
                for contra_batch in source_kshot_contra_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_contra, _ = roberta_model(
                            contra_batch[1].to(device),
                            contra_batch[2].to(device))
                    kshot_contra_reps += torch.mean(last_hidden_contra,
                                                    dim=0,
                                                    keepdim=True)
                    contra_batch_i += 1
                kshot_contra_rep = kshot_contra_reps / contra_batch_i

                source_class_prototype_reps = torch.cat(
                    [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
                    dim=0)  #(3, hidden)
                '''first get representations for support examples in target'''
                # target_kshot_entail_dataloader_subset = examples_to_features(random.sample(target_kshot_entail_examples, args.kshot), target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
                # target_kshot_nonentail_dataloader_subset = examples_to_features(random.sample(target_kshot_nonentail_examples, args.kshot), target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
                kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
                entail_batch_i = 0
                for entail_batch in target_kshot_entail_dataloader_subset:  #target_kshot_entail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(
                            entail_batch[1].to(device),
                            entail_batch[2].to(device))
                    kshot_entail_reps += torch.mean(last_hidden_entail,
                                                    dim=0,
                                                    keepdim=True)
                    entail_batch_i += 1
                kshot_entail_rep = kshot_entail_reps / entail_batch_i
                kshot_nonentail_reps = torch.zeros(1,
                                                   bert_hidden_dim).to(device)
                nonentail_batch_i = 0
                for nonentail_batch in target_kshot_nonentail_dataloader_subset:  #target_kshot_nonentail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_nonentail, _ = roberta_model(
                            nonentail_batch[1].to(device),
                            nonentail_batch[2].to(device))
                    kshot_nonentail_reps += torch.mean(last_hidden_nonentail,
                                                       dim=0,
                                                       keepdim=True)
                    nonentail_batch_i += 1
                kshot_nonentail_rep = kshot_nonentail_reps / nonentail_batch_i
                target_class_prototype_reps = torch.cat([
                    kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep
                ],
                                                        dim=0)  #(3, hidden)

                class_prototype_reps = torch.cat(
                    [source_class_prototype_reps, target_class_prototype_reps],
                    dim=0)  #(6, hidden)

                protonet.eval()

                # dev_acc = evaluation(protonet, target_dev_dataloader,  device, flag='Dev')
                # print('class_prototype_reps:', class_prototype_reps)
                dev_acc = evaluation(protonet,
                                     roberta_model,
                                     class_prototype_reps,
                                     target_dev_dataloader,
                                     device,
                                     flag='Dev')
                if dev_acc > max_dev_acc:
                    max_dev_acc = dev_acc
                    print('\n\t dev acc:', dev_acc, ' max_dev_acc:',
                          max_dev_acc, '\n')
                    if dev_acc > 0.73:  #10:0.73; 5:0.66
                        test_acc = evaluation(protonet,
                                              roberta_model,
                                              class_prototype_reps,
                                              target_test_dataloader,
                                              device,
                                              flag='Test')
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\n\t test acc:', test_acc, ' max_test_acc:',
                              max_test_acc, '\n')
                else:
                    print('\n\t dev acc:', dev_acc, ' max_dev_acc:',
                          max_dev_acc, '\n')

            if iter_co == 2000:
                break
    print('final_test_performance:', final_test_performance)
Beispiel #6
0
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                logits = model(token, mask, type_ids)
                loss = classified_loss(logits, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(eval_result['accuracy'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                # logger.info(eval_result)

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss
class MTBModel(RelationExtractor):
    def __init__(self, config: dict):
        """
        Matching the Blanks Model.

        Args:
            config: configuration parameters
        """
        super().__init__()
        self.experiment_name = config.get("experiment_name")
        self.transformer = config.get("transformer")
        self.config = config
        self.data_loader = MTBPretrainDataLoader(self.config)
        self.train_len = len(self.data_loader.train_generator)
        logger.info("Loaded %d pre-training samples." % self.train_len)

        self.model = BertModel.from_pretrained(
            model_size=self.transformer,
            pretrained_model_name_or_path=self.transformer,
            force_download=False,
        )

        self.tokenizer = self.data_loader.tokenizer
        self.model.resize_token_embeddings(len(self.tokenizer))
        e1_id = self.tokenizer.convert_tokens_to_ids("[E1]")
        e2_id = self.tokenizer.convert_tokens_to_ids("[E2]")
        if e1_id == e2_id == 1:
            raise ValueError("e1_id == e2_id == 1")

        self.train_on_gpu = torch.cuda.is_available() and config.get(
            "use_gpu", True)
        if self.train_on_gpu:
            logger.info("Train on GPU")
            self.model.cuda()

        self.criterion = MTBLoss(lm_ignore_idx=self.tokenizer.pad_token_id, )
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.01,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.config.get("lr"))
        ovr_steps = (self.config.get("epochs") *
                     len(self.data_loader.train_generator) *
                     self.config.get("max_size") * 2 /
                     self.config.get("batch_size"))
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, ovr_steps // 10, ovr_steps)

        self._start_epoch = 0
        self._best_mtb_bce = 50
        self._train_loss = []
        self._train_lm_acc = []
        self._lm_acc = []
        self._mtb_bce = []
        self.checkpoint_dir = os.path.join("models", "MTB-pretraining",
                                           self.experiment_name,
                                           self.transformer)
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

        self._batch_points_seen = 0
        self._points_seen = 0

    def load_best_model(self, checkpoint_dir: str):
        """
        Loads the current best model in the checkpoint directory.

        Args:
            checkpoint_dir: Checkpoint directory path
        """
        checkpoint = super().load_best_model(checkpoint_dir)
        return (
            checkpoint["epoch"],
            checkpoint["best_mtb_bce"],
            checkpoint["losses_per_epoch"],
            checkpoint["accuracy_per_epoch"],
            checkpoint["lm_acc"],
            checkpoint["blanks_mse"],
        )

    def train(self, **kwargs):
        """
        Runs the training.

        Arg:
            kwargs: Additional Keyword arguments
        """
        save_best_model_only = kwargs.get("save_best_model_only", False)
        results_path = os.path.join(
            "results",
            "MTB-pretraining",
            self.experiment_name,
            self.transformer,
        )
        best_model_path = os.path.join(self.checkpoint_dir,
                                       "best_model.pth.tar")
        resume = self.config.get("resume", False)
        if resume and os.path.exists(best_model_path):
            (
                self._start_epoch,
                self._best_mtb_bce,
                self._train_loss,
                self._train_lm_acc,
                self._lm_acc,
                self._mtb_bce,
            ) = self.load_best_model(self.checkpoint_dir)

        logger.info("Starting training process")
        update_size = len(self.data_loader.train_generator) // 100
        for epoch in range(self._start_epoch, self.config.get("epochs")):
            self._train_epoch(epoch, update_size, save_best_model_only)
            data = self._write_kpis(results_path)
            self._plot_results(data, results_path)
        logger.info("Finished Training.")
        return self.model

    def _plot_results(self, data, save_at):
        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(x="Epoch", y="Train Loss", ax=ax, data=data, linewidth=4)
        ax.set_title("Training Loss")
        plt.savefig(
            os.path.join(save_at,
                         "train_loss_{0}.png".format(self.transformer)))
        plt.close(fig)

        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(x="Epoch",
                     y="Val MTB Loss",
                     ax=ax,
                     data=data,
                     linewidth=4)
        ax.set_title("Val MTB Binary Cross Entropy")
        plt.savefig(
            os.path.join(save_at,
                         "val_mtb_bce_{0}.png".format(self.transformer)))
        plt.close(fig)

        tmp = data[["Epoch", "Train LM Accuracy",
                    "Val LM Accuracy"]].melt(id_vars="Epoch",
                                             var_name="Set",
                                             value_name="LM Accuracy")
        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(
            x="Epoch",
            y="LM Accuracy",
            hue="Set",
            ax=ax,
            data=tmp,
            linewidth=4,
        )
        ax.set_title("LM Accuracy")
        plt.savefig(
            os.path.join(save_at, "lm_acc_{0}.png".format(self.transformer)))
        plt.close(fig)

    def _write_kpis(self, results_path):
        Path(results_path).mkdir(parents=True, exist_ok=True)
        data = pd.DataFrame({
            "Epoch": np.arange(len(self._train_loss)),
            "Train Loss": self._train_loss,
            "Train LM Accuracy": self._train_lm_acc,
            "Val LM Accuracy": self._lm_acc,
            "Val MTB Loss": self._mtb_bce,
        })
        data.to_csv(
            os.path.join(results_path,
                         "kpis_{0}.csv".format(self.transformer)),
            index=False,
        )
        return data

    def _train_epoch(self,
                     epoch,
                     update_size,
                     save_best_model_only: bool = False):
        start_time = super()._train_epoch(epoch)

        train_lm_acc, train_loss, train_mtb_bce = [], [], []

        for i, data in enumerate(tqdm(self.data_loader.train_generator)):
            sequence, masked_label, e1_e2_start, blank_labels = data
            if sequence.shape[1] > 70:
                continue
            res = self._train_on_batch(sequence, masked_label, e1_e2_start,
                                       blank_labels)
            if res[0]:
                train_loss.append(res[0])
                train_lm_acc.append(res[1])
                train_mtb_bce.append(res[2])
            if (i % update_size) == (update_size - 1):
                logger.info(
                    f"{i+1}/{self.train_len} pools: - " +
                    f"Train loss: {np.mean(train_loss)}, " +
                    f"Train LM accuracy: {np.mean(train_lm_acc)}, " +
                    f"Train MTB Binary Cross Entropy {np.mean(train_mtb_bce)}")

        self._train_loss.append(np.mean(train_loss))
        self._train_lm_acc.append(np.mean(train_lm_acc))

        self.on_epoch_end(epoch, self._mtb_bce, self._best_mtb_bce,
                          save_best_model_only)

        logger.info(
            f"Epoch finished, took {time.time() - start_time} seconds!")
        logger.info(f"Train Loss: {self._train_loss[-1]}!")
        logger.info(f"Train LM Accuracy: {self._train_lm_acc[-1]}!")
        logger.info(f"Validation LM Accuracy: {self._lm_acc[-1]}!")
        logger.info(
            f"Validation MTB Binary Cross Entropy: {self._mtb_bce[-1]}!")

    def on_epoch_end(self,
                     epoch,
                     benchmark,
                     baseline,
                     save_best_model_only: bool = False):
        """
        Function to run at the end of an epoch.

        Runs the evaluation method, increments the scheduler, sets a new baseline and appends the KPIS.ä

        Args:
            epoch: Current epoch
            benchmark: List of benchmark results
            baseline: Current baseline. Best model performance so far
            save_best_model_only: Whether to only save the best model so far
                or all of them
        """
        eval_result = super().on_epoch_end(epoch, benchmark, baseline)
        self._best_mtb_bce = (eval_result[1]
                              if eval_result[1] < self._best_mtb_bce else
                              self._best_mtb_bce)
        self._mtb_bce.append(eval_result[1])
        self._lm_acc.append(eval_result[0])
        super().save_on_epoch_end(self._mtb_bce, self._best_mtb_bce, epoch,
                                  save_best_model_only)

    def _train_on_batch(
        self,
        sequence,
        mskd_label,
        e1_e2_start,
        blank_labels,
    ):
        mskd_label = mskd_label[(mskd_label != self.tokenizer.pad_token_id)]
        if mskd_label.shape[0] == 0:
            return None, None, None
        if self.train_on_gpu:
            mskd_label = mskd_label.cuda()
        blanks_logits, lm_logits = self._get_logits(e1_e2_start, sequence)
        loss = self.criterion(
            lm_logits,
            blanks_logits,
            mskd_label,
            blank_labels,
        )
        loss_p = loss.item()
        loss = loss / self.config.get("batch_size")
        loss.backward()
        self._batch_points_seen += len(sequence)
        self._points_seen += len(sequence)
        if self._batch_points_seen > self.config.get("batch_size"):
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()
            self._batch_points_seen = 0
        train_metrics = self.calculate_metrics(
            lm_logits,
            blanks_logits,
            mskd_label,
            blank_labels,
        )
        return loss_p / len(sequence), train_metrics[0], train_metrics[1]

    def _save_model(self, path, epoch, best_model: bool = False):
        if best_model:
            model_path = os.path.join(path, "best_model.pth.tar")
        else:
            model_path = os.path.join(
                path, "checkpoint_epoch_{0}.pth.tar").format(epoch + 1)
        torch.save(
            {
                "epoch": epoch + 1,
                "state_dict": self.model.state_dict(),
                "tokenizer": self.tokenizer,
                "best_mtb_bce": self._best_mtb_bce,
                "optimizer": self.optimizer.state_dict(),
                "scheduler": self.scheduler.state_dict(),
                "losses_per_epoch": self._train_loss,
                "accuracy_per_epoch": self._train_lm_acc,
                "lm_acc": self._lm_acc,
                "blanks_mse": self._mtb_bce,
            },
            model_path,
        )

    def _get_logits(self, e1_e2_start, x):
        attention_mask = (x != self.tokenizer.pad_token_id).float()
        token_type_ids = torch.zeros((x.shape[0], x.shape[1])).long()
        if self.train_on_gpu:
            x = x.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
        blanks_logits, lm_logits = self.model(
            x,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            e1_e2_start=e1_e2_start,
        )
        lm_logits = lm_logits[(x == self.tokenizer.mask_token_id)]
        return blanks_logits, lm_logits

    def evaluate(self) -> tuple:
        """
        Run the validation generator and return performance metrics.
        """
        total_loss = []
        lm_acc = []
        blanks_mse = []

        self.model.eval()
        with torch.no_grad():
            for data in self.data_loader.validation_generator:
                (x, masked_label, e1_e2_start, blank_labels) = data
                masked_label = masked_label[(masked_label !=
                                             self.tokenizer.pad_token_id)]
                if masked_label.shape[0] == 0:
                    continue
                if self.train_on_gpu:
                    masked_label = masked_label.cuda()
                blanks_logits, lm_logits = self._get_logits(e1_e2_start, x)

                loss = self.criterion(
                    lm_logits,
                    blanks_logits,
                    masked_label,
                    blank_labels,
                )

                total_loss += loss.cpu().numpy()
                eval_result = self.calculate_metrics(lm_logits, blanks_logits,
                                                     masked_label,
                                                     blank_labels)
                lm_acc += [eval_result[0]]
                blanks_mse += [eval_result[1]]
        self.model.train()
        return (
            np.mean(lm_acc),
            sum(b for b in blanks_mse if b != 1) /
            len([b for b in blanks_mse if b != 1]),
        )

    def calculate_metrics(
        self,
        lm_logits,
        blanks_logits,
        masked_for_pred,
        blank_labels,
    ) -> tuple:
        """
        Calculates the performance metrics of the MTB model.

        Args:
            lm_logits: Language model Logits per word in vocabulary
            blanks_logits: Blank logits
            masked_for_pred: List of marked tokens
            blank_labels: Blank labels
        """
        lm_logits_pred_ids = torch.softmax(lm_logits, dim=-1).max(1)[1]
        lm_accuracy = ((lm_logits_pred_ids == masked_for_pred).sum().float() /
                       len(masked_for_pred)).item()

        pos_idxs = np.where(blank_labels == 1)[0]
        neg_idxs = np.where(blank_labels == 0)[0]

        if len(pos_idxs) > 1:
            # positives
            pos_logits = []
            for pos1, pos2 in combinations(pos_idxs, 2):
                pos_logits.append(
                    self._get_mtb_logits(blanks_logits[pos1, :],
                                         blanks_logits[pos2, :]))
            pos_logits = torch.stack(pos_logits, dim=0)
            pos_labels = [1.0 for _ in range(pos_logits.shape[0])]
        else:
            pos_logits, pos_labels = torch.FloatTensor([]), []
            if blanks_logits.is_cuda:
                pos_logits = pos_logits.cuda()

        # negatives
        neg_logits = []
        for pos_idx in pos_idxs:
            for neg_idx in neg_idxs:
                neg_logits.append(
                    MTBModel._get_mtb_logits(blanks_logits[pos_idx, :],
                                             blanks_logits[neg_idx, :]))
        neg_logits = torch.stack(neg_logits, dim=0)
        neg_labels = [0.0 for _ in range(neg_logits.shape[0])]

        blank_labels = torch.FloatTensor(pos_labels + neg_labels)
        blank_pred = torch.cat([pos_logits, neg_logits], dim=0)
        bce = nn.BCEWithLogitsLoss(reduction="mean")(
            blank_pred.detach().cpu(), blank_labels.detach().cpu())

        return lm_accuracy, bce.numpy()

    @classmethod
    def _get_mtb_logits(cls, f1_vec, f2_vec):
        factor = 1 / (torch.norm(f1_vec) * torch.norm(f2_vec))
        return factor * torch.dot(f1_vec, f2_vec)
def main():
    parser = ArgumentParser()
    parser.add_argument('--pregenerated_neg_data', type=Path, required=True)
    parser.add_argument('--pregenerated_data', type=Path, required=True)
    parser.add_argument('--output_dir', type=Path, required=True)
    parser.add_argument(
        "--bert_model",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )

    parser.add_argument("--max_seq_len", default=512, type=int)

    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--epochs",
                        type=int,
                        default=3,
                        help="Number of epochs to train for")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--kr_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--kr_freq", default=0.7, type=float)
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--learning_rate",
                        default=1e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    assert args.pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"

    samples_per_epoch = []
    for i in range(args.epochs):
        epoch_file = args.pregenerated_data / f"epoch_{i}.json"
        metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit("No training data was found!")
            print(
                f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})."
            )
            print(
                "This script will loop over the available data, but training diversity may be negatively impacted."
            )
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.epochs

    if args.local_rank == -1 or args.no_cuda:
        print(torch.cuda.is_available())
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
        print(n_gpu)
        print("no gpu?")
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        print("GPU Device: ", device)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    logging.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # if n_gpu > 0:
    torch.cuda.manual_seed_all(args.seed)

    pt_output = Path(getenv('PT_OUTPUT_DIR', ''))
    args.output_dir = Path(os.path.join(pt_output, args.output_dir))

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logging.warning(
            f"Output directory ({args.output_dir}) already exists and is not empty!"
        )
    args.output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    total_train_examples = 0
    for i in range(args.epochs):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(total_train_examples /
                                       args.train_batch_size /
                                       args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    # Prepare model
    config = BertConfig.from_pretrained(args.bert_model)
    # config.num_hidden_layers = args.num_layers
    model = FuckWrapper(config)
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=num_train_optimization_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    global_step = 0
    logging.info("***** Running training *****")
    logging.info(f"  Num examples = {total_train_examples}")
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)
    model.train()

    before_train_path = Path(os.path.join(args.output_dir, "before_training"))
    print("Before training path: ", before_train_path)
    before_train_path.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(os.path.join(args.output_dir, "before_training"))
    tokenizer.save_pretrained(os.path.join(args.output_dir, "before_training"))

    neg_epoch_dataset = PregeneratedDataset(
        epoch=0,
        training_path=args.pregenerated_neg_data,
        tokenizer=tokenizer,
        num_data_epochs=num_data_epochs,
        reduce_memory=args.reduce_memory)
    if args.local_rank == -1:
        neg_train_sampler = RandomSampler(neg_epoch_dataset)
    else:
        neg_train_sampler = DistributedSampler(neg_epoch_dataset)

    neg_train_dataloader = DataLoader(neg_epoch_dataset,
                                      sampler=neg_train_sampler,
                                      batch_size=args.train_batch_size)

    def inf_train_gen():
        while True:
            for kr_step, kr_batch in enumerate(neg_train_dataloader):
                yield kr_step, kr_batch

    kr_gen = inf_train_gen()

    for epoch in range(args.epochs):
        epoch_dataset = PregeneratedDataset(
            epoch=epoch,
            training_path=args.pregenerated_data,
            tokenizer=tokenizer,
            num_data_epochs=num_data_epochs,
            reduce_memory=args.reduce_memory)
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)

        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        if n_gpu > 1 and args.local_rank == -1 or (n_gpu <= 1):
            logging.info("** ** * Saving fine-tuned model ** ** * ")
            model.save_pretrained(args.output_dir)
            tokenizer.save_pretrained(args.output_dir)

        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
            for step, batch in enumerate(train_dataloader):
                model.train()

                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch

                outputs = model(input_ids=input_ids,
                                attention_mask=input_mask,
                                token_type_ids=segment_ids,
                                masked_lm_labels=lm_label_ids,
                                negated=False)
                loss = outputs[0]
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)

                if args.local_rank == 0 or args.local_rank == -1:
                    nb_tr_steps += 1
                    pbar.update(1)
                    mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                    pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    scheduler.step()  # Update learning rate schedule
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if random.random() > args.kr_freq:
                    kr_step, kr_batch = next(kr_gen)
                    kr_batch = tuple(t.to(device) for t in kr_batch)
                    input_ids, input_mask, segment_ids, lm_label_ids = kr_batch

                    outputs = model(input_ids=input_ids,
                                    attention_mask=input_mask,
                                    token_type_ids=segment_ids,
                                    masked_lm_labels=lm_label_ids,
                                    negated=True)
                    loss = outputs[0]
                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                    tr_loss += loss.item()
                    nb_tr_examples += input_ids.size(0)
                    if args.local_rank == -1:
                        nb_tr_steps += 1
                        mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                        pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        scheduler.step()  # Update learning rate schedule
                        optimizer.step()
                        optimizer.zero_grad()
                        global_step += 1

    # Save a trained model
    if n_gpu > 1 and args.local_rank == -1 or (n_gpu <= 1):
        logging.info("** ** * Saving fine-tuned model ** ** * ")
        model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
Beispiel #9
0
def train(conf):
    ### device
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = conf.device
    conf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ### directory
    os.makedirs(conf.out_dir, exist_ok=True)
    best_model_dir = os.path.join(conf.out_dir, "best_model")
    os.makedirs(best_model_dir)
    latest_model_dir = os.path.join(conf.out_dir, "latest_model")
    os.makedirs(latest_model_dir)
    ### global variables
    best_acc = -1
    loss_fw = open(os.path.join(conf.out_dir, "loss.txt"),
                   "w",
                   encoding="utf8")
    acc_fw = open(os.path.join(conf.out_dir, "acc.txt"), "w", encoding="utf8")

    ### train data
    logging.info("get train data loader...")
    tokenizer = BertTokenizer.from_pretrained(conf.pretrain_model_path)
    train_data_iter = FullInteractionJoinStringDataSet(
        file_path=conf.train_file_path,
        ent_path=conf.ent_path,
        tokenizer=tokenizer,
        batch_size=conf.batch_size,
        max_len=conf.max_len)

    total_steps = int(train_data_iter.steps * conf.num_epochs)
    steps_per_epoch = train_data_iter.steps
    if conf.warmup < 1:
        warmup_steps = int(total_steps * conf.warmup)
    else:
        warmup_steps = int(conf.warmup)
    ### get dev evaluator
    evaluator = TopKAccEvaluator(conf.dev_file_path,
                                 conf.ent_path,
                                 tokenizer,
                                 conf.device,
                                 batch_size=conf.dev_batch_size)
    ### model
    logging.info("define model...")
    if "min" in conf.out_dir:
        logging.info("use min bert model for recall!!!!!!!")
        model = RetrievalModel(
            BertConfig(vocab_size=8021,
                       hidden_size=66,
                       num_hidden_layers=3,
                       num_attention_heads=3,
                       intermediate_size=66)).to(conf.device)
        model.tokenizer = BertTokenizer.from_pretrained(
            conf.pretrain_model_path)
    else:
        model = RetrievalModel(conf.pretrain_model_path).to(conf.device)

    model.train()
    ### optimizer
    logging.info("define optimizer...")
    no_decay = ["bias", "LayerNorm.weight"]
    paras = dict(model.named_parameters())
    logging.info(
        "=================================== trained parameters ==================================="
    )
    for n, p in paras.items():
        logging.info("{}".format(n))
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in paras.items()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01,
        },
        {
            "params":
            [p for n, p in paras.items() if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=conf.lr)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=total_steps)

    ### train
    global_step = 0
    logging.info("start train")
    for epoch in range(conf.num_epochs):
        for step, batch in enumerate(train_data_iter):
            global_step += 1
            step += 1
            batch_data = [i.to(conf.device) for i in batch]
            if step < 2:
                print(batch_data[0].shape)
            pos_score = model(input_ids=batch_data[0],
                              attention_mask=batch_data[1],
                              token_type_ids=batch_data[2])
            neg_score = model(input_ids=batch_data[3],
                              attention_mask=batch_data[4],
                              token_type_ids=batch_data[5])
            loss = torch.nn.functional.relu(neg_score - pos_score +
                                            conf.margin)
            loss = loss.mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_([v for k, v in paras.items()],
                                           max_norm=1)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            optimizer.zero_grad()
            if step % conf.print_steps == 0:
                logging.info("epoch:{},\tstep:{}/{},\tloss:{}".format(
                    epoch, step, steps_per_epoch, loss.data))
                loss_fw.write("epoch:{},\tstep:{}/{},\tloss:{}\n".format(
                    epoch, step, steps_per_epoch, loss.data))
                loss_fw.flush()

            if step % conf.evaluation_steps == 0 or step == steps_per_epoch - 1:
                logging.info("start evaluate...")
                ### eval dataset
        logging.info("start evaluate...")
        acc = evaluator.eval_acc(model, "test.xlsx")

        if acc > best_acc:
            logging.info("save best model to {}".format(best_model_dir))
            best_acc = acc
            model.save(best_model_dir)
        model.save(latest_model_dir)
        acc_fw.write("epoch:{},\tstep:{}/{},\tacc:{}\n".format(
            epoch, step, steps_per_epoch, acc))

    acc_fw.close()
    loss_fw.close()
Beispiel #10
0
def main(args):
    local_config = json.load(open(args.local_config_path))
    local_config['loss'] = args.loss
    local_config['data_dir'] = args.data_dir
    local_config['train_batch_size'] = args.train_batch_size
    local_config[
        'gradient_accumulation_steps'] = args.gradient_accumulation_steps
    local_config['lr_scheduler'] = args.lr_scheduler
    local_config['model_name'] = args.model_name
    local_config['pool_type'] = args.pool_type
    local_config['seed'] = args.seed
    local_config['do_train'] = args.do_train
    local_config['do_validation'] = args.do_validation
    local_config['do_eval'] = args.do_eval
    local_config['use_cuda'] = args.use_cuda.lower() == 'true'
    local_config['num_train_epochs'] = args.num_train_epochs
    local_config['eval_batch_size'] = args.eval_batch_size
    local_config['max_seq_len'] = args.max_seq_len
    local_config['syns'] = ["Target", "Synonym"]
    local_config['target_embeddings'] = args.target_embeddings
    local_config['symmetric'] = args.symmetric.lower() == 'true'
    local_config['mask_syns'] = args.mask_syns
    local_config['train_scd'] = args.train_scd
    local_config['ckpt_path'] = args.ckpt_path
    local_config['head_batchnorm'] = args.head_batchnorm
    local_config['head_hidden_size'] = args.head_hidden_size
    local_config['linear_head'] = args.linear_head.lower() == 'true'
    local_config['emb_size_for_cosine'] = args.emb_size_for_cosine
    local_config['add_fc_layer'] = args.add_fc_layer

    if local_config['do_train'] and os.path.exists(args.output_dir):
        from glob import glob
        model_weights = glob(os.path.join(args.output_dir, '*.bin'))
        if model_weights:
            print(f'{model_weights}: already computed: skipping ...')
            return
        else:
            print(
                f'already existing {args.output_dir}. but without model weights ...'
            )
            return

    device = torch.device("cuda" if local_config['use_cuda'] else "cpu")
    n_gpu = torch.cuda.device_count()

    if local_config['gradient_accumulation_steps'] < 1:
        raise ValueError(
            "gradient_accumulation_steps parameter should be >= 1")

    local_config['train_batch_size'] = \
        local_config['train_batch_size'] // local_config['gradient_accumulation_steps']

    if local_config['do_train']:
        random.seed(local_config['seed'])
        np.random.seed(local_config['seed'])
        torch.manual_seed(local_config['seed'])

    if n_gpu > 0:
        torch.cuda.manual_seed_all(local_config['seed'])

    if not local_config['do_train'] and not local_config['do_eval']:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if local_config['do_train'] and not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
        os.makedirs(os.path.join(args.output_dir, 'nen-nen-weights'))
    elif local_config['do_train'] or local_config['do_validation']:
        raise ValueError(args.output_dir, 'output_dir already exists')

    suffix = datetime.now().isoformat().replace('-', '_').replace(
        ':', '_').split('.')[0].replace('T', '-')
    if local_config['do_train']:
        train_writer = SummaryWriter(log_dir=os.path.join(
            args.output_dir, f'tensorboard-{suffix}', 'train'))
        dev_writer = SummaryWriter(log_dir=os.path.join(
            args.output_dir, f'tensorboard-{suffix}', 'dev'))

        logger.addHandler(
            logging.FileHandler(
                os.path.join(args.output_dir, f"train_{suffix}.log"), 'w'))
        eval_logger.addHandler(
            logging.FileHandler(
                os.path.join(args.output_dir, f"scores_{suffix}.log"), 'w'))
    else:
        logger.addHandler(
            logging.FileHandler(
                os.path.join(args.ckpt_path, f"eval_{suffix}.log"), 'w'))

    logger.info(args)
    logger.info(json.dumps(vars(args), indent=4))
    if args.do_train:
        json.dump(
            local_config,
            open(os.path.join(args.output_dir, 'local_config.json'), 'w'))
        json.dump(vars(args),
                  open(os.path.join(args.output_dir, 'args.json'), 'w'))
    logger.info("device: {}, n_gpu: {}".format(device, n_gpu))

    with open(os.path.join(args.output_dir, 'local_config.json'), 'w') as outp:
        json.dump(local_config, outp, indent=4)
    with open(os.path.join(args.output_dir, 'args.json'), 'w') as outp:
        json.dump(vars(args), outp, indent=4)

    syns = sorted(local_config['syns'])
    id2classifier = {i: classifier for i, classifier in enumerate(syns)}

    model_name = local_config['model_name']
    data_processor = DataProcessor()

    train_dir = os.path.join(local_config['data_dir'], 'train/')
    dev_dir = os.path.join(local_config['data_dir'], 'dev')

    if local_config['do_train']:

        config = configs[local_config['model_name']]
        config = config.from_pretrained(local_config['model_name'],
                                        hidden_dropout_prob=args.dropout)
        if args.ckpt_path != '':
            model_path = args.ckpt_path
        else:
            model_path = local_config['model_name']
        model = models[model_name].from_pretrained(
            model_path,
            cache_dir=str(PYTORCH_PRETRAINED_BERT_CACHE),
            local_config=local_config,
            data_processor=data_processor,
            config=config)

        param_optimizer = list(model.named_parameters())

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                param for name, param in param_optimizer
                if not any(nd in name for nd in no_decay)
            ],
            'weight_decay':
            float(args.weight_decay)
        }, {
            'params': [
                param for name, param in param_optimizer
                if any(nd in name for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=float(args.learning_rate),
                          eps=1e-6,
                          betas=(0.9, 0.98),
                          correct_bias=True)

        train_features = model.convert_dataset_to_features(train_dir, logger)

        if args.train_mode == 'sorted' or args.train_mode == 'random_sorted':
            train_features = sorted(train_features,
                                    key=lambda f: np.sum(f.input_mask))
        else:
            random.shuffle(train_features)


#        import pdb; pdb.set_trace()
        train_dataloader = \
            get_dataloader_and_tensors(train_features, local_config['train_batch_size'])
        train_batches = [batch for batch in train_dataloader]

        num_train_optimization_steps = \
            len(train_batches) // local_config['gradient_accumulation_steps'] * \
                local_config['num_train_epochs']

        warmup_steps = int(args.warmup_proportion *
                           num_train_optimization_steps)
        if local_config['lr_scheduler'] == 'linear_warmup':
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_train_optimization_steps)
        elif local_config['lr_scheduler'] == 'constant_warmup':
            scheduler = get_constant_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps)
        logger.info("***** Training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", local_config['train_batch_size'])
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if local_config['do_validation']:
            dev_features = model.convert_dataset_to_features(dev_dir, logger)
            logger.info("***** Dev *****")
            logger.info("  Num examples = %d", len(dev_features))
            logger.info("  Batch size = %d", local_config['eval_batch_size'])
            dev_dataloader = \
                get_dataloader_and_tensors(dev_features, local_config['eval_batch_size'])
            test_dir = os.path.join(local_config['data_dir'], 'test/')
            if os.path.exists(test_dir):
                test_features = model.convert_dataset_to_features(
                    test_dir, test_logger)
                logger.info("***** Test *****")
                logger.info("  Num examples = %d", len(test_features))
                logger.info("  Batch size = %d",
                            local_config['eval_batch_size'])

                test_dataloader = \
                    get_dataloader_and_tensors(test_features, local_config['eval_batch_size'])

        best_result = defaultdict(float)

        eval_step = max(1, len(train_batches) // args.eval_per_epoch)

        start_time = time.time()
        global_step = 0

        model.to(device)
        lr = float(args.learning_rate)
        for epoch in range(1, 1 + local_config['num_train_epochs']):
            tr_loss = 0
            nb_tr_examples = 0
            nb_tr_steps = 0
            cur_train_loss = defaultdict(float)

            model.train()
            logger.info("Start epoch #{} (lr = {})...".format(
                epoch,
                scheduler.get_lr()[0]))
            if args.train_mode == 'random' or args.train_mode == 'random_sorted':
                random.shuffle(train_batches)

            train_bar = tqdm(train_batches,
                             total=len(train_batches),
                             desc='training ... ')
            for step, batch in enumerate(train_bar):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, token_type_ids, \
                syn_labels, positions = batch
                train_loss, _ = model(input_ids=input_ids,
                                      token_type_ids=token_type_ids,
                                      attention_mask=input_mask,
                                      input_labels={
                                          'syn_labels': syn_labels,
                                          'positions': positions
                                      })
                loss = train_loss['total'].mean().item()
                for key in train_loss:
                    cur_train_loss[key] += train_loss[key].mean().item()

                train_bar.set_description(
                    f'training... [epoch == {epoch} / {local_config["num_train_epochs"]}, loss == {loss}]'
                )

                loss_to_optimize = train_loss['total']

                if local_config['gradient_accumulation_steps'] > 1:
                    loss_to_optimize = \
                        loss_to_optimize / local_config['gradient_accumulation_steps']

                loss_to_optimize.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                tr_loss += loss_to_optimize.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                if (step +
                        1) % local_config['gradient_accumulation_steps'] == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1

                if local_config['do_validation'] and (step +
                                                      1) % eval_step == 0:
                    logger.info(
                        'Ep: {}, Stp: {}/{}, usd_t={:.2f}s, loss={:.6f}'.
                        format(epoch, step + 1, len(train_batches),
                               time.time() - start_time,
                               tr_loss / nb_tr_steps))

                    cur_train_mean_loss = {}
                    for key, value in cur_train_loss.items():
                        cur_train_mean_loss[f'train_{key}_loss'] = \
                            value / nb_tr_steps

                    dev_predictions = os.path.join(args.output_dir,
                                                   'dev_predictions')

                    metrics = predict(model,
                                      dev_dataloader,
                                      dev_predictions,
                                      dev_features,
                                      args,
                                      cur_train_mean_loss=cur_train_mean_loss,
                                      logger=eval_logger)

                    metrics['global_step'] = global_step
                    metrics['epoch'] = epoch
                    metrics['learning_rate'] = scheduler.get_lr()[0]
                    metrics['batch_size'] = \
                        local_config['train_batch_size'] * local_config['gradient_accumulation_steps']

                    for key, value in metrics.items():
                        dev_writer.add_scalar(key, value, global_step)
                    scores_to_logger = tuple([
                        round(metrics[save_by_score] * 100.0, 2)
                        for save_by_score in args.save_by_score.split('+')
                    ])
                    logger.info(
                        f"dev %s (lr=%s, epoch=%d): %s" %
                        (args.save_by_score, str(
                            scheduler.get_lr()[0]), epoch, scores_to_logger))

                    predict_parts = [
                        part for part in metrics if part.endswith('.score')
                        and metrics[part] > args.start_save_threshold
                        and metrics[part] > best_result[part]
                    ]
                    if len(predict_parts) > 0:
                        best_dev_predictions = os.path.join(
                            args.output_dir, 'best_dev_predictions')
                        dev_predictions = os.path.join(args.output_dir,
                                                       'dev_predictions')
                        os.makedirs(best_dev_predictions, exist_ok=True)
                        for part in predict_parts:
                            logger.info(
                                "!!! Best dev %s (lr=%s, epoch=%d): %.2f -> %.2f"
                                % (part, str(scheduler.get_lr()[0]), epoch,
                                   best_result[part] * 100.0,
                                   metrics[part] * 100.0))
                            best_result[part] = metrics[part]
                            if [
                                    save_weight for save_weight in
                                    args.save_by_score.split('+')
                                    if save_weight == part
                            ]:
                                os.makedirs(os.path.join(
                                    args.output_dir, part),
                                            exist_ok=True)
                                output_model_file = os.path.join(
                                    args.output_dir, part, WEIGHTS_NAME)
                                save_model(args, model, output_model_file,
                                           metrics)
                            if 'nen-nen' not in part:
                                os.system(
                                    f'cp {dev_predictions}/{".".join(part.split(".")[1:-1])}* {best_dev_predictions}/'
                                )
                            else:
                                output_model_file = os.path.join(
                                    args.output_dir, 'nen-nen-weights',
                                    WEIGHTS_NAME)
                                save_model(args, model, output_model_file,
                                           metrics)

                        # dev_predictions = os.path.join(args.output_dir, 'dev_predictions')
                        # predict(
                        #     model, dev_dataloader, dev_predictions,
                        #     dev_features, args, only_parts='+'.join(predict_parts)
                        # )
                        # best_dev_predictions = os.path.join(args.output_dir, 'best_dev_predictions')
                        # os.makedirs(best_dev_predictions, exist_ok=True)
                        # os.system(f'mv {dev_predictions}/* {best_dev_predictions}/')
                        if 'scd' not in '+'.join(
                                predict_parts) and os.path.exists(test_dir):
                            test_predictions = os.path.join(
                                args.output_dir, 'test_predictions')
                            test_metrics = predict(
                                model,
                                test_dataloader,
                                test_predictions,
                                test_features,
                                args,
                                only_parts='+'.join([
                                    'test' + part[3:] for part in predict_parts
                                    if 'nen-nen' not in part
                                ]))
                            best_test_predictions = os.path.join(
                                args.output_dir, 'best_test_predictions')
                            os.makedirs(best_test_predictions, exist_ok=True)
                            os.system(
                                f'mv {test_predictions}/* {best_test_predictions}/'
                            )

                            for key, value in test_metrics.items():
                                if key.endswith('score'):
                                    dev_writer.add_scalar(
                                        key, value, global_step)

            if args.log_train_metrics:
                metrics = predict(model,
                                  train_dataloader,
                                  os.path.join(args.output_dir,
                                               'train_predictions'),
                                  train_features,
                                  args,
                                  logger=logger)
                metrics['global_step'] = global_step
                metrics['epoch'] = epoch
                metrics['learning_rate'] = scheduler.get_lr()[0]
                metrics['batch_size'] = \
                    local_config['train_batch_size'] * local_config['gradient_accumulation_steps']

                for key, value in metrics.items():
                    train_writer.add_scalar(key, value, global_step)

    if local_config['do_eval']:
        assert args.ckpt_path != '', 'in do_eval mode ckpt_path should be specified'
        test_dir = args.eval_input_dir
        config = configs[model_name].from_pretrained(model_name)
        model = models[model_name].from_pretrained(
            args.ckpt_path,
            local_config=local_config,
            data_processor=data_processor,
            config=config)
        model.to(device)
        test_features = model.convert_dataset_to_features(
            test_dir, test_logger)
        logger.info("***** Test *****")
        logger.info("  Num examples = %d", len(test_features))
        logger.info("  Batch size = %d", local_config['eval_batch_size'])

        test_dataloader = \
            get_dataloader_and_tensors(test_features, local_config['eval_batch_size'])

        metrics = predict(model,
                          test_dataloader,
                          os.path.join(args.output_dir, args.eval_output_dir),
                          test_features,
                          args,
                          compute_metrics=True)
        print(metrics)
        with open(
                os.path.join(args.output_dir, args.eval_output_dir,
                             'metrics.txt'), 'w') as outp:
            print(metrics, file=outp)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--DomainName",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_data_aug",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--meta_epochs',
                        type=int,
                        default=10,
                        help="random seed for initialization")
    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()


    processors = {
        "rte": RteProcessor
    }

    output_modes = {
        "rte": "classification"
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")


    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    # label_list = processor.get_labels() #["entailment", "neutral", "contradiction"]
    # label_list = ['How_do_I_create_a_profile_v4', 'Profile_Switch_v4', 'Deactivate_Active_Devices_v4', 'Ads_on_Hulu_v4', 'Watching_Hulu_with_Live_TV_v4', 'Hulu_Costs_and_Commitments_v4', 'offline_downloads_v4', 'womens_world_cup_v5', 'forgot_username_v4', 'confirm_account_cancellation_v4', 'Devices_to_Watch_HBO_on_v4', 'remove_add_on_v4', 'Internet_Speed_for_HD_and_4K_v4', 'roku_related_questions_v4', 'amazon_related_questions_v4', 'Clear_Browser_Cache_v4', 'ads_on_ad_free_plan_v4', 'inappropriate_ads_v4', 'itunes_related_questions_v4', 'Internet_Speed_Recommendations_v4', 'NBA_Basketball_v5', 'unexpected_charges_v4', 'change_billing_date_v4', 'NFL_on_Hulu_v5', 'How_to_delete_a_profile_v4', 'Devices_to_Watch_Hulu_on_v4', 'Manage_your_Hulu_subscription_v4', 'cancel_hulu_account_v4', 'disney_bundle_v4', 'payment_issues_v4', 'home_network_location_v4', 'Main_Menu_v4', 'Resetting_Hulu_Password_v4', 'Update_Payment_v4', 'I_need_general_troubleshooting_help_v4', 'What_is_Hulu_v4', 'sprint_related_questions_v4', 'Log_into_TV_with_activation_code_v4', 'Game_of_Thrones_v4', 'video_playback_issues_v4', 'How_to_edit_a_profile_v4', 'Watchlist_Remove_Video_v4', 'spotify_related_questions_v4', 'Deactivate_Login_Sessions_v4', 'Transfer_to_Agent_v4', 'Use_Hulu_Internationally_v4']

    meta_train_examples, meta_dev_examples, meta_test_examples, meta_label_list = load_CLINC150_without_specific_domain(args.DomainName)
    train_examples, dev_examples, eval_examples, finetune_label_list = load_CLINC150_with_specific_domain_sequence(args.DomainName, args.kshot, augment=args.do_data_aug)
    # oos_dev_examples, oos_test_examples = load_OOS()
    # dev_examples+=oos_dev_examples
    # eval_examples+=oos_test_examples

    eval_label_list = finetune_label_list#+['oos']
    label_list=finetune_label_list+meta_label_list#+['oos']
    assert len(label_list) ==  15*10
    num_labels = len(label_list)
    assert num_labels == 15*10


    model = RobertaForSequenceClassification(num_labels)


    tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)
    # tokenizer = BertTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters,
                             lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        meta_train_features = convert_examples_to_features(
            meta_train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)


        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        '''load dev set'''
        # dev_examples = processor.get_RTE_as_dev('/export/home/Dataset/glue_data/RTE/dev.tsv')
        # dev_examples = get_data_hulu('dev')
        dev_features = convert_examples_to_features(
            dev_examples, eval_label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features], dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features], dtype=torch.long)
        dev_all_segment_ids = torch.tensor([f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features], dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask, dev_all_segment_ids, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.eval_batch_size)


        '''load test set'''
        # eval_examples = processor.get_RTE_as_test('/export/home/Dataset/RTE/test_RTE_1235.txt')
        # eval_examples = get_data_hulu('test')
        eval_features = convert_examples_to_features(
            eval_examples, eval_label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        eval_all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        eval_all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        eval_all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        eval_all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)

        eval_data = TensorDataset(eval_all_input_ids, eval_all_input_mask, eval_all_segment_ids, eval_all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in meta_train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in meta_train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in meta_train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in meta_train_features], dtype=torch.long)

        meta_train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        meta_train_sampler = RandomSampler(meta_train_data)
        meta_train_dataloader = DataLoader(meta_train_data, sampler=meta_train_sampler, batch_size=args.train_batch_size*10)


        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
        '''support labeled examples in order, group in kshot size'''
        support_sampler = SequentialSampler(train_data)
        support_dataloader = DataLoader(train_data, sampler=support_sampler, batch_size=args.kshot)


        iter_co = 0
        max_dev_test = [0,0]
        fine_max_dev = False
        '''first train on meta_train tasks'''
        for meta_epoch_i in trange(args.meta_epochs, desc="metaEpoch"):
            for step, batch in enumerate(tqdm(meta_train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                logits,_,_ = model(input_ids, input_mask, None, labels=None)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            '''get class representation after each epoch of pretraining'''
            model.eval()
            last_reps_list = []
            for input_ids, input_mask, segment_ids, label_ids in support_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)
                # gold_label_ids+=list(label_ids.detach().cpu().numpy())

                with torch.no_grad():
                    logits, last_reps, bias = model(input_ids, input_mask, None, labels=None)
                last_reps_list.append(last_reps.mean(dim=0, keepdim=True)) #(1, 1024)
            class_reps_pretraining = torch.cat(last_reps_list, dim=0) #(15, 1024)

            '''
            start evaluate on dev set after this epoch
            '''
            for idd, dev_or_test_dataloader in enumerate([dev_dataloader, eval_dataloader]):
                if idd == 0:
                    logger.info("***** Running dev *****")
                    logger.info("  Num examples = %d", len(dev_examples))
                else:
                    logger.info("***** Running test *****")
                    logger.info("  Num examples = %d", len(eval_examples))
                # logger.info("  Batch size = %d", args.eval_batch_size)

                eval_loss = 0
                nb_eval_steps = 0
                preds = []
                gold_label_ids = []
                # print('Evaluating...')
                for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)
                    gold_label_ids+=list(label_ids.detach().cpu().numpy())

                    with torch.no_grad():
                        logits_LR, reps_batch, _ = model(input_ids, input_mask, None, labels=None)
                    # logits = logits[0]

                    '''pretraining logits'''
                    raw_similarity_scores = torch.mm(reps_batch,torch.transpose(class_reps_pretraining, 0,1)) #(batch, 15)
                    # print('raw_similarity_scores shaoe:', raw_similarity_scores.shape)
                    # print('bias_finetune:', bias_finetune.shape)
                    biased_similarity_scores = raw_similarity_scores#+bias_finetune.view(-1, raw_similarity_scores.shape[1])
                    logits_pretrain = torch.max(biased_similarity_scores.view(args.eval_batch_size, -1, len(finetune_label_list)), dim=1)[0] #(batch, #class)
                    '''finetune logits'''
                    # raw_similarity_scores = torch.mm(reps_batch,torch.transpose(class_reps_finetune, 0,1)) #(batch, 15*history)
                    # biased_similarity_scores = raw_similarity_scores+bias_finetune.view(-1, raw_similarity_scores.shape[1])
                    # logits_finetune = torch.max(biased_similarity_scores.view(args.eval_batch_size, -1, len(finetune_label_list)), dim=1)[0] #(batch, #class)

                    logits = logits_pretrain#+logits_finetune
                    # logits = (1-0.9)*logits+0.9*logits_LR

                    if len(preds) == 0:
                        preds.append(logits.detach().cpu().numpy())
                    else:
                        preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)

                # eval_loss = eval_loss / nb_eval_steps
                preds = preds[0]
                pred_probs = softmax(preds,axis=1)
                pred_label_ids = list(np.argmax(pred_probs, axis=1))
                gold_label_ids = gold_label_ids
                assert len(pred_label_ids) == len(gold_label_ids)
                hit_co = 0

                for k in range(len(pred_label_ids)):
                    if pred_label_ids[k] == gold_label_ids[k]:
                        hit_co +=1
                test_acc = hit_co/len(gold_label_ids)

                if idd == 0: # this is dev
                    if test_acc > max_dev_acc:
                        max_dev_acc = test_acc
                        print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                        fine_max_dev=True
                        max_dev_test[0] = round(max_dev_acc*100, 2)
                    else:
                        print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                        break
                else: # this is test
                    if test_acc > max_test_acc:
                        max_test_acc = test_acc
                    if fine_max_dev:
                        max_dev_test[1] = round(test_acc*100,2)
                        fine_max_dev = False
                    print('\ttest acc:', test_acc, ' max_test_acc:', max_test_acc, '\n')


        print('final:', str(max_dev_test[0])+'/'+str(max_dev_test[1]), '\n')
Beispiel #12
0
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        adversarial_loss = torch.nn.BCELoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                if args.BCE:
                    loss = adversarial_loss(discriminator_output,
                                            (y != 0.0).float())
                else:
                    loss = classified_loss(discriminator_output, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss
Beispiel #13
0
def main():
    parser = argparse.ArgumentParser()


    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")


    args = parser.parse_args()



    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


    mctest_path = '/export/home/Dataset/MCTest/Statements/'
    target_kshot_entail_examples, target_kshot_nonentail_examples = get_MCTest_train(mctest_path+'mc500.train.statements.pairs', args.kshot) #train_pu_half_v1.txt
    target_dev_examples, target_test_examples = get_MCTest_dev_and_test(mctest_path+'mc500.dev.statements.pairs', mctest_path+'mc500.test.statements.pairs')


    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train('/export/home/Dataset/glue_data/MNLI/train.tsv', args.kshot)
    source_examples = source_kshot_entail+ source_kshot_neural+ source_kshot_contra+ source_remaining_examples
    target_label_list = ["ENTAILMENT", "UNKNOWN"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:', len(target_dev_examples), 'test size:', len(target_test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(source_remaining_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load('/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'),strict=False)
    roberta_model.to(device)
    roberta_model.eval()

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters,
                             lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(source_kshot_entail, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(source_kshot_neural, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(source_kshot_contra, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(source_remaining_examples, source_label_list, args, tokenizer, args.train_batch_size, "classification", dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(target_kshot_entail_examples, target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(target_kshot_nonentail_examples, target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples, target_label_list, args, tokenizer, args.eval_batch_size, "classification", dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples, target_label_list, args, tokenizer, args.eval_batch_size, "classification", dataloader_mode='sequential')

    '''starting to train'''
    iter_co = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            _, input_ids, input_mask, segment_ids, label_ids_batch = batch

            roberta_model.eval()
            with torch.no_grad():
                last_hidden_batch, _ = roberta_model(input_ids, input_mask)
            '''
            retrieve rep for support examples
            '''
            kshot_entail_reps = []
            for entail_batch in source_kshot_entail_dataloader:
                entail_batch = tuple(t.to(device) for t in entail_batch)
                _, input_ids, input_mask, segment_ids, label_ids = entail_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_entail, _ = roberta_model(input_ids, input_mask)
                kshot_entail_reps.append(last_hidden_entail)
            kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0), dim=0, keepdim=True)
            kshot_neural_reps = []
            for neural_batch in source_kshot_neural_dataloader:
                neural_batch = tuple(t.to(device) for t in neural_batch)
                _, input_ids, input_mask, segment_ids, label_ids = neural_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_neural, _ = roberta_model(input_ids, input_mask)
                kshot_neural_reps.append(last_hidden_neural)
            kshot_neural_rep = torch.mean(torch.cat(kshot_neural_reps, dim=0), dim=0, keepdim=True)
            kshot_contra_reps = []
            for contra_batch in source_kshot_contra_dataloader:
                contra_batch = tuple(t.to(device) for t in contra_batch)
                _, input_ids, input_mask, segment_ids, label_ids = contra_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_contra, _ = roberta_model(input_ids, input_mask)
                kshot_contra_reps.append(last_hidden_contra)
            kshot_contra_rep = torch.mean(torch.cat(kshot_contra_reps, dim=0), dim=0, keepdim=True)

            class_prototype_reps = torch.cat([kshot_entail_rep, kshot_neural_rep, kshot_contra_rep], dim=0) #(3, hidden)

            '''forward to model'''
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)

            loss_fct = CrossEntropyLoss()

            loss = loss_fct(batch_logits.view(-1, source_num_labels), label_ids_batch.view(-1))

            if n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            iter_co+=1
            # if iter_co %20==0:
            if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                protonet.eval()
                '''first get representations for support examples'''
                kshot_entail_reps = []
                for entail_batch in target_kshot_entail_dataloader:
                    entail_batch = tuple(t.to(device) for t in entail_batch)
                    _, input_ids, input_mask, segment_ids, label_ids = entail_batch
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(input_ids, input_mask)
                    kshot_entail_reps.append(last_hidden_entail)
                kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0), dim=0, keepdim=True)
                kshot_nonentail_reps = []
                for nonentail_batch in target_kshot_nonentail_dataloader:
                    nonentail_batch = tuple(t.to(device) for t in nonentail_batch)
                    _, input_ids, input_mask, segment_ids, label_ids = nonentail_batch
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_nonentail, _ = roberta_model(input_ids, input_mask)
                    kshot_nonentail_reps.append(last_hidden_nonentail)
                kshot_nonentail_rep = torch.mean(torch.cat(kshot_nonentail_reps, dim=0), dim=0, keepdim=True)
                target_class_prototype_reps = torch.cat([kshot_entail_rep, kshot_nonentail_rep], dim=0) #(2, hidden)

                for idd, dev_or_test_dataloader in enumerate([target_dev_dataloader, target_test_dataloader]):

                    if idd == 0:
                        logger.info("***** Running dev *****")
                        logger.info("  Num examples = %d", len(target_dev_examples))
                    else:
                        logger.info("***** Running test *****")
                        logger.info("  Num examples = %d", len(target_test_examples))


                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    gold_pair_ids = []
                    for input_pair_ids, input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        gold_pair_ids+= list(input_pair_ids.numpy())
                        label_ids = label_ids.to(device)
                        gold_label_ids+=list(label_ids.detach().cpu().numpy())
                        roberta_model.eval()
                        with torch.no_grad():
                            last_hidden_target_batch, _ = roberta_model(input_ids, input_mask)

                        with torch.no_grad():
                            logits = protonet(target_class_prototype_reps, last_hidden_target_batch)
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)
                    preds = preds[0]
                    pred_probs = list(softmax(preds,axis=1)[:,0]) #entail prob

                    assert len(gold_pair_ids) == len(pred_probs)
                    assert len(gold_pair_ids) == len(gold_label_ids)

                    pairID_2_predgoldlist = {}
                    for pair_id, prob, gold_id in zip(gold_pair_ids, pred_probs, gold_label_ids):
                        predgoldlist = pairID_2_predgoldlist.get(pair_id)
                        if predgoldlist is None:
                            predgoldlist = []
                        predgoldlist.append((prob, gold_id))
                        pairID_2_predgoldlist[pair_id] = predgoldlist
                    total_size = len(pairID_2_predgoldlist)
                    hit_size = 0
                    for pair_id, predgoldlist in pairID_2_predgoldlist.items():
                        predgoldlist.sort(key=lambda x:x[0]) #sort by prob
                        assert len(predgoldlist) == 4
                        if predgoldlist[-1][1] == 0:
                            hit_size+=1
                    test_acc= hit_size/total_size

                    if idd == 0: # this is dev
                        if test_acc > max_dev_acc:
                            max_dev_acc = test_acc
                            print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')

                        else:
                            print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                            break
                    else: # this is test
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\n\t\t test acc:', test_acc, ' max_test_acc:', max_test_acc, '\n')

    print('final_test_performance:', final_test_performance)
class FinBert(object):
    """
    The main class for FinBERT.
    """
    def __init__(self, config):
        self.config = config

    def prepare_model(self, label_list):
        """
        Sets some of the components of the model: Dataset processor, number of labels, usage of gpu and distributed
        training, gradient accumulation steps and tokenizer.
        Parameters
        ----------
        label_list: list
            The list of labels values in the dataset. For example: ['positive','negative','neutral']
        """

        self.processors = {"finsent": FinSentProcessor}

        self.num_labels_task = {'finsent': 2}

        if self.config.local_rank == -1 or self.config.no_cuda:
            self.device = torch.device("cuda" if torch.cuda.is_available()
                                       and not self.config.no_cuda else "cpu")
            self.n_gpu = torch.cuda.device_count()
        else:
            torch.cuda.set_device(self.config.local_rank)
            self.device = torch.device("cuda", self.config.local_rank)
            self.n_gpu = 1
            # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
            torch.distributed.init_process_group(backend='nccl')
        logger.info(
            "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}"
            .format(self.device, self.n_gpu,
                    bool(self.config.local_rank != -1), self.config.fp16))

        if self.config.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(self.config.gradient_accumulation_steps))

        self.config.train_batch_size = self.config.train_batch_size // self.config.gradient_accumulation_steps

        random.seed(self.config.seed)
        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

        if self.n_gpu > 0:
            torch.cuda.manual_seed_all(self.config.seed)

        if os.path.exists(self.config.model_dir) and os.listdir(
                self.config.model_dir):
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(self.config.model_dir))
        if not os.path.exists(self.config.model_dir):
            os.makedirs(self.config.model_dir)

        self.processor = self.processors['finsent']()
        self.num_labels = len(label_list)
        self.label_list = label_list

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.base_model, do_lower_case=self.config.do_lower_case)

    def get_data(self, phase):
        """
        Gets the data for training or evaluation. It returns the data in the format that pytorch will process. In the
        data directory, there should be a .csv file with the name <phase>.csv
        Parameters
        ----------
        phase: str
            Name of the dataset that will be used in that phase. For example if there is a 'train.csv' in the data
            folder, it should be set to 'train'.
        Returns
        -------
        examples: list
            A list of InputExample's. Each InputExample is an object that includes the information for each example;
            text, id, label...
        """

        self.num_train_optimization_steps = None
        examples = None
        examples = self.processor.get_examples(self.config.data_dir, phase)
        self.num_train_optimization_steps = int(
            len(examples) / self.config.train_batch_size / self.config.
            gradient_accumulation_steps) * self.config.num_train_epochs

        if phase == 'train':
            train = pd.read_csv(os.path.join(self.config.data_dir,
                                             'train.csv'),
                                sep='\t',
                                index_col=False)
            weights = list()
            labels = self.label_list

            class_weights = [
                train.shape[0] / train[train.label == label].shape[0]
                for label in labels
            ]
            self.class_weights = torch.tensor(class_weights)

        return examples

    def create_the_model(self):
        """
        Creates the model. Sets the model to be trained and the optimizer.
        """

        model = self.config.bert_model

        model.to(self.device)

        # Prepare optimizer
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        lr = self.config.learning_rate
        dft_rate = 1.2

        if self.config.discriminate:
            # apply the discriminative fine-tuning. discrimination rate is governed by dft_rate.

            encoder_params = []
            for i in range(12):
                encoder_decay = {
                    'params': [
                        p for n, p in list(
                            model.bert.encoder.layer[i].named_parameters())
                        if not any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.01,
                    'lr':
                    lr / (dft_rate**(12 - i))
                }
                encoder_nodecay = {
                    'params': [
                        p for n, p in list(
                            model.bert.encoder.layer[i].named_parameters())
                        if any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.0,
                    'lr':
                    lr / (dft_rate**(12 - i))
                }
                encoder_params.append(encoder_decay)
                encoder_params.append(encoder_nodecay)

            optimizer_grouped_parameters = [{
                'params': [
                    p
                    for n, p in list(model.bert.embeddings.named_parameters())
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01,
                'lr':
                lr / (dft_rate**13)
            }, {
                'params': [
                    p
                    for n, p in list(model.bert.embeddings.named_parameters())
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0,
                'lr':
                lr / (dft_rate**13)
            }, {
                'params': [
                    p for n, p in list(model.bert.pooler.named_parameters())
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01,
                'lr':
                lr
            }, {
                'params': [
                    p for n, p in list(model.bert.pooler.named_parameters())
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0,
                'lr':
                lr
            }, {
                'params': [
                    p for n, p in list(model.classifier.named_parameters())
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01,
                'lr':
                lr
            }, {
                'params': [
                    p for n, p in list(model.classifier.named_parameters())
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0,
                'lr':
                lr
            }]

            optimizer_grouped_parameters.extend(encoder_params)

        else:
            param_optimizer = list(model.named_parameters())

            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]

        schedule = "warmup_linear"

        self.num_warmup_steps = int(
            float(self.num_train_optimization_steps) *
            self.config.warm_up_proportion)

        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.config.learning_rate,
                               correct_bias=False)

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=self.num_train_optimization_steps)

        return model

    def get_loader(self, examples, phase):
        """
        Creates a data loader object for a dataset.
        Parameters
        ----------
        examples: list
            The list of InputExample's.
        phase: 'train' or 'eval'
            Determines whether to use random sampling or sequential sampling depending on the phase.
        Returns
        -------
        dataloader: DataLoader
            The data loader object.
        """

        features = convert_examples_to_features(examples, self.label_list,
                                                self.config.max_seq_length,
                                                self.tokenizer,
                                                self.config.output_mode)

        # Log the necessasry information
        logger.info("***** Loading data *****")
        logger.info("  Num examples = %d", len(examples))
        logger.info("  Batch size = %d", self.config.train_batch_size)
        logger.info("  Num steps = %d", self.num_train_optimization_steps)

        # Load the data, make it into TensorDataset
        all_input_ids = torch.tensor([f.input_ids for f in features],
                                     dtype=torch.long)
        all_attention_mask = torch.tensor([f.attention_mask for f in features],
                                          dtype=torch.long)
        all_token_type_ids = torch.tensor([f.token_type_ids for f in features],
                                          dtype=torch.long)

        if self.config.output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in features],
                                         dtype=torch.long)
        elif self.config.output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in features],
                                         dtype=torch.float)

        try:
            all_agree_ids = torch.tensor([f.agree for f in features],
                                         dtype=torch.long)
        except:
            all_agree_ids = torch.tensor([0.0 for f in features],
                                         dtype=torch.long)

        data = TensorDataset(all_input_ids, all_attention_mask,
                             all_token_type_ids, all_label_ids, all_agree_ids)

        # Distributed, if necessary
        if phase == 'train':
            my_sampler = RandomSampler(data)
        elif phase == 'eval':
            my_sampler = SequentialSampler(data)

        dataloader = DataLoader(data,
                                sampler=my_sampler,
                                batch_size=self.config.train_batch_size)
        return dataloader

    def train(self, train_examples, model):
        """
        Trains the model.
        Parameters
        ----------
        examples: list
            Contains the data as a list of InputExample's
        model: BertModel
            The Bert model to be trained.
        weights: list
            Contains class weights.
        Returns
        -------
        model: BertModel
            The trained model.
        """

        validation_examples = self.get_data('validation')

        global_step = 0

        self.validation_losses = []

        # Training
        train_dataloader = self.get_loader(train_examples, 'train')

        model.train()

        step_number = len(train_dataloader)

        i = 0
        for _ in trange(int(self.config.num_train_epochs), desc="Epoch"):

            model.train()

            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc='Iteration')):

                if (self.config.gradual_unfreeze and i == 0):
                    for param in model.bert.parameters():
                        param.requires_grad = False

                if (step % (step_number // 3)) == 0:
                    i += 1

                if (self.config.gradual_unfreeze and i > 1
                        and i < self.config.encoder_no):

                    for k in range(i - 1):

                        try:
                            for param in model.bert.encoder.layer[
                                    self.config.encoder_no - 1 -
                                    k].parameters():
                                param.requires_grad = True
                        except:
                            pass

                if (self.config.gradual_unfreeze
                        and i > self.config.encoder_no + 1):
                    for param in model.bert.embeddings.parameters():
                        param.requires_grad = True

                batch = tuple(t.to(self.device) for t in batch)

                input_ids, attention_mask, token_type_ids, label_ids, agree_ids = batch

                logits = model(input_ids, attention_mask, token_type_ids)[0]
                weights = self.class_weights.to(self.device)

                if self.config.output_mode == "classification":
                    loss_fct = CrossEntropyLoss(weight=weights)
                    loss = loss_fct(logits.view(-1, self.num_labels),
                                    label_ids.view(-1))
                elif self.config.output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if self.config.gradient_accumulation_steps > 1:
                    loss = loss / self.config.gradient_accumulation_steps
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % self.config.gradient_accumulation_steps == 0:
                    if self.config.fp16:
                        lr_this_step = self.config.learning_rate * warmup_linear(
                            global_step / self.num_train_optimization_steps,
                            self.config.warm_up_proportion)
                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
                    global_step += 1

            # Validation

            validation_loader = self.get_loader(validation_examples,
                                                phase='eval')
            model.eval()

            valid_loss, valid_accuracy = 0, 0
            nb_valid_steps, nb_valid_examples = 0, 0

            for input_ids, attention_mask, token_type_ids, label_ids, agree_ids in tqdm(
                    validation_loader, desc="Validating"):
                input_ids = input_ids.to(self.device)
                attention_mask = attention_mask.to(self.device)
                token_type_ids = token_type_ids.to(self.device)
                label_ids = label_ids.to(self.device)
                agree_ids = agree_ids.to(self.device)

                with torch.no_grad():
                    logits = model(input_ids, attention_mask,
                                   token_type_ids)[0]

                    if self.config.output_mode == "classification":
                        loss_fct = CrossEntropyLoss(weight=weights)
                        tmp_valid_loss = loss_fct(
                            logits.view(-1, self.num_labels),
                            label_ids.view(-1))
                    elif self.config.output_mode == "regression":
                        loss_fct = MSELoss()
                        tmp_valid_loss = loss_fct(logits.view(-1),
                                                  label_ids.view(-1))

                    valid_loss += tmp_valid_loss.mean().item()

                    nb_valid_steps += 1

            valid_loss = valid_loss / nb_valid_steps

            self.validation_losses.append(valid_loss)
            print("Validation losses: {}".format(self.validation_losses))

            if valid_loss == min(self.validation_losses):

                try:
                    os.remove(self.config.model_dir /
                              ('temporary' + str(best_model)))
                except:
                    print('No best model found')
                torch.save({
                    'epoch': str(i),
                    'state_dict': model.state_dict()
                }, self.config.model_dir / ('temporary' + str(i)))
                best_model = i

        # Save a trained model and the associated configuration
        checkpoint = torch.load(self.config.model_dir /
                                ('temporary' + str(best_model)))
        model.load_state_dict(checkpoint['state_dict'])
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(self.config.model_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(self.config.model_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())
        os.remove(self.config.model_dir / ('temporary' + str(best_model)))
        return model

    def evaluate(self, model, examples):
        """
        Evaluate the model.
        Parameters
        ----------
        model: BertModel
            The model to be evaluated.
        examples: list
            Evaluation data as a list of InputExample's/
        Returns
        -------
        evaluation_df: pd.DataFrame
            A dataframe that includes for each example predicted probability and labels.
        """

        eval_loader = self.get_loader(examples, phase='eval')

        logger.info("***** Running evaluation ***** ")
        logger.info("  Num examples = %d", len(examples))
        logger.info("  Batch size = %d", self.config.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        predictions = []
        labels = []
        agree_levels = []
        text_ids = []

        for input_ids, attention_mask, token_type_ids, label_ids, agree_ids in tqdm(
                eval_loader, desc="Testing"):
            input_ids = input_ids.to(self.device)
            attention_mask = attention_mask.to(self.device)
            token_type_ids = token_type_ids.to(self.device)
            label_ids = label_ids.to(self.device)
            agree_ids = agree_ids.to(self.device)

            with torch.no_grad():
                logits = model(input_ids, attention_mask, token_type_ids)[0]

                if self.config.output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    tmp_eval_loss = loss_fct(logits.view(-1, self.num_labels),
                                             label_ids.view(-1))
                elif self.config.output_mode == "regression":
                    loss_fct = MSELoss()
                    tmp_eval_loss = loss_fct(logits.view(-1),
                                             label_ids.view(-1))

                np_logits = logits.cpu().numpy()

                if self.config.output_mode == 'classification':
                    prediction = np.array(np_logits)
                elif self.config.output_mode == "regression":
                    prediction = np.array(np_logits)

                for agree_id in agree_ids:
                    agree_levels.append(agree_id.item())

                for label_id in label_ids:
                    labels.append(label_id.item())

                for pred in prediction:
                    predictions.append(pred)

                text_ids.append(input_ids)

                # tmp_eval_loss = loss_fct(logits.view(-1, self.num_labels), label_ids.view(-1))
                # tmp_eval_loss = model(input_ids, token_type_ids, attention_mask, label_ids)

                eval_loss += tmp_eval_loss.mean().item()
                nb_eval_steps += 1

            # logits = logits.detach().cpu().numpy()
            # label_ids = label_ids.to('cpu').numpy()
            # tmp_eval_accuracy = accuracy(logits, label_ids)

            # eval_loss += tmp_eval_loss.mean().item()
            # eval_accuracy += tmp_eval_accuracy

        evaluation_df = pd.DataFrame({
            'predictions': predictions,
            'labels': labels,
            "agree_levels": agree_levels
        })

        return evaluation_df
Beispiel #15
0
class Trainer(object):
    def __init__(self, proto, stage="train"):
        # model config
        model_cfg = proto["model"]
        model_name = model_cfg["name"]
        self.model_name = model_name

        # dataset config
        data_cfg = proto["data"]
        train_data_path = data_cfg.get("train_path", None)
        val_data_path = data_cfg.get("val_path", None)
        pad = data_cfg.get("pad", 32)
        train_bs = data_cfg.get("train_batch_size", None)
        val_bs = data_cfg.get("val_batch_size", None)
        self.val_bs = val_bs
        self.skip_first = data_cfg.get("skip_first", False)
        self.delimiter = data_cfg.get("delimiter", "\t")

        # assorted config
        optim_cfg = proto.get("optimizer", {"lr": 0.00003})
        sched_cfg = proto.get("schedulers", None)
        loss = proto.get("loss", "CE")
        self.device = proto.get("device", None)

        model_cfg.pop("name")

        if torch.cuda.is_available() and self.device is not None:
            print("Using device: %d." % self.device)
            self.device = torch.device(self.device)
            self.gpu = True
        else:
            print("Using cpu device.")
            self.device = torch.device("cpu")
            self.gpu = False

        if stage == "train":
            if train_data_path is None or val_data_path is None:
                raise ValueError("Please specify both train and val data path.")
            if train_bs is None or val_bs is None:
                raise ValueError("Please specify both train and val batch size.")
            # loading model
            self.model = fetch_nn(model_name)(**model_cfg)
            self.model = self.model.cuda(self.device)

            # loading dataset and converting into dataloader
            self.train_data = ChineseTextSet(
                path=train_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.train_loader = DataLoader(
                self.train_data, train_bs, shuffle=True, num_workers=4)
            self.val_data = ChineseTextSet(
                path=val_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.val_loader = DataLoader(
                self.val_data, val_bs, shuffle=True, num_workers=4)

            time_format = "%Y-%m-%d...%H.%M.%S"
            id = time.strftime(time_format, time.localtime(time.time()))
            self.record_path = os.path.join(arg.record, model_name, id)

            os.makedirs(self.record_path)
            sys.stdout = Logger(os.path.join(self.record_path, 'records.txt'))
            print("Writing proto file to file directory: %s." % self.record_path)
            yaml.dump(proto, open(os.path.join(self.record_path, 'protocol.yml'), 'w'))

            print("*" * 25, " PROTO BEGINS ", "*" * 25)
            pprint(proto)
            print("*" * 25, " PROTO ENDS ", "*" * 25)

            self.optimizer = AdamW(self.model.parameters(), **optim_cfg)
            self.scheduler = fetch_scheduler(self.optimizer, sched_cfg)

            self.loss = fetch_loss(loss)

            self.best_f1 = 0.0
            self.best_step = 1
            self.start_step = 1

            self.num_steps = proto["num_steps"]
            self.num_epoch = math.ceil(self.num_steps / len(self.train_loader))

            # the number of steps to write down a log
            self.log_steps = proto["log_steps"]
            # the number of steps to validate on val dataset once
            self.val_steps = proto["val_steps"]

            self.f1_meter = AverageMeter()
            self.p_meter = AverageMeter()
            self.r_meter = AverageMeter()
            self.acc_meter = AverageMeter()
            self.loss_meter = AverageMeter()

        if stage == "test":
            if val_data_path is None:
                raise ValueError("Please specify the val data path.")
            if val_bs is None:
                raise ValueError("Please specify the val batch size.")
            id = proto["id"]
            ckpt_fold = proto.get("ckpt_fold", "runs")
            self.record_path = os.path.join(ckpt_fold, model_name, id)
            sys.stdout = Logger(os.path.join(self.record_path, 'tests.txt'))

            config, state_dict, fc_dict = self._load_ckpt(best=True, train=False)
            weights = {"config": config, "state_dict": state_dict}
            # loading trained model using config and state_dict
            self.model = fetch_nn(model_name)(weights=weights)
            # loading the weights for the final fc layer
            self.model.load_state_dict(fc_dict, strict=False)
            # loading model to gpu device if specified
            if self.gpu:
                self.model = self.model.cuda(self.device)

            print("Testing directory: %s." % self.record_path)
            print("*" * 25, " PROTO BEGINS ", "*" * 25)
            pprint(proto)
            print("*" * 25, " PROTO ENDS ", "*" * 25)

            self.val_path = val_data_path
            self.test_data = ChineseTextSet(
                path=val_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.test_loader = DataLoader(
                self.test_data, val_bs, shuffle=True, num_workers=4)

    def _save_ckpt(self, step, best=False, f=None, p=None, r=None):
        save_dir = os.path.join(self.record_path, "best_model.bin" if best else "latest_model.bin")
        torch.save({
            "step": step,
            "f1": f,
            "precision": p,
            "recall": r,
            "best_step": self.best_step,
            "best_f1": self.best_f1,
            "model": self.model.state_dict(),
            "config": self.model.config,
            "optimizer": self.optimizer.state_dict(),
            "schedulers": self.scheduler.state_dict(),
        }, save_dir)

    def _load_ckpt(self, best=False, train=False):
        load_dir = os.path.join(self.record_path, "best_model.bin" if best else "latest_model.bin")
        load_dict = torch.load(load_dir, map_location=self.device)
        self.start_step = load_dict["step"]
        self.best_step = load_dict["best_step"]
        self.best_f1 = load_dict["best_f1"]
        if train:
            self.optimizer.load_state_dict(load_dict["optimizer"])
            self.scheduler.load_state_dict(load_dict["schedulers"])
        print("Loading checkpoint from %s, best step: %d, best f1: %.4f."
              % (load_dir, self.best_step, self.best_f1))
        if not best:
            print("Checkpoint step %s, f1: %.4f, precision: %.4f, recall: %.4f."
                  % (self.start_step, load_dict["f1"],
                     load_dict["precision"], load_dict["recall"]))
        fc_dict = {
            "fc.weight": load_dict["model"]["fc.weight"],
            "fc.bias": load_dict["model"]["fc.bias"]
        }
        return load_dict["config"], load_dict["model"], fc_dict

    def to_cuda(self, *args):
        return [obj.cuda(self.device) for obj in args]

    @staticmethod
    def fixed_randomness():
        random.seed(0)
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        torch.cuda.manual_seed_all(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    @staticmethod
    def update_metrics(gt, pre, f1_m, p_m, r_m, acc_m):
        f1_value = f1(gt, pre, average="micro")
        f1_m.update(f1_value)
        p_value = precision(gt, pre, average="micro", zero_division=0)
        p_m.update(p_value)
        r_value = recall(gt, pre, average="micro")
        r_m.update(r_value)
        acc_value = accuracy(gt, pre)
        acc_m.update(acc_value)

    def train(self):
        timer = Timer()
        writer = SummaryWriter(self.record_path)
        print("*" * 25, " TRAINING BEGINS ", "*" * 25)
        start_epoch = self.start_step // len(self.train_loader) + 1
        for epoch_idx in range(start_epoch, self.num_epoch + 1):
            self.f1_meter.reset()
            self.p_meter.reset()
            self.r_meter.reset()
            self.acc_meter.reset()
            self.loss_meter.reset()
            self.optimizer.step()
            self.scheduler.step()
            train_generator = tqdm(enumerate(self.train_loader, 1), position=0, leave=True)

            for batch_idx, data in train_generator:
                global_step = (epoch_idx - 1) * len(self.train_loader) + batch_idx
                self.model.train()
                id, label, _, mask = data[:4]
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))
                loss = self.loss(pre, label)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(
                    lbl, yp, self.f1_meter, self.p_meter,
                    self.r_meter, self.acc_meter
                )
                self.loss_meter.update(loss.item())

                if global_step % self.log_steps == 0 and writer is not None:
                    writer.add_scalar("train/f1", self.f1_meter.avg, global_step)
                    writer.add_scalar("train/loss", self.loss_meter.avg, global_step)
                    writer.add_scalar("train/lr", self.scheduler.get_lr()[0], global_step)

                train_generator.set_description(
                    "Train Epoch %d (%d/%d), "
                    "Global Step %d, Loss %.4f, f1 %.4f, p %.4f, r %.4f, acc %.4f, LR %.6f" % (
                        epoch_idx, batch_idx, len(self.train_loader), global_step,
                        self.loss_meter.avg, self.f1_meter.avg,
                        self.p_meter.avg, self.r_meter.avg,
                        self.acc_meter.avg,
                        self.scheduler.get_lr()[0]
                    )
                )

                # validating process
                if global_step % self.val_steps == 0:
                    print()
                    self.validate(epoch_idx, global_step, timer, writer)

                # when num_steps has been set and the training process will
                # be stopped earlier than the specified num_epochs, then stop.
                if self.num_steps is not None and global_step == self.num_steps:
                    if writer is not None:
                        writer.close()
                    print()
                    print("*" * 25, " TRAINING ENDS ", "*" * 25)
                    return

            train_generator.close()
            print()
        writer.close()
        print("*" * 25, " TRAINING ENDS ", "*" * 25)

    def validate(self, epoch, step, timer, writer):
        with torch.no_grad():
            f1_meter = AverageMeter()
            p_meter = AverageMeter()
            r_meter = AverageMeter()
            acc_meter = AverageMeter()
            loss_meter = AverageMeter()
            val_generator = tqdm(enumerate(self.val_loader, 1), position=0, leave=True)
            for val_idx, data in val_generator:
                self.model.eval()
                id, label, _, mask = data[:4]
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))
                loss = self.loss(pre, label)

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(lbl, yp, f1_meter, p_meter, r_meter, acc_meter)
                loss_meter.update(loss.item())

                val_generator.set_description(
                    "Eval Epoch %d (%d/%d), Global Step %d, Loss %.4f, "
                    "f1 %.4f, p %.4f, r %.4f, acc %.4f" % (
                        epoch, val_idx, len(self.val_loader), step,
                        loss_meter.avg, f1_meter.avg,
                        p_meter.avg, r_meter.avg, acc_meter.avg
                    )
                )

            print("Eval Epoch %d, f1 %.4f" % (epoch, f1_meter.avg))
            if writer is not None:
                writer.add_scalar("val/loss", loss_meter.avg, step)
                writer.add_scalar("val/f1", f1_meter.avg, step)
                writer.add_scalar("val/precision", p_meter.avg, step)
                writer.add_scalar("val/recall", r_meter.avg, step)
                writer.add_scalar("val/acc", acc_meter.avg, step)
            if f1_meter.avg > self.best_f1:
                self.best_f1 = f1_meter.avg
                self.best_step = step
                self._save_ckpt(step, best=True)
            print("Best Step %d, Best f1 %.4f, Running Time: %s, Estimated Time: %s" % (
                self.best_step, self.best_f1, timer.measure(), timer.measure(step / self.num_steps)
            ))
            self._save_ckpt(step, best=False, f=f1_meter.avg, p=p_meter.avg, r=r_meter.avg)

    def test(self):
        # t_idx = random.randint(0, self.val_bs)
        t_idx = random.randint(0, 5)
        with torch.no_grad():
            self.fixed_randomness()  # for reproduction

            # for writing the total predictions to disk
            data_idxs = list()
            all_preds = list()

            # for ploting P-R Curve
            predicts = list()
            truths = list()

            # for showing predicted samples
            show_ctxs = list()
            pred_lbls = list()
            targets = list()

            f1_meter = AverageMeter()
            p_meter = AverageMeter()
            r_meter = AverageMeter()
            accuracy_meter = AverageMeter()
            test_generator = tqdm(enumerate(self.test_loader, 1))
            for idx, data in test_generator:
                self.model.eval()
                id, label, _, mask, data_idx = data
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(lbl, yp, f1_meter, p_meter, r_meter, accuracy_meter)

                test_generator.set_description(
                    "Test %d/%d, f1 %.4f, p %.4f, r %.4f, acc %.4f"
                    % (idx, len(self.test_loader), f1_meter.avg,
                       p_meter.avg, r_meter.avg, accuracy_meter.avg)
                )

                data_idxs.append(data_idx.numpy())
                all_preds.append(yp)

                predicts.append(torch.select(pre, dim=1, index=1).cpu().numpy())
                truths.append(lbl)

                # show some of the sample
                ctx = torch.select(id, dim=0, index=t_idx).detach()
                ctx = self.model.tokenizer.convert_ids_to_tokens(ctx)
                ctx = "".join([_ for _ in ctx if _ not in [PAD, CLS]])
                yp = yp[t_idx]
                lbl = lbl[t_idx]

                show_ctxs.append(ctx)
                pred_lbls.append(yp)
                targets.append(lbl)

            print("*" * 25, " SAMPLE BEGINS ", "*" * 25)
            for c, t, l in zip(show_ctxs, targets, pred_lbls):
                print("ctx: ", c, " gt: ", t, " est: ", l)
            print("*" * 25, " SAMPLE ENDS ", "*" * 25)
            print("Test, FINAL f1 %.4f, "
                  "p %.4f, r %.4f, acc %.4f\n" %
                  (f1_meter.avg, p_meter.avg, r_meter.avg, accuracy_meter.avg))

            # output the final results to disk
            data_idxs = np.concatenate(data_idxs, axis=0)
            all_preds = np.concatenate(all_preds, axis=0)
            write_predictions(
                self.val_path, os.path.join(self.record_path, "results.txt"),
                data_idxs, all_preds, delimiter=self.delimiter, skip_first=self.skip_first
            )

            # output the p-r values for future plotting P-R Curve
            predicts = np.concatenate(predicts, axis=0)
            truths = np.concatenate(truths, axis=0)
            values = precision_recall_curve(truths, predicts)
            with open(os.path.join(self.record_path, "pr.values"), "wb") as f:
                pickle.dump(values, f)
            p_value, r_value, _ = values

            # plot P-R Curve if specified
            if arg.image:
                plt.figure()
                plt.plot(
                    p_value, r_value,
                    label="%s (ACC: %.2f, F1: %.2f)"
                          % (self.model_name, accuracy_meter.avg, f1_meter.avg)
                )
                plt.legend(loc="best")
                plt.title("2-Classes P-R curve")
                plt.xlabel("precision")
                plt.ylabel("recall")
                plt.savefig(os.path.join(self.record_path, "P-R.png"))
                plt.show()
Beispiel #16
0
def main(args):
    if args.large:
        args.train_record_file += '_large'
        args.dev_eval_file += '_large'
        model_name = "albert-xlarge-v2"
    else:
        model_name = "albert-base-v2"
    if args.xxlarge:
        args.train_record_file += '_xxlarge'
        args.dev_eval_file += '_xxlarge'
        model_name = "albert-xxlarge-v2"
    # Set up logging and devices
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True)
    log = util.get_logger(args.save_dir, args.name)
    tbx = SummaryWriter(args.save_dir)
    device, args.gpu_ids = util.get_available_devices()
    log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}')
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seed
    log.info(f'Using random seed {args.seed}...')
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get model
    log.info('Building model...')
    if args.bidaf:
        char_vectors = util.torch_from_json(args.char_emb_file)

    if args.model_name == 'albert_highway':
        model = models.albert_highway(model_name)
    elif args.model_name == 'albert_lstm_highway':
        model = models.LSTM_highway(model_name, hidden_size=args.hidden_size)
    elif args.model_name == 'albert_bidaf':
        model = models.BiDAF(char_vectors=char_vectors,
                             hidden_size=args.hidden_size,
                             drop_prob=args.drop_prob)
    elif args.model_name == 'albert_bidaf2':
        model = models.BiDAF2(model_name=model_name,
                              char_vectors=char_vectors,
                              hidden_size=args.hidden_size,
                              drop_prob=args.drop_prob)
    else:
        model = AlbertForQuestionAnswering.from_pretrained(args.model_name)

    model = nn.DataParallel(model, args.gpu_ids)
    if args.load_path:
        log.info(f'Loading checkpoint from {args.load_path}...')
        model, step = util.load_model(model, args.load_path, args.gpu_ids)
    else:
        step = 0
    model = model.to(device)
    model.train()
    ema = util.EMA(model, args.ema_decay)

    # Get saver
    saver = util.CheckpointSaver(args.save_dir,
                                 max_checkpoints=args.max_checkpoints,
                                 metric_name=args.metric_name,
                                 maximize_metric=args.maximize_metric,
                                 log=log)

    # Get optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
    scheduler = sched.LambdaLR(optimizer, lambda s: 1.)  # Constant LR

    # Get data loader
    log.info('Building dataset...')
    train_dataset = SQuAD(args.train_record_file, args.use_squad_v2,
                          args.bidaf)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers)
    dev_dataset = SQuAD(args.dev_eval_file, args.use_squad_v2, args.bidaf)
    dev_loader = data.DataLoader(dev_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    with open(args.dev_gold_file) as f:
        gold_dict = json.load(f)

    tokenizer = AlbertTokenizer.from_pretrained(model_name)

    # Train
    log.info('Training...')
    steps_till_eval = args.eval_steps
    epoch = step // len(train_dataset)
    while epoch != args.num_epochs:
        epoch += 1
        log.info(f'Starting epoch {epoch}...')
        with torch.enable_grad(), \
                tqdm(total=len(train_loader.dataset)) as progress_bar:
            for batch in train_loader:
                batch = tuple(t.to(device) for t in batch)
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                    'start_positions': batch[3],
                    'end_positions': batch[4],
                }
                if args.bidaf:
                    inputs['char_ids'] = batch[6]
                y1 = batch[3]
                y2 = batch[4]
                # Setup for forward
                batch_size = inputs["input_ids"].size(0)
                optimizer.zero_grad()

                # Forward
                # log_p1, log_p2 = model(**inputs)
                y1, y2 = y1.to(device), y2.to(device)
                outputs = model(**inputs)
                loss = outputs[0]
                loss = loss.mean()
                # loss_fct = nn.CrossEntropyLoss()
                # loss = loss_fct(log_p1, y1) + loss_fct(log_p2, y2)
                # loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
                loss_val = loss.item()

                # Backward
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)
                optimizer.step()
                scheduler.step(step // batch_size)
                ema(model, step // batch_size)

                # Log info
                step += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch, NLL=loss_val)
                tbx.add_scalar('train/NLL', loss_val, step)
                tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'],
                               step)

                steps_till_eval -= batch_size
                if steps_till_eval <= 0:
                    steps_till_eval = args.eval_steps

                    # Evaluate and save checkpoint
                    log.info(f'Evaluating at step {step}...')
                    ema.assign(model)
                    results, pred_dict = evaluate(args, model, dev_dataset,
                                                  dev_loader, gold_dict,
                                                  tokenizer, device,
                                                  args.max_ans_len,
                                                  args.use_squad_v2)
                    saver.save(step, model, results[args.metric_name], device)
                    ema.resume(model)

                    # Log to console
                    results_str = ', '.join(f'{k}: {v:05.2f}'
                                            for k, v in results.items())
                    log.info(f'Dev {results_str}')

                    # Log to TensorBoard
                    log.info('Visualizing in TensorBoard...')
                    for k, v in results.items():
                        tbx.add_scalar(f'dev/{k}', v, step)
Beispiel #17
0
def main():
    parser = argparse.ArgumentParser()

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--target_train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    scitail_path = '/export/home/Dataset/SciTailV1/tsv_format/'
    target_kshot_entail_examples, target_kshot_nonentail_examples = get_SciTail_as_train_k_shot(
        scitail_path + 'scitail_1.0_train.tsv', args.kshot,
        args.seed)  #train_pu_half_v1.txt
    target_dev_examples, target_test_examples = get_SciTail_dev_and_test(
        scitail_path + 'scitail_1.0_dev.tsv',
        scitail_path + 'scitail_1.0_test.tsv')

    system_seed = 42
    random.seed(system_seed)
    np.random.seed(system_seed)
    torch.manual_seed(system_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(system_seed)

    source_kshot_size = 10  # if args.kshot>10 else 10 if max(10, args.kshot)
    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train(
        '/export/home/Dataset/glue_data/MNLI/train.tsv', source_kshot_size)
    source_examples = source_kshot_entail + source_kshot_neural + source_kshot_contra + source_remaining_examples
    target_label_list = ["entails", "neutral"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:',
          len(target_dev_examples), 'test size:', len(target_test_examples))

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load(
        '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
    ),
                                  strict=False)
    roberta_model.to(device)
    roberta_model.eval()

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(
        source_kshot_entail,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(
        source_kshot_neural,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(
        source_kshot_contra,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(
        source_remaining_examples,
        source_label_list,
        args,
        tokenizer,
        args.train_batch_size,
        "classification",
        dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(
        target_kshot_entail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(
        target_kshot_nonentail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples,
                                                 target_label_list,
                                                 args,
                                                 tokenizer,
                                                 args.eval_batch_size,
                                                 "classification",
                                                 dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples,
                                                  target_label_list,
                                                  args,
                                                  tokenizer,
                                                  args.eval_batch_size,
                                                  "classification",
                                                  dataloader_mode='sequential')
    '''
    retrieve rep for support examples in MNLI
    '''
    kshot_entail_reps = []
    for entail_batch in source_kshot_entail_dataloader:
        entail_batch = tuple(t.to(device) for t in entail_batch)
        input_ids, input_mask, segment_ids, label_ids = entail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_entail, _ = roberta_model(input_ids, input_mask)
        kshot_entail_reps.append(last_hidden_entail)
    kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0),
                                  dim=0,
                                  keepdim=True)
    kshot_neural_reps = []
    for neural_batch in source_kshot_neural_dataloader:
        neural_batch = tuple(t.to(device) for t in neural_batch)
        input_ids, input_mask, segment_ids, label_ids = neural_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_neural, _ = roberta_model(input_ids, input_mask)
        kshot_neural_reps.append(last_hidden_neural)
    kshot_neural_rep = torch.mean(torch.cat(kshot_neural_reps, dim=0),
                                  dim=0,
                                  keepdim=True)
    kshot_contra_reps = []
    for contra_batch in source_kshot_contra_dataloader:
        contra_batch = tuple(t.to(device) for t in contra_batch)
        input_ids, input_mask, segment_ids, label_ids = contra_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_contra, _ = roberta_model(input_ids, input_mask)
        kshot_contra_reps.append(last_hidden_contra)
    kshot_contra_rep = torch.mean(torch.cat(kshot_contra_reps, dim=0),
                                  dim=0,
                                  keepdim=True)

    source_class_prototype_reps = torch.cat(
        [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
        dim=0)  #(3, hidden)
    '''first get representations for support examples in target'''
    kshot_entail_reps = []
    for entail_batch in target_kshot_entail_dataloader:
        entail_batch = tuple(t.to(device) for t in entail_batch)
        input_ids, input_mask, segment_ids, label_ids = entail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_entail, _ = roberta_model(input_ids, input_mask)
        kshot_entail_reps.append(last_hidden_entail)
    all_kshot_entail_reps = torch.cat(kshot_entail_reps, dim=0)
    kshot_entail_rep = torch.mean(all_kshot_entail_reps, dim=0, keepdim=True)
    kshot_nonentail_reps = []
    for nonentail_batch in target_kshot_nonentail_dataloader:
        nonentail_batch = tuple(t.to(device) for t in nonentail_batch)
        input_ids, input_mask, segment_ids, label_ids = nonentail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_nonentail, _ = roberta_model(input_ids, input_mask)
        kshot_nonentail_reps.append(last_hidden_nonentail)
    all_kshot_neural_reps = torch.cat(kshot_nonentail_reps, dim=0)
    kshot_nonentail_rep = torch.mean(all_kshot_neural_reps,
                                     dim=0,
                                     keepdim=True)
    target_class_prototype_reps = torch.cat(
        [kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep],
        dim=0)  #(3, hidden)

    class_prototype_reps = torch.cat(
        [source_class_prototype_reps, target_class_prototype_reps],
        dim=0)  #(6, hidden)
    '''starting to train'''
    iter_co = 0
    tr_loss = 0
    source_loss = 0
    target_loss = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):

        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(
                tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, source_label_ids_batch = batch

            roberta_model.eval()
            with torch.no_grad():
                source_last_hidden_batch, _ = roberta_model(
                    input_ids, input_mask)
            '''forward to model'''
            target_batch_size = args.target_train_batch_size  #10*3
            target_batch_size_entail = target_batch_size  #random.randrange(5)+1
            target_batch_size_neural = target_batch_size  #random.randrange(5)+1

            selected_target_entail_rep = all_kshot_entail_reps[torch.randperm(
                all_kshot_entail_reps.shape[0])[:target_batch_size_entail]]
            selected_target_neural_rep = all_kshot_neural_reps[torch.randperm(
                all_kshot_neural_reps.shape[0])[:target_batch_size_neural]]
            target_last_hidden_batch = torch.cat(
                [selected_target_entail_rep, selected_target_neural_rep])

            last_hidden_batch = torch.cat(
                [source_last_hidden_batch, target_last_hidden_batch],
                dim=0)  #(train_batch_size+10*2)
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)
            '''source side loss'''
            # loss_fct = CrossEntropyLoss(reduction='none')
            loss_fct = CrossEntropyLoss()
            source_loss_list = loss_fct(
                batch_logits[:source_last_hidden_batch.shape[0]].view(
                    -1, source_num_labels), source_label_ids_batch.view(-1))
            '''target side loss'''
            target_label_ids_batch = torch.tensor(
                [0] * selected_target_entail_rep.shape[0] +
                [1] * selected_target_neural_rep.shape[0],
                dtype=torch.long)
            target_batch_logits = batch_logits[-target_last_hidden_batch.
                                               shape[0]:]
            target_loss_list = loss_by_logits_and_2way_labels(
                target_batch_logits, target_label_ids_batch.view(-1), device)

            loss = source_loss_list + target_loss_list  #torch.mean(torch.cat([source_loss_list, target_loss_list]))
            source_loss += source_loss_list
            target_loss += target_loss_list
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            global_step += 1
            iter_co += 1
            '''print loss'''
            # if iter_co %5==0:
            #     print('iter_co:', iter_co, ' mean loss', tr_loss/iter_co)
            #     print('source_loss_list:', source_loss/iter_co, ' target_loss_list: ', target_loss/iter_co)
            if iter_co % 1 == 0:
                # if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                protonet.eval()

                for idd, dev_or_test_dataloader in enumerate(
                    [target_dev_dataloader, target_test_dataloader]):

                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    # print('Evaluating...')
                    for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)
                        gold_label_ids += list(
                            label_ids.detach().cpu().numpy())
                        roberta_model.eval()
                        with torch.no_grad():
                            last_hidden_target_batch, logits_from_source = roberta_model(
                                input_ids, input_mask)

                        with torch.no_grad():
                            logits = protonet(class_prototype_reps,
                                              last_hidden_target_batch)
                        '''combine with logits from source domain'''
                        # print('logits:', logits)
                        # print('logits_from_source:', logits_from_source)
                        # weight = 0.9
                        # logits = weight*logits+(1.0-weight)*torch.sigmoid(logits_from_source)
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0],
                                                 logits.detach().cpu().numpy(),
                                                 axis=0)

                    preds = preds[0]

                    pred_probs = softmax(preds, axis=1)
                    pred_label_ids_3way = list(np.argmax(pred_probs, axis=1))
                    '''change from 3-way to 2-way'''
                    pred_label_ids = []
                    for pred_id in pred_label_ids_3way:
                        if pred_id != 0:
                            pred_label_ids.append(1)
                        else:
                            pred_label_ids.append(0)

                    gold_label_ids = gold_label_ids
                    assert len(pred_label_ids) == len(gold_label_ids)
                    hit_co = 0
                    for k in range(len(pred_label_ids)):
                        if pred_label_ids[k] == gold_label_ids[k]:
                            hit_co += 1
                    test_acc = hit_co / len(gold_label_ids)

                    if idd == 0:  # this is dev
                        if test_acc > max_dev_acc:
                            max_dev_acc = test_acc
                            print('\niter', iter_co, '\tdev acc:', test_acc,
                                  ' max_dev_acc:', max_dev_acc, '\n')

                        else:
                            print('\niter', iter_co, '\tdev acc:', test_acc,
                                  ' max_dev_acc:', max_dev_acc, '\n')
                            break
                    else:  # this is test
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\niter', iter_co, '\ttest acc:', test_acc,
                              ' max_test_acc:', max_test_acc, '\n')
            # if iter_co == 500:#3000:
            #     break
    print('final_test_performance:', final_test_performance)
Beispiel #18
0
class TRADE(nn.Module):
    def __init__(self,
                 hidden_size,
                 lang,
                 path,
                 task,
                 lr,
                 dropout,
                 slots,
                 gating_dict,
                 t_total,
                 device,
                 nb_train_vocab=0):
        super(TRADE, self).__init__()
        self.name = "TRADE"
        self.task = task
        self.hidden_size = hidden_size
        self.lang = lang[0]
        self.mem_lang = lang[1]
        self.lr = lr
        self.dropout = dropout
        self.slots = slots[0]
        self.slot_temp = slots[2]
        self.gating_dict = gating_dict
        self.device = device
        self.nb_gate = len(gating_dict)
        self.cross_entorpy = nn.CrossEntropyLoss()
        self.cell_type = args['cell_type']

        if args['encoder'] == 'RNN':
            self.encoder = EncoderRNN(self.lang.n_words, hidden_size,
                                      self.dropout, self.device,
                                      self.cell_type)
            self.decoder = Generator(self.lang, self.encoder.embedding,
                                     self.lang.n_words, hidden_size,
                                     self.dropout, self.slots, self.nb_gate,
                                     self.device, self.cell_type)
        elif args['encoder'] == 'TPRNN':
            self.encoder = EncoderTPRNN(self.lang.n_words, hidden_size,
                                        self.dropout, self.device,
                                        self.cell_type, args['nSymbols'],
                                        args['nRoles'], args['dSymbols'],
                                        args['dRoles'], args['temperature'],
                                        args['scale_val'], args['train_scale'])
            self.decoder = Generator(self.lang, self.encoder.embedding,
                                     self.lang.n_words, hidden_size,
                                     self.dropout, self.slots, self.nb_gate,
                                     self.device, self.cell_type)
        else:
            self.encoder = BERTEncoder(hidden_size, self.dropout, self.device)
            self.decoder = Generator(self.lang, None, self.lang.n_words,
                                     hidden_size, self.dropout, self.slots,
                                     self.nb_gate, self.device, self.cell_type)

        if path:
            print("MODEL {} LOADED".format(str(path)))
            trained_encoder = torch.load(str(path) + '/enc.th',
                                         map_location=self.device)
            trained_decoder = torch.load(str(path) + '/dec.th',
                                         map_location=self.device)

            # fix small confusion between old and newer trained models
            encoder_dict = trained_encoder.state_dict()
            new_encoder_dict = {}
            for key in encoder_dict:
                mapped_key = key
                if key.startswith('gru.'):
                    mapped_key = 'rnn.' + key[len('gru.'):]
                new_encoder_dict[mapped_key] = encoder_dict[key]

            decoder_dict = trained_decoder.state_dict()
            new_decoder_dict = {}
            for key in decoder_dict:
                mapped_key = key
                if key.startswith('gru.'):
                    mapped_key = 'rnn.' + key[len('gru.'):]
                new_decoder_dict[mapped_key] = decoder_dict[key]

            if not 'W_slot_embed.weight' in new_decoder_dict:
                new_decoder_dict['W_slot_embed.weight'] = torch.zeros(
                    (hidden_size, 2 * hidden_size), requires_grad=False)
                new_decoder_dict['W_slot_embed.bias'] = torch.zeros(
                    (hidden_size, ), requires_grad=False)

            self.encoder.load_state_dict(new_encoder_dict)
            self.decoder.load_state_dict(new_decoder_dict)

        # Initialize optimizers and criterion
        if args['encoder'] == 'RNN':
            self.optimizer = optim.Adam(self.parameters(), lr=lr)
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                            mode='max',
                                                            factor=0.5,
                                                            patience=1,
                                                            min_lr=0.0001,
                                                            verbose=True)
        else:
            if args['local_rank'] != -1:
                t_total = t_total // torch.distributed.get_world_size()

            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in self.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in self.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            self.optimizer = AdamW(optimizer_grouped_parameters,
                                   lr=args['learn'],
                                   correct_bias=False)
            self.scheduler = WarmupLinearSchedule(
                self.optimizer,
                warmup_steps=args['warmup_proportion'] * t_total,
                t_total=t_total)

        self.reset()

    def print_loss(self):
        print_loss_avg = self.loss / self.print_every
        print_loss_ptr = self.loss_ptr / self.print_every
        print_loss_gate = self.loss_gate / self.print_every
        print_loss_class = self.loss_class / self.print_every
        # print_loss_domain = self.loss_domain / self.print_every
        self.print_every += 1
        return 'L:{:.2f},LP:{:.2f},LG:{:.2f}'.format(print_loss_avg,
                                                     print_loss_ptr,
                                                     print_loss_gate)

    def save_model(self, dec_type):
        directory = 'save/TRADE-' + args["addName"] + args['dataset'] + str(
            self.task) + '/' + 'HDD' + str(self.hidden_size) + 'BSZ' + str(
                args['batch']) + 'DR' + str(self.dropout) + str(dec_type)
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save(self.encoder, directory + '/enc.th')
        torch.save(self.decoder, directory + '/dec.th')

    def reset(self):
        self.loss, self.print_every, self.loss_ptr, self.loss_gate, self.loss_class = 0, 1, 0, 0, 0

    def forward(self, data, clip, slot_temp, reset=0, n_gpu=0):
        if reset: self.reset()
        # Zero gradients of both optimizers
        self.optimizer.zero_grad()

        # Encode and Decode
        use_teacher_forcing = random.random() < args["teacher_forcing_ratio"]
        all_point_outputs, gates, words_point_out, words_class_out = self.encode_and_decode(
            data, use_teacher_forcing, slot_temp)

        loss_ptr = masked_cross_entropy_for_value(
            all_point_outputs.transpose(0, 1).contiguous(),
            data["generate_y"].contiguous(
            ),  #[:,:len(self.point_slots)].contiguous(),
            data["y_lengths"])  #[:,:len(self.point_slots)])
        loss_gate = self.cross_entorpy(
            gates.transpose(0, 1).contiguous().view(-1, gates.size(-1)),
            data["gating_label"].contiguous().view(-1))

        if args["use_gate"]:
            loss = loss_ptr + loss_gate
        else:
            loss = loss_ptr

        self.loss_grad = loss
        self.loss_ptr_to_bp = loss_ptr

        # Update parameters with optimizers
        self.loss += loss.item()
        self.loss_ptr += loss_ptr.item()
        self.loss_gate += loss_gate.item()

        return self.loss_grad

    def optimize_GEM(self, clip):
        torch.nn.utils.clip_grad_norm_(self.parameters(), clip)
        self.optimizer.step()
        if isinstance(self.scheduler, WarmupLinearSchedule):
            self.scheduler.step()

    def encode_and_decode(self, data, use_teacher_forcing, slot_temp):
        if args['encoder'] == 'RNN' or args['encoder'] == 'TPRNN':
            # Build unknown mask for memory to encourage generalization
            if args['unk_mask'] and self.decoder.training:
                story_size = data['context'].size()
                rand_mask = np.ones(story_size)
                bi_mask = np.random.binomial(
                    [np.ones(
                        (story_size[0], story_size[1]))], 1 - self.dropout)[0]
                rand_mask = rand_mask * bi_mask
                rand_mask = torch.Tensor(rand_mask).to(self.device)
                story = data['context'] * rand_mask.long()
            else:
                story = data['context']

            story = story.to(self.device)
            # encoded_outputs, encoded_hidden = self.encoder(story.transpose(0, 1), data['context_len'])
            encoded_outputs, encoded_hidden = self.encoder(
                story, data['context_len'])

        # Encode dialog history
        # story  32 396
        # data['context_len'] 32
        elif args['encoder'] == 'BERT':
            # import pdb; pdb.set_trace()
            story = data['context']
            # story_plain = data['context_plain']

            all_input_ids = data['all_input_ids']
            all_input_mask = data['all_input_mask']
            all_segment_ids = data['all_segment_ids']
            all_sub_word_masks = data['all_sub_word_masks']

            encoded_outputs, encoded_hidden = self.encoder(
                all_input_ids, all_input_mask, all_segment_ids,
                all_sub_word_masks)
            encoded_hidden = encoded_hidden.unsqueeze(0)

        # Get the words that can be copied from the memory
        # import pdb; pdb.set_trace()
        batch_size = len(data['context_len'])
        self.copy_list = data['context_plain']
        max_res_len = data['generate_y'].size(
            2) if self.encoder.training else 10

        all_point_outputs, all_gate_outputs, words_point_out, words_class_out = self.decoder.forward(batch_size, \
            encoded_hidden, encoded_outputs, data['context_len'], story, max_res_len, data['generate_y'], \
            use_teacher_forcing, slot_temp)

        return all_point_outputs, all_gate_outputs, words_point_out, words_class_out

    def evaluate(self,
                 dev,
                 matric_best,
                 slot_temp,
                 device,
                 save_dir="",
                 save_string="",
                 early_stop=None):
        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.decoder.train(False)
        print("STARTING EVALUATION")
        all_prediction = {}
        inverse_unpoint_slot = dict([(v, k)
                                     for k, v in self.gating_dict.items()])
        pbar = enumerate(dev)
        for j, data_dev in pbar:
            # Encode and Decode
            eval_data = {}
            # wrap all numerical values as tensors for multi-gpu training
            for k, v in data_dev.items():
                if isinstance(v, torch.Tensor):
                    eval_data[k] = v.to(device)
                elif isinstance(v, list):
                    if k in [
                            'ID', 'turn_belief', 'context_plain',
                            'turn_uttr_plain'
                    ]:
                        eval_data[k] = v
                    else:
                        eval_data[k] = torch.tensor(v).to(device)
                else:
                    # print('v is: {} and this ignoring {}'.format(v, k))
                    pass
            batch_size = len(data_dev['context_len'])
            with torch.no_grad():
                _, gates, words, class_words = self.encode_and_decode(
                    eval_data, False, slot_temp)

            for bi in range(batch_size):
                if data_dev["ID"][bi] not in all_prediction.keys():
                    all_prediction[data_dev["ID"][bi]] = {}
                all_prediction[data_dev["ID"][bi]][data_dev["turn_id"][bi]] = {
                    "turn_belief": data_dev["turn_belief"][bi]
                }
                predict_belief_bsz_ptr, predict_belief_bsz_class = [], []
                gate = torch.argmax(gates.transpose(0, 1)[bi], dim=1)
                # import pdb; pdb.set_trace()

                # pointer-generator results
                if args["use_gate"]:
                    for si, sg in enumerate(gate):
                        if sg == self.gating_dict["none"]:
                            continue
                        elif sg == self.gating_dict["ptr"]:
                            pred = np.transpose(words[si])[bi]
                            st = []
                            for e in pred:
                                if e == 'EOS': break
                                else: st.append(e)
                            st = " ".join(st)
                            if st == "none":
                                continue
                            else:
                                predict_belief_bsz_ptr.append(slot_temp[si] +
                                                              "-" + str(st))
                        else:
                            predict_belief_bsz_ptr.append(
                                slot_temp[si] + "-" +
                                inverse_unpoint_slot[sg.item()])
                else:
                    for si, _ in enumerate(gate):
                        pred = np.transpose(words[si])[bi]
                        st = []
                        for e in pred:
                            if e == 'EOS': break
                            else: st.append(e)
                        st = " ".join(st)
                        if st == "none":
                            continue
                        else:
                            predict_belief_bsz_ptr.append(slot_temp[si] + "-" +
                                                          str(st))

                all_prediction[data_dev["ID"][bi]][data_dev["turn_id"][bi]][
                    "pred_bs_ptr"] = predict_belief_bsz_ptr

                #if set(data_dev["turn_belief"][bi]) != set(predict_belief_bsz_ptr) and args["genSample"]:
                #    print("True", set(data_dev["turn_belief"][bi]) )
                #    print("Pred", set(predict_belief_bsz_ptr), "\n")

        if args["genSample"]:
            if save_dir is not "" and not os.path.exists(save_dir):
                os.mkdir(save_dir)
            json.dump(all_prediction,
                      open(
                          os.path.join(
                              save_dir, "prediction_{}_{}.json".format(
                                  self.name, save_string)), 'w'),
                      indent=4)
            print(
                "saved generated samples",
                os.path.join(
                    save_dir,
                    "prediction_{}_{}.json".format(self.name, save_string)))

        joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = self.evaluate_metrics(
            all_prediction, "pred_bs_ptr", slot_temp)

        evaluation_metrics = {
            "Joint Acc": joint_acc_score_ptr,
            "Turn Acc": turn_acc_score_ptr,
            "Joint F1": F1_score_ptr
        }
        print(evaluation_metrics)

        # Set back to training mode
        self.encoder.train(True)
        self.decoder.train(True)

        joint_acc_score = joint_acc_score_ptr  # (joint_acc_score_ptr + joint_acc_score_class)/2
        F1_score = F1_score_ptr

        if (early_stop == 'F1'):
            if (F1_score >= matric_best):
                self.save_model('ENTF1-{:.4f}'.format(F1_score))
                print("MODEL SAVED")
            return F1_score
        else:
            if (joint_acc_score >= matric_best):
                self.save_model('ACC-{:.4f}'.format(joint_acc_score))
                print("MODEL SAVED")
            return joint_acc_score

    def evaluate_metrics(self, all_prediction, from_which, slot_temp):
        total, turn_acc, joint_acc, F1_pred, F1_count = 0, 0, 0, 0, 0
        for d, v in all_prediction.items():
            for t in range(len(v)):
                cv = v[t]
                if set(cv["turn_belief"]) == set(cv[from_which]):
                    joint_acc += 1
                total += 1

                # Compute prediction slot accuracy
                temp_acc = self.compute_acc(set(cv["turn_belief"]),
                                            set(cv[from_which]), slot_temp)
                turn_acc += temp_acc

                # Compute prediction joint F1 score
                temp_f1, temp_r, temp_p, count = self.compute_prf(
                    set(cv["turn_belief"]), set(cv[from_which]))
                F1_pred += temp_f1
                F1_count += count

        joint_acc_score = joint_acc / float(total) if total != 0 else 0
        turn_acc_score = turn_acc / float(total) if total != 0 else 0
        F1_score = F1_pred / float(F1_count) if F1_count != 0 else 0
        return joint_acc_score, F1_score, turn_acc_score

    def compute_acc(self, gold, pred, slot_temp):
        miss_gold = 0
        miss_slot = []
        for g in gold:
            if g not in pred:
                miss_gold += 1
                miss_slot.append(g.rsplit("-", 1)[0])
        wrong_pred = 0
        for p in pred:
            if p not in gold and p.rsplit("-", 1)[0] not in miss_slot:
                wrong_pred += 1
        ACC_TOTAL = len(slot_temp)
        ACC = len(slot_temp) - miss_gold - wrong_pred
        ACC = ACC / float(ACC_TOTAL)
        return ACC

    def compute_prf(self, gold, pred):
        TP, FP, FN = 0, 0, 0
        if len(gold) != 0:
            count = 1
            for g in gold:
                if g in pred:
                    TP += 1
                else:
                    FN += 1
            for p in pred:
                if p not in gold:
                    FP += 1
            precision = TP / float(TP + FP) if (TP + FP) != 0 else 0
            recall = TP / float(TP + FN) if (TP + FN) != 0 else 0
            F1 = 2 * precision * recall / float(precision + recall) if (
                precision + recall) != 0 else 0
        else:
            if len(pred) == 0:
                precision, recall, F1, count = 1, 1, 1, 1
            else:
                precision, recall, F1, count = 0, 0, 0, 1
        return F1, recall, precision, count
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")

    parser.add_argument(
        "--data_label",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples, _ = get_FEVER_examples('train', hypo_only=False)
    dev_and_test_examples, _ = get_FEVER_examples('dev', hypo_only=False)
    random.shuffle(dev_and_test_examples)
    dev_examples = dev_and_test_examples[:-10000]
    test_examples = dev_and_test_examples[-10000:]
    '''write into files'''
    def examples_2_file(exs, prefix):
        writefile = codecs.open(
            '/export/home/Dataset/para_entail_datasets/nli_FEVER/nli_fever/my_split_binary/'
            + prefix + '.txt', 'w', 'utf-8')
        for ex in exs:
            writefile.write(ex.label + '\t' + ex.text_a + '\t' + ex.text_b +
                            '\n')
        print('print over')
        writefile.close()

    examples_2_file(train_examples, 'train')
    examples_2_file(dev_examples, 'dev')
    examples_2_file(test_examples, 'test')
    exit(0)

    label_list = ["entailment", "not_entailment"]  #, "contradiction"]
    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'dev size:', len(dev_examples), ' test size:', len(test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    model = RobertaForSequenceClassification(num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(
        torch.load(
            '/export/home/Dataset/BERT_pretrained_mine/paragraph_entail/2021/ANLI_CNNDailyMail_DUC_Curation_SQUAD_epoch_1.pt',
            map_location=device))
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)
        '''load dev set'''
        dev_features = convert_examples_to_features(
            dev_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                         dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                          dtype=torch.long)
        dev_all_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features],
                                         dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask,
                                 dev_all_segment_ids, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.eval_batch_size)
        '''load test set'''
        test_features = convert_examples_to_features(
            test_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        test_all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                          dtype=torch.long)
        test_all_input_mask = torch.tensor(
            [f.input_mask for f in test_features], dtype=torch.long)
        test_all_segment_ids = torch.tensor(
            [f.segment_ids for f in test_features], dtype=torch.long)
        test_all_label_ids = torch.tensor([f.label_id for f in test_features],
                                          dtype=torch.long)

        test_data = TensorDataset(test_all_input_ids, test_all_input_mask,
                                  test_all_segment_ids, test_all_label_ids)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        iter_co = 0
        final_test_performance = 0.0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                logits = model(input_ids, input_mask)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
            '''
            start evaluate on dev set after this epoch
            '''
            model.eval()

            dev_acc = evaluation(dev_dataloader, device, model)

            if dev_acc > max_dev_acc:
                max_dev_acc = dev_acc
                print('\ndev acc:', dev_acc, ' max_dev_acc:', max_dev_acc,
                      '\n')
                '''evaluate on the test set with the best dev model'''
                final_test_performance = evaluation(test_dataloader, device,
                                                    model)
                print('\ntest acc:', final_test_performance, '\n')

            else:
                print('\ndev acc:', dev_acc, ' max_dev_acc:', max_dev_acc,
                      '\n')
        print('final_test_performance:', final_test_performance)
Beispiel #20
0
def train():
    # 检查配置,获取超参数
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print("device:{} n_gpu:{}".format(device, n_gpu))
    seed = hyperparameters["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    max_seq_length = hyperparameters["max_sent_length"]
    gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
    num_epochs = hyperparameters["num_epoch"]
    train_batch_size = hyperparameters["train_batch_size"] // hyperparameters["gradient_accumulation_steps"]
    tokenizer = BertTokenizer.from_pretrained("bert-large-uncased", do_lower_case=True)
    model = BertForMultipleChoice.from_pretrained("bert-large-uncased")
    model.to(device)

    # 优化器
    param_optimizer = list(model.named_parameters())

    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    # 载入数据
    train_examples = read_examples('../dataset/train_bert.txt')
    dev_examples = read_examples('../dataset/test_bert.txt')
    nTrain = len(train_examples)
    nDev = len(dev_examples)
    num_train_optimization_steps = int(nTrain / train_batch_size / gradient_accumulation_steps) * num_epochs
    optimizer = AdamW(optimizer_grouped_parameters, lr=hyperparameters["learning_rate"])
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * num_train_optimization_steps),
                                                num_training_steps=num_train_optimization_steps)

    global_step = 0
    train_features = convert_examples_to_features(train_examples, tokenizer, max_seq_length)
    dev_features = convert_examples_to_features(dev_examples, tokenizer, max_seq_length)
    train_dataloader = get_train_dataloader(train_features, train_batch_size)
    dev_dataloader = get_eval_dataloader(dev_features, hyperparameters["eval_batch_size"])
    print("Num of train features:{}".format(nTrain))
    print("Num of dev features:{}".format(nDev))
    best_dev_accuracy = 0
    best_dev_epoch = 0
    no_up = 0

    epoch_tqdm = trange(int(num_epochs), desc="Epoch")
    for epoch in epoch_tqdm:
        model.train()

        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, label_ids = batch
            loss, logits = model(input_ids=input_ids, labels=label_ids)[:2]
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            loss.backward()
            if (step + 1) % gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

        train_loss, train_accuracy = evaluate(model, device, train_dataloader, "Train")
        dev_loss, dev_accuracy = evaluate(model, device, dev_dataloader, "Dev")

        if dev_accuracy > best_dev_accuracy:
            best_dev_accuracy = dev_accuracy
            best_dev_epoch = epoch + 1
            no_up = 0

        else:
            no_up += 1
        tqdm.write("\t ***** Eval results (Epoch %s) *****" % str(epoch + 1))
        tqdm.write("\t train_accuracy = %s" % str(train_accuracy))
        tqdm.write("\t dev_accuracy = %s" % str(dev_accuracy))
        tqdm.write("")
        tqdm.write("\t best_dev_accuracy = %s" % str(best_dev_accuracy))
        tqdm.write("\t best_dev_epoch = %s" % str(best_dev_epoch))
        tqdm.write("\t no_up = %s" % str(no_up))
        tqdm.write("")
        if no_up >= hyperparameters["patience"]:
            epoch_tqdm.close()
            break
Beispiel #21
0
def trainBERT(model, train_loader, val_loader, num_epoch=5, lr=2e-2):
    # Training steps
    start_time = time.time()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=lr, eps=1e-8)

    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    auc = []
    best_auc = 0.
    best_model = copy.deepcopy(model.state_dict())

    for epoch in range(num_epoch):
        model.train()
        #Initialize
        correct = 0
        total = 0
        total_loss = 0

        for i, (data, mask, labels) in enumerate(train_loader):
            data, mask, labels = data.to(device), mask.to(device), labels.to(
                device, dtype=torch.long)
            optimizer.zero_grad()

            outputs = model(data,
                            token_type_ids=None,
                            attention_mask=mask,
                            labels=None)

            loss = loss_fn(outputs.view(-1, 2), labels.view(-1))

            loss.backward()
            optimizer.step()
            label_cpu = labels.squeeze().to('cpu').numpy()
            pred = outputs.data.max(-1)[1].to('cpu').numpy()
            total += labels.size(0)
            correct += float(sum((pred == label_cpu)))
            total_loss += loss.item()

        acc = correct / total

        t_loss = total_loss / total
        train_loss.append(t_loss)
        train_acc.append(acc)
        # report performance

        print('Epoch: ', epoch)
        print('Train set | Accuracy: {:6.4f} | Loss: {:6.4f}'.format(
            acc, t_loss))

        # Evaluate after every epoch
        #Reset the initialization
        correct = 0
        total = 0
        total_loss = 0
        model.eval()

        predictions = []
        truths = []

        with torch.no_grad():
            for i, (data, mask, labels) in enumerate(val_loader):
                data, mask, labels = data.to(device), mask.to(
                    device), labels.to(device, dtype=torch.long)

                optimizer.zero_grad()

                outputs = model(data,
                                token_type_ids=None,
                                attention_mask=mask,
                                labels=None)
                #va_loss = loss_fn(outputs.squeeze(-1), labels)
                va_loss = loss_fn(outputs.view(-1, 2), labels.view(-1))

                label_cpu = labels.squeeze().to('cpu').numpy()

                pred = outputs.data.max(-1)[1].to('cpu').numpy()
                total += labels.size(0)
                correct += float(sum((pred == label_cpu)))
                total_loss += va_loss.item()

                predictions += list(pred)
                truths += list(label_cpu)

            v_acc = correct / total
            v_loss = total_loss / total
            val_loss.append(v_loss)
            val_acc.append(v_acc)

            v_auc = roc_auc_score(truths, predictions)
            auc.append(v_auc)

            elapse = time.strftime(
                '%H:%M:%S', time.gmtime(int((time.time() - start_time))))
            print(
                'Validation set | Accuracy: {:6.4f} | AUC: {:6.4f} | Loss: {:4.2f} | time elapse: {:>9}'
                .format(v_acc, v_auc, v_loss, elapse))
            print('-' * 10)

            if v_auc > best_auc:
                best_auc = v_auc
                best_model = copy.deepcopy(model.state_dict())

    print('Best validation auc: {:6.4f}'.format(best_auc))
    model.load_state_dict(best_model)
    return train_loss, train_acc, val_loss, val_acc, v_auc, model
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--round_name",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=50,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    round_name_2_rounds = {
        'r1': ['n1', 'ood'],
        'r2': ['n1', 'n2', 'ood'],
        'r3': ['n1', 'n2', 'n3', 'ood'],
        'r4': ['n1', 'n2', 'n3', 'n4', 'ood'],
        'r5': ['n1', 'n2', 'n3', 'n4', 'n5', 'ood']
    }

    model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(torch.load('../../data/MNLI_pretrained.pt'),
                          strict=False)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    banking77_class_list, ood_class_set, class_2_split = load_class_names()

    round_list = round_name_2_rounds.get(args.round_name)
    '''load training in list'''
    train_examples_list, train_class_list, train_class_2_split_list, class_2_sentlist_upto_this_round = processor.load_train(
        round_list[:-1])  # no odd training examples
    assert len(train_class_list) == len(train_class_2_split_list)
    # assert len(train_class_list) ==  20+(len(round_list)-2)*10
    '''dev and test'''
    dev_examples, dev_instance_size = processor.load_dev_or_test(
        round_list, train_class_list, class_2_sentlist_upto_this_round, 'dev')
    test_examples, test_instance_size = processor.load_dev_or_test(
        round_list, train_class_list, class_2_sentlist_upto_this_round, 'test')
    print('train size:', [len(train_i) for train_i in train_examples_list],
          ' dev size:', len(dev_examples), ' test size:', len(test_examples))
    entail_class_list = ['entailment', 'non-entailment']
    eval_class_list = train_class_list + ['ood']
    test_split_list = train_class_2_split_list + ['ood']
    train_dataloader_list = []
    for train_examples in train_examples_list:
        train_dataloader = examples_to_features(train_examples,
                                                entail_class_list,
                                                eval_class_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        train_dataloader_list.append(train_dataloader)
    dev_dataloader = examples_to_features(dev_examples,
                                          entail_class_list,
                                          eval_class_list,
                                          args,
                                          tokenizer,
                                          args.eval_batch_size,
                                          "classification",
                                          dataloader_mode='sequential')
    test_dataloader = examples_to_features(test_examples,
                                           entail_class_list,
                                           eval_class_list,
                                           args,
                                           tokenizer,
                                           args.eval_batch_size,
                                           "classification",
                                           dataloader_mode='sequential')
    '''training'''
    max_test_acc = 0.0
    max_dev_acc = 0.0
    for round_index, round in enumerate(round_list[:-1]):
        '''for the new examples in each round, train multiple epochs'''
        train_dataloader = train_dataloader_list[round_index]
        for epoch_i in range(args.num_train_epochs):
            for _, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="train|" + round + '|epoch_' + str(epoch_i))):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                _, input_ids, input_mask, _, label_ids, _, _ = batch

                logits = model(input_ids, input_mask)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, 3), label_ids.view(-1))
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
        print('\t\t round ', round, ' is over...')
    '''evaluation'''
    model.eval()
    '''test'''
    acc_each_round = []
    preds = []
    gold_guids = []
    gold_premise_ids = []
    gold_hypothesis_ids = []
    for _, batch in enumerate(tqdm(test_dataloader, desc="test")):
        guids, input_ids, input_mask, _, label_ids, premise_class_ids, hypothesis_class_id = batch
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)

        gold_guids += list(guids.detach().cpu().numpy())
        gold_premise_ids += list(premise_class_ids.detach().cpu().numpy())
        gold_hypothesis_ids += list(hypothesis_class_id.detach().cpu().numpy())

        with torch.no_grad():
            logits = model(input_ids, input_mask)
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(preds[0],
                                 logits.detach().cpu().numpy(),
                                 axis=0)
    preds = softmax(preds[0], axis=1)

    pred_label_3way = np.argmax(preds,
                                axis=1)  #dev_examples, 0 means "entailment"
    pred_probs = list(
        preds[:, 0])  #prob for "entailment" class: (#input, #seen_classe)
    assert len(pred_label_3way) == len(test_examples)
    assert len(pred_probs) == len(test_examples)
    assert len(gold_premise_ids) == len(test_examples)
    assert len(gold_hypothesis_ids) == len(test_examples)
    assert len(gold_guids) == len(test_examples)

    guid_2_premise_idlist = defaultdict(list)
    guid_2_hypoID_2_problist_labellist = {}
    for guid_i, threeway_i, prob_i, premise_i, hypo_i in zip(
            gold_guids, pred_label_3way, pred_probs, gold_premise_ids,
            gold_hypothesis_ids):
        guid_2_premise_idlist[guid_i].append(premise_i)
        hypoID_2_problist_labellist = guid_2_hypoID_2_problist_labellist.get(
            guid_i)
        if hypoID_2_problist_labellist is None:
            hypoID_2_problist_labellist = {}
        lists = hypoID_2_problist_labellist.get(hypo_i)
        if lists is None:
            lists = [[], []]
        lists[0].append(prob_i)
        lists[1].append(threeway_i)
        hypoID_2_problist_labellist[hypo_i] = lists
        guid_2_hypoID_2_problist_labellist[
            guid_i] = hypoID_2_problist_labellist

    pred_label_ids = []
    gold_label_ids = []
    for guid in range(test_instance_size):
        assert len(set(guid_2_premise_idlist.get(guid))) == 1
        gold_label_ids.append(guid_2_premise_idlist.get(guid)[0])
        '''infer predict label id'''
        hypoID_2_problist_labellist = guid_2_hypoID_2_problist_labellist.get(
            guid)

        final_max_mean_prob = 0.0
        final_hypo_id = -1
        for hypo_id, problist_labellist in hypoID_2_problist_labellist.items():
            problist = problist_labellist[0]
            mean_prob = np.mean(problist)
            labellist = problist_labellist[1]
            same_cluter_times = labellist.count(
                0)  #'entailment' is the first label
            same_cluter = False
            if same_cluter_times / len(labellist) > 0.5:
                same_cluter = True

            if same_cluter is True and mean_prob > final_max_mean_prob:
                final_max_mean_prob = mean_prob
                final_hypo_id = hypo_id
        if final_hypo_id != -1:  # can find a class that it belongs to
            pred_label_ids.append(final_hypo_id)
        else:
            pred_label_ids.append(len(train_class_list))

    assert len(pred_label_ids) == len(gold_label_ids)
    acc_each_round = []
    for round_name_id in round_list:
        #base, n1, n2, ood
        round_size = 0
        rount_hit = 0
        if round_name_id != 'ood':
            for ii, gold_label_id in enumerate(gold_label_ids):
                if test_split_list[gold_label_id] == round_name_id:
                    round_size += 1
                    if gold_label_id == pred_label_ids[ii]:
                        rount_hit += 1
            acc_i = rount_hit / round_size
            acc_each_round.append(acc_i)
        else:
            '''ood acc'''
            gold_binary_list = []
            pred_binary_list = []
            for ii, gold_label_id in enumerate(gold_label_ids):
                gold_binary_list.append(1 if test_split_list[gold_label_id] ==
                                        round_name_id else 0)
                pred_binary_list.append(1 if pred_label_ids[ii] ==
                                        len(train_class_list) else 0)
            overlap = 0
            for i in range(len(gold_binary_list)):
                if gold_binary_list[i] == 1 and pred_binary_list[i] == 1:
                    overlap += 1
            recall = overlap / (1e-6 + sum(gold_binary_list))
            precision = overlap / (1e-6 + sum(pred_binary_list))
            acc_i = 2 * recall * precision / (1e-6 + recall + precision)
            acc_each_round.append(acc_i)

    print('final_test_performance:', acc_each_round)
Beispiel #23
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--round_name",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    round_name_2_rounds = {
        'base': ['base', 'ood'],
        'r1': ['base', 'n1', 'ood'],
        'r2': ['base', 'n1', 'n2', 'ood'],
        'r3': ['base', 'n1', 'n2', 'n3', 'ood'],
        'r4': ['base', 'n1', 'n2', 'n3', 'n4', 'ood'],
        'r5': ['base', 'n1', 'n2', 'n3', 'n4', 'n5', 'ood']
    }

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    banking77_class_list, ood_class_set, class_2_split = load_class_names()

    round_list = round_name_2_rounds.get(args.round_name)
    train_examples, base_class_list = processor.load_train(
        ['base'])  #train on base only
    '''train the first stage'''
    model = RobertaForSequenceClassification(len(base_class_list))
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    train_dataloader = examples_to_features(train_examples,
                                            base_class_list,
                                            args,
                                            tokenizer,
                                            args.train_batch_size,
                                            "classification",
                                            dataloader_mode='random')
    mean_loss = 0.0
    count = 0
    for _ in trange(int(args.num_train_epochs), desc="Stage1Epoch"):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            logits = model(input_ids, input_mask, output_rep=False)
            loss_fct = CrossEntropyLoss()

            loss = loss_fct(logits.view(-1, len(base_class_list)),
                            label_ids.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            mean_loss += loss.item()
            count += 1
            # if count % 50 == 0:
            #     print('mean loss:', mean_loss/count)
    print('stage 1, train supervised classification on base is over.')
    '''now, train the second stage'''
    model_stage_2 = ModelStageTwo(len(base_class_list), model)
    model_stage_2.to(device)

    param_optimizer = list(model_stage_2.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer_stage_2 = AdamW(optimizer_grouped_parameters,
                              lr=args.learning_rate)
    mean_loss = 0.0
    count = 0
    best_threshold = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Stage2Epoch"):
        '''first, select some base classes as fake novel classes'''
        fake_novel_size = 15
        fake_novel_support_size = 5
        '''for convenience, we keep shuffle the base classes, select the last 5 as fake novel'''
        original_base_class_idlist = list(range(len(base_class_list)))
        # random.shuffle(original_base_class_idlist)
        shuffled_base_class_list = [
            base_class_list[idd] for idd in original_base_class_idlist
        ]
        fake_novel_classlist = shuffled_base_class_list[-fake_novel_size:]
        '''load their support examples'''
        base_support_examples = processor.load_base_support_examples(
            fake_novel_classlist, fake_novel_support_size)
        base_support_dataloader = examples_to_features(
            base_support_examples,
            fake_novel_classlist,
            args,
            tokenizer,
            fake_novel_support_size,
            "classification",
            dataloader_mode='sequential')

        novel_class_support_reps = []
        for _, batch in enumerate(base_support_dataloader):
            input_ids, input_mask, segment_ids, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            model.eval()
            with torch.no_grad():
                support_rep_for_novel_class = model(input_ids,
                                                    input_mask,
                                                    output_rep=True)
            novel_class_support_reps.append(support_rep_for_novel_class)
        assert len(novel_class_support_reps) == fake_novel_size
        print('Extracting support reps for fake novel is over.')
        '''retrain on query set to optimize the weight generator'''
        train_dataloader = examples_to_features(train_examples,
                                                shuffled_base_class_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        best_threshold_list = []
        for _ in range(10):  #repeat 10 times is important
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model_stage_2.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                logits = model_stage_2(
                    input_ids,
                    input_mask,
                    model,
                    novel_class_support_reps=novel_class_support_reps,
                    fake_novel_size=fake_novel_size,
                    base_class_mapping=original_base_class_idlist)
                # print('logits:', logits)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, len(base_class_list)),
                                label_ids.view(-1))
                loss.backward()
                optimizer_stage_2.step()
                optimizer_stage_2.zero_grad()
                mean_loss += loss.item()
                count += 1
                if count % 50 == 0:
                    print('mean loss:', mean_loss / count)
                scores_for_positive = logits[torch.arange(logits.shape[0]),
                                             label_ids.view(-1)].mean()
                best_threshold_list.append(scores_for_positive.item())

        best_threshold = sum(best_threshold_list) / len(best_threshold_list)

    print('stage 2 training over')
    '''
    start testing
    '''
    '''first, get reps for all base+novel classes'''
    '''support for all seen classes'''
    class_2_support_examples, seen_class_list = processor.load_support_all_rounds(
        round_list[:-1])  #no support set for ood
    assert seen_class_list[:len(base_class_list)] == base_class_list
    seen_class_list_size = len(seen_class_list)
    support_example_lists = [
        class_2_support_examples.get(seen_class)
        for seen_class in seen_class_list if seen_class not in base_class_list
    ]

    novel_class_support_reps = []
    for eval_support_examples_per_class in support_example_lists:
        support_dataloader = examples_to_features(
            eval_support_examples_per_class,
            seen_class_list,
            args,
            tokenizer,
            5,
            "classification",
            dataloader_mode='random')
        single_class_support_reps = []
        for _, batch in enumerate(support_dataloader):
            input_ids, input_mask, segment_ids, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            model.eval()
            with torch.no_grad():
                support_rep_for_novel_class = model(input_ids,
                                                    input_mask,
                                                    output_rep=True)
            single_class_support_reps.append(support_rep_for_novel_class)
        single_class_support_reps = torch.cat(single_class_support_reps,
                                              axis=0)
        novel_class_support_reps.append(single_class_support_reps)
    print('len(novel_class_support_reps):', len(novel_class_support_reps))
    print('len(base_class_list):', len(base_class_list))
    print('len(seen_class_list):', len(seen_class_list))
    assert len(novel_class_support_reps) + len(base_class_list) == len(
        seen_class_list)
    print('Extracting support reps for all  novel is over.')
    test_examples = processor.load_dev_or_test(round_list, 'test')
    test_class_list = seen_class_list + list(ood_class_set)
    print('test_class_list:', len(test_class_list))
    print('best_threshold:', best_threshold)
    test_split_list = []
    for test_class_i in test_class_list:
        test_split_list.append(class_2_split.get(test_class_i))
    test_dataloader = examples_to_features(test_examples,
                                           test_class_list,
                                           args,
                                           tokenizer,
                                           args.eval_batch_size,
                                           "classification",
                                           dataloader_mode='sequential')
    '''test on test batch '''
    preds = []
    gold_label_ids = []
    for input_ids, input_mask, segment_ids, label_ids in test_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)
        gold_label_ids += list(label_ids.detach().cpu().numpy())
        model_stage_2.eval()
        with torch.no_grad():
            logits = model_stage_2(
                input_ids,
                input_mask,
                model,
                novel_class_support_reps=novel_class_support_reps,
                fake_novel_size=None,
                base_class_mapping=None)
        # print('test logits:', logits)
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(preds[0],
                                 logits.detach().cpu().numpy(),
                                 axis=0)

    preds = preds[0]

    pred_probs = preds  #softmax(preds,axis=1)
    pred_label_ids_raw = list(np.argmax(pred_probs, axis=1))
    pred_max_prob = list(np.amax(pred_probs, axis=1))

    pred_label_ids = []
    for i, pred_max_prob_i in enumerate(pred_max_prob):
        if pred_max_prob_i < best_threshold:
            pred_label_ids.append(
                seen_class_list_size)  #seen_class_list_size means ood
        else:
            pred_label_ids.append(pred_label_ids_raw[i])

    assert len(pred_label_ids) == len(gold_label_ids)
    acc_each_round = []
    for round_name_id in round_list:
        #base, n1, n2, ood
        round_size = 0
        rount_hit = 0
        if round_name_id != 'ood':
            for ii, gold_label_id in enumerate(gold_label_ids):
                if test_split_list[gold_label_id] == round_name_id:
                    round_size += 1
                    # print('gold_label_id:', gold_label_id, 'pred_label_ids[ii]:', pred_label_ids[ii])
                    if gold_label_id == pred_label_ids[ii]:
                        rount_hit += 1
            acc_i = rount_hit / round_size
            acc_each_round.append(acc_i)
        else:
            '''ood f1'''
            gold_binary_list = []
            pred_binary_list = []
            for ii, gold_label_id in enumerate(gold_label_ids):
                # print('gold_label_id:', gold_label_id, 'pred_label_ids[ii]:', pred_label_ids[ii])
                gold_binary_list.append(1 if test_split_list[gold_label_id] ==
                                        round_name_id else 0)
                pred_binary_list.append(1 if pred_label_ids[ii] ==
                                        seen_class_list_size else 0)
            overlap = 0
            for i in range(len(gold_binary_list)):
                if gold_binary_list[i] == 1 and pred_binary_list[i] == 1:
                    overlap += 1
            recall = overlap / (1e-6 + sum(gold_binary_list))
            precision = overlap / (1e-6 + sum(pred_binary_list))

            acc_i = 2 * recall * precision / (1e-6 + recall + precision)
            acc_each_round.append(acc_i)

    print('\n\t\t test_acc:', acc_each_round)
    final_test_performance = acc_each_round

    print('final_test_performance:', final_test_performance)
def main():
    parser = HfArgumentParser(TrainingArguments)
    args: TrainingArguments = parser.parse_args_into_dataclasses()[0]
    # Prepare output directory
    if not args.do_eval:
        args.output_dir = os.path.join(
            args.output_dir,
            list(filter(None,
                        args.model.strip().split("/")))[-1] + "-" +
            datetime.now().strftime("%Y%m%d_%H%M%S"))
        os.mkdir(args.output_dir)
    logger = init_logger("souhu-text-match-2021", args.output_dir)
    logger.info(f"Output dir: {args.output_dir}")

    # # Prepare devices
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()
    # args.n_gpu = 1

    logger.info(f"device: {device}, n_gpu: {args.n_gpu}")
    logger.info(f"Training arguments: {args}")

    set_seed(args)
    train_dataloader = create_batch_iter(args, "train", logger)
    valid_dataloader = create_batch_iter(args, "valid", logger)

    model_dir = "/home/zhuminghao/work/model/pt/longformer/"
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    config = AutoConfig.from_pretrained(model_dir,
                                        num_labels=2,
                                        return_dict=True)
    model = LongformerForClassification(config, model_dir)
    model.to(device)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # param_optimizer = list(model.named_parameters())
    # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # optimizer_grouped_parameters = [
    #     {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    #     {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    # ]

    optimizer = AdamW(model.parameters(), lr=args.learning_rate)
    # scheduler = lr_scheduler.StepLR(optimizer, 2)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='min',
                                               factor=0.1,
                                               patience=2)
    pgd = PGD(model)
    K = 3

    # Train and evaluate
    global_step = 0
    best_dev_f1, best_epoch = float("-inf"), float("-inf")
    output_eval_file = os.path.join(args.output_dir, "eval_results.txt")

    train_loss2plot = []
    train_acc2plot = []
    train_f1_2plot = []
    eval_loss2plot = []
    eval_acc2plot = []
    eval_f1_2plot = []
    for epoch_ in trange(int(args.num_train_epochs), desc="Epoch", ascii=True):
        tr_loss = 0.
        train_logits = []
        train_labels = []

        model.train()

        # try:
        #     with tqdm(train_dataloader, desc=f"Epoch {epoch_ + 1} iteration", ascii=True, position=0) as tq:
        # tqdm 不单行显示,搜到一下两种解决方案,现方案是加入参数 ascii=True
        # https://blog.csdn.net/martinkeith/article/details/115668425
        # https://blog.csdn.net/weixin_42138078/article/details/81215207
        for step, batch in enumerate(
                tqdm(train_dataloader,
                     desc=f"Epoch {epoch_ + 1} iteration",
                     ascii=True)):
            # for step, batch in enumerate(tq):
            sources, targets, labels = batch
            inputs = list(zip(sources, targets))
            labels = torch.tensor([int(label) for label in labels],
                                  dtype=torch.long)
            pt_batch = tokenizer(inputs,
                                 padding=True,
                                 truncation="longest_first",
                                 max_length=args.max_seq_length,
                                 return_tensors="pt")
            pt_batch = pt_batch.to(device)
            labels = labels.to(device)

            outputs = model(**pt_batch, labels=labels, return_dict=True)
            train_logits.append(outputs.logits)
            train_labels.append(labels)

            loss = outputs.loss

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()  # 方向传播,得到正常的grad

            if args.do_adversarial:
                # 对抗训练
                pgd.backup_grad()

                for t in range(K):
                    pgd.attack(is_first_attack=(
                        t == 0))  # 在embedding上添加对抗扰动,first attack时备份param.data
                    if t != K - 1:
                        model.zero_grad()
                    else:
                        pgd.restore_grad()
                    adv_outputs = model(**pt_batch,
                                        labels=labels,
                                        return_dict=True)
                    adv_loss = adv_outputs.loss
                    if args.n_gpu > 1:
                        adv_loss = adv_loss.mean()
                    adv_loss.backward()  # 反向传播,并在正常grad基础上,累加对抗训练的梯度
                pgd.restore()  # 恢复embedding参数

            # 梯度下降,更新参数
            optimizer.step()
            optimizer.zero_grad()

            tr_loss += loss.item()
            global_step += 1

            if (step + 1) % args.gradient_accumulation_steps == 0:
                pass

            if (global_step + 1) % args.eval_step == 0:
                logger.info("***** Running evaluation *****")
                logger.info("  Process = {} iter {} step".format(
                    epoch_, global_step))
                logger.info("  Batch size = %d", args.eval_batch_size)
                logger.info(
                    f"next step learning rate = {optimizer.param_groups[0]['lr']:.8f}"
                )

                all_train_logits = torch.cat(train_logits, dim=0).cpu()
                all_train_labels = torch.cat(train_labels, dim=0).cpu()
                acc, prf = evaluate(all_train_logits, all_train_labels)

                train_loss2plot.append(loss.item())
                train_acc2plot.append(acc)
                train_f1_2plot.append(prf[2])

                loss = tr_loss / (step + 1)

                result = do_eval(args, model, tokenizer, valid_dataloader,
                                 device, epoch_, args.num_train_epochs, "eval",
                                 logger)
                scheduler.step(result["eval_loss"])
                eval_loss2plot.append(result["eval_loss"])
                eval_acc2plot.append(result["eval_acc"])
                eval_f1_2plot.append((result["eval_f1"]))

                result['global_step'] = global_step
                result['train_loss'] = loss

                result_to_file(result, output_eval_file, logger)

                if args.do_eval:
                    save_model = False
                else:
                    save_model = False
                    if result['eval_f1'] > best_dev_f1:
                        best_dev_f1 = result['eval_f1']
                        best_epoch = epoch_ + 1
                        save_model = True

                if save_model:
                    logger.info("***** Save model *****")
                    best_model = model
                    model_to_save = model.module if hasattr(
                        best_model, 'module') else best_model

                    output_model_file = os.path.join(args.output_dir,
                                                     "pytorch_model.bin")
                    output_config_file = os.path.join(args.output_dir,
                                                      "config.json")

                    torch.save(model_to_save.state_dict(), output_model_file)
                    model_to_save.config.to_json_file(output_config_file)
                    tokenizer.save_vocabulary(args.output_dir)
        # except KeyboardInterrupt:
        #     tq.close()
        #     raise
        # tq.close()

    logger.info(f"best epoch: {best_epoch}, best eval f1:{best_dev_f1:.4f}")

    loss_acc_plot([
        train_loss2plot, train_acc2plot, train_f1_2plot, eval_loss2plot,
        eval_acc2plot, eval_f1_2plot
    ], os.path.join(args.output_dir, "loss_acc_f1.png"))
    logger.info(f"output dir: {args.output_dir}")
Beispiel #25
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument('--kshot',
                        type=float,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--use_mixup",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--beta_sampling_times',
                        type=int,
                        default=10,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples, dev_examples, test_examples, label_list = processor.load_FewRel_data(
        args.kshot)

    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'dev size:', len(dev_examples), 'test size:', len(test_examples))

    model = RobertaForSequenceClassification(num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)
        '''load dev set'''
        dev_features = convert_examples_to_features(
            dev_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                         dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                          dtype=torch.long)
        dev_all_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_span_a_mask = torch.tensor(
            [f.span_a_mask for f in dev_features], dtype=torch.float)
        dev_all_span_b_mask = torch.tensor(
            [f.span_b_mask for f in dev_features], dtype=torch.float)

        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features],
                                         dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask,
                                 dev_all_segment_ids, dev_all_span_a_mask,
                                 dev_all_span_b_mask, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.eval_batch_size)
        '''load test set'''
        test_features = convert_examples_to_features(
            test_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        eval_all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                          dtype=torch.long)
        eval_all_input_mask = torch.tensor(
            [f.input_mask for f in test_features], dtype=torch.long)
        eval_all_segment_ids = torch.tensor(
            [f.segment_ids for f in test_features], dtype=torch.long)
        eval_all_span_a_mask = torch.tensor(
            [f.span_a_mask for f in test_features], dtype=torch.float)
        eval_all_span_b_mask = torch.tensor(
            [f.span_b_mask for f in test_features], dtype=torch.float)
        # eval_all_pair_ids = [f.pair_id for f in test_features]
        eval_all_label_ids = torch.tensor([f.label_id for f in test_features],
                                          dtype=torch.long)

        eval_data = TensorDataset(eval_all_input_ids, eval_all_input_mask,
                                  eval_all_segment_ids, eval_all_span_a_mask,
                                  eval_all_span_b_mask, eval_all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        test_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        # logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_span_a_mask = torch.tensor([f.span_a_mask for f in train_features],
                                       dtype=torch.float)
        all_span_b_mask = torch.tensor([f.span_b_mask for f in train_features],
                                       dtype=torch.float)

        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_span_a_mask,
                                   all_span_b_mask, all_label_ids)
        train_sampler = RandomSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        iter_co = 0
        final_test_performance = 0.0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, span_a_mask, span_b_mask, label_ids = batch

                #input_ids, input_mask, span_a_mask, span_b_mask
                logits = model(input_ids, input_mask, span_a_mask, span_b_mask)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
                # if iter_co %20==0:
                if iter_co % len(train_dataloader) == 0:
                    # if iter_co % (len(train_dataloader)//2)==0:
                    '''
                    start evaluate on dev set after this epoch
                    '''
                    model.eval()

                    for idd, dev_or_test_dataloader in enumerate(
                        [dev_dataloader, test_dataloader]):

                        if idd == 0:
                            logger.info("***** Running dev *****")
                            logger.info("  Num examples = %d",
                                        len(dev_features))
                        else:
                            logger.info("***** Running test *****")
                            logger.info("  Num examples = %d",
                                        len(test_features))
                        # logger.info("  Batch size = %d", args.eval_batch_size)

                        eval_loss = 0
                        nb_eval_steps = 0
                        preds = []
                        gold_label_ids = []
                        # print('Evaluating...')
                        for input_ids, input_mask, segment_ids, span_a_mask, span_b_mask, label_ids in dev_or_test_dataloader:
                            input_ids = input_ids.to(device)
                            input_mask = input_mask.to(device)
                            segment_ids = segment_ids.to(device)
                            span_a_mask = span_a_mask.to(device)
                            span_b_mask = span_b_mask.to(device)
                            label_ids = label_ids.to(device)
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())

                            with torch.no_grad():
                                logits = model(input_ids, input_mask,
                                               span_a_mask, span_b_mask)
                            if len(preds) == 0:
                                preds.append(logits.detach().cpu().numpy())
                            else:
                                preds[0] = np.append(
                                    preds[0],
                                    logits.detach().cpu().numpy(),
                                    axis=0)

                        preds = preds[0]

                        pred_probs = softmax(preds, axis=1)
                        pred_label_ids = list(np.argmax(pred_probs, axis=1))

                        assert len(pred_label_ids) == len(gold_label_ids)
                        hit_co = 0
                        for k in range(len(pred_label_ids)):
                            if pred_label_ids[k] == gold_label_ids[k]:
                                hit_co += 1
                        test_acc = hit_co / len(gold_label_ids)
                        f1 = test_acc

                        if idd == 0:  # this is dev
                            if f1 > max_dev_acc:
                                max_dev_acc = f1
                                print('\ndev acc :', f1, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                # '''store the model, because we can test after a max_dev acc reached'''
                                # model_to_save = (
                                #     model.module if hasattr(model, "module") else model
                                # )  # Take care of distributed/parallel training
                                # store_transformers_models(model_to_save, tokenizer, '/export/home/Dataset/BERT_pretrained_mine/event_2_nli', 'mnli_mypretrained_f1_'+str(max_dev_acc)+'.pt')

                            else:
                                print('\ndev acc :', f1, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                break
                        else:  # this is test
                            if f1 > max_test_acc:
                                max_test_acc = f1
                            final_test_performance = f1
                            print('\ntest acc:', f1, ' max_test_acc:',
                                  max_test_acc, '\n')
        print('final_test_f1:', final_test_performance)
Beispiel #26
0
class Framework(object):
    """A framework wrapping the Relational Graph Extraction model. This framework allows to train, predict, evaluate,
    saving and loading the model with a single line of code.
    """
    def __init__(self, **config):
        super().__init__()

        self.config = config

        self.grad_acc = self.config[
            'grad_acc'] if 'grad_acc' in self.config else 1
        self.device = torch.device(self.config['device'])
        if isinstance(self.config['model'], str):
            self.model = MODELS[self.config['model']](**self.config)
        else:
            self.model = self.config['model']

        self.class_weights = torch.tensor(self.config['class_weights']).float(
        ) if 'class_weights' in self.config else torch.ones(
            self.config['n_rel'])
        if 'lambda' in self.config:
            self.class_weights[0] = self.config['lambda']
        self.loss_fn = nn.CrossEntropyLoss(weight=self.class_weights.to(
            self.device),
                                           reduction='mean')
        if self.config['optimizer'] == 'SGD':
            self.optimizer = torch.optim.SGD(
                self.model.get_parameters(self.config.get('l2', .01)),
                lr=self.config['lr'],
                momentum=self.config.get('momentum', 0),
                nesterov=self.config.get('nesterov', False))
        elif self.config['optimizer'] == 'Adam':
            self.optimizer = torch.optim.Adam(self.model.get_parameters(
                self.config.get('l2', .01)),
                                              lr=self.config['lr'])
        elif self.config['optimizer'] == 'AdamW':
            self.optimizer = AdamW(self.model.get_parameters(
                self.config.get('l2', .01)),
                                   lr=self.config['lr'])
        else:
            raise Exception('The optimizer must be SGD, Adam or AdamW')

    def _train_step(self, dataset, epoch, scheduler=None):
        print("Training:")
        self.model.train()

        total_loss = 0
        predictions, labels, positions = [], [], []
        precision = recall = fscore = 0.0
        progress = tqdm(
            enumerate(dataset),
            desc=
            f"Epoch: {epoch} - Loss: {0.0} - P/R/F: {precision}/{recall}/{fscore}",
            total=len(dataset))
        for i, batch in progress:
            # uncompress the batch
            seq, mask, ent, label = batch
            seq = seq.to(self.device)
            mask = mask.to(self.device)
            ent = ent.to(self.device)
            label = label.to(self.device)

            #self.optimizer.zero_grad()
            output = self.model(seq, mask, ent)
            loss = self.loss_fn(output, label)
            total_loss += loss.item()

            if self.config['half']:
                with amp.scale_loss(loss, self.optimizer) as scale_loss:
                    scale_loss.backward()
            else:
                loss.backward()

            if (i + 1) % self.grad_acc == 0:
                if self.config.get('grad_clip', False):
                    if self.config['half']:
                        nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.config['grad_clip'])
                    else:
                        nn.utils.clip_grad_norm_(self.model.parameters(),
                                                 self.config['grad_clip'])

                self.optimizer.step()
                self.model.zero_grad()
                if scheduler:
                    scheduler.step()

            # Evaluate results
            pre, lab, pos = dataset.evaluate(
                i,
                output.detach().numpy() if self.config['device'] is 'cpu' else
                output.detach().cpu().numpy())

            predictions.extend(pre)
            labels.extend(lab)
            positions.extend(pos)

            if (i + 1) % 10 == 0:
                precision, recall, fscore, _ = precision_recall_fscore_support(
                    np.array(labels),
                    np.array(predictions),
                    average='micro',
                    labels=list(range(1, self.model.n_rel)))

            progress.set_description(
                f"Epoch: {epoch} - Loss: {total_loss/(i+1):.3f} - P/R/F: {precision:.2f}/{recall:.2f}/{fscore:.2f}"
            )

        # For last iteration
        #self.optimizer.step()
        #self.optimizer.zero_grad()

        predictions, labels = np.array(predictions), np.array(labels)
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels,
            predictions,
            average='micro',
            labels=list(range(1, self.model.n_rel)))
        print(
            f"Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels, predictions, average='micro')
        print(
            f"[with NO-RELATION] Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )

        return total_loss / (i + 1)

    def _val_step(self, dataset, epoch):
        print("Validating:")
        self.model.eval()

        predictions, labels, positions = [], [], []
        total_loss = 0
        with torch.no_grad():
            progress = tqdm(enumerate(dataset),
                            desc=f"Epoch: {epoch} - Loss: {0.0}",
                            total=len(dataset))
            for i, batch in progress:
                # uncompress the batch
                seq, mask, ent, label = batch
                seq = seq.to(self.device)
                mask = mask.to(self.device)
                ent = ent.to(self.device)
                label = label.to(self.device)

                output = self.model(seq, mask, ent)
                loss = self.loss_fn(output, label)
                total_loss += loss.item()

                # Evaluate results
                pre, lab, pos = dataset.evaluate(
                    i,
                    output.detach().numpy() if self.config['device'] is 'cpu'
                    else output.detach().cpu().numpy())

                predictions.extend(pre)
                labels.extend(lab)
                positions.extend(pos)

                progress.set_description(
                    f"Epoch: {epoch} - Loss: {total_loss/(i+1):.3f}")

        predictions, labels = np.array(predictions), np.array(labels)
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels,
            predictions,
            average='micro',
            labels=list(range(1, self.model.n_rel)))
        print(
            f"Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )
        noprecision, norecall, nofscore, _ = precision_recall_fscore_support(
            labels, predictions, average='micro')
        print(
            f"[with NO-RELATION] Precision: {noprecision:.3f} - Recall: {norecall:.3f} - F-Score: {nofscore:.3f}"
        )

        return total_loss / (i + 1), precision, recall, fscore

    def _save_checkpoint(self, dataset, epoch, loss, val_loss):
        print(f"Saving checkpoint ({dataset.name}.pth) ...")
        PATH = os.path.join('checkpoints', f"{dataset.name}.pth")
        config_PATH = os.path.join('checkpoints',
                                   f"{dataset.name}_config.json")
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': loss,
                'val_loss': val_loss
            }, PATH)
        with open(config_PATH, 'wt') as f:
            json.dump(self.config, f)

    def _load_checkpoint(self, PATH: str, config_PATH: str):
        checkpoint = torch.load(PATH)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']

        with open(config_PATH, 'rt') as f:
            self.config = json.load(f)

        return epoch, loss

    def fit(self,
            dataset,
            validation=True,
            batch_size=1,
            patience=3,
            delta=0.):
        """ Fits the model to the given dataset.

        Usage:
        ``` y
        >>> rge = Framework(**config)
        >>> rge.fit(train_data)
        """
        self.model.to(self.device)
        train_data = dataset.get_train(batch_size)

        if self.config['half']:
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level='O2',
                keep_batchnorm_fp32=True)

        if self.config['linear_scheduler']:
            num_training_steps = int(
                len(train_data) // self.grad_acc * self.config['epochs'])
            scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config.get('warmup_steps', 0),
                num_training_steps=num_training_steps)
        else:
            scheduler = None

        early_stopping = EarlyStopping(patience, delta, self._save_checkpoint)

        for epoch in range(self.config['epochs']):
            self.optimizer.zero_grad()
            loss = self._train_step(train_data, epoch, scheduler=scheduler)
            if validation:
                val_loss, _, _, _ = self._val_step(dataset.get_val(batch_size),
                                                   epoch)
                if early_stopping(val_loss,
                                  dataset=dataset,
                                  epoch=epoch,
                                  loss=loss):
                    break

        # Recover the best epoch
        path = os.path.join("checkpoints", f"{dataset.name}.pth")
        config_path = os.path.join("checkpoints",
                                   f"{dataset.name}_config.json")
        _, _ = self._load_checkpoint(path, config_path)

    def predict(self, dataset, return_proba=False) -> torch.Tensor:
        """ Predicts the relations graph for the given dataset.
        """
        self.model.to(self.device)
        self.model.eval()

        predictions, instances = [], []
        with torch.no_grad():
            progress = tqdm(enumerate(dataset), total=len(dataset))
            for i, batch in progress:
                # uncompress the batch
                seq, mask, ent, label = batch
                seq = seq.to(self.device)
                mask = mask.to(self.device)
                ent = ent.to(self.device)
                label = label.to(self.device)

                output = self.model(seq, mask, ent)
                if not return_proba:
                    pred = np.argmax(output.detach().cpu().numpy(),
                                     axis=1).tolist()
                else:
                    pred = output.detach().cpu().numpy().tolist()
                inst = dataset.get_instances(i)

                predictions.extend(pred)
                instances.extend(inst)

        return predictions, instances

    def evaluate(self, dataset: Dataset, batch_size=1) -> torch.Tensor:
        """ Evaluates the model given for the given dataset.
        """
        loss, precision, recall, fscore = self._val_step(
            dataset.get_val(batch_size), 0)
        return loss, precision, recall, fscore

    def save_model(self, path: str):
        """ Saves the model to a file.

        Usage:
        ``` 
        >>> rge = Framework(**config)
        >>> rge.fit(train_data)

        >>> rge.save_model("path/to/file")
        ```

        TODO
        """
        self.model.save_pretrained(path)
        with open(f"{path}/fine_tunning.config.json", 'wt') as f:
            json.dump(self.config, f, indent=4)

    @classmethod
    def load_model(cls,
                   path: str,
                   config_path: str = None,
                   from_checkpoint=False):
        """ Loads the model from a file.

        Args:
            path: str Path to the file that stores the model.

        Returns:
            Framework instance with the loaded model.

        Usage:
        ```
        >>> rge = Framework.load_model("path/to/model")
        ```

        TODO
        """
        if not from_checkpoint:
            config_path = path + '/fine_tunning.config.json'
            with open(config_path) as f:
                config = json.load(f)
            config['pretrained_model'] = path
            rge = cls(**config)

        else:
            if config_path is None:
                raise Exception(
                    'Loading the model from a checkpoint requires config_path argument.'
                )
            with open(config_path) as f:
                config = json.load(f)
            rge = cls(**config)
            rge._load_checkpoint(path, config_path)

        return rge
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        global best_dev
        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)
        adversarial_loss_v2 = torch.nn.CrossEntropyLoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer_G = torch.optim.Adam(G.parameters(),
                                       lr=args.G_lr)  # optimizer for generator
        optimizer_D = torch.optim.Adam(
            D.parameters(), lr=args.D_lr)  # optimizer for discriminator
        optimizer_E = AdamW(E.parameters(), args.bert_lr)
        optimizer_detector = torch.optim.Adam(detector.parameters(),
                                              lr=args.detector_lr)

        G_total_train_loss = []
        D_total_fake_loss = []
        D_total_real_loss = []
        FM_total_train_loss = []
        D_total_class_loss = []
        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []
        detector_total_train_loss = []

        all_features = []
        result = dict()

        for i in range(args.n_epoch):

            # Initialize model state
            G.train()
            D.train()
            E.train()
            detector.train()

            G_train_loss = 0
            D_fake_loss = 0
            D_real_loss = 0
            FM_train_loss = 0
            D_class_loss = 0
            detector_train_loss = 0

            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                ood_sample = (y == 0.0)
                # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta
                # real_loss_func = torch.nn.BCELoss(weight=weight).to(device)

                # the label used to train generator and discriminator.
                valid_label = FloatTensor(batch, 1).fill_(1.0).detach()
                fake_label = FloatTensor(batch, 1).fill_(0.0).detach()

                optimizer_E.zero_grad()
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # train D on real
                optimizer_D.zero_grad()
                real_f_vector, discriminator_output, classification_output = D(
                    real_feature, return_feature=True)
                # discriminator_output = discriminator_output.squeeze()
                real_loss = adversarial_loss(discriminator_output, valid_label)
                real_loss.backward(retain_graph=True)

                if args.do_vis:
                    all_features.append(real_f_vector.detach())

                # # train D on fake
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                fake_discriminator_output = D.detect_only(fake_feature)
                fake_loss = adversarial_loss(fake_discriminator_output,
                                             fake_label)
                fake_loss.backward()
                optimizer_D.step()

                # if args.fine_tune:
                #     optimizer_E.step()

                # train G
                optimizer_G.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_f_vector, D_decision = D.detect_only(G(z),
                                                          return_feature=True)
                gd_loss = adversarial_loss(D_decision, valid_label)
                fm_loss = torch.abs(
                    torch.mean(real_f_vector.detach(), 0) -
                    torch.mean(fake_f_vector, 0)).mean()
                g_loss = gd_loss + 0 * fm_loss
                g_loss.backward()
                optimizer_G.step()

                optimizer_E.zero_grad()

                # train detector
                optimizer_detector.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                if args.loss == 'v1':
                    loss_fake = adversarial_loss(
                        detector(fake_feature),
                        fake_label)  # fake sample is ood
                else:
                    loss_fake = adversarial_loss_v2(
                        detector(fake_feature),
                        fake_label.long().squeeze())
                if args.loss == 'v1':
                    loss_real = adversarial_loss(detector(real_feature),
                                                 y.float())
                else:
                    loss_real = adversarial_loss_v2(detector(real_feature),
                                                    y.long())
                if args.detect_loss == 'v1':
                    detector_loss = args.beta * loss_fake + (
                        1 - args.beta) * loss_real
                else:
                    detector_loss = args.beta * loss_fake + loss_real
                    detector_loss = args.sigma * detector_loss
                detector_loss.backward()
                optimizer_detector.step()

                if args.fine_tune:
                    optimizer_E.step()

                global_step += 1

                D_fake_loss += fake_loss.detach()
                D_real_loss += real_loss.detach()
                G_train_loss += g_loss.detach() + fm_loss.detach()
                FM_train_loss += fm_loss.detach()
                detector_train_loss += detector_loss

            logger.info('[Epoch {}] Train: D_fake_loss: {}'.format(
                i, D_fake_loss / n_sample))
            logger.info('[Epoch {}] Train: D_real_loss: {}'.format(
                i, D_real_loss / n_sample))
            logger.info('[Epoch {}] Train: D_class_loss: {}'.format(
                i, D_class_loss / n_sample))
            logger.info('[Epoch {}] Train: G_train_loss: {}'.format(
                i, G_train_loss / n_sample))
            logger.info('[Epoch {}] Train: FM_train_loss: {}'.format(
                i, FM_train_loss / n_sample))
            logger.info('[Epoch {}] Train: detector_train_loss: {}'.format(
                i, detector_train_loss / n_sample))
            logger.info(
                '---------------------------------------------------------------------------'
            )

            D_total_fake_loss.append(D_fake_loss / n_sample)
            D_total_real_loss.append(D_real_loss / n_sample)
            D_total_class_loss.append(D_class_loss / n_sample)
            G_total_train_loss.append(G_train_loss / n_sample)
            FM_total_train_loss.append(FM_train_loss / n_sample)
            detector_total_train_loss.append(detector_train_loss / n_sample)

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(
                    eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_gan_model(D, G, config['gan_save_path'])
                    if args.fine_tune:
                        save_model(E,
                                   path=config['bert_save_path'],
                                   model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        if args.patience >= args.n_epoch:
            save_gan_model(D, G, config['gan_save_path'])
            if args.fine_tune:
                save_model(E, path=config['bert_save_path'], model_name='bert')

        freeze_data['D_total_fake_loss'] = D_total_fake_loss
        freeze_data['D_total_real_loss'] = D_total_real_loss
        freeze_data['D_total_class_loss'] = D_total_class_loss
        freeze_data['G_total_train_loss'] = G_total_train_loss
        freeze_data['FM_total_train_loss'] = FM_total_train_loss
        freeze_data['valid_real_loss'] = valid_detection_loss
        freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision
        freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall
        freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score

        best_dev = -early_stopping.best_score

        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features
        return result
Beispiel #28
0
class CXRBERT_Trainer():
    def __init__(self, args, train_dataloader, test_dataloader=None):
        self.args = args

        cuda_condition = torch.cuda.is_available() and args.with_cuda

        self.device = torch.device("cuda" if cuda_condition else "cpu")
        print('Current cuda device ', torch.cuda.current_device())  # check

        if args.weight_load:
            config = AutoConfig.from_pretrained(args.pre_trained_model_path)
            model_state_dict = torch.load(
                os.path.join(args.pre_trained_model_path, 'pytorch_model.bin'))
            self.model = CXRBERT.from_pretrained(args.pre_trained_model_path,
                                                 state_dict=model_state_dict,
                                                 config=config,
                                                 args=args).to(self.device)
            print('training restart with mid epoch')
            print(config)
        else:
            if args.bert_model == "albert-base-v2":
                config = AlbertConfig.from_pretrained(args.bert_model)
            elif args.bert_model == "emilyalsentzer/Bio_ClinicalBERT":
                config = AutoConfig.from_pretrained(args.bert_model)
            elif args.bert_model == "bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12":
                config = AutoConfig.from_pretrained(args.bert_model)
            elif args.bert_model == "bert-small-scratch":
                config = BertConfig.from_pretrained(
                    "google/bert_uncased_L-4_H-512_A-8")
            elif args.bert_model == "bert-base-scratch":
                config = BertConfig.from_pretrained("bert-base-uncased")
            else:
                config = BertConfig.from_pretrained(
                    args.bert_model)  # bert-base, small, tiny

            self.model = CXRBERT(config, args).to(self.device)

        wandb.watch(self.model)

        if args.with_cuda and torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model,
                                         device_ids=args.cuda_devices)

        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.optimizer = AdamW(self.model.parameters(), lr=args.lr)

        self.mlm_criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.itm_criterion = nn.CrossEntropyLoss()

        self.log_freq = args.log_freq
        self.step_cnt = 0

        print("Total Parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):

        self.model.train()

        train_losses = []
        train_itm_loss = []
        train_mlm_loss = []

        train_data_iter = tqdm.tqdm(enumerate(self.train_data),
                                    desc=f'EP_:{epoch}',
                                    total=len(self.train_data),
                                    bar_format='{l_bar}{r_bar}')
        total_correct = 0
        total_element = 0
        total_mlm_correct = 0
        total_mlm_element = 0

        total_valid_correct = 0
        total_valid_element = 0
        total_mlm_valid_correct = 0
        total_mlm_valid_element = 0

        for i, data in train_data_iter:

            cls_tok, input_ids, txt_labels, attn_masks, img, segment, is_aligned, sep_tok, itm_prob = data

            cls_tok = cls_tok.to(self.device)
            input_ids = input_ids.to(self.device)
            txt_labels = txt_labels.to(self.device)
            attn_masks = attn_masks.to(self.device)
            img = img.to(self.device)
            segment = segment.to(self.device)
            is_aligned = is_aligned.to(self.device)
            sep_tok = sep_tok.to(self.device)

            mlm_output, itm_output = self.model(cls_tok, input_ids, attn_masks,
                                                segment, img, sep_tok)

            if self.args.mlm_task and self.args.itm_task == False:
                mlm_loss = self.mlm_criterion(mlm_output.transpose(1, 2),
                                              txt_labels)
                loss = mlm_loss
                print('only mlm_loss')

            if self.args.itm_task and self.args.mlm_task == False:
                itm_loss = self.itm_criterion(itm_output, is_aligned)
                loss = itm_loss
                print('only itm_loss')

            if self.args.mlm_task and self.args.itm_task:

                mlm_loss = self.mlm_criterion(mlm_output.transpose(1, 2),
                                              txt_labels)
                train_mlm_loss.append(mlm_loss.item())

                itm_loss = self.itm_criterion(itm_output, is_aligned)
                train_itm_loss.append(itm_loss.item())

                loss = itm_loss + mlm_loss

            train_losses.append(loss.item())
            self.optimizer.zero_grad()  # above
            loss.backward()
            self.optimizer.step()

            if self.args.itm_task:
                correct = itm_output.argmax(dim=-1).eq(is_aligned).sum().item()
                total_correct += correct
                total_element += is_aligned.nelement()

            if self.args.mlm_task:
                eq = (mlm_output.argmax(dim=-1).eq(txt_labels)).cpu().numpy()
                txt_labels_np = txt_labels.cpu().numpy()
                for bs, label in enumerate(txt_labels_np):
                    index = np.where(label == -100)[0]
                    f_label = np.delete(label, index)
                    f_eq = np.delete(eq[bs], index)
                    total_mlm_correct += f_eq.sum()
                    total_mlm_element += len(f_label)

        print("avg loss per epoch", np.mean(train_losses))
        print("avg itm acc per epoch",
              round(total_correct / total_element * 100, 3))
        if self.args.mlm_task and self.args.itm_task:
            wandb.log(
                {
                    "avg_loss": np.mean(train_losses),
                    "avg_mlm_loss": np.mean(train_mlm_loss),
                    "avg_itm_loss": np.mean(train_itm_loss),
                    "itm_acc": total_correct / total_element * 100,
                    "mlm_acc": total_mlm_correct / total_mlm_element * 100
                },
                step=epoch)

        if self.args.itm_task and self.args.mlm_task == False:
            wandb.log(
                {
                    "avg_loss": np.mean(train_losses),
                    "itm_epoch_acc": total_correct / total_element * 100
                },
                step=epoch)

        if self.args.mlm_task and self.args.itm_task == False:
            wandb.log(
                {
                    "avg_loss": np.mean(train_losses),
                    "mlm_epoch_acc":
                    total_mlm_correct / total_mlm_element * 100
                },
                step=epoch)

        test_data_iter = tqdm.tqdm(enumerate(self.test_data),
                                   desc=f'EP_:{epoch}',
                                   total=len(self.test_data),
                                   bar_format='{l_bar}{r_bar}')
        self.model.eval()
        with torch.no_grad():
            eval_losses = []
            eval_mlm_loss = []
            eval_itm_loss = []
            for i, data in test_data_iter:
                cls_tok, input_ids, txt_labels, attn_masks, img, segment, is_aligned, sep_tok, itm_prob = data

                cls_tok = cls_tok.to(self.device)
                input_ids = input_ids.to(self.device)
                txt_labels = txt_labels.to(self.device)
                attn_masks = attn_masks.to(self.device)
                img = img.to(self.device)
                segment = segment.to(self.device)
                is_aligned = is_aligned.to(self.device)
                sep_tok = sep_tok.to(self.device)

                mlm_output, itm_output = self.model(cls_tok, input_ids,
                                                    attn_masks, segment, img,
                                                    sep_tok)

                if self.args.mlm_task and self.args.itm_task == False:
                    valid_mlm_loss = self.mlm_criterion(
                        mlm_output.transpose(1, 2), txt_labels)
                    valid_loss = valid_mlm_loss
                    print('only valid mlm loss')

                if self.args.itm_task and self.args.mlm_task == False:
                    valid_itm_loss = self.itm_criterion(itm_output, is_aligned)
                    valid_loss = valid_itm_loss
                    print('only valid itm loss')

                if self.args.mlm_task and self.args.itm_task:
                    # TODO: weight each loss, mlm > itm
                    valid_mlm_loss = self.mlm_criterion(
                        mlm_output.transpose(1, 2), txt_labels)
                    valid_itm_loss = self.itm_criterion(itm_output, is_aligned)
                    eval_mlm_loss.append(valid_mlm_loss.item())
                    eval_itm_loss.append(valid_itm_loss.item())

                    valid_loss = valid_itm_loss + valid_mlm_loss

                eval_losses.append(valid_loss.item())

                if self.args.itm_task:
                    valid_correct = itm_output.argmax(
                        dim=-1).eq(is_aligned).sum().item()
                    total_valid_correct += valid_correct
                    total_valid_element += is_aligned.nelement()

                if self.args.mlm_task:
                    eq = (mlm_output.argmax(
                        dim=-1).eq(txt_labels)).cpu().numpy()
                    txt_labels_np = txt_labels.cpu().numpy()
                    for bs, label in enumerate(txt_labels_np):
                        index = np.where(label == -100)[0]
                        f_label = np.delete(label, index)
                        f_eq = np.delete(eq[bs], index)
                        total_mlm_valid_correct += f_eq.sum()
                        total_mlm_valid_element += len(f_label)

            print("avg loss in testset", np.mean(eval_losses))
            print("avg itm acc in testset",
                  round(total_valid_correct / total_valid_element * 100, 3))

            if self.args.mlm_task and self.args.itm_task:
                wandb.log(
                    {
                        "eval_avg_loss":
                        np.mean(eval_losses),
                        "eval_mlm_loss":
                        np.mean(eval_mlm_loss),
                        "eval_itm_loss":
                        np.mean(eval_itm_loss),
                        "eval_itm_acc":
                        total_valid_correct / total_valid_element * 100,
                        "eval_mlm_acc":
                        total_mlm_valid_correct / total_mlm_valid_element * 100
                    },
                    step=epoch)

            if self.args.itm_task and self.args.mlm_task == False:
                wandb.log(
                    {
                        "eval_avg_loss":
                        np.mean(eval_losses),
                        "eval_itm_epoch_acc":
                        total_valid_correct / total_valid_element * 100
                    },
                    step=epoch)

            if self.args.mlm_task and self.args.itm_task == False:
                wandb.log(
                    {
                        "eval_avg_loss":
                        np.mean(eval_losses),
                        "eval_mlm_epoch_acc":
                        total_mlm_valid_correct / total_mlm_valid_element * 100
                    },
                    step=epoch)

    def save(self, epoch, file_path):
        save_path_per_ep = os.path.join(file_path, str(epoch))
        if not os.path.exists(save_path_per_ep):
            os.mkdir(save_path_per_ep)
            os.chmod(save_path_per_ep, 0o777)

        if torch.cuda.device_count() > 1:
            self.model.module.save_pretrained(save_path_per_ep)
            print(f'Multi_EP: {epoch} Model saved on {save_path_per_ep}')
        else:
            self.model.save_pretrained(save_path_per_ep)
            print(f'Single_EP: {epoch} Model saved on {save_path_per_ep}')
        os.chmod(save_path_per_ep + '/pytorch_model.bin', 0o777)
                loss = loss / Config.gradient_accumulation_steps
            loss.backward()

            nb_tr_steps += 1
            tr_mask_acc.update(mask_metric.value(), n=input_ids.size(0))
            tr_sop_acc.update(sop_metric.value(), n=input_ids.size(0))
            tr_loss.update(loss.item(), n=1)
            tr_mask_loss.update(masked_lm_loss.item(), n=1)
            tr_sop_loss.update(next_sentence_loss.item(), n=1)

            if (step + 1) % Config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               Config.max_grad_norm)
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            if global_step % Config.num_save_steps == 0:
                model_to_save = model.module if hasattr(model,
                                                        'module') else model
                output_model_file = os.path.join(
                    Config.output_dir,
                    'pytorch_model_epoch{}.bin'.format(global_step))
                torch.save(model_to_save.state_dict(), output_model_file)

                # save config
                output_config_file = Config.output_dir + "config.json"
                with open(str(output_config_file), 'w') as f:
                    f.write(model_to_save.config.to_json_string())
                # save vocab
Beispiel #30
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, n_gpu)
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()

    num_labels = len(["entailment", "neutral", "contradiction"])
    # pretrain_model_dir = 'roberta-large' #'roberta-large' , 'roberta-large-mnli'
    pretrain_model_dir = '/export/home/Dataset/BERT_pretrained_mine/TrainedModelReminder/RoBERTa_on_MNLI_SNLI_SciTail_RTE_ANLI_SpecialToken_epoch_2_acc_4.156359461121103'  #'roberta-large' , 'roberta-large-mnli'
    model = RobertaForSequenceClassification.from_pretrained(
        pretrain_model_dir, num_labels=num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    #MNLI-SNLI-SciTail-RTE-SICK
    train_examples_MNLI, dev_examples_MNLI = processor.get_MNLI_train_and_dev(
        '/export/home/Dataset/glue_data/MNLI/train.tsv',
        '/export/home/Dataset/glue_data/MNLI/dev_mismatched.tsv'
    )  #train_pu_half_v1.txt
    train_examples_SNLI, dev_examples_SNLI = processor.get_SNLI_train_and_dev(
        '/export/home/Dataset/glue_data/SNLI/train.tsv',
        '/export/home/Dataset/glue_data/SNLI/dev.tsv')
    train_examples_SciTail, dev_examples_SciTail = processor.get_SciTail_train_and_dev(
        '/export/home/Dataset/SciTailV1/tsv_format/scitail_1.0_train.tsv',
        '/export/home/Dataset/SciTailV1/tsv_format/scitail_1.0_dev.tsv')
    train_examples_RTE, dev_examples_RTE = processor.get_RTE_train_and_dev(
        '/export/home/Dataset/glue_data/RTE/train.tsv',
        '/export/home/Dataset/glue_data/RTE/dev.tsv')
    train_examples_ANLI, dev_examples_ANLI = processor.get_ANLI_train_and_dev(
        'train', 'dev',
        '/export/home/Dataset/para_entail_datasets/ANLI/anli_v0.1/')

    train_examples = train_examples_MNLI + train_examples_SNLI + train_examples_SciTail + train_examples_RTE + train_examples_ANLI
    dev_examples_list = [
        dev_examples_MNLI, dev_examples_SNLI, dev_examples_SciTail,
        dev_examples_RTE, dev_examples_ANLI
    ]

    dev_task_label = [0, 0, 1, 1, 0]
    task_names = ['MNLI', 'SNLI', 'SciTail', 'RTE', 'ANLI']
    '''filter challenging neighbors'''
    neighbor_id_list = []
    readfile = codecs.open('neighbors_indices_before_dropout_eud.v3.txt', 'r',
                           'utf-8')
    for line in readfile:
        neighbor_id_list.append(int(line.strip()))
    readfile.close()
    print('neighbor_id_list size:', len(neighbor_id_list))
    truncated_train_examples = [train_examples[i] for i in neighbor_id_list]
    train_examples = truncated_train_examples

    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    train_features = convert_examples_to_features(
        train_examples,
        label_list,
        args.max_seq_length,
        tokenizer,
        output_mode,
        cls_token_at_end=
        False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
        cls_token=tokenizer.cls_token,
        cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
        sep_token=tokenizer.sep_token,
        sep_token_extra=
        True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
        pad_on_left=
        False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
        pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
        pad_token_segment_id=0)  #4 if args.model_type in ['xlnet'] else 0,)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)
    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)
    all_task_label_ids = torch.tensor([f.task_label for f in train_features],
                                      dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids, all_task_label_ids)
    train_sampler = RandomSampler(train_data)

    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  drop_last=True)
    '''dev data to features'''
    valid_dataloader_list = []
    for valid_examples_i in dev_examples_list:
        valid_features = convert_examples_to_features(
            valid_examples_i,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        logger.info("***** valid_examples *****")
        logger.info("  Num examples = %d", len(valid_examples_i))
        valid_input_ids = torch.tensor([f.input_ids for f in valid_features],
                                       dtype=torch.long)
        valid_input_mask = torch.tensor([f.input_mask for f in valid_features],
                                        dtype=torch.long)
        valid_segment_ids = torch.tensor(
            [f.segment_ids for f in valid_features], dtype=torch.long)
        valid_label_ids = torch.tensor([f.label_id for f in valid_features],
                                       dtype=torch.long)
        valid_task_label_ids = torch.tensor(
            [f.task_label for f in valid_features], dtype=torch.long)

        valid_data = TensorDataset(valid_input_ids, valid_input_mask,
                                   valid_segment_ids, valid_label_ids,
                                   valid_task_label_ids)
        valid_sampler = SequentialSampler(valid_data)
        valid_dataloader = DataLoader(valid_data,
                                      sampler=valid_sampler,
                                      batch_size=args.eval_batch_size)
        valid_dataloader_list.append(valid_dataloader)

    iter_co = 0
    for epoch_i in trange(int(args.num_train_epochs), desc="Epoch"):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, task_label_ids = batch
            logits = model(input_ids, input_mask, None, labels=None)

            prob_matrix = F.log_softmax(logits[0].view(-1, num_labels), dim=1)
            '''this step *1.0 is very important, otherwise bug'''
            new_prob_matrix = prob_matrix * 1.0
            '''change the entail prob to p or 1-p'''
            changed_places = torch.nonzero(task_label_ids, as_tuple=False)
            new_prob_matrix[changed_places,
                            0] = 1.0 - prob_matrix[changed_places, 0]

            loss = F.nll_loss(new_prob_matrix, label_ids.view(-1))

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            iter_co += 1

            # if iter_co % len(train_dataloader) ==0:
            if iter_co % (len(train_dataloader) // 5) == 0:
                '''
                start evaluate on  dev set after this epoch
                '''
                # if n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
                #     model = torch.nn.DataParallel(model)
                model.eval()
                for m in model.modules():
                    if isinstance(m, torch.nn.BatchNorm2d):
                        m.track_running_stats = False
                # logger.info("***** Running evaluation *****")
                # logger.info("  Num examples = %d", len(valid_examples_MNLI))
                # logger.info("  Batch size = %d", args.eval_batch_size)

                dev_acc_sum = 0.0
                for idd, valid_dataloader in enumerate(valid_dataloader_list):
                    task_label = dev_task_label[idd]
                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    # print('Evaluating...', task_label)
                    # for _, batch in enumerate(tqdm(valid_dataloader, desc=task_names[idd])):
                    for _, batch in enumerate(valid_dataloader):
                        batch = tuple(t.to(device) for t in batch)
                        input_ids, input_mask, segment_ids, label_ids, task_label_ids = batch
                        if task_label == 0:
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())
                        else:
                            '''SciTail, RTE'''
                            task_label_ids_list = list(
                                task_label_ids.detach().cpu().numpy())
                            gold_label_batch_fake = list(
                                label_ids.detach().cpu().numpy())
                            for ex_id, label_id in enumerate(
                                    gold_label_batch_fake):
                                if task_label_ids_list[ex_id] == 0:
                                    gold_label_ids.append(label_id)  #0
                                else:
                                    gold_label_ids.append(1)  #1
                        with torch.no_grad():
                            logits = model(input_ids=input_ids,
                                           attention_mask=input_mask,
                                           token_type_ids=None,
                                           labels=None)
                        logits = logits[0]
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0],
                                                 logits.detach().cpu().numpy(),
                                                 axis=0)

                    preds = preds[0]
                    pred_probs = softmax(preds, axis=1)
                    pred_label_ids_3way = np.argmax(pred_probs, axis=1)
                    if task_label == 0:
                        '''3-way tasks MNLI, SNLI, ANLI'''
                        pred_label_ids = pred_label_ids_3way
                    else:
                        '''SciTail, RTE'''
                        pred_label_ids = []
                        for pred_label_i in pred_label_ids_3way:
                            if pred_label_i == 0:
                                pred_label_ids.append(0)
                            else:
                                pred_label_ids.append(1)
                    assert len(pred_label_ids) == len(gold_label_ids)
                    hit_co = 0
                    for k in range(len(pred_label_ids)):
                        if pred_label_ids[k] == gold_label_ids[k]:
                            hit_co += 1
                    test_acc = hit_co / len(gold_label_ids)
                    dev_acc_sum += test_acc
                    print(task_names[idd], ' dev acc:', test_acc)
                '''store the model, because we can test after a max_dev acc reached'''
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                store_transformers_models(
                    model_to_save, tokenizer,
                    '/export/home/Dataset/BERT_pretrained_mine/TrainedModelReminder/',
                    'RoBERTa_on_MNLI_SNLI_SciTail_RTE_ANLI_SpecialToken_Filter_1_epoch_'
                    + str(epoch_i) + '_acc_' + str(dev_acc_sum))