Ejemplo n.º 1
0
def get_dataflow(config):
    # - Get train/test datasets
    if idist.get_rank() > 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    train_dataset, test_dataset = utils.get_train_test_datasets(
        config["data_path"])

    if idist.get_rank() == 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=True,
        pin_memory="cuda" in idist.device().type,
        drop_last=True,
    )

    test_loader = idist.auto_dataloader(
        test_dataset,
        batch_size=2 * config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=False,
        pin_memory="cuda" in idist.device().type,
    )
    return train_loader, test_loader
Ejemplo n.º 2
0
def get_dataflow(config):
    # - Get train/test datasets
    if idist.get_rank() > 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    train_dataset, test_dataset = utils.get_dataset(config["data_dir"],
                                                    config["model"],
                                                    config["tokenizer_dir"],
                                                    config["max_length"])

    if idist.get_rank() == 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=True,
        drop_last=True,
    )

    test_loader = idist.auto_dataloader(
        test_dataset,
        batch_size=2 * config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=False,
    )
    return train_loader, test_loader
def get_dataflow(config: ConfigSchema, wlm: WeakLabelManager) -> Dict[str, DataLoader]:
    # - Get train/test datasets
    if idist.get_rank() > 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    dataset = get_dataset(config.dataset, config.data_path)
    train_split = wlm.convert_targets(dataset["train"])

    if idist.get_rank() == 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_split,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        drop_last=True,
    )
    val_loader = idist.auto_dataloader(
        dataset["val"],
        batch_size=2 * config.batch_size,
        num_workers=config.num_workers,
        shuffle=False,
    )
    test_loader = idist.auto_dataloader(
        dataset["test"],
        batch_size=2 * config.batch_size,
        num_workers=config.num_workers,
        shuffle=False,
    )
    return {"train": train_loader, "val": val_loader, "test": test_loader}
Ejemplo n.º 4
0
def get_dataflow(config):
    # - Get train/test datasets
    if idist.get_rank() > 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    train_dataset, test_dataset = utils.get_train_test_datasets(
        config["data_path"],
        **{k: config[k]
           for k in ["rescale_size", "rand_aug", "rand_erasing"]},
    )

    if idist.get_rank() == 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=True,
        drop_last=True,
    )

    test_loader = idist.auto_dataloader(
        test_dataset,
        batch_size=2 * config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=False,
    )
    return train_loader, test_loader
Ejemplo n.º 5
0
def get_dataflow(config):
    # - Get train/test datasets
    if idist.get_local_rank() > 0:
        # Ensure that only local rank 0 download the dataset
        # Thus each node will download a copy of the dataset
        idist.barrier()

    train_dataset, test_dataset = utils.get_train_test_datasets(
        config["data_path"])

    if idist.get_local_rank() == 0:
        # Ensure that only local rank 0 download the dataset
        idist.barrier()

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=True,
        drop_last=True,
    )

    test_loader = idist.auto_dataloader(
        test_dataset,
        batch_size=2 * config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=False,
    )
    return train_loader, test_loader
Ejemplo n.º 6
0
def get_dataflow(config):

    # - Get train/test datasets
    if idist.get_rank() > 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    train_dataset, test_dataset = get_train_test_datasets(config.get("data_path", "."))

    if idist.get_rank() == 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset,
        batch_size=config.get("batch_size", 512),
        num_workers=config.get("num_workers", 8),
        shuffle=True,
        drop_last=True,
    )
    config["num_iters_per_epoch"] = len(train_loader)

    test_loader = idist.auto_dataloader(
        test_dataset,
        batch_size=2 * config.get("batch_size", 512),
        num_workers=config.get("num_workers", 8),
        shuffle=False,
    )
    return train_loader, test_loader
    def _init_distribution(self):
        self.rank = idist.get_rank()
        manual_seed(42 + self.rank)
        self.device = idist.device()

        if self.train_ds:
            if self.train_ds.sampler is not None:
                sampler = self.train_ds.sampler(self.train_ds,
                                                self.train_ds.get_label)
                isShuffle = False
            else:
                sampler = None
                isShuffle = True
            self.train_loader = idist.auto_dataloader(
                self.train_ds,
                batch_size=self.hparams.train_bs,
                num_workers=self.hparams.train_num_workers,
                shuffle=isShuffle,
                drop_last=True,
                sampler=sampler,
                **self.train_ds.additional_loader_params)

        if self.valid_ds:
            self.valid_loader = idist.auto_dataloader(
                self.valid_ds,
                batch_size=self.hparams.valid_bs,
                num_workers=self.hparams.valid_num_workers,
                shuffle=False,
                drop_last=False,
                **self.valid_ds.additional_loader_params)

        if self.test_ds:
            self.test_loader = idist.auto_dataloader(
                self.test_ds,
                batch_size=self.hparams.valid_bs,
                num_workers=self.hparams.valid_num_workers,
                shuffle=False,
                drop_last=False,
                **self.test_ds.additional_loader_params)

        if USE_AMP:
            self._init_optimizer()
            self.model = idist.auto_model(self.model)
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O1")
        else:
            self.model = idist.auto_model(self.model)

        if not USE_AMP:
            self._init_optimizer()

        self.optimizer = idist.auto_optim(self.optimizer)

        self._init_scheduler()

        self.criterion = self.criterion.to(self.device)
