コード例 #1
0
def setup_trainer(model, optimizer, device, data_parallel: bool) -> Engine:
    def update(trainer, batch: Tuple[torch.Tensor]):
        model.train()
        optimizer.zero_grad()
        if isinstance(batch, tuple) or isinstance(batch, list):
            assert len(batch) == 1
            batch = batch[0]
        else:
            assert isinstance(batch, torch.Tensor)
        batch = batch.to(device)
        if not data_parallel:
            masked_batch = mask_for_forward(
                batch)  # replace -1 with some other token
            lm_logits = model(masked_batch)[0]
            loss = calculate_lm_loss(lm_logits, batch)
            loss.backward()
        else:
            # handling of -1 as padding is not implemented
            losses = model(batch, lm_labels=batch)
            losses.backward(torch.ones_like(losses))
            loss = losses.mean()
        optimizer.step()
        return loss.item()

    trainer = Engine(update)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
    return trainer
コード例 #2
0
    def create_engine():
        engine = Engine(update)
        pbar = ProgressBar()

        engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
        pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)
        return engine
コード例 #3
0
    def _setup_trainer_handlers(self, trainer):
        # Setup timer to measure training time
        timer = setup_timer(trainer)
        self._setup_log_training_loss(trainer)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_training_time(engine):
            self.logger.info("One epoch training time (seconds): {}".format(
                timer.value()))

        last_model_saver = ModelCheckpoint(
            self.log_dir.as_posix(),
            filename_prefix="checkpoint",
            save_interval=self.trainer_checkpoint_interval,
            n_saved=1,
            atomic=True,
            create_dir=True,
            save_as_state_dict=True)

        model_name = get_object_name(self.model)

        to_save = {
            model_name: self.model,
            "optimizer": self.optimizer,
        }

        if self.lr_scheduler is not None:
            to_save['lr_scheduler'] = self.lr_scheduler

        trainer.add_event_handler(Events.ITERATION_COMPLETED, last_model_saver,
                                  to_save)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
コード例 #4
0
def test_terminate_on_nan_and_inf(state_output, should_terminate):

    torch.manual_seed(12)

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)
    trainer.state = State()
    h = TerminateOnNan()

    trainer.state.output = state_output
    if isinstance(state_output, np.ndarray):
        h._output_transform = lambda x: x.tolist()
    h(trainer)
    assert trainer.should_terminate == should_terminate
コード例 #5
0
def test_with_terminate_on_inf():

    torch.manual_seed(12)

    data = [
        1.0,
        0.8,
        torch.rand(4, 4),
        (1.0 / torch.randint(0, 2, size=(4, )).type(torch.float),
         torch.tensor(1.234)),
        torch.rand(5),
        torch.asin(torch.randn(4, 4)),
        0.0,
        1.0,
    ]

    def update_fn(engine, batch):
        return batch

    trainer = Engine(update_fn)
    h = TerminateOnNan()
    trainer.add_event_handler(Events.ITERATION_COMPLETED, h)

    trainer.run(data, max_epochs=2)
    assert trainer.state.iteration == 4
コード例 #6
0
    def _add_event_handlers(self):
        """
        Adds a progressbar and a summary writer to output the current training status. Adds event handlers to output
        common messages and update the progressbar.
        """
        progressbar_description = 'TRAINING => loss: {:.6f}'
        progressbar = tqdm(initial=0,
                           leave=False,
                           total=len(self.train_loader),
                           desc=progressbar_description.format(0))
        writer = SummaryWriter(self.log_directory)

        @self.trainer_engine.on(Events.ITERATION_COMPLETED)
        def log_training_loss(trainer):
            writer.add_scalar('loss', trainer.state.output)
            progressbar.desc = progressbar_description.format(
                trainer.state.output)
            progressbar.update(1)

        @self.trainer_engine.on(Events.EPOCH_COMPLETED)
        def log_training_results(trainer):
            progressbar.n = progressbar.last_print_n = 0
            self.evaluator.run(self.train_loader)
            metrics = self.evaluator.state.metrics
            for key, value in metrics.items():
                writer.add_scalar(key, value)
            tqdm.write(
                '\nTraining Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}\n'
                .format(trainer.state.epoch, metrics['accuracy'],
                        metrics['loss']))

        @self.trainer_engine.on(Events.EPOCH_COMPLETED)
        def log_validation_results(trainer):
            progressbar.n = progressbar.last_print_n = 0
            self.evaluator.run(self.val_loader)
            metrics = self.evaluator.state.metrics
            for key, value in metrics.items():
                writer.add_scalar(key, value)
            tqdm.write(
                '\nValidation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}\n'
                .format(trainer.state.epoch, metrics['accuracy'],
                        metrics['loss']))

        checkpoint_saver = ModelCheckpoint(  # create a Checkpoint handler that can be used to periodically
            self.checkpoint_directory,
            filename_prefix='net',  # save model objects to disc.
            save_interval=1,
            n_saved=5,
            atomic=True,
            create_dir=True,
            save_as_state_dict=False,
            require_empty=False)
        self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                              checkpoint_saver,
                                              {'train': self.model})
        self.trainer_engine.add_event_handler(Events.COMPLETED,
                                              checkpoint_saver,
                                              {'complete': self.model})
        self.trainer_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                              TerminateOnNan())
コード例 #7
0
def test_without_terminate_on_nan_inf():

    data = [1.0, 0.8, torch.rand(4, 4), (torch.rand(5), torch.rand(5, 4)), 0.0, 1.0]

    def update_fn(engine, batch):
        return batch

    trainer = Engine(update_fn)
    h = TerminateOnNan()
    trainer.add_event_handler(Events.ITERATION_COMPLETED, h)

    trainer.run(data, max_epochs=2)
    assert trainer.state.iteration == len(data) * 2
コード例 #8
0
ファイル: common.py プロジェクト: zivzone/ignite
def _setup_common_training_handlers(trainer,
                                    to_save=None, save_every_iters=1000, output_path=None,
                                    lr_scheduler=None, with_gpu_stats=True,
                                    output_names=None, with_pbars=True, with_pbar_on_iters=True,
                                    log_every_iters=100, device='cuda'):
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step())
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:
        if output_path is None:
            raise ValueError("If to_save argument is provided then output_path argument should be also defined")
        checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="training")
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler, to_save)

    if with_gpu_stats:
        GpuInfo().attach(trainer, name='gpu', event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

    if output_names is not None:

        def output_transform(x, index, name):
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, torch.Tensor):
                return x
            else:
                raise ValueError("Unhandled type of update_function's output. "
                                 "It should either mapping or sequence, but given {}".format(type(x)))

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform, index=i, name=n),
                           epoch_bound=False, device=device).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(trainer, metric_names='all',
                                              event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True, bar_format="").attach(trainer,
                                                        event_name=Events.EPOCH_STARTED,
                                                        closing_event_name=Events.COMPLETED)
コード例 #9
0
def test_with_terminate_on_nan():

    torch.manual_seed(12)

    data = [1.0, 0.8,
            (torch.rand(4, 4), torch.rand(4, 4)),
            torch.rand(5), torch.asin(torch.randn(4, 4)), 0.0, 1.0]

    def update_fn(engine, batch):
        return batch

    trainer = Engine(update_fn)
    h = TerminateOnNan()
    trainer.add_event_handler(Events.ITERATION_COMPLETED, h)

    trainer.run(data, max_epochs=2)
    assert trainer.state.iteration == 5
コード例 #10
0
	def __init__(self, name, model, log_dir, lr, lr_decay_step, adam=False):
		"""
		Initialize to train the given model.
		:param name: The name of the model to be trained.
		:param model: The model to be trained.
		:param log_dir: String. The log directory of the tensorboard.
		:param lr: Float. The learning rate.
		:param lr_decay_step: Integer. The amount of steps the learning rate decays.
		:param adam: Bool. Whether to use adam optimizer or not.
		"""
		super(Trainer, self).__init__(self.update_model)
		self.model = model
		# tqdm
		ProgressBar(persist=True).attach(self)
		# Optimizer
		params = [p for p in model.parameters() if p.requires_grad]
		if adam:
			self.optimizer = torch.optim.Adam(params, lr=lr)
		else:
			self.optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9)
		# Scheduler
		if lr_decay_step > 0:
			self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=lr_decay_step, gamma=0.1)
			self.add_event_handler(Events.EPOCH_COMPLETED, lambda e: e.scheduler.step())
		else:
			self.scheduler = None
		# Terminate if nan values found
		self.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
		# Tensorboard logging
		self.tb_logger = TensorboardLogger(log_dir=os.path.join(log_dir, name))
		self.add_event_handler(Events.COMPLETED, lambda x: self.tb_logger.close())
		self.tb_logger.attach(self,
		                      log_handler=OptimizerParamsHandler(self.optimizer),
		                      event_name=Events.EPOCH_COMPLETED)
		self.tb_logger.attach(self,
		                      log_handler=OutputHandler(tag='training', output_transform=lambda x: {
			                      'rpn_box_loss': round(self.state.output['loss_rpn_box_reg'].item(), 4),
			                      'rpn_cls_loss': round(self.state.output['loss_objectness'].item(), 4),
			                      'roi_box_loss': round(self.state.output['loss_box_reg'].item(), 4),
			                      'roi_cls_loss': round(self.state.output['loss_classifier'].item(), 4)
		                      }),
		                      event_name=Events.EPOCH_COMPLETED)
		# Run on GPU (cuda) if available
		if torch.cuda.is_available():
			torch.cuda.set_device(int(get_free_gpu()))
			model.cuda(torch.cuda.current_device())
コード例 #11
0
    def _finetune(self, train_dl, val_dl, criterion, iter_num):
        print("Recovery")
        self.model.to_rank = False
        finetune_epochs = config["pruning"]["finetune_epochs"].get()

        optimizer_constructor = optimizer_constructor_from_config(config)
        optimizer = optimizer_constructor(self.model.parameters())

        finetune_engine = create_supervised_trainer(self.model, optimizer, criterion, self.device)
        # progress bar
        pbar = Progbar(train_dl, metrics='none')
        finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, pbar)

        # log training loss
        if self.writer:
            finetune_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                              lambda engine: log_training_loss(engine, self.writer))

        # terminate on Nan
        finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

        # model checkpoints
        checkpoint = ModelCheckpoint(config["pruning"]["out_path"].get(), require_empty=False,
                                     filename_prefix=f"pruning_iteration_{iter_num}", save_interval=1)
        finetune_engine.add_event_handler(Events.COMPLETED, checkpoint, {"weights": self.model.cpu()})

        # add early stopping
        validation_evaluator = create_supervised_evaluator(self.model, device=self.device,
                                                           metrics=self._metrics)

        if config["pruning"]["early_stopping"].get():
            def _score_function(evaluator):
                return -evaluator.state.metrics["loss"]
            early_stop = EarlyStopping(config["pruning"]["patience"].get(), _score_function, finetune_engine)
            validation_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stop)

        finetune_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine:
                                          run_evaluator(engine, validation_evaluator, val_dl))

        for handler_dict in self._finetune_handlers:
            finetune_engine.add_event_handler(handler_dict["event_name"], handler_dict["handler"],
                                              *handler_dict["args"], **handler_dict["kwargs"])

        # run training engine
        finetune_engine.run(train_dl, max_epochs=finetune_epochs)
