def test_constant_scheduler(self):
        scheduler = get_constant_schedule(self.optimizer)
        lrs = unwrap_schedule(scheduler, self.num_steps)
        expected_learning_rates = [10.] * self.num_steps
        self.assertEqual(len(lrs[0]), 1)
        self.assertListEqual([l[0] for l in lrs], expected_learning_rates)

        scheduler = get_constant_schedule(self.optimizer)
        lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
        self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
Exemple #2
0
    def configure_optimizers(self):
        if self.hparams.optimize == 'basic':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
            scheduler = get_constant_schedule(optimizer)

        elif self.hparams.optimize == 'bert':
            # Copied from: https://huggingface.co/transformers/training.html
            no_decay = ['bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in self.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                self.hparams.weight_decay
            }, {
                'params': [
                    p for n, p in self.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.
            }]
            optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.lr)

            self.num_warmup_steps = int(self.num_train_steps *
                                        self.hparams.warmup_proportion)
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=self.num_warmup_steps,
                num_training_steps=self.num_train_steps)
        else:
            raise ValueError

        return [optimizer], [scheduler]
   def configure_optimizers(self):
      "Prepare optimizer and schedule (linear warmup and decay)"
      model = self.model
      no_decay = ["bias", "LayerNorm.weight"]
      optimizer_grouped_parameters = [
          {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": self.hparams.weight_decay,
          },
          {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0,
          },
      ]
      # Original optimizer from Transformers. It works but needs warmup.
      # optimizer = transformers.AdamW(optimizer_grouped_parameters,
      #      lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
      # The RAdam optimizer works approximately as well as Ranger.
      #optimizer = RAdam(optimizer_grouped_parameters,
      #      lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
      # The Ranger optimizer is the combination of RAdam and Lookahead. It
      # works well for this task. The best conditions seem to be learning
      # rate 1e-4 w/ RAdam or Ranger, gradient accumulation of 2 batches.
      optimizer = ranger.Ranger(optimizer_grouped_parameters,
            lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)

      # The constant scheduler does nothing. Replace with another
      # scheduler if required.
      scheduler = transformers.get_constant_schedule(optimizer)
      scheduler = {
          'scheduler': scheduler,
          'interval': 'step',
          'frequency': 1
      }
      return [optimizer], [scheduler]
Exemple #4
0
 def init_optimizer(self, model, lr):
     args = self.args
     no_decay = ['bias', 'LayerNorm.weight']
     optimizer_grouped_parameters = [{
         "params": [
             p for n, p in model.named_parameters()
             if not any(nd in n for nd in no_decay)
         ],
         "weight_decay":
         args.weight_decay
     }, {
         "params": [
             p for n, p in model.named_parameters()
             if any(nd in n for nd in no_decay)
         ],
         "weight_decay":
         0.0
     }]
     # TODO calculate t_total
     optimizer = AdamW(optimizer_grouped_parameters,
                       lr=lr,
                       eps=args.adam_epsilon)
     # scheduler = WarmupLinearSchedule(
     #   optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
     scheduler = get_constant_schedule(optimizer)
     return optimizer_grouped_parameters, optimizer, scheduler
Exemple #5
0
def get_scheduler(optimizer, scheduler: str, warmup_steps: int,
                  num_total: int):
    assert scheduler in [
        "constantlr", "warmuplinear", "warmupconstant", "warmupcosine",
        "warmupcosinewithhardrestarts"
    ], ('scheduler should be one of ["constantlr","warmupconstant","warmupcosine","warmupcosinewithhardrestarts"]'
        )
    if scheduler == 'constantlr':
        return transformers.get_constant_schedule(optimizer)
    elif scheduler == 'warmupconstant':
        return transformers.get_constant_schedule_with_warmup(
            optimizer, num_warmup_steps=warmup_steps)
    elif scheduler == 'warmuplinear':
        return transformers.get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_total)
    elif scheduler == 'warmupcosine':
        return transformers.get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_total)
    elif scheduler == 'warmupcosinewithhardrestarts':
        return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_total)
 def _get_scheduler(self, optimizer, scheduler: str, warmup_steps: int,
                    t_total: int):
     """
     Returns the correct learning rate scheduler
     """
     scheduler = scheduler.lower()
     if scheduler == 'constantlr':
         return transformers.get_constant_schedule(optimizer)
     elif scheduler == 'warmupconstant':
         return transformers.get_constant_schedule_with_warmup(
             optimizer, num_warmup_steps=warmup_steps)
     elif scheduler == 'warmuplinear':
         return transformers.get_linear_schedule_with_warmup(
             optimizer,
             num_warmup_steps=warmup_steps,
             num_training_steps=t_total)
     elif scheduler == 'warmupcosine':
         return transformers.get_cosine_schedule_with_warmup(
             optimizer,
             num_warmup_steps=warmup_steps,
             num_training_steps=t_total)
     elif scheduler == 'warmupcosinewithhardrestarts':
         return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
             optimizer,
             num_warmup_steps=warmup_steps,
             num_training_steps=t_total)
     else:
         raise ValueError("Unknown scheduler {}".format(scheduler))
    def configure_optimizers(self):
        optimizers = [
            LookaheadRMSprop(
                params=[
                    {
                        "params": self.gate.g_hat.parameters(),
                        "lr": self.hparams.learning_rate,
                    },
                    {
                        "params": self.gate.placeholder.parameters()
                        if isinstance(self.gate.placeholder, torch.nn.ParameterList)
                        else [self.gate.placeholder],
                        "lr": self.hparams.learning_rate_placeholder,
                    },
                ],
                centered=True,
            ),
            LookaheadRMSprop(
                params=[self.alpha]
                if isinstance(self.alpha, torch.Tensor)
                else self.alpha.parameters(),
                lr=self.hparams.learning_rate_alpha,
            ),
        ]

        schedulers = [
            {
                "scheduler": get_constant_schedule_with_warmup(optimizers[0], 12 * 100),
                "interval": "step",
            },
            get_constant_schedule(optimizers[1]),
        ]
        return optimizers, schedulers
def get_lr_scheduler(optimizer,
                     scheduler_type,
                     lr_warmup=None,
                     num_steps=None):
    if scheduler_type == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer, int(lr_warmup * num_steps), num_steps)
    elif scheduler_type == "constant":
        scheduler = get_constant_schedule(optimizer)
    else:
        raise ValueError("Unknown scheduler_type:", scheduler_type)

    # Initialize step as Poptorch does not call optimizer.step() explicitly
    optimizer._step_count = 1

    return scheduler
def train_model(train_dataloader, val_dataloader, model, EPOCHS, BATCH_SIZE, lr, ACCUMULATION_STEPS):
    ## Optimization
    num_train_optimization_steps = int(EPOCHS * len(train_dataloader) / BATCH_SIZE / ACCUMULATION_STEPS)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(np in n for np in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(np in n for np in no_decay)], 'weight_decay': 0.01}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=lr, correct_bias=False)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100,
                                            num_training_steps=num_train_optimization_steps)
    scheduler0 = get_constant_schedule(optimizer)
    

    frozen = True
    # Training
    for epoch in (range(EPOCHS+1)):
        print("\n--------Start training on  Epoch %d/%d" %(epoch, EPOCHS))
        avg_loss = 0 
        avg_accuracy = 0

        model.train()
        for i, (input_ids, attention_mask, label_batch) in (enumerate(train_dataloader)):
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            label_batch = label_batch.to(device)

            optimizer.zero_grad()
            y_preds = model(input_ids, attention_mask, None)
            loss = torch.nn.functional.binary_cross_entropy(y_preds.to(device),
                                                                label_batch.float().to(device))
            
            loss = loss.mean()
            loss.backward()
            optimizer.step()

            lossf = loss.item()
            avg_loss += loss.item() / len(train_dataloader)

        print("Loss training:", avg_loss)

        roc = eval(val_dataloader, model, device)

    return model
Exemple #10
0
def get_lr_scheduler(optimizer,
                     scheduler_type,
                     warmup_steps=None,
                     num_steps=None,
                     last_epoch=-1):
    if scheduler_type == "linear":
        scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps,
                                                    num_steps)
    elif scheduler_type == "constant":
        scheduler = get_constant_schedule(optimizer)
    elif scheduler_type == "cosine":
        scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                    warmup_steps,
                                                    num_steps,
                                                    last_epoch=last_epoch)
    else:
        raise ValueError("Unknown scheduler_type:", scheduler_type)
    return scheduler
