コード例 #1
0
 def _handle_batch_teacher_training(self, batch):
     model = get_nn_from_ddp_module(self.model)
     teacher = model["teacher"]
     t_logits = teacher(batch["features"])
     loss = self.criterion(t_logits, batch["targets"])
     self.batch["logits"] = t_logits
     self.batch_metrics["loss"] = loss
コード例 #2
0
 def _handle_batch_distillation(self, batch):
     model = get_nn_from_ddp_module(self.model)
     student, teacher = model["student"], model["teacher"]
     if self.is_train_loader:
         teacher.eval()
         set_requires_grad(teacher, False)
         t_outputs = teacher(
             batch["features"],
             output_hidden_states=self.output_hidden_states,
             return_dict=True,
         )
     s_outputs = student(
         batch["features"],
         output_hidden_states=self.output_hidden_states,
         return_dict=True,
     )
     self.batch["s_logits"] = s_outputs["logits"]
     if self.is_train_loader:
         self.batch["t_logits"] = t_outputs["logits"]
     if self.output_hidden_states and self.is_train_loader:
         self.batch["s_hidden_states"] = s_outputs["hidden_states"]
         self.batch["t_hidden_states"] = t_outputs["hidden_states"]
     self.batch_metrics["task_loss"] = self.criterion(
         batch["s_logits"], batch["targets"])
     self.batch["logits"] = self.batch[
         "s_logits"]  # for accuracy callback or other metric callback
コード例 #3
0
 def handle_batch(self, batch):
     model = get_nn_from_ddp_module(self.model)
     student, teacher = model["student"], model["teacher"]
     if self.is_train_loader:
         teacher.eval()
         set_requires_grad(teacher, False)
         t_outputs = teacher(
             batch["features"],
             output_hidden_states=self.output_hidden_states,
             return_dict=True,
         )
     s_outputs = student(
         batch["features"], output_hidden_states=self.output_hidden_states, return_dict=True,
     )
     self.batch["s_logits"] = s_outputs["logits"]
     if self.is_train_loader:
         self.batch["t_logits"] = t_outputs["logits"]
         if self.apply_probability_shift:
             self.batch["t_logits"] = probability_shift(
                 logits=self.batch["t_logits"], labels=self.batch["targets"]
             )
     if self.output_hidden_states:
         self.batch["s_hidden_states"] = s_outputs["hidden_states"]
         if self.is_train_loader:
             self.batch["t_hidden_states"] = t_outputs["hidden_states"]
コード例 #4
0
 def on_epoch_end(self, runner: "IRunner"):
     if (runner.epoch - 1) % 10 == 0:
         mel = torch.load(self.mel_path)
         hop_length = 256
         # pad input mel with zeros to cut artifact
         # see https://github.com/seungwonpark/melgan/issues/8
         zero = torch.full((1, 80, 10), -11.5129).to(mel.device)
         mel = torch.cat((mel, zero), dim=2)
         generator = get_nn_from_ddp_module(runner.model)["generator"]
         if torch.cuda.is_available():
             mel.to("cuda")
             mel = mel.type(torch.cuda.FloatTensor)
         audio = generator.forward(mel).detach().cpu()
         audio = audio.squeeze()  # collapse all dimension except time axis
         audio = audio[:-(hop_length * 10)]
         audio = MAX_WAV_VALUE * audio
         audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE - 1)
         audio = audio.short()
         audio = audio.cpu().detach().numpy()
         try:
             import wandb
             wandb.log(
                 {
                     f"generated_{runner.epoch}.wav": [
                         wandb.Audio(audio,
                                     caption=self.mel_path,
                                     sample_rate=22050)
                     ]
                 },
                 step=runner.epoch)
         except:
             Warning("can't import wandb")
         out_path = self.out_name + f"_{runner.epoch}.wav"
         write(out_path, 22050, audio)
コード例 #5
0
 def _handle_batch(self, batch: Mapping[str, Any]) -> None:
     self.output = OrderedDict()
     need_hiddens = self.is_train_loader and self.output_hiddens
     student = get_nn_from_ddp_module(self.model["student"])
     teacher = get_nn_from_ddp_module(self.model["teacher"])
     teacher.eval()
     set_requires_grad(teacher, False)
     s_outputs = student(batch["features"], output_hiddens=need_hiddens)
     t_outputs = teacher(batch["features"], output_hiddens=need_hiddens)
     if need_hiddens:
         self.output["logits"] = s_outputs[0]
         self.output["hiddens"] = s_outputs[1]
         self.output["teacher_logits"] = t_outputs[0]
         self.output["teacher_hiddens"] = t_outputs[1]
     else:
         self.output["logits"] = s_outputs
         self.output["teacher_logits"] = t_outputs
コード例 #6
0
 def get_model(self, stage: str):
     if self.model is not None:
         model = utils.get_nn_from_ddp_module(self.model)
     else:
         model = DummyModelFinetune(4, 3, 2)
     if stage == "train_freezed":
         # freeze layer
         utils.set_requires_grad(model.layer1, False)
     else:
         utils.set_requires_grad(model, True)
     return model