コード例 #12
0
    def _train(self, model, optimizer, train_loader, max_epochs, **kwargs):
        trainer = create_supervised_trainer(model, optimizer, self.criterion)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

        val_metrics = {"smape": SymmetricMeanAbsolutePercentageError()}
        evaluator = create_supervised_evaluator(model, metrics=val_metrics)

        @trainer.on(Events.COMPLETED)
        def log_training_results(trainer):
            evaluator.run(train_loader)
            metrics = evaluator.state.metrics
            self._print_line(
                "Interval {} Training Results - Epoch: {} Avg smape: {:.2f}".
                format(kwargs.get("interval"), trainer.state.epoch,
                       metrics["smape"]))

        trainer.run(train_loader, max_epochs=max_epochs)
        return (model, optimizer)
コード例 #13
0
 def _add_event_handlers(self):
     """Add event handlers to output common messages and update the progressbar."""
     self.trainer_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                           TerminateOnNan())
     self.trainer_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                           self._event_log_training_output)
     self.trainer_engine.add_event_handler(
         Events.ITERATION_COMPLETED, self._event_update_progressbar_step)
     self.trainer_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                           self._event_update_step_counter)
     self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           self._event_log_training_results)
     self.trainer_engine.add_event_handler(
         Events.EPOCH_COMPLETED, self._event_log_validation_results)
     self.trainer_engine.add_event_handler(
         Events.EPOCH_COMPLETED, self._event_save_trainer_checkpoint)
     self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           self._event_reset_progressbar)
     self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           self._event_update_epoch_counter)
     self.trainer_engine.add_event_handler(Events.COMPLETED,
                                           self._event_cleanup)
コード例 #14
0
def test_terminate_on_nan_and_inf():

    torch.manual_seed(12)

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)
    trainer.state = State()
    h = TerminateOnNan()

    trainer.state.output = 1.0
    h(trainer)
    assert not trainer.should_terminate

    trainer.state.output = torch.tensor(123.45)
    h(trainer)
    assert not trainer.should_terminate

    trainer.state.output = torch.asin(torch.randn(10, ))
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = np.array([1.0, 2.0])
    h._output_transform = lambda x: x.tolist()
    h(trainer)
    assert not trainer.should_terminate
    h._output_transform = lambda x: x

    trainer.state.output = torch.asin(torch.randn(4, 4))
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = (10.0, 1.0 /
                            torch.randint(0, 2, size=(4, )).type(torch.float),
                            1.0)
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = (1.0, torch.tensor(1.0), "abc")
    h(trainer)
    assert not trainer.should_terminate

    trainer.state.output = 1.0 / torch.randint(0, 2, size=(4, 4)).type(
        torch.float)
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = (float("nan"), 10.0)
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = float("inf")
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = [float("nan"), 10.0]
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False
コード例 #15
0
def run(conf: DictConfig, local_rank=0, distributed=False):
    epochs = conf.train.epochs
    epoch_length = conf.train.epoch_length
    torch.manual_seed(conf.general.seed)

    if distributed:
        rank = dist.get_rank()
        num_replicas = dist.get_world_size()
        torch.cuda.set_device(local_rank)
    else:
        rank = 0
        num_replicas = 1
        torch.cuda.set_device(conf.general.gpu)
    device = torch.device('cuda')
    loader_args = dict()
    master_node = rank == 0

    if master_node:
        print(conf.pretty())
    if num_replicas > 1:
        epoch_length = epoch_length // num_replicas
        loader_args = dict(rank=rank, num_replicas=num_replicas)

    train_dl = create_train_loader(conf.data, **loader_args)

    if epoch_length < 1:
        epoch_length = len(train_dl)

    metric_names = list(conf.logging.stats)
    metrics = create_metrics(metric_names, device if distributed else None)

    G = instantiate(conf.model.G).to(device)
    D = instantiate(conf.model.D).to(device)
    G_loss = instantiate(conf.loss.G).to(device)
    D_loss = instantiate(conf.loss.D).to(device)
    G_opt = instantiate(conf.optim.G, G.parameters())
    D_opt = instantiate(conf.optim.D, D.parameters())
    G_ema = None

    if master_node and conf.G_smoothing.enabled:
        G_ema = instantiate(conf.model.G)
        if not conf.G_smoothing.use_cpu:
            G_ema = G_ema.to(device)
        G_ema.load_state_dict(G.state_dict())
        G_ema.requires_grad_(False)

    to_save = {
        'G': G,
        'D': D,
        'G_loss': G_loss,
        'D_loss': D_loss,
        'G_opt': G_opt,
        'D_opt': D_opt,
        'G_ema': G_ema
    }

    if master_node and conf.logging.model:
        logging.info(G)
        logging.info(D)

    if distributed:
        ddp_kwargs = dict(device_ids=[
            local_rank,
        ], output_device=local_rank)
        G = torch.nn.parallel.DistributedDataParallel(G, **ddp_kwargs)
        D = torch.nn.parallel.DistributedDataParallel(D, **ddp_kwargs)

    train_options = {
        'train': dict(conf.train),
        'snapshot': dict(conf.snapshots),
        'smoothing': dict(conf.G_smoothing),
        'distributed': distributed
    }
    bs_dl = int(conf.data.loader.batch_size) * num_replicas
    bs_eff = conf.train.batch_size
    if bs_eff % bs_dl:
        raise AttributeError(
            "Effective batch size should be divisible by data-loader batch size "
            "multiplied by number of devices in use"
        )  # until there is no special bs for master node...
    upd_interval = max(bs_eff // bs_dl, 1)
    train_options['train']['update_interval'] = upd_interval
    if epoch_length < len(train_dl):
        # ideally epoch_length should be tied to the effective batch_size only
        # and the ignite trainer counts data-loader iterations
        epoch_length *= upd_interval

    train_loop, sample_images = create_train_closures(G,
                                                      D,
                                                      G_loss,
                                                      D_loss,
                                                      G_opt,
                                                      D_opt,
                                                      G_ema=G_ema,
                                                      device=device,
                                                      options=train_options)
    trainer = create_trainer(train_loop, metrics, device, num_replicas)
    to_save['trainer'] = trainer

    every_iteration = Events.ITERATION_COMPLETED
    trainer.add_event_handler(every_iteration, TerminateOnNan())

    cp = conf.checkpoints
    pbar = None

    if master_node:
        log_freq = conf.logging.iter_freq
        log_event = Events.ITERATION_COMPLETED(every=log_freq)
        pbar = ProgressBar(persist=False)
        trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_start)
        trainer.add_event_handler(log_event, log_iter, pbar, log_freq)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, log_epoch)
        pbar.attach(trainer, metric_names=metric_names)
        setup_checkpoints(trainer, to_save, epoch_length, conf)
        setup_snapshots(trainer, sample_images, conf)

    if 'load' in cp.keys() and cp.load is not None:
        if master_node:
            logging.info("Resume from a checkpoint: {}".format(cp.load))
            trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp,
                                      pbar)
        Checkpoint.load_objects(to_load=to_save,
                                checkpoint=torch.load(cp.load,
                                                      map_location=device))

    try:
        trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length)
    except Exception as e:
        import traceback
        logging.error(traceback.format_exc())
    if pbar is not None:
        pbar.close()
コード例 #16
0
def main(parser_args):
    """Main function to create trainer engine, add handlers to train and validation engines.
    Then runs train engine to perform training and validation.

    Args:
        parser_args (dict): parsed arguments
    """
    dataloader_train, dataloader_validation = get_dataloaders(parser_args)
    criterion = nn.CrossEntropyLoss()

    unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels,
                         parser_args.depth, parser_args.laplacian_type,
                         parser_args.kernel_size)
    unet, device = init_device(parser_args.device, unet)
    lr = parser_args.learning_rate
    optimizer = optim.Adam(unet.parameters(), lr=lr)

    def trainer(engine, batch):
        """Train Function to define train engine.
        Called for every batch of the train engine, for each epoch.

        Args:
            engine (ignite.engine): train engine
            batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader

        Returns:
            :obj:`torch.tensor` : train loss for that batch and epoch
        """
        unet.train()
        data, labels = batch
        labels = labels.to(device)
        data = data.to(device)
        output = unet(data)

        B, V, C = output.shape
        B_labels, V_labels, C_labels = labels.shape
        output = output.view(B * V, C)
        labels = labels.view(B_labels * V_labels, C_labels).max(1)[1]

        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()

    writer = SummaryWriter(parser_args.tensorboard_path)

    engine_train = Engine(trainer)

    engine_validate = create_supervised_evaluator(
        model=unet,
        metrics={"AP": EpochMetric(average_precision_compute_fn)},
        device=device,
        output_transform=validate_output_transform)

    engine_train.add_event_handler(
        Events.EPOCH_STARTED,
        lambda x: print("Starting Epoch: {}".format(x.state.epoch)))
    engine_train.add_event_handler(Events.ITERATION_COMPLETED,
                                   TerminateOnNan())

    @engine_train.on(Events.EPOCH_COMPLETED)
    def epoch_validation(engine):
        """Handler to run the validation engine at the end of the train engine's epoch.

        Args:
            engine (ignite.engine): train engine
        """
        print("beginning validation epoch")
        engine_validate.run(dataloader_validation)

    reduce_lr_plateau = ReduceLROnPlateau(
        optimizer,
        mode=parser_args.reducelronplateau_mode,
        factor=parser_args.reducelronplateau_factor,
        patience=parser_args.reducelronplateau_patience,
    )

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def update_reduce_on_plateau(engine):
        """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        reduce_lr_plateau.step(mean_average_precision)

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def save_epoch_results(engine):
        """Handler to save the metrics at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        print("Average precisions:", ap)
        print("mAP:", mean_average_precision)
        writer.add_scalars(
            "metrics",
            {
                "mean average precision (AR+TC)": mean_average_precision,
                "AR average precision": ap[2],
                "TC average precision": ap[1]
            },
            engine_train.state.epoch,
        )
        writer.close()

    step_scheduler = StepLR(optimizer,
                            step_size=parser_args.steplr_step_size,
                            gamma=parser_args.steplr_gamma)
    scheduler = create_lr_scheduler_with_warmup(
        step_scheduler,
        warmup_start_value=parser_args.warmuplr_warmup_start_value,
        warmup_end_value=parser_args.warmuplr_warmup_end_value,
        warmup_duration=parser_args.warmuplr_warmup_duration,
    )
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler)

    earlystopper = EarlyStopping(
        patience=parser_args.earlystopping_patience,
        score_function=lambda x: -x.state.metrics["AP"][1],
        trainer=engine_train)
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper)

    add_tensorboard(engine_train,
                    optimizer,
                    unet,
                    log_dir=parser_args.tensorboard_path)

    engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs)

    torch.save(unet.state_dict(),
               parser_args.model_save_path + "unet_state.pt")
