class EpisodeFPSHandler:
    def __init__(self):
        self._timer = Timer(average=True)

    def attach(self, engine: Engine):
        self._timer.attach(engine, step=EngineEvents.ITERATION_COMPLETED)
        engine.add_event_handler(EndOfEpisodeHandler.Events.EPISODE_COMPLETED, self)

    def __call__(self, engine: Engine):
        t_val = self._timer.value()
        if engine.state.iteration == 1:
            self._timer.reset()
        else:
            engine.state.metrics['fps'] = 1./t_val
        engine.state.metrics['time_passed'] = t_val * self._timer.step_count
Esempio n. 2
0
class EpisodeFPSHandler:
    FPS_METRIC = 'fps'
    AVG_FPS_METRIC = 'avg_fps'
    TIME_PASSED_METRIC = 'time_passed'

    def __init__(self, fps_mul: float = 1.0, fps_smooth_alpha: float = 0.98):
        self._timer = Timer(average=True)
        self._fps_mul = fps_mul
        self._started_ts = time.time()
        self._fps_smooth_alpha = fps_smooth_alpha

    def attach(self, engine: Engine, manual_step: bool = False):
        self._timer.attach(
            engine, step=None if manual_step else Events.ITERATION_COMPLETED)
        engine.add_event_handler(EpisodeEvents.EPISODE_COMPLETED, self)

    def step(self):
        """
        If manual_step=True on attach(), this method should be used every time we've communicated with environment
        to get proper FPS
        :return:
        """
        self._timer.step()

    def __call__(self, engine: Engine):
        t_val = self._timer.value()
        if engine.state.iteration > 1:
            fps = self._fps_mul / t_val
            avg_fps = engine.state.metrics.get(self.AVG_FPS_METRIC)
            if avg_fps is None:
                avg_fps = fps
            else:
                avg_fps *= self._fps_smooth_alpha
                avg_fps += (1 - self._fps_smooth_alpha) * fps
            engine.state.metrics[self.AVG_FPS_METRIC] = avg_fps
            engine.state.metrics[self.FPS_METRIC] = fps
        engine.state.metrics[
            self.TIME_PASSED_METRIC] = time.time() - self._started_ts
        self._timer.reset()
