def train(datasets, model, loss, optimizer, meters, args): task = (ClassificationTask() .set_num_epochs(args.num_epochs) .set_loss(loss) .set_model(model) .set_optimizer(optimizer) .set_meters(meters)) for phase in ["train", "test"]: task.set_dataset(datasets[phase], phase) hooks = [LossLrMeterLoggingHook(log_freq=args.print_freq)] # show progress hooks.append(ProgressBarHook()) if not args.skip_tensorboard: try: from tensorboardX import SummaryWriter tb_writer = SummaryWriter(log_dir=args.video_dir + "/tensorboard") hooks.append(TensorboardPlotHook(tb_writer)) except ImportError: print("tensorboardX not installed, skipping tensorboard hooks") checkpoint_dir = f"{args.video_dir}/checkpoint/classy_checkpoint_{time.time()}" os.mkdir(checkpoint_dir) hooks.append(CheckpointHook(checkpoint_dir, input_args={})) task = task.set_hooks(hooks) trainer = LocalTrainer(use_gpu=args.cuda, num_dataloader_workers=args.num_workers) trainer.train(task)
def configure_hooks(args, config): hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()] # Make a folder to store checkpoints and tensorboard logging outputs suffix = datetime.now().isoformat() base_folder = f"{Path(__file__).parent}/output_{suffix}" if args.checkpoint_folder == "": args.checkpoint_folder = base_folder + "/checkpoints" os.makedirs(args.checkpoint_folder, exist_ok=True) logging.info(f"Logging outputs to {base_folder}") logging.info(f"Logging checkpoints to {args.checkpoint_folder}") if not args.skip_tensorboard: try: from torch.utils.tensorboard import SummaryWriter os.makedirs(Path(base_folder) / "tensorboard", exist_ok=True) tb_writer = SummaryWriter(log_dir=Path(base_folder) / "tensorboard") hooks.append(TensorboardPlotHook(tb_writer)) except ImportError: logging.warning( "tensorboard not installed, skipping tensorboard hooks") args_dict = vars(args) args_dict["config"] = config hooks.append( CheckpointHook(args.checkpoint_folder, args_dict, checkpoint_period=args.checkpoint_period)) if args.profiler: hooks.append(ProfilerHook()) if args.show_progress: hooks.append(ProgressBarHook()) if args.visdom_server != "": hooks.append(VisdomHook(args.visdom_server, args.visdom_port)) return hooks
def test_progress_bar(self, mock_is_master: mock.MagicMock, mock_progressbar_pkg: mock.MagicMock) -> None: """ Tests that the progress bar is created, updated and destroyed correctly. """ mock_progress_bar = mock.create_autospec(progressbar.ProgressBar, instance=True) mock_progressbar_pkg.ProgressBar.return_value = mock_progress_bar mock_is_master.return_value = True task = get_test_classy_task() task.prepare() task.advance_phase() num_batches = task.num_batches_per_phase # make sure we are checking at least one batch self.assertGreater(num_batches, 0) # create a progress bar hook progress_bar_hook = ProgressBarHook() # progressbar.ProgressBar should be init-ed with num_batches progress_bar_hook.on_phase_start(task) mock_progressbar_pkg.ProgressBar.assert_called_once_with(num_batches) mock_progress_bar.start.assert_called_once_with() mock_progress_bar.start.reset_mock() mock_progressbar_pkg.ProgressBar.reset_mock() # on_step should update the progress bar correctly for i in range(num_batches): progress_bar_hook.on_step(task) mock_progress_bar.update.assert_called_once_with(i + 1) mock_progress_bar.update.reset_mock() # check that even if on_step is called again, the progress bar is # only updated with num_batches for _ in range(num_batches): progress_bar_hook.on_step(task) mock_progress_bar.update.assert_called_once_with(num_batches) mock_progress_bar.update.reset_mock() # finish should be called on the progress bar progress_bar_hook.on_phase_end(task) mock_progress_bar.finish.assert_called_once_with() mock_progress_bar.finish.reset_mock() # check that even if the progress bar isn't created, the code doesn't # crash progress_bar_hook = ProgressBarHook() try: progress_bar_hook.on_step(task) progress_bar_hook.on_phase_end(task) except Exception as e: self.fail( "Received Exception when on_phase_start() isn't called: {}". format(e)) mock_progressbar_pkg.ProgressBar.assert_not_called() # check that a progress bar is not created if is_master() returns False mock_is_master.return_value = False progress_bar_hook = ProgressBarHook() try: progress_bar_hook.on_phase_start(task) progress_bar_hook.on_step(task) progress_bar_hook.on_phase_end(task) except Exception as e: self.fail( "Received Exception when is_master() is False: {}".format(e)) self.assertIsNone(progress_bar_hook.progress_bar) mock_progressbar_pkg.ProgressBar.assert_not_called()