Esempio n. 1
0
def test_custom_event_with_arg_handlers_profiler():
    true_event_handler_time = 0.1
    true_max_epochs = 1
    true_num_iters = 2

    profiler = HandlersTimeProfiler()
    dummy_trainer = Engine(_do_nothing_update_fn)
    dummy_trainer.register_events("custom_event")
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.ITERATION_COMPLETED(every=1))
    def trigger_custom_event():
        dummy_trainer.fire_event("custom_event")

    args = [122, 324]

    @dummy_trainer.on("custom_event", args)
    def on_custom_event(args):
        time.sleep(true_event_handler_time)

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
    results = profiler.get_results()
    event_results = None
    for row in results:
        if row[1] == "custom_event":
            event_results = row
            break
    assert event_results is not None
    assert "on_custom_event" in event_results[0]

    assert event_results[2] == approx(true_max_epochs * true_num_iters * true_event_handler_time, abs=1e-1)  # total
    assert event_results[3][0] == approx(true_event_handler_time, abs=1e-1)  # min
    assert event_results[4][0] == approx(true_event_handler_time, abs=1e-1)  # max
    assert event_results[5] == approx(true_event_handler_time, abs=1e-1)  # mean
    assert event_results[6] == approx(0.0, abs=1e-1)  # stddev
Esempio n. 2
0
def test_custom_events():
    class CustomEvents(EventEnum):
        TEST_EVENT = "test_event"

    # Dummy engine
    engine = Engine(lambda engine, batch: 0)
    engine.register_events(*CustomEvents)
    engine.register_events("a", "b", "c")

    evs = [CustomEvents.TEST_EVENT, "a", "b", "c"]

    # Handle is never called
    handlers = [(e, MagicMock()) for e in evs]
    for e, h in handlers:
        engine.add_event_handler(e, h)
    engine.run(range(1))
    for _, h in handlers:
        assert not h.called

    # Advanced engine
    def process_func(engine, batch):
        for e, _ in handlers:
            engine.fire_event(e)

    engine = Engine(process_func)
    engine.register_events(*CustomEvents)
    engine.register_events("a", "b", "c")

    # Handle should be called
    handlers = [(e, MagicMock()) for e in evs]
    for e, h in handlers:
        engine.add_event_handler(e, h)
    engine.run(range(1))
    for _, h in handlers:
        assert h.called
Esempio n. 3
0
    def setup_training(self):
        assert self.batch_size is not None
        trainer = Engine(lambda e, b: self.train_step(b))
        trainer.register_events("EVAL_DONE")
        Average(lambda o: o['loss']).attach(trainer, 'avg_loss')
        state_vars = dict(model=self.model, opt=self.opt, trainer=trainer)
        checkpoint_handler = ModelCheckpoint(self.run_path, '', score_function=lambda e: e.state.metrics['val_accuracy'],
                                             score_name='val_accuracy', n_saved=2, global_step_transform=lambda e, evt_name: e.state.epoch)
        if checkpoint_handler.last_checkpoint:
            checkpoint_handler.load_objects(state_vars, self.run_path / checkpoint_handler.last_checkpoint)
        trainer.add_event_handler("EVAL_DONE", lambda e: checkpoint_handler(e, state_vars))
        if self.use_lr_decay:
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda e: self.lr_decay.step(e.state.iteration * self.batch_size))

        RunningAverage(output_transform=lambda o: o['loss']).attach(trainer, 'running_avg_loss')
        ProgressBar().attach(trainer, ['running_avg_loss'])
        logger.setup_logger(self.run_path, trainer, self.model)

        @trainer.on(Events.EPOCH_COMPLETED)
        def eval_and_log(e: Engine):
            eval_results = self.eval()
            e.state.metrics['val_accuracy'] = eval_results['val'].metrics['accuracy'] 
            e.state.metrics['val_loss'] = eval_results['val'].metrics['avg_loss']
            e.state.eval_results = eval_results
            e.fire_event("EVAL_DONE")

        if self.use_early_stop:
            es = self.make_early_stopper(trainer)
            trainer.add_event_handler("EVAL_DONE", es)

        return trainer
Esempio n. 4
0
def test_custom_events():
    class CustomEvents(Enum):
        TEST_EVENT = "test_event"

    # Dummy engine
    engine = Engine(lambda engine, batch: 0)
    engine.register_events(*CustomEvents)

    # Handle is never called
    handle = MagicMock()
    engine.add_event_handler(CustomEvents.TEST_EVENT, handle)
    engine.run(range(1))
    assert not handle.called

    # Advanced engine
    def process_func(engine, batch):
        engine.fire_event(CustomEvents.TEST_EVENT)

    engine = Engine(process_func)
    engine.register_events(*CustomEvents)

    # Handle should be called
    handle = MagicMock()
    engine.add_event_handler(CustomEvents.TEST_EVENT, handle)
    engine.run(range(1))
    assert handle.called
Esempio n. 5
0
def create_engine(
    model,
    loss_fn,
    constraint_fn,
    optimizer=None,
    projection=False,
    monitor=None,
    guard=True,
    regularization_weight=0.0,
    error_fn=None,
    device="cpu",
    tolerance=1e-5,
    max_iterations=1e4,
):
    """Creates an engine with the necessary components. If optimizer is not
    provided, then will run inference

    :param model: model to train or evaluate
    :param loss_fn: loss_fn to be used for training or monitored for evaluation
    :param constraint_fn: constraint function to be used for training or
        monitored for evaluation
    :param optimizer: optimizer to use to update the model. Must be provided 
        even for inference
    :param projection: whether to run the projection loop
    :param monitor: handler to be used for monitoring. Must have an
        .attach(engine) method
    :param guard: whether to perform a check to ensure that the model is
        training
    :param regularization_weight: multiplier to use for soft-constraining during
        training. Defaults to 0 for unconstrained
    :param error_fn: error function to use for converting the constraint 
        function to an error function for soft constraining. Defaults to MSE
    :param device: "cuda" or "cpu"
    :returns: an ignite.engine.Engine whose output is (xb, yb, out) for every
        iteration
    """

    if projection:
        iteration_fn = ProjectionLoop
    else:
        iteration_fn = TrainingLoop

    engine = Engine(
        iteration_fn(
            model,
            loss_fn,
            constraint_fn,
            optimizer,
            regularization_weight,
            error_fn,
            device,
        ))
    engine.register_events(*Sub_Batch_Events)

    if monitor is not None:
        monitor.attach(engine)

    return engine
