Exemple #1
0
def test_new_events():
    def update(*args, **kwargs):
        pass

    engine = Engine(update)
    cpe = CustomPeriodicEvent(n_iterations=5)
    cpe.attach(engine)

    assert hasattr(cpe, "Events")
    assert hasattr(cpe.Events, "ITERATIONS_5_STARTED")
    assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED")

    assert engine._allowed_events[-2] == getattr(cpe.Events,
                                                 "ITERATIONS_5_STARTED")
    assert engine._allowed_events[-1] == getattr(cpe.Events,
                                                 "ITERATIONS_5_COMPLETED")

    cpe = CustomPeriodicEvent(n_epochs=5)
    cpe.attach(engine)

    assert hasattr(cpe, "Events")
    assert hasattr(cpe.Events, "EPOCHS_5_STARTED")
    assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED")

    assert engine._allowed_events[-2] == getattr(cpe.Events,
                                                 "EPOCHS_5_STARTED")
    assert engine._allowed_events[-1] == getattr(cpe.Events,
                                                 "EPOCHS_5_COMPLETED")
Exemple #2
0
def test_integration_epochs():
    def update(*args, **kwargs):
        pass

    engine = Engine(update)

    n_epochs = 3
    cpe = CustomPeriodicEvent(n_epochs=n_epochs)
    cpe.attach(engine)
    data = list(range(16))

    custom_period = [1]

    @engine.on(cpe.Events.EPOCHS_3_STARTED)
    def on_my_epoch_started(engine):
        assert (engine.state.epoch - 1) % n_epochs == 0
        assert engine.state.epochs_3 == custom_period[0]

    @engine.on(cpe.Events.EPOCHS_3_COMPLETED)
    def on_my_epoch_ended(engine):
        assert engine.state.epoch % n_epochs == 0
        assert engine.state.epochs_3 == custom_period[0]
        custom_period[0] += 1

    engine.run(data, max_epochs=10)

    assert custom_period[0] == 4
def test_new_events():
    def update(*args, **kwargs):
        pass

    with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
        engine = Engine(update)
        cpe = CustomPeriodicEvent(n_iterations=5)
        cpe.attach(engine)

        assert hasattr(cpe, "Events")
        assert hasattr(cpe.Events, "ITERATIONS_5_STARTED")
        assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED")

        assert engine._allowed_events[-2] == getattr(cpe.Events, "ITERATIONS_5_STARTED")
        assert engine._allowed_events[-1] == getattr(cpe.Events, "ITERATIONS_5_COMPLETED")

    with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
        cpe = CustomPeriodicEvent(n_epochs=5)
        cpe.attach(engine)

        assert hasattr(cpe, "Events")
        assert hasattr(cpe.Events, "EPOCHS_5_STARTED")
        assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED")

        assert engine._allowed_events[-2] == getattr(cpe.Events, "EPOCHS_5_STARTED")
        assert engine._allowed_events[-1] == getattr(cpe.Events, "EPOCHS_5_COMPLETED")
Exemple #4
0
    def _test(n_iterations, max_epochs, n_iters_per_epoch):
        def update(*args, **kwargs):
            pass

        engine = Engine(update)
        with pytest.warns(DeprecationWarning,
                          match="CustomPeriodicEvent is deprecated"):
            cpe = CustomPeriodicEvent(n_iterations=n_iterations)
            cpe.attach(engine)
        data = list(range(n_iters_per_epoch))

        custom_period = [0]
        n_calls_iter_started = [0]
        n_calls_iter_completed = [0]

        event_started = getattr(cpe.Events,
                                "ITERATIONS_{}_STARTED".format(n_iterations))

        @engine.on(event_started)
        def on_my_event_started(engine):
            assert (engine.state.iteration - 1) % n_iterations == 0
            custom_period[0] += 1
            custom_iter = getattr(engine.state,
                                  "iterations_{}".format(n_iterations))
            assert custom_iter == custom_period[0]
            n_calls_iter_started[0] += 1

        event_completed = getattr(
            cpe.Events, "ITERATIONS_{}_COMPLETED".format(n_iterations))

        @engine.on(event_completed)
        def on_my_event_ended(engine):
            assert engine.state.iteration % n_iterations == 0
            custom_iter = getattr(engine.state,
                                  "iterations_{}".format(n_iterations))
            assert custom_iter == custom_period[0]
            n_calls_iter_completed[0] += 1

        engine.run(data, max_epochs=max_epochs)

        n = len(data) * max_epochs / n_iterations
        nf = math.floor(n)
        assert custom_period[0] == n_calls_iter_started[0]
        assert n_calls_iter_started[0] == nf + 1 if nf < n else nf
        assert n_calls_iter_completed[0] == nf
