示例#1
0
    def evaluate_task(self, task, train_set):
        scorer = task.get_scorer(dump_to_file={
            "output_dir": self.config.output_dir,
            "task_name": task.name
        })
        data: DataFlow = task.train_set if train_set else task.val_set
        for i, mb in enumerate(
                data.get_minibatches(
                    self.config.tasks[data.task_name]["test_batch_size"])):
            # batch_preds = self.model.test(mb)
            outputs = self.model.test_abstract(mb)
            task_outputs = outputs[task.name]
            scorer.update(mb, task_outputs, task_outputs.get("loss", 0), {})
            # TODO: This code will break a lot!
            # TODO: But we need to migrate this!

            # if isinstance(batch_preds, dict):
            #   scorer.update(mb, batch_preds, 0, None)
            # else:
            #   # Slow Migration towards the return interface
            #   extra_output = {}
            #   if self.config.output_attentions:
            #     batch_preds, attention_map = batch_preds
            #     extra_output["attention_map"] = attention_map
            #   if isinstance(batch_preds, tuple):
            #     loss, batch_preds = batch_preds
            #   else:
            #     loss = 0
            #   scorer.update(mb, batch_preds, loss, extra_output)
            if i % 100 == 0:
                utils.log("{} batch processed.".format(i))
        results = scorer.get_results()
        utils.log(task.name.upper() + ": " + scorer.results_str())
        return results
示例#2
0
  def train_labeled_abstract(self, mb: MiniBatch, step):
    self.model.train()

    inputs = mb.generate_input(device=self.device, use_label=True)
    if "input_ids" in inputs and inputs["input_ids"].size(0) == 0:
      utils.log("Zero Batch")
      return 0

    outputs = self.model(**inputs)

    # TODO: Slow process Migrating Interface ...
    if isinstance(outputs, dict):
      loss = outputs[mb.task_name]["loss"]
    else:
      if self.config.output_attentions:
        loss, _, _ = outputs
      else:
        loss, _ = outputs

    loss = mb.loss_weight * loss

    if self.config.gradient_accumulation_steps > 1:
      loss = loss / self.config.gradient_accumulation_steps
    loss.backward()
    if (step + 1) % self.config.gradient_accumulation_steps == 0:
      # TODO: a quick fix
      if not hasattr(mb, "task_name") or mb.task_name not in ["squad11", "squad20"]:
        nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
      self.optimizer.step()
      self.optimizer.zero_grad()
      self.global_step_labeled += 1
    return loss.item()
示例#3
0
 def set_mode(self):
     if self.mode == "alignment":
         utils.log("Parallel Mapping Mode: Alignment")
         pass
     if self.mode == "adjustment":
         for param in self.W.parameters():
             param.requires_grad = False
         utils.log("Parallel Mapping Mode: Adjustment")
示例#4
0
  def get_minibatches(self, minibatch_size, sequential=True, bucket=False):
    """Generate list of batch size based on examples.

    There are two modes for generating batches. One is sequential,
    which follows the original example sequence in the dataset.
    The other mode is based on bucketing, to save the memory consumption.

    NOTICE: The default value here is for evaluation. Because the evaluation interface will directly call this.
    The traning interface will call endless_minibatches.

    Args:
      minibatch_size (int): Batch size.
      sequential (bool): To be sequential mode or not.

    """
    if sequential:
      index = 0
      while index < self.size:
        yield self._make_minibatch(
            np.array(range(index, min(index + minibatch_size, self.size))))
        index += minibatch_size
    elif not bucket:
      indices = list(range(self.size))
      random.shuffle(indices)
      indices = np.array(indices)
      index = 0
      while index < self.size:
        yield self._make_minibatch(
            indices[index: min(index + minibatch_size, self.size)])
        index += minibatch_size
    else:
      by_bucket = collections.defaultdict(list)
      for i, example in enumerate(self.examples):
        by_bucket[get_bucket(self.config, example.bucketing_len)].append(i)
      # save memory by weighting examples so longer sentences
      #   have smaller minibatches
      weight = lambda ind: np.sqrt(self.examples[ind].bucketing_len)
      total_weight = float(sum(weight(i) for i in range(self.size)))
      weight_per_batch = minibatch_size * total_weight / self.size
      cumulative_weight = 0.0
      id_batches = []
      for _, ids in by_bucket.items():
        ids = np.array(ids)
        np.random.shuffle(ids)
        curr_batch, curr_weight = [], 0.0
        for i, curr_id in enumerate(ids):
          curr_batch.append(curr_id)
          curr_weight += weight(curr_id)
          if (i == len(ids) - 1 or cumulative_weight + curr_weight >=
              (len(id_batches) + 1) * weight_per_batch):
            cumulative_weight += curr_weight
            id_batches.append(np.array(curr_batch))
            curr_batch, curr_weight = [], 0.0
      random.shuffle(id_batches)
      utils.log("Data Flow {}, There are {} batches".format(
          self.__class__, len(id_batches)))
      for id_batch in id_batches:
        yield self._make_minibatch(id_batch)