Esempio n. 6
0
def test_deprecated_callable_events_class():
    engine = Engine(lambda engine, batch: 0)

    with pytest.warns(
            DeprecationWarning,
            match=r"Class ignite\.engine\.events\.CallableEvents is deprecated"
    ):

        class CustomEvents(CallableEvents, Enum):
            TEST_EVENT = "test_event"

        engine.register_events(*CustomEvents)
Esempio n. 7
0
def test_custom_events_asserts():
    # Dummy engine
    engine = Engine(lambda engine, batch: 0)

    class A:
        pass

    with pytest.raises(
            TypeError,
            match=r"Value at \d of event_names should be a str or EventEnum"):
        engine.register_events(None)

    with pytest.raises(
            TypeError,
            match=r"Value at \d of event_names should be a str or EventEnum"):
        engine.register_events("str", None)

    with pytest.raises(
            TypeError,
            match=r"Value at \d of event_names should be a str or EventEnum"):
        engine.register_events(1)

    with pytest.raises(
            TypeError,
            match=r"Value at \d of event_names should be a str or EventEnum"):
        engine.register_events(A())

    assert Events.EPOCH_COMPLETED != 1
    assert Events.EPOCH_COMPLETED != "abc"
    assert Events.ITERATION_COMPLETED != Events.EPOCH_COMPLETED
    assert Events.ITERATION_COMPLETED != Events.EPOCH_COMPLETED(every=2)
    # In current implementation, EPOCH_COMPLETED and EPOCH_COMPLETED with event filter are the same
    assert Events.EPOCH_COMPLETED == Events.EPOCH_COMPLETED(every=2)
    assert Events.ITERATION_COMPLETED == Events.ITERATION_COMPLETED(every=2)
Esempio n. 8
0
def test_custom_events_with_events_list():
    class CustomEvents(EventEnum):
        TEST_EVENT = "test_event"

    def process_func(engine, batch):
        engine.fire_event(CustomEvents.TEST_EVENT)

    engine = Engine(process_func)
    engine.register_events(*CustomEvents)

    # Handle should be called
    handle = MagicMock()
    engine.add_event_handler(CustomEvents.TEST_EVENT | Events.STARTED, handle)
    engine.run(range(1))
    assert handle.called
Esempio n. 9
0
def test_deprecated_callable_events_class():
    engine = Engine(lambda engine, batch: 0)

    with pytest.warns(
            DeprecationWarning,
            match=r"Class ignite\.engine\.events\.CallableEvents is deprecated"
    ):

        class CustomEvents(CallableEvents, Enum):
            TEST_EVENT = "test_event"

        with pytest.raises(
                TypeError,
                match=r"Value at \d of event_names should be a str or EventEnum"
        ):
            engine.register_events(*CustomEvents)
Esempio n. 10
0
def create_trainers(config, model, optimizer, loss_fn, device) -> Tuple[Engine, Engine]:
    """Create Engines for training and evaluation.

    Parameters
    ----------
    config
        config object
    model
        nn.Module model
    loss_fn
        nn.Module loss
    optimizer
        torch optimizer
    device
        device to use for training

    Returns
    -------
    trainer, evaluator
    """
    trainer = Engine(
        lambda e, b: train_function(
            config=config,
            engine=e,
            batch=b,
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            device=device
        )
    )
    evaluator = Engine(
        lambda e, b: evaluate_function(
            config=config,
            engine=e,
            batch=b,
            model=model,
            device=device
        )
    )
    trainer.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
    return trainer, evaluator
Esempio n. 11
0
    def _test(event_name, event_attr, true_num_calls):

        def update_fn(engine, batch):
            engine.state.test_event = engine.state.iteration
            engine.fire_event(CustomEvents.TEST_EVENT)

        engine = Engine(update_fn)
        engine.register_events(*CustomEvents, event_to_attr=event_to_attr)

        num_calls = [0, ]

        @engine.on(event_name(event_filter=custom_event_filter))
        def assert_on_special_event(engine):
            assert getattr(engine.state, event_attr) == special_events.pop(0)
            num_calls[0] += 1

        d = list(range(50))
        engine.run(d, max_epochs=25)

        assert num_calls[0] == true_num_calls
Esempio n. 12
0
def test_custom_events_with_event_to_attr():
    class CustomEvents(EventEnum):
        TEST_EVENT = "test_event"

    custom_event_to_attr = {CustomEvents.TEST_EVENT: "test_event"}

    # Dummy engine
    engine = Engine(lambda engine, batch: 0)
    engine.register_events(*CustomEvents, event_to_attr=custom_event_to_attr)

    # Handle is never called
    handle = MagicMock()
    engine.add_event_handler(CustomEvents.TEST_EVENT, handle)
    engine.run(range(1))
    assert hasattr(engine.state, "test_event")
    assert engine.state.test_event == 0

    # Advanced engine
    def process_func(engine, batch):
        engine.fire_event(CustomEvents.TEST_EVENT)

    engine = Engine(process_func)
    engine.register_events(*CustomEvents, event_to_attr=custom_event_to_attr)

    def handle(engine):
        engine.state.test_event += 1

    engine.add_event_handler(CustomEvents.TEST_EVENT, handle)
    engine.run(range(25))
    assert engine.state.test_event == 25

    custom_event_to_attr = "a"
    engine = Engine(lambda engine, batch: 0)
    with pytest.raises(ValueError):
        engine.register_events(*CustomEvents, event_to_attr=custom_event_to_attr)