Ejemplo n.º 8
0
def get_train_val_loaders(
    root_path: str,
    train_transforms: Callable,
    val_transforms: Callable,
    batch_size: int = 16,
    num_workers: int = 8,
    val_batch_size: Optional[int] = None,
    limit_train_num_samples: Optional[int] = None,
    limit_val_num_samples: Optional[int] = None,
) -> Tuple[DataLoader, DataLoader, DataLoader]:

    train_ds = ImageNet(
        root_path, split="train", transform=lambda sample: train_transforms(image=sample)["image"], loader=opencv_loader
    )
    val_ds = ImageNet(
        root_path, split="val", transform=lambda sample: val_transforms(image=sample)["image"], loader=opencv_loader
    )

    if limit_train_num_samples is not None:
        np.random.seed(limit_train_num_samples)
        train_indices = np.random.permutation(len(train_ds))[:limit_train_num_samples]
        train_ds = Subset(train_ds, train_indices)

    if limit_val_num_samples is not None:
        np.random.seed(limit_val_num_samples)
        val_indices = np.random.permutation(len(val_ds))[:limit_val_num_samples]
        val_ds = Subset(val_ds, val_indices)

    # random samples for evaluation on training dataset
    if len(val_ds) < len(train_ds):
        np.random.seed(len(val_ds))
        train_eval_indices = np.random.permutation(len(train_ds))[: len(val_ds)]
        train_eval_ds = Subset(train_ds, train_eval_indices)
    else:
        train_eval_ds = train_ds

    train_loader = idist.auto_dataloader(
        train_ds, shuffle=True, batch_size=batch_size, num_workers=num_workers, drop_last=True,
    )

    val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size
    val_loader = idist.auto_dataloader(
        val_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
    )

    train_eval_loader = idist.auto_dataloader(
        train_eval_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
    )

    return train_loader, val_loader, train_eval_loader
Ejemplo n.º 9
0
def get_dataloaders(config):
    dataset_train, dataset_val = get_train_val_datasets(config)

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    dataloader_train = idist.auto_dataloader(
        dataset_train,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=True,
        drop_last=True,
    )

    dataloader_val = idist.auto_dataloader(
        dataset_val,
        batch_size=2 * config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=False,
    )
    return dataloader_train, dataloader_val
Ejemplo n.º 10
0
def get_imagenet_dataloader(config):
    train_dir = os.path.join(config["data_path"], 'train')
    val_dir = os.path.join(config["data_path"], 'val')
    print(train_dir)
    print(val_dir)
    train_dataset, test_dataset = load_data(train_dir, val_dir,
                                            config["cache_dataset"])
    print(len(train_dataset))
    train_loader = idist.auto_dataloader(train_dataset,
                                         batch_size=config["batch_size"],
                                         shuffle=True,
                                         num_workers=config["num_workers"],
                                         pin_memory=True)

    test_loader = idist.auto_dataloader(test_dataset,
                                        batch_size=config["batch_size"],
                                        shuffle=False,
                                        num_workers=config["num_workers"],
                                        pin_memory=True)

    return train_loader, test_loader
Ejemplo n.º 11
0
def get_dataloader(config: Config):
    dataset, num_channels = get_dataset(config)

    loader = auto_dataloader(
        dataset,
        batch_size=config.batch_size * idist.get_world_size(),
        num_workers=config.num_workers,
        shuffle=True,
        drop_last=True
    )

    return loader, num_channels