示例#5
0
 def __init__(self, config, restore_if_possible=True):
     self.config = config
     if restore_if_possible and os.path.exists(config.progress):
         history, current_file, current_line = utils.load_pickle(
             config.progress, memoized=False)
         self.history = history
         # self.unlabeled_data_reader =
     else:
         utils.log("No previous checkpoint found - starting from scratch")
         self.history = []
     self.evaluated_steps = set([0])
     self.log_steps = set([])
示例#6
0
 def restore_teacher(self, model_path):
     restore_state_dict = torch.load(
         model_path, map_location=lambda storage, location: storage)
     # loaded_dict = {k: restore_state_dict[k] for k in
     #                set(self.model.model.state_dict().keys()) & set(restore_state_dict.keys())}
     # model_state = self.model.model.state_dict()
     # model_state.update(loaded_dict)
     for key in self.config.ignore_parameters:
         # restore_state_dict.pop(key)
         restore_state_dict[key] = self.model.model.state_dict()[key]
     self.teacher_model.model.load_state_dict(restore_state_dict)
     utils.log("Teacher Model Restored from {}".format(model_path))
示例#7
0
 def __init__(self, config, ext_config):
   self.config = config
   self.ext_config = ext_config
   if config.local_rank == -1 or config.no_cuda:
     self.device = torch.device("cuda" if torch.cuda.is_available() and not config.no_cuda else "cpu")
     n_gpu = torch.cuda.device_count()
   else:
     torch.cuda.set_device(config.local_rank)
     self.device = torch.device("cuda:" + str(config.local_rank))
     n_gpu = 1
     # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
     # torch.distributed.init_process_group(backend='nccl')
   utils.log("device: {}".format(self.device))
示例#8
0
 def __init__(self, config, task_name, n_classes):
   super(SequenceLabelingModule, self).__init__()
   self.config = config
   self.task_name = task_name
   self.n_classes = n_classes
   if hasattr(self.config, "sequence_labeling_use_cls") and self.config.sequence_labeling_use_cls:
     self.mul = 2
     log("Use cls in sequence labeling")
   else:
     self.mul = 1
   self.projection = nn.Linear(config.hidden_size * self.mul, config.projection_size)
   self.to_logits = nn.Linear(config.projection_size, self.n_classes)
   self.init_weight()
示例#9
0
 def restore(self, model_path, model_name="default.ckpt"):
     model_ckpt_path = os.path.join(model_path, model_name)
     restore_state_dict = torch.load(
         model_ckpt_path, map_location=lambda storage, location: storage)
     # loaded_dict = {k: restore_state_dict[k] for k in
     #                set(self.model.model.state_dict().keys()) & set(restore_state_dict.keys())}
     # model_state = self.model.model.state_dict()
     # model_state.update(loaded_dict)
     for key in self.config.ignore_parameters:
         # restore_state_dict.pop(key)
         restore_state_dict[key] = self.model.model.state_dict()[key]
     self.model.model.load_state_dict(restore_state_dict)
     utils.log("Model Restored from {}".format(model_path))
示例#10
0
 def feature_extraction(self, task, dump_file):
     data = task.val_set
     fout = open(dump_file, 'wb')
     features = []
     for i, mb in enumerate(
             data.get_minibatches(
                 self.config.tasks[data.task_name]["test_batch_size"])):
         # batch_preds = self.model.test(mb)
         batch_preds = self.model.test_abstract(mb)
         features.append(batch_preds[task.name]["features"].cpu().detach())
         if i % 100 == 0:
             utils.log("{} batch processed.".format(i))
     features = torch.cat(features, dim=0).numpy()
     np.savez(fout, X=np.array(features))
     fout.close()