Exemple #11
0
    def configure_optimizers(self):
        optimizers = [
            LookaheadRMSprop(
                params=list(self.gate.parameters()) + [self.placeholder],
                lr=self.hparams.learning_rate,
                centered=True,
            ),
            LookaheadRMSprop(
                params=[self.alpha],
                lr=self.hparams.learning_rate_alpha,
            ),
        ]

        schedulers = [
            {
                "scheduler":
                get_constant_schedule_with_warmup(optimizers[0], 200),
                "interval": "step",
            },
            get_constant_schedule(optimizers[1]),
        ]
        return optimizers, schedulers
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        help="The teacher model dir.")
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        help="The student model dir.")
    parser.add_argument("--task_name",
                        default="SST-2",
                        type=str,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )

    # added arguments
    parser.add_argument('--aug_train', action='store_true')
    parser.add_argument('--eval_step', type=float, default=0.1)
    parser.add_argument('--pred_distill', action='store_true')
    parser.add_argument('--data_url', type=str, default="")
    parser.add_argument('--temperature', type=float, default=1.)

    args = parser.parse_args()
    logger.info('The args: {}'.format(args))

    # intermediate distillation default parameters
    default_params = {
        "cola": {
            "num_train_epochs": 50,
            "max_seq_length": 64
        },
        "mnli": {
            "num_train_epochs": 5,
            "max_seq_length": 128
        },
        "mrpc": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        },
        "sst-2": {
            "num_train_epochs": 10,
            "max_seq_length": 64
        },
        "sts-b": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        },
        "qqp": {
            "num_train_epochs": 5,
            "max_seq_length": 128
        },
        "qnli": {
            "num_train_epochs": 10,
            "max_seq_length": 128
        },
        "rte": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        }
    }

    acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"]
    corr_tasks = ["sts-b"]
    mcc_tasks = ["cola"]

    # Prepare devices
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare task settings
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name in default_params:
        args.max_seq_len = default_params[task_name]["max_seq_length"]

    if not args.pred_distill and not args.do_eval:
        if task_name in default_params:
            args.num_train_epoch = default_params[task_name][
                "num_train_epochs"]

    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.student_model,
                                              do_lower_case=args.do_lower_case)
    student_config = BertConfig.from_pretrained(args.student_model,
                                                num_labels=num_labels,
                                                finetuning_task=args.task_name)

    if not args.do_eval:
        if args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(args.gradient_accumulation_steps))

        args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

        train_data, _ = get_tensor_data(args, task_name, tokenizer, False,
                                        args.aug_train)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        num_train_optimization_steps = int(
            len(train_dataloader) /
            args.gradient_accumulation_steps) * args.num_train_epochs

    eval_data, eval_labels = get_tensor_data(args, task_name, tokenizer, True,
                                             False)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    if not args.do_eval:
        teacher_config = BertConfig.from_pretrained(
            args.teacher_model,
            num_labels=num_labels,
            finetuning_task=args.task_name)
        teacher_model = TinyBertForSequenceClassification.from_pretrained(
            args.teacher_model, config=teacher_config)
        teacher_model.to(device)

    student_model = TinyBertForSequenceClassification.from_pretrained(
        args.student_model, config=student_config)
    student_model.to(device)
    if args.do_eval:
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_data))
        logger.info("  Batch size = %d", args.eval_batch_size)

        student_model.eval()
        result = do_eval(student_model, task_name, eval_dataloader, device,
                         output_mode, eval_labels, num_labels)
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
    else:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        if n_gpu > 1:
            student_model = torch.nn.DataParallel(student_model)
            teacher_model = torch.nn.DataParallel(teacher_model)
        # Prepare optimizer
        param_optimizer = list(student_model.named_parameters())
        size = 0
        for n, p in student_model.named_parameters():
            logger.info('n: {}'.format(n))
            size += p.nelement()

        logger.info('Total parameters: {}'.format(size))
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          correct_bias=False)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(num_train_optimization_steps *
                                 args.warmup_proportion),
            num_training_steps=num_train_optimization_steps)
        if not args.pred_distill:
            scheduler = get_constant_schedule(optimizer)

        # Prepare loss functions
        loss_mse = MSELoss()

        def soft_cross_entropy(predicts, targets):
            student_likelihood = torch.nn.functional.log_softmax(predicts,
                                                                 dim=-1)
            targets_prob = torch.nn.functional.softmax(targets, dim=-1)
            return (-targets_prob * student_likelihood).mean()

        # Train and evaluate
        global_step = 0
        best_dev_acc = 0.0
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")

        for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0.
            tr_att_loss = 0.
            tr_rep_loss = 0.
            tr_cls_loss = 0.

            student_model.train()
            nb_tr_examples, nb_tr_steps = 0, 0

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration", ascii=True)):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.
                rep_loss = 0.
                cls_loss = 0.

                student_logits, student_atts, student_reps = student_model(
                    input_ids, segment_ids, input_mask, is_student=True)

                with torch.no_grad():
                    teacher_logits, teacher_atts, teacher_reps = teacher_model(
                        input_ids, segment_ids, input_mask)

                if not args.pred_distill:
                    teacher_layer_num = len(teacher_atts)
                    student_layer_num = len(student_atts)
                    # print("teacher_layer_num:",teacher_layer_num)
                    # print("student_layer_num:",student_layer_num)
                    # print("teacher_reps num:",len(teacher_reps))

                    assert teacher_layer_num % student_layer_num == 0
                    layers_per_block = int(teacher_layer_num /
                                           student_layer_num)
                    new_teacher_atts = [
                        teacher_atts[i * layers_per_block + layers_per_block -
                                     1] for i in range(student_layer_num)
                    ]

                    for student_att, teacher_att in zip(
                            student_atts, new_teacher_atts):
                        student_att = torch.where(
                            student_att <= -1e2,
                            torch.zeros_like(student_att).to(device),
                            student_att)
                        teacher_att = torch.where(
                            teacher_att <= -1e2,
                            torch.zeros_like(teacher_att).to(device),
                            teacher_att)

                        tmp_loss = loss_mse(student_att, teacher_att)
                        att_loss += tmp_loss

                    new_teacher_reps = [
                        teacher_reps[i * layers_per_block]
                        for i in range(student_layer_num + 1)
                    ]
                    new_student_reps = student_reps
                    for student_rep, teacher_rep in zip(
                            new_student_reps, new_teacher_reps):
                        tmp_loss = loss_mse(student_rep, teacher_rep)
                        rep_loss += tmp_loss

                    loss = rep_loss + att_loss
                    tr_att_loss += att_loss.item()
                    tr_rep_loss += rep_loss.item()
                else:
                    if output_mode == "classification":
                        cls_loss = soft_cross_entropy(
                            student_logits / args.temperature,
                            teacher_logits / args.temperature)
                    elif output_mode == "regression":
                        loss_mse = MSELoss()
                        cls_loss = loss_mse(student_logits.view(-1),
                                            label_ids.view(-1))

                    loss = cls_loss
                    tr_cls_loss += cls_loss.item()

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += label_ids.size(0)
                nb_tr_steps += 1

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (global_step + 1) % int(
                        args.eval_step * num_train_optimization_steps) == 0:
                    logger.info("***** Running evaluation *****")
                    logger.info("  Epoch = {} iter {} step".format(
                        epoch_, global_step))
                    logger.info("  Num examples = %d", len(eval_data))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    student_model.eval()

                    loss = tr_loss / (step + 1)
                    cls_loss = tr_cls_loss / (step + 1)
                    att_loss = tr_att_loss / (step + 1)
                    rep_loss = tr_rep_loss / (step + 1)

                    result = {}
                    if args.pred_distill:
                        result = do_eval(student_model, task_name,
                                         eval_dataloader, device, output_mode,
                                         eval_labels, num_labels)
                    result['global_step'] = global_step
                    result['cls_loss'] = cls_loss
                    result['att_loss'] = att_loss
                    result['rep_loss'] = rep_loss
                    result['loss'] = loss

                    result_to_file(result, output_eval_file)

                    if not args.pred_distill:
                        save_model = True
                    else:
                        save_model = False

                        if task_name in acc_tasks and result[
                                'acc'] > best_dev_acc:
                            best_dev_acc = result['acc']
                            save_model = True

                        if task_name in corr_tasks and result[
                                'corr'] > best_dev_acc:
                            best_dev_acc = result['corr']
                            save_model = True

                        if task_name in mcc_tasks and result[
                                'mcc'] > best_dev_acc:
                            best_dev_acc = result['mcc']
                            save_model = True

                    if save_model:
                        logger.info("***** Save model *****")

                        model_to_save = student_model.module if hasattr(
                            student_model, 'module') else student_model

                        model_name = "pytorch_model.bin"
                        # if not args.pred_distill:
                        #     model_name = "step_{}_{}".format(global_step, "pytorch_model.bin")
                        output_model_file = os.path.join(
                            args.output_dir, model_name)
                        output_config_file = os.path.join(
                            args.output_dir, "config.json")

                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)

                        # Test mnli-mm
                        if args.pred_distill and task_name == "mnli":
                            task_name = "mnli-mm"
                            if not os.path.exists(args.output_dir + '-MM'):
                                os.makedirs(args.output_dir + '-MM')

                            eval_data, eval_labels = get_tensor_data(
                                args, task_name, tokenizer, True, False)

                            eval_sampler = SequentialSampler(eval_data)
                            eval_dataloader = DataLoader(
                                eval_data,
                                sampler=eval_sampler,
                                batch_size=args.eval_batch_size)
                            logger.info("***** Running mm evaluation *****")
                            logger.info("  Num examples = %d", len(eval_data))
                            logger.info("  Batch size = %d",
                                        args.eval_batch_size)

                            result = do_eval(student_model, task_name,
                                             eval_dataloader, device,
                                             output_mode, eval_labels,
                                             num_labels)

                            result['global_step'] = global_step

                            tmp_output_eval_file = os.path.join(
                                args.output_dir + '-MM', "eval_results.txt")
                            result_to_file(result, tmp_output_eval_file)

                            task_name = 'mnli'

                    student_model.train()
Exemple #13
0
def train(args):
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,train_acc,dev_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    if 'lm' in args.ent_emb:
        print('Using contextualized embeddings for concepts')
        use_contextualized, cp_emb = True, None
    else:
        use_contextualized = False
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1))

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)

    rel_emb = np.load(args.rel_emb_path)
    rel_emb = np.concatenate((rel_emb, -rel_emb), 0)
    rel_emb = cal_2hop_rel_emb(rel_emb)
    rel_emb = torch.tensor(rel_emb)
    relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1)
    # print('| num_concepts: {} | num_relations: {} |'.format(concept_num, relation_num))

    device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = LMRelationNetDataLoader(args.train_statements, args.train_rel_paths,
                                      args.dev_statements, args.dev_rel_paths,
                                      args.test_statements, args.test_rel_paths,
                                      batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, device=device,
                                      model_name=args.encoder,
                                      max_tuple_num=args.max_tuple_num, max_seq_length=args.max_seq_len,
                                      is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
                                      use_contextualized=use_contextualized,
                                      train_adj_path=args.train_adj, dev_adj_path=args.dev_adj, test_adj_path=args.test_adj,
                                      train_node_features_path=args.train_node_features, dev_node_features_path=args.dev_node_features,
                                      test_node_features_path=args.test_node_features, node_feature_type=args.node_feature_type,
                                      format=args.format)

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMRelationNet(model_name=args.encoder, concept_num=concept_num, concept_dim=relation_dim,
                          relation_num=relation_num, relation_dim=relation_dim,
                          concept_in_dim=(dataset.get_node_feature_dim() if use_contextualized else concept_dim),
                          hidden_size=args.mlp_dim, num_hidden_layers=args.mlp_layer_num, num_attention_heads=args.att_head_num,
                          fc_size=args.fc_dim, num_fc_layers=args.fc_layer_num, dropout=args.dropoutm,
                          pretrained_concept_emb=cp_emb, pretrained_relation_emb=rel_emb, freeze_ent_emb=args.freeze_ent_emb,
                          init_range=args.init_range, ablation=args.ablation, use_contextualized=use_contextualized,
                          emb_scale=args.emb_scale, encoder_config=lstm_config)

    try:
        model.to(device)
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        rel_grad = []
        linear_grad = []
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                print('encoder unfreezed')
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                print('encoder refreezed')
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[correct_mask == 0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0),))
                        loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                rel_grad.append(model.decoder.rel_emb.weight.grad.abs().mean().item())
                linear_grad.append(model.decoder.mlp.layers[8].weight.grad.abs().mean().item())
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                    print('| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch))
                    # print('| rel_grad: {:1.2e} | linear_grad: {:1.2e} |'.format(sum(rel_grad) / len(rel_grad), sum(linear_grad) / len(linear_grad)))
                    total_loss = 0
                    rel_grad = []
                    linear_grad = []
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0
            print('-' * 71)
            print('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, dev_acc, test_acc))
            print('-' * 71)
            with open(log_path, 'a') as fout:
                fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                torch.save([model, args], model_path)
                print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc, best_dev_epoch))
    print('final test acc: {:.4f}'.format(final_test_acc))
    print()
Exemple #14
0
def train(args, train_dataset, model, model_config, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        # setup tb writer
        logger.info("Saving tensorboard logs to %s", args.tb_output_dir)
        tb_writer = SummaryWriter(log_dir=args.tb_output_dir, flush_secs=30)

        # Write config files to tensorboard
        tb_writer.add_text('encoder_config', str(model_config))

        # create train log file
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

        output_train_file = os.path.join(args.output_dir, "train_results.txt")

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    #scheduler = get_linear_schedule_with_warmup(
    #    optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    #)
    scheduler = get_constant_schedule(optimizer)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility

    # Log once before training starts
    # Only evaluate when single GPU and not torch.distributed otherwise metrics may not average well
    if args.local_rank in [-1, 0]:
        if args.evaluate_during_training:
            results = evaluate(args, model, tokenizer, dev_set=True)
            for key, value in results.items():
                tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                # log to wandb
                # wandb.log({f'eval_{key}': value}, step=0)

        # write to tensorboard
        tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)

    # Enter training loop
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                global_step += 1

                # log gradients before clipping them
                if args.local_rank in [
                        -1, 0
                ] and args.train_logging_steps > 0 and global_step % args.train_logging_steps == 0:
                    for name, param in model.named_parameters():
                        # tb_writer.add_histogram(name, param, global_step)
                        if param.grad is not None:
                            grads = param.grad.view(-1)
                            grads_norm = torch.norm(grads, p=2, dim=0)
                            tb_writer.add_scalar(name + '_grad_norm',
                                                 grads_norm, global_step)
                        # else:
                        # For XLM transformer.lang_embeddings.weight grads are disabled
                        # print('Gradients are disabled for:', name)

                if args.fp16:
                    total_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    total_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                # log stuff
                if args.local_rank in [
                        -1, 0
                ] and args.train_logging_steps > 0 and global_step % args.train_logging_steps == 0:
                    # log weights and gradients after clipping
                    for name, param in model.named_parameters():
                        # Compute l2 norm of the gradients
                        if param.grad is not None:
                            # tb_writer.add_histogram(name, param, global_step)
                            # tb_writer.add_histogram(
                            #     name + '_grad', param.grad, global_step)

                            grads = param.grad.view(-1)
                            grads_norm = torch.norm(grads, p=2, dim=0)
                            tb_writer.add_scalar(name + '_clipped_grad_norm',
                                                 grads_norm, global_step)

                    tb_writer.add_scalar('total_grad_norm', total_norm,
                                         global_step)

                    # log learning rate
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)

                    # log training loss
                    loss = (tr_loss - logging_loss) / args.train_logging_steps
                    tb_writer.add_scalar('train_loss', loss, global_step)
                    logging_loss = tr_loss

                    # log to wandb
                    wandb.log(
                        {
                            'loss': loss,
                            'total_grad_norm': total_norm,
                            'lr': scheduler.get_lr()[0]
                        },
                        step=global_step)

                    # write to logfile
                    with open(output_train_file, "a") as writer:
                        writer.write(f"{global_step}: train_loss = {loss}\n")

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value
                            # log to wandb
                            wandb.log({'eval_key': value}, step=0)

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    # log to wandb
                    wandb.log(
                        {
                            'loss': loss,
                            'total_grad_norm': total_norm,
                            'lr': scheduler.get_lr()[0]
                        },
                        step=global_step)

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer, orgin_dict):
    """ Train the model """
    record_result = []
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    # )

    scheduler = get_constant_schedule(optimizer)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    return_flag = False
    print('starting pruning')
    pruning_model(model, args.sparsity)
    rate_weight_equal_zero = see_weight_rate(model)
    print('zero_rate = ', rate_weight_equal_zero)

    print('starting rewinding')
    model_dict = model.state_dict()
    model_dict.update(orgin_dict)
    model.load_state_dict(model_dict)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= args.gradient_accumulation_steps and
                (step + 1) == len(epoch_iterator)):
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:

                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        rate_weight_equal_zero = see_weight_rate(model)
                        print('zero_rate = ', rate_weight_equal_zero)

                        results = evaluate(args, model, tokenizer)
                        # return_flag = True
                        record_result.append(results)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    torch.save(model, os.path.join(output_dir, "model.pt"))

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if return_flag:
                epoch_iterator.close()
                break

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if return_flag:
            epoch_iterator.close()
            break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    torch.save(record_result, os.path.join(args.output_dir, "record_result"))

    return global_step, tr_loss / global_step
