def test_remove_event_handler_on_callable_events(): engine = Engine(lambda e, b: 1) def foo(e): pass assert not engine.has_event_handler(foo) engine.add_event_handler(Events.EPOCH_STARTED, foo) assert engine.has_event_handler(foo) engine.remove_event_handler(foo, Events.EPOCH_STARTED) assert not engine.has_event_handler(foo) def bar(e): pass engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar) assert engine.has_event_handler(bar) engine.remove_event_handler(bar, Events.EPOCH_COMPLETED) assert not engine.has_event_handler(bar) engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar) assert engine.has_event_handler(bar) engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3)) assert not engine.has_event_handler(bar)
def setup_event_handler(trainer, evaluator, train_loader, test_loader): log_interval = 10 writer = SummaryWriter(log_dir=log_dir) @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(trainer): print("Epoch[{}] Loss: {:.5f}".format(trainer.state.epoch, trainer.state.output)) writer.add_scalar("training_iteration_loss", trainer.state.output, trainer.state.epoch) @trainer.on(Events.EPOCH_COMPLETED(every=log_interval)) def log_training_results(trainer): evaluator.run(train_loader) metrics = evaluator.state.metrics print("Training Results - Epoch: {} Accuracy: {:.5f} Loss: {:.5f}". format(trainer.state.epoch, metrics["accuracy"], metrics["nll"])) writer.add_scalar("training_loss", metrics["nll"], trainer.state.epoch) writer.add_scalar("training_accuracy", metrics["accuracy"], trainer.state.epoch) @trainer.on(Events.EPOCH_COMPLETED(every=log_interval)) def log_testing_results(trainer): evaluator.run(test_loader) metrics = evaluator.state.metrics print("Validation Results - Epoch: {} Accuracy: {:.5f} Loss: {:.5f}". format(trainer.state.epoch, metrics["accuracy"], metrics["nll"])) writer.add_scalar("testing_loss", metrics["nll"], trainer.state.epoch) writer.add_scalar("testing_accuracy", metrics["accuracy"], trainer.state.epoch)
def main(width, depth, max_epochs, state_dict_path, device, data_dir, num_workers): """ This function constructs and trains a model from scratch, without any knowledge transfer method applied. :param int depth: factor for controlling the depth of the model. :param int width: factor for controlling the width of the model. :param int max_epochs: maximum number of epochs for training the student model. :param string state_dict_path: path to save the trained model. :param int device: device to use for training the model. :param string data_dir: directory to save and load the dataset. :param int num_workers: number of workers to use for loading the dataset. """ # Define the device for training the model. device = torch.device(device) # Get data loaders for the CIFAR-10 dataset. train_loader, validation_loader, test_loader = get_cifar10_loaders( data_dir, batch_size=BATCH_SIZE, num_workers=num_workers ) # Construct the model to be trained. model = WideResidualNetwork(depth=depth, width=width) model = model.to(device) # Define optimizer and learning rate scheduler. optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=LEARNING_RATE_DECAY_MILESTONES, gamma=LEARNING_RATE_DECAY_FACTOR ) # Construct the loss function to be used for training. criterion = torch.nn.CrossEntropyLoss() # Define the ignite engines for training and evaluation. batch_updater = BatchUpdaterWithoutTransfer(model=model, optimizer=optimizer, criterion=criterion, device=device) batch_evaluator = BatchEvaluator(model=model, device=device) trainer = Engine(batch_updater) evaluator = Engine(batch_evaluator) # Define and attach the progress bar, loss metric, and the accuracy metrics. attach_pbar_and_metrics(trainer, evaluator) # The training engine updates the learning rate schedule at end of each epoch. lr_updater = LearningRateUpdater(lr_scheduler=lr_scheduler) trainer.on(Events.EPOCH_COMPLETED(every=1))(lr_updater) # The training engine logs the training and the evaluation metrics at end of each epoch. metric_logger = MetricLogger(evaluator=evaluator, eval_loader=validation_loader) trainer.on(Events.EPOCH_COMPLETED(every=1))(metric_logger) # Train the model trainer.run(train_loader, max_epochs=max_epochs) # Save the model to pre-defined path. We move the model to CPU which is desirable as the default device # for loading the model. model.cpu() state_dict_dir = "/".join(state_dict_path.split("/")[:-1]) os.makedirs(state_dict_dir, exist_ok=True) torch.save(model.state_dict(), state_dict_path)
def test_remove_event_handler_on_callable_events(): engine = Engine(lambda e, b: 1) def foo(e): pass assert not engine.has_event_handler(foo) engine.add_event_handler(Events.EPOCH_STARTED, foo) assert engine.has_event_handler(foo) engine.remove_event_handler(foo, Events.EPOCH_STARTED) assert not engine.has_event_handler(foo) def bar(e): pass engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar) assert engine.has_event_handler(bar) engine.remove_event_handler(bar, Events.EPOCH_COMPLETED) assert not engine.has_event_handler(foo) with pytest.raises( TypeError, match=r"Argument event_name should not be a filtered event"): engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
def test_custom_events_asserts(): # Dummy engine engine = Engine(lambda engine, batch: 0) class A: pass with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum"): engine.register_events(None) with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum"): engine.register_events("str", None) with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum"): engine.register_events(1) with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum"): engine.register_events(A()) assert Events.EPOCH_COMPLETED != 1 assert Events.EPOCH_COMPLETED != "abc" assert Events.ITERATION_COMPLETED != Events.EPOCH_COMPLETED assert Events.ITERATION_COMPLETED != Events.EPOCH_COMPLETED(every=2) # In current implementation, EPOCH_COMPLETED and EPOCH_COMPLETED with event filter are the same assert Events.EPOCH_COMPLETED == Events.EPOCH_COMPLETED(every=2) assert Events.ITERATION_COMPLETED == Events.ITERATION_COMPLETED(every=2)
def init(self): assert 'engine' in self.frame, 'The frame does not have engine.' shutil.copy(self.frame.config_path, self.save_handler.dirname) checkpoint = self.Checkpoint(self.modules, self.frame) self.frame['engine'].engine.add_event_handler( Events.EPOCH_COMPLETED(every=self.save_interval), self, {'checkpoint': checkpoint}) self.frame['engine'].engine.add_event_handler( Events.EPOCH_COMPLETED(every=self.save_interval), self._correct_checkpoint)
def test_state_get_event_attrib_value(): state = State() state.iteration = 10 state.epoch = 9 e = Events.ITERATION_STARTED assert state.get_event_attrib_value(e) == state.iteration e = Events.ITERATION_COMPLETED assert state.get_event_attrib_value(e) == state.iteration e = Events.EPOCH_STARTED assert state.get_event_attrib_value(e) == state.epoch e = Events.EPOCH_COMPLETED assert state.get_event_attrib_value(e) == state.epoch e = Events.STARTED assert state.get_event_attrib_value(e) == state.epoch e = Events.COMPLETED assert state.get_event_attrib_value(e) == state.epoch e = Events.ITERATION_STARTED(every=10) assert state.get_event_attrib_value(e) == state.iteration e = Events.ITERATION_COMPLETED(every=10) assert state.get_event_attrib_value(e) == state.iteration e = Events.EPOCH_STARTED(once=5) assert state.get_event_attrib_value(e) == state.epoch e = Events.EPOCH_COMPLETED(once=5) assert state.get_event_attrib_value(e) == state.epoch
def create_trainer(model, optimizer, loss_fn, lr_scheduler, config): # Define any training logic for iteration update def train_step(engine, batch): x = batch[0].to(idist.device()) y = batch[1].to(idist.device()) model.train() y_pred = model(x) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() return loss.item() # Define trainer engine trainer = Engine(train_step) if idist.get_rank() == 0: # Add any custom handlers @trainer.on(Events.EPOCH_COMPLETED(every=1)) def save_checkpoint(): model_path = os.path.join((config.get("output_path", "output")), "checkpoint.pt") torch.save(model.state_dict(), model_path) # Add progress bar showing batch loss value ProgressBar().attach(trainer, output_transform=lambda x: {"batch loss": x}) return trainer
def get_event_by_freq(freq: Union[int, Epochs, Iters]): if isinstance(freq, int): freq = Epochs(freq) if isinstance(freq, Epochs): return Events.EPOCH_COMPLETED(every=freq.n) elif isinstance(freq, Iters): return Events.ITERATION_COMPLETED(every=freq.n)
def configure_checkpoint_saving(trainer, evaluator, model, optimizer, args): to_save = {"model": model, "optimizer": optimizer} save_handler = DiskSaver(str(args.output_dir), create_dir=False, require_empty=False) # Configure epoch checkpoints. interval = 1 if args.dev_mode else min(5, args.max_epochs) checkpoint = Checkpoint( to_save, save_handler, n_saved=None, global_step_transform=lambda *_: trainer.state.epoch) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=interval), checkpoint, evaluator) # Configure "best score" checkpoints. metric_name = "accuracy" best_checkpoint = Checkpoint( to_save, save_handler, score_name=metric_name, score_function=lambda engine: engine.state.metrics[metric_name], filename_prefix="best") trainer.add_event_handler(Events.EPOCH_COMPLETED, best_checkpoint, evaluator)
def training(local_rank, config): # Setup dataflow and train_loader, val_loader = get_dataflow(config) model, optimizer, criterion, lr_scheduler = initialize(config) # Setup model trainer and evaluator trainer = create_trainer(model, optimizer, criterion, lr_scheduler, config) evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy()}, device=idist.device()) # Run model evaluation every 3 epochs and show results @trainer.on(Events.EPOCH_COMPLETED(every=3)) def evaluate_model(): state = evaluator.run(val_loader) if idist.get_rank() == 0: print(state.metrics) # Setup tensorboard experiment tracking if idist.get_rank() == 0: tb_logger = common.setup_tb_logging( config.get("output_path", "output"), trainer, optimizer, evaluators={"validation": evaluator}, ) trainer.run(train_loader, max_epochs=config.get("max_epochs", 3)) if idist.get_rank() == 0: tb_logger.close()
def attach(self, engine): if self.epoch_level: engine.add_event_handler( Events.EPOCH_COMPLETED(every=self.interval), self) else: engine.add_event_handler( Events.ITERATION_COMPLETED(every=self.interval), self)
def run(cfg, train_loader, tr_comp, saver, trainer, valid_dict): # TODO resume # trainer = Engine(...) # trainer.load_state_dict(state_dict) # trainer.run(data) # checkpoint handler = ModelCheckpoint(saver.model_dir, 'train', n_saved=3, create_dir=True) checkpoint_params = tr_comp.state_dict() trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, checkpoint_params) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer names = ["Acc", "Loss"] names.extend(tr_comp.loss_function_map.keys()) for n in names: RunningAverage(output_transform=Run(n)).attach(trainer, n) @trainer.on(Events.EPOCH_COMPLETED) def adjust_learning_rate(engine): tr_comp.scheduler.step() @trainer.on(Events.ITERATION_COMPLETED(every=cfg.TRAIN.LOG_ITER_PERIOD)) def log_training_loss(engine): message = f"Epoch[{engine.state.epoch}], " + \ f"Iteration[{engine.state.iteration}/{len(train_loader)}], " + \ f"Base Lr: {tr_comp.scheduler.get_last_lr()[0]:.2e}, " for loss_name in engine.state.metrics.keys(): message += f"{loss_name}: {engine.state.metrics[loss_name]:.4f}, " if tr_comp.xent and tr_comp.xent.learning_weight: message += f"xentWeight: {tr_comp.xent.uncertainty.mean().item():.4f}, " logger.info(message) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 80) timer.reset() @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD)) def log_validation_results(engine): logger.info(f"Valid - Epoch: {engine.state.epoch}") eval_multi_dataset(cfg, valid_dict, tr_comp) trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS)
def run(self, epochs: int = 1): trainer = self.trainer train_loader = self.dataloader["train"] val_loader = self.dataloader["val"] @trainer.on(Events.ITERATION_COMPLETED(every=self.log_interval)) def log_training_loss(engine): length = len(train_loader) self.logger.info(f"Epoch[{engine.state.epoch}] " f"Iteration[{engine.state.iteration}/{length}] " f"Loss: {engine.state.output:.2f}") self.writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): self.evaluator.run(train_loader) metrics = self.evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_loss = metrics["loss"] self.logger.info( f"Training Results - Epoch: {engine.state.epoch} " f"Avg accuracy: {avg_accuracy:.2f}" f"Avg loss: {avg_loss:.2f}") self.writer.add_scalar("training/avg_loss", avg_loss, engine.state.epoch) self.writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): self.evaluator.run(val_loader) metrics = self.evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_loss = metrics["loss"] self.logger.info( f"Validation Results - Epoch: {engine.state.epoch} " f"Avg accuracy: {avg_accuracy:.2f} " f"Avg loss: {avg_loss:.2f}") self.writer.add_scalar("valdation/avg_loss", avg_loss, engine.state.epoch) self.writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch) objects_to_checkpoint = dict(model=self.model, optimizer=self.optimizer) training_checkpoint = Checkpoint( to_save=objects_to_checkpoint, save_handler=DiskSaver(self.log_dir, require_empty=False), n_saved=None, global_step_transform=lambda *_: trainer.state.epoch, ) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=self.checkpoint_every), training_checkpoint, ) trainer.run(train_loader, max_epochs=epochs)
def add_logging(self): # Add validation logging self.train_engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), self.evaluate_model) # Add step length update at the end of each epoch self.train_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self.scheduler.step())
def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self.epoch_level: engine.add_event_handler( Events.EPOCH_COMPLETED(every=self.interval), self) else: engine.add_event_handler( Events.ITERATION_COMPLETED(every=self.interval), self)
def attach(self, engine) -> None: event_filter = lambda engine, event: True if ( event >= (self.start) and (event - self.start) % self.interval == 0) else False if self.epoch_level: engine.add_event_handler( Events.EPOCH_COMPLETED(event_filter=event_filter), self) else: engine.add_event_handler( Events.ITERATION_COMPLETED(event_filter=event_filter), self)
def test_pbar_wrong_events_order(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach( engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED, ) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach( engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED, ) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach( engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED, ) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach( engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED, ) with pytest.raises(ValueError, match="Closing event should not use any event filter"): pbar.attach( engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10), )
def train_epochs(self, max_epochs): self.trainer = Engine(self.train_one_step) self.evaluator = Engine(self.evaluate_one_step) self.metrics = {'Loss': Loss(self.criterion), 'Acc': Accuracy()} for name, metric in self.metrics.items(): metric.attach(self.evaluator, name) with SummaryWriter( log_dir="/tmp/tensorboard/Transform" + str(type(self))[17:len(str(type(self))) - 2]) as writer: @self.trainer.on(Events.EPOCH_COMPLETED(every=1)) # Cada 1 epocas def log_results(engine): # Evaluo el conjunto de entrenamiento self.eval() self.evaluator.run(self.train_loader) writer.add_scalar("train/loss", self.evaluator.state.metrics['Loss'], engine.state.epoch) writer.add_scalar("train/accy", self.evaluator.state.metrics['Acc'], engine.state.epoch) # Evaluo el conjunto de validación self.evaluator.run(self.valid_loader) writer.add_scalar("valid/loss", self.evaluator.state.metrics['Loss'], engine.state.epoch) writer.add_scalar("valid/accy", self.evaluator.state.metrics['Acc'], engine.state.epoch) self.train() # Guardo el mejor modelo en validación best_model_handler = ModelCheckpoint( dirname='.', require_empty=False, filename_prefix="best", n_saved=1, score_function=lambda engine: -engine.state.metrics['Loss'], score_name="val_loss") # Lo siguiente se ejecuta cada ves que termine el loop de validación self.evaluator.add_event_handler( Events.COMPLETED, best_model_handler, { f'Transform{str(type(self))[17:len(str(type(self)))-2]}': model }) self.trainer.run(self.train_loader, max_epochs=max_epochs)
def run(subj_ind: int, result_name: str, dataset_path: str, deep4_path: str, result_path: str, config: dict = default_config, model_builder: ProgressiveModelBuilder = default_model_builder): result_path_subj = os.path.join(result_path, result_name, str(subj_ind)) os.makedirs(result_path_subj, exist_ok=True) joblib.dump(config, os.path.join(result_path_subj, 'config.dict'), compress=False) joblib.dump(model_builder, os.path.join(result_path_subj, 'model_builder.jblb'), compress=True) # create discriminator and generator modules discriminator = model_builder.build_discriminator() generator = model_builder.build_generator() # initiate weights generator.apply(weight_filler) discriminator.apply(weight_filler) # trainer engine trainer = GanSoftplusTrainer(10, discriminator, generator, config['r1_gamma'], config['r2_gamma']) # handles potential progression after each epoch progression_handler = ProgressionHandler( discriminator, generator, config['n_stages'], config['use_fade'], config['n_epochs_fade'], freeze_stages=config['freeze_stages']) progression_handler.set_progression(0, 1.) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), progression_handler.advance_alpha) generator.train() discriminator.train() train(subj_ind, dataset_path, deep4_path, result_path_subj, progression_handler, trainer, config['n_batch'], config['lr_d'], config['lr_g'], config['betas'], config['n_epochs_per_stage'], config['n_epochs_metrics'], config['plot_every_epoch'], config['orig_fs'])
def train(self, epochs: int, train_loader, test_loader=None, trainsize=None, valsize=None): self.model.train() train_engine = Engine(lambda e, b: self.train_step(b)) @train_engine.on(Events.EPOCH_COMPLETED(every=self.track_loss_freq)) def eval_test(engine): if self.track_loss: self.tb_log(train_loader, engine.state.epoch, is_train=True, eval_length=valsize) if test_loader is not None: self.tb_log(test_loader, engine.state.epoch, is_train=False, eval_length=valsize) @train_engine.on(Events.EPOCH_COMPLETED) def save_state(engine): torch.save(self.model.state_dict(), self.snail_path) torch.save(self.opt.state_dict(), self.snail_opt_path) @train_engine.on( Events.ITERATION_COMPLETED(every=self.track_params_freq)) def tb_log_histogram_params(engine): if self.track_layers: for name, params in self.model.named_parameters(): self.logger.add_histogram(name.replace('.', '/'), params, engine.state.iteration) if params.grad is not None: self.logger.add_histogram( name.replace('.', '/') + '/grad', params.grad, engine.state.iteration) if self.trainpbar: RunningAverage(output_transform=lambda x: x).attach( train_engine, 'loss') p = ProgressBar() p.attach(train_engine, ['loss']) train_engine.run(train_loader, max_epochs=epochs, epoch_length=trainsize)
def setup_checkpoints(trainer, obj_to_save, epoch_length, conf): # type: (Engine, Dict[str, Any], int, DictConfig) -> None cp = conf.checkpoints save_path = cp.get('save_dir', os.getcwd()) logging.info("Saving checkpoints to {}".format(save_path)) max_cp = max(int(cp.get('max_checkpoints', 1)), 1) save = DiskSaver(save_path, create_dir=True, require_empty=True) make_checkpoint = Checkpoint(obj_to_save, save, n_saved=max_cp) cp_iter = cp.interval_iteration cp_epoch = cp.interval_epoch if cp_iter > 0: save_event = Events.ITERATION_COMPLETED(every=cp_iter) trainer.add_event_handler(save_event, make_checkpoint) if cp_epoch > 0: if cp_iter < 1 or epoch_length % cp_iter: save_event = Events.EPOCH_COMPLETED(every=cp_epoch) trainer.add_event_handler(save_event, make_checkpoint)
def setup_evaluation( trainer: Engine, evaluators: Dict[str, Engine], data_loaders: Dict[str, DataLoader], logger: Logger, ) -> None: # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score def _evaluation(engine: Engine) -> None: epoch = trainer.state.epoch for split in ["train", "val", "test"]: state = evaluators[split].run(data_loaders[split]) log_metrics(logger, epoch, state.times["COMPLETED"], split, state.metrics) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config.validate_every) | Events.COMPLETED, _evaluation, ) return
def attach(self, engine: Engine): if self._name is None: self.logger = engine.logger if self._final_checkpoint is not None: engine.add_event_handler(Events.COMPLETED, self.completed) engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised) if self._key_metric_checkpoint is not None: engine.add_event_handler(Events.EPOCH_COMPLETED, self.metrics_completed) if self._interval_checkpoint is not None: if self.epoch_level: engine.add_event_handler( Events.EPOCH_COMPLETED(every=self.save_interval), self.interval_completed) else: engine.add_event_handler( Events.ITERATION_COMPLETED(every=self.save_interval), self.interval_completed)
# Define mean dice metric and Evaluator. validation_every_n_epochs = 1 val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)} evaluator = create_supervised_evaluator(net, val_metrics, device, True, output_transform=lambda x, y, y_pred: (y_pred[0], y)) val_stats_handler = StatsHandler() val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric('Mean Dice'), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) state = trainer.run(loader, train_epochs)
def fit(self, dataset, fold=0, train_split='train', valid_split='val'): """Fit the predictor model. Args: - dataset: temporal, static, label, time, treatment information - fold: Cross validation fold - train_split: training set splitting parameter - valid_split: validation set splitting parameter Returns: - self.predictor_model: trained predictor model """ train_x, train_y = self._data_preprocess(dataset, fold, train_split) valid_x, valid_y = self._data_preprocess(dataset, fold, valid_split) train_dataset = torch.utils.data.dataset.TensorDataset( self._make_tensor(train_x), self._make_tensor(train_y)) valid_dataset = torch.utils.data.dataset.TensorDataset( self._make_tensor(valid_x), self._make_tensor(valid_y)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=self.batch_size, shuffle=True) if self.predictor_model is None: self.predictor_model = TransformerModule( self.task, dataset.problem, train_x.shape[-1], self.h_dim, train_y.shape[-1], self.n_head, self.n_layer).to(self.device) self.optimizer = torch.optim.Adam( self.predictor_model.parameters(), lr=self.learning_rate) self.predictor_model.train() # classification vs regression # static vs dynamic trainer = create_supervised_trainer(self.predictor_model, self.optimizer, self.predictor_model.loss_fn) evaluator = create_supervised_evaluator( self.predictor_model, metrics={'loss': Loss(self.predictor_model.loss_fn)}) # model check point checkpoint_handler = ModelCheckpoint(self.model_path, self.model_id, n_saved=1, create_dir=True, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpoint_handler, {'model': self.predictor_model}) # early stopping def score_function(engine): val_loss = engine.state.metrics['loss'] return -val_loss early_stopping_handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, early_stopping_handler) # evaluation loss @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): evaluator.run(val_loader) metrics = evaluator.state.metrics print("Validation Results - Epoch[{}] Avg loss: {:.2f}".format( trainer.state.epoch, metrics['loss'])) trainer.run(train_loader, max_epochs=self.epoch) return self.predictor_model
def _train(save_iter=None, save_epoch=None, sd=None): w_norms = [] grad_norms = [] data = [] chkpt = [] manual_seed(12) arch = [ nn.Conv2d(3, 10, 3), nn.ReLU(), nn.Conv2d(10, 10, 3), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2), ] if with_dropout: arch.insert(2, nn.Dropout2d()) arch.insert(-2, nn.Dropout()) model = nn.Sequential(*arch).to(device) opt = SGD(model.parameters(), lr=0.001) def proc_fn(e, b): from ignite.engine.deterministic import _get_rng_states, _repr_rng_state s = _repr_rng_state(_get_rng_states()) model.train() opt.zero_grad() y = model(b.to(device)) y.sum().backward() opt.step() if debug: print(trainer.state.iteration, trainer.state.epoch, "proc_fn - b.shape", b.shape, torch.norm(y).item(), s) trainer = DeterministicEngine(proc_fn) if save_iter is not None: ev = Events.ITERATION_COMPLETED(once=save_iter) elif save_epoch is not None: ev = Events.EPOCH_COMPLETED(once=save_epoch) save_iter = save_epoch * (data_size // batch_size) @trainer.on(ev) def save_chkpt(_): if debug: print(trainer.state.iteration, "save_chkpt") fp = dirname / "test.pt" from ignite.engine.deterministic import _repr_rng_state tsd = trainer.state_dict() if debug: print("->", _repr_rng_state(tsd["rng_states"])) torch.save([model.state_dict(), opt.state_dict(), tsd], fp) chkpt.append(fp) def log_event_filter(_, event): if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5: return True return False @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter)) def write_data_grads_weights(e): x = e.state.batch i = e.state.iteration data.append([i, x.mean().item(), x.std().item()]) total = [0.0, 0.0] out1 = [] out2 = [] for p in model.parameters(): n1 = torch.norm(p).item() n2 = torch.norm(p.grad).item() out1.append(n1) out2.append(n2) total[0] += n1 total[1] += n2 w_norms.append([i, total[0]] + out1) grad_norms.append([i, total[1]] + out2) if sd is not None: sd = torch.load(sd) model.load_state_dict(sd[0]) opt.load_state_dict(sd[1]) from ignite.engine.deterministic import _repr_rng_state if debug: print("-->", _repr_rng_state(sd[2]["rng_states"])) trainer.load_state_dict(sd[2]) manual_seed(32) trainer.run(random_train_data_loader(size=data_size), max_epochs=5) return { "sd": chkpt, "data": data, "grads": grad_norms, "weights": w_norms }
def create_trainer( train_step, output_names, model, ema_model, optimizer, lr_scheduler, supervised_train_loader, test_loader, cfg, logger, cta=None, unsup_train_loader=None, cta_probe_loader=None, ): trainer = Engine(train_step) trainer.logger = logger output_path = os.getcwd() to_save = { "model": model, "ema_model": ema_model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler, } if cta is not None: to_save["cta"] = cta common.setup_common_training_handlers( trainer, train_sampler=supervised_train_loader.sampler, to_save=to_save, save_every_iters=cfg.solver.checkpoint_every, output_path=output_path, output_names=output_names, lr_scheduler=lr_scheduler, with_pbars=False, clear_cuda_cache=False, ) ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED ) unsupervised_train_loader_iter = None if unsup_train_loader is not None: unsupervised_train_loader_iter = cycle(unsup_train_loader) cta_probe_loader_iter = None if cta_probe_loader is not None: cta_probe_loader_iter = cycle(cta_probe_loader) # Setup handler to prepare data batches @trainer.on(Events.ITERATION_STARTED) def prepare_batch(e): sup_batch = e.state.batch e.state.batch = { "sup_batch": sup_batch, } if unsupervised_train_loader_iter is not None: unsup_batch = next(unsupervised_train_loader_iter) e.state.batch["unsup_batch"] = unsup_batch if cta_probe_loader_iter is not None: cta_probe_batch = next(cta_probe_loader_iter) cta_probe_batch["policy"] = [ deserialize(p) for p in cta_probe_batch["policy"] ] e.state.batch["cta_probe_batch"] = cta_probe_batch # Setup handler to update EMA model @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay) def update_ema_model(ema_decay): # EMA on parametes for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay) # Setup handlers for debugging if cfg.debug: @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100)) @idist.one_rank_only() def log_weights_norms(): wn = [] ema_wn = [] for ema_param, param in zip(ema_model.parameters(), model.parameters()): wn.append(torch.mean(param.data)) ema_wn.append(torch.mean(ema_param.data)) msg = "\n\nWeights norms" msg += "\n- Raw model: {}".format( to_list_str(torch.tensor(wn[:10] + wn[-10:])) ) msg += "\n- EMA model: {}\n".format( to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:])) ) logger.info(msg) rmn = [] rvar = [] ema_rmn = [] ema_rvar = [] for m1, m2 in zip(model.modules(), ema_model.modules()): if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d): rmn.append(torch.mean(m1.running_mean)) rvar.append(torch.mean(m1.running_var)) ema_rmn.append(torch.mean(m2.running_mean)) ema_rvar.append(torch.mean(m2.running_var)) msg = "\n\nBN buffers" msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10]))) msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10]))) msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10]))) msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10]))) logger.info(msg) # TODO: Need to inspect a bug # if idist.get_rank() == 0: # from ignite.contrib.handlers import ProgressBar # # profiler = BasicTimeProfiler() # profiler.attach(trainer) # # @trainer.on(Events.ITERATION_COMPLETED(every=200)) # def log_profiling(_): # results = profiler.get_results() # profiler.print_results(results) # Setup validation engine metrics = { "accuracy": Accuracy(), } if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU): metrics.update({ "precision": Precision(average=False), "recall": Recall(average=False), }) eval_kwargs = dict( metrics=metrics, prepare_batch=sup_prepare_batch, device=idist.device(), non_blocking=True, ) evaluator = create_supervised_evaluator(model, **eval_kwargs) ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs) def log_results(epoch, max_epochs, metrics, ema_metrics): msg1 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()] ) msg2 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()] ) logger.info( "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2) ) if cta is not None: logger.info("\n" + stats(cta)) @trainer.on( Events.EPOCH_COMPLETED(every=cfg.solver.validate_every) | Events.STARTED | Events.COMPLETED ) def run_evaluation(): evaluator.run(test_loader) ema_evaluator.run(test_loader) log_results( trainer.state.epoch, trainer.state.max_epochs, evaluator.state.metrics, ema_evaluator.state.metrics, ) # setup TB logging if idist.get_rank() == 0: tb_logger = common.setup_tb_logging( output_path, trainer, optimizers=optimizer, evaluators={"validation": evaluator, "ema validation": ema_evaluator}, log_every_iters=15, ) if cfg.online_exp_tracking.wandb: from ignite.contrib.handlers import WandBLogger wb_dir = Path("/tmp/output-fixmatch-wandb") if not wb_dir.exists(): wb_dir.mkdir() _ = WandBLogger( project="fixmatch-pytorch", name=cfg.name, config=cfg, sync_tensorboard=True, dir=wb_dir.as_posix(), reinit=True, ) resume_from = cfg.solver.resume_from if resume_from is not None: resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*")) if len(resume_from) > 0: # get latest checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix() ) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) @trainer.on(Events.COMPLETED) def release_all_resources(): nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter if idist.get_rank() == 0: tb_logger.close() if unsupervised_train_loader_iter is not None: unsupervised_train_loader_iter = None if cta_probe_loader_iter is not None: cta_probe_loader_iter = None return trainer
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz", ] # 2 binary labels for gender classification: man and woman labels = np.array( [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]) train_files = [{ "img": img, "label": label } for img, label in zip(images[:10], labels[:10])] val_files = [{ "img": img, "label": label } for img, label in zip(images[-10:], labels[-10:])] # define transforms for image train_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img"]), ]) val_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), ToTensord(keys=["img"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["label"]) # create DenseNet121, CrossEntropyLoss and Adam optimizer net = monai.networks.nets.densenet.densenet121( spatial_dims=3, in_channels=1, out_channels=2, ) loss = torch.nn.CrossEntropyLoss() lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device("cuda:0") # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch["img"], batch["label"]), device, non_blocking) trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer") train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) # set parameters for validation validation_every_n_epochs = 1 metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = { metric_name: Accuracy(), "AUC": ROCAUC(to_onehot_y=True, add_softmax=True) } # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs)
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) trainer.logger = setup_logger("Trainer") if sys.version_info > (3, ): from ignite.contrib.metrics.gpu_info import GpuInfo try: GpuInfo().attach(trainer) except RuntimeError: print( "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). " "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please " "install it : `pip install pynvml`") metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.logger = setup_logger("Train Evaluator") validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) validation_evaluator.logger = setup_logger("Val Evaluator") @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): train_evaluator.run(train_loader) validation_evaluator.run(val_loader) tb_logger = TensorboardLogger(log_dir=log_dir) tb_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all", ) for tag, evaluator in [("training", train_evaluator), ("validation", validation_evaluator)]: tb_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names=["loss", "accuracy"], global_step_transform=global_step_from_engine(trainer), ) tb_logger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) def score_function(engine): return engine.state.metrics["accuracy"] model_checkpoint = ModelCheckpoint( log_dir, n_saved=2, filename_prefix="best", score_function=score_function, score_name="validation_accuracy", global_step_transform=global_step_from_engine(trainer), ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) # kick everything off trainer.run(train_loader, max_epochs=epochs) tb_logger.close()