示例#11
0
    def analyze_task(self, task, head_mask, head_importance, attn_entropy):
        scorer = task.get_scorer(dump_to_file={
            "output_dir": self.config.output_dir,
            "task_name": task.name
        })
        params = {
            "head_importance": head_importance,
            "attn_entropy": attn_entropy,
            "total_token": 0.0
        }
        data: DataFlow = task.val_set
        # The output of the model is logits and attention_map
        for i, mb in enumerate(
                data.get_minibatches(
                    self.config.tasks[data.task_name]["test_batch_size"])):
            batch_preds, attention_map = self.model.analyze(mb,
                                                            head_mask,
                                                            params=params)
            extra_output = {}
            extra_output["attention_map"] = attention_map
            loss = 0
            scorer.update(mb, batch_preds, loss, extra_output)
            if i % 100 == 0:
                utils.log("{} batch processed.".format(i))
        results = scorer.get_results()
        utils.log(task.name.upper() + ": " + scorer.results_str())

        params["attn_entropy"] /= params["total_token"]
        params["head_importance"] /= params["total_token"]
        np.save(os.path.join(self.config.output_dir, 'attn_entropy.npy'),
                attn_entropy.detach().cpu().numpy())
        np.save(os.path.join(self.config.output_dir, 'head_importance.npy'),
                head_importance.detach().cpu().numpy())

        utils.log("Attention entropies")
        print_2d_tensor(attn_entropy)
        utils.log("Head importance scores")
        print_2d_tensor(head_importance)
        utils.log("Head ranked by importance scores")
        head_ranks = torch.zeros(head_importance.numel(),
                                 dtype=torch.long,
                                 device=self.model.device)
        head_ranks[head_importance.view(-1).sort(
            descending=True)[1]] = torch.arange(head_importance.numel(),
                                                device=self.model.device)
        head_ranks = head_ranks.view_as(head_importance)
        print_2d_tensor(head_ranks)
        return results
示例#12
0
    def __init__(self, config, tasks):
        super(Inference, self).__init__()
        self.config = config
        self.tasks = tasks
        if config.branching_encoder:
            utils.log("Build Branching Bert Encoder")
            self.encoder = BranchingBertModel.from_pretrained(
                config.bert_model,
                encoder_structure=config.branching_structure)
        else:
            utils.log("Build {}:{} Encoder".format(config.encoder_type,
                                                   config.bert_model))
            self.encoder = get_encoder(config.encoder_type).from_pretrained(
                config.bert_model, output_attentions=config.output_attentions)

        utils.log("Build Task Modules")
        self.tasks_modules = nn.ModuleDict()
        for task in tasks:
            if task.has_module:
                self.tasks_modules.update([(task.name, task.get_module())])
        self.task_dict = dict([(task.name, task) for task in self.tasks])
        self.dummy_input = torch.rand(1, 10, requires_grad=True)

        # self.encoder = HighwayLSTM(num_layers=3, input_size=300, hidden_size=200, layer_dropout=0.2)
        # self.word_embedding = nn.Embedding(self.config.external_vocab_size, self.config.external_vocab_embed_size)
        # self.word_embedding.weight.data.copy_(torch.from_numpy(np.load(config.external_embeddings)))
        # print("Loading embedding from {}".format(config.external_embeddings))

        self.loss_max_margin = MarginRankingLoss(margin=config.max_margin)
        self.distance = nn.PairwiseDistance(p=1)
示例#13
0
  def __init__(self, config, tasks, ext_config: Configuration=None):
    super(Model, self).__init__(config, ext_config)
    self.tasks = tasks
    utils.log("Building model")
    inference = get_inference(config)(config, tasks)
    utils.log("Switch Model to device")
    inference = inference.to(self.device)
    # TODO: need to test
    if config.multi_gpu:
      inference = torch.nn.DataParallel(inference)
    self.model = inference
    self.teacher = inference
    utils.log(self.model.__str__())

    if config.mode == "train" or config.mode == "finetune":
      self.setup_training(config, tasks)

    ## Inplace Relu
    def inplace_relu(m):
      classname = m.__class__.__name__
      if classname.find('ReLU') != -1:
        m.inplace = True

    inference.apply(inplace_relu)

    if config.adversarial_training:
      self.adversarial_agent = Adversarial(config=ext_config.adversarial_configs).to(self.device)
示例#14
0
 def save_if_best_dev_model(self, model):
     # Why it is average score here
     # TODO: double check the results format
     best_avg_score = 0
     for i, results in enumerate(self.history):
         for result in results:
             if any("train" in metric for metric, value in result):
                 continue
             if any("test" in metric for metric, value in result):
                 continue
         total, count = 0, 0
         for result in results:
             for metric, value in result:
                 if hasattr(self.config,
                            "metrics") and self.config.metrics is not None:
                     if metric in self.config.metrics:
                         total += value
                         count += 1
                 elif "distance" in metric or "f1" in metric or "las" in metric or "accuracy" in metric or "recall_left" in metric or "recall_right" in metric or "map" in metric:
                     total += value
                     count += 1
         avg_score = total / count
         if avg_score >= best_avg_score:
             best_avg_score = avg_score
             if i == len(self.history) - 1:
                 utils.log(
                     "New Score {}, New best model! Saving ...".format(
                         best_avg_score))
                 torch.save(
                     model.state_dict(),
                     os.path.join(self.config.output_dir,
                                  self.config.model_name + ".ckpt"))
                 general_config_path = os.path.join(self.config.output_dir,
                                                    "general_config.json")
                 with open(general_config_path, "w") as fout:
                     fout.write(json.dumps(vars(self.config)))