コード例 #17
0
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             loss_fn, metrics):

    device = cfg['device_ids'][0] if torch.cuda.is_available(
    ) else 'cpu'  #默认主卡设置为
    max_epochs = cfg['max_epochs']

    # create trainer
    if cfg['multi_gpu']:  #多卡时,不需要传入loss_fn
        trainer = create_supervised_dp_trainer(model.train(),
                                               optimizer,
                                               device=device)
    else:
        trainer = create_supervised_trainer(model.train(),
                                            optimizer,
                                            loss_fn,
                                            device=device)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')

    # create pbar
    len_train_loader = len(train_loader)
    pbar = tqdm(total=len_train_loader)

    ##########################################################################################
    ###########                    Events.ITERATION_COMPLETED                    #############
    ##########################################################################################

    # 每 log_period 轮迭代结束输出train_loss
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_period = cfg['log_period']
        log_per_iter = int(log_period * len_train_loader) if int(
            log_period * len_train_loader) >= 1 else 1  # 计算打印周期
        current_iter = (engine.state.iteration - 1) % len_train_loader + 1 + (
            engine.state.epoch - 1) * len_train_loader  # 计算当前 iter

        lr = optimizer.state_dict()['param_groups'][0]['lr']

        if current_iter % log_per_iter == 0:
            pbar.write("Epoch[{}] Iteration[{}] lr {:.7f} Loss {:.7f}".format(
                engine.state.epoch, current_iter, lr,
                engine.state.metrics['avg_loss']))
            pbar.update(log_per_iter)

    # lr_scheduler
    @trainer.on(Events.ITERATION_COMPLETED)
    def adjust_lr_scheduler(engine):
        if isinstance(scheduler, lr_scheduler.CyclicLR):
            scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def update_swa(engine):
        if isinstance(scheduler, lr_scheduler.CyclicLR):
            if cfg['enable_swa']:
                swa_period = 2 * cfg['lr_scheduler']['step_size_up']
                current_iter = (
                    engine.state.iteration - 1) % len_train_loader + 1 + (
                        engine.state.epoch - 1) * len_train_loader  # 计算当前 iter
                if current_iter % swa_period == 0:
                    optimizer.update_swa()

    @trainer.on(Events.ITERATION_COMPLETED)
    def update_bn(engine):
        if isinstance(scheduler, lr_scheduler.CyclicLR):
            save_period = 2 * cfg['lr_scheduler']['step_size_up']
            current_iter = (
                engine.state.iteration - 1) % len_train_loader + 1 + (
                    engine.state.epoch - 1) * len_train_loader  # 计算当前 iter
            if current_iter % save_period == 0 and current_iter >= save_period * 2:  # 从第 4 个周期开始存

                save_dir = cfg['save_dir']
                if not os.path.isdir(save_dir):
                    os.makedirs(save_dir)
                if cfg['enable_swa']:
                    optimizer.swap_swa_sgd()
                    optimizer.bn_update(train_loader, model, device=device)
                model_name = os.path.join(
                    save_dir, cfg['model']['type'] + '_' + cfg['tag'] + "_" +
                    str(current_iter) + ".pth")
                if cfg['multi_gpu']:
                    save_pth = {
                        'model': model.module.model.state_dict(),
                        'cfg': cfg
                    }
                    torch.save(save_pth, model_name)
                else:
                    save_pth = {'model': model.state_dict(), 'cfg': cfg}
                    torch.save(save_pth, model_name)

    ##########################################################################################
    ##################               Events.EPOCH_COMPLETED                    ###############
    ##########################################################################################
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_temp_epoch(engine):
        save_dir = cfg['save_dir']
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        epoch = engine.state.epoch
        if epoch % 1 == 0:
            model_name = os.path.join(
                save_dir,
                cfg['model']['type'] + '_' + cfg['tag'] + "_temp.pth")
            if cfg['multi_gpu']:
                save_pth = {
                    'model': model.module.model.state_dict(),
                    'cfg': cfg
                }
                torch.save(save_pth, model_name)
            else:
                save_pth = {'model': model.state_dict(), 'cfg': cfg}
                torch.save(save_pth, model_name)

    @trainer.on(Events.EPOCH_COMPLETED)
    def reset_pbar(engine):
        pbar.reset()

    trainer.run(train_loader, max_epochs=max_epochs)
    pbar.close()
コード例 #18
0
def run(args):
    train_iter, valid_iter, test_iter, indexed_vector = load_dataset(args)
    # iters_per_epoch = len(train_iter) // 100 * 100  # 取整百,比如train dataset是7463,batch16,则每epoch有466.4 iteration
    iters_per_epoch = len(train_iter)

    model = LSTMClassifier(indexed_vector,
                           hidden_dim=args.nhid,
                           output_dim=args.nclass,
                           num_layers=args.nlayers,
                           dropout=args.dropout,
                           bidirectional=args.bi)

    optimizer = optim.Adam(
        filter(lambda param: param.requires_grad, model.parameters()))
    criterion = nn.CrossEntropyLoss()

    trainer = create_supervised_trainer(model=model,
                                        optimizer=optimizer,
                                        loss_fn=criterion,
                                        device=args.device)
    train_evaluator = create_supervised_evaluator(model=model,
                                                  metrics={
                                                      'accuracy': Accuracy(),
                                                      'precision': Precision(),
                                                      'recall': Recall(),
                                                      'loss': Loss(criterion)
                                                  },
                                                  device=args.device)
    valid_evaluator = create_supervised_evaluator(model=model,
                                                  metrics={
                                                      'accuracy': Accuracy(),
                                                      'precision': Precision(),
                                                      'recall': Recall(),
                                                      'loss': Loss(criterion)
                                                  },
                                                  device=args.device)

    def loss_score(engine):
        loss = engine.state.output
        return -loss  # 分数越高越好,所以loss取负

    def acc_score(engine):
        accuracy = engine.state.metrics['accuracy']
        return accuracy

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_loss(engine):
        train_iter_num = engine.state.iteration
        logger.info("Epoch {} Iteration {}: Loss {:.4f}"
                    "".format(engine.state.epoch, train_iter_num,
                              engine.state.output))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_validation_results(engine):
        train_iter_num = engine.state.iteration
        if train_iter_num > iters_per_epoch and train_iter_num % args.log_interval == 0:
            valid_evaluator.run(valid_iter)
            metrics = valid_evaluator.state.metrics
            logger.info(
                "Validation Results - Epoch {}, Iter {}: Avg accuracy {}, Precision {}, Recall {}, valid loss {:.4f}"
                "".format(engine.state.epoch, train_iter_num,
                          metrics['accuracy'], metrics['precision'].tolist(),
                          metrics['recall'].tolist(), metrics['loss']))

    # train的每ITERATION检查loss是否是 "Nan"
    # 是的话终止训练
    terminateonnan = TerminateOnNan(output_transform=lambda output: output)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, terminateonnan)

    checkpoint_handler = ModelCheckpoint(
        dirname='./data/saved_models/',
        filename_prefix='checkpoint',
        score_function=acc_score,
        score_name="acc",
        save_interval=None,  # 按次数周期保存
        n_saved=3,
        require_empty=False,  # 强制覆盖
        create_dir=True,
        save_as_state_dict=False)
    # 因为valid的epoch往往是1,所以Events.EPOCH_COMPLETED和Events.COMPLETED是一样的
    valid_evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler,
                                      {'model': model})
    patience = int(args.early_stop * (iters_per_epoch / args.log_interval))
    earlystop_handler = EarlyStopping(patience=patience,
                                      score_function=acc_score,
                                      trainer=trainer)
    earlystop_handler._logger = logger
    valid_evaluator.add_event_handler(Events.COMPLETED, earlystop_handler)

    trainer_bar = ProgressBar()
    trainer_bar.attach(trainer, output_transform=lambda x: {'loss': x})

    trainer.run(data=train_iter, max_epochs=args.epochs)
    logger.info("Best model: Epoch {}, Train iters {}, Valid iters {}, acc {}"
                "".format(earlystop_handler.best_state['epoch'], \
                          earlystop_handler.best_state['iters'], \
                          earlystop_handler.best_state['valid_iters'],
                          earlystop_handler.best_score))
    logger.info("Valid results in best model: {}".format(
        earlystop_handler.best_state['metrics']))
    logger.info("Best models info: {}".format(str(
        checkpoint_handler._saved)))  # [(0.65,['model_6_acc=0.65.pth']),...]
    best_models_info = {
        'model_args': str(args.__dict__),
        'checkpint_saved': checkpoint_handler._saved,
        'train_epoch': earlystop_handler.best_state['epoch'],
        'train_iters': earlystop_handler.best_state['iters'],
        'valid_iters': earlystop_handler.best_state['valid_iters'],
        'best_model_path': checkpoint_handler._saved[-1][1]
        [0],  #checkpoint_handler._saved按sore升序排列的
        'best_score': checkpoint_handler._saved[-1][0],
        'score_function': checkpoint_handler._score_function.__name__,
        'valid_results': {
            'accuracy':
            earlystop_handler.best_state['metrics']['accuracy'],
            'precision':
            earlystop_handler.best_state['metrics']['precision'].tolist(),
            'recall':
            earlystop_handler.best_state['metrics']['recall'].tolist(),
            'loss':
            earlystop_handler.best_state['metrics']['loss']
        }
    }
    print(checkpoint_handler._saved)
    pprint(str(best_models_info))
    # exit()
    # with open('./data/pkl/best_models_path.pkl', 'wb') as f:
    #     pickle.dump(list(map(lambda model_info: model_info[1][0], checkpoint_handler._saved)), f)
    with open('./data/saved_models/best_models_info.json', 'w') as f:
        f.write(repr(best_models_info))

    def test(test_iter, args):
        # with open('./data/pkl/best_models_path.pkl', 'rb') as f:
        #     best_models = pickle.load(f)
        with open('./data/saved_models/best_models_info.json', 'r') as f:
            best_models_info = eval(f.read())
        print("best models info: {}".format(best_models_info))
        logger.info("best model path: {}".format(
            best_models_info['best_model_path']))
        model = torch.load(best_models_info['best_model_path'],
                           map_location=args.device)
        test_evaluator = create_supervised_evaluator(model=model,
                                                     metrics={
                                                         'accuracy':
                                                         Accuracy(),
                                                         'precision':
                                                         Precision(),
                                                         'recall': Recall(),
                                                         'loss':
                                                         Loss(criterion)
                                                     },
                                                     device=args.device)

        @test_evaluator.on(Events.COMPLETED)
        def log_test_results(engine):
            metrics = engine.state.metrics
            logger.info("Test Results: Avg accuracy: {}, Precision: {}, Recall: {}, Loss: {}"
                        "".format(
                                  metrics['accuracy'], \
                                  metrics['precision'].tolist(),
                                  metrics['recall'].tolist(),
                                  metrics['loss']
                            )
                        )

        test_evaluator.run(test_iter)

    test(valid_iter, args)