Ejemplo n.º 12
0
def get_cta_probe_loader(supervised_train_dataset, cta, **dataloader_kwargs):
    dataloader_kwargs["pin_memory"] = "cuda" in idist.device().type
    dataloader_kwargs["drop_last"] = False
    dataloader_kwargs["shuffle"] = dataloader_kwargs.get("sampler", None) is None

    cta_probe_loader = idist.auto_dataloader(
        TransformedDataset(
            supervised_train_dataset, transforms=partial(cta_probe_transforms, cta=cta)
        ),
        **dataloader_kwargs
    )

    return cta_probe_loader
Ejemplo n.º 13
0
def get_dataloader(dataset,
                   sampler=None,
                   shuffle=False,
                   limit_num_samples=None,
                   **kwargs):

    if limit_num_samples is not None:
        np.random.seed(limit_num_samples)
        indices = np.random.permutation(len(dataset))[:limit_num_samples]
        dataset = Subset(dataset, indices)

    return idist.auto_dataloader(dataset,
                                 sampler=sampler,
                                 shuffle=(sampler is None) and shuffle,
                                 **kwargs)
Ejemplo n.º 14
0
def get_supervised_train_loader(
    supervised_train_dataset, transforms=weak_transforms, **dataloader_kwargs
):
    dataloader_kwargs["pin_memory"] = "cuda" in idist.device().type
    dataloader_kwargs["drop_last"] = True
    dataloader_kwargs["shuffle"] = dataloader_kwargs.get("sampler", None) is None

    supervised_train_loader = idist.auto_dataloader(
        TransformedDataset(
            supervised_train_dataset,
            transforms=lambda d: {"image": transforms(d[0]), "target": d[1]},
        ),
        **dataloader_kwargs
    )
    return supervised_train_loader
Ejemplo n.º 15
0
def get_dataflow(config):
    # - Get train/test datasets
    if idist.get_local_rank() > 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    train_dataset, test_dataset = get_dataset(
        config.data_dir, config.model, config.tokenizer_dir, config.max_length
    )

    if idist.get_local_rank() == 0:
        # Ensure that only rank 0 download the dataset
        idist.barrier()

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        drop_last=True,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers = True,
        {% endif %}
    )

    test_loader = idist.auto_dataloader(
        test_dataset,
        batch_size=2 * config.batch_size,
        num_workers=config.num_workers,
        shuffle=False,
        {% if use_distributed_training and not use_distributed_launcher %}
    persistent_workers = True,
        {% endif %}
    )
    return train_loader, test_loader
