コード例 #1
0
ファイル: multitask_model.py プロジェクト: vrmpx/relogic
    def train_generator(self, labeled_mb: MiniBatch, unlabeled_mb: MiniBatch):
        self.model.train()
        self.adversarial_agent.train()
        labeled_inputs = labeled_mb.generate_input(device=self.device,
                                                   use_label=True)
        unlabeled_inputs = unlabeled_mb.generate_input(device=self.device,
                                                       use_label=False)

        # A quick hack on reuse encoder output, will reorganize this later
        original_labeled_input_task_name = labeled_inputs["task_name"]
        labeled_inputs["task_name"] = ",".join(
            [labeled_inputs["task_name"], unlabeled_inputs["task_name"]])
        labeled_outputs = self.model(**labeled_inputs)
        labeled_loss = labeled_outputs[original_labeled_input_task_name][
            "loss"]
        # unlabeled_inputs["task_name"] = "encoding"
        unlabeled_outputs = self.model(**unlabeled_inputs)
        discriminator_loss = self.adversarial_agent.gen_loss(
            labeled_outputs[unlabeled_inputs["task_name"]]["logits"],
            unlabeled_outputs[unlabeled_inputs["task_name"]]["logits"], 1.0,
            0.0)
        loss = labeled_loss + discriminator_loss
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(),
                                 self.config.grad_clip)
        # For the adversarial training, we now do not support the loss accumulation
        self.optimizer.step()
        self.optimizer.zero_grad()
        return labeled_loss.item(), discriminator_loss.item()
コード例 #2
0
ファイル: multitask_model.py プロジェクト: ljj7975/relogic
  def train_labeled_abstract(self, mb: MiniBatch, step):
    self.model.train()

    inputs = mb.generate_input(device=self.device, use_label=True)
    if "input_ids" in inputs and inputs["input_ids"].size(0) == 0:
      utils.log("Zero Batch")
      return 0

    outputs = self.model(**inputs)

    # TODO: Slow process Migrating Interface ...
    if isinstance(outputs, dict):
      loss = outputs[mb.task_name]["loss"]
    else:
      if self.config.output_attentions:
        loss, _, _ = outputs
      else:
        loss, _ = outputs

    loss = mb.loss_weight * loss

    if self.config.gradient_accumulation_steps > 1:
      loss = loss / self.config.gradient_accumulation_steps
    loss.backward()
    if (step + 1) % self.config.gradient_accumulation_steps == 0:
      # TODO: a quick fix
      if not hasattr(mb, "task_name") or mb.task_name not in ["squad11", "squad20"]:
        nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
      self.optimizer.step()
      self.optimizer.zero_grad()
      self.global_step_labeled += 1
    return loss.item()
コード例 #3
0
ファイル: multitask_model.py プロジェクト: ljj7975/relogic
  def train_unlabeled_abstract(self, mb: MiniBatch, step):
    self.model.train()
    inputs = mb.generate_input(device=self.device, use_label=False)
    outputs = self.model(**inputs,
            teacher_predictions=mb.teacher_predictions)

    loss, _ = outputs

    loss = mb.loss_weight * loss

    if self.config.gradient_accumulation_steps > 1:
      loss = loss / self.config.gradient_accumulation_steps
    loss.backward()
    if (step + 1) % self.config.gradient_accumulation_steps == 0:
      # TODO: a quick fix
      if not hasattr(mb, "task_name") or mb.task_name not in ["squad11", "squad20"]:
        nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
      self.optimizer.step()
      self.optimizer.zero_grad()
      self.global_step_unlabeled += 1
    return loss.item()
コード例 #4
0
ファイル: multitask_model.py プロジェクト: ljj7975/relogic
 def run_teacher_abstract(self, mb: MiniBatch):
   self.teacher.eval()
   inputs = mb.generate_input(device=self.device, use_label=False)
   with torch.no_grad():
     results = self.teacher(**inputs)
     mb.teacher_predictions = results