def _handle_batch_teacher_training(self, batch): model = get_nn_from_ddp_module(self.model) teacher = model["teacher"] t_logits = teacher(batch["features"]) loss = self.criterion(t_logits, batch["targets"]) self.batch["logits"] = t_logits self.batch_metrics["loss"] = loss
def _handle_batch_distillation(self, batch): model = get_nn_from_ddp_module(self.model) student, teacher = model["student"], model["teacher"] if self.is_train_loader: teacher.eval() set_requires_grad(teacher, False) t_outputs = teacher( batch["features"], output_hidden_states=self.output_hidden_states, return_dict=True, ) s_outputs = student( batch["features"], output_hidden_states=self.output_hidden_states, return_dict=True, ) self.batch["s_logits"] = s_outputs["logits"] if self.is_train_loader: self.batch["t_logits"] = t_outputs["logits"] if self.output_hidden_states and self.is_train_loader: self.batch["s_hidden_states"] = s_outputs["hidden_states"] self.batch["t_hidden_states"] = t_outputs["hidden_states"] self.batch_metrics["task_loss"] = self.criterion( batch["s_logits"], batch["targets"]) self.batch["logits"] = self.batch[ "s_logits"] # for accuracy callback or other metric callback
def handle_batch(self, batch): model = get_nn_from_ddp_module(self.model) student, teacher = model["student"], model["teacher"] if self.is_train_loader: teacher.eval() set_requires_grad(teacher, False) t_outputs = teacher( batch["features"], output_hidden_states=self.output_hidden_states, return_dict=True, ) s_outputs = student( batch["features"], output_hidden_states=self.output_hidden_states, return_dict=True, ) self.batch["s_logits"] = s_outputs["logits"] if self.is_train_loader: self.batch["t_logits"] = t_outputs["logits"] if self.apply_probability_shift: self.batch["t_logits"] = probability_shift( logits=self.batch["t_logits"], labels=self.batch["targets"] ) if self.output_hidden_states: self.batch["s_hidden_states"] = s_outputs["hidden_states"] if self.is_train_loader: self.batch["t_hidden_states"] = t_outputs["hidden_states"]
def on_epoch_end(self, runner: "IRunner"): if (runner.epoch - 1) % 10 == 0: mel = torch.load(self.mel_path) hop_length = 256 # pad input mel with zeros to cut artifact # see https://github.com/seungwonpark/melgan/issues/8 zero = torch.full((1, 80, 10), -11.5129).to(mel.device) mel = torch.cat((mel, zero), dim=2) generator = get_nn_from_ddp_module(runner.model)["generator"] if torch.cuda.is_available(): mel.to("cuda") mel = mel.type(torch.cuda.FloatTensor) audio = generator.forward(mel).detach().cpu() audio = audio.squeeze() # collapse all dimension except time axis audio = audio[:-(hop_length * 10)] audio = MAX_WAV_VALUE * audio audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE - 1) audio = audio.short() audio = audio.cpu().detach().numpy() try: import wandb wandb.log( { f"generated_{runner.epoch}.wav": [ wandb.Audio(audio, caption=self.mel_path, sample_rate=22050) ] }, step=runner.epoch) except: Warning("can't import wandb") out_path = self.out_name + f"_{runner.epoch}.wav" write(out_path, 22050, audio)
def _handle_batch(self, batch: Mapping[str, Any]) -> None: self.output = OrderedDict() need_hiddens = self.is_train_loader and self.output_hiddens student = get_nn_from_ddp_module(self.model["student"]) teacher = get_nn_from_ddp_module(self.model["teacher"]) teacher.eval() set_requires_grad(teacher, False) s_outputs = student(batch["features"], output_hiddens=need_hiddens) t_outputs = teacher(batch["features"], output_hiddens=need_hiddens) if need_hiddens: self.output["logits"] = s_outputs[0] self.output["hiddens"] = s_outputs[1] self.output["teacher_logits"] = t_outputs[0] self.output["teacher_hiddens"] = t_outputs[1] else: self.output["logits"] = s_outputs self.output["teacher_logits"] = t_outputs
def get_model(self, stage: str): if self.model is not None: model = utils.get_nn_from_ddp_module(self.model) else: model = DummyModelFinetune(4, 3, 2) if stage == "train_freezed": # freeze layer utils.set_requires_grad(model.layer1, False) else: utils.set_requires_grad(model, True) return model
def get_model(self, stage: str): model = (utils.get_nn_from_ddp_module(self.model) if self.model is not None else nn.Sequential( nn.Flatten(), nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10))) if stage == "train_freezed": # freeze layer utils.set_requires_grad(model[1], False) else: utils.set_requires_grad(model, True) return model
def handle_batch(self, batch): model = get_nn_from_ddp_module(self.model) student, teacher = model["student"], model["teacher"] if self.is_train_loader: teacher.eval() set_requires_grad(teacher, False) t_outputs = teacher(**batch, output_hidden_states=True, return_dict=True) s_outputs = student(**batch, output_hidden_states=True, return_dict=True) if self.is_train_loader: self.batch["t_logits"] = t_outputs["logits"] self.batch["t_hidden_states"] = t_outputs["hidden_states"] self.batch_metrics["task_loss"] = s_outputs["loss"] self.batch["s_logits"] = s_outputs["logits"] self.batch["s_hidden_states"] = s_outputs["hidden_states"]
def _handle_batch(self, batch: Mapping[str, Any]) -> None: model = utils.get_nn_from_ddp_module(self.model) generator = model["generator"] discriminator = model["discriminator"] segment_length = self.loaders["train"].dataset.segment_length generated_audio = generator( batch["generator_mel"])[:, :, :segment_length] disc_fake = discriminator(generated_audio) # probably slice here disc_real = discriminator(batch["generator_audio"]) self.output = {"generator": {}, "discriminator": {}} self.output["generator"]["fake"] = disc_fake self.output["generator"]["real"] = disc_real generated_audio = generator( batch["discriminator_mel"])[:, :, :segment_length] generated_audio = generated_audio.detach() disc_fake = discriminator(generated_audio) # probably slice here disc_real = discriminator(batch["discriminator_audio"]) self.output["discriminator"]["fake"] = disc_fake self.output["discriminator"]["real"] = disc_real
def trace_model_from_runner( runner: IRunner, checkpoint_name: str = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", ) -> ScriptModule: """ Traces model using created experiment and runner. Args: runner (Runner): Current runner. checkpoint_name (str): Name of model checkpoint to use, if None traces current model from runner method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): AMP FP16 init level device (str): Torch device Returns: (ScriptModule): Traced model """ logdir = runner.logdir model = get_nn_from_ddp_module(runner.model) if checkpoint_name is not None: dumped_checkpoint = pack_checkpoint(model=model) checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" checkpoint = load_checkpoint(filepath=checkpoint_path) unpack_checkpoint(checkpoint=checkpoint, model=model) # getting input names of args for method since we don't have Runner # and we don't know input_key to preprocess batch for method call fn = getattr(model, method_name) method_argnames = _get_input_argnames(fn=fn, exclude=["self"]) batch = {} for name in method_argnames: # TODO: We don't know input_keys without runner assert name in runner.input, ( "Input batch should contain the same keys as input argument " "names of `forward` function to be traced correctly") batch[name] = runner.input[name] batch = any2device(batch, device) # Dumping previous runner of the model, we will need it to restore _device, _is_training, _requires_grad = ( runner.device, model.training, get_requires_grad(model), ) model.to(device) # Function to run prediction on batch def predict_fn(model: Model, inputs, **kwargs): return model(**inputs, **kwargs) traced_model = trace_model( model=model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) if checkpoint_name is not None: unpack_checkpoint(checkpoint=dumped_checkpoint, model=model) # Restore previous runner of the model getattr(model, "train" if _is_training else "eval")() set_requires_grad(model, _requires_grad) model.to(_device) return traced_model