Beispiel #1
0
def main(local_rank, c10d_backend, rdzv_init_url, max_world_size, classy_args):
    torch.manual_seed(0)
    set_video_backend(classy_args.video_backend)

    # Loads config, sets up task
    config = load_json(classy_args.config_file)

    task = build_task(config)

    # Load checkpoint, if available
    checkpoint = load_checkpoint(classy_args.checkpoint_folder)
    task.set_checkpoint(checkpoint)

    pretrained_checkpoint = load_checkpoint(classy_args.pretrained_checkpoint_folder)
    if pretrained_checkpoint is not None:
        assert isinstance(
            task, FineTuningTask
        ), "Can only use a pretrained checkpoint for fine tuning tasks"
        task.set_pretrained_checkpoint(pretrained_checkpoint)

    hooks = [
        LossLrMeterLoggingHook(classy_args.log_freq),
        ModelComplexityHook(),
        TimeMetricsHook(),
    ]

    if classy_args.checkpoint_folder != "":
        args_dict = vars(classy_args)
        args_dict["config"] = config
        hooks.append(
            CheckpointHook(
                classy_args.checkpoint_folder,
                args_dict,
                checkpoint_period=classy_args.checkpoint_period,
            )
        )
    if classy_args.profiler:
        hooks.append(ProfilerHook())

    task.set_hooks(hooks)

    assert c10d_backend == Backend.NCCL or c10d_backend == Backend.GLOO
    if c10d_backend == torch.distributed.Backend.NCCL:
        # needed to enable NCCL error handling
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

    coordinator = CoordinatorP2P(
        c10d_backend=c10d_backend,
        init_method=rdzv_init_url,
        max_num_trainers=max_world_size,
        process_group_timeout=60000,
    )
    trainer = ElasticTrainer(
        use_gpu=classy_args.device == "gpu",
        num_dataloader_workers=classy_args.num_workers,
        local_rank=local_rank,
        elastic_coordinator=coordinator,
        input_args={},
    )
    trainer.train(task)
Beispiel #2
0
def configure_hooks(args, config):
    hooks = [LossLrMeterLoggingHook(args.log_freq), TimeMetricsHook()]

    # Make a folder to store checkpoints and tensorboard logging outputs
    suffix = datetime.now().isoformat()
    base_folder = Path(__file__).parent / f"output_{suffix}"
    if args.checkpoint_folder == "":
        args.checkpoint_folder = base_folder / "checkpoints"
        os.makedirs(args.checkpoint_folder, exist_ok=True)

    logging.info(f"Logging outputs to {base_folder.resolve()}")
    logging.info(f"Logging checkpoints to {args.checkpoint_folder}")

    if not args.skip_tensorboard:
        try:
            from tensorboardX import SummaryWriter

            tb_writer = SummaryWriter(log_dir=base_folder / "tensorboard")
            hooks.append(TensorboardPlotHook(tb_writer))
        except ImportError:
            logging.warning(
                "tensorboardX not installed, skipping tensorboard hooks")

    args_dict = vars(args)
    args_dict["config"] = config
    hooks.append(
        CheckpointHook(args.checkpoint_folder,
                       args_dict,
                       checkpoint_period=args.checkpoint_period))

    if args.profiler:
        hooks.append(ProfilerHook())
    if args.show_progress:
        hooks.append(ProgressBarHook())
    if args.visdom_server != "":
        hooks.append(VisdomHook(args.visdom_server, args.visdom_port))

    return hooks
    def test_time_metrics(
        self,
        mock_get_rank: mock.MagicMock,
        mock_report_str: mock.MagicMock,
        mock_time: mock.MagicMock,
    ) -> None:
        """
        Tests that the progress bar is created, updated and destroyed correctly.
        """
        rank = 5
        mock_get_rank.return_value = rank

        mock_report_str.return_value = ""
        local_variables = {}

        for log_freq, train in product([5, None], [True, False]):
            # create a time metrics hook
            time_metrics_hook = TimeMetricsHook(log_freq=log_freq)

            phase_type = "train" if train else "test"

            task = get_test_classy_task()
            task.prepare()
            task.train = train

            # on_phase_start() should set the start time and perf_stats
            start_time = 1.2
            mock_time.return_value = start_time
            time_metrics_hook.on_phase_start(task, local_variables)
            self.assertEqual(time_metrics_hook.start_time, start_time)
            self.assertTrue(
                isinstance(local_variables.get("perf_stats"), PerfStats))

            # test that the code doesn't raise an exception if losses is empty
            try:
                time_metrics_hook.on_phase_end(task, local_variables)
            except Exception as e:
                self.fail("Received Exception when losses is []: {}".format(e))

            # check that _log_performance_metrics() is called after on_step()
            # every log_freq batches and after on_phase_end()
            with mock.patch.object(time_metrics_hook,
                                   "_log_performance_metrics") as mock_fn:
                num_batches = 20

                for i in range(num_batches):
                    task.losses = list(range(i))
                    time_metrics_hook.on_step(task, local_variables)
                    if log_freq is not None and i and i % log_freq == 0:
                        mock_fn.assert_called_with(task, local_variables)
                        mock_fn.reset_mock()
                        continue
                    mock_fn.assert_not_called()

                time_metrics_hook.on_phase_end(task, local_variables)
                mock_fn.assert_called_with(task, local_variables)

            task.losses = [0.23, 0.45, 0.34, 0.67]

            end_time = 10.4
            avg_batch_time_ms = 2.3 * 1000
            mock_time.return_value = end_time

            # test _log_performance_metrics()
            with self.assertLogs() as log_watcher:
                time_metrics_hook._log_performance_metrics(
                    task, local_variables)

            # there should 2 be info logs for train and 1 for test
            self.assertEqual(len(log_watcher.output), 2 if train else 1)
            self.assertTrue(
                all(log_record.levelno == logging.INFO
                    for log_record in log_watcher.records))
            match = re.search(
                (r"Average {} batch time \(ms\) for {} batches: "
                 r"(?P<avg_batch_time>[-+]?\d*\.\d+|\d+)").format(
                     phase_type, len(task.losses)),
                log_watcher.output[0],
            )
            self.assertIsNotNone(match)
            self.assertAlmostEqual(avg_batch_time_ms,
                                   float(match.group("avg_batch_time")),
                                   places=4)
            if train:
                self.assertIn(f"Train step time breakdown (rank {rank})",
                              log_watcher.output[1])

            # if on_phase_start() is not called, 2 warnings should be logged
            # create a new time metrics hook
            local_variables = {}
            time_metrics_hook_new = TimeMetricsHook()

            with self.assertLogs() as log_watcher:
                time_metrics_hook_new.on_phase_end(task, local_variables)

            self.assertEqual(len(log_watcher.output), 2)
            self.assertTrue(
                all(log_record.levelno == logging.WARN
                    for log_record in log_watcher.records))