Exemple #16
0
def train(args):
    logging.info(f'{socket.gethostname()}: {os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "unknown"}')
    logging.info('python ' + ' '.join(sys.argv))
    logging.info(args)

    model_path = os.path.join(args.save_dir, args.save_file_name)
    check_path(model_path)

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1))
    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)

    rel_emb = np.load(args.rel_emb_path)
    rel_emb = np.concatenate((rel_emb, -rel_emb), 0)
    rel_emb = torch.tensor(rel_emb)
    relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1)
    logging.info('| num_concepts: {} | num_relations: {} |'.format(concept_num, relation_num))

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    lm_data_loader = LMDataLoader(args.train_jsonl, args.dev_jsonl, args.test_jsonl,
                                  batch_size=args.mini_batch_size, eval_batch_size=args.eval_batch_size, device=device,
                                  model_name=args.encoder, max_seq_length=args.max_seq_len, is_inhouse=args.inhouse,
                                  inhouse_train_qids_path=args.inhouse_train_qids, subset_qids_path=args.subset_train_qids,
                                  format=args.format)
    logging.info(f'| # train questions: {lm_data_loader.train_size()} | # dev questions: {lm_data_loader.dev_size()} | # test questions: {lm_data_loader.test_size()} |')

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################
    graph_data_loader = GraphDataLoader(args.train_adj_pk, args.train_gen_pt, args.dev_adj_pk, args.dev_gen_pt,
                                        args.test_adj_pk, args.test_gen_pt,
                                        args.mini_batch_size, args.eval_batch_size, args.num_choice, args.ablation)
    train_avg_node_num, train_avg_edge_num = graph_data_loader.get_pyg_loader(lm_data_loader.get_train_indexes(), stats_only=True)

    dev_lm_data_loader = lm_data_loader.dev()
    dev_graph_loader, dev_avg_node_num, dev_avg_edge_num = graph_data_loader.dev_graph_data()
    assert len(dev_graph_loader) == len(dev_lm_data_loader)

    if args.inhouse:
        test_index = lm_data_loader.get_test_indexes()
        test_graph_loader, test_avg_node_num, test_avg_edge_num = graph_data_loader.get_pyg_loader(test_index)
    else:
        test_index = None
        test_graph_loader, test_avg_node_num, test_avg_edge_num = graph_data_loader.test_graph_data()
    test_lm_data_loader = lm_data_loader.test(test_index)
    assert len(test_graph_loader) == len(test_lm_data_loader)

    logging.info(f'| train | avg node num: {train_avg_node_num:.2f} | avg edge num: {train_avg_edge_num:.2f} |')
    logging.info(f'| dev   | avg node num: {dev_avg_node_num:.2f} | avg edge num: {dev_avg_edge_num:.2f} |')
    logging.info(f'| test  | avg node num: {test_avg_node_num:.2f} | avg edge num: {test_avg_edge_num:.2f} |')

    model = LMGraphNet(model_name=args.encoder, encoder_pooler=args.encoder_pooler,
                       concept_num=concept_num, concept_dim=relation_dim,
                       relation_num=relation_num, relation_dim=relation_dim, concept_in_dim=concept_dim,
                       hidden_size=args.mlp_dim, num_attention_heads=args.att_head_num,
                       fc_size=args.fc_dim, num_fc_layers=args.fc_layer_num, dropout=args.dropoutm,
                       edge_weight_dropout=args.edge_weight_dropout,
                       pretrained_concept_emb=cp_emb,  pretrained_relation_emb=rel_emb,
                       freeze_ent_emb=args.freeze_ent_emb, num_layers=args.num_gnn_layers,
                       ablation=args.ablation, emb_scale=args.emb_scale,
                       aristo_path=args.aristo_path)

    model.to(device)

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (lm_data_loader.train_size() / args.batch_size))
        if args.warmup_ratio is not None:
            warmup_steps = int(args.warmup_ratio * max_steps)
        else:
            warmup_steps = args.warmup_steps
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)

    logging.info('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            logging.info('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            logging.info('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    logging.info(f'\ttotal: {num_params}')

    loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    logging.info('')
    logging.info('-' * 71)
    global_step, eval_id, best_dev_id, best_dev_step = 0, 0, 0, 0
    best_dev_acc, final_test_acc, best_test_acc, total_loss = 0.0, 0.0, 0.0, 0.0
    best_test_acc = 0.0
    exit_training = False
    train_start_time = time.time()
    start_time = train_start_time
    model.train()
    freeze_net(model.encoder)
    try:
        binary_score_lst = []
        for epoch_id in range(args.n_epochs):
            if exit_training:
                break
            if epoch_id == args.unfreeze_epoch:
                logging.info('encoder unfreezed')
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                logging.info('encoder refreezed')
                freeze_net(model.encoder)
            model.train()
            i = 0
            optimizer.zero_grad()
            train_index = lm_data_loader.get_train_indexes()
            train_graph_loader, train_avg_node_num, train_avg_edge_num = graph_data_loader.get_pyg_loader(train_index)
            train_lm_data_loader = lm_data_loader.train(train_index)
            assert len(train_graph_loader) == len(train_lm_data_loader)
            for graph, (qids, labels, *lm_input_data) in zip(train_graph_loader, train_lm_data_loader):
                graph = graph.to(device)
                edge_index = graph.edge_index
                row, col = edge_index
                node_batch = graph.batch
                num_of_nodes = graph.num_of_nodes
                num_of_edges = graph.num_of_edges
                rel_ids_embs = graph.edge_attr
                c_ids = graph.x
                c_types = graph.node_type
                logits, unnormalized_wts, normalized_wts = model(*lm_input_data, edge_index=edge_index, c_ids=c_ids, c_types=c_types, node_batch=node_batch, rel_ids_embs=rel_ids_embs, num_of_nodes=num_of_nodes, num_of_edges=num_of_edges)
                loss = loss_func(logits, labels)  # scale: loss per question
                if 'no_edge_weight' not in args.ablation and 'GAT' not in args.ablation:  # add options for other kinds of sparsity
                    log_wts = torch.log(normalized_wts + 0.0000001)
                    entropy = - normalized_wts * log_wts  # entropy: [num_of_edges in the batched graph, 1]
                    entropy = scatter_mean(entropy, node_batch[row], dim=0, dim_size=args.mini_batch_size * args.num_choice)
                    loss += args.alpha * torch.mean(entropy)  # scale: entropy per graph (each question has num_choice graphs)
                loss = loss * args.mini_batch_size / args.batch_size  # will be accumulated for (args.batch_size / args.mini_batch_size) times
                loss.backward()
                total_loss += loss.item()
                if 'no_edge_weight' not in args.ablation and 'GAT' not in args.ablation:
                    binary_score_lst += entropy.squeeze().tolist()
                else:
                    binary_score_lst.append(0)
                i = i + args.mini_batch_size
                if i % args.batch_size == 0:
                    if args.max_grad_norm > 0:
                        nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    optimizer.step()  # bp: scale: loss per question
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1
                    if global_step % args.log_interval == 0:
                        total_loss /= args.log_interval
                        ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                        logging.info('| step {:5} | lr: {:9.7f} | loss {:7.20f} | entropy score {:7.4f} | ms/batch {:7.2f} |'
                                     .format(global_step, scheduler.get_lr()[0], total_loss, np.mean(binary_score_lst), ms_per_batch))
                        total_loss = 0
                        binary_score_lst = []
                        start_time = time.time()
                    if args.eval_interval > 0:
                        if global_step % args.eval_interval == 0:
                            eval_id += 1
                            model.eval()
                            dev_acc = evaluate_accuracy(dev_graph_loader, dev_lm_data_loader, model, device)
                            test_acc = evaluate_accuracy(test_graph_loader, test_lm_data_loader, model, device)
                            # test_acc = 0.2
                            best_test_acc = max(best_test_acc, test_acc)
                            logging.info('-' * 71)
                            logging.info('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(global_step, dev_acc, test_acc))
                            logging.info('-' * 71)
                            if dev_acc >= best_dev_acc:
                                best_dev_acc = dev_acc
                                final_test_acc = test_acc
                                best_dev_id = eval_id
                                best_dev_step = global_step
                                if args.save_model:
                                    torch.save(model.state_dict(), model_path)
                                    copyfile(model_path, f'{model_path}_{global_step}_{dev_acc*100:.2f}_{test_acc*100:.2f}.pt')  # tmp
                                logging.info(f'model saved to {model_path}')
                            else:
                                logging.info(f'hit patience {eval_id - best_dev_id}/{args.patience}')
                            model.train()
                            if epoch_id > args.unfreeze_epoch and eval_id - best_dev_id >= args.patience:
                                exit_training = True
                                break
            if args.eval_interval == 0:
                eval_id += 1
                model.eval()
                dev_acc = evaluate_accuracy(dev_graph_loader, dev_lm_data_loader, model, device)
                test_acc = evaluate_accuracy(test_graph_loader, test_lm_data_loader, model, device)
                best_test_acc = max(best_test_acc, test_acc)
                logging.info('-' * 71)
                logging.info('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, dev_acc, test_acc))
                logging.info('-' * 71)
                if dev_acc >= best_dev_acc:
                    best_dev_acc = dev_acc
                    final_test_acc = test_acc
                    best_dev_id = eval_id
                    best_dev_step = global_step
                    if args.save_model:
                        torch.save(model.state_dict(), model_path)
                    logging.info(f'model saved to {model_path}')
                else:
                    logging.info(f'hit patience {eval_id - best_dev_id}/{args.patience}')
                model.train()
                if epoch_id > args.unfreeze_epoch and eval_id - best_dev_id >= args.patience:
                    exit_training = True
                    break
            start_time = time.time()
    except KeyboardInterrupt:
        logging.info('-' * 89)
        logging.info('Exiting from training early')
    train_end_time = time.time()
    logging.info('')
    logging.info(f'training ends in {global_step} steps, {train_end_time - train_start_time:.0f} s')
    logging.info('best dev acc: {:.4f} (at step {})'.format(best_dev_acc, best_dev_step))
    logging.info('final test acc: {:.4f}'.format(final_test_acc))
    if args.use_last_epoch:
        logging.info(f'last dev acc: {dev_acc:.4f}')
        logging.info(f'last test acc: {test_acc:.4f}')
        return dev_acc, test_acc, best_test_acc
    else:
        return best_dev_acc, final_test_acc, best_test_acc
Exemple #17
0
    def train(self, train_path: str, valid_path: str, types_path: str,
              input_reader_cls: BaseInputReader):
        args = self.args
        train_label, valid_label = 'train', 'valid'

        self._logger.info("Datasets: %s, %s" % (train_path, valid_path))
        self._logger.info("Model type: %s" % args.model_type)

        # create log csv files
        self._init_train_logging(train_label)
        self._init_eval_logging(valid_label)

        # read datasets
        input_reader = input_reader_cls(types_path, args.bio_path,
                                        self._tokenizer, self._logger)
        input_reader.read({train_label: train_path, valid_label: valid_path})
        self._log_datasets(input_reader)

        train_dataset = input_reader.get_dataset(train_label)

        train_sample_count = train_dataset.document_count
        updates_epoch = train_sample_count // args.train_batch_size
        updates_total = updates_epoch * args.epochs

        steps_before_rel = int(updates_total * self.args.before_rel)

        validation_dataset = input_reader.get_dataset(valid_label)

        self._logger.info("Updates per epoch: %s" % updates_epoch)
        self._logger.info("Updates total: %s" % updates_total)
        self._logger.info("Updates before relation: %s" % steps_before_rel)

        # create model
        model_class = models.get_model(self.args.model_type)

        # load model
        if args.model_type == 'table_filling':
            model = model_class.from_pretrained(
                self.args.model_path,
                cache_dir=self.args.cache_path,
                tokenizer=self._tokenizer,
                # table_filling model parameters
                relation_labels=input_reader.relation_label_count,
                entity_labels=input_reader.entity_label_count,
                att_hidden=self.args.att_hidden,
                prop_drop=self.args.prop_drop,
                entity_label_embedding=self.args.entity_label_embedding,
                freeze_transformer=self.args.freeze_transformer,
                device=self._device)

#         if self._device.type != 'cpu':
#             torch.distributed.init_process_group(backend='nccl', world_size=3, init_method='...')
#             model = torch.nn.parallel.DistributedDataParallel(model)

        model.to(self._device)
        #         model.to(f'cuda:{model.device_ids[0]}')

        # create optimizer

        optimizer_params = self._get_optimizer_params(model)
        optimizer = AdamW(optimizer_params,
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          correct_bias=False)

        #         other_optimizer_params = self._get_optimizer_params([])
        # create scheduler

        if args.scheduler == 'constant':
            scheduler = transformers.get_constant_schedule(optimizer)
        elif args.scheduler == 'constant_warmup':
            scheduler = transformers.get_constant_schedule_with_warmup(
                optimizer, num_warmup_steps=args.lr_warmup * updates_total)
        elif args.scheduler == 'linear_warmup':
            scheduler = transformers.get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=args.lr_warmup * updates_total,
                num_training_steps=updates_total)
        elif args.scheduler == 'cosine_warmup':
            scheduler = transformers.get_cosine_schedule_with_warmup(
                optimizer,
                num_warmup_steps=args.lr_warmup * updates_total,
                num_training_steps=updates_total)
        elif args.scheduler == 'cosine_warmup_restart':
            scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
                optimizer,
                num_warmup_steps=args.lr_warmup * updates_total,
                num_training_steps=updates_total,
                num_cycles=args.num_cycles)

        # create loss function
        rel_criterion = torch.nn.CrossEntropyLoss(reduction='none')
        entity_criterion = torch.nn.CrossEntropyLoss(reduction='none')

        if args.model_type == 'table_filling':
            compute_loss = TableLoss(rel_criterion, entity_criterion, model,
                                     optimizer, scheduler, args.max_grad_norm)

        # eval validation set
        if args.init_eval:
            self._eval(model, compute_loss, validation_dataset, input_reader,
                       0, updates_epoch)

        # train
        for epoch in range(args.epochs):
            # train epoch
            self._train_epoch(model, compute_loss, optimizer, train_dataset,
                              updates_epoch, epoch, input_reader.context_size,
                              input_reader.entity_label_count,
                              input_reader.relation_label_count,
                              input_reader._start_entity_label,
                              steps_before_rel)

            # eval validation sets
            if not args.final_eval or (epoch == args.epochs - 1):
                ner_acc, rel_acc, rel_ner_acc = self._eval(
                    model, compute_loss, validation_dataset, input_reader,
                    epoch, updates_epoch)
                if args.save_best:
                    extra = dict(epoch=epoch,
                                 updates_epoch=updates_epoch,
                                 epoch_iteration=0)
                    self._save_best(model=model,
                                    optimizer=optimizer
                                    if self.args.save_optimizer else None,
                                    accuracy=ner_acc[2],
                                    iteration=epoch * updates_epoch,
                                    label='ner_micro_f1',
                                    extra=extra)

        # save final model
        extra = dict(epoch=args.epochs,
                     updates_epoch=updates_epoch,
                     epoch_iteration=0)
        global_iteration = args.epochs * updates_epoch
        self._save_model(
            self._save_path,
            model,
            global_iteration,
            optimizer=optimizer if self.args.save_optimizer else None,
            extra=extra,
            include_iteration=False,
            name='final_model')

        self._logger.info("Logged in: %s" % self._log_path)
        self._logger.info("Saved in: %s" % self._save_path)
Exemple #18
0
def train(args):
    util.ensure_dir(args["save_dir"])
    model_file = args["save_dir"] + "/" + "phonlp.pt"

    tokenizer = AutoTokenizer.from_pretrained(args["pretrained_lm"],
                                              use_fast=False)
    config_phobert = AutoConfig.from_pretrained(args["pretrained_lm"],
                                                output_hidden_states=True)

    print("Loading data with batch size {}...".format(args["batch_size"]))
    train_doc_dep = Document(
        CoNLL.conll2dict(input_file=args["train_file_dep"]))
    vocab = BuildVocab(args, args["train_file_pos"], train_doc_dep,
                       args["train_file_ner"]).vocab

    train_batch_pos = DataLoaderPOS(
        args["train_file_pos"],
        args["batch_size"],
        args,
        vocab=vocab,
        evaluation=False,
        tokenizer=tokenizer,
        max_seq_length=args["max_sequence_length"],
    )
    train_batch_dep = DataLoaderDep(
        train_doc_dep,
        args["batch_size"],
        args,
        vocab=vocab,
        evaluation=False,
        tokenizer=tokenizer,
        max_seq_length=args["max_sequence_length"],
    )
    train_batch_ner = DataLoaderNER(
        args["train_file_ner"],
        args["batch_size"],
        args,
        vocab=vocab,
        evaluation=False,
        tokenizer=tokenizer,
        max_seq_length=args["max_sequence_length"],
    )

    dev_doc_dep = Document(CoNLL.conll2dict(input_file=args["eval_file_dep"]))

    dev_batch_pos = DataLoaderPOS(
        args["eval_file_pos"],
        args["batch_size"],
        args,
        vocab=vocab,
        sort_during_eval=True,
        evaluation=True,
        tokenizer=tokenizer,
        max_seq_length=args["max_sequence_length"],
    )
    dev_batch_dep = DataLoaderDep(
        dev_doc_dep,
        args["batch_size"],
        args,
        vocab=vocab,
        sort_during_eval=True,
        evaluation=True,
        tokenizer=tokenizer,
        max_seq_length=args["max_sequence_length"],
    )
    dev_batch_ner = DataLoaderNER(
        args["eval_file_ner"],
        args["batch_size"],
        args,
        vocab=vocab,
        evaluation=True,
        tokenizer=tokenizer,
        max_seq_length=args["max_sequence_length"],
    )

    # pred and gold path
    system_pred_file = args["output_file_dep"]
    gold_file = args["eval_file_dep"]

    # ##POS

    dev_gold_tags = dev_batch_ner.tags

    # skip training if the language does not have training or dev data
    if len(train_batch_pos) == 0 or len(dev_batch_pos) == 0:
        print("Skip training because no data available...")
        sys.exit(0)

    print("Training jointmodel...")
    trainer = JointTrainer(args, vocab, None, config_phobert,
                           args["cuda"])  # ###
    tsfm = trainer.model.phobert
    for child in tsfm.children():
        for param in child.parameters():
            if not param.requires_grad:
                print("whoopsies")
            param.requires_grad = True

    global_step = 0
    las_score_history = 0
    uas_score_history = 0
    upos_score_history = 0
    f1_score_history = 0
    ####

    # start training
    train_loss = 0
    train_loss_pos = 0
    train_loss_dep = 0
    train_loss_ner = 0

    # Creating optimizer and lr schedulers
    param_optimizer = list(trainer.model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0
        },
    ]
    num_train_optimization_steps = int(
        args["num_epoch"] * len(train_batch_pos) / args["accumulation_steps"])
    optimizer = AdamW(
        optimizer_grouped_parameters, lr=args["lr"], correct_bias=False
    )  # To reproduce BertAdam specific behavior set correct_bias=False
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=5,
        num_training_steps=num_train_optimization_steps)
    get_constant_schedule(optimizer)
    for epoch in range(args["num_epoch"]):
        ####
        optimizer.zero_grad()
        print(" EPOCH  : ", epoch)
        step = 0
        lambda_pos = args["lambda_pos"]
        lambda_ner = args["lambda_ner"]
        lambda_dep = args["lambda_dep"]

        epoch_size = max(
            [len(train_batch_pos),
             len(train_batch_dep),
             len(train_batch_ner)])
        for i in tqdm(range(epoch_size)):
            step += 1
            global_step += 1
            batch_pos = train_batch_pos[i]
            batch_dep = train_batch_dep[i]
            batch_ner = train_batch_ner[i]
            ###
            loss, loss_pos, loss_ner = trainer.update(
                batch_dep,
                batch_pos,
                batch_ner,
                lambda_pos=lambda_pos,
                lambda_dep=lambda_dep,
                lambda_ner=lambda_ner)  # update step
            train_loss += loss
            train_loss_pos += loss_pos
            # train_loss_dep += loss_dep
            train_loss_ner += loss_ner
            ###

            if i % args["accumulation_steps"] == 0:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

            if epoch_size == len(train_batch_pos):
                if step % len(train_batch_dep) == 0:
                    train_batch_dep.reshuffle()
                if step % len(train_batch_ner) == 0:
                    train_batch_ner.reshuffle()
            elif epoch_size == len(train_batch_ner):
                if step % len(train_batch_dep) == 0:
                    train_batch_dep.reshuffle()
                if step % len(train_batch_pos) == 0:
                    train_batch_pos.reshuffle()
            elif epoch_size == len(train_batch_dep):
                if step % len(train_batch_pos) == 0:
                    train_batch_dep.reshuffle()
                if step % len(train_batch_ner) == 0:
                    train_batch_ner.reshuffle()
            if step % args["eval_interval"] == 0:
                print("Evaluating on dev set...")
                dev_preds_dep = []
                dev_preds_upos = []
                dev_preds_ner = []
                for batch in dev_batch_dep:
                    preds_dep = trainer.predict_dep(batch)
                    dev_preds_dep += preds_dep
                ###
                dev_preds_dep = util.unsort(dev_preds_dep,
                                            dev_batch_dep.data_orig_idx_dep)
                dev_batch_dep.doc_dep.set(
                    [HEAD, DEPREL], [y for x in dev_preds_dep for y in x])
                CoNLL.dict2conll(dev_batch_dep.doc_dep.to_dict(),
                                 system_pred_file)
                _, _, las_dev, uas_dev = score_dep.score(
                    system_pred_file, gold_file)

                for batch in dev_batch_pos:
                    preds_pos = trainer.predict_pos(batch)
                    dev_preds_upos += preds_pos
                dev_preds_upos = util.unsort(dev_preds_upos,
                                             dev_batch_pos.data_orig_idx_pos)
                accuracy_pos_dev = score_pos.score_acc(dev_preds_upos,
                                                       dev_batch_pos.upos)

                for batch in dev_batch_ner:
                    preds_ner = trainer.predict_ner(batch)
                    dev_preds_ner += preds_ner
                p, r, f1 = score_ner.score_by_entity(dev_preds_ner,
                                                     dev_gold_tags)
                for i in range(len(dev_batch_ner)):
                    assert len(dev_preds_ner[i]) == len(dev_gold_tags[i])

                print(
                    "step {}: dev_las_score = {:.4f}, dev_uas_score = {:.4f}, dev_pos = {:.4f}, dev_ner_p = {:.4f}, dev_ner_r = {:.4f}, dev_ner_f1 = {:.4f}"
                    .format(global_step, las_dev, uas_dev, accuracy_pos_dev, p,
                            r, f1))

                # save best model
                if las_dev + accuracy_pos_dev + f1 >= (las_score_history +
                                                       upos_score_history +
                                                       f1_score_history):
                    las_score_history = las_dev
                    upos_score_history = accuracy_pos_dev
                    uas_score_history = uas_dev
                    f1_score_history = f1
                    trainer.save(model_file)
                    print("new best model saved.")
                print("")

        print("Evaluating on dev set...")
        dev_preds_dep = []
        dev_preds_upos = []
        dev_preds_ner = []
        for batch in dev_batch_dep:
            preds_dep = trainer.predict_dep(batch)
            dev_preds_dep += preds_dep

        dev_preds_dep = util.unsort(dev_preds_dep,
                                    dev_batch_dep.data_orig_idx_dep)
        dev_batch_dep.doc_dep.set([HEAD, DEPREL],
                                  [y for x in dev_preds_dep for y in x])
        CoNLL.dict2conll(dev_batch_dep.doc_dep.to_dict(), system_pred_file)
        _, _, las_dev, uas_dev = score_dep.score(system_pred_file, gold_file)

        for batch in dev_batch_pos:
            preds_pos = trainer.predict_pos(batch)
            dev_preds_upos += preds_pos
        dev_preds_upos = util.unsort(dev_preds_upos,
                                     dev_batch_pos.data_orig_idx_pos)
        accuracy_pos_dev = score_pos.score_acc(dev_preds_upos,
                                               dev_batch_pos.upos)

        for batch in dev_batch_ner:
            preds_ner = trainer.predict_ner(batch)
            dev_preds_ner += preds_ner
        p, r, f1 = score_ner.score_by_entity(dev_preds_ner, dev_gold_tags)
        for i in range(len(dev_batch_ner)):
            assert len(dev_preds_ner[i]) == len(dev_gold_tags[i])

        train_loss = train_loss / len(train_batch_pos)  # avg loss per batch
        train_loss_dep = train_loss_dep / len(train_batch_pos)
        train_loss_pos = train_loss_pos / len(train_batch_pos)
        train_loss_ner = train_loss_ner / len(train_batch_pos)

        print(
            "step {}: train_loss = {:.6f}, train_loss_dep = {:.6f}, train_loss_pos = {:.6f}, train_loss_ner = {:.6f}, dev_las_score = {:.4f}, dev_uas_score = {:.4f}, dev_pos = {:.4f}, dev_ner_p = {:.4f}, dev_ner_r = {:.4f}, dev_ner_f1 = {:.4f} "
            .format(
                global_step,
                train_loss,
                train_loss_dep,
                train_loss_pos,
                train_loss_ner,
                las_dev,
                uas_dev,
                accuracy_pos_dev,
                p,
                r,
                f1,
            ))

        # save best model
        if las_dev + accuracy_pos_dev + f1 >= (
                las_score_history + upos_score_history + f1_score_history):
            las_score_history = las_dev
            upos_score_history = accuracy_pos_dev
            uas_score_history = uas_dev
            f1_score_history = f1
            trainer.save(model_file)
            print("new best model saved.")
        train_loss = 0
        train_loss_pos = 0
        train_loss_dep = 0
        train_loss_ner = 0

        print("")
        train_batch_dep.reshuffle()
        train_batch_pos.reshuffle()
        train_batch_ner.reshuffle()

    print("Training ended with {} epochs.".format(epoch))

    best_las, uas, upos, f1 = (
        las_score_history * 100,
        uas_score_history * 100,
        upos_score_history * 100,
        f1_score_history * 100,
    )
    print("Best dev las = {:.2f}, uas = {:.2f}, upos = {:.2f}, f1 = {:.2f}".
          format(best_las, uas, upos, f1))