コード例 #19
0
    def run(self, train_loader, val_loader, test_loader):
        """Perform model training and evaluation on holdout dataset."""

        ## attach certain metrics to trainer ##
        CpuInfo().attach(self.trainer, "cpu_util")
        Loss(self.loss).attach(self.trainer, "loss")

        ###### configure evaluator settings ######
        def get_output_transform(target: str, collapse_y: bool = False):
            return lambda out: metric_output_transform(
                out, self.loss, target, collapse_y=collapse_y)

        graph_num_classes = len(self.graph_classes)
        node_num_classes = len(self.node_classes)
        node_num_classes = 2 if node_num_classes == 1 else node_num_classes

        node_output_transform = get_output_transform("node")
        node_output_transform_collapsed = get_output_transform("node",
                                                               collapse_y=True)
        graph_output_transform = get_output_transform("graph")
        graph_output_transform_collapsed = get_output_transform(
            "graph", collapse_y=True)

        # metrics we are interested in
        base_metrics: dict = {
            'loss':
            Loss(self.loss),
            "cpu_util":
            CpuInfo(),
            'node_accuracy_avg':
            Accuracy(output_transform=node_output_transform,
                     is_multilabel=False),
            'node_accuracy':
            LabelwiseAccuracy(output_transform=node_output_transform,
                              is_multilabel=False),
            "node_recall":
            Recall(output_transform=node_output_transform_collapsed,
                   is_multilabel=False,
                   average=False),
            "node_precision":
            Precision(output_transform=node_output_transform_collapsed,
                      is_multilabel=False,
                      average=False),
            "node_f1_score":
            Fbeta(1,
                  output_transform=node_output_transform_collapsed,
                  average=False),
            "node_c_matrix":
            ConfusionMatrix(node_num_classes,
                            output_transform=node_output_transform_collapsed,
                            average=None)
        }

        metrics = dict(**base_metrics)

        # settings for the evaluator
        evaluator_settings = {
            "device": self.device,
            "loss_fn": self.loss,
            "node_classes": self.node_classes,
            "graph_classes": self.graph_classes,
            "non_blocking": True,
            "metrics": OrderedDict(sorted(metrics.items(),
                                          key=lambda m: m[0])),
            "pred_collector_function": self._pred_collector_function
        }

        ## configure evaluators ##
        val_evaluator = None
        if len(val_loader):
            val_evaluator = create_supervised_evaluator(
                self.model, **evaluator_settings)
            # configure behavior for early stopping
            if self.stopper:
                val_evaluator.add_event_handler(Events.COMPLETED, self.stopper)
            # configure behavior for checkpoint saving
            val_evaluator.add_event_handler(Events.COMPLETED,
                                            self.best_checkpoint_handler)
            val_evaluator.add_event_handler(Events.COMPLETED,
                                            self.latest_checkpoint_handler)
        else:
            self.trainer.add_event_handler(Events.COMPLETED,
                                           self.latest_checkpoint_handler)

        test_evaluator = None
        if len(test_loader):
            test_evaluator = create_supervised_evaluator(
                self.model, **evaluator_settings)
        #############################

        @self.trainer.on(Events.STARTED)
        def log_training_start(trainer):
            self.custom_print("Start training...")

        @self.trainer.on(Events.EPOCH_COMPLETED)
        def compute_metrics(trainer):
            """Compute evaluation metric values after each epoch."""

            epoch = trainer.state.epoch

            self.custom_print(f"Finished epoch {epoch:03d}!")

            if len(val_loader):
                self.persist_collection = True
                val_evaluator.run(val_loader)
                self._save_collected_predictions(
                    prefix=f"validation_epoch{epoch:03}")
                # write metrics to file
                self.write_metrics(trainer, val_evaluator, suffix="validation")

        @self.trainer.on(Events.COMPLETED)
        def log_training_complete(trainer):
            """Trigger evaluation on test set if training is completed."""

            epoch = trainer.state.epoch
            suffix = "(Early Stopping)" if epoch < self.epochs else ""

            self.custom_print("Finished after {:03d} epochs! {}".format(
                epoch, suffix))

            # load best model for evaluation
            self.custom_print("Load best model for final evaluation...")
            last_checkpoint: str = self.best_checkpoint_handler.last_checkpoint or self.latest_checkpoint_handler.last_checkpoint
            best_checkpoint_path = os.path.join(self.save_path,
                                                last_checkpoint)
            checkpoint_path_dict: dict = {
                "latest_checkpoint_path":
                best_checkpoint_path  # we want to load states from the best checkpoint as "latest" configuration for testing
            }
            self.model, self.optimizer, self.trainer, _, _, _ = self._load_checkpoint(
                self.model,
                self.optimizer,
                self.trainer,
                None,
                None,
                None,
                checkpoint_path_dict=checkpoint_path_dict)

            if len(test_loader):
                self.persist_collection = True
                test_evaluator.run(test_loader)
                self._save_collected_predictions(prefix="test_final")
                # write metrics to file
                self.write_metrics(trainer, test_evaluator, suffix="test")

        # terminate training if Nan values are produced
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                       TerminateOnNan())

        # start the actual training
        self.custom_print(f"Train for a maximum of {self.epochs} epochs...")
        self.trainer.run(train_loader, max_epochs=self.epochs)
コード例 #20
0
ファイル: trainer.py プロジェクト: phymucs/lgf
    def __init__(
            self,

            module,
            device,

            train_loss,
            train_loader,
            opt,
            lr_scheduler,
            max_epochs,
            max_grad_norm,

            test_metrics,
            test_loader,
            epochs_per_test,

            early_stopping,
            valid_loss,
            valid_loader,
            max_bad_valid_epochs,

            visualizer,

            writer,
            should_checkpoint_latest,
            should_checkpoint_best_valid
    ):
        self._module = module
        self._module.to(device)
        self._device = device

        self._train_loss = train_loss
        self._train_loader = train_loader
        self._opt = opt
        self._lr_scheduler = lr_scheduler
        self._max_epochs = max_epochs
        self._max_grad_norm = max_grad_norm

        self._test_metrics = test_metrics
        self._test_loader = test_loader
        self._epochs_per_test = epochs_per_test

        self._valid_loss = valid_loss
        self._valid_loader = valid_loader
        self._max_bad_valid_epochs = max_bad_valid_epochs
        self._best_valid_loss = float("inf")
        self._num_bad_valid_epochs = 0

        self._visualizer = visualizer

        self._writer = writer
        self._should_checkpoint_best_valid = should_checkpoint_best_valid

        ### Training

        self._trainer = Engine(self._train_batch)

        AverageMetric().attach(self._trainer)
        ProgressBar(persist=True).attach(self._trainer, ["loss"])

        self._trainer.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.train())
        self._trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
        self._trainer.add_event_handler(Events.ITERATION_COMPLETED, self._log_training_info)

        if should_checkpoint_latest:
            self._trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self._save_checkpoint("latest"))

        ### Validation

        if early_stopping:
            self._validator = Engine(self._validate_batch)

            AverageMetric().attach(self._validator)
            ProgressBar(persist=False, desc="Validating").attach(self._validator)

            self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._validate)
            self._validator.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.eval())

        ### Testing

        self._tester = Engine(self._test_batch)

        AverageMetric().attach(self._tester)
        ProgressBar(persist=False, desc="Testing").attach(self._tester)

        self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._test)
        self._tester.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.eval())