Esempio n. 3
0
def train(data,
          model,
          optimizer,
          model_seed=1,
          sampler_seed=1,
          max_epochs=120,
          patience=None,
          stopping_rule=None,
          compute_test_error_rates=False,
          loading_file_path=None,
          callback=None):

    # Checkpointing file path is named based on Mahler task ID
    checkpointing_file_path = get_checkpoint_file_path()

    if loading_file_path is None:
        loading_file_path = checkpointing_file_path
    # Else, we are branching from another configuration.

    print("\n\nLoading file path:")
    print(loading_file_path)

    print("\n\nCheckpointing file path:")
    print(checkpointing_file_path)
    print("\n\n")

    dataset, model, optimizer, lr_scheduler, device, seeds = build_experiment(
        data=data,
        model=model,
        optimizer=optimizer,
        model_seed=model_seed,
        sampler_seed=sampler_seed)

    if lr_scheduler is None and patience is None:
        patience = 20
    elif patience is None:
        patience = lr_scheduler.patience * 2

    print("\n\nMax epochs: {}\n\n".format(max_epochs))

    print("\n\nEarly stopping with patience: {}\n\n".format(patience))

    print('Building timers, training and evaluation loops...')
    timer = Timer(average=True)

    print('    Stopping timer')
    stopping_timer = Timer(average=True)

    print('    Training loop')
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        torch.nn.functional.cross_entropy,
                                        device=device)

    print('    Evaluator loop')
    evaluators, early_stopping = build_evaluators(trainer, model, device,
                                                  patience,
                                                  compute_test_error_rates)

    print('    Set timer events')
    timer.attach(trainer, start=Events.STARTED, step=Events.EPOCH_COMPLETED)

    print('    Metric logger')
    metric_logger = Logger()
    print('Done')

    all_stats = []
    best_stats = {}

    @trainer.on(Events.STARTED)
    def trainer_load_checkpoint(engine):
        engine.state.last_checkpoint = datetime.utcnow()
        metadata = load_checkpoint(loading_file_path, model, optimizer,
                                   lr_scheduler)
        if metadata:
            print('Resuming from epoch {}'.format(metadata['epoch']))
            print('Optimizer:')
            print('    lr:', optimizer.param_groups[0]['lr'])
            print('    momentum:', optimizer.param_groups[0]['momentum'])
            print('    weight decay:',
                  optimizer.param_groups[0]['weight_decay'])

            print('LR schedule:')
            print('    best:', lr_scheduler.best)
            print('    num_bad_epochs:', lr_scheduler.num_bad_epochs)
            print('    cooldown:', lr_scheduler.cooldown)

            engine.state.epoch = metadata['epoch']
            engine.state.iteration = metadata['iteration']
            for epoch_stats in metadata['all_stats']:
                tmp = engine.state.metrics
                engine.state.metrics = epoch_stats['valid']
                early_stopping(engine)
                engine.state.metrics = tmp

                all_stats.append(epoch_stats)
                if (not best_stats
                        or (epoch_stats['valid']['error_rate']['mean'] <
                            best_stats['valid']['error_rate']['mean'])):
                    best_stats.update(epoch_stats)

            print('Early stopping:')
            print('    best_score:', early_stopping.best_score)
            print('    counter:', early_stopping.counter)
        else:
            engine.state.epoch = 0
            engine.state.iteration = 0
            engine.state.output = 0.0
            # trainer_save_checkpoint(engine)

    @trainer.on(Events.EPOCH_STARTED)
    def trainer_seeding(engine):
        print(seeds['sampler'] + engine.state.epoch)
        seed(int(seeds['sampler'] + engine.state.epoch))
        model.train()

    @trainer.on(Events.EPOCH_COMPLETED)
    def trainer_save_checkpoint(engine):
        model.eval()

        stats = dict(epoch=engine.state.epoch)

        for name in ['valid', 'train', 'test']:
            evaluator = evaluators.get(name, None)
            if evaluator is None:
                continue

            loader = dataset[name]
            metrics = evaluator.run(loader).metrics
            stats[name] = dict(loss=metrics['nll'],
                               error_rate=metrics['error_rate'])

        print('Early stopping')
        print('{}   {} < {}'.format(early_stopping.best_score,
                                    early_stopping.counter,
                                    early_stopping.patience))

        current_v_error_rate = stats['valid']['error_rate']['mean']
        best_v_error_rate = best_stats.get('valid',
                                           {}).get('error_rate',
                                                   {}).get('mean', 100)

        if lr_scheduler:
            lr_scheduler.step(current_v_error_rate)
            print('Lr schedule')
            print('{}   last_epoch: {} bads: {} cooldown: {}'.format(
                lr_scheduler.best, lr_scheduler.last_epoch,
                lr_scheduler.num_bad_epochs, lr_scheduler.cooldown_counter))

        if not best_stats or current_v_error_rate < best_v_error_rate:
            best_stats.update(stats)

        # TODO: load all tasks with the same tags in mahler, compute the error_rate at that point
        #       (compare median of best error_rates up to that point vs this best_stats
        #       if below median, suspend
        #       maybe, interrupt and increase priority, or not... Because we would need to wait for
        #       it to completed anyway
        #       Grace period? Like 60 epochs? :/
        #       Or reduce quantile as time grows (stop worst 95th quantile at 10 epochs, 50th at
        #       100, 75th at 150 and so on...) Meh to much novelty.
        #       min trials at that point?
        #       or interrupt after each 10/20 epochs, so that number of trials is quickly high
        #       but that means we need a way to log results during execution, not just output.

        print(("Epoch {:>4} Iteration {:>12} Loss {:>8.3f} "
               "Best-Valid-ER {:>8.4f} Time {:>8.3f}").format(
                   engine.state.epoch, engine.state.iteration,
                   engine.state.output, best_v_error_rate, timer.value()))

        metric_logger.add_metric(stats)

        all_stats.append(stats)

        # TODO: Checkpoint lr_scheduler as well
        if (datetime.utcnow() -
                engine.state.last_checkpoint).total_seconds() > TIME_BUFFER:
            print('Checkpointing epoch {}'.format(engine.state.epoch))
            save_checkpoint(checkpointing_file_path,
                            model,
                            optimizer,
                            lr_scheduler,
                            epoch=engine.state.epoch,
                            iteration=engine.state.iteration,
                            all_stats=all_stats)
            engine.state.last_checkpoint = datetime.utcnow()

        if callback:
            callback(step=engine.state.epoch,
                     objective=stats['valid']['error_rate']['mean'],
                     finished=False)

    print("Training")
    trainer.run(dataset['train'], max_epochs=max_epochs)

    metric_logger.close()

    # Remove checkpoint to avoid cluttering the FS.
    clear_checkpoint(checkpointing_file_path)

    if callback:
        callback(step=max_epochs,
                 objective=all_stats[-1]['valid']['error_rate']['mean'],
                 finished=True)

    return {'best': best_stats, 'all': tuple(all_stats)}