Exemple #19
0
def train(args):
    print(args)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    model_path = os.path.join(args.save_dir, 'model.pt')
    check_path(model_path)

    logger = setup_logger(__name__, args.save_dir + "log.txt")
    logger.info(args)

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = LMDataLoader(args.train_statements,
                           args.dev_statements,
                           args.test_statements,
                           batch_size=args.batch_size,
                           eval_batch_size=args.eval_batch_size,
                           device=device,
                           model_name=args.encoder,
                           max_seq_length=args.max_seq_len,
                           is_inhouse=args.inhouse,
                           inhouse_train_qids_path=args.inhouse_train_qids,
                           subsample=args.subsample,
                           format=args.format)

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMForMultipleChoice(args.encoder,
                                from_checkpoint=args.from_checkpoint,
                                encoder_config=lstm_config)

    try:
        model.to(device)
    except RuntimeError as e:
        logger.info(e)
        logger.info('best dev acc: 0.0 (at epoch 0)')
        logger.info('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.weight']
    grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'lr':
        args.encoder_lr,
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'lr':
        args.encoder_lr,
        'weight_decay':
        0.0
    }]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = get_constant_schedule_with_warmup(
            optimizer, warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs *
                        (dataset.train_size() / args.batch_size))
        scheduler = get_linear_schedule_with_warmup(
            optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('***** running training *****')
    logger.info(
        f'| batch_size: {args.batch_size} | num_epochs: {args.n_epochs} | num_train: {dataset.train_size()} |'
        f' num_dev: {dataset.dev_size()} | num_test: {dataset.test_size()}')

    global_step = 0
    best_dev_acc = 0
    best_dev_epoch = 0
    final_test_acc = 0
    try:
        for epoch in range(int(args.n_epochs)):
            model.train()
            tqdm_bar = tqdm(dataset.train(), desc="Training")
            for qids, labels, *input_data in tqdm_bar:
                optimizer.zero_grad()
                batch_loss = 0
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits = model(*[x[a:b] for x in input_data],
                                   layer_id=args.encoder_layer)
                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(
                            labels, num_classes=num_choice).view(
                                -1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[
                            correct_mask == 1].contiguous().view(-1, 1).expand(
                                -1, num_choice - 1).contiguous().view(
                                    -1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[
                            correct_mask ==
                            0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0), ))
                        loss = loss_func(correct_logits, wrong_logits,
                                         y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    batch_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                tqdm_bar.desc = "loss: {:.2e}  lr: {:.2e}".format(
                    batch_loss,
                    scheduler.get_lr()[0])
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(
                dataset.test(), model) if dataset.test_size() > 0 else 0.0
            if dev_acc > best_dev_acc:
                final_test_acc = test_acc
                best_dev_acc = dev_acc
                best_dev_epoch = epoch
                torch.save([model, args], model_path)
            logger.info(
                '| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(
                    epoch, dev_acc, test_acc))
            if epoch - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print('***** training ends *****')
    print()
    logger.info('training ends in {} steps'.format(global_step))
    logger.info('best dev acc: {:.4f} (at epoch {})'.format(
        best_dev_acc, best_dev_epoch))
    logger.info('final test acc: {:.4f}'.format(final_test_acc))
    print()
Exemple #20
0
def train(args):
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,dev_acc,test_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float)

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('| num_concepts: {} |'.format(concept_num))

    # try:
    if True:
        if torch.cuda.device_count() >= 2 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:1")
        elif torch.cuda.device_count() == 1 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:0")
        else:
            device0 = torch.device("cpu")
            device1 = torch.device("cpu")
        dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj,
                                               args.dev_statements, args.dev_adj,
                                               args.test_statements, args.test_adj,
                                               batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
                                               device=(device0, device1),
                                               model_name=args.encoder,
                                               max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
                                               is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
                                               subsample=args.subsample, use_cache=args.use_cache)

        ###################################################################################################
        #   Build model                                                                                   #
        ###################################################################################################

        model = LM_QAGNN(args, args.encoder, k=args.k, n_ntype=4, n_etype=args.num_relation, n_concept=concept_num,
                                   concept_dim=args.gnn_dim,
                                   concept_in_dim=concept_dim,
                                   n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num,
                                   p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf,
                                   pretrained_concept_emb=cp_emb, freeze_ent_emb=args.freeze_ent_emb,
                                   init_range=args.init_range,
                                   encoder_config={})
        model.encoder.to(device0)
        model.decoder.to(device1)


    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        try:
            scheduler = ConstantLRSchedule(optimizer)
        except:
            scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        try:
            scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps)
        except:
            scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        try:
            scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)
        except:
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}\tdevice:{}'.format(name, param.size(), param.device))
        else:
            print('\t{:45}\tfixed\t{}\tdevice:{}'.format(name, param.size(), param.device))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    if True:
    # try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[correct_mask == 0]
                        y = wrong_logits.new_ones((wrong_logits.size(0),))
                        loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                    print('| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch))
                    total_loss = 0
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            save_test_preds = args.save_model
            if not save_test_preds:
                test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0
            else:
                eval_set = dataset.test()
                total_acc = []
                count = 0
                preds_path = os.path.join(args.save_dir, 'test_e{}_preds.csv'.format(epoch_id))
                with open(preds_path, 'w') as f_preds:
                    with torch.no_grad():
                        for qids, labels, *input_data in tqdm(eval_set):
                            count += 1
                            logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True)
                            predictions = logits.argmax(1) #[bsize, ]
                            preds_ranked = (-logits).argsort(1) #[bsize, n_choices]
                            for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)):
                                acc = int(pred.item()==label.item())
                                print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds)
                                f_preds.flush()
                                total_acc.append(acc)
                test_acc = float(sum(total_acc))/len(total_acc)

            print('-' * 71)
            print('| epoch {:3} | step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, global_step, dev_acc, test_acc))
            print('-' * 71)
            with open(log_path, 'a') as fout:
                fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                if args.save_model:
                    torch.save([model, args], model_path +".{}".format(epoch_id))
                    with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f:
                        for p in model.named_parameters():
                            print (p, file=f)
                    print(f'model saved to {model_path}')
            else:
                if args.save_model:
                    torch.save([model, args], model_path +".{}".format(epoch_id))
                    with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f:
                        for p in model.named_parameters():
                            print (p, file=f)
                    print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
Exemple #21
0
def train(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v)) for k, v in sorted(dict(vars(args)).items())))
    print()

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,train_acc,dev_acc\n')

    dic = {'transe': 0, 'numberbatch': 1}
    cp_emb, rel_emb = [np.load(args.ent_emb_paths[dic[source]]) for source in args.ent_emb], np.load(args.rel_emb_path)
    cp_emb = np.concatenate(cp_emb, axis=1)
    cp_emb = torch.tensor(cp_emb)
    rel_emb = np.concatenate((rel_emb, -rel_emb), 0)
    rel_emb = torch.tensor(rel_emb)
    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('num_concepts: {}, concept_dim: {}'.format(concept_num, concept_dim))
    relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1)
    print('num_relations: {}, relation_dim: {}'.format(relation_num, relation_dim))

    try:

        device0 = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")
        device1 = torch.device("cuda:1" if torch.cuda.is_available() and args.cuda else "cpu")
        dataset = KagNetDataLoader(args.train_statements, args.train_paths, args.train_graphs,
                                   args.dev_statements, args.dev_paths, args.dev_graphs,
                                   args.test_statements, args.test_paths, args.test_graphs,
                                   batch_size=args.mini_batch_size, eval_batch_size=args.eval_batch_size, device=(device0, device1),
                                   model_name=args.encoder, max_seq_length=args.max_seq_len, max_path_len=args.max_path_len,
                                   is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, use_cache=args.use_cache, format=args.format)
        print('dataset done')

        ###################################################################################################
        #   Build model                                                                                   #
        ###################################################################################################
        lstm_config = get_lstm_config_from_args(args)

        model = LMKagNet(model_name=args.encoder, concept_dim=concept_dim, relation_dim=relation_dim, concept_num=concept_num,
                         relation_num=relation_num, qas_encoded_dim=args.qas_encoded_dim, pretrained_concept_emb=cp_emb,
                         pretrained_relation_emb=rel_emb, lstm_dim=args.lstm_dim, lstm_layer_num=args.lstm_layer_num, graph_hidden_dim=args.graph_hidden_dim,
                         graph_output_dim=args.graph_output_dim, dropout=args.dropout, bidirect=args.bidirect, num_random_paths=args.num_random_paths,
                         path_attention=args.path_attention, qa_attention=args.qa_attention, encoder_config=lstm_config)
        print('model done')
        if args.freeze_ent_emb:
            freeze_net(model.decoder.concept_emb)
        print('freezed')
        model.encoder.to(device0)
        print('encoder done')
        model.decoder.to(device1)
        print('decoder done')
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    print()
    print('-' * 71)
    global_step, last_best_step = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    print(00)
                    b = min(a + args.mini_batch_size, bs)
                    # print(11)
                    # # print([x.device if isinstance(x, (torch.tensor,)) else None for x in input_data])
                    # print(type(input_data[0]), type(input_data[0][0]), input_data[0][0].size())
                    # print(type(input_data[1]), type(input_data[1][0]), input_data[1][0].size())
                    # print(type(input_data[2]), type(input_data[2][0]), input_data[2][0].size())
                    # print(type(input_data[3]), type(input_data[3][0]), input_data[3][0].size())
                    # print(type(input_data[4]), type(input_data[4][0]))
                    # print(type(input_data[5]), type(input_data[5][0]))
                    # print(type(input_data[6]), type(input_data[6][0]))
                    # print(type(input_data[7]), type(input_data[7][0]))
                    # print(type(input_data[8]), type(input_data[8][0]))
                    # print(type(input_data[9]))
                    # print(type(input_data[10]))
                    logits, _ = model(*[x for x in input_data], layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[correct_mask == 0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0),))
                        loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                    print('| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch))
                    total_loss = 0
                    start_time = time.time()

                if (global_step + 1) % args.eval_interval == 0:
                    model.eval()
                    dev_acc = evaluate_accuracy(dataset.dev(), model)
                    test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0
                    print('-' * 71)
                    print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(global_step, dev_acc, test_acc))
                    print('-' * 71)
                    with open(log_path, 'a') as fout:
                        fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc))
                    if dev_acc >= best_dev_acc:
                        best_dev_acc = dev_acc
                        final_test_acc = test_acc
                        last_best_step = global_step
                        torch.save([model, args], model_path)
                        print(f'model saved to {model_path}')
                    model.train()
                    start_time = time.time()

                global_step += 1
                # if global_step >= args.max_steps or global_step - last_best_step >= args.max_steps_before_stop:
                #     end_flag = True
                #     break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at step)'.format(best_dev_acc, last_best_step))
    print('final test acc: {:.4f}'.format(final_test_acc))
    def train(self, eval_set="val", train_set_name="train"):
        model_short_name = self.args.base_model_path_or_name.split('/')[-1]
        load_path = os.path.join(self.args.data_path, f"{model_short_name}-ft-{train_set_name}-data.pt")
        tokenizer = AutoTokenizer.from_pretrained(self.args.base_model_path_or_name)

        if os.path.isfile(load_path) and not self.args.override:
            encoded_train_examples = torch.load(load_path)
        else:
            data_path = os.path.join(self.args.data_path, f"{train_set_name}.json")
            examples = read_jsonl(data_path)
            if self.args.eda_aug:
                self.logger.info('start eda augmentation')
                categories = []
                for ex in examples:
                    cates = []
                    for each in ex["categories"].split(","):
                        cates.append(each.split("-")[-1])
                    categories.extend(cates)

                class_dist = Counter(categories)
                self.logger.info(f'dist of categories before augmentation: {json.dumps(class_dist, indent=2)}')
                examples, class_dist = self.aug_batch_with_eda(examples, class_dist, aug_target=500)
                self.logger.info(f'dist of categories after eda augmentation: {json.dumps(class_dist, indent=2)}')

            self.report_data_stats(examples)
            encoded_train_examples = self.encode_data(tokenizer, examples)
            torch.save(encoded_train_examples, load_path)

        categories = []
        for each in encoded_train_examples["categories"]:
            categories.extend(each.split(","))

        self.logger.info(f"the dist of categories (training): {json.dumps(Counter(categories), indent=2)}")
        self.logger.info(
            f"the dist of priority (training): {json.dumps(Counter(encoded_train_examples['priority']), indent=2)}")
        self.logger.info(
            f"the dist of tweets by events (training): {json.dumps(Counter(encoded_train_examples['events']), indent=2)}")

        eval_dataset = None
        if eval_set is not None and os.path.isfile(os.path.join(self.args.data_path, f"{eval_set}.json")):
            data_path = os.path.join(self.args.data_path, f"{eval_set}.json")
            examples = read_jsonl(data_path)
            encoded_eval_examples = self.encode_data(tokenizer, examples)
            eval_dataset = MyDataset(encoded_eval_examples)
            self.logger.info(
                f"the dist of tweets by events (eval): {json.dumps(Counter(encoded_eval_examples['events']), indent=2)}")

        model = MTLModelForSequenceClassification(self.args.base_model_path_or_name, len(self.cate_classes))
        train_dataset = MyDataset(encoded_train_examples)
        model = self.get_model_by_device(model)
        train_loader = DataLoader(train_dataset, batch_size=self.args.train_batch_size_per_device * self.device_count,
                                  num_workers=1, shuffle=True)
        total_steps = len(train_loader) * self.args.train_epochs / self.args.accumulation_steps

        no_decay = ["bias", "LayerNorm.weight"]
        params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
        params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
        optim_groups = [
            {"params": params_decay, "weight_decay": self.args.weight_decay},
            {"params": params_nodecay, "weight_decay": 0.0},
        ]

        optimizer = AdamW(optim_groups, lr=self.args.training_lr, eps=1e-8)
        # optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=self.training_args.pre_train_training_lr, eps=1e-8)

        if self.args.lr_scheduler == "linear":
            scheduler = get_linear_schedule_with_warmup(optimizer,
                                                        num_warmup_steps=self.args.warmup_ratio * total_steps,
                                                        num_training_steps=total_steps)
        elif self.args.lr_scheduler == "linearconstant":
            scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=total_steps)
        else:
            scheduler = get_constant_schedule(optimizer)

        multi_label_loss_fn = nn.BCEWithLogitsLoss()
        regression_loss_fn = nn.MSELoss()
        model.train()

        global_step = 0
        eval_loss = 0

        for i in range(self.args.train_epochs):
            self.logger.info(f"Epoch {i + 1}:")
            wrap_dataset_loader = tqdm(train_loader)
            model.zero_grad()
            total_epoch_loss = 0
            for j, batch in enumerate(wrap_dataset_loader):
                batch.pop("categories")
                batch.pop("priority")
                batch.pop("raw_text")
                batch.pop("events")

                categories_indices = batch.pop("categories_indices").to(self.device)
                priority_score = batch.pop("priority_score").to(self.device)
                inputs = {k: batch[k].to(self.device) for k in batch}

                classification_logits, regression_logits = model(inputs)
                classification_loss = multi_label_loss_fn(classification_logits, categories_indices.float())
                regression_loss = regression_loss_fn(regression_logits.view(-1).sigmoid(), priority_score.float())
                loss = self.args.alpha * classification_loss + (1 - self.args.alpha) * regression_loss
                total_epoch_loss += loss.item()
                eval_loss += loss.item()
                loss.backward()
                if (j + 1) % self.args.accumulation_steps == 0:
                    # Clip the norm of the gradients to 1.0.
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                global_step += 1
                wrap_dataset_loader.update(1)
                wrap_dataset_loader.set_description(
                    f"MTL-Training - epoch {i + 1}/{self.args.train_epochs} iter {j}/{len(wrap_dataset_loader)}: train loss {loss.item():.8f}. lr {scheduler.get_last_lr()[0]:e}")
                if self.args.eval_steps > 0 and global_step % self.args.eval_steps == 0:
                    self.logger.info(
                        f"\naverage training loss at global_step={global_step}: {eval_loss / self.args.eval_steps}")
                    eval_loss = 0
                    if eval_dataset is not None:
                        self.logger.info(
                            f"evaluation during training on {eval_set} set ({model_short_name}_epoch{i + 1}): ")
                        self.inference(model, eval_dataset)
                    model.train()

            self.logger.info(f"Average training loss for epoch {i + 1}: {total_epoch_loss / len(train_loader)}")
            # evaluate at the end of epoch if eval_steps is smaller than or equal to 0
            if self.args.eval_steps <= 0:
                self.logger.info(f"evaluation during training on {eval_set} set ({model_short_name}_epoch{i + 1}): ")
                self.inference(model, eval_dataset)
                model.train()
            # save up at end of each epoch!
            # model.save_pretrained(os.path.join(self.args.output_path, "mtl_train", model_short_name, f"epoch_{i + 1}"))
            # tokenizer.save_pretrained(os.path.join(self.args.output_path, "mtl_train", model_short_name, f"epoch_{i + 1}"))
        # save up at end of training!
        save_model_path = os.path.join(self.args.output_path, "mtl_train",
                                       model_short_name if not self.args.eda_aug else model_short_name + "-eda",
                                       "final_model")
        if isinstance(model, DataParallel):
            model.module.save_pretrained(save_model_path)
        else:
            model.save_pretrained(save_model_path)
        tokenizer.save_pretrained(save_model_path)

        # eval at the final model saved ck
        return_dict = {}
        if eval_dataset is not None:
            self.logger.info(f"evaluation on test set with mtl-trained model: {save_model_path}")
            return_dict = {f"mtl_train(eval_set={eval_set})": self.inference(model, eval_dataset)}
        return return_dict