コード例 #21
0
def run(conf: DictConfig, local_rank=0, distributed=False):
    epochs = conf.train.epochs
    epoch_length = conf.train.epoch_length
    torch.manual_seed(conf.seed)

    if distributed:
        rank = dist.get_rank()
        num_replicas = dist.get_world_size()
        torch.cuda.set_device(local_rank)
    else:
        rank = 0
        num_replicas = 1
        torch.cuda.set_device(conf.gpu)
    device = torch.device('cuda')
    loader_args = dict(mean=conf.data.mean, std=conf.data.std)
    master_node = rank == 0

    if master_node:
        print(conf.pretty())
    if num_replicas > 1:
        epoch_length = epoch_length // num_replicas
        loader_args["rank"] = rank
        loader_args["num_replicas"] = num_replicas

    train_dl = create_train_loader(conf.data.train, **loader_args)
    valid_dl = create_val_loader(conf.data.val, **loader_args)

    if epoch_length < 1:
        epoch_length = len(train_dl)

    model = instantiate(conf.model).to(device)
    model_ema, update_ema = setup_ema(conf,
                                      model,
                                      device=device,
                                      master_node=master_node)
    optim = build_optimizer(conf.optim, model)

    scheduler_kwargs = dict()
    if "schedule.OneCyclePolicy" in conf.lr_scheduler["class"]:
        scheduler_kwargs["cycle_steps"] = epoch_length
    lr_scheduler: Scheduler = instantiate(conf.lr_scheduler, optim,
                                          **scheduler_kwargs)

    use_amp = False
    if conf.use_apex:
        import apex
        from apex import amp
        logging.debug("Nvidia's Apex package is available")

        model, optim = amp.initialize(model, optim, **conf.amp)
        use_amp = True
        if master_node:
            logging.info("Using AMP with opt_level={}".format(
                conf.amp.opt_level))
    else:
        apex, amp = None, None

    to_save = dict(model=model, optim=optim)
    if use_amp:
        to_save["amp"] = amp
    if model_ema is not None:
        to_save["model_ema"] = model_ema

    if master_node and conf.logging.model:
        logging.info(model)

    if distributed:
        sync_bn = conf.distributed.sync_bn
        if apex is not None:
            if sync_bn:
                model = apex.parallel.convert_syncbn_model(model)
            model = apex.parallel.distributed.DistributedDataParallel(
                model, delay_allreduce=True)
        else:
            if sync_bn:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[
                    local_rank,
                ], output_device=local_rank)

    upd_interval = conf.optim.step_interval
    ema_interval = conf.smoothing.interval_it * upd_interval
    clip_grad = conf.optim.clip_grad

    _handle_batch_train = build_process_batch_func(conf.data,
                                                   stage="train",
                                                   device=device)
    _handle_batch_val = build_process_batch_func(conf.data,
                                                 stage="val",
                                                 device=device)

    def _update(eng: Engine, batch: Batch) -> FloatDict:
        model.train()
        batch = _handle_batch_train(batch)
        losses: Dict = model(*batch)
        stats = {k: v.item() for k, v in losses.items()}
        loss = losses["loss"]
        del losses

        if use_amp:
            with amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        it = eng.state.iteration
        if not it % upd_interval:
            if clip_grad > 0:
                params = amp.master_params(
                    optim) if use_amp else model.parameters()
                torch.nn.utils.clip_grad_norm_(params, clip_grad)
            optim.step()
            optim.zero_grad()
            lr_scheduler.step_update(it)

            if not it % ema_interval:
                update_ema()

            eng.state.lr = optim.param_groups[0]["lr"]

        return stats

    calc_map = conf.validate.calc_map
    min_score = conf.validate.get("min_score", -1)

    model_val = model
    if conf.train.skip and model_ema is not None:
        model_val = model_ema.to(device)

    def _validate(eng: Engine, batch: Batch) -> FloatDict:
        model_val.eval()
        images, targets = _handle_batch_val(batch)

        with torch.no_grad():
            out: Dict = model_val(images, targets)

        pred_boxes = out.pop("detections")
        stats = {k: v.item() for k, v in out.items()}

        if calc_map:
            pred_boxes = pred_boxes.detach().cpu().numpy()
            true_boxes = targets['bbox'].cpu().numpy()
            img_scale = targets['img_scale'].cpu().numpy()
            # yxyx -> xyxy
            true_boxes = true_boxes[:, :, [1, 0, 3, 2]]
            # xyxy -> xywh
            true_boxes[:, :, [2, 3]] -= true_boxes[:, :, [0, 1]]
            # scale downsized boxes to match predictions on a full-sized image
            true_boxes *= img_scale[:, None, None]

            scores = []
            for i in range(len(images)):
                mask = pred_boxes[i, :, 4] >= min_score
                s = calculate_image_precision(true_boxes[i],
                                              pred_boxes[i, mask, :4],
                                              thresholds=IOU_THRESHOLDS,
                                              form='coco')
                scores.append(s)
            stats['map'] = np.mean(scores)

        return stats

    train_metric_names = list(conf.logging.out.train)
    train_metrics = create_metrics(train_metric_names,
                                   device if distributed else None)

    val_metric_names = list(conf.logging.out.val)
    if calc_map:
        from utils.metric import calculate_image_precision, IOU_THRESHOLDS
        val_metric_names.append('map')
    val_metrics = create_metrics(val_metric_names,
                                 device if distributed else None)

    trainer = build_engine(_update, train_metrics)
    evaluator = build_engine(_validate, val_metrics)
    to_save['trainer'] = trainer

    every_iteration = Events.ITERATION_COMPLETED
    trainer.add_event_handler(every_iteration, TerminateOnNan())

    if distributed:
        dist_bn = conf.distributed.dist_bn
        if dist_bn in ["reduce", "broadcast"]:
            from timm.utils import distribute_bn

            @trainer.on(Events.EPOCH_COMPLETED)
            def _distribute_bn_stats(eng: Engine):
                reduce = dist_bn == "reduce"
                if master_node:
                    logging.info("Distributing BN stats...")
                distribute_bn(model, num_replicas, reduce)

    sampler = train_dl.sampler
    if isinstance(sampler, (CustomSampler, DistributedSampler)):

        @trainer.on(Events.EPOCH_STARTED)
        def _set_epoch(eng: Engine):
            sampler.set_epoch(eng.state.epoch - 1)

    @trainer.on(Events.EPOCH_COMPLETED)
    def _scheduler_step(eng: Engine):
        # it starts from 1, so we don't need to add 1 here
        ep = eng.state.epoch
        lr_scheduler.step(ep)

    cp = conf.checkpoints
    pbar, pbar_vis = None, None

    if master_node:
        log_interval = conf.logging.interval_it
        log_event = Events.ITERATION_COMPLETED(every=log_interval)
        pbar = ProgressBar(persist=False)
        pbar.attach(trainer, metric_names=train_metric_names)
        pbar.attach(evaluator, metric_names=val_metric_names)

        for engine, name in zip([trainer, evaluator], ['train', 'val']):
            engine.add_event_handler(Events.EPOCH_STARTED, on_epoch_start)
            engine.add_event_handler(log_event,
                                     log_iter,
                                     pbar,
                                     interval_it=log_interval,
                                     name=name)
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     log_epoch,
                                     name=name)

        setup_checkpoints(trainer, to_save, epoch_length, conf)

    if 'load' in cp.keys() and cp.load is not None:
        if master_node:
            logging.info("Resume from a checkpoint: {}".format(cp.load))
            trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp,
                                      pbar)
        resume_from_checkpoint(to_save, cp, device=device)
        state = trainer.state
        # epoch counter start from 1
        lr_scheduler.step(state.epoch - 1)
        state.max_epochs = epochs

    @trainer.on(Events.EPOCH_COMPLETED(every=conf.validate.interval_ep))
    def _run_validation(eng: Engine):
        if distributed:
            torch.cuda.synchronize(device)
        evaluator.run(valid_dl)

    skip_train = conf.train.skip
    if master_node and conf.visualize.enabled:
        vis_eng = evaluator if skip_train else trainer
        setup_visualizations(vis_eng,
                             model,
                             valid_dl,
                             device,
                             conf,
                             force_run=skip_train)

    try:
        if skip_train:
            evaluator.run(valid_dl)
        else:
            trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length)
    except Exception as e:
        import traceback
        logging.error(traceback.format_exc())

    for pb in [pbar, pbar_vis]:
        if pb is not None:
            pbar.close()
コード例 #22
0
def run(args, random_seed=0):

    # Set random seed
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    writer = SummaryWriter(os.path.join(args['log_dir'], args['name']))

    print('Loading model....')
    training_history = {'CrossEntropy': [], 'Accuracy': []}
    testing_history = {'CrossEntropy': [], 'Accuracy': []}

    model_dict = args['config']['model']

    model = FastWeights(**model_dict)

    device = torch.device(args['config']['device'])
    model = model.to(device)

    print('Loading data....')
    train_loader, test_loader = load_data(args['config']['batch_size'],
                                          args['config']['workers'])

    params = [p for p in model.parameters() if p.requires_grad]

    # optimizer = torch.optim.SGD(
    #     params, lr=args['config']['lr'],
    #     momentum=args['config']['momentum'],
    #     weight_decay=args['config']['weight_decay']
    # )

    optimizer = torch.optim.Adam(params, lr=args['config']['lr'])

    if args['config']['scheduler'] == 'multi':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=args['config']['lr_steps'],
            gamma=args['config']['lr_gamma'])
    elif args['config']['scheduler'] == 'step':
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            milestones=args['config']['lr_step_size'],
            gamma=args['config']['lr_gamma'])
    elif args['config']['scheduler'] == 'reduce':
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=args['config']['reduce_type'],
            factor=args['config']['lr_gamma'],
        )
    elif args['config']['scheduler'] == 'cyclic':
        lr_scheduler = torch.optim.lr_scheduler.CyclicLR(
            optimizer,
            base_lr=args['config']['lr'],
            max_lr=10 * args['config']['lr'])
    else:
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda epoch: 1)

    criterion = nn.CrossEntropyLoss()

    def evaluate_function(engine, batch):
        model.eval()
        with torch.no_grad():
            inputs, targets = batch
            if device:
                inputs = inputs.to(device)
                targets = targets.to(device)

            inputs = torch.transpose(inputs, 0, 1)

            preds = model(inputs)
            return preds, targets

    def process_function(engine, batch):
        model.train()
        optimizer.zero_grad()
        inputs, targets = batch
        if device:
            inputs = inputs.to(device)
            targets = targets.to(device)

        inputs = torch.transpose(inputs, 0, 1)

        preds = model(inputs)
        loss = criterion(preds, targets)
        loss.backward()
        if args['config']['max_norm'] > 0:
            nn.utils.clip_grad_norm_(model.parameters(),
                                     max_norm=args['config']['max_norm'])
        optimizer.step()
        if args['config']['scheduler'] == 'cyclic':
            lr_scheduler.step()
        else:
            pass
        return loss.item()

    trainer = Engine(process_function)
    evaluator = Engine(evaluate_function)
    train_evaluator = Engine(evaluate_function)

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    Loss(criterion, output_transform=lambda x: [x[0], x[1]]).attach(
        evaluator, 'CrossEntropy')
    Accuracy().attach(evaluator, 'Accuracy')

    Loss(criterion, output_transform=lambda x: [x[0], x[1]]).attach(
        train_evaluator, 'CrossEntropy')
    Accuracy().attach(train_evaluator, 'Accuracy')

    pbar = ProgressBar(persist=True, bar_format="")
    pbar.attach(trainer, ['loss'])

    @trainer.on(Events.STARTED)
    def resume_training(engine):
        if args['config']['resume'] > 0:
            checkpoint = torch.load(os.path.join(
                args['dir'], args['config']['output_dir'],
                f"{args['name']}_{args['config']['resume']}.pth"),
                                    map_location=device)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            engine.state.epoch = args['config']['resume']

    def val_score(engine):
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        avg_loss = metrics['CrossEntropy']
        return -avg_loss

    def checkpointer(engine):
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': engine.state.epoch,
            'args': args
        }
        save_on_master(
            checkpoint,
            os.path.join(args['dir'], args['config']['output_dir'],
                         f"{args['name']}_{engine.state.epoch}.pth"))

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer)

    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    def score_function(engine):
        metrics = evaluator.state.metrics
        avg_loss = metrics['CrossEntropy']
        return -avg_loss

    def print_trainer_logs(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        avg_loss = metrics['CrossEntropy']
        avg_acc = metrics['Accuracy'] * 100

        training_history['CrossEntropy'].append(avg_loss)
        training_history['Accuracy'].append(avg_acc)

        writer.add_scalar("training/avg_loss", avg_loss, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_acc, engine.state.epoch)

        print("Training Results - Epoch: {} ".format(engine.state.epoch),
              "Avg loss: {:.4f} ".format(avg_loss),
              "Avg Acc: {:.4f} ".format(avg_acc))

    trainer.add_event_handler(Events.EPOCH_COMPLETED, print_trainer_logs)

    def log_validation_results(engine):
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        avg_loss = metrics['CrossEntropy']
        avg_acc = metrics['Accuracy'] * 100

        if args['config']['scheduler'] == 'reduce':
            lr_scheduler.step(-avg_loss)
        elif args['config']['scheduler'] == 'cyclic':
            pass
        else:
            lr_scheduler.step()

        testing_history['CrossEntropy'].append(avg_loss)
        testing_history['Accuracy'].append(avg_acc)

        writer.add_scalar("validation/avg_loss", avg_loss, engine.state.epoch)
        writer.add_scalar("validation/avg_accuracy", avg_acc,
                          engine.state.epoch)

        print("Validation Results - Epoch: {} ".format(engine.state.epoch),
              "Avg loss: {:.4f} ".format(avg_loss),
              "Avg Acc: {:.4f} ".format(avg_acc))

    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)

    handler = EarlyStopping(patience=args['config']['patience'],
                            score_function=score_function,
                            trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, handler)

    print('Training....')
    trainer.run(train_loader, max_epochs=args['config']['epochs'])
    writer.close()
    np.save(os.path.join(args['dir'], f"{args['name']}_traininglog.npy"),
            [training_history])
    np.save(os.path.join(args['dir'], f"{args['name']}_testinglog.npy"),
            [testing_history])