示例#15
0
  def __init__(self, config, task_name, n_classes):
    super(BiaffineDepModule, self).__init__()
    self.config = config
    self.task_name = task_name
    self.n_classes = n_classes

    if hasattr(self.config, "sequence_labeling_use_cls") and self.config.sequence_labeling_use_cls:
      self.mul = 2
      log("Use CLS in dependency parsing")
    else:
      self.mul = 1

    encoder_dim = config.hidden_size
    arc_representation_dim = tag_representation_dim = config.dep_parsing_mlp_dim
    # self.pos_tag_embedding = nn.Embedding()
    self.head_sentinel = torch.nn.Parameter(torch.randn([1, 1, config.hidden_size]))

    # TODO: Need to check the dropout attribute.
    # TODO: How to design task specific parameter configuration
    self.dropout = InputVariationalDropout(config.dropout)
    self.head_arc_feedforward = nn.Sequential(
      nn.Linear(encoder_dim, arc_representation_dim),
      nn.ELU())
    self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)
    self.head_tag_feedforward = nn.Sequential(
      nn.Linear(encoder_dim, tag_representation_dim),
      nn.ELU())
    self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

    self.arc_attention = BilinearMatrixAttention(
      matrix_1_dim=arc_representation_dim,
      matrix_2_dim=arc_representation_dim,
      use_input_biases=True)

    self.tag_bilinear = nn.modules.Bilinear(
        tag_representation_dim, tag_representation_dim, self.n_classes)