Exemple #23
0
def train(args):
    cudnn.enabled = True
    cudnn.benchmark = True
    cudnn.deterministic = True

    print("torch_version:{}".format(torch.__version__))
    print("CUDA_version:{}".format(torch.version.cuda))
    print("cudnn_version:{}".format(cudnn.version()))

    init_seed(123456)

    data_path = args.base_data_path+args.dataset+'/'

    tokenizer, vocab2id, id2vocab = bert_tokenizer()
    detokenizer = bert_detokenizer()

    print('Vocabulary size', len(vocab2id))

    if os.path.exists(data_path + 'train_DukeNet.pkl'):
        query = torch.load(data_path + 'query_DukeNet.pkl')
        train_samples = torch.load(data_path + 'train_DukeNet.pkl')
        passage = torch.load(data_path + 'passage_DukeNet.pkl')
        print("The number of train_samples:", len(train_samples))
    else:
        samples, query, passage = load_default(args.dataset, args.datasetdata_path + args.dataset + '.answer',
                                                                   data_path + args.dataset + '.passage',
                                                                   data_path + args.dataset + '.pool',
                                                                   data_path + args.dataset + '.qrel',
                                                                   data_path + args.dataset + '.query',
                                                                   tokenizer)

        if args.dataset == "wizard_of_wikipedia":
            train_samples, dev_samples, test_seen_samples, test_unseen_samples = split_data(args.dataset, data_path + args.dataset + '.split', samples)
            print("The number of test_seen_samples:", len(test_seen_samples))
            print("The number of test_unseen_samples:", len(test_unseen_samples))
            torch.save(test_seen_samples, data_path + 'test_seen_DukeNet.pkl')
            torch.save(test_unseen_samples, data_path + 'test_unseen_DukeNet.pkl')

        elif args.dataset == "holl_e":
            train_samples, dev_samples, test_samples, = split_data(args.dataset, data_path + args.dataset + '.split', samples)
            print("The number of test_samples:", len(test_samples))
            torch.save(test_samples, data_path + 'test_DukeNet.pkl')

        print("The number of train_samples:", len(train_samples))
        print("The number of dev_samples:", len(dev_samples))
        torch.save(query, data_path + 'query_DukeNet.pkl')
        torch.save(passage, data_path + 'passage_DukeNet.pkl')
        torch.save(train_samples, data_path + 'train_DukeNet.pkl')
        torch.save(dev_samples, data_path + 'dev_DukeNet.pkl')


    model = DukeNet(vocab2id, id2vocab, args)
    saved_model_path = os.path.join(args.base_output_path + args.name + "/", 'model/')

    if args.resume is True:
        print("Reading checkpoints...")

        with open(saved_model_path + "checkpoints.json", 'r', encoding='utf-8') as r:
            checkpoints = json.load(r)
        last_epoch = checkpoints["time"][-1]

        fuse_dict = torch.load(os.path.join(saved_model_path, '.'.join([str(last_epoch), 'pkl'])))
        model.load_state_dict(fuse_dict["model"])
        print('Loading success, last_epoch is {}'.format(last_epoch))


    else:
        init_params(model, "enc")
        freeze_params(model, "enc")

        last_epoch = -1

        if not os.path.exists(saved_model_path):
            os.makedirs(saved_model_path)

        with open(saved_model_path + "checkpoints.json", 'w', encoding='utf-8') as w:
            checkpoints = {"time": []}
            json.dump(checkpoints, w)

    # construct an optimizer object
    model_optimizer = optim.Adam(model.parameters(), args.lr) # model.parameters() Returns an iterator over module parameters.This is typically passed to an optimizer.
    model_scheduler = get_constant_schedule(model_optimizer)

    if args.resume is True:
        model_scheduler.load_state_dict(fuse_dict["scheduler"])
        print('Loading scheduler, last_scheduler is', fuse_dict["scheduler"])

    trainer = CumulativeTrainer(args.name, model, tokenizer, detokenizer, args.local_rank, accumulation_steps=args.accumulation_steps)
    model_optimizer.zero_grad()  # Clears the gradients of all optimized torch.Tensor s.

    for i in range(last_epoch+1, args.epoches):
        if i==5:
            unfreeze_params(model, "enc")
            args.train_batch_size = 2
            args.accumulation_steps = 16

        train_dataset = Dataset(args.mode, train_samples, query, passage, vocab2id,
                                    args.max_knowledge_pool_when_train, args.max_knowledge_pool_when_inference, args.context_len, args.knowledge_sentence_len,
                                    args.max_dec_length)
        trainer.train_epoch('train', train_dataset, collate_fn, args.train_batch_size, i, model_optimizer, model_scheduler)
        del train_dataset

        trainer.serialize(i, model_scheduler, saved_model_path=saved_model_path)