Ejemplo n.º 16
0
    def _test_score(metric_device):
        from torchvision import models
        from torchvision.datasets import FakeData

        from ignite.engine import Engine

        inception_model = models.inception_v3(
            pretrained=True).eval().to(metric_device)
        dataset = FakeData(size=64,
                           transform=transforms.Compose(
                               [transforms.Resize(299),
                                transforms.ToTensor()]))
        dataset = IgnoreLabelDataset(dataset)
        dataloader = idist.auto_dataloader(dataset, batch_size=32)

        def np_compute(dataloader, splits):
            def get_pred(x):
                x = inception_model(x)
                return F.softmax(x).detach().cpu().numpy()

            preds = []
            for i, batch in enumerate(dataloader):
                preds.append(get_pred(batch))

            split_scores = np.zeros((splits, ))
            preds = np.vstack(preds)
            N = preds.shape[0]
            for i in range(splits):
                part = preds[i * N // splits:(i + 1) * N // splits, :]
                kl = part * (np.log(part) -
                             np.log(np.mean(part, axis=0, keepdims=True)))
                kl = np.mean(np.sum(kl, axis=1))
                split_scores[i] = np.exp(kl)

            return np.mean(split_scores)

        def process_func(engine, batch):
            return batch

        inception_score = InceptionScore(device=metric_device)
        test_engine = Engine(process_func)
        inception_score.attach(test_engine, "score")
        np_is = np_compute(dataloader, 10)
        state = test_engine.run(dataloader)
        computed_is = state.metrics["score"]
        assert pytest.approx(computed_is, 0.1) == np_is
Ejemplo n.º 17
0
def get_test_loader(
    root, transforms=test_transforms, download=True, **dataloader_kwargs
):
    full_test_dataset = CIFAR10(root, train=False, download=download)

    dataloader_kwargs["pin_memory"] = "cuda" in idist.device().type
    dataloader_kwargs["drop_last"] = False
    dataloader_kwargs["shuffle"] = False

    test_loader = idist.auto_dataloader(
        TransformedDataset(
            full_test_dataset,
            transforms=lambda dp: {"image": transforms(dp[0]), "target": dp[1]},
        ),
        **dataloader_kwargs
    )
    return test_loader
Ejemplo n.º 18
0
def get_unsupervised_train_loader(
    raw_dataset, transforms_weak, transforms_strong, **dataloader_kwargs
):
    unsupervised_train_dataset = TransformedDataset(
        raw_dataset,
        transforms=lambda dp: {
            "image": transforms_weak(dp[0]),
            "strong_aug": transforms_strong(dp[0]),
        },
    )

    dataloader_kwargs["drop_last"] = True
    dataloader_kwargs["pin_memory"] = "cuda" in idist.device().type
    dataloader_kwargs["shuffle"] = dataloader_kwargs.get("sampler", None) is None

    unsupervised_train_loader = idist.auto_dataloader(
        unsupervised_train_dataset, **dataloader_kwargs
    )
    return unsupervised_train_loader
Ejemplo n.º 19
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.dataset}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------
    # TODO : PLEASE provide your custom datasets and dataloaders configurations
    # we can use `idist.auto_dataloader` to handle distributed configurations
    # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
    # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

    train_dataset, eval_dataset = get_datasets()

    train_dataloader = idist.auto_dataloader(train_dataset, **kwargs)
    eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    model, optimizer, loss_fn, lr_scheduler = initialize()

    # -----------------------------
    # trainer and evaluator
    # -----------------------------

    trainer, evaluator = create_trainers(
        config=config,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    # -------------------------------------------
    # update config with optimizer parameters
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    config.__dict__.update(**optimizer.defaults)
    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger
    evaluator.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model=model,
        trainer=trainer,
        evaluator=evaluator,
        metric_name=None,
        # TODO : replace with the metric name to save the best model
        # if you check `Save the best model by evaluation score` otherwise leave it None
        # metric must be in evaluator.state.metrics.
        es_metric_name=None,
        # TODO : replace with the metric name to early stop
        # if you check `Early stop the training by evaluation score` otherwise leave it None
        # metric must be in evaluator.state.metrics.
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=None,
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(
            config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer
        )

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------------------
    # let's trigger custom events we registered
    # we will use a `event_filter` to trigger that
    # `event_filter` has to return boolean
    # whether this event should be executed
    # here will log the gradients on the 1st iteration
    # and every 100 iterations
    # --------------------------------------------

    @trainer.on(TrainEvents.BACKWARD_COMPLETED(lambda _, ev: (ev % 100 == 0) or (ev == 1)))
    def _():
        # do something interesting
        pass

    # ----------------------------------------
    # here we will use `every` to trigger
    # every 100 iterations
    # ----------------------------------------

    @trainer.on(TrainEvents.OPTIM_STEP_COMPLETED(every=100))
    def _():
        # do something interesting
        pass

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ---------------------------------------------
    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    # ---------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, "eval")

    # --------------------------------------------------
    # let's try run evaluation first as a sanity check
    # --------------------------------------------------

    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------
    # TODO : PLEASE provide `max_epochs` parameters

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Ejemplo n.º 20
0
def get_train_val_loaders(
    root_path: str,
    train_transforms: Callable,
    val_transforms: Callable,
    batch_size: int = 16,
    num_workers: int = 8,
    val_batch_size: Optional[int] = None,
    with_sbd: Optional[str] = None,
    limit_train_num_samples: Optional[int] = None,
    limit_val_num_samples: Optional[int] = None,
) -> Tuple[DataLoader, DataLoader, DataLoader]:

    train_ds = get_train_dataset(root_path)
    val_ds = get_val_dataset(root_path)

    if with_sbd is not None:
        sbd_train_ds = get_train_noval_sbdataset(with_sbd)
        train_ds = ConcatDataset([train_ds, sbd_train_ds])

    if limit_train_num_samples is not None:
        np.random.seed(limit_train_num_samples)
        train_indices = np.random.permutation(
            len(train_ds))[:limit_train_num_samples]
        train_ds = Subset(train_ds, train_indices)

    if limit_val_num_samples is not None:
        np.random.seed(limit_val_num_samples)
        val_indices = np.random.permutation(
            len(val_ds))[:limit_val_num_samples]
        val_ds = Subset(val_ds, val_indices)

    # random samples for evaluation on training dataset
    if len(val_ds) < len(train_ds):
        np.random.seed(len(val_ds))
        train_eval_indices = np.random.permutation(len(train_ds))[:len(val_ds)]
        train_eval_ds = Subset(train_ds, train_eval_indices)
    else:
        train_eval_ds = train_ds

    train_ds = TransformedDataset(train_ds, transform_fn=train_transforms)
    val_ds = TransformedDataset(val_ds, transform_fn=val_transforms)
    train_eval_ds = TransformedDataset(train_eval_ds,
                                       transform_fn=val_transforms)

    train_loader = idist.auto_dataloader(
        train_ds,
        shuffle=True,
        batch_size=batch_size,
        num_workers=num_workers,
        drop_last=True,
    )

    val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size
    val_loader = idist.auto_dataloader(
        val_ds,
        shuffle=False,
        batch_size=val_batch_size,
        num_workers=num_workers,
        drop_last=False,
    )

    train_eval_loader = idist.auto_dataloader(
        train_eval_ds,
        shuffle=False,
        batch_size=val_batch_size,
        num_workers=num_workers,
        drop_last=False,
    )

    return train_loader, val_loader, train_eval_loader
Ejemplo n.º 21
0
def training(rank, config):
    rank = idist.get_rank()
    manual_seed(config["seed"] + rank)
    device = idist.device()

    # Define output folder:
    config.output = "/tmp/output"

    model = idist.auto_model(config.model)
    optimizer = idist.auto_optim(config.optimizer)
    criterion = config.criterion

    train_set, val_set = config.train_set, config.val_set
    train_loader = idist.auto_dataloader(train_set,
                                         batch_size=config.train_batch_size)
    val_loader = idist.auto_dataloader(val_set,
                                       batch_size=config.val_batch_size)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED(every=config.val_interval))
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    if rank == 0:
        tb_logger = TensorboardLogger(log_dir=config.output)

        tb_logger.attach_output_handler(
            trainer,
            event_name=Events.ITERATION_COMPLETED(every=100),
            tag="training",
            output_transform=lambda loss: {"batchloss": loss},
            metric_names="all",
        )

        for tag, evaluator in [("training", train_evaluator),
                               ("validation", validation_evaluator)]:
            tb_logger.attach_output_handler(
                evaluator,
                event_name=Events.EPOCH_COMPLETED,
                tag=tag,
                metric_names=["loss", "accuracy"],
                global_step_transform=global_step_from_engine(trainer),
            )

        tb_logger.attach_opt_params_handler(
            trainer,
            event_name=Events.ITERATION_COMPLETED(every=100),
            optimizer=optimizer)

    model_checkpoint = ModelCheckpoint(
        config.output,
        n_saved=2,
        filename_prefix="best",
        score_name="accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if rank == 0:
        tb_logger.close()
Ejemplo n.º 22
0
def run(
    local_rank: int,
    device: str,
    experiment_name: str,
    gpus: Optional[Union[int, List[int], str]] = None,
    dataset_root: str = "./dataset",
    log_dir: str = "./log",
    model: str = "fasterrcnn_resnet50_fpn",
    epochs: int = 13,
    batch_size: int = 4,
    lr: float = 0.01,
    download: bool = False,
    image_size: int = 256,
    resume_from: Optional[dict] = None,
) -> None:
    bbox_params = A.BboxParams(format="pascal_voc")
    train_transform = A.Compose(
        [A.HorizontalFlip(p=0.5), ToTensorV2()],
        bbox_params=bbox_params,
    )
    val_transform = A.Compose([ToTensorV2()], bbox_params=bbox_params)

    download = local_rank == 0 and download
    train_dataset = Dataset(root=dataset_root,
                            download=download,
                            image_set="train",
                            transforms=train_transform)
    val_dataset = Dataset(root=dataset_root,
                          download=download,
                          image_set="val",
                          transforms=val_transform)
    vis_dataset = Subset(val_dataset,
                         random.sample(range(len(val_dataset)), k=16))

    train_dataloader = idist.auto_dataloader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             collate_fn=collate_fn,
                                             num_workers=4)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                collate_fn=collate_fn,
                                num_workers=4)
    vis_dataloader = DataLoader(vis_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                collate_fn=collate_fn,
                                num_workers=4)

    model = idist.auto_model(model)
    scaler = GradScaler()
    optimizer = SGD(lr=lr, params=model.parameters())
    optimizer = idist.auto_optim(optimizer)
    scheduler = OneCycleLR(optimizer,
                           max_lr=lr,
                           total_steps=len(train_dataloader) * epochs)

    def update_model(engine, batch):
        model.train()
        images, targets = batch
        images = list(image.to(device) for image in images)
        targets = [{
            k: v.to(device)
            for k, v in t.items() if isinstance(v, torch.Tensor)
        } for t in targets]

        with torch.autocast(device, enabled=True):
            loss_dict = model(images, targets)
            loss = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loss_items = {k: v.item() for k, v in loss_dict.items()}
        loss_items["loss_average"] = loss.item() / 4

        return loss_items

    @torch.no_grad()
    def inference(engine, batch):
        model.eval()
        images, targets = batch
        images = list(image.to(device) for image in images)
        outputs = model(images)
        outputs = [{k: v.to("cpu") for k, v in t.items()} for t in outputs]
        return {
            "y_pred": outputs,
            "y": targets,
            "x": [i.cpu() for i in images]
        }

    trainer = Engine(update_model)
    evaluator = Engine(inference)
    visualizer = Engine(inference)

    aim_logger = AimLogger(
        repo=os.path.join(log_dir, "aim"),
        experiment=experiment_name,
    )

    CocoMetric(convert_to_coco_api(val_dataset)).attach(evaluator, "mAP")

    @trainer.on(Events.EPOCH_COMPLETED)
    @one_rank_only()
    def log_validation_results(engine):
        evaluator.run(val_dataloader)
        visualizer.run(vis_dataloader)

    @trainer.on(Events.ITERATION_COMPLETED)
    def step_scheduler(engine):
        scheduler.step()
        aim_logger.log_metrics({"lr": scheduler.get_last_lr()[0]},
                               step=engine.state.iteration)

    @visualizer.on(Events.EPOCH_STARTED)
    def reset_vis_images(engine):
        engine.state.model_outputs = []

    @visualizer.on(Events.ITERATION_COMPLETED)
    def add_vis_images(engine):
        engine.state.model_outputs.append(engine.state.output)

    @visualizer.on(Events.ITERATION_COMPLETED)
    def submit_vis_images(engine):
        aim_images = []
        for outputs in engine.state.model_outputs:
            for image, target, pred in zip(outputs["x"], outputs["y"],
                                           outputs["y_pred"]):
                image = (image * 255).byte()
                pred_labels = [
                    Dataset.class2name[label.item()]
                    for label in pred["labels"]
                ]
                pred_boxes = pred["boxes"].long()
                image = draw_bounding_boxes(image,
                                            pred_boxes,
                                            pred_labels,
                                            colors="red")

                target_labels = [
                    Dataset.class2name[label.item()]
                    for label in target["labels"]
                ]
                target_boxes = target["boxes"].long()
                image = draw_bounding_boxes(image,
                                            target_boxes,
                                            target_labels,
                                            colors="green")

                aim_images.append(aim.Image(image.numpy().transpose(
                    (1, 2, 0))))
        aim_logger.experiment.track(aim_images,
                                    name="vis",
                                    step=trainer.state.epoch)

    losses = [
        "loss_classifier", "loss_box_reg", "loss_objectness",
        "loss_rpn_box_reg", "loss_average"
    ]
    for loss_name in losses:
        RunningAverage(output_transform=lambda x: x[loss_name]).attach(
            trainer, loss_name)
    ProgressBar().attach(trainer, losses)
    ProgressBar().attach(evaluator)

    objects_to_checkpoint = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": scheduler,
        "scaler": scaler,
    }
    checkpoint = Checkpoint(
        to_save=objects_to_checkpoint,
        save_handler=DiskSaver(log_dir, require_empty=False),
        n_saved=3,
        score_name="mAP",
        global_step_transform=lambda *_: trainer.state.epoch,
    )
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint)
    if resume_from:
        Checkpoint.load_objects(objects_to_checkpoint, torch.load(resume_from))

    aim_logger.log_params({
        "lr": lr,
        "image_size": image_size,
        "batch_size": batch_size,
        "epochs": epochs,
    })
    aim_logger.attach_output_handler(trainer,
                                     event_name=Events.ITERATION_COMPLETED,
                                     tag="train",
                                     output_transform=lambda loss: loss)
    aim_logger.attach_output_handler(
        evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag="val",
        metric_names=["mAP"],
        global_step_transform=global_step_from_engine(
            trainer, Events.ITERATION_COMPLETED),
    )

    trainer.run(train_dataloader, max_epochs=epochs)
