def train_generator(self, labeled_mb: MiniBatch, unlabeled_mb: MiniBatch): self.model.train() self.adversarial_agent.train() labeled_inputs = labeled_mb.generate_input(device=self.device, use_label=True) unlabeled_inputs = unlabeled_mb.generate_input(device=self.device, use_label=False) # A quick hack on reuse encoder output, will reorganize this later original_labeled_input_task_name = labeled_inputs["task_name"] labeled_inputs["task_name"] = ",".join( [labeled_inputs["task_name"], unlabeled_inputs["task_name"]]) labeled_outputs = self.model(**labeled_inputs) labeled_loss = labeled_outputs[original_labeled_input_task_name][ "loss"] # unlabeled_inputs["task_name"] = "encoding" unlabeled_outputs = self.model(**unlabeled_inputs) discriminator_loss = self.adversarial_agent.gen_loss( labeled_outputs[unlabeled_inputs["task_name"]]["logits"], unlabeled_outputs[unlabeled_inputs["task_name"]]["logits"], 1.0, 0.0) loss = labeled_loss + discriminator_loss loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) # For the adversarial training, we now do not support the loss accumulation self.optimizer.step() self.optimizer.zero_grad() return labeled_loss.item(), discriminator_loss.item()
def train_labeled_abstract(self, mb: MiniBatch, step): self.model.train() inputs = mb.generate_input(device=self.device, use_label=True) if "input_ids" in inputs and inputs["input_ids"].size(0) == 0: utils.log("Zero Batch") return 0 outputs = self.model(**inputs) # TODO: Slow process Migrating Interface ... if isinstance(outputs, dict): loss = outputs[mb.task_name]["loss"] else: if self.config.output_attentions: loss, _, _ = outputs else: loss, _ = outputs loss = mb.loss_weight * loss if self.config.gradient_accumulation_steps > 1: loss = loss / self.config.gradient_accumulation_steps loss.backward() if (step + 1) % self.config.gradient_accumulation_steps == 0: # TODO: a quick fix if not hasattr(mb, "task_name") or mb.task_name not in ["squad11", "squad20"]: nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) self.optimizer.step() self.optimizer.zero_grad() self.global_step_labeled += 1 return loss.item()
def train_unlabeled_abstract(self, mb: MiniBatch, step): self.model.train() inputs = mb.generate_input(device=self.device, use_label=False) outputs = self.model(**inputs, teacher_predictions=mb.teacher_predictions) loss, _ = outputs loss = mb.loss_weight * loss if self.config.gradient_accumulation_steps > 1: loss = loss / self.config.gradient_accumulation_steps loss.backward() if (step + 1) % self.config.gradient_accumulation_steps == 0: # TODO: a quick fix if not hasattr(mb, "task_name") or mb.task_name not in ["squad11", "squad20"]: nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) self.optimizer.step() self.optimizer.zero_grad() self.global_step_unlabeled += 1 return loss.item()
def run_teacher_abstract(self, mb: MiniBatch): self.teacher.eval() inputs = mb.generate_input(device=self.device, use_label=False) with torch.no_grad(): results = self.teacher(**inputs) mb.teacher_predictions = results