コード例 #23
0
def train(train_data_path,
          valid_data_path,
          config,
          out_dir="./explainability",
          batch_size=64,
          lr=1e-4,
          epochs=100):
    train_dataset = dataset_classes[config['type']](train_data_path, config)
    valid_dataset = dataset_classes[config['type']](valid_data_path, config)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=8)
    val_loader = torch.utils.data.DataLoader(valid_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=8)

    model = Model(valid_dataset, config)
    path_cp = "./explainability_checkpoints/" + out_dir
    os.makedirs(path_cp, exist_ok=True)
    with open(path_cp + "/config.json", 'w') as config_file:
        json.dump(config, config_file)

    optimizer = Adam(model.parameters(), lr=lr)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        weighted_binary_cross_entropy,
                                        device=model.device)

    validation_evaluator = create_evaluator(model)

    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")

    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names="all")

    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    best_model_handler = ModelCheckpoint(
        dirname="./explainability_checkpoints/" + out_dir,
        filename_prefix="best",
        n_saved=1,
        global_step_transform=global_step_from_engine(trainer),
        score_name="val_ap",
        score_function=lambda engine: engine.state.metrics['ap'],
        require_empty=False)
    validation_evaluator.add_event_handler(Events.COMPLETED,
                                           best_model_handler, {
                                               'model': model,
                                           })

    tb_logger = TensorboardLogger(log_dir='./explainability_tensorboard/' +
                                  out_dir)
    tb_logger.attach(
        trainer,
        log_handler=OutputHandler(
            tag="training",
            output_transform=lambda loss: {"batchloss": loss},
            metric_names="all"),
        event_name=Events.ITERATION_COMPLETED(every=100),
    )

    tb_logger.attach(
        validation_evaluator,
        log_handler=OutputHandler(tag="validation",
                                  metric_names=["ap"],
                                  another_engine=trainer),
        event_name=Events.EPOCH_COMPLETED,
    )
    #tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100))
    #tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
    #tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
    #tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
    #tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))

    @trainer.on(Events.EPOCH_COMPLETED(every=5))
    def log_validation_results(engine):
        validation_evaluator.run(val_loader)
        metrics = validation_evaluator.state.metrics
        pbar.log_message(
            f"Validation Results - Epoch: {engine.state.epoch} ap: {metrics['ap']}"  # f1: {metrics['f1']}, p: {metrics['p']}, r: {metrics['r']}
        )

        pbar.n = pbar.last_print_n = 0

    trainer.run(train_loader, max_epochs=epochs)
コード例 #24
0
def do_train(cfg,model,train_loader,val_loader,optimizer,scheduler,loss_fn,metrics,image_3_dataloader=None,image_4_dataloader=None):

    device = cfg.MODEL.DEVICE if torch.cuda.is_available() else 'cpu'
    epochs = cfg.SOLVER.MAX_EPOCHS
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("Trainer")
    logger.info("Start training")
    trainer = create_supervised_trainer(model.train(),optimizer,loss_fn,device=device)
    trainer.add_event_handler(Events.ITERATION_COMPLETED,TerminateOnNan())

    evaluator = create_supervised_evaluator(model.eval(),metrics={"pixel_error":metrics},device=device)
    # evaluator_trainer = create_supervised_evaluator(model.eval(),metrics={"pixel_error":(metrics)},device=device)
    timer = Timer(average=True)
    timer.attach(trainer,start=Events.EPOCH_STARTED,resume=Events.ITERATION_STARTED,pause=Events.ITERATION_COMPLETED,step=Events.ITERATION_COMPLETED)
    RunningAverage(output_transform=lambda x:x).attach(trainer,'avg_loss')

    # 每 log_period 轮迭代结束输出train_loss
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        len_train_loader = len(train_loader)
        log_period = int(cfg.LOG_PERIOD*len_train_loader)
        iter = (engine.state.iteration-1)%len_train_loader + 1 + engine.state.epoch*len_train_loader
        if iter % log_period == 0:
            iter = (engine.state.iteration-1)%len_train_loader + 1
            logger.info("Epoch[{}] Iteration[{}/{}] Loss {:.7f}".format(engine.state.epoch,iter,len_train_loader,engine.state.metrics['avg_loss']))
            
    @trainer.on(Events.EPOCH_COMPLETED)
    def save(engine):
        epoch = engine.state.epoch
        print("epoch: "+str(epoch))
        if epoch%1 == 0:
            model_name=os.path.join(cfg.OUTPUT.DIR_NAME+"model/","epoch_"+str(engine.state.epoch)+"_"+cfg.TAG+"_"+cfg.MODEL.NET_NAME+".pth")
            torch.save(model.module.state_dict(),model_name)

    # 每val_period轮迭代结束计算一次val_metric
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_val_metric(engine):
        len_train_loader = len(train_loader)
        iter = (engine.state.iteration-1)%len_train_loader + 1 + engine.state.epoch*len_train_loader
        val_period = int(cfg.VAL_PERIOD*len_train_loader)
        if iter % val_period == 0:
            pass
            # 打印输出
            # evaluator.run(val_loader)
            # metrics = evaluator.state.metrics
            # avg_loss = metrics["pixel_error"]
            # logger.info("Validation Result - Epoch: {} Avg Pixel Accuracy: {:.7f} ".format(engine.state.epoch,avg_loss))

            ######################
            # # 分别用ttaforward
            # cfg.TOOLS.image_n = 3
            # image_3_predict = tta_forward(cfg,image_3_dataloader,model.eval())
            # pil_image_3 = Image.fromarray(image_3_predict)
            # image_3_save_path = "iter_" + str(iter) + "_" + "image_3_predict.png"
            # pil_image_3.save(os.path.join(r"./output",image_3_save_path))
            # image_3_label_save_path = "iter_" + str(iter) + "_" + "vis_" + "image_3_predict.jpg"
            # source_image_3 = cv.imread("./output/source/image_3.png")
            # mask_3 = label_resize_vis(image_3_predict,source_image_3)
            # cv.imwrite(os.path.join(r"./output",image_3_label_save_path),mask_3)

            # cfg.TOOLS.image_n = 4
            # image_4_predict = tta_forward(cfg,image_4_dataloader,model.eval())
            # pil_image_4 = Image.fromarray(image_4_predict)
            # image_4_save_path = "iter_" + str(iter) + "_" + "image_4_predict.png"
            # pil_image_4.save(os.path.join(r"./output",image_4_save_path))
            # image_4_label_save_path = "iter_" + str(iter) + "_" + "vis_" + "image_4_predict.jpg"
            # source_image_4 = cv.imread("./output/source/image_4.png")
            # mask_4 = label_resize_vis(image_4_predict,source_image_4)
            # cv.imwrite(os.path.join(r"./output",image_4_label_save_path),mask_4)



            # 设置Loss检测,当检测到pixel_accuracy停止下降时,调整loss
            if cfg.SOLVER.LR_SCHEDULER == "StepLR":
                lr = optimizer.state_dict()['param_groups'][0]['lr']
                scheduler.step()
                new_lr = optimizer.state_dict()['param_groups'][0]['lr']
                    
            # elif cfg.SOLVER.LR_SCHEDULER == "ReduceLROnPlateau":
            #     lr = optimizer.state_dict()['param_groups'][0]['lr']
            #     scheduler.step(-avg_loss)
            #     new_lr = optimizer.state_dict()['param_groups'][0]['lr']
            #     print(new_lr,lr)
            #     if new_lr != lr:
            #         cfg.SOLVER.LR_SCHEDULER_REPEAT = cfg.SOLVER.LR_SCHEDULER_REPEAT - 1
            #         if cfg.SOLVER.LR_SCHEDULER_REPEAT <0:  trainer.terminate()   #设定学习率调整次数,降低太多次学习率太低时,终止训练

            elif  cfg.SOLVER.LR_SCHEDULER == "CosineAnnealingLR":
                lr = optimizer.state_dict()['param_groups'][0]['lr']
                scheduler.step()
                new_lr = optimizer.state_dict()['param_groups'][0]['lr']
                pass
            if new_lr!=lr:
                print(new_lr,lr)



    # @trainer.on(Events.EPOCH_COMPLETED)
    # def log_training_result(engine):
    #     if engine.state.epoch % 5 == 0:
    #         evaluator_trainer.run(train_loader)
    #         metrics = evaluator_trainer.state.metrics
    #         avg_loss = metrics["pixel_error"]
    #         logger.info("Training Result - Epoch: {} Avg Pixel Error: {:.7f} ".format(engine.state.epoch,avg_loss))


            
    


    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info("Epoch {} done.Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]".format(engine.state.epoch,timer.value()*timer.step_count,
    		train_loader.batch_size / timer.value()))
        timer.reset()


    # def score_pixel_error(engine):
    # 	error = evaluator.state.metrics['pixel_error']
    # 	return error
    # handler_ModelCheckpoint_pixel_error = ModelCheckpoint(dirname=cfg.OUTPUT.DIR_NAME+"model/",filename_prefix=cfg.TAG+"_"+cfg.MODEL.NET_NAME,
    # 	score_function=score_pixel_error,n_saved=cfg.OUTPUT.N_SAVED,create_dir=True,score_name=cfg.SOLVER.CRITERION,require_empty=False)
    # evaluator.add_event_handler(Events.EPOCH_COMPLETED,handler_ModelCheckpoint_pixel_error,{'model':model.module.state_dict()})
    

    trainer.run(train_loader,max_epochs=epochs)