示例#16
0
def print_2d_tensor(tensor):
    """ Print a 2D tensor """
    utils.log("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
        if tensor.dtype != torch.long:
            utils.log(f"layer {row + 1}:\t" +
                      "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
        else:
            utils.log(f"layer {row + 1}:\t" +
                      "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
示例#17
0
 def initialization_from_pretrained(self, model_path):
     pretrained_dict = torch.load(
         model_path, map_location=lambda storage, location: storage)
     model_dict = self.model.model.state_dict()
     if self.ext_config.param_configs.param_assignment is not None:
         for source_key, target_key in self.ext_config.param_configs.param_assignment.items(
         ):
             model_dict[target_key] = pretrained_dict[source_key].to(
                 self.model.device)
             utils.log("Replace {} with {}".format(target_key, source_key))
     self.model.model.load_state_dict(model_dict)
     utils.log("Initialize models Restored from {}".format(model_path))
     utils.log("With keys {}".format(pretrained_dict.keys()))
示例#18
0
    def adversarial_train(self, progress: TrainingProgress):
        heading = lambda s: utils.heading(s, '(' + self.config.model_name + ')'
                                          )
        trained_on_sentences = 0
        start_time = time.time()
        generator_supervised_loss_total, generator_supervised_loss_count = 0, 0
        generator_dis_loss_total, generator_dis_loss_count = 0, 0
        discriminator_positive_loss_total, discriminator_positive_loss_count = 0, 0
        discriminator_negative_loss_total, discriminator_negative_loss_count = 0, 0
        real_acc_total, real_acc_count = 0, 0
        fake_acc_total, fake_acc_count = 0, 0
        step = 0

        for turn, labeled_mb, unlabeled_mb in self.get_training_mbs():
            labeled_mb: MiniBatch
            unlabeled_mb: MiniBatch
            if turn == TRAIN_DISCRIMINATOR:
                positive_loss, negative_loss, real_acc, fake_acc = self.model.train_discriminator(
                    labeled_mb, unlabeled_mb)
                discriminator_positive_loss_total += positive_loss
                discriminator_positive_loss_count += 1
                discriminator_negative_loss_total += negative_loss
                discriminator_negative_loss_count += 1
                real_acc_total += real_acc
                real_acc_count += 1
                fake_acc_total += fake_acc
                fake_acc_count += 1
            if turn == TRAIN_GENERATOR:
                supervised_loss, dis_loss = self.model.train_generator(
                    labeled_mb, unlabeled_mb)
                generator_supervised_loss_total += supervised_loss
                generator_supervised_loss_count += 1
                generator_dis_loss_total += dis_loss
                generator_dis_loss_count += 1

            step += 1
            if labeled_mb is not None:
                trained_on_sentences += labeled_mb.size
            if unlabeled_mb is not None:
                trained_on_sentences += unlabeled_mb.size

            # Use simplified version of logging.
            # TODO: Will check if we need to the original version
            if step % self.config.print_every == 0:
                utils.log(
                    "step {:} - "
                    "generator supervised loss {:.3f} - "
                    "generator dis loss {:.3f} - "
                    "discriminator positive loss {:.3f} - "
                    "discriminator negative loss {:.3f} - "
                    "discriminator real accuracy {:.3f} - "
                    "discriminator fake accuracy {:.3f} - "
                    "{:.1f} sentences per second".format(
                        step, generator_supervised_loss_total /
                        max(1, generator_supervised_loss_count),
                        generator_dis_loss_total /
                        max(1, generator_dis_loss_count),
                        discriminator_positive_loss_total /
                        max(1, discriminator_positive_loss_count),
                        discriminator_negative_loss_total /
                        max(1, discriminator_negative_loss_count),
                        real_acc_total / max(1, real_acc_count),
                        fake_acc_total / max(1, fake_acc_count),
                        trained_on_sentences / (time.time() - start_time)))
                generator_supervised_loss_total, generator_supervised_loss_count = 0, 0
                generator_dis_loss_total, generator_dis_loss_count = 0, 0
                discriminator_positive_loss_total, discriminator_positive_loss_count = 0, 0
                discriminator_negative_loss_total, discriminator_negative_loss_count = 0, 0
                real_acc_total, real_acc_count = 0, 0
                fake_acc_total, fake_acc_count = 0, 0

            if step % self.config.eval_dev_every == 0:
                heading("EVAL on DEV")
                self.evaluate_all_tasks(progress.history)
                progress.save_if_best_dev_model(self.model.model)
                progress.add_evaluated_step(self.model.global_step_labeled)

            if self.config.early_stop_at > 0 and step >= self.config.early_stop_at:
                utils.log("Early stop at step {}".format(step))
                break
示例#19
0
  def setup_training(self, config, tasks):
    # Calculate optimization steps
    size_train_examples = 0
    config.num_steps_in_one_epoch = 0
    if config.mode == "train" or config.mode == "finetune":
      for task in tasks:
        utils.log("{} : {}  training examples".format(task.name, task.train_set.size))
        if "loss_weight" in config.tasks[task.name]:
          utils.log("loss weight {}".format(config.tasks[task.name]["loss_weight"]))
        size_train_examples += task.train_set.size
        config.num_steps_in_one_epoch += task.train_set.size // config.tasks[task.name]["train_batch_size"]

        # config.train_batch_size = config.train_batch_size // config.gradient_accumulation_steps
        # config.test_batch_size = config.test_batch_size // config.gradient_accumulation_steps
        config.tasks[task.name]["train_batch_size"] =  config.tasks[task.name]["train_batch_size"] // config.gradient_accumulation_steps
        config.tasks[task.name]["test_batch_size"] = config.tasks[task.name]["test_batch_size"] // config.gradient_accumulation_steps
        # adjust to real training batch size
        utils.log("Training batch size: {}".format(config.tasks[task.name]["train_batch_size"]))

    if config.num_train_optimization_steps == 0:
      config.num_train_optimization_steps = config.num_steps_in_one_epoch * config.epoch_number \
        if config.schedule_lr else -1
    utils.log("Optimization steps : {}".format(config.num_train_optimization_steps))

    # Optimization
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    param_optimizer = list(self.model.named_parameters())
    # for n, p in param_optimizer:
    #   print(n)
    optimizers = {}
    optim_type = ""
    if config.sep_optim:
      utils.log("Optimizing the module using Adam optimizer ..")
      modules_parameters = [p for n, p in param_optimizer if "bert" not in n]
      bert_optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if ("bert" in n) and (not any(nd in n for nd in no_decay))],
         'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if ("bert" in n) and (any(nd in n for nd in no_decay))],
         'weight_decay': 0.0}
      ]
      optimizers["module_optimizer"] = Adam(params=modules_parameters, lr=config.adam_learning_rate)
      optimizers["bert_optimizer"] = BertAdam(bert_optimizer_grouped_parameters,
                                              lr=config.learning_rate,
                                              warmup=config.warmup_proportion,
                                              schedule=config.schedule_method,
                                              t_total=config.num_train_optimization_steps)
      optim_type = "sep_optim"
    elif config.two_stage_optim:
      utils.log("Optimizing the module with two stage")
      modules_parameters = [p for n, p in param_optimizer if "bert" not in n]
      bert_optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
      optimizers["module_optimizer"] = Adam(params=modules_parameters, lr=config.adam_learning_rate)
      optimizers["bert_optimizer"] = BertAdam(bert_optimizer_grouped_parameters,
                                              lr=config.learning_rate,
                                              warmup=config.warmup_proportion,
                                              schedule=config.schedule_method,
                                              t_total=config.num_train_optimization_steps)
      optim_type = "two_stage_optim"
    elif config.fix_bert:
      utils.log("Optimizing the module using Adam optimizer ..")
      modules_parameters = [p for n, p in param_optimizer if "bert" not in n]
      optimizers["module_optimizer"] = SGD(params=modules_parameters, lr=config.adam_learning_rate)
      optim_type = "fix_bert"
    elif config.fix_embedding:
      utils.log("Optimizing the model using one optimizer and fix embedding layer")
      bert_optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and ("word_embeddings" not in n)],
                'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ("word_embeddings" not in n)],
         'weight_decay': 0.0}]
      optimizers["bert_optimizer"] = BertAdam(bert_optimizer_grouped_parameters,
                                              lr=config.learning_rate,
                                              warmup=config.warmup_proportion,
                                              schedule=config.schedule_method,
                                              t_total=config.num_train_optimization_steps)
      optim_type = "normal"
    elif self.ext_config.encoder_configs.fix_embedding or self.ext_config.encoder_configs.fix_layers:
      utils.log("Fixing layers from config")
      print("Fix embedding {}".format(self.ext_config.encoder_configs.fix_embedding))
      print("Fix layers {}".format(self.ext_config.encoder_configs.fix_layers))
      self.ext_config: Configuration
      if config.encoder_type == 'xlmr':
        prefix = "layers"
        embed_prefix = "embed_tokens"
      elif config.encoder_type == "bert":
        prefix = "layer"
        embed_prefix = "word_embeddings"
      else:
        raise ValueError("Not supported encoder_type {}".format(config.encoder_type))
      bert_optimizer_grouped_parameters = [
        {'params': [],
         'weight_decay': 0.01},
        {'params': [],
        'weight_decay': 0.0}]

      for n, p in param_optimizer:
        if any(nd in n for nd in no_decay):
          # Checking embedding
          if embed_prefix in n:
            if not self.ext_config.encoder_configs.fix_embedding:
              bert_optimizer_grouped_parameters[1]['params'].append(p)
            else:
              print("Skip {}".format(n))
          if not any(".{}.{}.".format(prefix, l) in n for l in self.ext_config.encoder_configs.fix_layers):
            bert_optimizer_grouped_parameters[1]['params'].append(p)
          else:
            print("Skip {}".format(n))

        else:
          if embed_prefix in n:
            if not self.ext_config.encoder_configs.fix_embedding:
              bert_optimizer_grouped_parameters[0]['params'].append(p)
            else:
              print("Skip {}".format(n))
          if not any(".{}.{}.".format(prefix, l) in n for l in self.ext_config.encoder_configs.fix_layers):
            bert_optimizer_grouped_parameters[0]['params'].append(p)
          else:
            print("Skip {}".format(n))
      optimizers["bert_optimizer"] = BertAdam(bert_optimizer_grouped_parameters,
                                              lr=config.learning_rate,
                                              warmup=config.warmup_proportion,
                                              schedule=config.schedule_method,
                                              t_total=config.num_train_optimization_steps)
      optim_type = "normal"

      optimizers["bert_optimizer"] = BertAdam(bert_optimizer_grouped_parameters,
                                    lr=config.learning_rate,
                                    warmup=config.warmup_proportion,
                                    schedule=config.schedule_method,
                                    t_total=config.num_train_optimization_steps)

      optim_type = "normal"

    else:
      utils.log("Optimizing the model using one optimizer")
      bert_optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
      optimizers["bert_optimizer"] = BertAdam(bert_optimizer_grouped_parameters,
                                lr=config.learning_rate,
                                warmup=config.warmup_proportion,
                                schedule=config.schedule_method,
                                t_total=config.num_train_optimization_steps)
      optim_type = "normal"

    self.optimizer = MultipleOptimizer(optim=optimizers, optim_type=optim_type)

    self.global_step_labeled = 0
    self.global_step_unlabeled = 0