Ejemplo n.º 23
0
def training(rank, config):

    # Specific ignite.distributed
    print(
        idist.get_rank(),
        ": run with config:",
        config,
        "- backend=",
        idist.backend(),
        "- world size",
        idist.get_world_size(),
    )

    device = idist.device()

    # Data preparation:
    dataset = RndDataset(nb_samples=config["nb_samples"])

    # Specific ignite.distributed
    train_loader = idist.auto_dataloader(dataset,
                                         batch_size=config["batch_size"])

    # Model, criterion, optimizer setup
    model = idist.auto_model(wide_resnet50_2(num_classes=100))
    criterion = NLLLoss()
    optimizer = idist.auto_optim(SGD(model.parameters(), lr=0.01))

    # Training loop log param
    log_interval = config["log_interval"]

    def _train_step(engine, batch):

        data = batch[0].to(device)
        target = batch[1].to(device)

        optimizer.zero_grad()
        output = model(data)
        # Add a softmax layer
        probabilities = torch.nn.functional.softmax(output, dim=0)

        loss_val = criterion(probabilities, target)
        loss_val.backward()
        optimizer.step()

        return loss_val

    # Running the _train_step function on whole batch_data iterable only once
    trainer = Engine(_train_step)

    # Add a logger
    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training():
        print("Process {}/{} Train Epoch: {} [{}/{}]\tLoss: {}".format(
            idist.get_rank(),
            idist.get_world_size(),
            trainer.state.epoch,
            trainer.state.iteration * len(trainer.state.batch[0]),
            len(dataset) / idist.get_world_size(),
            trainer.state.output,
        ))

    trainer.run(train_loader, max_epochs=1)
