Example #1
0
def train(datasets, model, loss, optimizer, meters, args):
    task = (ClassificationTask()
            .set_num_epochs(args.num_epochs)
            .set_loss(loss)
            .set_model(model)
            .set_optimizer(optimizer)
            .set_meters(meters))
    for phase in ["train", "test"]:
        task.set_dataset(datasets[phase], phase)

    hooks = [LossLrMeterLoggingHook(log_freq=args.print_freq)]
    # show progress
    hooks.append(ProgressBarHook())
    if not args.skip_tensorboard:
        try:
            from tensorboardX import SummaryWriter
            tb_writer = SummaryWriter(log_dir=args.video_dir + "/tensorboard")
            hooks.append(TensorboardPlotHook(tb_writer))
        except ImportError:
            print("tensorboardX not installed, skipping tensorboard hooks")

    checkpoint_dir = f"{args.video_dir}/checkpoint/classy_checkpoint_{time.time()}"
    os.mkdir(checkpoint_dir)
    hooks.append(CheckpointHook(checkpoint_dir, input_args={}))

    task = task.set_hooks(hooks)
    trainer = LocalTrainer(use_gpu=args.cuda, num_dataloader_workers=args.num_workers)
    trainer.train(task)
def configure_hooks(args, config):
    hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()]

    # Make a folder to store checkpoints and tensorboard logging outputs
    suffix = datetime.now().isoformat()
    base_folder = f"{Path(__file__).parent}/output_{suffix}"
    if args.checkpoint_folder == "":
        args.checkpoint_folder = base_folder + "/checkpoints"
        os.makedirs(args.checkpoint_folder, exist_ok=True)

    logging.info(f"Logging outputs to {base_folder}")
    logging.info(f"Logging checkpoints to {args.checkpoint_folder}")

    if not args.skip_tensorboard:
        try:
            from torch.utils.tensorboard import SummaryWriter

            os.makedirs(Path(base_folder) / "tensorboard", exist_ok=True)
            tb_writer = SummaryWriter(log_dir=Path(base_folder) /
                                      "tensorboard")
            hooks.append(TensorboardPlotHook(tb_writer))
        except ImportError:
            logging.warning(
                "tensorboard not installed, skipping tensorboard hooks")

    args_dict = vars(args)
    args_dict["config"] = config
    hooks.append(
        CheckpointHook(args.checkpoint_folder,
                       args_dict,
                       checkpoint_period=args.checkpoint_period))

    if args.profiler:
        hooks.append(ProfilerHook())
    if args.show_progress:
        hooks.append(ProgressBarHook())
    if args.visdom_server != "":
        hooks.append(VisdomHook(args.visdom_server, args.visdom_port))

    return hooks
    def test_progress_bar(self, mock_is_master: mock.MagicMock,
                          mock_progressbar_pkg: mock.MagicMock) -> None:
        """
        Tests that the progress bar is created, updated and destroyed correctly.
        """
        mock_progress_bar = mock.create_autospec(progressbar.ProgressBar,
                                                 instance=True)
        mock_progressbar_pkg.ProgressBar.return_value = mock_progress_bar

        mock_is_master.return_value = True

        task = get_test_classy_task()
        task.prepare()
        task.advance_phase()

        num_batches = task.num_batches_per_phase
        # make sure we are checking at least one batch
        self.assertGreater(num_batches, 0)

        # create a progress bar hook
        progress_bar_hook = ProgressBarHook()

        # progressbar.ProgressBar should be init-ed with num_batches
        progress_bar_hook.on_phase_start(task)
        mock_progressbar_pkg.ProgressBar.assert_called_once_with(num_batches)
        mock_progress_bar.start.assert_called_once_with()
        mock_progress_bar.start.reset_mock()
        mock_progressbar_pkg.ProgressBar.reset_mock()

        # on_step should update the progress bar correctly
        for i in range(num_batches):
            progress_bar_hook.on_step(task)
            mock_progress_bar.update.assert_called_once_with(i + 1)
            mock_progress_bar.update.reset_mock()

        # check that even if on_step is called again, the progress bar is
        # only updated with num_batches
        for _ in range(num_batches):
            progress_bar_hook.on_step(task)
            mock_progress_bar.update.assert_called_once_with(num_batches)
            mock_progress_bar.update.reset_mock()

        # finish should be called on the progress bar
        progress_bar_hook.on_phase_end(task)
        mock_progress_bar.finish.assert_called_once_with()
        mock_progress_bar.finish.reset_mock()

        # check that even if the progress bar isn't created, the code doesn't
        # crash
        progress_bar_hook = ProgressBarHook()
        try:
            progress_bar_hook.on_step(task)
            progress_bar_hook.on_phase_end(task)
        except Exception as e:
            self.fail(
                "Received Exception when on_phase_start() isn't called: {}".
                format(e))
        mock_progressbar_pkg.ProgressBar.assert_not_called()

        # check that a progress bar is not created if is_master() returns False
        mock_is_master.return_value = False
        progress_bar_hook = ProgressBarHook()
        try:
            progress_bar_hook.on_phase_start(task)
            progress_bar_hook.on_step(task)
            progress_bar_hook.on_phase_end(task)
        except Exception as e:
            self.fail(
                "Received Exception when is_master() is False: {}".format(e))
        self.assertIsNone(progress_bar_hook.progress_bar)
        mock_progressbar_pkg.ProgressBar.assert_not_called()