def configure_hooks(args, config):
    hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()]

    # Make a folder to store checkpoints and tensorboard logging outputs
    suffix = datetime.now().isoformat()
    base_folder = f"{Path(__file__).parent}/output_{suffix}"
    if args.checkpoint_folder == "":
        args.checkpoint_folder = base_folder + "/checkpoints"
        os.makedirs(args.checkpoint_folder, exist_ok=True)

    logging.info(f"Logging outputs to {base_folder}")
    logging.info(f"Logging checkpoints to {args.checkpoint_folder}")

    if not args.skip_tensorboard:
        try:
            from torch.utils.tensorboard import SummaryWriter

            os.makedirs(Path(base_folder) / "tensorboard", exist_ok=True)
            tb_writer = SummaryWriter(log_dir=Path(base_folder) /
                                      "tensorboard")
            hooks.append(TensorboardPlotHook(tb_writer))
        except ImportError:
            logging.warning(
                "tensorboard not installed, skipping tensorboard hooks")

    args_dict = vars(args)
    args_dict["config"] = config
    hooks.append(
        CheckpointHook(args.checkpoint_folder,
                       args_dict,
                       checkpoint_period=args.checkpoint_period))

    if args.profiler:
        hooks.append(ProfilerHook())
    if args.show_progress:
        hooks.append(ProgressBarHook())
    if args.visdom_server != "":
        hooks.append(VisdomHook(args.visdom_server, args.visdom_port))

    return hooks
Beispiel #2
0
    def test_visdom(self, mock_visdom_cls: mock.MagicMock,
                    mock_is_primary: mock.MagicMock) -> None:
        """
        Tests that visdom is populated with plots.
        """
        mock_visdom = mock.create_autospec(Visdom, instance=True)
        mock_visdom_cls.return_value = mock_visdom

        # set up the task and state
        config = get_test_task_config()
        config["dataset"]["train"]["batchsize_per_replica"] = 2
        config["dataset"]["test"]["batchsize_per_replica"] = 5
        task = build_task(config)
        task.prepare()

        losses = [1.2, 2.3, 1.23, 2.33]
        loss_val = sum(losses) / len(losses)

        task.losses = losses

        visdom_server = "localhost"
        visdom_port = 8097

        for master, visdom_conn in product([False, True], [False, True]):
            mock_is_primary.return_value = master
            mock_visdom.check_connection.return_value = visdom_conn

            # create a visdom hook
            visdom_hook = VisdomHook(visdom_server, visdom_port)

            mock_visdom_cls.assert_called_once()
            mock_visdom_cls.reset_mock()

            counts = {"train": 0, "test": 0}
            count = 0

            for phase_idx in range(10):
                train = phase_idx % 2 == 0
                task.train = train
                phase_type = "train" if train else "test"

                counts[phase_type] += 1
                count += 1

                # test that the metrics don't change if losses is empty and that
                # visdom.line() is not called
                task.losses = []
                original_metrics = copy.deepcopy(visdom_hook.metrics)
                visdom_hook.on_phase_end(task)
                self.assertDictEqual(original_metrics, visdom_hook.metrics)
                mock_visdom.line.assert_not_called()

                # test that the metrics are updated correctly when losses
                # is non empty
                task.losses = [loss * count for loss in losses]
                visdom_hook.on_phase_end(task)

                # every meter should be present and should have the correct length
                for meter in task.meters:
                    for key in meter.value:
                        key = phase_type + "_" + meter.name + "_" + key
                        self.assertTrue(
                            key in visdom_hook.metrics
                            and type(visdom_hook.metrics[key]) == list
                            and len(visdom_hook.metrics[key])
                            == counts[phase_type])

                # the loss metric should be calculated correctly
                loss_key = phase_type + "_loss"
                self.assertTrue(loss_key in visdom_hook.metrics
                                and type(visdom_hook.metrics[loss_key]) == list
                                and len(visdom_hook.metrics[loss_key])
                                == counts[phase_type])
                self.assertAlmostEqual(
                    visdom_hook.metrics[loss_key][-1],
                    loss_val * count,
                    places=4,
                )

                # the lr metric should be correct
                lr_key = phase_type + "_learning_rate"
                self.assertTrue(
                    lr_key in visdom_hook.metrics
                    and type(visdom_hook.metrics[lr_key]) == list
                    and len(visdom_hook.metrics[lr_key]) == counts[phase_type])
                self.assertAlmostEqual(
                    visdom_hook.metrics[lr_key][-1],
                    task.optimizer.options_view.lr,
                    places=4,
                )

                if master and not train and visdom_conn:
                    # visdom.line() should be called once
                    mock_visdom.line.assert_called_once()
                    mock_visdom.line.reset_mock()
                else:
                    # visdom.line() should not be called
                    mock_visdom.line.assert_not_called()