示例#20
0
    def setup_training(self, config, tasks):
        # Calculate optimization steps
        size_train_examples = 0
        config.num_steps_in_one_epoch = 0
        if config.mode == "train" or config.mode == "finetune":
            for task in tasks:
                utils.log("{} : {}  training examples".format(
                    task.name, task.train_set.size))
                if "loss_weight" in config.tasks[task.name]:
                    utils.log("loss weight {}".format(
                        config.tasks[task.name]["loss_weight"]))
                size_train_examples += task.train_set.size
                config.num_steps_in_one_epoch += task.train_set.size // config.tasks[
                    task.name]["train_batch_size"]

                # config.train_batch_size = config.train_batch_size // config.gradient_accumulation_steps
                # config.test_batch_size = config.test_batch_size // config.gradient_accumulation_steps
                config.tasks[
                    task.name]["train_batch_size"] = config.tasks[task.name][
                        "train_batch_size"] // config.gradient_accumulation_steps
                config.tasks[
                    task.name]["test_batch_size"] = config.tasks[task.name][
                        "test_batch_size"] // config.gradient_accumulation_steps
                # adjust to real training batch size
                utils.log("Training batch size: {}".format(
                    config.tasks[task.name]["train_batch_size"]))

        calculated_num_train_optimization_steps = config.num_steps_in_one_epoch * config.epoch_number \
            if config.schedule_lr else -1
        if config.num_train_optimization_steps == 0:
            config.num_train_optimization_steps = calculated_num_train_optimization_steps
        else:
            utils.log(
                "Overwriting the training steps to {} instead of {} because of the configuration"
                .format(config.num_train_optimization_steps,
                        calculated_num_train_optimization_steps))
        utils.log("Optimization steps : {}".format(
            config.num_train_optimization_steps))

        # Optimization
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        param_optimizer = list(self.model.named_parameters())
        # for n, p in param_optimizer:
        #   print(n)
        optimizers = {}
        optim_type = ""

        if config.only_adam:
            """
      This optimization method can only support finetuning all.
      This is for the adaptation of [CEDR](https://github.com/Georgetown-IR-Lab/cedr).
      Will extend this for other settings, such as layer fixing
      """
            utils.log(
                "Optimizing with Adam with {} and {} learninig rate".format(
                    config.learning_rate, config.adam_learning_rate))
            if config.fix_embedding:
                modules_parameters = {
                    "params": [
                        p for n, p in param_optimizer
                        if "bert" not in n and "word_embedding" not in n
                    ]
                }
            else:
                modules_parameters = {
                    "params":
                    [p for n, p in param_optimizer if "bert" not in n]
                }
            bert_optimizer_grouped_parameters = {
                'params': [p for n, p in param_optimizer if "bert" in n],
                'lr': config.learning_rate
            }
            if len(bert_optimizer_grouped_parameters["params"]) == 0:
                utils.log("There is no BERT module in the model.")
                optimizers["optimizer"] = torch.optim.Adam(
                    [modules_parameters], lr=config.adam_learning_rate)
            else:
                optimizers["optimizer"] = torch.optim.Adam(
                    [modules_parameters, bert_optimizer_grouped_parameters],
                    lr=config.adam_learning_rate)
            optim_type = "only_adam"

        elif config.sep_optim:
            utils.log("Optimizing the module using Adam optimizer ..")
            modules_parameters = [
                p for n, p in param_optimizer if "bert" not in n
            ]
            bert_optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if ("bert" in n) and (not any(nd in n for nd in no_decay))
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if ("bert" in n) and (any(nd in n for nd in no_decay))
                ],
                'weight_decay':
                0.0
            }]
            optimizers["module_optimizer"] = Adam(params=modules_parameters,
                                                  lr=config.adam_learning_rate)
            optimizers["bert_optimizer"] = BertAdam(
                bert_optimizer_grouped_parameters,
                lr=config.learning_rate,
                warmup=config.warmup_proportion,
                schedule=config.schedule_method,
                t_total=config.num_train_optimization_steps)
            optim_type = "sep_optim"
        elif config.two_stage_optim:
            utils.log("Optimizing the module with two stage")
            modules_parameters = [
                p for n, p in param_optimizer if "bert" not in n
            ]
            bert_optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            optimizers["module_optimizer"] = Adam(params=modules_parameters,
                                                  lr=config.adam_learning_rate)
            optimizers["bert_optimizer"] = BertAdam(
                bert_optimizer_grouped_parameters,
                lr=config.learning_rate,
                warmup=config.warmup_proportion,
                schedule=config.schedule_method,
                t_total=config.num_train_optimization_steps)
            optim_type = "two_stage_optim"
        elif config.fix_bert:
            utils.log("Optimizing the module using Adam optimizer ..")
            modules_parameters = [
                p for n, p in param_optimizer if "bert" not in n
            ]
            optimizers["module_optimizer"] = SGD(params=modules_parameters,
                                                 lr=config.adam_learning_rate)
            optim_type = "fix_bert"
        elif config.fix_embedding:
            utils.log(
                "Optimizing the model using one optimizer and fix embedding layer"
            )
            bert_optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer if not any(
                        nd in n
                        for nd in no_decay) and ("word_embeddings" not in n)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n
                           for nd in no_decay) and ("word_embeddings" not in n)
                ],
                'weight_decay':
                0.0
            }]
            optimizers["bert_optimizer"] = BertAdam(
                bert_optimizer_grouped_parameters,
                lr=config.learning_rate,
                warmup=config.warmup_proportion,
                schedule=config.schedule_method,
                t_total=config.num_train_optimization_steps)
            optim_type = "normal"
        elif self.ext_config.encoder_configs.fix_embedding or self.ext_config.encoder_configs.fix_layers:
            utils.log("Fixing layers from config")
            print("Fix embedding {}".format(
                self.ext_config.encoder_configs.fix_embedding))
            print("Fix layers {}".format(
                self.ext_config.encoder_configs.fix_layers))
            self.ext_config: Configuration
            if config.encoder_type == 'xlmr':
                prefix = "layers"
                embed_prefix = "embed_tokens"
            elif config.encoder_type == "bert":
                prefix = "layer"
                embed_prefix = "word_embeddings"
            else:
                raise ValueError("Not supported encoder_type {}".format(
                    config.encoder_type))
            bert_optimizer_grouped_parameters = [{
                'params': [],
                'weight_decay': 0.01
            }, {
                'params': [],
                'weight_decay': 0.0
            }]

            for n, p in param_optimizer:
                if any(nd in n for nd in no_decay):
                    # Checking embedding
                    if embed_prefix in n:
                        if not self.ext_config.encoder_configs.fix_embedding:
                            bert_optimizer_grouped_parameters[1][
                                'params'].append(p)
                        else:
                            print("Skip {}".format(n))
                    if not any(".{}.{}.".format(prefix, l) in n for l in
                               self.ext_config.encoder_configs.fix_layers):
                        bert_optimizer_grouped_parameters[1]['params'].append(
                            p)
                    else:
                        print("Skip {}".format(n))

                else:
                    if embed_prefix in n:
                        if not self.ext_config.encoder_configs.fix_embedding:
                            bert_optimizer_grouped_parameters[0][
                                'params'].append(p)
                        else:
                            print("Skip {}".format(n))
                    if not any(".{}.{}.".format(prefix, l) in n for l in
                               self.ext_config.encoder_configs.fix_layers):
                        bert_optimizer_grouped_parameters[0]['params'].append(
                            p)
                    else:
                        print("Skip {}".format(n))
            optimizers["bert_optimizer"] = BertAdam(
                bert_optimizer_grouped_parameters,
                lr=config.learning_rate,
                warmup=config.warmup_proportion,
                schedule=config.schedule_method,
                t_total=config.num_train_optimization_steps)
            optim_type = "normal"

        else:
            utils.log("Optimizing the model using one optimizer")
            bert_optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            optimizers["bert_optimizer"] = BertAdam(
                bert_optimizer_grouped_parameters,
                lr=config.learning_rate,
                warmup=config.warmup_proportion,
                schedule=config.schedule_method,
                t_total=config.num_train_optimization_steps)
            optim_type = "normal"

        self.optimizer = MultipleOptimizer(optim=optimizers,
                                           optim_type=optim_type)

        self.global_step_labeled = 0
        self.global_step_unlabeled = 0