Exemple #24
0
def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    data_files = dict(train=data_args.train_file, validation=data_args.validation_file)
    datasets = load_dataset('csv', data_files=data_files, cache_dir=model_args.cache_dir,
                            delimiter='\t', column_names=['img_id', 'graph', '_', 'text'])  # _ is unimportant
    model, tokenizer = load_model_and_tokenizer(model_args)
    prefix = "translate graph to text"
    column_names = datasets["train"].column_names
    padding = "max_length" if data_args.pad_to_max_length else False

    def preprocess_function(examples):
        inputs = [ex for ex in examples['graph']]
        targets = [ex for ex in examples['text']]
        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
        # Setup the tokenizer for targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    if training_args.do_train:
        train_dataset = datasets["train"]
        #train_dataset = train_dataset.select(range(1))
        if "train" not in datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = train_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )
    if training_args.do_eval:
        if "validation" not in datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = datasets["validation"]
        #eval_dataset = eval_dataset.select(range(1))
        eval_dataset = eval_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    else:
        data_collator = DataCollatorForSeq2Seq(
            tokenizer,
            model=model,
            label_pad_token_id=label_pad_token_id,
            pad_to_multiple_of=8 if training_args.fp16 else None,
        )
    metric = load_metric('sacrebleu')

    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
        labels = [[label] for label in labels]

        return preds, labels

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        result = metric.compute(predictions=decoded_preds, references=decoded_labels)
        result = {"bleu": result["score"]}
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)
        result = {k: round(v, 4) for k, v in result.items()}
        return result

    # this is the recommended t5 finetuning setup from
    # https://huggingface.co/transformers/main_classes/optimizer_schedules.html#adafactor-pytorch
    optimizer = Adafactor(
        model.parameters(),
        lr=3e-4,
        eps=(1e-30, 1e-3),
        clip_threshold=1.0,
        decay_rate=-0.8,
        beta1=None,
        weight_decay=0.0,
        relative_step=False,
        scale_parameter=False,
        warmup_init=False)
    """
    optimizer = AdamW(lr=3e-5)
    """
    lr_scheduler = transformers.get_constant_schedule(optimizer)
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
        optimizers=(optimizer, lr_scheduler)
    )
    if training_args.do_train:
        train_result = trainer.train()
        trainer.save_model()
        metrics = train_result.metrics
        max_train_samples = len(train_dataset)
        metrics["train_samples"] = len(train_dataset)
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
    if training_args.do_eval:
        metrics = trainer.evaluate(
            max_length=data_args.max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
        )
        max_val_samples = len(eval_dataset)
        metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
    def start(self, model, data, evaluation):
        """

        :param model:
        :type model: BertWrapperModel
        :param data:
        :type data: MultiData
        :type evaluation: Evaluation
        """
        self.prepare_training(model, data)

        if hasattr(model.bert.config, 'adapter_attention'):
            self.logger.info(
                "Adapter attention detected. Freezing all weights except the adapter attention"
            )
            for param in model.bert.bert.parameters():
                param.requires_grad = False
            model.bert.bert.enable_adapters(unfreeze_adapters=False,
                                            unfreeze_attention=True)
        elif hasattr(model.bert.config, 'adapters'):
            self.logger.info(
                "Adapters detected. Freezing all weights except the adapters")
            for param in model.bert.bert.parameters():
                param.requires_grad = False
            model.bert.bert.enable_adapters(unfreeze_adapters=True,
                                            unfreeze_attention=False)

        if self.config.get('freeze_bert', False):
            self.logger.warn('FREEZING BERT')
            for name, param in model.bert.bert.named_parameters():
                self.logger.warn('freeze {}'.format(name))
                param.requires_grad = False

        if self.config.get('freeze_head', False):
            self.logger.info("Freezing the weights of the classification head")
            for name, param in model.bert.lin_layer.named_parameters():
                param.requires_grad = False

        # for param in model.bert.parameters():
        #     if param.requires_grad:
        #         print(param)

        # Prepare BERT optimizer
        param_optimizer = list(model.bert.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            self.config.get('weight_decay', 0.0)
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        t_total = self.get_n_batches() * self.n_epochs
        num_warmup_steps = int(self.get_n_batches() *
                               self.config.get('warmup_proportion', 0.1))

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.config.get('learning_rate', 5e-5),
                          eps=self.config.get('adam_epsilon', 1e-8))
        # correct_bias=False)

        if num_warmup_steps > 0:
            scheduler = transformers.get_constant_schedule_with_warmup(
                optimizer, num_warmup_steps)
        else:
            # scheduler = WarmupConstantSchedule(optimizer=optimizer, warmup_steps=num_warmup_steps)
            scheduler = transformers.get_constant_schedule(optimizer)

        self.state.load(model.bert, optimizer, weights='last')
        start_epoch = self.state.recorded_epochs + 1
        end_epoch = self.n_epochs + 1

        if self.state.recorded_epochs > 0:
            self.logger.info(
                'Loaded the weights of last epoch {} with valid score={}'.
                format(self.state.recorded_epochs, self.state.scores[-1]))
            # if start_epoch < end_epoch and not self.config.get('skip_restore_validation', False):
            #     self.logger.info('Now calculating validation score (to verify the restoring success)')
            #     valid_score = list(evaluation.start(model, data, valid_only=True)[0].values())[0]
            #     self.logger.info('Score={:.4f}'.format(valid_score))

        self.logger.info('Running from epoch {} to epoch {}'.format(
            start_epoch, end_epoch - 1))

        global_step = self.get_n_batches() * self.state.recorded_epochs

        for epoch in range(start_epoch, end_epoch):
            self.logger.info('Epoch {}/{}'.format(epoch, self.n_epochs))

            self.logger.debug('Preparing epoch')
            self.prepare_next_epoch(model, data, epoch)

            bar = self.create_progress_bar('loss')
            train_losses = []  # used to calculate the epoch train loss
            recent_train_losses = []  # used to calculate the display loss

            self.logger.debug('Training')
            self.logger.debug('{} minibatches with size {}'.format(
                self.get_n_batches(), self.batchsize))

            for _ in bar(range(int(self.get_n_batches()))):
                # self.global_step += self.batchsize
                train_examples = self.get_next_batch(model, data)
                batch_loss = 0

                batch_steps = int(
                    np.ceil(
                        len(train_examples) /
                        (self.batchsize / self.gradient_accumulation_steps)))
                for i in range(batch_steps):
                    step_size = self.batchsize // self.gradient_accumulation_steps
                    step_examples = train_examples[i * step_size:(i + 1) *
                                                   step_size]
                    step_loss = self.get_loss(model, step_examples,
                                              self.adapter_task)
                    step_loss = step_loss / self.gradient_accumulation_steps
                    step_loss.backward()
                    batch_loss += step_loss.item()

                if self.config.get('max_grad_norm', 0) > 0:
                    torch.nn.utils.clip_grad_norm_(
                        model.bert.parameters(),
                        self.config.get('max_grad_norm'))
                    # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
                optimizer.step()
                scheduler.step(epoch)
                optimizer.zero_grad()
                global_step += 1

                self.tensorboard.add_scalars('scores',
                                             {'train_loss': batch_loss},
                                             global_step=global_step)

                recent_train_losses = ([batch_loss] + recent_train_losses)[:20]
                train_losses.append(recent_train_losses[0])
                bar.dynamic_messages['loss'] = np.mean(recent_train_losses)

            self.logger.info('train loss={:.6f}'.format(np.mean(train_losses)))

            if self.config.get('evaluate_dev', True):
                self.logger.info('Now calculating validation score')
                valid_score, valid_score_other_measures = evaluation.start(
                    model, data, valid_only=True)
                valid_score = list(valid_score.values(
                ))[0]  # get only the dev split (there wont be any other split)
                valid_score_other_measures = list(
                    valid_score_other_measures.values())[0]

                self.tensorboard.add_scalar('valid_score',
                                            valid_score,
                                            global_step=global_step)
                self.tensorboard.add_scalars(
                    'scores',
                    dict([('valid_' + k, v)
                          for k, v, in valid_score_other_measures.items()]),
                    global_step=global_step)
                for key, value in valid_score_other_measures.items():
                    self.tensorboard.add_scalar(key,
                                                value,
                                                global_step=global_step)

                self.state.record(model.bert,
                                  optimizer if self.chkpt_optimizer else None,
                                  valid_score, self.backup_checkpoint_every)
            else:
                self.logger.info('Not validating dev. Setting score to epoch')
                self.state.record(model.bert,
                                  optimizer if self.chkpt_optimizer else None,
                                  epoch, self.backup_checkpoint_every)

        return self.state.best_epoch, self.state.best_score
