Example #1
0
    def train(self, **kwargs):
        config = self.config
        train_dataset = Dataset.create(
            config.train.dataset,
            split="train",
            data_root=config.system.data_root,
            transforms=self._transforms(is_train=True,
                                        crop_size=config.train.crop_size),
        )
        train_loader = create_loader(
            train_dataset,
            batch_size=config.train.batch_size,
            num_workers=config.system.workers,
            dryrun=config.system.dryrun,
        )

        val_dataset = Dataset.create(
            config.val.dataset,
            split="val",
            data_root=config.system.data_root,
            transforms=self._transforms(is_train=False),
        )
        val_loader = create_loader(
            val_dataset,
            batch_size=config.val.batch_size,
            num_workers=config.system.workers,
            dryrun=config.system.dryrun,
        )

        logger.info("Start training estimator: %s", type(self).__name__)
        self.model.to(self.device)
        n_epochs = config.train.epochs
        val_interval = config.system.val_interval
        for epoch in range(1, n_epochs + 1):
            logger.info(f"Training Epoch[{epoch}/{n_epochs}]")
            self._train_one_epoch(train_loader, epoch)

            if epoch % val_interval == 0:
                self._evaluate_one_epoch(val_loader, epoch)

            self.checkpointer.save(self, epoch=epoch)

        self.writer.close()
    def evaluate(self, **kwargs):
        config = self.config
        test_dataset = Dataset.create(
            config.test.dataset,
            split="test",
            data_root=config.system.data_root,
            transforms=self._NYU_transforms(is_train=False),
        )
        test_loader = create_loader(
            test_dataset,
            batch_size=config.test.batch_size,
            num_workers=config.system.workers,
            dryrun=config.system.dryrun,
        )

        logger.info("Start evaluating estimator: %s", type(self).__name__)
        self.model.to(self.device)
        self._evaluate_one_epoch(test_loader, 1, 1)
    def train(self, **kwargs):
        # Training parameters
        config = self.config
        optimizer = self.optimizer
        val_interval = config.system.val_interval
        writer = self.writer

        # Load data
        train_dataset = Dataset.create(
            config.train.dataset,
            split="train",
            data_root=config.system.data_root,
            transforms=self._NYU_transforms(is_train=True),
        )

        train_loader = create_loader(
            train_dataset,
            batch_size=config.train.batch_size,
            num_workers=config.system.workers,
            dryrun=config.system.dryrun,
        )

        val_dataset = Dataset.create(
            config.val.dataset,
            split="test",
            data_root=config.system.data_root,
            transforms=self._NYU_transforms(is_train=False),
        )

        val_loader = create_loader(
            val_dataset,
            batch_size=config.val.batch_size,
            num_workers=config.system.workers,
            dryrun=config.system.dryrun,
        )

        # Logging
        logger.info("Start training estimator: %s", type(self).__name__)

        self.model.to(self.device)
        n_epochs = config.train.epochs

        # Start training
        for epoch in range(1, n_epochs + 1):
            logger.info(f"Epoch[{epoch}/{n_epochs}] training started.")
            loss_metric = Loss(self._loss_fn)
            self.model.train()
            N = len(train_loader)
            accumulation_steps = self.config.train.accumulation_steps
            optimizer.zero_grad()
            for i, (image, depth) in enumerate(train_loader):
                # Prepare sample and depth
                image = image.to(self.device)
                depth_n = depth.to(self.device)

                # Predict
                output = self.model(image)

                # Compute loss
                loss = self._loss_fn(output, depth_n)

                # Backward
                loss.backward()

                if (i + 1) % accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()

                loss_metric.update((output, depth_n))

                # Log progress
                logger.debug(f"[{i}/{N}] Loss: {loss:.4f}")

            epoch_loss = loss_metric.compute()
            if epoch % val_interval == 0:
                self._evaluate_one_epoch(val_loader, epoch, n_epochs)

            # Record epoch's intermediate results
            writer.add_scalar("Training/Loss", epoch_loss, epoch)
            self.checkpointer.save(self, epoch=epoch)