예제 #1
0
def test_args_validation():

    trainer = Engine(do_nothing_update_fn)

    with pytest.raises(ValueError,
                       match=r"Argument patience should be positive integer."):
        EarlyStopping(patience=-1,
                      score_function=lambda engine: 0,
                      trainer=trainer)

    with pytest.raises(
            ValueError,
            match=r"Argument min_delta should not be a negative number."):
        EarlyStopping(patience=2,
                      min_delta=-0.1,
                      score_function=lambda engine: 0,
                      trainer=trainer)

    with pytest.raises(TypeError,
                       match=r"Argument score_function should be a function."):
        EarlyStopping(patience=2, score_function=12345, trainer=trainer)

    with pytest.raises(
            TypeError,
            match=r"Argument trainer should be an instance of Engine."):
        EarlyStopping(patience=2,
                      score_function=lambda engine: 0,
                      trainer=None)
예제 #2
0
def test_with_engine_early_stopping():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9])

    def score_function(engine):
        return next(scores)

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)
    evaluator = Engine(update_fn)
    early_stopping = EarlyStopping(patience=3,
                                   score_function=score_function,
                                   trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 7
예제 #3
0
    def create_callbacks(self):

        ## SETUP CALLBACKS
        print('[INFO] Creating callback functions for training loop...',
              end='')
        # Early Stopping - stops training if the validation loss does not decrease after 5 epochs
        handler = EarlyStopping(patience=self.config.EARLY_STOPPING_PATIENCE,
                                score_function=score_function_loss,
                                trainer=self.train_engine)
        self.evaluator.add_event_handler(Events.COMPLETED, handler)
        print('Early Stopping ({} epochs)...'.format(
            self.config.EARLY_STOPPING_PATIENCE),
              end='')

        val_checkpointer = Checkpoint(
            {"model": self.model},
            ClearMLSaver(),
            n_saved=1,
            score_function=score_function_acc,
            score_name="val_acc",
            filename_prefix='cub200_{}_ignite_best'.format(
                self.config.MODEL.MODEL_NAME),
            global_step_transform=global_step_from_engine(self.train_engine),
        )
        self.evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                         val_checkpointer)
        print('Model Checkpointing...', end='')
        print('Done')
예제 #4
0
    def finalize(self, context):
        if context.local_rank == 0:
            publisher = PublishStatsAndModel(
                self._stats_path,
                self._publish_path,
                self._key_metric_filename,
                context.start_ts,
                context.run_id,
                context.output_dir,
                context.trainer,
                context.evaluator,
            )
            if context.evaluator:
                context.evaluator.add_event_handler(
                    event_name=Events.EPOCH_COMPLETED, handler=publisher)
            else:
                context.trainer.add_event_handler(
                    event_name=Events.EPOCH_COMPLETED, handler=publisher)

        early_stop_patience = int(context.request.get("early_stop_patience",
                                                      0))
        if early_stop_patience > 0 and context.evaluator:
            early_stopper = EarlyStopping(
                patience=early_stop_patience,
                score_function=stopping_fn_from_metric(self.VAL_KEY_METRIC),
                trainer=context.trainer,
            )
            context.evaluator.add_event_handler(
                event_name=Events.EPOCH_COMPLETED, handler=early_stopper)
예제 #5
0
 def _setup_early_stopping(self, trainer, val_evaluator, score_function):
     kwargs = dict(self.early_stopping_kwargs)
     if 'score_function' not in kwargs:
         kwargs['score_function'] = score_function
     handler = EarlyStopping(trainer=trainer, **kwargs)
     setup_logger(handler._logger, self.log_filepath, self.log_level)
     val_evaluator.add_event_handler(Events.COMPLETED, handler)
예제 #6
0
def train(param, device):
    model = Model(param)
    state_dict = torch.load(CKPT)
    new_dict = model.state_dict().copy()
    for k, v in state_dict.items():
        if k.startswith('t_encoder'):
            new_dict[k] = state_dict[k]
    model.load_state_dict(new_dict)
    for parameter in model.t_encoder.parameters():
        parameter.requires_grad = False
    optimizer = AdamW(model.parameters(), lr=param.lr, eps=1e-8)
    update_steps = MAX_EPOCH * len(train_loader)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=update_steps)
    loss_fn = L1Loss()
    trainer = create_trainer(model, optimizer, scheduler, loss_fn, MAX_GRAD_NORM, device)
    dev_evaluator = create_evaluator(model, val_metrics, device)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=10), log_training_loss)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results, *[dev_evaluator, dev_loader, 'Dev'])
    es_handler = EarlyStopping(patience=PATIENCE, score_function=score_fn, trainer=trainer)
    dev_evaluator.add_event_handler(Events.COMPLETED, es_handler)
    ckpt_handler = ModelCheckpoint(SAVE_PATH, f'lr_{param.lr}', score_function=score_fn,
                                   score_name='score', require_empty=True)
    dev_evaluator.add_event_handler(Events.COMPLETED, ckpt_handler, {SAVE_PATH.split("/")[-1]: model})
    print(f'Start running {SAVE_PATH.split("/")[-1]} at device: {DEVICE}\tlr: {param.lr}')
    trainer.run(train_loader, max_epochs=MAX_EPOCH)