Esempio n. 13
0
def create_supervised_tbptt_trainer(model,
                                    optimizer,
                                    loss_fn,
                                    tbtt_step,
                                    dim=0,
                                    device=None,
                                    non_blocking=False,
                                    prepare_batch=_prepare_batch):
    """Create a trainer for truncated backprop through time supervised models.

    Training recurrent model on long sequences is computationally intensive as
    it requires to process the whole sequence before getting a gradient.
    However, when the training loss is computed over many outputs
    (`X to many <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`_),
    there is an opportunity to compute a gradient over a subsequence. This is
    known as
    `truncated backpropagation through time <https://machinelearningmastery.com/
    gentle-introduction-backpropagation-time/>`_.
    This supervised trainer apply gradient optimization step every `tbtt_step`
    time steps of the sequence, while backpropagating through the same
    `tbtt_step` time steps.

    Args:
        model (`torch.nn.Module`): the model to train
        optimizer (`torch.optim.Optimizer`): the optimizer to use
        loss_fn (torch.nn loss function): the loss function to use
        tbtt_step (int): the length of time chunks (last one may be smaller)
        dim (int): axis representing the time dimension
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU,
            the copy may occur asynchronously with respect to the host. For other cases,
            this argument has no effect.
        prepare_batch (Callable, optional): function that receives `batch`, `device`,
            `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`.

    Returns:
        Engine: a trainer engine with supervised update function

    """
    if device:
        model.to(device)

    def _update(engine, batch):
        loss_list = []
        hidden = None

        x, y = batch
        for batch_t in zip(x.split(tbtt_step, dim=dim),
                           y.split(tbtt_step, dim=dim)):
            x_t, y_t = prepare_batch(batch_t,
                                     device=device,
                                     non_blocking=non_blocking)
            # Fire event for start of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED)
            # Forward, backward and
            model.train()
            optimizer.zero_grad()
            if hidden is None:
                y_pred_t, hidden = model(x_t)
            else:
                hidden = _detach_hidden(hidden)
                y_pred_t, hidden = model(x_t, hidden)
            loss_t = loss_fn(y_pred_t, y_t)
            loss_t.backward()
            optimizer.step()

            # Setting state of engine for consistent behaviour
            engine.state.output = loss_t.item()
            loss_list.append(loss_t.item())

            # Fire event for end of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED)

        # return average loss over the time splits
        return sum(loss_list) / len(loss_list)

    engine = Engine(_update)
    engine.register_events(*Tbptt_Events)
    return engine
Esempio n. 14
0
def set_image_classification_trainer(model, optimizer, criterion, device,
                                     loaders, loggers):
    def train_step(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y).mean()
        loss.backward()
        optimizer.step()
        return loss.item()

    trainer = Engine(train_step)
    loggers['progress_bar'].attach(trainer, metric_names='all')

    def validation_step(engine, batch):
        model.eval()
        with torch.no_grad():
            x, target = batch[0].to(device), batch[1].to(device)
            y = model(x)
            return {'y_pred': y, 'y': target, 'criterion_kwargs': {}}

    evaluator = Engine(validation_step)
    evaluator.state.validation_completed = 0
    evaluator.register_events(*EvaluatorEvents, event_to_attr=event_to_attr)

    metrics = {
        'loss': Loss(criterion),
        'F1': Fbeta(beta=1, average=False),
        'mA': Accuracy(is_multilabel=False),
        'mP': Precision(average=False, is_multilabel=False),
        'mR': Recall(average=False, is_multilabel=False)
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=250),
                              log_training_loss, loggers)

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        with evaluator.add_event_handler(Events.COMPLETED, log_results,
                                         'train', engine.state.epoch, loggers):
            evaluator.run(loaders['train'])
        with evaluator.add_event_handler(Events.COMPLETED, log_results,
                                         'validation', engine.state.epoch,
                                         loggers):
            evaluator.run(loaders['validation'])
            evaluator.state.validation_completed += 1
            evaluator.fire_event(EvaluatorEvents.VALIDATION_COMPLETED)

    @trainer.on(Events.COMPLETED)
    def test(engine):
        with evaluator.add_event_handler(
                Events.COMPLETED, log_results, 'test', engine.state.epoch,
                loggers), evaluator.add_event_handler(
                    Events.COMPLETED,
                    log_calibration_results,
                    'test',
                    loggers,
                    output_transform=lambda output: {
                        'y_pred': F.softmax(output['y_pred'], dim=1),
                        'y': output['y']
                    }):
            evaluator.run(loaders['test'])

    return trainer, evaluator
Esempio n. 15
0
def create_supervised_tbptt_trainer(model,
                                    optimizer,
                                    loss_fn,
                                    tbtt_step,
                                    dim=0,
                                    device=None):
    """Create a trainer for truncated backprop through time supervised models.

    Training recurrent model on long sequences is computationally intensive as
    it requires to process the whole sequence before getting a gradient.
    However, when the training loss is computed over many outputs
    ([X to many](https://karpathy.github.io/2015/05/21/rnn-effectiveness/)),
    there is an opportunity to compute a gradient over a subsequence. This is
    known as
    [truncated backpropagation through time](
    https://machinelearningmastery.com/gentle-introduction-backpropagation-time/
    ).
    This supervised trainer apply gradient optimization step every `tbtt_step`
    time steps of the sequence, while backpropagating through the same
    `tbtt_step` time steps.

    Args:
        model (`torch.nn.Module`): the model to train
        optimizer (`torch.optim.Optimizer`): the optimizer to use
        loss_fn (torch.nn loss function): the loss function to use
        tbtt_step (int): the length of time chunks (last one may be smaller)
        dim (int): axis representing the time dimension
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.

    Returns:
        Engine: a trainer engine with supervised update function

    """
    if device:
        model.to(device)

    def _update(engine, batch):
        loss_list = []
        hidden = None

        # Batches split in time chunks
        batch_splits = _prepare_tbptt_batch(batch,
                                            tbtt_step,
                                            dim=dim,
                                            device=device)
        for x_t, y_t in batch_splits:
            # Fire event for start of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED)
            # Forward, backward and
            model.train()
            optimizer.zero_grad()
            if hidden is None:
                y_pred_t, hidden = model(x_t)
            else:
                hidden = _detach_hidden(hidden)
                y_pred_t, hidden = model(x_t)
            loss_t = loss_fn(y_pred_t, y_t)
            loss_t.backward()
            optimizer.step()

            # Setting state of engine for consistent behaviour
            engine.state.output = loss_t.item()
            loss_list.append(loss_t.item())

            # Fire event for end of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED)

        # return average loss over the time splits
        return sum(loss_list) / len(loss_list)

    engine = Engine(_update)
    engine.register_events(*Tbptt_Events)
    return engine