Ejemplo n.º 24
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.model}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------
    # TODO : PLEASE provide your custom datasets and dataloaders configurations
    # we can use `idist.auto_dataloader` to handle distributed configurations
    # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
    # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

    train_dataset, eval_dataset = get_datasets(path=config.data_path)

    train_dataloader = idist.auto_dataloader(
        train_dataset,
        batch_size=config.train_batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )
    eval_dataloader = idist.auto_dataloader(
        eval_dataset,
        batch_size=config.eval_batch_size,
        num_workers=config.num_workers,
        shuffle=False,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    config.num_iters_per_epoch = len(train_dataloader)
    model, optimizer, loss_fn, lr_scheduler = initialize(config=config)

    # -----------------------------
    # trainer and evaluator
    # -----------------------------

    trainer, evaluator = create_trainers(
        config=config,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    # ---------------------------------
    # attach metrics to evaluator
    # ---------------------------------
    accuracy = Accuracy(device=device)
    metrics = {
        "eval_accuracy": accuracy,
        "eval_loss": Loss(loss_fn, device=device),
        "eval_error": (1.0 - accuracy) * 100,
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # -------------------------------------------
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger
    evaluator.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model=model,
        trainer=trainer,
        evaluator=evaluator,
        metric_name="eval_accuracy",
        es_metric_name="eval_accuracy",
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=None,
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(
            config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer
        )

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ---------------------------------------------
    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    # ---------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, "eval")

    # --------------------------------------------------
    # let's try run evaluation first as a sanity check
    # --------------------------------------------------

    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Ejemplo n.º 25
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.dataset}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------

    train_dataset, num_channels = get_datasets(config.dataset, config.data_path)

    train_dataloader = idist.auto_dataloader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    netD, netG, optimizerD, optimizerG, loss_fn, lr_scheduler = initialize(config, num_channels)

    # -----------------------------
    # trainer and evaluator
    # -----------------------------
    ws = idist.get_world_size()
    real_labels = torch.ones(config.batch_size // ws, device=device)
    fake_labels = torch.zeros(config.batch_size // ws, device=device)
    fixed_noise = torch.randn(config.batch_size // ws, config.z_dim, 1, 1, device=device)

    trainer = create_trainers(
        config=config,
        netD=netD,
        netG=netG,
        optimizerD=optimizerD,
        optimizerG=optimizerG,
        loss_fn=loss_fn,
        device=device,
        real_labels=real_labels,
        fake_labels=fake_labels,
    )

    # -------------------------------------------
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {'netD': netD, 'netG': netG, 'optimizerD': optimizerD, 'optimizerG': optimizerG, 'trainer': trainer}
    optimizers = {'optimizerD': optimizerD, 'optimizerG': optimizerG}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model={'netD', netD, 'netG', netG},
        trainer=trainer,
        evaluator=trainer,
        metric_name='errD',
        es_metric_name='errD',
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"],
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(config=config, trainer=trainer, optimizers=optimizers)

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # --------------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = config.output_dir / (FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # --------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # --------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = config.output_dir / (REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # -------------------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # -------------------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        if not timer_handler:
            logger.info(f"Epoch {engine.state.epoch} done. Time per batch: {timer_handler.value():.3f}[s]")
            timer_handler.reset()

    @trainer.on(Events.ITERATION_COMPLETED(every=config.log_every_iters))
    @idist.one_rank_only()
    def print_logs(engine):
        fname = config.output_dir / LOGS_FNAME
        columns = ["iteration", ] + list(engine.state.metrics.keys())
        values = [str(engine.state.iteration), ] + [str(round(value, 5)) for value in engine.state.metrics.values()]

        with open(fname, "a") as f:
            if f.tell() == 0:
                print("\t".join(columns), file=f)
            print("\t".join(values), file=f)
        message = f"[{engine.state.epoch}/{config.max_epochs}][{engine.state.iteration % len(train_dataloader)}/{len(train_dataloader)}]"
        for name, value in zip(columns, values):
            message += f" | {name}: {value}"

    # -------------------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # -------------------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl

            mpl.use("agg")

            import matplotlib.pyplot as plt
            import pandas as pd

        except ImportError:
            warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found")

        else:
            df = pd.read_csv(config.output_dir / LOGS_FNAME, delimiter="\t", index_col="iteration")
            _ = df.plot(subplots=True, figsize=(20, 20))
            _ = plt.xlabel("Iteration number")
            fig = plt.gcf()
            path = config.output_dir / PLOT_FNAME

            fig.savefig(path)

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Ejemplo n.º 26
0
def run(loop: Loop):
    seed_everything(42)
    setup_cudnn_reproducibility(True, False)

    train_ds, valid_ds = get_train_test_datasets("data/cifar")

    model = auto_model(get_model())

    train_loader = auto_dataloader(
        train_ds,
        batch_size=512,
        shuffle=True,
        drop_last=True,
        num_workers=4,
    )

    valid_loader = auto_dataloader(
        valid_ds,
        batch_size=512,
        num_workers=4,
        shuffle=False,
    )

    optim = SGD(model.parameters(), lr=0.4, momentum=0.9)

    scheduler = OneCycleLR(optim,
                           max_lr=1,
                           epochs=NUM_EPOCHS,
                           steps_per_epoch=len(train_loader))
    criterion = CrossEntropyLoss()

    precision = Precision(average=False)
    recall = Recall(average=False)

    # Ignite metrics are combinable
    f1 = (precision * recall * 2 / (precision + recall)).mean()
    accuracy = Accuracy()

    # We are attaching metrics to automatically reset
    loop.attach(
        # Loop manages train/eval modes, device and requires_grad of attached `nn.Module`s
        criterion=criterion,
        # This criterion doesn't have any state or attribute tensors
        # So it's attachment doesn't introduce any behavior
        model=model,
        # Loop saves state of all attached objects having state_dict()/load_state_dict() methods
        # to checkpoints
        optimizer=optim,
        scheduler=scheduler,
    )

    def train(loop: Loop):
        for _ in loop.iterate_epochs(NUM_EPOCHS):
            for x, y in loop.iterate_dataloader(train_loader, mode="train"):
                y_pred_logits = model(x)

                loss: torch.Tensor = criterion(y_pred_logits, y)
                loop.backward(loss)
                # Makes optimizer step and also
                # zeroes grad after (default)
                loop.optimizer_step(optim, zero_grad=True)

                # Here we call scheduler.step() every iteration
                # because we have one-cycle scheduler
                # we also can call it after all dataloader loop
                # if it's som usual scheduler
                scheduler.step()

                # Log learning rate. All metrics are written to tensorboard
                # with specified names
                # If iteration='auto' (default) its determined based on where the call is
                # performed. Here it will be batches
                loop.metrics.log("lr",
                                 scheduler.get_last_lr()[0],
                                 iteration="auto")

            # Loop disables gradients and calls Module.eval() inside loop
            # for all attached modules when mode="valid" (default)
            for x, y in loop.iterate_dataloader(valid_loader, mode="valid"):
                y_pred_logits: torch.Tensor = model(x)

                y_pred = to_onehot(y_pred_logits.argmax(dim=-1),
                                   num_classes=10)

                precision.update((y_pred, y))
                recall.update((y_pred, y))
                accuracy.update((y_pred, y))

            # This metrics will be epoch metrics because they are called outside
            # dataloader loop
            # Here we logging metric without resetting it
            loop.metrics.log("valid/precision", precision.compute().mean())
            loop.metrics.log("valid/recall", recall.compute().mean())

            # .log() method above accepts values (tensors, floats, np.array's)
            # .consume() accepts Metric object. It resets it after logging
            loop.metrics.consume("valid/f1", f1)
            loop.metrics.consume("valid/accuracy", accuracy)

    loop.run(train)