コード例 #7
0
ファイル: test_finetune2.py プロジェクト: vkurenkov/catalyst
 def get_model(self, stage: str):
     model = (utils.get_nn_from_ddp_module(self.model)
              if self.model is not None else nn.Sequential(
                  nn.Flatten(), nn.Linear(784, 128), nn.ReLU(),
                  nn.Linear(128, 10)))
     if stage == "train_freezed":
         # freeze layer
         utils.set_requires_grad(model[1], False)
     else:
         utils.set_requires_grad(model, True)
     return model
コード例 #8
0
ファイル: hf_runner.py プロジェクト: elephantmipt/compressors
    def handle_batch(self, batch):
        model = get_nn_from_ddp_module(self.model)
        student, teacher = model["student"], model["teacher"]
        if self.is_train_loader:
            teacher.eval()
            set_requires_grad(teacher, False)
            t_outputs = teacher(**batch, output_hidden_states=True, return_dict=True)

        s_outputs = student(**batch, output_hidden_states=True, return_dict=True)
        if self.is_train_loader:
            self.batch["t_logits"] = t_outputs["logits"]
            self.batch["t_hidden_states"] = t_outputs["hidden_states"]
        self.batch_metrics["task_loss"] = s_outputs["loss"]
        self.batch["s_logits"] = s_outputs["logits"]
        self.batch["s_hidden_states"] = s_outputs["hidden_states"]
コード例 #9
0
ファイル: runner.py プロジェクト: elephantmipt/MelGAN
 def _handle_batch(self, batch: Mapping[str, Any]) -> None:
     model = utils.get_nn_from_ddp_module(self.model)
     generator = model["generator"]
     discriminator = model["discriminator"]
     segment_length = self.loaders["train"].dataset.segment_length
     generated_audio = generator(
         batch["generator_mel"])[:, :, :segment_length]
     disc_fake = discriminator(generated_audio)  # probably slice here
     disc_real = discriminator(batch["generator_audio"])
     self.output = {"generator": {}, "discriminator": {}}
     self.output["generator"]["fake"] = disc_fake
     self.output["generator"]["real"] = disc_real
     generated_audio = generator(
         batch["discriminator_mel"])[:, :, :segment_length]
     generated_audio = generated_audio.detach()
     disc_fake = discriminator(generated_audio)  # probably slice here
     disc_real = discriminator(batch["discriminator_audio"])
     self.output["discriminator"]["fake"] = disc_fake
     self.output["discriminator"]["real"] = disc_real
コード例 #10
0
ファイル: trace.py プロジェクト: saswat0/catalyst
def trace_model_from_runner(
    runner: IRunner,
    checkpoint_name: str = None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
) -> ScriptModule:
    """
    Traces model using created experiment and runner.

    Args:
        runner (Runner): Current runner.
        checkpoint_name (str): Name of model checkpoint to use, if None
            traces current model from runner
        method_name (str): Model's method name that will be
            used as entrypoint during tracing
        mode (str): Mode for model to trace (``train`` or ``eval``)
        requires_grad (bool): Flag to use grads
        opt_level (str): AMP FP16 init level
        device (str): Torch device

    Returns:
        (ScriptModule): Traced model
    """
    logdir = runner.logdir
    model = get_nn_from_ddp_module(runner.model)

    if checkpoint_name is not None:
        dumped_checkpoint = pack_checkpoint(model=model)
        checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth"
        checkpoint = load_checkpoint(filepath=checkpoint_path)
        unpack_checkpoint(checkpoint=checkpoint, model=model)

    # getting input names of args for method since we don't have Runner
    # and we don't know input_key to preprocess batch for method call
    fn = getattr(model, method_name)
    method_argnames = _get_input_argnames(fn=fn, exclude=["self"])

    batch = {}
    for name in method_argnames:
        # TODO: We don't know input_keys without runner
        assert name in runner.input, (
            "Input batch should contain the same keys as input argument "
            "names of `forward` function to be traced correctly")
        batch[name] = runner.input[name]

    batch = any2device(batch, device)

    # Dumping previous runner of the model, we will need it to restore
    _device, _is_training, _requires_grad = (
        runner.device,
        model.training,
        get_requires_grad(model),
    )

    model.to(device)

    # Function to run prediction on batch
    def predict_fn(model: Model, inputs, **kwargs):
        return model(**inputs, **kwargs)

    traced_model = trace_model(
        model=model,
        predict_fn=predict_fn,
        batch=batch,
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        device=device,
    )

    if checkpoint_name is not None:
        unpack_checkpoint(checkpoint=dumped_checkpoint, model=model)

    # Restore previous runner of the model
    getattr(model, "train" if _is_training else "eval")()
    set_requires_grad(model, _requires_grad)
    model.to(_device)

    return traced_model