Exemplo n.º 1
0
def train_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device: str,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):

    early_break = getattr(debugging, "early_break", False)
    profiling = getattr(debugging, "profiling", False)

    model.train()
    tracker.reset("train")
    visualizer.reset(epoch, "train")
    train_loader = dataset.train_dataloader

    iter_data_time = time.time()
    with Ctq(train_loader) as tq_train_loader:
        for i, data in enumerate(tq_train_loader):
            model.set_input(data, device)
            t_data = time.time() - iter_data_time

            iter_start_time = time.time()
            model.optimize_parameters(epoch, dataset.batch_size)
            if i % 10 == 0:
                tracker.track(model)

            tq_train_loader.set_postfix(**tracker.get_metrics(),
                                        data_loading=float(t_data),
                                        iteration=float(time.time() -
                                                        iter_start_time),
                                        color=COLORS.TRAIN_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            iter_data_time = time.time()

            if early_break:
                break

            if profiling:
                if i > getattr(debugging, "num_batches", 50):
                    return 0

    metrics = tracker.publish(epoch)
    checkpoint.save_best_models_under_current_metrics(model, metrics,
                                                      tracker.metric_func)
    log.info("Learning rate = %f" % model.learning_rate)
Exemplo n.º 2
0
    def test_save_all(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((num_points * batch_size, 3))
        mock_data.y = torch.zeros((num_points * batch_size, 1))
        mock_data.pred = torch.zeros((num_points * batch_size, 1))
        mock_data.batch = torch.zeros((num_points * batch_size))
        mock_data.batch[:num_points] = 1
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        epochs = 2
        num_samples = 100
        mock_num_batches = {"train": num_samples}

        config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_save_all.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)

        for epoch in range(epochs):
            run(num_samples // batch_size, visualizer, epoch, "train", data)

        for split in ["train"]:
            for epoch in range(epochs):
                current = set(os.listdir(os.path.join(self.run_path, "viz", str(epoch), split)))
                self.assertGreaterEqual(len(current), num_samples)
        shutil.rmtree(self.run_path)
Exemplo n.º 3
0
    def test_indices(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((batch_size, num_points, 3))
        mock_data.y = torch.zeros((batch_size, num_points, 1))
        mock_data.pred = torch.zeros((batch_size, num_points, 1))
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        mock_num_batches = {"train": 9, "test": 3, "val": 0}
        config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_indices.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)

        for epoch in range(epochs):
            run(9, visualizer, epoch, "train", data)
            run(3, visualizer, epoch, "test", data)
            run(0, visualizer, epoch, "val", data)

        targets = {'train': set(["1_1.ply", "0_0.ply"]),
                   'test': set(["0_0.ply"])}
        for split in ["train", "test"]:
            for epoch in range(epochs):
                self.assertEqual(targets[split], set(os.listdir(os.path.join(self.run_path, "viz", str(epoch), split))))
        shutil.rmtree(self.run_path)
Exemplo n.º 4
0
    def test_pyg_data(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((num_points * batch_size, 3))
        mock_data.y = torch.zeros((num_points * batch_size, 1))
        mock_data.pred = torch.zeros((num_points * batch_size, 1))
        mock_data.batch = torch.zeros((num_points * batch_size))
        mock_data.batch[:num_points] = 1
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        epochs = 10
        num_batches = 100
        mock_num_batches = {"train": num_batches}

        config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_non_deterministic.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)

        for epoch in range(epochs):
            run(num_batches, visualizer, epoch, "train", data)

        count = 0
        for split in ["train"]:
            target = set(os.listdir(os.path.join(self.run_path, "viz", "0", split)))
            for epoch in range(1, epochs):
                current = set(os.listdir(os.path.join(self.run_path, "viz", str(epoch), split)))
                count += 1 if len(target & current) == 0 else 0
        self.assertGreaterEqual(count, 4)
        shutil.rmtree(self.run_path)
Exemplo n.º 5
0
    def test_dense_data(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((batch_size, num_points, 3))
        mock_data.y = torch.zeros((batch_size, num_points, 1))
        mock_data.pred = torch.zeros((batch_size, num_points, 1))
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        mock_num_batches = {"train": 9, "test": 3, "val": 0}
        config = OmegaConf.load(
            os.path.join(DIR, "test_config/viz/viz_config_deterministic.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches,
                                batch_size, self.run_path, None)

        for epoch in range(epochs):
            run(9, visualizer, epoch, "train", data)
            run(3, visualizer, epoch, "test", data)
            run(0, visualizer, epoch, "val", data)

        for split in ["train", "test"]:
            for format in ["ply", "las"]:
                targets = os.listdir(
                    os.path.join(self.run_path, "viz", "0", split, format))
                for epoch in range(1, epochs):
                    current = os.listdir(
                        os.path.join(self.run_path, "viz", str(epoch), split,
                                     format))
                    self.assertEqual(len(targets), len(current))
        shutil.rmtree(self.run_path)
Exemplo n.º 6
0
def test_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):
    early_break = getattr(debugging, "early_break", False)
    model.eval()

    loaders = dataset.test_dataloaders

    for loader in loaders:
        stage_name = loader.dataset.name
        tracker.reset(stage_name)
        visualizer.reset(epoch, stage_name)
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()

                tracker.track(model)
                tq_test_loader.set_postfix(**tracker.get_metrics(),
                                           color=COLORS.TEST_COLOR)

                if visualizer.is_active:
                    visualizer.save_visuals(model.get_current_visuals())

                if early_break:
                    break

        tracker.finalise()
        metrics = tracker.publish(epoch)
        tracker.print_summary()
        checkpoint.save_best_models_under_current_metrics(
            model, metrics, tracker.metric_func)
Exemplo n.º 7
0
def eval_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):

    early_break = getattr(debugging, "early_break", False)

    model.eval()
    tracker.reset("val")
    visualizer.reset(epoch, "val")
    loader = dataset.val_dataloader
    with Ctq(loader) as tq_val_loader:
        for data in tq_val_loader:
            with torch.no_grad():
                model.set_input(data, device)
                model.forward()

            tracker.track(model)
            tq_val_loader.set_postfix(**tracker.get_metrics(),
                                      color=COLORS.VAL_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            if early_break:
                break

    metrics = tracker.publish(epoch)
    tracker.print_summary()
    checkpoint.save_best_models_under_current_metrics(model, metrics,
                                                      tracker.metric_func)
Exemplo n.º 8
0
    def test_indices(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((batch_size, num_points, 3))
        mock_data.y = torch.zeros((batch_size, num_points, 1))
        mock_data.pred = torch.zeros((batch_size, num_points, 1))
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        mock_num_batches = {"train": 9, "test": 3, "val": 0}
        config = OmegaConf.load(
            os.path.join(DIR, "test_config/viz/viz_config_indices.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches,
                                batch_size, self.run_path, None)

        for epoch in range(epochs):
            run(9, visualizer, epoch, "train", data)
            run(3, visualizer, epoch, "test", data)
            run(0, visualizer, epoch, "val", data)

        targets = {"train": set(["1_1", "0_0"]), "test": set(["0_0"])}
        for split in ["train", "test"]:
            for epoch in range(epochs):
                for format in ["ply", "las"]:
                    files = os.listdir(
                        os.path.join(self.run_path, "viz", str(epoch), split,
                                     format))
                    files = [
                        os.path.splitext(filename)[0] for filename in files
                    ]

                    target = targets[split]
                    target = ["%d_%s" % (epoch, f) for f in target
                              ]  # append current epoch to start of target
                    if format == "las":
                        target_gt = ["%s_gt" % (x)
                                     for x in target]  # add gt files for las
                        target += target_gt

                    self.assertEqual(set(target), set(files))
        shutil.rmtree(self.run_path)
Exemplo n.º 9
0
    def test_empty(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((batch_size, num_points, 3))
        mock_data.y = torch.zeros((batch_size, num_points, 1))
        mock_data.pred = torch.zeros((batch_size, num_points, 1))
        data = {}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        mock_num_batches = {"train": 9, "test": 3, "val": 0}
        config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_indices.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)

        for epoch in range(epochs):
            run(9, visualizer, epoch, "train", data)
            run(3, visualizer, epoch, "test", data)
            run(2, visualizer, epoch, "val", data)

        self.assertEqual(len(os.listdir(os.path.join(self.run_path, "viz"))), 0)
        shutil.rmtree(self.run_path)
Exemplo n.º 10
0
    def _initialize_trainer(self):
        # Enable CUDNN BACKEND
        torch.backends.cudnn.enabled = self.enable_cudnn

        if not self.has_training:
            self._cfg.training = self._cfg
            resume = bool(self._cfg.checkpoint_dir)
        else:
            resume = bool(self._cfg.training.checkpoint_dir)

        # Get device
        if self._cfg.training.cuda > -1 and torch.cuda.is_available():
            device = "cuda"
            torch.cuda.set_device(self._cfg.training.cuda)
        else:
            device = "cpu"
        self._device = torch.device(device)
        log.info("DEVICE : {}".format(self._device))

        # Profiling
        if self.profiling:
            # Set the num_workers as torch.utils.bottleneck doesn't work well with it
            self._cfg.training.num_workers = 0

        # Start Wandb if public
        if self.wandb_log:
            Wandb.launch(self._cfg, self._cfg.wandb.public and self.wandb_log)

        # Checkpoint

        self._checkpoint: ModelCheckpoint = ModelCheckpoint(
            self._cfg.training.checkpoint_dir,
            self._cfg.model_name,
            self._cfg.training.weight_name,
            run_config=self._cfg,
            resume=resume,
        )

        # Create model and datasets
        if not self._checkpoint.is_empty:
            self._dataset: BaseDataset = instantiate_dataset(
                self._checkpoint.data_config)
            self._model: BaseModel = self._checkpoint.create_model(
                self._dataset, weight_name=self._cfg.training.weight_name)
        else:
            self._dataset: BaseDataset = instantiate_dataset(self._cfg.data)
            self._model: BaseModel = instantiate_model(
                copy.deepcopy(self._cfg), self._dataset)
            self._model.instantiate_optimizers(self._cfg, "cuda" in device)
            self._model.set_pretrained_weights()
            if not self._checkpoint.validate(self._dataset.used_properties):
                log.warning(
                    "The model will not be able to be used from pretrained weights without the corresponding dataset. Current properties are {}"
                    .format(self._dataset.used_properties))
        self._checkpoint.dataset_properties = self._dataset.used_properties

        log.info(self._model)

        self._model.log_optimizers()
        log.info(
            "Model size = %i",
            sum(param.numel() for param in self._model.parameters()
                if param.requires_grad))

        # Set dataloaders
        self._dataset.create_dataloaders(
            self._model,
            self._cfg.training.batch_size,
            self._cfg.training.shuffle,
            self._cfg.training.num_workers,
            self.precompute_multi_scale,
        )
        log.info(self._dataset)

        # Verify attributes in dataset
        self._model.verify_data(self._dataset.train_dataset[0])

        # Choose selection stage
        selection_stage = getattr(self._cfg, "selection_stage", "")
        self._checkpoint.selection_stage = self._dataset.resolve_saving_stage(
            selection_stage)
        self._tracker: BaseTracker = self._dataset.get_tracker(
            self.wandb_log, self.tensorboard_log)

        if self.wandb_log:
            Wandb.launch(self._cfg, not self._cfg.wandb.public
                         and self.wandb_log)

        # Run training / evaluation
        self._model = self._model.to(self._device)
        if self.has_visualization:
            self._visualizer = Visualizer(self._cfg.visualization,
                                          self._dataset.num_batches,
                                          self._dataset.batch_size,
                                          os.getcwd())
Exemplo n.º 11
0
class Trainer:
    """
    TorchPoints3d Trainer handles the logic between
        - BaseModel,
        - Dataset and its Tracker
        - A custom ModelCheckpoint
        - A custom Visualizer
    It supports MC dropout - multiple voting_runs for val / test datasets
    """
    def __init__(self, cfg):
        self._cfg = cfg
        self._initialize_trainer()

    def _initialize_trainer(self):
        # Enable CUDNN BACKEND
        torch.backends.cudnn.enabled = self.enable_cudnn

        if not self.has_training:
            self._cfg.training = self._cfg
            resume = bool(self._cfg.checkpoint_dir)
        else:
            resume = bool(self._cfg.training.checkpoint_dir)

        # Get device
        if self._cfg.training.cuda > -1 and torch.cuda.is_available():
            device = "cuda"
            torch.cuda.set_device(self._cfg.training.cuda)
        else:
            device = "cpu"
        self._device = torch.device(device)
        log.info("DEVICE : {}".format(self._device))

        # Profiling
        if self.profiling:
            # Set the num_workers as torch.utils.bottleneck doesn't work well with it
            self._cfg.training.num_workers = 0

        # Start Wandb if public
        if self.wandb_log:
            Wandb.launch(self._cfg, self._cfg.wandb.public and self.wandb_log)

        # Checkpoint

        self._checkpoint: ModelCheckpoint = ModelCheckpoint(
            self._cfg.training.checkpoint_dir,
            self._cfg.model_name,
            self._cfg.training.weight_name,
            run_config=self._cfg,
            resume=resume,
        )

        # Create model and datasets
        if not self._checkpoint.is_empty:
            self._dataset: BaseDataset = instantiate_dataset(
                self._checkpoint.data_config)
            self._model: BaseModel = self._checkpoint.create_model(
                self._dataset, weight_name=self._cfg.training.weight_name)
        else:
            self._dataset: BaseDataset = instantiate_dataset(self._cfg.data)
            self._model: BaseModel = instantiate_model(
                copy.deepcopy(self._cfg), self._dataset)
            self._model.instantiate_optimizers(self._cfg, "cuda" in device)
            self._model.set_pretrained_weights()
            if not self._checkpoint.validate(self._dataset.used_properties):
                log.warning(
                    "The model will not be able to be used from pretrained weights without the corresponding dataset. Current properties are {}"
                    .format(self._dataset.used_properties))
        self._checkpoint.dataset_properties = self._dataset.used_properties

        log.info(self._model)

        self._model.log_optimizers()
        log.info(
            "Model size = %i",
            sum(param.numel() for param in self._model.parameters()
                if param.requires_grad))

        # Set dataloaders
        self._dataset.create_dataloaders(
            self._model,
            self._cfg.training.batch_size,
            self._cfg.training.shuffle,
            self._cfg.training.num_workers,
            self.precompute_multi_scale,
        )
        log.info(self._dataset)

        # Verify attributes in dataset
        self._model.verify_data(self._dataset.train_dataset[0])

        # Choose selection stage
        selection_stage = getattr(self._cfg, "selection_stage", "")
        self._checkpoint.selection_stage = self._dataset.resolve_saving_stage(
            selection_stage)
        self._tracker: BaseTracker = self._dataset.get_tracker(
            self.wandb_log, self.tensorboard_log)

        if self.wandb_log:
            Wandb.launch(self._cfg, not self._cfg.wandb.public
                         and self.wandb_log)

        # Run training / evaluation
        self._model = self._model.to(self._device)
        if self.has_visualization:
            self._visualizer = Visualizer(self._cfg.visualization,
                                          self._dataset.num_batches,
                                          self._dataset.batch_size,
                                          os.getcwd())

    def train(self):
        self._is_training = True

        for epoch in range(self._checkpoint.start_epoch,
                           self._cfg.training.epochs):
            log.info("EPOCH %i / %i", epoch, self._cfg.training.epochs)

            self._train_epoch(epoch)

            if self.profiling:
                return 0

            if epoch % self.eval_frequency != 0:
                continue

            if self._dataset.has_val_loader:
                self._test_epoch(epoch, "val")

            if self._dataset.has_test_loaders:
                self._test_epoch(epoch, "test")

        # Single test evaluation in resume case
        if self._checkpoint.start_epoch > self._cfg.training.epochs:
            if self._dataset.has_test_loaders:
                self._test_epoch(epoch, "test")

    def eval(self, stage_name=""):
        self._is_training = False

        epoch = self._checkpoint.start_epoch
        if self._dataset.has_val_loader:
            if not stage_name or stage_name == "val":
                self._test_epoch(epoch, "val")

        if self._dataset.has_test_loaders:
            if not stage_name or stage_name == "test":
                self._test_epoch(epoch, "test")

    def _finalize_epoch(self, epoch):
        self._tracker.finalise(**self.tracker_options)
        if self._is_training:
            metrics = self._tracker.publish(epoch)
            self._checkpoint.save_best_models_under_current_metrics(
                self._model, metrics, self._tracker.metric_func)
            if self.wandb_log and self._cfg.wandb.public:
                Wandb.add_file(self._checkpoint.checkpoint_path)
            if self._tracker._stage == "train":
                log.info("Learning rate = %f" % self._model.learning_rate)

    def _train_epoch(self, epoch: int):

        self._model.train()
        self._tracker.reset("train")
        self._visualizer.reset(epoch, "train")
        train_loader = self._dataset.train_dataloader

        iter_data_time = time.time()
        with Ctq(train_loader) as tq_train_loader:
            for i, data in enumerate(tq_train_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                self._model.set_input(data, self._device)
                self._model.optimize_parameters(epoch,
                                                self._dataset.batch_size)
                if i % 10 == 0:
                    with torch.no_grad():
                        self._tracker.track(self._model,
                                            data=data,
                                            **self.tracker_options)

                tq_train_loader.set_postfix(**self._tracker.get_metrics(),
                                            data_loading=float(t_data),
                                            iteration=float(time.time() -
                                                            iter_start_time),
                                            color=COLORS.TRAIN_COLOR)

                if self._visualizer.is_active:
                    self._visualizer.save_visuals(
                        self._model.get_current_visuals())

                iter_data_time = time.time()

                if self.early_break:
                    break

                if self.profiling:
                    if i > self.num_batches:
                        return 0

        self._finalize_epoch(epoch)

    def _test_epoch(self, epoch, stage_name: str):
        voting_runs = self._cfg.get("voting_runs", 1)
        if stage_name == "test":
            loaders = self._dataset.test_dataloaders
        else:
            loaders = [self._dataset.val_dataloader]

        self._model.eval()
        if self.enable_dropout:
            self._model.enable_dropout_in_eval()

        for loader in loaders:
            stage_name = loader.dataset.name
            self._tracker.reset(stage_name)
            if self.has_visualization:
                self._visualizer.reset(epoch, stage_name)
            if not self._dataset.has_labels(
                    stage_name) and not self.tracker_options.get(
                        "make_submission",
                        False):  # No label, no submission -> do nothing
                log.warning("No forward will be run on dataset %s." %
                            stage_name)
                continue

            for i in range(voting_runs):
                with Ctq(loader) as tq_loader:
                    for data in tq_loader:
                        with torch.no_grad():
                            self._model.set_input(data, self._device)
                            with torch.cuda.amp.autocast(
                                    enabled=self._model.is_mixed_precision()):
                                self._model.forward(epoch=epoch)
                            self._tracker.track(self._model,
                                                data=data,
                                                **self.tracker_options)
                        tq_loader.set_postfix(**self._tracker.get_metrics(),
                                              color=COLORS.TEST_COLOR)

                        if self.has_visualization and self._visualizer.is_active:
                            self._visualizer.save_visuals(
                                self._model.get_current_visuals())

                        if self.early_break:
                            break

                        if self.profiling:
                            if i > self.num_batches:
                                return 0

            self._finalize_epoch(epoch)
            self._tracker.print_summary()

    @property
    def early_break(self):
        return getattr(self._cfg.debugging, "early_break",
                       False) and self._is_training

    @property
    def profiling(self):
        return getattr(self._cfg.debugging, "profiling", False)

    @property
    def num_batches(self):
        return getattr(self._cfg.debugging, "num_batches", 50)

    @property
    def enable_cudnn(self):
        return getattr(self._cfg.training, "enable_cudnn", True)

    @property
    def enable_dropout(self):
        return getattr(self._cfg, "enable_dropout", True)

    @property
    def has_visualization(self):
        return getattr(self._cfg, "visualization", False)

    @property
    def has_tensorboard(self):
        return getattr(self._cfg, "tensorboard", False)

    @property
    def has_training(self):
        return getattr(self._cfg, "training", None)

    @property
    def precompute_multi_scale(self):
        return self._model.conv_type == "PARTIAL_DENSE" and getattr(
            self._cfg.training, "precompute_multi_scale", False)

    @property
    def wandb_log(self):
        if getattr(self._cfg, "wandb", False):
            return getattr(self._cfg.wandb, "log", False)
        else:
            return False

    @property
    def tensorboard_log(self):
        if self.has_tensorboard:
            return getattr(self._cfg.tensorboard, "log", False)
        else:
            return False

    @property
    def tracker_options(self):
        return self._cfg.get("tracker_options", {})

    @property
    def eval_frequency(self):
        return self._cfg.get("eval_frequency", 1)
Exemplo n.º 12
0
def main(cfg):
    OmegaConf.set_struct(
        cfg,
        False)  # This allows getattr and hasattr methods to function correctly
    if cfg.pretty_print:
        print(cfg.pretty())

    # Get device
    device = torch.device("cuda" if (
        torch.cuda.is_available() and cfg.training.cuda) else "cpu")
    log.info("DEVICE : {}".format(device))

    # Enable CUDNN BACKEND
    torch.backends.cudnn.enabled = cfg.training.enable_cudnn

    # Profiling
    profiling = getattr(cfg.debugging, "profiling", False)
    if profiling:
        # Set the num_workers as torch.utils.bottleneck doesn't work well with it
        cfg.training.num_workers = 0

    # Start Wandb if public
    launch_wandb(cfg, cfg.wandb.public and cfg.wandb.log)

    # Checkpoint
    checkpoint = ModelCheckpoint(
        cfg.training.checkpoint_dir,
        cfg.model_name,
        cfg.training.weight_name,
        run_config=cfg,
        resume=bool(cfg.training.checkpoint_dir),
    )

    # Create model and datasets
    if not checkpoint.is_empty:
        dataset = instantiate_dataset(checkpoint.data_config)
        model = checkpoint.create_model(dataset,
                                        weight_name=cfg.training.weight_name)
    else:
        dataset = instantiate_dataset(cfg.data)
        model = instantiate_model(cfg, dataset)
        model.instantiate_optimizers(cfg)
    log.info(model)
    model.log_optimizers()
    log.info(
        "Model size = %i",
        sum(param.numel() for param in model.parameters()
            if param.requires_grad))

    # Set dataloaders
    dataset.create_dataloaders(
        model,
        cfg.training.batch_size,
        cfg.training.shuffle,
        cfg.training.num_workers,
        cfg.training.precompute_multi_scale,
    )
    log.info(dataset)

    # Choose selection stage
    selection_stage = getattr(cfg, "selection_stage", "")
    checkpoint.selection_stage = dataset.resolve_saving_stage(selection_stage)
    tracker: BaseTracker = dataset.get_tracker(model, dataset, cfg.wandb.log,
                                               cfg.tensorboard.log)

    launch_wandb(cfg, not cfg.wandb.public and cfg.wandb.log)

    # Run training / evaluation
    model = model.to(device)
    visualizer = Visualizer(cfg.visualization, dataset.num_batches,
                            dataset.batch_size, os.getcwd())
    run(cfg, model, dataset, device, tracker, checkpoint, visualizer)

    # https://github.com/facebookresearch/hydra/issues/440
    hydra._internal.hydra.GlobalHydra.get_state().clear()
    return 0