예제 #7
0
def test_with_engine_early_stopping_on_plateau():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    def score_function(engine):
        return 0.047

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=4,
                                   score_function=score_function,
                                   trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 5
    assert trainer.state.epoch == 5
예제 #8
0
def test_args_validation():

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)

    # save_interval & score_func
    with pytest.raises(AssertionError):
        h = EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer)

    with pytest.raises(AssertionError):
        h = EarlyStopping(patience=2, score_function=12345, trainer=trainer)

    with pytest.raises(AssertionError):
        h = EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)
예제 #9
0
def test_with_engine_no_early_stopping():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=5,
                                   score_function=score_function,
                                   trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 10
    assert trainer.state.epoch == 10
예제 #10
0
def add_early_stopping_and_checkpoint(evaluator: Engine, trainer: Engine,
                                      checkpoint_filename: str,
                                      model: Module) -> None:
    """
    adds two event handlers to an ``ignite`` trainer/evaluator pair:

    * early stopping
    * best model checkpoint saver

    :param evaluator: an evaluator to add hooks to
    :param trainer: a trainer from which to make a checkpoint
    :param checkpoint_filename: some pretty name for a checkpoint
    :param model: a network which is saved in checkpoints
    """
    def score(engine):
        return -engine.state.metrics["loss"]

    early_stopping = EarlyStopping(100, score, trainer)
    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    checkpoint = ModelCheckpoint("checkpoints",
                                 "",
                                 score_function=score,
                                 require_empty=False)
    evaluator.add_event_handler(Events.COMPLETED, checkpoint,
                                {checkpoint_filename: model})
예제 #11
0
def _test_distrib_with_engine_early_stopping(device):

    import torch.distributed as dist

    torch.manual_seed(12)

    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = torch.tensor([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9], requires_grad=False).to(device)

    def score_function(engine):
        i = trainer.state.epoch - 1
        v = scores[i]
        dist.all_reduce(v)
        v /= dist.get_world_size()
        return v.item()

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert trainer.state.epoch == 7
    assert n_epochs_counter.count == 7
예제 #12
0
def train(epochs, model, train_loader, valid_loader, criterion, optimizer,
          writer, device, log_interval):
    # device: str であることに注意
    # この時点では Dataloader を与えていないことに注意
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy': Accuracy(),
                                                'nll': Loss(criterion)
                                            },
                                            device=device)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        i = (engine.state.iteration - 1) % len(train_loader) + 1
        if i % log_interval == 0:
            print(
                f"Epoch[{engine.state.epoch}] Iteration[{i}/{len(train_loader)}] "
                f"Loss: {engine.state.output:.2f}")
            # engine.state.output は criterion(model(input)) を表す?
            writer.add_scalar("training/loss", engine.state.output,
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        write_metrics(metrics, writer, 'training', engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(valid_loader)
        metrics = evaluator.state.metrics
        write_metrics(metrics, writer, 'validation', engine.state.epoch)

    # # Checkpoint setting
    # ./checkpoints/sample_mymodel_{step_number}
    # n_saved 個までパラメータを保持する
    handler = ModelCheckpoint(dirname='./checkpoints',
                              filename_prefix='sample',
                              save_interval=2,
                              n_saved=3,
                              create_dir=True)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler,
                              {'mymodel': model})

    # # Early stopping
    handler = EarlyStopping(patience=5,
                            score_function=score_function,
                            trainer=trainer)
    # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset)
    evaluator.add_event_handler(Events.COMPLETED, handler)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)
예제 #13
0
def test_state_dict():

    scores = iter([1.0, 0.8, 0.88])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(patience=2,
                      score_function=score_function,
                      trainer=trainer)
    # Call 3 times and check if stopped
    assert not trainer.should_terminate
    h(None)
    assert not trainer.should_terminate

    # Swap to new object, but maintain state
    h2 = EarlyStopping(patience=2,
                       score_function=score_function,
                       trainer=trainer)
    h2.load_state_dict(h.state_dict())

    h2(None)
    assert not trainer.should_terminate
    h2(None)
    assert trainer.should_terminate