Exemple #26
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriterP(args.output_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if p.requires_grad and not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if p.requires_grad and any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    warmup_steps = args.warmup_samples // args.train_batch_size
    if args.lr_decay:
        scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                    warmup_steps=warmup_steps,
                                                    t_total=t_total)
    else:
        scheduler = get_constant_schedule(optimizer, warmup_steps=warmup_steps)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    try:
        with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c:
            global_step = int(c.readline())
    except OSError as e:
        global_step = 0

    tr_loss, logging_loss = 0.0, 0.0
    moving_loss = MovingLoss(10000 // args.logging_steps)
    model.zero_grad()

    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    try:
        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=args.local_rank not in [-1, 0])
            for step, batch in enumerate(epoch_iterator):
                inputs, labels = mask_tokens(
                    batch, tokenizer, args) if args.mlm else (batch, batch)
                inputs = inputs.to(args.device)
                labels = labels.to(args.device)
                model.train()
                outputs = model(
                    inputs, masked_lm_labels=labels) if args.mlm else model(
                        inputs, labels=labels)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                moving_loss.add(loss.item())
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training and global_step % args.eval_steps == 0:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer,
                                           f"checkpoint-{global_step}")
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)

                    if args.local_rank in [
                            -1, 0
                    ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             args.logging_steps, global_step)
                        logging_loss = tr_loss
                        epoch_iterator.set_postfix(
                            MovingLoss=f'{moving_loss.loss:.2f}',
                            Perplexity=
                            f'{torch.exp(torch.tensor(moving_loss.loss)):.2f}')

                    if args.local_rank in [
                            -1, 0
                    ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        save_state(args, model, tokenizer, global_step)

                if args.max_steps > 0 and global_step > args.max_steps:
                    epoch_iterator.close()
                    break
            print_sample(model, tokenizer, args.device, args)
            if args.max_steps > 0 and global_step > args.max_steps:
                train_iterator.close()
                break
    except (KeyboardInterrupt, SystemExit):
        save_state(args, model, tokenizer, global_step)
        raise

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #27
0
def train(args, training_features, model, tokenizer):
    """ Train the model """
    wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"),
               config=args,
               name=args.run_name)
    wandb.watch(model)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    else:
        amp = None

    # model recover
    recover_step = utils.get_max_epoch_model(args.output_dir)

    # if recover_step:
    #     model_recover_checkpoint = os.path.join(args.output_dir, "model.{}.bin".format(recover_step))
    #     logger.info(" ** Recover model checkpoint in %s ** ", model_recover_checkpoint)
    #     model_state_dict = torch.load(model_recover_checkpoint, map_location='cpu')
    #     optimizer_recover_checkpoint = os.path.join(args.output_dir, "optim.{}.bin".format(recover_step))
    #     checkpoint_state_dict = torch.load(optimizer_recover_checkpoint, map_location='cpu')
    #     checkpoint_state_dict['model'] = model_state_dict
    # else:
    checkpoint_state_dict = None

    model.to(args.device)
    model, optimizer = prepare_for_training(args,
                                            model,
                                            checkpoint_state_dict,
                                            amp=amp)

    if args.n_gpu == 0 or args.no_cuda:
        per_node_train_batch_size = args.per_gpu_train_batch_size * args.gradient_accumulation_steps
    else:
        per_node_train_batch_size = args.per_gpu_train_batch_size * args.n_gpu * args.gradient_accumulation_steps

    train_batch_size = per_node_train_batch_size * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)
    global_step = recover_step if recover_step else 0

    if args.num_training_steps == -1:
        args.num_training_steps = int(args.num_training_epochs *
                                      len(training_features) /
                                      train_batch_size)

    if args.warmup_portion:
        args.num_warmup_steps = args.warmup_portion * args.num_training_steps

    if args.scheduler == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.num_warmup_steps,
            num_training_steps=args.num_training_steps,
            last_epoch=-1)

    elif args.scheduler == "constant":
        scheduler = get_constant_schedule(optimizer, last_epoch=-1)

    elif args.scheduler == "1cycle":
        scheduler = OneCycleLR(optimizer,
                               max_lr=args.learning_rate,
                               total_steps=args.num_training_steps,
                               pct_start=args.warmup_portion,
                               anneal_strategy=args.anneal_strategy,
                               final_div_factor=1e4,
                               last_epoch=-1)

    else:
        assert False

    if checkpoint_state_dict:
        scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"])

    train_dataset = utils.Seq2seqDatasetForBert(
        features=training_features,
        max_source_len=args.max_source_seq_length,
        max_target_len=args.max_target_seq_length,
        vocab_size=tokenizer.vocab_size,
        cls_id=tokenizer.cls_token_id,
        sep_id=tokenizer.sep_token_id,
        pad_id=tokenizer.pad_token_id,
        mask_id=tokenizer.mask_token_id,
        random_prob=args.random_prob,
        keep_prob=args.keep_prob,
        offset=train_batch_size * global_step,
        num_training_instances=train_batch_size * args.num_training_steps,
    )

    logger.info("Check dataset:")
    for i in range(5):
        source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens = train_dataset.__getitem__(
            i)
        logger.info("Instance-%d" % i)
        logger.info("Source tokens = %s" %
                    " ".join(tokenizer.convert_ids_to_tokens(source_ids)))
        logger.info("Target tokens = %s" %
                    " ".join(tokenizer.convert_ids_to_tokens(target_ids)))

    logger.info("Mode = %s" % str(model))

    # Train!
    logger.info("  ***** Running training *****  *")
    logger.info("  Num examples = %d", len(training_features))
    logger.info("  Num Epochs = %.2f",
                len(train_dataset) / len(training_features))
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info("  Batch size per node = %d", per_node_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", args.num_training_steps)

    if args.num_training_steps <= global_step:
        logger.info(
            "Training is done. Please use a new dir or clean this dir!")
    else:
        # The training features are shuffled
        train_sampler = SequentialSampler(train_dataset) \
            if args.local_rank == -1 else DistributedSampler(train_dataset, shuffle=False)
        train_dataloader = DataLoader(
            train_dataset,
            sampler=train_sampler,
            batch_size=per_node_train_batch_size //
            args.gradient_accumulation_steps,
            collate_fn=utils.batch_list_to_batch_tensors)

        train_iterator = tqdm.tqdm(train_dataloader,
                                   initial=global_step,
                                   desc="Iter (loss=X.XXX, lr=X.XXXXXXX)",
                                   disable=args.local_rank not in [-1, 0])

        model.train()
        model.zero_grad()

        tr_loss, logging_loss = 0.0, 0.0

        for step, batch in enumerate(train_iterator):
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'source_ids': batch[0],
                'target_ids': batch[1],
                'pseudo_ids': batch[2],
                'num_source_tokens': batch[3],
                'num_target_tokens': batch[4]
            }
            loss = model(**inputs)
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training

            train_iterator.set_description(
                'Iter (loss=%5.3f) lr=%9.7f' %
                (loss.item(), scheduler.get_last_lr()[0]))

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            logging_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    wandb.log(
                        {
                            'lr': scheduler.get_last_lr()[0],
                            'loss': logging_loss / args.logging_steps
                        },
                        step=global_step)

                    logger.info(" Step [%d ~ %d]: %.2f",
                                global_step - args.logging_steps, global_step,
                                logging_loss)
                    logging_loss = 0.0

                if args.local_rank in [-1, 0] and args.save_steps > 0 and \
                        (global_step % args.save_steps == 0 or global_step == args.num_training_steps):

                    save_path = os.path.join(args.output_dir,
                                             "ckpt-%d" % global_step)
                    os.makedirs(save_path, exist_ok=True)
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(save_path)

                    # optim_to_save = {
                    #     "optimizer": optimizer.state_dict(),
                    #     "lr_scheduler": scheduler.state_dict(),
                    # }
                    # if args.fp16:
                    #     optim_to_save["amp"] = amp.state_dict()
                    # torch.save(
                    #     optim_to_save, os.path.join(args.output_dir, 'optim.{}.bin'.format(global_step)))

                    logger.info("Saving model checkpoint %d into %s",
                                global_step, save_path)

    wandb.save(f'{save_path}/*')