Esempio n. 16
0
def create_train_and_validation_engines(train_func,
                                        val_func=None,
                                        device='cpu'):
    """
    Helper function for creating an ignite Engine object with helpful defaults.
    This sets up an Engine that has four handlers attached to it:

    - prepare_batch: before a batch is passed to train_func or val_func, this
      function runs, moving every item in the batch (which is a dictionary) to
      the appropriate device ('cpu'  or 'cuda').

    - book_keeping: sets up some dictionaries that are used for bookkeeping so one
      can easily track the epoch and iteration losses for both training and
      validation.

    - add_to_iter_history: records the iteration, epoch, and past iteration losses
      into the dictionaries set up by book_keeping.

    - clear_iter_history: resets the current iteration history of losses after moving
      the current iteration history into past iteration history.
    
    Args:
        train_func (func): Function that provides the closure for training for
          a single batch.
        val_func (func, optional): Function that provides the closure for
          validating a single batch. Defaults to None.
        device (str, optional): Device to move tensors to. Defaults to 'cpu'.
    """
    # Set up engines for training and validation
    trainer = Engine(train_func)
    trainer.register_events(*ValidationEvents)
    trainer.register_events(*BackwardsEvents)

    validator = None if val_func is None else Engine(val_func)

    # Before a batch starts, the items should be float and moved to the
    # correct device, for both training and validation. Checks to make
    # sure "cuda" is available if user requested cuda.
    device = device if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    def prepare_batch(engine):
        batch = engine.state.batch
        for key in batch:
            if torch.is_tensor(batch[key]):
                batch[key] = batch[key].float().to(device)
        engine.state.batch = batch

    # Set up stuff for bookkeeping as training progresses.
    def book_keeping(engine):
        engine.state.epoch_history = {}
        engine.state.iter_history = {}
        engine.state.past_iter_history = {}

    def add_to_iter_history(engine):
        for key in engine.state.output:
            if key not in engine.state.iter_history:
                engine.state.iter_history[key] = []
            if key not in engine.state.past_iter_history:
                engine.state.past_iter_history[key] = []
            engine.state.iter_history[key].append(engine.state.output[key])
            engine.state.past_iter_history[key].append(
                engine.state.iter_history[key])

    def clear_iter_history(engine):
        engine.state.iter_history = {}

    trainer.add_event_handler(Events.ITERATION_STARTED, prepare_batch)
    trainer.add_event_handler(Events.STARTED, book_keeping)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, add_to_iter_history)
    trainer.add_event_handler(Events.EPOCH_STARTED, clear_iter_history)

    if validator is not None:
        validator.add_event_handler(Events.ITERATION_STARTED, prepare_batch)
        validator.add_event_handler(Events.STARTED, book_keeping)
        validator.add_event_handler(Events.ITERATION_COMPLETED,
                                    add_to_iter_history)
        validator.add_event_handler(Events.EPOCH_STARTED, clear_iter_history)

    return trainer, validator
Esempio n. 17
0
def create_supervised_tbptt_trainer(
    model: nn.Module,
    optimizer: Optimizer,
    loss_fn: nn.Module,
    tbtt_step: int,
    dim: int = 0,
    device: Optional[str] = None,
    non_blocking: bool = False,
    prepare_batch: Callable = _prepare_batch,
):
    """Create a trainer for truncated backprop through time supervised models.

    Training recurrent model on long sequences is computationally intensive as
    it requires to process the whole sequence before getting a gradient.
    However, when the training loss is computed over many outputs
    (`X to many <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`_),
    there is an opportunity to compute a gradient over a subsequence. This is
    known as
    `truncated backpropagation through time <https://machinelearningmastery.com/
    gentle-introduction-backpropagation-time/>`_.
    This supervised trainer apply gradient optimization step every `tbtt_step`
    time steps of the sequence, while backpropagating through the same
    `tbtt_step` time steps.

    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        tbtt_step (int): the length of time chunks (last one may be smaller).
        dim (int): axis representing the time dimension.
        device (str, optional): device type specification (default: None).
            Applies to batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU,
            the copy may occur asynchronously with respect to the host. For other cases,
            this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`,
            `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`.

    .. warning::

        The internal use of `device` has changed.
        `device` will now *only* be used to move the input data to the correct device.
        The `model` should be moved by the user before creating an optimizer.

        For more information see:

        * `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
        * `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

    Returns:
        Engine: a trainer engine with supervised update function.

    """
    def _update(engine: Engine, batch: Sequence[torch.Tensor]):
        loss_list = []
        hidden = None

        x, y = batch
        for batch_t in zip(x.split(tbtt_step, dim=dim),
                           y.split(tbtt_step, dim=dim)):
            x_t, y_t = prepare_batch(batch_t,
                                     device=device,
                                     non_blocking=non_blocking)
            # Fire event for start of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED)
            # Forward, backward and
            model.train()
            optimizer.zero_grad()
            if hidden is None:
                y_pred_t, hidden = model(x_t)
            else:
                hidden = _detach_hidden(hidden)
                y_pred_t, hidden = model(x_t, hidden)
            loss_t = loss_fn(y_pred_t, y_t)
            loss_t.backward()
            optimizer.step()

            # Setting state of engine for consistent behaviour
            engine.state.output = loss_t.item()
            loss_list.append(loss_t.item())

            # Fire event for end of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED)

        # return average loss over the time splits
        return sum(loss_list) / len(loss_list)

    engine = Engine(_update)
    engine.register_events(*Tbptt_Events)
    return engine