예제 #14
0
def add_early_stopping(trainer, val_evaluator, configuration):
    # Setup early stopping:
    handler = EarlyStopping(
        patience=configuration.early_stop_patience,
        score_function=_score_function,
        trainer=trainer,
    )
    setup_logger(handler._logger, configuration.log_dir, configuration.log_level)
    val_evaluator.add_event_handler(Events.COMPLETED, handler)
예제 #15
0
def register_early_stopping(evaluator_test, trainer, args):
    def score_function(engine):
        val_loss = engine.state.metrics['bce']
        return val_loss

    early_stopping_handler = EarlyStopping(patience=args.patience,
                                           score_function=score_function,
                                           trainer=trainer)
    evaluator_test.add_event_handler(Events.COMPLETED, early_stopping_handler)
예제 #16
0
    def _build_objects(acc_list):

        model = DummyModel().to(device)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5)

        def update_fn(engine, batch):
            x = torch.rand((4, 1)).to(device)
            optim.zero_grad()
            y = model(x)
            loss = y.pow(2.0).sum()
            loss.backward()
            if idist.has_xla_support:
                import torch_xla.core.xla_model as xm

                xm.optimizer_step(optim, barrier=True)
            else:
                optim.step()
            lr_scheduler.step()

        trainer = Engine(update_fn)

        evaluator = Engine(lambda e, b: None)
        acc_iter = iter(acc_list)

        @evaluator.on(Events.EPOCH_COMPLETED)
        def setup_result():
            evaluator.state.metrics["accuracy"] = next(acc_iter)

        @trainer.on(Events.EPOCH_COMPLETED)
        def run_eval():
            evaluator.run([0, 1, 2])

        def score_function(engine):
            return engine.state.metrics["accuracy"]

        save_handler = DiskSaver(dirname, create_dir=True, require_empty=False)
        early_stop = EarlyStopping(score_function=score_function,
                                   patience=2,
                                   trainer=trainer)
        evaluator.add_event_handler(Events.COMPLETED, early_stop)

        checkpointer = Checkpoint(
            {
                "trainer": trainer,
                "model": model,
                "optim": optim,
                "lr_scheduler": lr_scheduler,
                "early_stop": early_stop,
            },
            save_handler,
            include_self=True,
            global_step_transform=global_step_from_engine(trainer),
        )
        evaluator.add_event_handler(Events.COMPLETED, checkpointer)

        return trainer, evaluator, model, optim, lr_scheduler, early_stop, checkpointer
예제 #17
0
 def make_early_stopper(self, trainer):
     if self.early_stop_metric == 'loss':
         key_name = 'val_loss'
         c = -1
     else:
         c = 1
         key_name = 'val_accuracy'
     return EarlyStopping(self.early_stop_patience, lambda e: c * e.state.metrics[key_name],
                            trainer, min_delta=self.early_stop_delta)
예제 #18
0
def test_args_validation():
    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)

    with pytest.raises(ValueError):
        h = EarlyStopping(patience=-1,
                          score_function=lambda engine: 0,
                          trainer=trainer)

    with pytest.raises(TypeError):
        h = EarlyStopping(patience=2, score_function=12345, trainer=trainer)

    with pytest.raises(TypeError):
        h = EarlyStopping(patience=2,
                          score_function=lambda engine: 0,
                          trainer=None)
예제 #19
0
def _test_distrib_integration_engine_early_stopping(device):

    from ignite.metrics import Accuracy

    if device is None:
        device = idist.device()
    if isinstance(device, str):
        device = torch.device(device)
    metric_device = device
    if device.type == "xla":
        metric_device = "cpu"

    rank = idist.get_rank()
    ws = idist.get_world_size()
    torch.manual_seed(12)

    n_epochs = 10
    n_iters = 20

    y_preds = (
        [torch.randint(0, 2, size=(n_iters, ws)).to(device)]
        + [torch.ones(n_iters, ws).to(device)]
        + [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)]
    )

    y_true = (
        [torch.randint(0, 2, size=(n_iters, ws)).to(device)]
        + [torch.ones(n_iters, ws).to(device)]
        + [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)]
    )

    def update(engine, _):
        e = trainer.state.epoch - 1
        i = engine.state.iteration - 1
        return y_preds[e][i, rank], y_true[e][i, rank]

    evaluator = Engine(update)
    acc = Accuracy(device=metric_device)
    acc.attach(evaluator, "acc")

    def score_function(engine):
        return engine.state.metrics["acc"]

    trainer = Engine(lambda e, b: None)
    early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        data = list(range(n_iters))
        evaluator.run(data=data)

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert trainer.state.epoch == 5
예제 #20
0
파일: common.py 프로젝트: zivzone/ignite
def add_early_stopping_by_val_score(patience, evaluator, trainer, metric_name):
    """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.

    Args:
        patience (int): number of events to wait if no improvement and then stop the training.
        evaluator (Engine): evaluation engine used to provide the score
        trainer (Engine): trainer engine to stop the run if no improvement.
        metric_name (str): metric name to use for score evaluation. This metric should be present in
            `evaluator.state.metrics`.

    """
    es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, es_handler)
