def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        x, y = self.to_device(batch, pl_module.device)

        with torch.no_grad():
            representations = self.get_representations(pl_module, x)

        representations = representations.detach()

        # forward pass
        mlp_preds = pl_module.non_linear_evaluator(
            representations)  # type: ignore[operator]
        mlp_loss = F.cross_entropy(mlp_preds, y)

        # log metrics
        val_acc = accuracy(mlp_preds, y)
        pl_module.log('online_val_acc',
                      val_acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        pl_module.log('online_val_loss',
                      mlp_loss,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
    def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule,
                           outputs: Sequence, batch: Sequence, batch_idx: int,
                           dataloader_idx: int) -> None:
        x, y = self.to_device(batch, pl_module.device)

        with torch.no_grad():
            representations = self.get_representations(pl_module, x)

        representations = representations.detach()

        # forward pass
        mlp_preds = pl_module.non_linear_evaluator(
            representations)  # type: ignore[operator]
        mlp_loss = F.cross_entropy(mlp_preds, y)

        # update finetune weights
        mlp_loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        # log metrics
        train_acc = accuracy(mlp_preds, y)
        pl_module.log('online_train_acc',
                      train_acc,
                      on_step=True,
                      on_epoch=False)
        pl_module.log('online_train_loss',
                      mlp_loss,
                      on_step=True,
                      on_epoch=False)
Beispiel #3
0
    def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
        if self.n_test_episodes <= 0:
            return

        self.env_loop.seed(self.seed)
        was_in_training_mode = pl_module.training
        if self.to_eval:
            pl_module.eval()

        returns: List[float] = []
        lengths: List[float] = []

        while len(returns) < self.n_test_episodes:
            self.env_loop.reset()
            _lengths, _returns = self._eval_env_run()
            returns = returns + _returns
            lengths = lengths + _lengths

        returns_arr = np.array(returns)
        lengths_arr = np.array(lengths)

        if self.to_eval and was_in_training_mode:
            pl_module.train()

        for k, mapper in self.return_mappers.items():
            v: Any = mapper(returns_arr)
            pl_module.log(self.logging_prefix + "/test/" + k,
                          v,
                          prog_bar=False)

        for k, mapper in self.length_mappers.items():
            v: Any = mapper(lengths_arr)  # type: ignore
            pl_module.log(self.logging_prefix + "/test/" + k,
                          v,
                          prog_bar=False)
Beispiel #4
0
    def common_step(self, pred, target, prefix="train/", log=False):
        opt_key = prefix + self.hparams.optimization_metric
        loss_key = prefix + "loss"
        metrics = prefix_keys(prefix, self.collect_metrics(pred, target))

        if log:
            LightningModule.log(
                self,
                name=f"step_{loss_key}",
                value=metrics[loss_key],
                prog_bar=True,
                logger=True,
                sync_dist=self._sync_dist,
            )
            if opt_key != loss_key and opt_key in metrics:
                LightningModule.log(
                    self,
                    name=f"step_{opt_key}",
                    value=metrics[opt_key],
                    prog_bar=True,
                    logger=True,
                    sync_dist=self._sync_dist,
                )
        return {
            "loss": metrics[loss_key],
            **detach_to_cpu(metrics),
            "pred": detach_to_cpu(pred),
            "target": detach_to_cpu(target),
        }
Beispiel #5
0
    def on_validation_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        x, y = self.extract_online_finetuning_view(batch, pl_module.device)

        with torch.no_grad():
            feats = pl_module(x)

        feats = feats.detach()
        preds = pl_module.online_finetuner(feats)
        loss = F.cross_entropy(preds, y)

        acc = accuracy(F.softmax(preds, dim=1), y)
        pl_module.log('online_val_acc',
                      acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        pl_module.log('online_val_loss',
                      loss,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        def make_logging(self,
                         pl_module: pl.LightningModule,
                         func_name,
                         func_idx,
                         on_steps=[],
                         on_epochs=[],
                         prob_bars=[]):
            self.funcs_called_count[func_name] += 1
            iterate = list(
                itertools.product(*[on_steps, on_epochs, prob_bars]))
            for idx, (on_step, on_epoch, prog_bar) in enumerate(iterate):
                # run logging
                custom_func_name = f"{func_idx}_{idx}_{func_name}"
                pl_module.log(custom_func_name,
                              self.count * func_idx,
                              on_step=on_step,
                              on_epoch=on_epoch,
                              prog_bar=prog_bar)

                # catch information for verification

                # on on_train_start is outside the main loop. Won't be called
                if func_name == "on_train_start":
                    self.callback_funcs_called[func_name].append(
                        [self.count * func_idx])

                # Saved only values from second epoch, so we can compute its mean or latest.
                if pl_module.trainer.current_epoch == 1:
                    self.callback_funcs_called[func_name].append(
                        [self.count * func_idx])

                forked = on_step and on_epoch

                self.funcs_attr[custom_func_name] = {
                    "on_step": on_step,
                    "on_epoch": on_epoch,
                    "prog_bar": prog_bar,
                    "forked": forked,
                    "func_name": func_name
                }

                if on_step and on_epoch:
                    self.funcs_attr[f"{custom_func_name}_step"] = {
                        "on_step": True,
                        "on_epoch": False,
                        "prog_bar": prog_bar,
                        "forked": False,
                        "func_name": func_name
                    }

                    self.funcs_attr[f"{custom_func_name}_epoch"] = {
                        "on_step": False,
                        "on_epoch": True,
                        "prog_bar": prog_bar,
                        "forked": False,
                        "func_name": func_name
                    }
Beispiel #7
0
    def on_validation_epoch_end(self, trainer: Trainer,
                                pl_module: LightningModule) -> None:
        assert not trainer.model.training

        # Skip Sanity Check as train_dataloader is not initialized during Sanity Check
        if trainer.train_dataloader is None:
            return

        total_top1, total_num, feature_bank, target_bank = 0.0, 0, [], []

        # go through train data to generate feature bank
        for batch in trainer.train_dataloader:
            x, target = self.to_device(batch, pl_module.device)
            feature = pl_module(x).flatten(start_dim=1)
            feature = F.normalize(feature, dim=1)

            feature_bank.append(feature)
            target_bank.append(target)

        # [N, D]
        feature_bank = torch.cat(feature_bank, dim=0)
        # [N]
        target_bank = torch.cat(target_bank, dim=0)

        # switch fo PL compatibility reasons
        accel = (trainer.accelerator_connector if hasattr(
            trainer, "accelerator_connector") else
                 trainer._accelerator_connector)
        # gather representations from other gpus
        if accel.is_distributed:
            feature_bank = concat_all_gather(feature_bank, trainer.accelerator)
            target_bank = concat_all_gather(target_bank, trainer.accelerator)

        # go through val data to predict the label by weighted knn search
        for val_dataloader in trainer.val_dataloaders:
            for batch in val_dataloader:
                x, target = self.to_device(batch, pl_module.device)
                feature = pl_module(x).flatten(start_dim=1)
                feature = F.normalize(feature, dim=1)

                pred_labels = self.predict(feature, feature_bank, target_bank)

                total_num += x.shape[0]
                total_top1 += (pred_labels[:,
                                           0] == target).float().sum().item()

        pl_module.log("online_knn_val_acc",
                      total_top1 / total_num,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
Beispiel #8
0
 def __call__(
     self,
     name: str,
     translations: List[Tuple[str, List[str], Optional[str]]],
     module: pl.LightningModule,
 ):
     assert all(t[2] is not None for t in translations)
     results = self.rouge.compute(
         predictions=[t[1][0] for t in translations],
         references=[t[2] for t in translations])
     for k, v in results.items():
         module.log(f"val_{name}_{k}",
                    v.mid.fmeasure,
                    prog_bar=True,
                    on_step=False,
                    on_epoch=True)
Beispiel #9
0
    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        x, y = self.to_device(batch, pl_module.device)

        with torch.no_grad():
            representations = self.get_representations(pl_module, x)

        representations = representations.detach()

        # forward pass
        mlp_preds = pl_module.non_linear_evaluator(
            representations)  # type: ignore[operator]
        mlp_loss = F.cross_entropy(mlp_preds, y)

        # log metrics
        val_acc = accuracy(mlp_preds, y)
        pl_module.log('online_val_acc',
                      val_acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        pl_module.log('online_val_loss',
                      mlp_loss,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        self.confusion_matrix(mlp_preds, y)

        if self.time_to_sample:
            N, C, H, W = batch[0][2].shape
            num = min(N, 16)
            self.images = batch[0][2][0:num]
            self.mlp_preds = torch.argmax(mlp_preds[0:num], dim=1)
            self.labels = y[0:num]
            self.time_to_sample = False
Beispiel #10
0
 def on_validation_batch_end(
     self,
     trainer: Trainer,
     pl_module: LightningModule,
     outputs: Sequence,
     batch: Sequence,
     batch_idx: int,
     dataloader_idx: int,
 ) -> None:
     val_acc, mlp_loss = self.shared_step(pl_module, batch)
     pl_module.log("online_val_acc",
                   val_acc,
                   on_step=False,
                   on_epoch=True,
                   sync_dist=True)
     pl_module.log("online_val_loss",
                   mlp_loss,
                   on_step=False,
                   on_epoch=True,
                   sync_dist=True)
Beispiel #11
0
 def on_validation_batch_end(self, trainer: pl.Trainer,
                             pl_module: pl.LightningModule, outputs: Any,
                             batch: BatchType, batch_idx: int,
                             dataloader_idx: int) -> None:  # type: ignore
     """
     Get and log validation metrics.
     """
     ids_linear_head = tuple(
         batch[SSLDataModuleType.LINEAR_HEAD][0].tolist())
     if ids_linear_head not in self.visited_ids:
         self.visited_ids.add(ids_linear_head)
         loss = self.shared_step(batch, pl_module, is_training=False)
         pl_module.log('ssl_online_evaluator/val/loss',
                       loss,
                       on_step=False,
                       on_epoch=True,
                       sync_dist=False)
         for metric in self.val_metrics:
             pl_module.log(f"ssl_online_evaluator/val/{metric.name}",
                           metric,
                           on_epoch=True,
                           on_step=False)  # type: ignore
Beispiel #12
0
    def on_validation_epoch_end(self, trainer: Trainer,
                                pl_module: LightningModule) -> None:
        pl_module.knn_evaluator = KNeighborsClassifier(
            n_neighbors=self.num_classes)

        train_dataloader = pl_module.train_dataloader()
        representations, y = self.get_all_representations(
            pl_module, train_dataloader)

        # knn fit
        pl_module.knn_evaluator.fit(representations,
                                    y)  # type: ignore[union-attr,operator]
        train_acc = pl_module.knn_evaluator.score(
            representations, y)  # type: ignore[union-attr,operator]

        # log metrics

        val_dataloader = pl_module.val_dataloader()
        representations, y = self.get_all_representations(
            pl_module, val_dataloader)  # type: ignore[arg-type]

        # knn val acc
        val_acc = pl_module.knn_evaluator.score(
            representations, y)  # type: ignore[union-attr,operator]

        # log metrics
        pl_module.log('online_knn_train_acc',
                      train_acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        pl_module.log('online_knn_val_acc',
                      val_acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
Beispiel #13
0
    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        train_acc, mlp_loss = self.shared_step(pl_module, batch)

        # update finetune weights
        mlp_loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        pl_module.log("online_train_acc",
                      train_acc,
                      on_step=True,
                      on_epoch=False)
        pl_module.log("online_train_loss",
                      mlp_loss,
                      on_step=True,
                      on_epoch=False)
Beispiel #14
0
    def finetune_function(self, pl_module: pl.LightningModule, epoch: int,
                          optimizer: Optimizer, opt_idx: int):
        opt_pgs = optimizer.param_groups
        pl_module.log("lr head param group", opt_pgs[0]["lr"])
        if len(opt_pgs) > 1:
            pl_module.log("lr backbone param group", opt_pgs[1]["lr"])
        pl_module.log("# param groups", len(opt_pgs))

        if epoch == self.milestone:
            # unfreeze `unfreeze_top_layers` last layers
            self.unfreeze_and_add_param_group(
                modules=pl_module.
                feature_extractor[-self.unfreeze_top_layers:],
                optimizer=optimizer,
                train_bn=self.train_bn)
Beispiel #15
0
def override_unsupported_nud(lm: pl.LightningModule, context: PyTorchTrialContext) -> None:
    writer = pytorch.TorchWriter()

    def lm_print(*args: Any, **kwargs: Any) -> None:
        if context.distributed.get_rank() == 0:
            print(*args, **kwargs)

    def lm_log_dict(a_dict: Dict, *args: Any, **kwargs: Any) -> None:
        if len(args) != 0 or len(kwargs) != 0:
            raise InvalidModelException(
                f"unsupported arguments to LightningModule.log {args} {kwargs}"
            )
        for metric, value in a_dict.items():
            if type(value) == int or type(value) == float:
                writer.add_scalar(metric, value, context.current_train_batch())

    def lm_log(name: str, value: Any, *args: Any, **kwargs: Any) -> None:
        lm_log_dict({name: value}, *args, **kwargs)

    lm.print = lm_print  # type: ignore
    lm.log = lm_log  # type: ignore
    lm.log_dict = lm_log_dict  # type: ignore
Beispiel #16
0
 def on_train_batch_end(
     self,
     trainer: pl.Trainer,
     pl_module: pl.LightningModule,
     outputs: Tensor,
     batch: tuple[Tensor, Tensor, Tensor],
     batch_idx: int,
     dataloader_idx: int,
 ):
     del trainer, outputs, batch_idx, dataloader_idx
     obs, act, new_obs = batch
     pl_module.log("train/obs-mean", obs.mean())
     pl_module.log("train/obs-std", obs.std())
     pl_module.log("train/act-mean", act.mean())
     pl_module.log("train/act-std", act.std())
     pl_module.log("train/new_obs-mean", new_obs.mean())
     pl_module.log("train/new_obs-std", new_obs.std())
Beispiel #17
0
 def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,
                        outputs: Any) -> None:
     pl_module.log(name="epoch", value=trainer.current_epoch)