class EpisodeFPSHandler: def __init__(self): self._timer = Timer(average=True) def attach(self, engine: Engine): self._timer.attach(engine, step=EngineEvents.ITERATION_COMPLETED) engine.add_event_handler(EndOfEpisodeHandler.Events.EPISODE_COMPLETED, self) def __call__(self, engine: Engine): t_val = self._timer.value() if engine.state.iteration == 1: self._timer.reset() else: engine.state.metrics['fps'] = 1./t_val engine.state.metrics['time_passed'] = t_val * self._timer.step_count
class EpisodeFPSHandler: FPS_METRIC = 'fps' AVG_FPS_METRIC = 'avg_fps' TIME_PASSED_METRIC = 'time_passed' def __init__(self, fps_mul: float = 1.0, fps_smooth_alpha: float = 0.98): self._timer = Timer(average=True) self._fps_mul = fps_mul self._started_ts = time.time() self._fps_smooth_alpha = fps_smooth_alpha def attach(self, engine: Engine, manual_step: bool = False): self._timer.attach( engine, step=None if manual_step else Events.ITERATION_COMPLETED) engine.add_event_handler(EpisodeEvents.EPISODE_COMPLETED, self) def step(self): """ If manual_step=True on attach(), this method should be used every time we've communicated with environment to get proper FPS :return: """ self._timer.step() def __call__(self, engine: Engine): t_val = self._timer.value() if engine.state.iteration > 1: fps = self._fps_mul / t_val avg_fps = engine.state.metrics.get(self.AVG_FPS_METRIC) if avg_fps is None: avg_fps = fps else: avg_fps *= self._fps_smooth_alpha avg_fps += (1 - self._fps_smooth_alpha) * fps engine.state.metrics[self.AVG_FPS_METRIC] = avg_fps engine.state.metrics[self.FPS_METRIC] = fps engine.state.metrics[ self.TIME_PASSED_METRIC] = time.time() - self._started_ts self._timer.reset()
def train(data, model, optimizer, model_seed=1, sampler_seed=1, max_epochs=120, patience=None, stopping_rule=None, compute_test_error_rates=False, loading_file_path=None, callback=None): # Checkpointing file path is named based on Mahler task ID checkpointing_file_path = get_checkpoint_file_path() if loading_file_path is None: loading_file_path = checkpointing_file_path # Else, we are branching from another configuration. print("\n\nLoading file path:") print(loading_file_path) print("\n\nCheckpointing file path:") print(checkpointing_file_path) print("\n\n") dataset, model, optimizer, lr_scheduler, device, seeds = build_experiment( data=data, model=model, optimizer=optimizer, model_seed=model_seed, sampler_seed=sampler_seed) if lr_scheduler is None and patience is None: patience = 20 elif patience is None: patience = lr_scheduler.patience * 2 print("\n\nMax epochs: {}\n\n".format(max_epochs)) print("\n\nEarly stopping with patience: {}\n\n".format(patience)) print('Building timers, training and evaluation loops...') timer = Timer(average=True) print(' Stopping timer') stopping_timer = Timer(average=True) print(' Training loop') trainer = create_supervised_trainer(model, optimizer, torch.nn.functional.cross_entropy, device=device) print(' Evaluator loop') evaluators, early_stopping = build_evaluators(trainer, model, device, patience, compute_test_error_rates) print(' Set timer events') timer.attach(trainer, start=Events.STARTED, step=Events.EPOCH_COMPLETED) print(' Metric logger') metric_logger = Logger() print('Done') all_stats = [] best_stats = {} @trainer.on(Events.STARTED) def trainer_load_checkpoint(engine): engine.state.last_checkpoint = datetime.utcnow() metadata = load_checkpoint(loading_file_path, model, optimizer, lr_scheduler) if metadata: print('Resuming from epoch {}'.format(metadata['epoch'])) print('Optimizer:') print(' lr:', optimizer.param_groups[0]['lr']) print(' momentum:', optimizer.param_groups[0]['momentum']) print(' weight decay:', optimizer.param_groups[0]['weight_decay']) print('LR schedule:') print(' best:', lr_scheduler.best) print(' num_bad_epochs:', lr_scheduler.num_bad_epochs) print(' cooldown:', lr_scheduler.cooldown) engine.state.epoch = metadata['epoch'] engine.state.iteration = metadata['iteration'] for epoch_stats in metadata['all_stats']: tmp = engine.state.metrics engine.state.metrics = epoch_stats['valid'] early_stopping(engine) engine.state.metrics = tmp all_stats.append(epoch_stats) if (not best_stats or (epoch_stats['valid']['error_rate']['mean'] < best_stats['valid']['error_rate']['mean'])): best_stats.update(epoch_stats) print('Early stopping:') print(' best_score:', early_stopping.best_score) print(' counter:', early_stopping.counter) else: engine.state.epoch = 0 engine.state.iteration = 0 engine.state.output = 0.0 # trainer_save_checkpoint(engine) @trainer.on(Events.EPOCH_STARTED) def trainer_seeding(engine): print(seeds['sampler'] + engine.state.epoch) seed(int(seeds['sampler'] + engine.state.epoch)) model.train() @trainer.on(Events.EPOCH_COMPLETED) def trainer_save_checkpoint(engine): model.eval() stats = dict(epoch=engine.state.epoch) for name in ['valid', 'train', 'test']: evaluator = evaluators.get(name, None) if evaluator is None: continue loader = dataset[name] metrics = evaluator.run(loader).metrics stats[name] = dict(loss=metrics['nll'], error_rate=metrics['error_rate']) print('Early stopping') print('{} {} < {}'.format(early_stopping.best_score, early_stopping.counter, early_stopping.patience)) current_v_error_rate = stats['valid']['error_rate']['mean'] best_v_error_rate = best_stats.get('valid', {}).get('error_rate', {}).get('mean', 100) if lr_scheduler: lr_scheduler.step(current_v_error_rate) print('Lr schedule') print('{} last_epoch: {} bads: {} cooldown: {}'.format( lr_scheduler.best, lr_scheduler.last_epoch, lr_scheduler.num_bad_epochs, lr_scheduler.cooldown_counter)) if not best_stats or current_v_error_rate < best_v_error_rate: best_stats.update(stats) # TODO: load all tasks with the same tags in mahler, compute the error_rate at that point # (compare median of best error_rates up to that point vs this best_stats # if below median, suspend # maybe, interrupt and increase priority, or not... Because we would need to wait for # it to completed anyway # Grace period? Like 60 epochs? :/ # Or reduce quantile as time grows (stop worst 95th quantile at 10 epochs, 50th at # 100, 75th at 150 and so on...) Meh to much novelty. # min trials at that point? # or interrupt after each 10/20 epochs, so that number of trials is quickly high # but that means we need a way to log results during execution, not just output. print(("Epoch {:>4} Iteration {:>12} Loss {:>8.3f} " "Best-Valid-ER {:>8.4f} Time {:>8.3f}").format( engine.state.epoch, engine.state.iteration, engine.state.output, best_v_error_rate, timer.value())) metric_logger.add_metric(stats) all_stats.append(stats) # TODO: Checkpoint lr_scheduler as well if (datetime.utcnow() - engine.state.last_checkpoint).total_seconds() > TIME_BUFFER: print('Checkpointing epoch {}'.format(engine.state.epoch)) save_checkpoint(checkpointing_file_path, model, optimizer, lr_scheduler, epoch=engine.state.epoch, iteration=engine.state.iteration, all_stats=all_stats) engine.state.last_checkpoint = datetime.utcnow() if callback: callback(step=engine.state.epoch, objective=stats['valid']['error_rate']['mean'], finished=False) print("Training") trainer.run(dataset['train'], max_epochs=max_epochs) metric_logger.close() # Remove checkpoint to avoid cluttering the FS. clear_checkpoint(checkpointing_file_path) if callback: callback(step=max_epochs, objective=all_stats[-1]['valid']['error_rate']['mean'], finished=True) return {'best': best_stats, 'all': tuple(all_stats)}