def test_custom_event_with_arg_handlers_profiler(): true_event_handler_time = 0.1 true_max_epochs = 1 true_num_iters = 2 profiler = HandlersTimeProfiler() dummy_trainer = Engine(_do_nothing_update_fn) dummy_trainer.register_events("custom_event") profiler.attach(dummy_trainer) @dummy_trainer.on(Events.ITERATION_COMPLETED(every=1)) def trigger_custom_event(): dummy_trainer.fire_event("custom_event") args = [122, 324] @dummy_trainer.on("custom_event", args) def on_custom_event(args): time.sleep(true_event_handler_time) dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs) results = profiler.get_results() event_results = None for row in results: if row[1] == "custom_event": event_results = row break assert event_results is not None assert "on_custom_event" in event_results[0] assert event_results[2] == approx(true_max_epochs * true_num_iters * true_event_handler_time, abs=1e-1) # total assert event_results[3][0] == approx(true_event_handler_time, abs=1e-1) # min assert event_results[4][0] == approx(true_event_handler_time, abs=1e-1) # max assert event_results[5] == approx(true_event_handler_time, abs=1e-1) # mean assert event_results[6] == approx(0.0, abs=1e-1) # stddev
def test_custom_events(): class CustomEvents(EventEnum): TEST_EVENT = "test_event" # Dummy engine engine = Engine(lambda engine, batch: 0) engine.register_events(*CustomEvents) engine.register_events("a", "b", "c") evs = [CustomEvents.TEST_EVENT, "a", "b", "c"] # Handle is never called handlers = [(e, MagicMock()) for e in evs] for e, h in handlers: engine.add_event_handler(e, h) engine.run(range(1)) for _, h in handlers: assert not h.called # Advanced engine def process_func(engine, batch): for e, _ in handlers: engine.fire_event(e) engine = Engine(process_func) engine.register_events(*CustomEvents) engine.register_events("a", "b", "c") # Handle should be called handlers = [(e, MagicMock()) for e in evs] for e, h in handlers: engine.add_event_handler(e, h) engine.run(range(1)) for _, h in handlers: assert h.called
def setup_training(self): assert self.batch_size is not None trainer = Engine(lambda e, b: self.train_step(b)) trainer.register_events("EVAL_DONE") Average(lambda o: o['loss']).attach(trainer, 'avg_loss') state_vars = dict(model=self.model, opt=self.opt, trainer=trainer) checkpoint_handler = ModelCheckpoint(self.run_path, '', score_function=lambda e: e.state.metrics['val_accuracy'], score_name='val_accuracy', n_saved=2, global_step_transform=lambda e, evt_name: e.state.epoch) if checkpoint_handler.last_checkpoint: checkpoint_handler.load_objects(state_vars, self.run_path / checkpoint_handler.last_checkpoint) trainer.add_event_handler("EVAL_DONE", lambda e: checkpoint_handler(e, state_vars)) if self.use_lr_decay: trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda e: self.lr_decay.step(e.state.iteration * self.batch_size)) RunningAverage(output_transform=lambda o: o['loss']).attach(trainer, 'running_avg_loss') ProgressBar().attach(trainer, ['running_avg_loss']) logger.setup_logger(self.run_path, trainer, self.model) @trainer.on(Events.EPOCH_COMPLETED) def eval_and_log(e: Engine): eval_results = self.eval() e.state.metrics['val_accuracy'] = eval_results['val'].metrics['accuracy'] e.state.metrics['val_loss'] = eval_results['val'].metrics['avg_loss'] e.state.eval_results = eval_results e.fire_event("EVAL_DONE") if self.use_early_stop: es = self.make_early_stopper(trainer) trainer.add_event_handler("EVAL_DONE", es) return trainer
def test_custom_events(): class CustomEvents(Enum): TEST_EVENT = "test_event" # Dummy engine engine = Engine(lambda engine, batch: 0) engine.register_events(*CustomEvents) # Handle is never called handle = MagicMock() engine.add_event_handler(CustomEvents.TEST_EVENT, handle) engine.run(range(1)) assert not handle.called # Advanced engine def process_func(engine, batch): engine.fire_event(CustomEvents.TEST_EVENT) engine = Engine(process_func) engine.register_events(*CustomEvents) # Handle should be called handle = MagicMock() engine.add_event_handler(CustomEvents.TEST_EVENT, handle) engine.run(range(1)) assert handle.called
def create_engine( model, loss_fn, constraint_fn, optimizer=None, projection=False, monitor=None, guard=True, regularization_weight=0.0, error_fn=None, device="cpu", tolerance=1e-5, max_iterations=1e4, ): """Creates an engine with the necessary components. If optimizer is not provided, then will run inference :param model: model to train or evaluate :param loss_fn: loss_fn to be used for training or monitored for evaluation :param constraint_fn: constraint function to be used for training or monitored for evaluation :param optimizer: optimizer to use to update the model. Must be provided even for inference :param projection: whether to run the projection loop :param monitor: handler to be used for monitoring. Must have an .attach(engine) method :param guard: whether to perform a check to ensure that the model is training :param regularization_weight: multiplier to use for soft-constraining during training. Defaults to 0 for unconstrained :param error_fn: error function to use for converting the constraint function to an error function for soft constraining. Defaults to MSE :param device: "cuda" or "cpu" :returns: an ignite.engine.Engine whose output is (xb, yb, out) for every iteration """ if projection: iteration_fn = ProjectionLoop else: iteration_fn = TrainingLoop engine = Engine( iteration_fn( model, loss_fn, constraint_fn, optimizer, regularization_weight, error_fn, device, )) engine.register_events(*Sub_Batch_Events) if monitor is not None: monitor.attach(engine) return engine
def test_deprecated_callable_events_class(): engine = Engine(lambda engine, batch: 0) with pytest.warns( DeprecationWarning, match=r"Class ignite\.engine\.events\.CallableEvents is deprecated" ): class CustomEvents(CallableEvents, Enum): TEST_EVENT = "test_event" engine.register_events(*CustomEvents)
def test_custom_events_asserts(): # Dummy engine engine = Engine(lambda engine, batch: 0) class A: pass with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum"): engine.register_events(None) with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum"): engine.register_events("str", None) with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum"): engine.register_events(1) with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum"): engine.register_events(A()) assert Events.EPOCH_COMPLETED != 1 assert Events.EPOCH_COMPLETED != "abc" assert Events.ITERATION_COMPLETED != Events.EPOCH_COMPLETED assert Events.ITERATION_COMPLETED != Events.EPOCH_COMPLETED(every=2) # In current implementation, EPOCH_COMPLETED and EPOCH_COMPLETED with event filter are the same assert Events.EPOCH_COMPLETED == Events.EPOCH_COMPLETED(every=2) assert Events.ITERATION_COMPLETED == Events.ITERATION_COMPLETED(every=2)
def test_custom_events_with_events_list(): class CustomEvents(EventEnum): TEST_EVENT = "test_event" def process_func(engine, batch): engine.fire_event(CustomEvents.TEST_EVENT) engine = Engine(process_func) engine.register_events(*CustomEvents) # Handle should be called handle = MagicMock() engine.add_event_handler(CustomEvents.TEST_EVENT | Events.STARTED, handle) engine.run(range(1)) assert handle.called
def test_deprecated_callable_events_class(): engine = Engine(lambda engine, batch: 0) with pytest.warns( DeprecationWarning, match=r"Class ignite\.engine\.events\.CallableEvents is deprecated" ): class CustomEvents(CallableEvents, Enum): TEST_EVENT = "test_event" with pytest.raises( TypeError, match=r"Value at \d of event_names should be a str or EventEnum" ): engine.register_events(*CustomEvents)
def create_trainers(config, model, optimizer, loss_fn, device) -> Tuple[Engine, Engine]: """Create Engines for training and evaluation. Parameters ---------- config config object model nn.Module model loss_fn nn.Module loss optimizer torch optimizer device device to use for training Returns ------- trainer, evaluator """ trainer = Engine( lambda e, b: train_function( config=config, engine=e, batch=b, model=model, loss_fn=loss_fn, optimizer=optimizer, device=device ) ) evaluator = Engine( lambda e, b: evaluate_function( config=config, engine=e, batch=b, model=model, device=device ) ) trainer.register_events(*TrainEvents, event_to_attr=train_events_to_attr) return trainer, evaluator
def _test(event_name, event_attr, true_num_calls): def update_fn(engine, batch): engine.state.test_event = engine.state.iteration engine.fire_event(CustomEvents.TEST_EVENT) engine = Engine(update_fn) engine.register_events(*CustomEvents, event_to_attr=event_to_attr) num_calls = [0, ] @engine.on(event_name(event_filter=custom_event_filter)) def assert_on_special_event(engine): assert getattr(engine.state, event_attr) == special_events.pop(0) num_calls[0] += 1 d = list(range(50)) engine.run(d, max_epochs=25) assert num_calls[0] == true_num_calls
def test_custom_events_with_event_to_attr(): class CustomEvents(EventEnum): TEST_EVENT = "test_event" custom_event_to_attr = {CustomEvents.TEST_EVENT: "test_event"} # Dummy engine engine = Engine(lambda engine, batch: 0) engine.register_events(*CustomEvents, event_to_attr=custom_event_to_attr) # Handle is never called handle = MagicMock() engine.add_event_handler(CustomEvents.TEST_EVENT, handle) engine.run(range(1)) assert hasattr(engine.state, "test_event") assert engine.state.test_event == 0 # Advanced engine def process_func(engine, batch): engine.fire_event(CustomEvents.TEST_EVENT) engine = Engine(process_func) engine.register_events(*CustomEvents, event_to_attr=custom_event_to_attr) def handle(engine): engine.state.test_event += 1 engine.add_event_handler(CustomEvents.TEST_EVENT, handle) engine.run(range(25)) assert engine.state.test_event == 25 custom_event_to_attr = "a" engine = Engine(lambda engine, batch: 0) with pytest.raises(ValueError): engine.register_events(*CustomEvents, event_to_attr=custom_event_to_attr)
def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=_prepare_batch): """Create a trainer for truncated backprop through time supervised models. Training recurrent model on long sequences is computationally intensive as it requires to process the whole sequence before getting a gradient. However, when the training loss is computed over many outputs (`X to many <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`_), there is an opportunity to compute a gradient over a subsequence. This is known as `truncated backpropagation through time <https://machinelearningmastery.com/ gentle-introduction-backpropagation-time/>`_. This supervised trainer apply gradient optimization step every `tbtt_step` time steps of the sequence, while backpropagating through the same `tbtt_step` time steps. Args: model (`torch.nn.Module`): the model to train optimizer (`torch.optim.Optimizer`): the optimizer to use loss_fn (torch.nn loss function): the loss function to use tbtt_step (int): the length of time chunks (last one may be smaller) dim (int): axis representing the time dimension device (str, optional): device type specification (default: None). Applies to both model and batches. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (Callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. Returns: Engine: a trainer engine with supervised update function """ if device: model.to(device) def _update(engine, batch): loss_list = [] hidden = None x, y = batch for batch_t in zip(x.split(tbtt_step, dim=dim), y.split(tbtt_step, dim=dim)): x_t, y_t = prepare_batch(batch_t, device=device, non_blocking=non_blocking) # Fire event for start of iteration engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED) # Forward, backward and model.train() optimizer.zero_grad() if hidden is None: y_pred_t, hidden = model(x_t) else: hidden = _detach_hidden(hidden) y_pred_t, hidden = model(x_t, hidden) loss_t = loss_fn(y_pred_t, y_t) loss_t.backward() optimizer.step() # Setting state of engine for consistent behaviour engine.state.output = loss_t.item() loss_list.append(loss_t.item()) # Fire event for end of iteration engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED) # return average loss over the time splits return sum(loss_list) / len(loss_list) engine = Engine(_update) engine.register_events(*Tbptt_Events) return engine
def set_image_classification_trainer(model, optimizer, criterion, device, loaders, loggers): def train_step(engine, batch): model.train() optimizer.zero_grad() x, y = batch[0].to(device), batch[1].to(device) y_pred = model(x) loss = criterion(y_pred, y).mean() loss.backward() optimizer.step() return loss.item() trainer = Engine(train_step) loggers['progress_bar'].attach(trainer, metric_names='all') def validation_step(engine, batch): model.eval() with torch.no_grad(): x, target = batch[0].to(device), batch[1].to(device) y = model(x) return {'y_pred': y, 'y': target, 'criterion_kwargs': {}} evaluator = Engine(validation_step) evaluator.state.validation_completed = 0 evaluator.register_events(*EvaluatorEvents, event_to_attr=event_to_attr) metrics = { 'loss': Loss(criterion), 'F1': Fbeta(beta=1, average=False), 'mA': Accuracy(is_multilabel=False), 'mP': Precision(average=False, is_multilabel=False), 'mR': Recall(average=False, is_multilabel=False) } for name, metric in metrics.items(): metric.attach(evaluator, name) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=250), log_training_loss, loggers) @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): with evaluator.add_event_handler(Events.COMPLETED, log_results, 'train', engine.state.epoch, loggers): evaluator.run(loaders['train']) with evaluator.add_event_handler(Events.COMPLETED, log_results, 'validation', engine.state.epoch, loggers): evaluator.run(loaders['validation']) evaluator.state.validation_completed += 1 evaluator.fire_event(EvaluatorEvents.VALIDATION_COMPLETED) @trainer.on(Events.COMPLETED) def test(engine): with evaluator.add_event_handler( Events.COMPLETED, log_results, 'test', engine.state.epoch, loggers), evaluator.add_event_handler( Events.COMPLETED, log_calibration_results, 'test', loggers, output_transform=lambda output: { 'y_pred': F.softmax(output['y_pred'], dim=1), 'y': output['y'] }): evaluator.run(loaders['test']) return trainer, evaluator
def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, dim=0, device=None): """Create a trainer for truncated backprop through time supervised models. Training recurrent model on long sequences is computationally intensive as it requires to process the whole sequence before getting a gradient. However, when the training loss is computed over many outputs ([X to many](https://karpathy.github.io/2015/05/21/rnn-effectiveness/)), there is an opportunity to compute a gradient over a subsequence. This is known as [truncated backpropagation through time]( https://machinelearningmastery.com/gentle-introduction-backpropagation-time/ ). This supervised trainer apply gradient optimization step every `tbtt_step` time steps of the sequence, while backpropagating through the same `tbtt_step` time steps. Args: model (`torch.nn.Module`): the model to train optimizer (`torch.optim.Optimizer`): the optimizer to use loss_fn (torch.nn loss function): the loss function to use tbtt_step (int): the length of time chunks (last one may be smaller) dim (int): axis representing the time dimension device (str, optional): device type specification (default: None). Applies to both model and batches. Returns: Engine: a trainer engine with supervised update function """ if device: model.to(device) def _update(engine, batch): loss_list = [] hidden = None # Batches split in time chunks batch_splits = _prepare_tbptt_batch(batch, tbtt_step, dim=dim, device=device) for x_t, y_t in batch_splits: # Fire event for start of iteration engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED) # Forward, backward and model.train() optimizer.zero_grad() if hidden is None: y_pred_t, hidden = model(x_t) else: hidden = _detach_hidden(hidden) y_pred_t, hidden = model(x_t) loss_t = loss_fn(y_pred_t, y_t) loss_t.backward() optimizer.step() # Setting state of engine for consistent behaviour engine.state.output = loss_t.item() loss_list.append(loss_t.item()) # Fire event for end of iteration engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED) # return average loss over the time splits return sum(loss_list) / len(loss_list) engine = Engine(_update) engine.register_events(*Tbptt_Events) return engine
def create_train_and_validation_engines(train_func, val_func=None, device='cpu'): """ Helper function for creating an ignite Engine object with helpful defaults. This sets up an Engine that has four handlers attached to it: - prepare_batch: before a batch is passed to train_func or val_func, this function runs, moving every item in the batch (which is a dictionary) to the appropriate device ('cpu' or 'cuda'). - book_keeping: sets up some dictionaries that are used for bookkeeping so one can easily track the epoch and iteration losses for both training and validation. - add_to_iter_history: records the iteration, epoch, and past iteration losses into the dictionaries set up by book_keeping. - clear_iter_history: resets the current iteration history of losses after moving the current iteration history into past iteration history. Args: train_func (func): Function that provides the closure for training for a single batch. val_func (func, optional): Function that provides the closure for validating a single batch. Defaults to None. device (str, optional): Device to move tensors to. Defaults to 'cpu'. """ # Set up engines for training and validation trainer = Engine(train_func) trainer.register_events(*ValidationEvents) trainer.register_events(*BackwardsEvents) validator = None if val_func is None else Engine(val_func) # Before a batch starts, the items should be float and moved to the # correct device, for both training and validation. Checks to make # sure "cuda" is available if user requested cuda. device = device if torch.cuda.is_available() else 'cpu' device = torch.device(device) def prepare_batch(engine): batch = engine.state.batch for key in batch: if torch.is_tensor(batch[key]): batch[key] = batch[key].float().to(device) engine.state.batch = batch # Set up stuff for bookkeeping as training progresses. def book_keeping(engine): engine.state.epoch_history = {} engine.state.iter_history = {} engine.state.past_iter_history = {} def add_to_iter_history(engine): for key in engine.state.output: if key not in engine.state.iter_history: engine.state.iter_history[key] = [] if key not in engine.state.past_iter_history: engine.state.past_iter_history[key] = [] engine.state.iter_history[key].append(engine.state.output[key]) engine.state.past_iter_history[key].append( engine.state.iter_history[key]) def clear_iter_history(engine): engine.state.iter_history = {} trainer.add_event_handler(Events.ITERATION_STARTED, prepare_batch) trainer.add_event_handler(Events.STARTED, book_keeping) trainer.add_event_handler(Events.ITERATION_COMPLETED, add_to_iter_history) trainer.add_event_handler(Events.EPOCH_STARTED, clear_iter_history) if validator is not None: validator.add_event_handler(Events.ITERATION_STARTED, prepare_batch) validator.add_event_handler(Events.STARTED, book_keeping) validator.add_event_handler(Events.ITERATION_COMPLETED, add_to_iter_history) validator.add_event_handler(Events.EPOCH_STARTED, clear_iter_history) return trainer, validator
def create_supervised_tbptt_trainer( model: nn.Module, optimizer: Optimizer, loss_fn: nn.Module, tbtt_step: int, dim: int = 0, device: Optional[str] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, ): """Create a trainer for truncated backprop through time supervised models. Training recurrent model on long sequences is computationally intensive as it requires to process the whole sequence before getting a gradient. However, when the training loss is computed over many outputs (`X to many <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`_), there is an opportunity to compute a gradient over a subsequence. This is known as `truncated backpropagation through time <https://machinelearningmastery.com/ gentle-introduction-backpropagation-time/>`_. This supervised trainer apply gradient optimization step every `tbtt_step` time steps of the sequence, while backpropagating through the same `tbtt_step` time steps. Args: model (`torch.nn.Module`): the model to train. optimizer (`torch.optim.Optimizer`): the optimizer to use. loss_fn (torch.nn loss function): the loss function to use. tbtt_step (int): the length of time chunks (last one may be smaller). dim (int): axis representing the time dimension. device (str, optional): device type specification (default: None). Applies to batches. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. .. warning:: The internal use of `device` has changed. `device` will now *only* be used to move the input data to the correct device. The `model` should be moved by the user before creating an optimizer. For more information see: * `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_ * `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_ Returns: Engine: a trainer engine with supervised update function. """ def _update(engine: Engine, batch: Sequence[torch.Tensor]): loss_list = [] hidden = None x, y = batch for batch_t in zip(x.split(tbtt_step, dim=dim), y.split(tbtt_step, dim=dim)): x_t, y_t = prepare_batch(batch_t, device=device, non_blocking=non_blocking) # Fire event for start of iteration engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED) # Forward, backward and model.train() optimizer.zero_grad() if hidden is None: y_pred_t, hidden = model(x_t) else: hidden = _detach_hidden(hidden) y_pred_t, hidden = model(x_t, hidden) loss_t = loss_fn(y_pred_t, y_t) loss_t.backward() optimizer.step() # Setting state of engine for consistent behaviour engine.state.output = loss_t.item() loss_list.append(loss_t.item()) # Fire event for end of iteration engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED) # return average loss over the time splits return sum(loss_list) / len(loss_list) engine = Engine(_update) engine.register_events(*Tbptt_Events) return engine
def create_engine( model, loss_fn, constraint_fn, optimizer=None, metrics=None, monitor=None, guard=True, method="unconstrained", reduction=None, device="cpu", ): """Creates an engine with the necessary components. If optimizer is not provided, then will run inference :param model: model to train or evaluate :param loss_fn: loss_fn to be used for training or monitored for evaluation :param constraint_fn: constraint function to be used for training or monitored for evaluation :param optimizer: optimizer to use to update the model. If not provided, then the model weights are not updated :param metrics: an optional dictionary of (ignite / pyinsulate.ignite) metrics to attach to the engine :param monitor: handler to be used for monitoring. Must have an .attach(engine) method :param guard: whether to perform a check to ensure that the model is training :param method: method to use for constraining. Should be one of "constrained" - compute average (along batch) of constrained update "batchwise" - compute constrained update of mean loss with respect to all constraints within the batch "reduction" - apply reduction before computing constraints. If no reduction is specified, will throw error "unconstrained" - don't constrain. Used as a control method "soft-constrained" - use soft constraints "no-loss" - intended entirely for debugging. Ignores the loss function entirely and just tries to satisfy the constraints "non-projecting" - the sum of "no-loss" and "unconstrained". This destroys the exponential convergence guarantee, but should be useful for debugging :param reduction: reduction to apply to constraints before computing constrained loss if method == "reduction" :returns: an ignite.engine.Engine whose output is (xb, yb, out) for every iteration """ def end_section(engine, section_event, section_start_time): """End the section, tabulate the time, fire the event, and resume time""" engine.state.times[section_event.value] = (perf_counter() - section_start_time) engine.fire_event(section_event) return perf_counter() def proof_of_constraint_iteration(engine, batch): if not hasattr(engine.state, "last_grounded"): engine.state.last_grounded = 0 if not hasattr(engine.state, "times"): setattr(engine.state, "times", dict()) iteration_start = perf_counter() section_start = iteration_start if optimizer is not None: model.train() optimizer.zero_grad() else: model.eval() engine.state.xb, engine.state.yb = prepare_batch( batch, device=torch.device(device)) section_start = end_section(engine, Sub_Batch_Events.DATA_LOADED, section_start) engine.state.out = model(*engine.state.xb) section_start = end_section(engine, Sub_Batch_Events.FORWARD_PASS_COMPLETED, section_start) if guard: # Ensure training isn't failing last = getattr(engine.state, "last", None) if (last is not None and len(engine.state.out) == len(last) and torch.allclose(engine.state.out, last)): print("WARNING! Just outputting same thing!") print(f"xb: {[x.cpu() for x in engine.state.xb]}") print(f"yb: {engine.state.yb.cpu()}") print(f"out: {engine.state.out.cpu()}") engine.state.last = engine.state.out if torch.allclose( engine.state.out, engine.state.out.new_zeros(engine.state.out.size()), ): print("WARNING! Training is failing") section_start = end_section(engine, Sub_Batch_Events.GUARD_COMPLETED, section_start) engine.state.loss = loss_fn(engine.state.out, engine.state.yb) engine.state.mean_loss = torch.mean(engine.state.loss) section_start = end_section(engine, Sub_Batch_Events.LOSS_COMPUTED, section_start) engine.state.constraints, engine.state.constraints_diagnostics = constraint_fn( engine.state.out, engine.state.xb, model, True) # last parameter is to return diagnostics section_start = end_section(engine, Sub_Batch_Events.CONSTRAINTS_COMPUTED, section_start) if method == "constrained": constrained_loss, engine.state.multipliers, multiplier_computation_timing = constrain_loss( engine.state.loss, engine.state.constraints, list(model.parameters()), return_multipliers=True, return_timing=True, # defaults are for this method ) engine.state.reduced_constraints = engine.state.constraints.new_zeros( 1) engine.state.constrained_loss = torch.mean(constrained_loss) engine.state.times.update(multiplier_computation_timing) elif method == "batchwise": engine.state.constrained_loss, engine.state.multipliers, multiplier_computation_timing = constrain_loss( engine.state.loss, engine.state.constraints, list(model.parameters()), return_multipliers=True, return_timing=True, batchwise=True, ) engine.state.reduced_constraints = engine.state.constraints.new_zeros( 1) engine.state.times.update(multiplier_computation_timing) elif method == "reduction": if reduction is None: raise ValueError( "Reduction must be specified if method=='reduction'") engine.state.constrained_loss, engine.state.multipliers, multiplier_computation_timing = constrain_loss( engine.state.loss, engine.state.constraints, list(model.parameters()), return_multipliers=True, return_timing=True, reduction=reduction, ) engine.state.reduced_constraints = reduction( engine.state.constraints) engine.state.times.update(multiplier_computation_timing) elif method == "soft-constrained": engine.state.multipliers = (engine.state.constraints / engine.state.constraints.numel()) engine.state.constrained_loss = torch.mean( engine.state.loss) + torch.mean( engine.state.constraints * engine.state.constraints) engine.state.reduced_constraints = engine.state.constraints.new_zeros( 1) elif method == "unconstrained": # Technically the multipliers are zero, so we set this for consistency engine.state.multipliers = engine.state.constraints.new_zeros( engine.state.constraints.size()) engine.state.constrained_loss = torch.mean(engine.state.loss) engine.state.reduced_constraints = engine.state.constraints.new_zeros( 1) elif method == "no-loss": constrained_loss, engine.state.multipliers, multiplier_computation_timing = constrain_loss( engine.state.loss.new_zeros( engine.state.loss.size()).requires_grad_(), engine.state.constraints, list(model.parameters()), return_multipliers=True, return_timing=True, ) engine.state.constrained_loss = torch.mean(constrained_loss) engine.state.times.update(multiplier_computation_timing) engine.state.reduced_constraints = engine.state.constraints.new_zeros( 1) elif method == "non-projecting": correction_term, engine.state.multipliers, multiplier_computation_timing = constrain_loss( engine.state.loss.new_zeros( engine.state.loss.size()).requires_grad_(), engine.state.constraints, list(model.parameters()), return_multipliers=True, return_timing=True, ) engine.state.constrained_loss = torch.mean(engine.state.loss + correction_term) engine.state.times.update(multiplier_computation_timing) engine.state.reduced_constraints = engine.state.constraints.new_zeros( 1) else: raise ValueError(f"Method {method} not known. Please respecify") section_start = end_section(engine, Sub_Batch_Events.REWEIGHTED_LOSS_COMPUTED, section_start) # log the values of the model parameters (without gradients) engine.state.model_parameters = (torch.cat( [param.view(-1) for param in model.parameters()], dim=-1).clone().detach()) if optimizer is not None: engine.state.constrained_loss.backward() # attach the gradients engine.state.model_parameters_grad = torch.cat( [param.grad.view(-1) for param in model.parameters()], dim=-1) optimizer.step() else: engine.state.model_parameters_grad = None engine.state.model_state_dict = model.state_dict() if optimizer is not None: engine.state.optimizer_state_dict = optimizer.state_dict() else: engine.state.optimizer_state_dict = None section_start = end_section(engine, Sub_Batch_Events.OPTIMIZER_STEPPED, section_start) if torch.allclose( engine.state.constrained_loss, engine.state.constrained_loss.new_zeros( engine.state.constrained_loss.size()), ): print("Constrained loss is zero!") engine.state.times["total"] = perf_counter() - iteration_start return engine.state.xb, engine.state.yb, engine.state.out engine = Engine(proof_of_constraint_iteration) engine.register_events(*Sub_Batch_Events) if metrics is not None: for name, metric in metrics.items(): metric.attach(engine, name) if monitor is not None: monitor.attach(engine) return engine
def attach(self, engine: Engine): engine.add_event_handler(Events.ITERATION_COMPLETED, self) engine.register_events(*PeriodEvents) for e in PeriodEvents: State.event_to_attr[e] = "iteration"
def attach(self, engine: Engine): engine.add_event_handler(Events.ITERATION_COMPLETED, self) engine.register_events(*EpisodeEvents) State.event_to_attr[EpisodeEvents.EPISODE_COMPLETED] = "episode" State.event_to_attr[EpisodeEvents.BOUND_REWARD_REACHED] = "episode" State.event_to_attr[EpisodeEvents.BEST_REWARD_REACHED] = "episode"
def get_prepared_engine_for_handlers_profiler(true_event_handler_time): HANDLERS_SLEEP_COUNT = 11 PROCESSING_SLEEP_COUNT = 3 class CustomEvents(EventEnum): CUSTOM_STARTED = "custom_started" CUSTOM_COMPLETED = "custom_completed" def dummy_train_step(engine, batch): engine.fire_event(CustomEvents.CUSTOM_STARTED) time.sleep(true_event_handler_time) engine.fire_event(CustomEvents.CUSTOM_COMPLETED) dummy_trainer = Engine(dummy_train_step) dummy_trainer.register_events(*CustomEvents) @dummy_trainer.on(Events.STARTED) def delay_start(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.COMPLETED) def delay_complete(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.EPOCH_STARTED) def delay_epoch_start(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.EPOCH_COMPLETED) def delay_epoch_complete(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.ITERATION_STARTED) def delay_iter_start(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.ITERATION_COMPLETED) def delay_iter_complete(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.GET_BATCH_STARTED) def delay_get_batch_started(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.GET_BATCH_COMPLETED) def delay_get_batch_completed(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(CustomEvents.CUSTOM_STARTED) def delay_custom_started(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(CustomEvents.CUSTOM_COMPLETED) def delay_custom_completed(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.EPOCH_STARTED(once=1)) def do_something_once_on_1_epoch(): time.sleep(true_event_handler_time) return dummy_trainer, HANDLERS_SLEEP_COUNT, PROCESSING_SLEEP_COUNT
def run(train_config, logger, **kwargs): logger = logging.getLogger('UDA') if getattr(train_config, 'debug', False): setup_logger(logger, logging.DEBUG) # Set Polyaxon environment if needed plx_logger = None save_dir = None output_experiment_path = None try: plx_logger = PolyaxonLogger() experiment = plx_logger.experiment save_dir = get_outputs_path() output_experiment_path = get_outputs_refs_paths() output_experiment_path = output_experiment_path['experiments'][ 0] if output_experiment_path else None logger.debug("Experiment info: {}".format( experiment.get_experiment_info())) except PolyaxonClientException as e: logger.warning('Logger Polyaxon : ' + str(e)) # Path configuration saves_dict = getattr(train_config, 'saves', {}) save_dir = saves_dict.get('save_dir', '') if save_dir is None else save_dir log_dir = os.path.join(save_dir, saves_dict.get('log_dir', '')) save_model_dir = os.path.join(save_dir, saves_dict.get('model_dir', '')) save_prediction_dir = os.path.join(save_dir, saves_dict.get('prediction_dir', '')) save_config_dir = os.path.join(save_dir, saves_dict.get('config_dir', '')) load_model_file = saves_dict.get('load_model_file', '') load_optimizer_file = saves_dict.get('load_optimizer_file', '') # Create folders create_save_folders(save_dir, saves_dict) if output_experiment_path is not None: model_dir = saves_dict.get('model_dir', '') load_model_file = os.path.join( output_experiment_path, model_dir, load_model_file) if load_model_file else None load_optimizer_file = os.path.join( output_experiment_path, model_dir, load_optimizer_file) if load_optimizer_file else None num_epochs = getattr(train_config, 'num_epochs') num_classes = getattr(train_config, 'num_classes') device = getattr(train_config, 'device', 'cpu') # Set magical acceleration if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True else: assert device == 'cpu', 'CUDA device selected but none is available' # Set half precision if required use_fp_16 = getattr(train_config, 'use_fp_16', False) train1_sup_loader = getattr(train_config, 'train1_sup_loader') train1_unsup_loader = getattr(train_config, 'train1_unsup_loader') train2_unsup_loader = getattr(train_config, 'train2_unsup_loader') test_loader = getattr(train_config, 'test_loader') save_interval = saves_dict.get('save_interval', 0) n_saved = saves_dict.get('n_saved', 0) val_interval = getattr(train_config, 'val_interval', 1) pred_interval = getattr(train_config, 'pred_interval', 0) model = getattr(train_config, 'model').to(device) optimizer = getattr(train_config, 'optimizer') criterion = getattr(train_config, 'criterion').to(device) consistency_criterion = getattr(train_config, 'consistency_criterion').to(device) cm_metric = getattr( train_config, 'cm_metric', ConfusionMatrix(num_classes=num_classes, output_transform=lambda x: (x['y_pred'], x['y']))) # AMP initialization for half precision if use_fp_16: assert 'cuda' in device assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled." try: from apex import amp except: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to run this example." ) # Initialize amp model, optimizer = amp.initialize(model, optimizer, opt_level="O2") # Load checkpoint load_params(model, optimizer=optimizer, model_file=load_model_file, optimizer_file=load_optimizer_file, device_name=device) # Add batch norm is_bn = getattr(train_config, 'is_bn', False) if is_bn: batch_norm = nn.BatchNorm2d(3).to(device) if use_fp_16: batch_norm = amp.initialize(batch_norm) batch_norm.reset_parameters() model = nn.Sequential(batch_norm, model) # Copy the config file shutil.copy2(os.path.abspath(train_config.__file__), os.path.join(save_config_dir, 'checkpoint_module.py')) le = len(train1_sup_loader) num_train_steps = le * num_epochs mlflow.log_param("num train steps", num_train_steps) lr = getattr(train_config, 'learning_rate') num_warmup_steps = getattr(train_config, 'num_warmup_steps', 0) lr_scheduler = getattr(train_config, 'lr_scheduler', None) if lr_scheduler is not None: lr_scheduler = lr_scheduler(optimizer) if num_warmup_steps > 0: lr_scheduler = create_lr_scheduler_with_warmup( lr_scheduler, warmup_start_value=0.0, warmup_end_value=lr * (1.0 + 1.0 / num_warmup_steps), warmup_duration=num_warmup_steps) train1_sup_loader_iter = cycle(train1_sup_loader) train1_unsup_loader_iter = cycle(train1_unsup_loader) train2_unsup_loader_iter = cycle(train2_unsup_loader) # Reduce on plateau reduce_on_plateau = getattr(train_config, 'reduce_on_plateau', None) # Output transform model output_transform_model = getattr(train_config, 'output_transform_model', lambda x: x) inference_fn = getattr(train_config, 'inference_fn', inference_standard) lam = getattr(train_config, 'consistency_lambda') beta = getattr(train_config, 'consistency_beta', lam) tsa = TrainingSignalAnnealing( num_steps=num_train_steps, min_threshold=getattr(train_config, 'TSA_proba_min'), max_threshold=getattr(train_config, 'TSA_proba_max')) with_tsa = getattr(train_config, 'with_TSA', False) cfg = { 'tsa': tsa, 'lambda': lam, 'beta': beta, 'with_tsa': with_tsa, 'device': device, 'consistency_criterion': consistency_criterion, 'criterion': criterion } trainer = Engine( partial(train_update_function, model=model, optimizer=optimizer, cfg=cfg, train1_sup_loader_iter=train1_sup_loader_iter, train1_unsup_loader_iter=train1_unsup_loader_iter, train2_unsup_loader_iter=train2_unsup_loader_iter, output_transform_model=output_transform_model, use_fp_16=use_fp_16)) # Register events for e in CustomEvents: State.event_to_attr[e] = 'iteration' trainer.register_events(*CustomEvents) if with_tsa: trainer.add_event_handler(Events.ITERATION_COMPLETED, log_tsa, tsa) if lr_scheduler is not None: if not hasattr(lr_scheduler, "step"): trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_STARTED, lambda engine: lr_scheduler.step()) trainer.add_event_handler(Events.ITERATION_COMPLETED, log_learning_rate, optimizer) metric_names = [ 'supervised batch loss', 'consistency batch loss', 'final batch loss' ] def output_transform(x, name): return x[name] for n in metric_names: RunningAverage( output_transform=partial(output_transform, name=n)).attach( trainer, n) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED) # Handlers for Tensorboard logging tb_logger = TensorboardLogger(log_dir=log_dir) tb_logger.attach(trainer, log_handler=tbOutputHandler(tag="train", metric_names=metric_names), event_name=CustomEvents.ITERATION_K_COMPLETED) tb_logger.attach(trainer, log_handler=tbOptimizerParamsHandler(optimizer, param_name="lr"), event_name=CustomEvents.ITERATION_K_STARTED) # Handlers for Polyaxon logging if plx_logger is not None: plx_logger.attach(trainer, log_handler=plxOutputHandler( tag="train", metric_names=metric_names), event_name=CustomEvents.ITERATION_K_COMPLETED) metrics = { 'loss': Loss(criterion, output_transform=lambda x: (x['y_pred'], x['y'])), 'mAcc': cmAccuracy(cm_metric).mean(), 'mPr': cmPrecision(cm_metric).mean(), 'mRe': cmRecall(cm_metric).mean(), 'mIoU': mIoU(cm_metric), 'mF1': cmFbeta(cm_metric, 1).mean() } iou = IoU(cm_metric) for i in range(num_classes): key_name = 'IoU_{}'.format(str(i)) metrics[key_name] = iou[i] inference_update_fn = partial( inference_update_function, model=model, cfg=cfg, output_transform_model=output_transform_model, inference_fn=inference_fn) evaluator = Engine(inference_update_fn) train_evaluator = Engine(inference_update_fn) for name, metric in metrics.items(): metric.attach(train_evaluator, name) metric.attach(evaluator, name) # Add checkpoint if save_model_dir: checkpoint = ModelCheckpoint(dirname=save_model_dir, filename_prefix='checkpoint', save_interval=save_interval, n_saved=n_saved, create_dir=True) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint, { 'mymodel': model, 'optimizer': optimizer }) def trigger_k_iteration_started(engine, k): if engine.state.iteration % k == 0: engine.fire_event(CustomEvents.ITERATION_K_STARTED) def trigger_k_iteration_completed(engine, k): if engine.state.iteration % k == 0: engine.fire_event(CustomEvents.ITERATION_K_COMPLETED) def run_validation(engine, validation_interval): if (trainer.state.epoch - 1) % validation_interval == 0: train_evaluator.run(train1_sup_loader) evaluator.run(test_loader) if save_prediction_dir: train_output = train_evaluator.state.output test_output = evaluator.state.output iteration = str(trainer.state.iteration) epoch = str(trainer.state.epoch) save_prediction('train_{}_{}'.format(iteration, epoch), save_prediction_dir, train_output['x'], torch.argmax( train_output['y_pred'][0, :, :, :], dim=0), y=train_output['y'][0, :, :]) save_prediction('test_{}_{}'.format(iteration, epoch), save_prediction_dir, test_output['x'], torch.argmax(test_output['y_pred'][0, :, :, :], dim=0), y=test_output['y'][0, :, :]) train_evaluator.state.output = None evaluator.state.output = None if reduce_on_plateau is not None: reduce_on_plateau.step(evaluator.state.metrics['mIoU']) trainer.add_event_handler(Events.ITERATION_STARTED, trigger_k_iteration_started, k=10) trainer.add_event_handler(Events.ITERATION_COMPLETED, trigger_k_iteration_completed, k=10) trainer.add_event_handler(Events.EPOCH_STARTED, run_validation, validation_interval=val_interval) trainer.add_event_handler(Events.COMPLETED, run_validation, validation_interval=1) def trainer_prediction_save(engine, prediction_interval): if (engine.state.iteration - 1) % prediction_interval == 0: if save_prediction_dir: trainer_output = trainer.state.output['unsup pred'] iteration = str(trainer.state.iteration) epoch = str(trainer.state.epoch) save_prediction('trainer_{}_{}'.format(iteration, epoch), save_prediction_dir, trainer_output['x'], trainer_output['y_pred']) logger.debug( 'Saved trainer prediction for iteration {}'.format( str(engine.state.iteration))) trainer.state.output = None trainer.add_event_handler(Events.ITERATION_COMPLETED, trainer_prediction_save, prediction_interval=pred_interval) tb_logger.attach(train_evaluator, log_handler=tbOutputHandler(tag="train", metric_names=list( metrics.keys())), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator, log_handler=tbOutputHandler(tag="test", metric_names=list( metrics.keys())), event_name=Events.EPOCH_COMPLETED) # Handlers for Polyaxon logging if plx_logger is not None: plx_logger.attach(train_evaluator, log_handler=plxOutputHandler(tag="train", metric_names=list( metrics.keys())), event_name=Events.EPOCH_COMPLETED) plx_logger.attach(evaluator, log_handler=plxOutputHandler(tag="test", metric_names=list( metrics.keys())), event_name=Events.EPOCH_COMPLETED) trainer.add_event_handler(Events.ITERATION_COMPLETED, mlflow_batch_metrics_logging, "train", trainer) train_evaluator.add_event_handler(Events.COMPLETED, mlflow_val_metrics_logging, "train", trainer) evaluator.add_event_handler(Events.COMPLETED, mlflow_val_metrics_logging, "test", trainer) data_steps = list(range(len(train1_sup_loader))) logger.debug('Start training') trainer.run(data_steps, max_epochs=num_epochs) logger.debug('Finished training')
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, postfix="inverted1", to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) # test different nearest interpolation values TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="image", nearest_interp=[True, False], post_func=[lambda x: x + 10, lambda x: x], postfix="inverted2", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) # check the nearest inerpolation mode for i in engine.state.output["image_inverted1"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted1"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted1"][-1].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") # check the case that different items use different interpolation mode to invert transforms for i in engine.state.output["image_inverted2"]: # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted2"]: # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(i.shape, (1, 100, 101, 107))
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), CastToTyped(KEYS, dtype=torch.uint8), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) for i in engine.state.output["image_inverted"] + engine.state.output[ "label_inverted"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted"][-1].detach().cpu( ).numpy()[0].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) self.assertTrue((reverted.size - n_good) in (25300, 1812), "diff. in two possible values")