Exemple #28
0
def train(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    print()

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    if args.save:
        export_config(args, config_path)
        check_path(model_path)
        with open(log_path, 'w') as fout:
            fout.write('step,train_acc,dev_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1))

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('num_concepts: {}, concept_dim: {}'.format(concept_num, concept_dim))

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = GconAttnDataLoader(
        train_statement_path=args.train_statements,
        train_concept_jsonl=args.train_concepts,
        dev_statement_path=args.dev_statements,
        dev_concept_jsonl=args.dev_concepts,
        test_statement_path=args.test_statements,
        test_concept_jsonl=args.test_concepts,
        concept2id_path=args.cpnet_vocab_path,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        device=device,
        model_name=args.encoder,
        max_cpt_num=max_cpt_num[args.dataset],
        max_seq_length=args.max_seq_len,
        is_inhouse=args.inhouse,
        inhouse_train_qids_path=args.inhouse_train_qids,
        subsample=args.subsample,
        format=args.format)

    print('len(train_set): {}   len(dev_set): {}   len(test_set): {}'.format(
        dataset.train_size(), dataset.dev_size(), dataset.test_size()))
    print()

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMGconAttn(model_name=args.encoder,
                       concept_num=concept_num,
                       concept_dim=args.cpt_out_dim,
                       concept_in_dim=concept_dim,
                       freeze_ent_emb=args.freeze_ent_emb,
                       pretrained_concept_emb=cp_emb,
                       hidden_dim=args.decoder_hidden_dim,
                       dropout=args.dropoutm,
                       encoder_config=lstm_config)

    if args.freeze_ent_emb:
        freeze_net(model.decoder.concept_emb)

    try:
        model.to(device)
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.decoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.decoder_lr
        },
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = get_constant_schedule_with_warmup(
            optimizer, warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs *
                        (dataset.train_size() / args.batch_size))
        scheduler = get_linear_schedule_with_warmup(
            optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters()
                     if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data],
                                      layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(
                            labels, num_classes=num_choice).view(
                                -1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[
                            correct_mask == 1].contiguous().view(-1, 1).expand(
                                -1, num_choice - 1).contiguous().view(
                                    -1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[
                            correct_mask ==
                            0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0), ))
                        loss = loss_func(correct_logits, wrong_logits,
                                         y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() -
                                           start_time) / args.log_interval
                    print(
                        '| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'
                        .format(global_step,
                                scheduler.get_lr()[0], total_loss,
                                ms_per_batch))
                    total_loss = 0
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(
                dataset.test(), model) if args.test_statements else 0.0
            print('-' * 71)
            print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(
                global_step, dev_acc, test_acc))
            print('-' * 71)
            if args.save:
                with open(log_path, 'a') as fout:
                    fout.write('{},{},{}\n'.format(global_step, dev_acc,
                                                   test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                if args.save:
                    torch.save([model, args], model_path)
                    print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc,
                                                      best_dev_epoch))
    print('final test acc: {:.4f}'.format(final_test_acc))
    print()
Exemple #29
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    # Allow for different learning rate for final layers
    final_layers = [
        'span_outputs.weight', 'span_outputs.bias', 'type_output.weight',
        'type_output.bias'
    ]
    if args.final_layers_lr == -1.0:
        args.final_layers_lr = args.learning_rate
    if args.final_layers_wd == -1.0:
        args.final_layers_wd = args.weight_decay

    final_layer_params = [(n, p) for n, p in model.named_parameters()
                          if n in final_layers]
    non_final_layer_params = [(n, p) for n, p in model.named_parameters()
                              if n not in final_layers]

    no_decay = ['bias', 'LayerNorm.weight']
    final_layer_decaying_params = [
        p for n, p in final_layer_params if not any(nd in n for nd in no_decay)
    ]
    final_layer_nondecaying_params = [
        p for n, p in final_layer_params if any(nd in n for nd in no_decay)
    ]

    non_final_layer_decaying_params = [
        p for n, p in non_final_layer_params
        if not any(nd in n for nd in no_decay)
    ]
    non_final_layer_nondecaying_params = [
        p for n, p in non_final_layer_params if any(nd in n for nd in no_decay)
    ]

    optimizer_grouped_parameters = [
        {
            'params': final_layer_decaying_params,
            'lr': args.final_layers_lr,
            'weight_decay': args.final_layers_wd
        },
        {
            'params': final_layer_nondecaying_params,
            'lr': args.final_layers_lr,
            'weight_decay': 0.0
        },
        {
            'params': non_final_layer_decaying_params,
            'lr': args.learning_rate,
            'weight_decay': args.weight_decay
        },
        {
            'params': non_final_layer_nondecaying_params,
            'lr': args.learning_rate,
            'weight_decay': 0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)

    # Allow choice between lr schedules
    if args.constant_lr and args.warmup_steps == 0:
        scheduler = get_constant_schedule(optimizer)
    elif args.constant_lr and args.warmup_steps > 0:
        scheduler = get_constant_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_steps)
    else:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            inputs = {
                'input_ids': batch['input_ids'].to(args.device),
                'attention_mask': batch['attention_mask'].to(args.device),
                'token_type_ids': batch['token_type_ids'].to(args.device),
                'start_positions': batch['start_positions'].to(args.device),
                'end_positions': batch['end_positions'].to(args.device),
                'instance_types': batch['instance_types'].to(args.device)
            }
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args,
                                           model,
                                           tokenizer,
                                           dataset_type='dev',
                                           prefix=str(global_step))
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('lr_final_layers',
                                         scheduler.get_lr()[1], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step