コード例 #25
0
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             metrics, device):
    def _prepare_batch(batch, device=None, non_blocking=False):
        """Prepare batch for training: pass to a device with options.

        """
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def create_supervised_dp_trainer(
        model,
        optimizer,
        device=None,
        non_blocking=False,
        prepare_batch=_prepare_batch,
        output_transform=lambda x, y, y_pred, loss: loss.item()):
        """
        Factory function for creating a trainer for supervised models.

        Args:
            model (`torch.nn.Module`): the model to train.
            optimizer (`torch.optim.Optimizer`): the optimizer to use.
            loss_fn (torch.nn loss function): the loss function to use.
            device (str, optional): device type specification (default: None).
                Applies to both model and batches.
            non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
                with respect to the host. For other cases, this argument has no effect.
            prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
                tuple of tensors `(batch_x, batch_y)`.
            output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
                to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.

        Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
            of the processed batch by default.

        Returns:
            Engine: a trainer engine with supervised update function.
        """
        if device:
            model.to(device)

        def _update(engine, batch):
            # model.train()
            optimizer.zero_grad()
            x, y = prepare_batch(batch,
                                 device=device,
                                 non_blocking=non_blocking)
            with autocast():
                total_loss = model(x, y)
            total_loss = total_loss.mean()  # model 里求均值
            # Scales loss. 为了梯度放大.
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            writer.add_scalar("total loss", total_loss.cpu().data.numpy())
            scaler.update()
            # total_loss.backward()
            # optimizer.step()
            return output_transform(x, y, None, total_loss)

        return Engine(_update)

    scaler = torch.cuda.amp.GradScaler()
    master_device = device[0]  #默认设置第一块为主卡
    trainer = create_supervised_dp_trainer(model,
                                           optimizer,
                                           device=master_device)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')

    log_dir = cfg['log_dir']
    writer = SummaryWriter(log_dir=log_dir)

    # create pbar
    len_train_loader = len(train_loader)
    pbar = tqdm(total=len_train_loader)

    froze_num_layers = cfg['warm_up']['froze_num_lyers']
    if cfg['multi_gpu']:
        freeze_layers(model.module, froze_num_layers)
    else:
        freeze_layers(model, froze_num_layers)

    # Finetuning 模式下,patch较大,batch较小冻结全模型bn
    # Normal 模式下, 冻结对应网络层数
    if 'mode' in cfg and cfg['mode'] == "Finetuning":
        if cfg['multi_gpu']:
            fix_bn(model.module)
        else:
            fix_bn(model)

    ##########################################################################################
    ###########                    Events.ITERATION_COMPLETED                    #############
    ##########################################################################################

    # 每 log_period 轮迭代结束输出train_loss
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_period = cfg['log_period']
        log_per_iter = int(log_period * len_train_loader) if int(
            log_period * len_train_loader) >= 1 else 1  # 计算打印周期
        current_iter = (engine.state.iteration - 1) % len_train_loader + 1 + (
            engine.state.epoch - 1) * len_train_loader  # 计算当前 iter

        lr = optimizer.state_dict()['param_groups'][0]['lr']

        if current_iter % log_per_iter == 0:
            pbar.write("Epoch[{}] Iteration[{}] lr {:.7f} Loss {:.7f}".format(
                engine.state.epoch, current_iter, lr,
                engine.state.metrics['avg_loss']))
            pbar.update(log_per_iter)
            writer.add_scalar('loss', engine.state.metrics['avg_loss'],
                              current_iter)

    # lr_scheduler Warm Up
    @trainer.on(Events.ITERATION_COMPLETED)
    def lr_scheduler_iteration(engine):
        scheduler.ITERATION_COMPLETED()
        current_iter = (engine.state.iteration - 1) % len_train_loader + 1 + (
            engine.state.epoch - 1) * len_train_loader  # 计算当前 iter
        length = cfg['warm_up']['length']
        min_lr = cfg['warm_up']['min_lr']
        max_lr = cfg['warm_up']['max_lr']
        froze_num_layers = cfg['warm_up']['froze_num_lyers']
        if current_iter < length:
            """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
            lr = (max_lr - min_lr) / length * current_iter
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            # pbar.write("lr: {}".format(lr))

        if current_iter == length:
            if 'mode' in cfg and cfg['mode'] == "Finetuning":
                pass
            else:  # Normal 模式下,Warm Up结束解冻
                pass
                # if cfg['multi_gpu']:
                #     freeze_layers(model.module,froze_num_layers)
                # else:
                #     freeze_layers(model,froze_num_layers)

                # for param_group in optimizer.param_groups:
                #     param_group['lr'] = cfg['optimizer']['lr']

    @trainer.on(Events.EPOCH_COMPLETED)
    def lr_scheduler_epoch(engine):
        scheduler.EPOCH_COMPLETED()

    ##########################################################################################
    ##################               Events.EPOCH_COMPLETED                    ###############
    ##########################################################################################
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_temp_epoch(engine):
        save_dir = cfg['save_dir']
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        epoch = engine.state.epoch
        if epoch % 1 == 0:
            model_name = os.path.join(save_dir, cfg['tag'] + "_temp.pth")
            # import pdb; pdb.set_trace()

            if cfg['multi_gpu']:
                save_pth = {'model': model.module.state_dict(), 'cfg': cfg}
                torch.save(save_pth, model_name)
            else:
                save_pth = {'model': model.state_dict(), 'cfg': cfg}
                torch.save(save_pth, model_name)

        if epoch % 10 == 0:
            model_name = os.path.join(save_dir,
                                      cfg['tag'] + "_" + str(epoch) + ".pth")
            if cfg['multi_gpu']:
                save_pth = {'model': model.module.state_dict(), 'cfg': cfg}
                torch.save(save_pth, model_name)
            else:
                save_pth = {'model': model.state_dict(), 'cfg': cfg}
                torch.save(save_pth, model_name)

    @trainer.on(Events.EPOCH_COMPLETED)
    def calu_acc(engine):
        epoch = engine.state.epoch
        if epoch % 10 == 0:
            model.eval()
            num_correct = 0
            num_example = 0
            torch.cuda.empty_cache()
            with torch.no_grad():
                for image, target in tqdm(train_loader):
                    image, target = image.to(master_device), target.to(
                        master_device)
                    pred_logit_dict = model(image, target)
                    pred_logit = [
                        value for value in pred_logit_dict.values()
                        if value is not None
                    ]

                    pred_logit = pred_logit[0]
                    indices = torch.max(pred_logit, dim=1)[1]
                    correct = torch.eq(indices, target).view(-1)
                    num_correct += torch.sum(correct).item()
                    num_example += correct.shape[0]

            acc = (num_correct / num_example)
            pbar.write("Acc: {}".format(acc))
            writer.add_scalar("Acc", acc, epoch)
            torch.cuda.empty_cache()
            model.train()

        # Finetuning 模式下,patch较大,batch较小冻结全模型bn
        # Normal 模式下, 冻结对应网络层数
        if 'mode' in cfg and cfg['mode'] == "Finetuning":
            if cfg['multi_gpu']:
                fix_bn(model.module)
            else:
                fix_bn(model)

    @trainer.on(Events.EPOCH_COMPLETED)
    def reset_pbar(engine):
        pbar.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def reset_dataset(engine):  # 仅针对jr写的train_dataset,手动shuffle
        if hasattr(train_loader.dataset, 'shuffle'):
            pbar.write("shuffle train_dataloader")
            train_loader.dataset.shuffle()

    max_epochs = cfg['max_epochs']
    trainer.run(train_loader, max_epochs=max_epochs)
    pbar.close()
コード例 #26
0
    def _run_training(self,
                      train_data,
                      valid_data=[],
                      test_data=[],
                      tb_log_dir=None,
                      split_num=None,
                      verbose=True):

        # setup
        epochs = self.training_config.get("epochs")
        optimizer_config = self.training_config.get("optimizer_config")
        early_stopping = self.training_config.get("early_stopping")
        checkpoint_saving = self.training_config.get("checkpoint_saving")
        graph_dataset_config = self.training_config.get("graph_dataset_config")
        device = self.training_config.get("device")
        extraction_target = self.training_config.get("extraction_target")

        graph_dataset_config["device"] = device
        graph_dataset_config["extraction_target"] = extraction_target

        if split_num is not None:
            tb_log_dir += "-Split{}".format(split_num + 1)

        # prepare graph-sets
        graph_dataset = GraphDataset(train_data,
                                     valid_data=valid_data,
                                     test_data=test_data,
                                     graph_dataset_config=graph_dataset_config)

        train_loader, val_loader, test_loader = graph_dataset.get_loaders()

        worker_init_fn = graph_dataset.init_fn

        # create model, optimizer, loss
        model = self.model_class(self.model_config)
        model = model.to(device)

        optimizer = self.optimizer_class(model.parameters(),
                                         **optimizer_config)

        # apparently, we have to do this
        self.model = model
        self.optimizer = optimizer

        # load model from checkpoint if available
        if self.trained_model_checkpoint is not None:
            self.custom_print("load transfer-learning checkpoint...")
            model, optimizer = self._prepare_trained_model(model, optimizer)

        loss = self.loss_class()
        loss_name = "mse"

        evaluator_settings = {
            "device": device,
            "extraction_target": extraction_target,
            "pred_collector_function":
            lambda x: self._pred_collector_function(x),
            "metrics": {
                loss_name: Loss(loss)
            }
        }

        ## configure trainer ##
        trainer = create_supervised_trainer(
            model,
            optimizer,
            loss,
            device=device,
            extraction_target=extraction_target)

        ###############################################
        ## configure evaluators for each data source ##
        train_evaluator = create_supervised_evaluator(model,
                                                      **evaluator_settings)

        val_evaluator = create_supervised_evaluator(model,
                                                    **evaluator_settings)

        test_evaluator = create_supervised_evaluator(model,
                                                     **evaluator_settings)

        # configure behavior for early stopping
        if early_stopping is not None:
            stopper = EarlyStopping(patience=early_stopping,
                                    score_function=self.score_function,
                                    trainer=trainer)
            val_evaluator.add_event_handler(Events.COMPLETED, stopper)

        # configure behavior for checkpoint saving
        if checkpoint_saving is not None:
            save_handler = None
            if self.test_mode and self.validation_mode:
                self.custom_print("Use LocalSaveHandler...")
                save_handler = LocalSaveHandler(self)
            else:
                self.custom_print("Use IgniteSaveHandler...")
                save_handler = DiskSaver(self.save_path,
                                         create_dir=True,
                                         require_empty=False)

            saver = Checkpoint(
                {
                    "model_state_dict": model,
                    "optimizer_state_dict": optimizer
                },
                save_handler,
                filename_prefix='{}_best'.format(self.dataset.name),
                score_name="val_loss",
                score_function=self.score_function,
                global_step_transform=global_step_from_engine(trainer),
                n_saved=1)
            train_evaluator.add_event_handler(Events.COMPLETED, saver)

        @trainer.on(Events.STARTED)
        def log_training_start(trainer):
            self.custom_print("Split: {}".format(split_num + 1))

        @trainer.on(Events.COMPLETED)
        def log_training_complete(trainer):
            """Trigger evaluation on test set if training is completed."""

            epoch = trainer.state.epoch
            suffix = "(Early Stopping)" if epoch < epochs else ""

            self.custom_print("Finished after {:03d} epochs! {}".format(
                epoch, suffix))

            embedding_list = []

            def _graph_embedding_function(tensor, idx):
                while idx >= len(embedding_list):
                    embedding_list.append([])
                embedding_list[idx].append(tensor.cpu().detach().numpy())

            if self.test_mode and self.validation_mode:
                checkpoint_dict = self.best_model_checkpoint
                self.custom_print(
                    "Load best model checkpoint by validation loss... Epoch: {}"
                    .format(checkpoint_dict["epoch"]))
                model, optimizer = self._load_checkpoint(
                    self.model, self.optimizer, checkpoint_dict["checkpoint"])

            self.model.graph_embedding_function = _graph_embedding_function
            self.persist_pred = True

            if not self.test_mode:
                return

            test_evaluator.run(test_loader)

        @trainer.on(Events.EPOCH_COMPLETED)
        def compute_metrics(engine):
            """Compute evaluation metric values after each epoch."""
            train_evaluator.run(train_loader)

            if hasattr(self.model, "node_counter"):
                self.custom_print(self.model.node_counter)

            if self.validation_mode:
                val_evaluator.run(val_loader)

        # terminate training if Nan values are produced
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

        # create tensorboard-logger
        tb_logger = create_tb_logger(model,
                                     optimizer,
                                     trainer,
                                     train_evaluator,
                                     val_evaluator,
                                     test_evaluator,
                                     log_dir=tb_log_dir,
                                     verbose=verbose,
                                     custom_print=self.custom_print,
                                     loss_name=loss_name)

        with torch.autograd.detect_anomaly():
            trainer.run(train_loader, max_epochs=epochs)

        tb_logger.close()

        if not self.test_mode:
            return 0, 0

        test_acc = test_evaluator.state.metrics["accuracy"]
        test_loss = test_evaluator.state.metrics["mse"]

        return test_acc, test_loss
コード例 #27
0
        print(f"TRAINING IS DONE FOR {RUN_NAME} RUN.")

    pbar = ProgressBar()

    checkpointer = ModelCheckpoint(
        CHECKPOINTS_RUN_DIR_PATH,
        filename_prefix=RUN_NAME.lower(),
        n_saved=None,
        score_function=lambda engine: round(engine.state.metrics['WRA'], 3),
        score_name='WRA',
        atomic=True,
        require_empty=True,
        create_dir=True,
        archived=False,
        global_step_transform=global_step_from_engine(trainer))
    nan_handler = TerminateOnNan()
    coslr = CosineAnnealingScheduler(opt,
                                     "lr",
                                     start_value=LR,
                                     end_value=LR / 4,
                                     cycle_size=TOTAL_UPDATE_STEPS // 1)

    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer,
                                {'_': mude})

    trainer.add_event_handler(Events.ITERATION_COMPLETED, nan_handler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, coslr)

    GpuInfo().attach(trainer, name='gpu')
    pbar.attach(trainer,
                output_transform=lambda output: {'loss': output['loss']},
コード例 #28
0
def add_events(engines, dataloaders, model, optimizer, device, save_dir, args):
    trainer, valid_evaluator, test_evaluator = engines
    train_dl, valid_dl, test_dl = dataloaders

    if args.valid_on == 'Loss':
        score_fn = lambda engine: -engine.state.metrics[args.valid_on]
    elif args.valid_on == 'Product':
        score_fn = lambda engine: engine.state.metrics[
            'MRR'] * engine.state.metrics['HR@10']
    elif args.valid_on == 'RMS':
        score_fn = lambda engine: engine.state.metrics[
            'MRR']**2 + engine.state.metrics['HR@10']**2
    else:
        score_fn = lambda engine: engine.state.metrics[args.valid_on]

    # LR Scheduler
    if args.lr_scheduler == 'restart':
        scheduler = CosineAnnealingScheduler(optimizer,
                                             'lr',
                                             start_value=args.lr,
                                             end_value=args.lr * 0.01,
                                             cycle_size=len(train_dl),
                                             cycle_mult=args.cycle_mult)
        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler,
                                  'lr_scheduler')
    elif args.lr_scheduler == 'triangle':
        scheduler = make_slanted_triangular_lr_scheduler(
            optimizer, n_events=args.n_epochs * len(train_dl), lr_max=args.lr)
        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler,
                                  'lr_scheduler')
    elif args.lr_scheduler == 'none':
        pass
    else:
        raise NotImplementedError

    # EarlyStopping
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
    valid_evaluator.add_event_handler(
        Events.COMPLETED,
        EarlyStopping(args.patience, score_function=score_fn, trainer=trainer))
    # Training Loss
    RunningAverage(output_transform=lambda x: x,
                   alpha=args.avg_alpha).attach(trainer, 'loss')
    # Checkpoint
    ckpt_handler = ModelCheckpoint(save_dir,
                                   'best',
                                   score_function=score_fn,
                                   score_name=args.valid_on,
                                   n_saved=1)
    valid_evaluator.add_event_handler(Events.COMPLETED, ckpt_handler,
                                      {'model': model})
    # Timer
    timer = Timer(average=True)
    timer.attach(trainer,
                 resume=Events.EPOCH_STARTED,
                 step=Events.EPOCH_COMPLETED)
    # Progress Bar
    if args.pbar:
        pbar = ProgressBar()
        pbar.attach(trainer, ['loss'])
        log_msg = pbar.log_message
    else:
        log_msg = print

    cpe_valid = CustomPeriodicEvent(n_epochs=args.valid_every)
    cpe_valid.attach(trainer)
    valid_metrics_history = []

    @trainer.on(
        getattr(cpe_valid.Events, f'EPOCHS_{args.valid_every}_COMPLETED'))
    def evaluate_on_valid(engine):
        state = valid_evaluator.run(valid_dl)
        metrics = state.metrics
        valid_metrics_history.append(metrics)
        msg = f'Epoch: {engine.state.epoch:3d} AvgTime: {timer.value():3.1f}s TrainLoss: {engine.state.metrics["loss"]:.4f} '
        msg += ' '.join([
            f'{k}: {v:.4f}' for k, v in metrics.items()
            if k in ['Loss', 'MRR', 'HR@10']
        ])
        log_msg(msg)

    @trainer.on(Events.COMPLETED)
    def evaluate_on_test(engine):
        pth_file = [
            f for f in pathlib.Path(save_dir).iterdir()
            if f.name.endswith('pth')
        ][0]
        log_msg(f'Load Best Model: {str(pth_file)}')
        model.load_state_dict(torch.load(pth_file, map_location=device))
        # Rerun on Valid for log.
        valid_state = valid_evaluator.run(valid_dl)
        engine.state.valid_metrics = valid_state.metrics
        # Test
        test_state = test_evaluator.run(test_dl)
        engine.state.test_metrics = test_state.metrics
        engine.state.valid_metrics_history = valid_metrics_history
        msg = f'[Test] '
        msg += ' '.join([
            f'{k}: {v:.4f}' for k, v in test_state.metrics.items()
            if k in ['Loss', 'MRR', 'HR@10']
        ])
        log_msg(msg)

    # Tensorboard
    if args.tensorboard:
        tb_logger = TensorboardLogger(log_dir=str(save_dir / 'tb_log'))
        # Loss
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag='training', output_transform=lambda x: x),
                         event_name=Events.ITERATION_COMPLETED)
        # Metrics
        tb_logger.attach(valid_evaluator,
                         log_handler=OutputHandler(
                             tag='validation',
                             metric_names=['Loss', 'MRR', 'HR@10'],
                             another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        # Optimizer
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        # Parameters
        # tb_logger.attach(trainer,
        #                  log_handler=WeightsScalarHandler(model),
        #                  event_name=Events.ITERATION_COMPLETED)
        # tb_logger.attach(trainer,
        #                  log_handler=GradsScalarHandler(model),
        #                  event_name=Events.ITERATION_COMPLETED)

        @trainer.on(Events.COMPLETED)
        def close_tb(engine):
            tb_logger.close()
コード例 #29
0
ファイル: common.py プロジェクト: vidushityagi8/ignite
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any,
) -> None:
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED,
                lambda engine: cast(_LRScheduler, lr_scheduler).step())
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(to_save,
                                        cast(Union[Callable, BaseSaveHandler],
                                             save_handler),
                                        filename_prefix="training",
                                        **kwargs)
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=save_every_iters),
            checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer,
            name="gpu",
            event_name=Events.ITERATION_COMPLETED(
                every=log_every_iters)  # type: ignore[arg-type]
        )

    if output_names is not None:

        def output_transform(x: Any, index: int, name: str) -> Any:
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    f"It should either mapping or sequence, but given {type(x)}"
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform,
                                                    index=i,
                                                    name=n),
                           epoch_bound=False).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer,
                metric_names="all",
                event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
コード例 #30
0
def create_trainer(model, tasks, optims, loaders, args):
    zt = []
    zt_task = {'left': [], 'right': []}

    if args.dataset.name == 'dummy':
        lim = 2.5
        lims = [[-lim, lim], [-lim, lim]]
        grid = setup_grid(lims, 1000)

    def trainer_step(engine, batch):
        model.train()

        for optim in optims:
            optim.zero_grad()

        # Batch data
        x, y = batch
        x = convert_tensor(x.float(), args.device)
        y = [convert_tensor(y_, args.device) for y_ in y]

        training_loss = 0.
        losses = []

        # Intermediate representation
        with cached():
            preds = model(x)
            if args.dataset.name == 'dummy':
                zt.append(model.rep.detach().clone())

            for pred_i, task_i in zip(preds, tasks):
                loss_i = task_i.loss(pred_i, y[task_i.index])

                if args.dataset.name == 'dummy':
                    loss_i = loss_i.mean(dim=0)
                    zt_task[task_i.name].append(pred_i.detach().clone())

                # Track losses
                losses.append(loss_i)
                training_loss += loss_i.item() * task_i.weight

            if args.dataset.name == 'dummy' and (
                    engine.state.epoch == engine.state.max_epochs
                    or engine.state.epoch % args.training.plot_every == 0):
                fig = plot_toy(grid,
                               model,
                               tasks, [zt, zt_task['left'], zt_task['right']],
                               trainer.state.iteration - 1,
                               levels=20,
                               lims=lims)
                fig.savefig(f'plots/step_{engine.state.iteration - 1}.png')
                plt.close(fig)

            model.backward(losses)

        for optim in optims:  # Run the optimizers
            optim.step()

        return training_loss, losses

    trainer = Engine(trainer_step)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss')
    for i, task_i in enumerate(tasks):
        output_transform = partial(lambda idx, x: x[1][idx], i)
        RunningAverage(output_transform=output_transform).attach(
            trainer, f'train_{task_i.name}')

    pbar = ProgressBar()
    pbar.attach(trainer,
                metric_names=['loss'] + [f'train_{t.name}' for t in tasks])

    # Validation
    validator = create_evaluator(model, tasks, args)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validator(trainer):
        validator.run(loaders['val'])
        metrics = validator.state.metrics
        loss = 0.
        for task_i in tasks:
            loss += metrics[f'loss_{task_i.name}'] * task_i.weight

        trainer.state.metrics['val_loss'] = loss

    # Checkpoints
    model_checkpoint = {'model': model}
    handler = ModelCheckpoint('checkpoints', 'latest', require_empty=False)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=args.training.save_every), handler,
        model_checkpoint)

    @trainer.on(Events.EPOCH_COMPLETED(every=args.training.save_every))
    def save_state(engine):
        with open('checkpoints/state.pkl', 'wb') as f:
            pickle.dump(engine.state, f)

    @trainer.on(Events.COMPLETED(every=args.training.save_every))
    def save_state(engine):
        with open('checkpoints/state.pkl', 'wb') as f:
            pickle.dump(engine.state, f)

    handler = ModelCheckpoint(
        'checkpoints',
        'best',
        require_empty=False,
        score_function=(lambda e: -e.state.metrics['val_loss']))
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=args.training.save_every), handler,
        model_checkpoint)
    trainer.add_event_handler(Events.COMPLETED, handler, model_checkpoint)

    return trainer