예제 #21
0
def test_simple_early_stopping_on_plateau():
    def score_function(engine):
        return 42

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(patience=1, score_function=score_function, trainer=trainer)
    # Call 2 times and check if stopped
    assert not trainer.should_terminate
    h(None)
    assert not trainer.should_terminate
    h(None)
    assert trainer.should_terminate
예제 #22
0
def test_args_validation():
    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)

    with pytest.raises(ValueError,
                       match=r"Argument patience should be positive integer."):
        EarlyStopping(patience=-1,
                      score_function=lambda engine: 0,
                      trainer=trainer)

    with pytest.raises(TypeError,
                       match=r"Argument score_function should be a function."):
        EarlyStopping(patience=2, score_function=12345, trainer=trainer)

    with pytest.raises(
            TypeError,
            match=r"Argument trainer should be an instance of Engine."):
        EarlyStopping(patience=2,
                      score_function=lambda engine: 0,
                      trainer=None)
예제 #23
0
    def _early_stopping_handler(self):
        """Create the EarlyStopping handler that will evaluate the `score_function` class on each `evaluator_engine` run
        and stop the `trainer_engine` if there has been no improvement in the `_score_function` for the number of
        epochs specified in `early_stopping_patience`.

        Args:

        Returns:
          the early stopping handler

        """
        return EarlyStopping(
            patience=self.early_stopping_patience, score_function=self._score_function, trainer=self.trainer_engine
        )
예제 #24
0
    def init_function(h_model):
        h_criterion = torch.nn.CrossEntropyLoss()
        h_evaluator = SupervisedEvaluator(model=h_model, criterion=h_criterion, device=device)
        h_train_evaluator = SupervisedEvaluator(model=h_model, criterion=h_criterion, device=device)
        h_optimizer = torch.optim.Adam(params=h_model.parameters(), lr=1e-3)
        h_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(h_optimizer, 'max', verbose=True, patience=5,
                                                                    factor=0.5)
        h_trainer = SupervisedTrainer(model=h_model, optimizer=h_optimizer, criterion=h_criterion, device=device)

        # Tqdm logger
        h_pbar = ProgressBar(persist=False, bar_format=config.IGNITE_BAR_FORMAT)
        h_pbar.attach(h_trainer.engine, metric_names='all')
        h_tqdm_logger = TqdmLogger(pbar=h_pbar)
        # noinspection PyTypeChecker
        h_tqdm_logger.attach_output_handler(
            h_evaluator.engine,
            event_name=Events.COMPLETED,
            tag="validation",
            global_step_transform=global_step_from_engine(h_trainer.engine),
        )
        # noinspection PyTypeChecker
        h_tqdm_logger.attach_output_handler(
            h_train_evaluator.engine,
            event_name=Events.COMPLETED,
            tag="train",
            global_step_transform=global_step_from_engine(h_trainer.engine),
        )

        # Learning rate scheduling
        # The PyTorch Ignite LRScheduler class does not work with ReduceLROnPlateau
        h_evaluator.engine.add_event_handler(Events.COMPLETED,
                                             lambda engine: h_lr_scheduler.step(engine.state.metrics['accuracy']))

        # Model checkpoints
        h_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True,
                                    require_empty=False, score_name='acc',
                                    score_function=lambda engine: engine.state.metrics['accuracy'],
                                    global_step_transform=global_step_from_engine(trainer.engine))
        h_evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, h_handler, {'m': model})

        # Early stopping
        h_es_handler = EarlyStopping(patience=15,
                                     min_delta=0.0001,
                                     score_function=lambda engine: engine.state.metrics['accuracy'],
                                     trainer=h_trainer.engine, cumulative_delta=True)
        h_es_handler.logger.setLevel(logging.DEBUG)
        h_evaluator.engine.add_event_handler(Events.COMPLETED, h_es_handler)

        return h_trainer, h_train_evaluator, h_evaluator