Esempio n. 18
0
def create_engine(
    model,
    loss_fn,
    constraint_fn,
    optimizer=None,
    metrics=None,
    monitor=None,
    guard=True,
    method="unconstrained",
    reduction=None,
    device="cpu",
):
    """Creates an engine with the necessary components. If optimizer is not
    provided, then will run inference

    :param model: model to train or evaluate
    :param loss_fn: loss_fn to be used for training or monitored for evaluation
    :param constraint_fn: constraint function to be used for training or
        monitored for evaluation
    :param optimizer: optimizer to use to update the model. If not provided,
        then the model weights are not updated
    :param metrics: an optional dictionary of (ignite / pyinsulate.ignite)
        metrics to attach to the engine
    :param monitor: handler to be used for monitoring. Must have an
        .attach(engine) method
    :param guard: whether to perform a check to ensure that the model is
        training
    :param method: method to use for constraining. Should be one of
        "constrained" - compute average (along batch) of constrained update
        "batchwise" - compute constrained update of mean loss with respect to
            all constraints within the batch
        "reduction" - apply reduction before computing constraints. If no 
            reduction is specified, will throw error
        "unconstrained" - don't constrain. Used as a control method
        "soft-constrained" - use soft constraints
        "no-loss" - intended entirely for debugging. Ignores the loss function
            entirely and just tries to satisfy the constraints
        "non-projecting" - the sum of "no-loss" and "unconstrained". This 
            destroys the exponential convergence guarantee, but should be useful
            for debugging
    :param reduction: reduction to apply to constraints before computing 
        constrained loss if method == "reduction"
    :returns: an ignite.engine.Engine whose output is (xb, yb, out) for every
        iteration
    """
    def end_section(engine, section_event, section_start_time):
        """End the section, tabulate the time, fire the event, and resume time"""
        engine.state.times[section_event.value] = (perf_counter() -
                                                   section_start_time)
        engine.fire_event(section_event)
        return perf_counter()

    def proof_of_constraint_iteration(engine, batch):

        if not hasattr(engine.state, "last_grounded"):
            engine.state.last_grounded = 0
        if not hasattr(engine.state, "times"):
            setattr(engine.state, "times", dict())

        iteration_start = perf_counter()
        section_start = iteration_start
        if optimizer is not None:
            model.train()
            optimizer.zero_grad()
        else:
            model.eval()
        engine.state.xb, engine.state.yb = prepare_batch(
            batch, device=torch.device(device))

        section_start = end_section(engine, Sub_Batch_Events.DATA_LOADED,
                                    section_start)

        engine.state.out = model(*engine.state.xb)
        section_start = end_section(engine,
                                    Sub_Batch_Events.FORWARD_PASS_COMPLETED,
                                    section_start)

        if guard:
            # Ensure training isn't failing
            last = getattr(engine.state, "last", None)
            if (last is not None and len(engine.state.out) == len(last)
                    and torch.allclose(engine.state.out, last)):
                print("WARNING! Just outputting same thing!")
                print(f"xb: {[x.cpu() for x in engine.state.xb]}")
                print(f"yb: {engine.state.yb.cpu()}")
                print(f"out: {engine.state.out.cpu()}")
            engine.state.last = engine.state.out
            if torch.allclose(
                    engine.state.out,
                    engine.state.out.new_zeros(engine.state.out.size()),
            ):
                print("WARNING! Training is failing")
        section_start = end_section(engine, Sub_Batch_Events.GUARD_COMPLETED,
                                    section_start)

        engine.state.loss = loss_fn(engine.state.out, engine.state.yb)
        engine.state.mean_loss = torch.mean(engine.state.loss)
        section_start = end_section(engine, Sub_Batch_Events.LOSS_COMPUTED,
                                    section_start)

        engine.state.constraints, engine.state.constraints_diagnostics = constraint_fn(
            engine.state.out, engine.state.xb, model,
            True)  # last parameter is to return diagnostics

        section_start = end_section(engine,
                                    Sub_Batch_Events.CONSTRAINTS_COMPUTED,
                                    section_start)

        if method == "constrained":
            constrained_loss, engine.state.multipliers, multiplier_computation_timing = constrain_loss(
                engine.state.loss,
                engine.state.constraints,
                list(model.parameters()),
                return_multipliers=True,
                return_timing=True,
                # defaults are for this method
            )
            engine.state.reduced_constraints = engine.state.constraints.new_zeros(
                1)
            engine.state.constrained_loss = torch.mean(constrained_loss)
            engine.state.times.update(multiplier_computation_timing)
        elif method == "batchwise":
            engine.state.constrained_loss, engine.state.multipliers, multiplier_computation_timing = constrain_loss(
                engine.state.loss,
                engine.state.constraints,
                list(model.parameters()),
                return_multipliers=True,
                return_timing=True,
                batchwise=True,
            )
            engine.state.reduced_constraints = engine.state.constraints.new_zeros(
                1)
            engine.state.times.update(multiplier_computation_timing)
        elif method == "reduction":
            if reduction is None:
                raise ValueError(
                    "Reduction must be specified if method=='reduction'")
            engine.state.constrained_loss, engine.state.multipliers, multiplier_computation_timing = constrain_loss(
                engine.state.loss,
                engine.state.constraints,
                list(model.parameters()),
                return_multipliers=True,
                return_timing=True,
                reduction=reduction,
            )
            engine.state.reduced_constraints = reduction(
                engine.state.constraints)
            engine.state.times.update(multiplier_computation_timing)
        elif method == "soft-constrained":
            engine.state.multipliers = (engine.state.constraints /
                                        engine.state.constraints.numel())
            engine.state.constrained_loss = torch.mean(
                engine.state.loss) + torch.mean(
                    engine.state.constraints * engine.state.constraints)
            engine.state.reduced_constraints = engine.state.constraints.new_zeros(
                1)
        elif method == "unconstrained":
            # Technically the multipliers are zero, so we set this for consistency
            engine.state.multipliers = engine.state.constraints.new_zeros(
                engine.state.constraints.size())
            engine.state.constrained_loss = torch.mean(engine.state.loss)
            engine.state.reduced_constraints = engine.state.constraints.new_zeros(
                1)
        elif method == "no-loss":
            constrained_loss, engine.state.multipliers, multiplier_computation_timing = constrain_loss(
                engine.state.loss.new_zeros(
                    engine.state.loss.size()).requires_grad_(),
                engine.state.constraints,
                list(model.parameters()),
                return_multipliers=True,
                return_timing=True,
            )
            engine.state.constrained_loss = torch.mean(constrained_loss)
            engine.state.times.update(multiplier_computation_timing)
            engine.state.reduced_constraints = engine.state.constraints.new_zeros(
                1)
        elif method == "non-projecting":
            correction_term, engine.state.multipliers, multiplier_computation_timing = constrain_loss(
                engine.state.loss.new_zeros(
                    engine.state.loss.size()).requires_grad_(),
                engine.state.constraints,
                list(model.parameters()),
                return_multipliers=True,
                return_timing=True,
            )
            engine.state.constrained_loss = torch.mean(engine.state.loss +
                                                       correction_term)
            engine.state.times.update(multiplier_computation_timing)
            engine.state.reduced_constraints = engine.state.constraints.new_zeros(
                1)
        else:
            raise ValueError(f"Method {method} not known. Please respecify")

        section_start = end_section(engine,
                                    Sub_Batch_Events.REWEIGHTED_LOSS_COMPUTED,
                                    section_start)

        # log the values of the model parameters (without gradients)
        engine.state.model_parameters = (torch.cat(
            [param.view(-1) for param in model.parameters()],
            dim=-1).clone().detach())
        if optimizer is not None:
            engine.state.constrained_loss.backward()
            # attach the gradients
            engine.state.model_parameters_grad = torch.cat(
                [param.grad.view(-1) for param in model.parameters()], dim=-1)
            optimizer.step()
        else:
            engine.state.model_parameters_grad = None
        engine.state.model_state_dict = model.state_dict()
        if optimizer is not None:
            engine.state.optimizer_state_dict = optimizer.state_dict()
        else:
            engine.state.optimizer_state_dict = None
        section_start = end_section(engine, Sub_Batch_Events.OPTIMIZER_STEPPED,
                                    section_start)

        if torch.allclose(
                engine.state.constrained_loss,
                engine.state.constrained_loss.new_zeros(
                    engine.state.constrained_loss.size()),
        ):
            print("Constrained loss is zero!")

        engine.state.times["total"] = perf_counter() - iteration_start
        return engine.state.xb, engine.state.yb, engine.state.out

    engine = Engine(proof_of_constraint_iteration)
    engine.register_events(*Sub_Batch_Events)

    if metrics is not None:
        for name, metric in metrics.items():
            metric.attach(engine, name)

    if monitor is not None:
        monitor.attach(engine)

    return engine