Exemple #5
0
    def get_loops(self):
        if self.mode == l.TRAIN:
            train_loop = get_loop(self, self.mode)
            val_loop = get_loop(self, l.VAL)
            visualize_loop = get_loop(self, l.VISUALIZE)

            train_config = self.metric_config[l.TRAIN]
            train_config[l.MODE] = l.TRAIN

            val_config = self.metric_config[l.VAL]
            val_config[l.MODE] = l.VAL

            visualize_config = self.metric_config[l.VISUALIZE]

            self.criterion_chain.attach(train_loop.engine, train_config)
            self.criterion_chain.attach(val_loop.engine, val_config)

            self.model.attach(train_loop.engine, train_config)
            self.model.attach(val_loop.engine, val_config)

            # TODO. Add to the heart of the container.  Calculate only during test with measure_time flag
            # AveragePeriodicMetric(TimeTransformer()).attach(engine, n.FORWARD_TIME)

            train_loss_event = CustomPeriodicEvent(
                n_iterations=train_config[exp.LOSS_LOG_ITER])
            train_metric_event = CustomPeriodicEvent(
                n_iterations=train_config[exp.METRIC_LOG_ITER])
            val_event = CustomPeriodicEvent(
                n_iterations=val_config[exp.LOG_ITER])
            visualize_event = CustomPeriodicEvent(
                n_epochs=visualize_config[exp.LOG_EPOCH])

            train_loss_event.attach(train_loop.engine)
            train_metric_event.attach(train_loop.engine)
            val_event.attach(train_loop.engine)
            visualize_event.attach(train_loop.engine)

            @train_loop.engine.on(train_loss_event._periodic_event_completed)
            def on_train_loss(engine):
                plot_losses_tensorboard(self.writer, engine, engine, l.TRAIN)

            @train_loop.engine.on(train_metric_event._periodic_event_completed)
            def on_train_metric(engine):
                plot_metrics_tensorboard(self.writer, engine, engine, l.TRAIN)

            @train_loop.engine.on(val_event._periodic_event_completed)
            def on_val(engine):
                val_loop.run(1)

                plot_losses_tensorboard(self.writer, val_loop.engine, engine,
                                        l.VAL)
                plot_metrics_tensorboard(self.writer, val_loop.engine, engine,
                                         l.VAL)

            @train_loop.engine.on(visualize_event._periodic_event_completed)
            def on_visualize(engine):
                visualize_loop.run(1)

                plot_scores(self.writer, engine, visualize_loop.engine,
                            (eu.SCORE1, eu.SCORE2))
                plot_scores(self.writer, engine, visualize_loop.engine,
                            (eu.SAL_SCORE1, eu.SAL_SCORE2))
                plot_scores(self.writer, engine, visualize_loop.engine,
                            (eu.CONF_SCORE1, eu.CONF_SCORE2), True)

                plot_kp_matches(self.writer, engine, visualize_loop.engine,
                                visualize_config[exp.PX_THRESH])
                plot_desc_matches(self.writer, engine, visualize_loop.engine,
                                  visualize_config[exp.PX_THRESH],
                                  DescriptorDistance.INV_COS_SIM)

            return [train_loop, val_loop, visualize_loop]

        elif self.mode == l.TEST:
            test_loop = get_loop(self, self.mode)

            test_config = self.metric_config[l.TEST]
            test_config[l.MODE] = l.TEST
            test_config[du.DATASET_NAME] = list(
                self.dataset_config[self.mode].keys())

            self.model.attach(test_loop.engine, test_config)

            if du.AACHEN in test_config[du.DATASET_NAME]:

                @test_loop.engine.on(Events.ITERATION_COMPLETED)
                def on_iteration_completed(engine):
                    save_aachen_inference(self.dataset_config[self.mode],
                                          engine.state.output)

            else:

                @test_loop.engine.on(Events.EPOCH_COMPLETED)
                def on_epoch(engine):
                    logs = join_logs(engine)

                    print_summary(logs, test_config)
                    test_log_to_csv(self.log_dir, logs, test_config,
                                    self.model_config,
                                    test_config[du.DATASET_NAME])

            return [test_loop]

        elif self.mode == l.ANALYZE:
            analyze_loop = get_loop(self, self.mode)

            return [analyze_loop]

        else:
            return None