예제 #25
0
    def _register_early_stopping(self, loss_fn, device, trainer):
        prepare_batch = __class__._prepare_batch
        evaluator = create_supervised_evaluator(model=self.model,
                                                metrics={'nll': Loss(loss_fn)},
                                                device=device,
                                                prepare_batch=prepare_batch)

        def score_fn(engine):
            return -engine.state.metrics['nll']

        early_stopping = EarlyStopping(patience=5,
                                       score_function=score_fn,
                                       trainer=trainer)
        evaluator.add_event_handler(Events.COMPLETED, early_stopping)
        return evaluator
예제 #26
0
def test_simple_no_early_stopping():

    scores = iter([1.0, 0.8, 1.2])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
    # Call 3 times and check if not stopped
    assert not trainer.should_terminate
    h(None)
    h(None)
    h(None)
    assert not trainer.should_terminate
예제 #27
0
def test_early_stopping_on_last_event_delta():

    scores = iter([0.0, 0.3, 0.6])

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(
        patience=2, min_delta=0.4, cumulative_delta=False, score_function=lambda _: next(scores), trainer=trainer
    )

    assert not trainer.should_terminate
    h(None)  # counter == 0
    assert not trainer.should_terminate
    h(None)  # delta == 0.3; counter == 1
    assert not trainer.should_terminate
    h(None)  # delta == 0.3; counter == 2
    assert trainer.should_terminate
예제 #28
0
    def _finetune(self, train_dl, val_dl, criterion, iter_num):
        print("Recovery")
        self.model.to_rank = False
        finetune_epochs = config["pruning"]["finetune_epochs"].get()

        optimizer_constructor = optimizer_constructor_from_config(config)
        optimizer = optimizer_constructor(self.model.parameters())

        finetune_engine = create_supervised_trainer(self.model, optimizer, criterion, self.device)
        # progress bar
        pbar = Progbar(train_dl, metrics='none')
        finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, pbar)

        # log training loss
        if self.writer:
            finetune_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                              lambda engine: log_training_loss(engine, self.writer))

        # terminate on Nan
        finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

        # model checkpoints
        checkpoint = ModelCheckpoint(config["pruning"]["out_path"].get(), require_empty=False,
                                     filename_prefix=f"pruning_iteration_{iter_num}", save_interval=1)
        finetune_engine.add_event_handler(Events.COMPLETED, checkpoint, {"weights": self.model.cpu()})

        # add early stopping
        validation_evaluator = create_supervised_evaluator(self.model, device=self.device,
                                                           metrics=self._metrics)

        if config["pruning"]["early_stopping"].get():
            def _score_function(evaluator):
                return -evaluator.state.metrics["loss"]
            early_stop = EarlyStopping(config["pruning"]["patience"].get(), _score_function, finetune_engine)
            validation_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stop)

        finetune_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine:
                                          run_evaluator(engine, validation_evaluator, val_dl))

        for handler_dict in self._finetune_handlers:
            finetune_engine.add_event_handler(handler_dict["event_name"], handler_dict["handler"],
                                              *handler_dict["args"], **handler_dict["kwargs"])

        # run training engine
        finetune_engine.run(train_dl, max_epochs=finetune_epochs)
예제 #29
0
def add_early_stopping_by_val_score(patience: int, evaluator: Engine, trainer: Engine, metric_name: str):
    """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.
    Metric value should increase in order to keep training and not early stop.

    Args:
        patience (int): number of events to wait if no improvement and then stop the training.
        evaluator (Engine): evaluation engine used to provide the score
        trainer (Engine): trainer engine to stop the run if no improvement.
        metric_name (str): metric name to use for score evaluation. This metric should be present in
            `evaluator.state.metrics`.

    Returns:
        A :class:`~ignite.handlers.EarlyStopping` handler.
    """
    es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, es_handler)

    return es_handler
예제 #30
0
def assign_event_handlers(trainer, evaluator, val_set):
    pbar = ProgressBar()
    pbar.attach(trainer, ['loss'])

    early_stop = EarlyStopping(patience=2, score_function=lambda e: -e.state.metrics['loss'], trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, early_stop)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        print("\nTraining Results - Epoch: {} : Avg loss: {:.3f}"
              .format(trainer.state.epoch, trainer.state.metrics['avg_loss']))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_set)
        metrics_eval = evaluator.state.metrics
        print("Validation Results - Epoch: {} Avg loss: {:.3f}, Avg abs. error: {:.2f}"
              .format(trainer.state.epoch, metrics_eval['loss'], metrics_eval['mae']))