Esempio n. 19
0
 def attach(self, engine: Engine):
     engine.add_event_handler(Events.ITERATION_COMPLETED, self)
     engine.register_events(*PeriodEvents)
     for e in PeriodEvents:
         State.event_to_attr[e] = "iteration"
Esempio n. 20
0
 def attach(self, engine: Engine):
     engine.add_event_handler(Events.ITERATION_COMPLETED, self)
     engine.register_events(*EpisodeEvents)
     State.event_to_attr[EpisodeEvents.EPISODE_COMPLETED] = "episode"
     State.event_to_attr[EpisodeEvents.BOUND_REWARD_REACHED] = "episode"
     State.event_to_attr[EpisodeEvents.BEST_REWARD_REACHED] = "episode"
Esempio n. 21
0
def get_prepared_engine_for_handlers_profiler(true_event_handler_time):
    HANDLERS_SLEEP_COUNT = 11
    PROCESSING_SLEEP_COUNT = 3

    class CustomEvents(EventEnum):
        CUSTOM_STARTED = "custom_started"
        CUSTOM_COMPLETED = "custom_completed"

    def dummy_train_step(engine, batch):
        engine.fire_event(CustomEvents.CUSTOM_STARTED)
        time.sleep(true_event_handler_time)
        engine.fire_event(CustomEvents.CUSTOM_COMPLETED)

    dummy_trainer = Engine(dummy_train_step)
    dummy_trainer.register_events(*CustomEvents)

    @dummy_trainer.on(Events.STARTED)
    def delay_start(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.COMPLETED)
    def delay_complete(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.EPOCH_STARTED)
    def delay_epoch_start(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.EPOCH_COMPLETED)
    def delay_epoch_complete(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.ITERATION_STARTED)
    def delay_iter_start(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.ITERATION_COMPLETED)
    def delay_iter_complete(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.GET_BATCH_STARTED)
    def delay_get_batch_started(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.GET_BATCH_COMPLETED)
    def delay_get_batch_completed(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(CustomEvents.CUSTOM_STARTED)
    def delay_custom_started(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(CustomEvents.CUSTOM_COMPLETED)
    def delay_custom_completed(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.EPOCH_STARTED(once=1))
    def do_something_once_on_1_epoch():
        time.sleep(true_event_handler_time)

    return dummy_trainer, HANDLERS_SLEEP_COUNT, PROCESSING_SLEEP_COUNT
Esempio n. 22
0
def run(train_config, logger, **kwargs):

    logger = logging.getLogger('UDA')
    if getattr(train_config, 'debug', False):
        setup_logger(logger, logging.DEBUG)

    # Set Polyaxon environment if needed
    plx_logger = None
    save_dir = None
    output_experiment_path = None
    try:
        plx_logger = PolyaxonLogger()
        experiment = plx_logger.experiment
        save_dir = get_outputs_path()
        output_experiment_path = get_outputs_refs_paths()
        output_experiment_path = output_experiment_path['experiments'][
            0] if output_experiment_path else None
        logger.debug("Experiment info: {}".format(
            experiment.get_experiment_info()))
    except PolyaxonClientException as e:
        logger.warning('Logger Polyaxon : ' + str(e))

    # Path configuration
    saves_dict = getattr(train_config, 'saves', {})

    save_dir = saves_dict.get('save_dir', '') if save_dir is None else save_dir
    log_dir = os.path.join(save_dir, saves_dict.get('log_dir', ''))
    save_model_dir = os.path.join(save_dir, saves_dict.get('model_dir', ''))
    save_prediction_dir = os.path.join(save_dir,
                                       saves_dict.get('prediction_dir', ''))
    save_config_dir = os.path.join(save_dir, saves_dict.get('config_dir', ''))
    load_model_file = saves_dict.get('load_model_file', '')
    load_optimizer_file = saves_dict.get('load_optimizer_file', '')

    # Create folders
    create_save_folders(save_dir, saves_dict)

    if output_experiment_path is not None:
        model_dir = saves_dict.get('model_dir', '')
        load_model_file = os.path.join(
            output_experiment_path, model_dir,
            load_model_file) if load_model_file else None
        load_optimizer_file = os.path.join(
            output_experiment_path, model_dir,
            load_optimizer_file) if load_optimizer_file else None

    num_epochs = getattr(train_config, 'num_epochs')
    num_classes = getattr(train_config, 'num_classes')
    device = getattr(train_config, 'device', 'cpu')

    # Set magical acceleration
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
    else:
        assert device == 'cpu', 'CUDA device selected but none is available'

    # Set half precision if required
    use_fp_16 = getattr(train_config, 'use_fp_16', False)

    train1_sup_loader = getattr(train_config, 'train1_sup_loader')
    train1_unsup_loader = getattr(train_config, 'train1_unsup_loader')
    train2_unsup_loader = getattr(train_config, 'train2_unsup_loader')
    test_loader = getattr(train_config, 'test_loader')

    save_interval = saves_dict.get('save_interval', 0)
    n_saved = saves_dict.get('n_saved', 0)

    val_interval = getattr(train_config, 'val_interval', 1)
    pred_interval = getattr(train_config, 'pred_interval', 0)

    model = getattr(train_config, 'model').to(device)

    optimizer = getattr(train_config, 'optimizer')

    criterion = getattr(train_config, 'criterion').to(device)
    consistency_criterion = getattr(train_config,
                                    'consistency_criterion').to(device)

    cm_metric = getattr(
        train_config, 'cm_metric',
        ConfusionMatrix(num_classes=num_classes,
                        output_transform=lambda x: (x['y_pred'], x['y'])))

    # AMP initialization for half precision
    if use_fp_16:
        assert 'cuda' in device
        assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
        try:
            from apex import amp
        except:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to run this example."
            )
        # Initialize amp
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

    # Load checkpoint
    load_params(model,
                optimizer=optimizer,
                model_file=load_model_file,
                optimizer_file=load_optimizer_file,
                device_name=device)

    # Add batch norm
    is_bn = getattr(train_config, 'is_bn', False)
    if is_bn:
        batch_norm = nn.BatchNorm2d(3).to(device)
        if use_fp_16:
            batch_norm = amp.initialize(batch_norm)
        batch_norm.reset_parameters()
        model = nn.Sequential(batch_norm, model)

    # Copy the config file
    shutil.copy2(os.path.abspath(train_config.__file__),
                 os.path.join(save_config_dir, 'checkpoint_module.py'))

    le = len(train1_sup_loader)
    num_train_steps = le * num_epochs
    mlflow.log_param("num train steps", num_train_steps)

    lr = getattr(train_config, 'learning_rate')
    num_warmup_steps = getattr(train_config, 'num_warmup_steps', 0)

    lr_scheduler = getattr(train_config, 'lr_scheduler', None)
    if lr_scheduler is not None:
        lr_scheduler = lr_scheduler(optimizer)

    if num_warmup_steps > 0:
        lr_scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=lr * (1.0 + 1.0 / num_warmup_steps),
            warmup_duration=num_warmup_steps)

    train1_sup_loader_iter = cycle(train1_sup_loader)
    train1_unsup_loader_iter = cycle(train1_unsup_loader)
    train2_unsup_loader_iter = cycle(train2_unsup_loader)

    # Reduce on plateau
    reduce_on_plateau = getattr(train_config, 'reduce_on_plateau', None)

    # Output transform model
    output_transform_model = getattr(train_config, 'output_transform_model',
                                     lambda x: x)

    inference_fn = getattr(train_config, 'inference_fn', inference_standard)

    lam = getattr(train_config, 'consistency_lambda')
    beta = getattr(train_config, 'consistency_beta', lam)

    tsa = TrainingSignalAnnealing(
        num_steps=num_train_steps,
        min_threshold=getattr(train_config, 'TSA_proba_min'),
        max_threshold=getattr(train_config, 'TSA_proba_max'))

    with_tsa = getattr(train_config, 'with_TSA', False)

    cfg = {
        'tsa': tsa,
        'lambda': lam,
        'beta': beta,
        'with_tsa': with_tsa,
        'device': device,
        'consistency_criterion': consistency_criterion,
        'criterion': criterion
    }

    trainer = Engine(
        partial(train_update_function,
                model=model,
                optimizer=optimizer,
                cfg=cfg,
                train1_sup_loader_iter=train1_sup_loader_iter,
                train1_unsup_loader_iter=train1_unsup_loader_iter,
                train2_unsup_loader_iter=train2_unsup_loader_iter,
                output_transform_model=output_transform_model,
                use_fp_16=use_fp_16))

    # Register events
    for e in CustomEvents:
        State.event_to_attr[e] = 'iteration'

    trainer.register_events(*CustomEvents)

    if with_tsa:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, log_tsa, tsa)

    if lr_scheduler is not None:
        if not hasattr(lr_scheduler, "step"):
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED,
                                      lambda engine: lr_scheduler.step())

    trainer.add_event_handler(Events.ITERATION_COMPLETED, log_learning_rate,
                              optimizer)

    metric_names = [
        'supervised batch loss', 'consistency batch loss', 'final batch loss'
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        RunningAverage(
            output_transform=partial(output_transform, name=n)).attach(
                trainer, n)

    ProgressBar(persist=True,
                bar_format="").attach(trainer,
                                      event_name=Events.EPOCH_STARTED,
                                      closing_event_name=Events.COMPLETED)

    # Handlers for Tensorboard logging
    tb_logger = TensorboardLogger(log_dir=log_dir)
    tb_logger.attach(trainer,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names=metric_names),
                     event_name=CustomEvents.ITERATION_K_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=tbOptimizerParamsHandler(optimizer,
                                                          param_name="lr"),
                     event_name=CustomEvents.ITERATION_K_STARTED)

    # Handlers for Polyaxon logging
    if plx_logger is not None:
        plx_logger.attach(trainer,
                          log_handler=plxOutputHandler(
                              tag="train", metric_names=metric_names),
                          event_name=CustomEvents.ITERATION_K_COMPLETED)

    metrics = {
        'loss': Loss(criterion,
                     output_transform=lambda x: (x['y_pred'], x['y'])),
        'mAcc': cmAccuracy(cm_metric).mean(),
        'mPr': cmPrecision(cm_metric).mean(),
        'mRe': cmRecall(cm_metric).mean(),
        'mIoU': mIoU(cm_metric),
        'mF1': cmFbeta(cm_metric, 1).mean()
    }
    iou = IoU(cm_metric)
    for i in range(num_classes):
        key_name = 'IoU_{}'.format(str(i))
        metrics[key_name] = iou[i]

    inference_update_fn = partial(
        inference_update_function,
        model=model,
        cfg=cfg,
        output_transform_model=output_transform_model,
        inference_fn=inference_fn)

    evaluator = Engine(inference_update_fn)
    train_evaluator = Engine(inference_update_fn)

    for name, metric in metrics.items():
        metric.attach(train_evaluator, name)
        metric.attach(evaluator, name)

    # Add checkpoint
    if save_model_dir:
        checkpoint = ModelCheckpoint(dirname=save_model_dir,
                                     filename_prefix='checkpoint',
                                     save_interval=save_interval,
                                     n_saved=n_saved,
                                     create_dir=True)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint, {
            'mymodel': model,
            'optimizer': optimizer
        })

    def trigger_k_iteration_started(engine, k):
        if engine.state.iteration % k == 0:
            engine.fire_event(CustomEvents.ITERATION_K_STARTED)

    def trigger_k_iteration_completed(engine, k):
        if engine.state.iteration % k == 0:
            engine.fire_event(CustomEvents.ITERATION_K_COMPLETED)

    def run_validation(engine, validation_interval):
        if (trainer.state.epoch - 1) % validation_interval == 0:
            train_evaluator.run(train1_sup_loader)
            evaluator.run(test_loader)

            if save_prediction_dir:
                train_output = train_evaluator.state.output
                test_output = evaluator.state.output

                iteration = str(trainer.state.iteration)
                epoch = str(trainer.state.epoch)

                save_prediction('train_{}_{}'.format(iteration, epoch),
                                save_prediction_dir,
                                train_output['x'],
                                torch.argmax(
                                    train_output['y_pred'][0, :, :, :], dim=0),
                                y=train_output['y'][0, :, :])

                save_prediction('test_{}_{}'.format(iteration, epoch),
                                save_prediction_dir,
                                test_output['x'],
                                torch.argmax(test_output['y_pred'][0, :, :, :],
                                             dim=0),
                                y=test_output['y'][0, :, :])

            train_evaluator.state.output = None
            evaluator.state.output = None

            if reduce_on_plateau is not None:
                reduce_on_plateau.step(evaluator.state.metrics['mIoU'])

    trainer.add_event_handler(Events.ITERATION_STARTED,
                              trigger_k_iteration_started,
                              k=10)
    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              trigger_k_iteration_completed,
                              k=10)

    trainer.add_event_handler(Events.EPOCH_STARTED,
                              run_validation,
                              validation_interval=val_interval)
    trainer.add_event_handler(Events.COMPLETED,
                              run_validation,
                              validation_interval=1)

    def trainer_prediction_save(engine, prediction_interval):
        if (engine.state.iteration - 1) % prediction_interval == 0:

            if save_prediction_dir:
                trainer_output = trainer.state.output['unsup pred']

                iteration = str(trainer.state.iteration)
                epoch = str(trainer.state.epoch)

                save_prediction('trainer_{}_{}'.format(iteration, epoch),
                                save_prediction_dir, trainer_output['x'],
                                trainer_output['y_pred'])

                logger.debug(
                    'Saved trainer prediction for iteration {}'.format(
                        str(engine.state.iteration)))

            trainer.state.output = None

    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              trainer_prediction_save,
                              prediction_interval=pred_interval)

    tb_logger.attach(train_evaluator,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names=list(
                                                     metrics.keys())),
                     event_name=Events.EPOCH_COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=tbOutputHandler(tag="test",
                                                 metric_names=list(
                                                     metrics.keys())),
                     event_name=Events.EPOCH_COMPLETED)

    # Handlers for Polyaxon logging
    if plx_logger is not None:
        plx_logger.attach(train_evaluator,
                          log_handler=plxOutputHandler(tag="train",
                                                       metric_names=list(
                                                           metrics.keys())),
                          event_name=Events.EPOCH_COMPLETED)

        plx_logger.attach(evaluator,
                          log_handler=plxOutputHandler(tag="test",
                                                       metric_names=list(
                                                           metrics.keys())),
                          event_name=Events.EPOCH_COMPLETED)

    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              mlflow_batch_metrics_logging, "train", trainer)
    train_evaluator.add_event_handler(Events.COMPLETED,
                                      mlflow_val_metrics_logging, "train",
                                      trainer)
    evaluator.add_event_handler(Events.COMPLETED, mlflow_val_metrics_logging,
                                "test", trainer)

    data_steps = list(range(len(train1_sup_loader)))

    logger.debug('Start training')
    trainer.run(data_steps, max_epochs=num_epochs)
    logger.debug('Finished training')
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(
                "image"
            ),  # test to support both Tensor and Numpy array when inverting
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            postfix="inverted1",
            to_tensor=[True, False],
            device="cpu",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        # test different nearest interpolation values
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="image",
            nearest_interp=[True, False],
            post_func=[lambda x: x + 10, lambda x: x],
            postfix="inverted2",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        # check the nearest inerpolation mode
        for i in engine.state.output["image_inverted1"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        for i in engine.state.output["label_inverted1"]:
            np.testing.assert_allclose(
                i.astype(np.uint8).astype(np.float32), i.astype(np.float32))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        # check labels match
        reverted = engine.state.output["label_inverted1"][-1].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1824: torch 1.5.1
        self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824),
                        "diff. in 3 possible values")

        # check the case that different items use different interpolation mode to invert transforms
        for i in engine.state.output["image_inverted2"]:
            # if the interpolation mode is nearest, accumulated diff should be smaller than 1
            self.assertLess(
                torch.sum(
                    i.to(torch.float) -
                    i.to(torch.uint8).to(torch.float)).item(), 1.0)
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        for i in engine.state.output["label_inverted2"]:
            # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
            self.assertGreater(
                torch.sum(
                    i.to(torch.float) -
                    i.to(torch.uint8).to(torch.float)).item(), 10000.0)
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
Esempio n. 24
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(KEYS),
            CastToTyped(KEYS, dtype=torch.uint8),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        for i in engine.state.output["image_inverted"] + engine.state.output[
                "label_inverted"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        # check labels match
        reverted = engine.state.output["label_inverted"][-1].detach().cpu(
        ).numpy()[0].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        self.assertTrue((reverted.size - n_good) in (25300, 1812),
                        "diff. in two possible values")