示例#21
0
    def _train(self, progress: TrainingProgress):
        heading = lambda s: utils.heading(s, '(' + self.config.model_name + ')'
                                          )
        trained_on_sentences = 0
        start_time = time.time()
        unsupervised_loss_total, unsupervised_loss_count = 0, 0
        supervised_loss_total, supervised_loss_count = 0, 0
        step = 0
        # self.evaluate_all_tasks(progress.history)
        for mb in self.get_training_mbs():
            if mb.task_name not in DISTILL_TASKS:
                loss = self.model.train_labeled_abstract(mb, step)
                supervised_loss_total += loss
                supervised_loss_count += 1
            else:
                if self.config.use_external_teacher:
                    self.teacher_model.run_teacher_abstract(mb)
                else:
                    self.model.run_teacher_abstract(mb)
                loss = self.model.train_unlabeled_abstract(mb, step)
                unsupervised_loss_total += loss
                unsupervised_loss_count += 1
                mb.teacher_predictions.clear()

            step += 1
            trained_on_sentences += mb.size


            if self.model.global_step_labeled % self.config.print_every == 0 \
                  and not progress.log_in_step(self.model.global_step_labeled):
                # and self.model.global_step_unlabeled % self.config.print_every == 0 \

                # a quick patch here
                # TODO: organize better
                self.model.optimizer.update_loss(supervised_loss_total /
                                                 max(1, supervised_loss_count))

                utils.log(
                    "step supervised {:} - "
                    "step unsupervised {:} - "
                    "supervised loss: {:.3f} - "
                    "unsupervised loss : {:.3f} - "
                    "{:.1f} sentences per second".format(
                        self.model.global_step_labeled,
                        self.model.global_step_unlabeled,
                        supervised_loss_total / max(1, supervised_loss_count),
                        unsupervised_loss_total /
                        max(1, unsupervised_loss_count),
                        trained_on_sentences / (time.time() - start_time)))
                unsupervised_loss_total, unsupervised_loss_count = 0, 0
                supervised_loss_total, supervised_loss_count = 0, 0
                progress.add_log_step(self.model.global_step_labeled)

            if self.model.global_step_labeled % self.config.eval_dev_every == 0 \
                  and not progress.evaluated_in_step(self.model.global_step_labeled):
                # and self.model.global_step_unlabeled % self.config.eval_dev_every == 0 and \

                heading("EVAL on DEV")
                self.evaluate_all_tasks(progress.history)
                progress.save_if_best_dev_model(self.model.model)
                progress.add_evaluated_step(self.model.global_step_labeled)

            if self.config.early_stop_at > 0 and self.model.global_step_labeled >= self.config.early_stop_at:
                utils.log("Early stop at step {}".format(
                    self.model.global_step_labeled))
                break