Exemplo n.º 1
0
 def __init__(
     self,
     job_dir,
     num_examples,
     learning_rate,
     batch_size,
     epochs,
     num_workers,
     seed,
 ):
     super(PyTorchModel, self).__init__(job_dir=job_dir, seed=seed)
     self.num_examples = num_examples
     self.learning_rate = learning_rate
     self.batch_size = batch_size
     self.epochs = epochs
     self.summary_writer = tensorboard.SummaryWriter(log_dir=self.job_dir)
     self.logger = utils.setup_logger(name=__name__ + "." +
                                      self.__class__.__name__,
                                      distributed_rank=0)
     self.trainer = engine.Engine(self.train_step)
     self.evaluator = engine.Engine(self.tune_step)
     self._network = None
     self._optimizer = None
     self._metrics = None
     self.num_workers = num_workers
     self.device = distributed.device()
     self.best_state = None
     self.counter = 0
def create_vae_trainer(model,
                       optimizer,
                       crt,
                       metrics=None,
                       device=th.device('cpu'),
                       non_blocking=True) -> ie.Engine:
    if device:
        model.to(device)

    def _update(_engine, batch):
        model.train()
        optimizer.zero_grad()
        _in, _cls_gt, _recon_gt = prepare_batch(batch,
                                                device=device,
                                                non_blocking=non_blocking)
        _recon_pred, _cls_pred, _temporal_latents, _class_latent, _mean, _var, _ = model(
            _in, num_samples=1)
        ce, l1, kld = crt(_recon_pred, _cls_pred, _recon_gt, _cls_gt, _mean,
                          _var)

        (ce + l1 + crt.kld_factor * kld).backward()
        optimizer.step()

        return (_recon_pred.detach(), _cls_pred.detach(),
                _temporal_latents.detach(), _class_latent.detach(),
                _mean.detach(), _var.detach(), _in.detach(), _cls_gt.detach(),
                ce.item(), l1.item(), kld.item(), crt.kld_factor)

    _engine = ie.Engine(_update)
    if metrics is not None:
        for name, metric in metrics.items():
            metric.attach(_engine, name)

    return _engine
def create_ae_trainer(model,
                      optimizer,
                      crt,
                      metrics=None,
                      device=th.device('cpu'),
                      non_blocking=True) -> ie.Engine:
    if device:
        model.to(device)

    def _update(_engine, batch):
        model.train()
        optimizer.zero_grad()
        _in, _cls_gt, _recon_gt = prepare_batch(batch,
                                                device=device,
                                                non_blocking=non_blocking)

        _recon_pred, _cls_pred, _temporal_embeds, _class_embed = model(_in)

        ce, l1 = crt(_recon_pred, _cls_pred, _recon_gt, _cls_gt)
        (ce + l1).backward()

        optimizer.step()

        return (_recon_pred.detach(), _cls_pred.detach(),
                _temporal_embeds.detach(), _class_embed.detach(), _in.detach(),
                _cls_gt.detach(), ce.item(), l1.item())

    _engine = ie.Engine(_update)
    if metrics is not None:
        for name, metric in metrics.items():
            metric.attach(_engine, name)

    return _engine
    def create_trainer(self, optimizer: optim.Optimizer,
                       device: torch.device) -> engine.Engine:
        """Create :class:`ignite.engine.Engine` trainer.

        Args:
            optimizer (optim.Optimizer): torch optimizer.
            device (torch.device): selected device.

        Returns:
            engine.Engine: training Engine.
        """
        def _update(engine: engine.Engine, batch: dict):
            batch["src"] = batch["src"].to(device)
            batch["trg"] = batch["trg"].to(device)

            self.train()
            optimizer.zero_grad()
            gen_probs = self.forward(batch["src"], batch["trg"])
            loss = self.criterion(gen_probs.view(-1, self.vocab_size),
                                  batch["trg"][:, 1:].contiguous().view(-1))
            loss.backward()
            optimizer.step()

            return loss.item()

        return engine.Engine(_update)
def create_cls_trainer(model,
                       optimizer,
                       crt,
                       metrics=None,
                       device=th.device('cpu'),
                       non_blocking=True) -> ie.Engine:
    if device:
        model.to(device)

    def _update(_engine, batch):
        model.train()
        optimizer.zero_grad()
        _in, _cls_gt, _ = prepare_batch(batch,
                                        device=device,
                                        non_blocking=non_blocking)
        y_pred, temporal_embeds, class_embed = model(_in)
        loss = crt(y_pred, _cls_gt)
        loss.backward()
        optimizer.step()
        return loss.item(), y_pred.detach(), _cls_gt.detach()

    _engine = ie.Engine(_update)
    if metrics is not None:
        for name, metric in metrics.items():
            metric.attach(_engine, name)

    return _engine
Exemplo n.º 6
0
 def __init__(self, patience, score_name, evaluator_name, mode='max'):
     if mode not in ['min', 'max']:
         raise ValueError(
             f'mode must be min or max. mode value found is {mode}')
     super(EarlyStopping, self).__init__(
         patience,
         score_function=lambda e: e.state.metrics[score_name]
         if mode == 'max' else -e.state.metrics[score_name],
         trainer=engine.Engine(lambda engine, batch: None))
     self.evaluator_name = evaluator_name
Exemplo n.º 7
0
    def predict(self, loader: _data.DataLoader) -> Tensor:
        def estimation_update(engine: _engine.Engine, batch) -> dict:
            return {"y_pred": self.model(batch)}

        estimator = _engine.Engine(estimation_update)
        result = []

        @estimator.on(_engine.Events.ITERATION_COMPLETED)
        def save_results(engine: _engine.Engine) -> None:
            output = engine.state.output['y_pred'].detach()
            result.append(output)
            torch.cuda.empty_cache()

        self.model.eval()
        batches = VocalExtractor.get_number_of_batches(loader)
        estimator.run(loader, epoch_length=batches, max_epochs=1)

        result = torch.cat(result, dim=0)
        return result.transpose(0, 1)
    def create_evaluator(self, device: torch.device) -> engine.Engine:
        """Create :class:`ignite.engine.Engine` evaluator

        Args:
            device (torch.device): selected device.

        Returns:
            engine.Engine: evaluator engine.
        """
        def _evaluate(engine: engine.Engine, batch: dict):
            batch["src"] = batch["src"].to(device)
            batch["trg"] = batch["trg"].to(device)

            self.eval()
            generated, __ = self.inference(batch["src"],
                                           batch["trg"].shape[1] - 1)

            return generated, batch["trg"][:, 1:]

        return engine.Engine(_evaluate)
def create_cls_evaluator(model,
                         metrics=None,
                         device=th.device('cpu'),
                         non_blocking=True) -> ie.Engine:
    if device:
        model.to(device)

    def _inference(_engine, batch):
        model.eval()
        with th.no_grad():
            _in, _cls_gt, _ = prepare_batch(batch,
                                            device=device,
                                            non_blocking=non_blocking)
            y_pred, temporal_embeds, class_embed = model(_in)
            return y_pred, _cls_gt, temporal_embeds, class_embed

    _engine = ie.Engine(_inference)
    if metrics is not None:
        for name, metric in metrics.items():
            metric.attach(_engine, name)

    return _engine
Exemplo n.º 10
0
def create_mask_rcnn_trainer(model: nn.Module, optimizer: optim.Optimizer, device=None, non_blocking: bool = False):
    if device:
        model.to(device)

    fn_prepare_batch = lambda batch: engine._prepare_batch(batch, device=device, non_blocking=non_blocking)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()

        image, targets = fn_prepare_batch(batch)
        losses = model(image, targets)

        loss = sum(loss for loss in losses.values())

        loss.backward()
        optimizer.step()

        losses = {k: v.item() for k, v in losses.items()}
        losses['loss'] = loss.item()
        return losses

    return engine.Engine(_update)
def create_vae_evaluator(model,
                         metrics=None,
                         device=None,
                         num_samples: int = None,
                         non_blocking=True) -> ie.Engine:
    if device:
        model.to(device)

    def _inference(_engine, batch):
        model.eval()
        with th.no_grad():
            _in, _cls_gt, _recon_gt = prepare_batch(batch,
                                                    device=device,
                                                    non_blocking=non_blocking)
            _recon_pred, _cls_pred, _temp_lat, _cls_lat, _mean, _var, _vote = model(
                _in, num_samples=num_samples)
            return _recon_pred, _cls_pred, _temp_lat, _cls_lat, _mean, _var, _in, _cls_gt, _vote

    _engine = ie.Engine(_inference)
    if metrics is not None:
        for name, metric in metrics.items():
            metric.attach(_engine, name)

    return _engine
Exemplo n.º 12
0
def create_mask_rcnn_evaluator(model: nn.Module, metrics, device=None, non_blocking: bool = False):
    if device:
        model.to(device)

    fn_prepare_batch = lambda batch: engine._prepare_batch(batch, device=device, non_blocking=non_blocking)

    def _update(engine, batch):
        # warning(will.brennan) - not putting model in eval mode because we want the losses!
        with torch.no_grad():
            image, targets = fn_prepare_batch(batch)
            losses = model(image, targets)

            losses = {k: v.item() for k, v in losses.items()}
            losses['loss'] = sum(losses.values())

        # note(will.brennan) - an ugly hack for metrics...
        return (losses, len(image))

    evaluator = engine.Engine(_update)

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    return evaluator
Exemplo n.º 13
0
    def create_trainer(self,
                       profile: Profile,
                       shared: Storage,
                       logger: Logger,
                       model: nn.Module,
                       loss_function: nn.Module,
                       optimizer: optim.Optimizer,
                       lr_scheduler: Any,
                       output_transform=lambda x, y, y_pred, loss: loss.item(),
                       **kwargs) -> engine.Engine:
        """
        Build the trainer engine. Re-implement this function when you
        want to customize the updating actions of training.

        Args:
            profile: Runtime profile defined in TOML file.
            shared: Shared storage in the whole lifecycle.
            logger: The logger named with this Task.
            model: The model to train.
            loss_function: The loss function to train.
            optimizer: The optimizer to train.
            lr_scheduler: The scheduler to control the learning rate.
            output_transform: The action to transform the output of the model.

        Returns:
            The trainer engine.
        """
        if 'device' in profile:
            device_type = profile.device
        else:
            device_type = 'cpu'

        if 'non_blocking' in profile:
            non_blocking = profile.non_blocking
        else:
            non_blocking = False

        if 'deterministic' in profile:
            deterministic = profile.deterministic
        else:
            deterministic = False

        def _update(_engine: engine.Engine, _batch: Tuple[torch.Tensor]):
            model.train()
            optimizer.zero_grad()
            x, y = self.prepare_train_batch(profile,
                                            shared,
                                            logger,
                                            _batch,
                                            device=device_type,
                                            non_blocking=non_blocking)
            y_pred = model(x)
            loss = loss_function(y_pred, y)
            loss.backward()

            optimizer.step()
            if lr_scheduler is not None:
                lr_scheduler.step(loss)

            return output_transform(x, y, y_pred, loss)

        trainer = engine.Engine(
            _update) if not deterministic else engine.DeterministicEngine(
                _update)

        return trainer
Exemplo n.º 14
0
    def __init__(self,
                 model,
                 params,
                 config,
                 eval_data_iter=None,
                 optimizer="adam",
                 grad_clip_norm=5.0,
                 grad_noise_weight=0.01):
        self._logger = logging.getLogger(__name__)

        self.model = model
        self.params = params
        self.config = config
        self.device = config.device
        self.model_dir = config.model_dir
        self.metrics = {}

        self._optimizer_str = optimizer
        self.grad_clip_norm = grad_clip_norm
        self.grad_noise_weight = grad_noise_weight

        self.train_engine = engine.Engine(self._update_fn)
        self.train_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                            self._print_eta_handler)

        if config.save_summary_steps > 0:
            # Summary buffer and tensorboardX writer.
            self._summary_writer = SummaryWriter(config.model_dir)
            self.summary = SummaryBuffer(self._summary_writer)

            # Try to attach summary writer to the model.
            try:
                model.attach_summary_writer(self.summary)
                get_worth_manager().attach_summary_writer(self.summary)
                self.train_engine.add_event_handler(
                    Events.ITERATION_COMPLETED, self.summary.writing_handler,
                    config.save_summary_steps)
            except Exception as e:
                print(e)
                self._logger.warning(
                    "Can't attach summary writer to this model. Subclass EstimatableModel to access the "
                    "summary writer.")

        else:
            self._summary_writer = None

        # Set to True after writing graph information summary.
        self._graph_written = False

        if config.evaluate_steps != 0 and eval_data_iter is None:
            raise ValueError(
                "eval_data_iter should be provided for config.evaluate_steps != 0"
            )

        self.eval_data_iter = eval_data_iter

        if config.evaluate_steps == EstimatorConfig.AFTER_EACH_EPOCH:
            self.train_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                                self._eval_handler)
            self._eval_summary_writer = SummaryWriter(
                os.path.join(config.model_dir, "eval"))
            self.eval_summary = SummaryBuffer(self._eval_summary_writer)
        elif config.evaluate_steps > 0:
            self.train_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                                self._eval_handler,
                                                config.evaluate_steps)
            self._eval_summary_writer = SummaryWriter(
                os.path.join(config.model_dir, "eval"))
            self.eval_summary = SummaryBuffer(self._eval_summary_writer)

        self.add_metric("loss", Loss(model.compute_loss))
        self.add_metric("xentropy_loss", Loss(model.loss_fn))

        self._reload_eval_engine()
        self._built = False
Exemplo n.º 15
0
def train(args, trial, is_train=True, study=None):

    hparams = HPARAMS[args.hparams]
    print(hparams)

    if hparams.model_type in {"bjorn"}:
        Dataset = src.dataset.xTSDatasetSpeakerIdEmbedding

        def prepare_batch(batch, device, non_blocking):
            for i in range(len(batch)):
                batch[i] = batch[i].to(device)
            batch_x, batch_y, _, emb = batch
            return (batch_x, batch_y, emb), batch_y

    else:
        Dataset = src.dataset.xTSDatasetSpeakerId
        prepare_batch = prepare_batch_3

    train_path_loader = PATH_LOADERS[args.dataset](ROOT,
                                                   args.filelist + "-train")
    valid_path_loader = PATH_LOADERS[args.dataset](ROOT,
                                                   args.filelist + "-valid")

    train_dataset = Dataset(hparams,
                            train_path_loader,
                            transforms=TRAIN_TRANSFORMS)
    valid_dataset = Dataset(hparams,
                            valid_path_loader,
                            transforms=VALID_TRANSFORMS)

    kwargs = dict(batch_size=args.batch_size, collate_fn=collate_fn)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               shuffle=True,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               shuffle=False,
                                               **kwargs)

    num_speakers = len(train_path_loader.speaker_to_id)
    dataset_parameters = DATASET_PARAMETERS[args.dataset]
    dataset_parameters["num_speakers"] = num_speakers

    hparams = update_namespace(hparams, trial.parameters)
    model = MODELS[hparams.model_type](dataset_parameters, hparams)
    model_speaker = MODELS_SPEAKER[hparams.model_speaker_type](
        hparams.encoder_embedding_dim, num_speakers)

    if model.speaker_info is SpeakerInfo.EMBEDDING and hparams.embedding_normalize:
        model.embedding_stats = train_dataset.embedding_stats

    if hparams.drop_frame_rate:
        path_mel_mean = os.path.join("output", "mel-mean",
                                     f"{args.dataset}-{args.filelist}.npz")
        mel_mean = cache(compute_mel_mean,
                         path_mel_mean)(train_dataset)["mel_mean"]
        mel_mean = torch.tensor(mel_mean).float().to(DEVICE)
        mel_mean = mel_mean.unsqueeze(0).unsqueeze(0)
        model.decoder.mel_mean = mel_mean

    model_name = f"{args.dataset}_{args.filelist}_{args.hparams}_dispel"
    model_path = f"output/models/{model_name}.pth"

    # Initialize model from existing one.
    if args.model_path is not None:
        model.load_state_dict(torch.load(args.model_path, map_location=DEVICE))

    if hasattr(hparams, "model_speaker_path"):
        model_speaker.load_state_dict(
            torch.load(hparams.model_speaker_path, map_location=DEVICE))

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=trial.parameters["lr"])  # 0.001
    optimizer_speaker = torch.optim.Adam(model_speaker.parameters(), lr=0.001)

    mse_loss = nn.MSELoss()

    def loss_reconstruction(pred, true):
        pred1, pred2 = pred
        return mse_loss(pred1, true) + mse_loss(pred2, true)

    if hasattr(hparams, "loss_speaker_weight"):
        λ = hparams.loss_speaker_weight
    else:
        λ = 0.0002

    model.to(DEVICE)
    model_speaker.to(DEVICE)

    def step(engine, batch):
        model.train()
        model_speaker.train()

        x, y = prepare_batch(batch, device=DEVICE, non_blocking=True)
        i = batch[2].to(DEVICE)

        # Generator: generates audio and dispels speaker identity
        y_pred, z = model.forward_emb(x)
        i_pred = model_speaker.forward(z)

        entropy_s = (-i_pred.exp() *
                     i_pred).sum(dim=1).mean()  # entropy on speakers
        loss_r = loss_reconstruction(y_pred, y)  # reconstruction
        loss_g = loss_r - λ * entropy_s  # generator

        optimizer.zero_grad()
        loss_g.backward(retain_graph=True)
        optimizer.step()

        # Discriminator: predicts speaker identity
        optimizer_speaker.zero_grad()
        loss_s = F.nll_loss(i_pred, i)
        loss_s.backward()
        optimizer_speaker.step()

        return {
            "loss-generator": loss_g.item(),
            "loss-reconstruction": loss_r.item(),
            "loss-speaker": loss_s.item(),
            "entropy-speaker": entropy_s,
        }

    trainer = engine.Engine(step)

    # trainer = engine.create_supervised_trainer(
    #     model, optimizer, loss, device=device, prepare_batch=prepare_batch
    # )

    evaluator = engine.create_supervised_evaluator(
        model,
        metrics={"loss": ignite.metrics.Loss(loss_reconstruction)},
        device=DEVICE,
        prepare_batch=prepare_batch,
    )

    @trainer.on(engine.Events.ITERATION_COMPLETED)
    def log_training_loss(trainer):
        print(
            "Epoch {:3d} | Loss gen.: {:+8.6f} = {:8.6f} - λ * {:8.6f} | Loss disc.: {:8.6f}"
            .format(
                trainer.state.epoch,
                trainer.state.output["loss-generator"],
                trainer.state.output["loss-reconstruction"],
                trainer.state.output["entropy-speaker"],
                trainer.state.output["loss-speaker"],
            ))

    @trainer.on(engine.Events.ITERATION_COMPLETED(every=EVERY_K_ITERS))
    def log_validation_loss(trainer):
        evaluator.run(valid_loader)
        metrics = evaluator.state.metrics
        print("Epoch {:3d} Valid loss: {:8.6f} ←".format(
            trainer.state.epoch, metrics["loss"]))

    lr_reduce = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               verbose=args.verbose,
                                               **LR_REDUCE_PARAMS)

    @evaluator.on(engine.Events.COMPLETED)
    def update_lr_reduce(engine):
        loss = engine.state.metrics["loss"]
        lr_reduce.step(loss)

    @evaluator.on(engine.Events.COMPLETED)
    def terminate_study(engine):
        """Stops underperforming trials."""
        if study and study.should_trial_stop(trial=trial):
            trainer.terminate()

    def score_function(engine):
        return -engine.state.metrics["loss"]

    early_stopping_handler = ignite.handlers.EarlyStopping(
        patience=PATIENCE, score_function=score_function, trainer=trainer)
    evaluator.add_event_handler(engine.Events.COMPLETED,
                                early_stopping_handler)

    if is_train:

        def global_step_transform(*args):
            return trainer.state.iteration // EVERY_K_ITERS

        checkpoint_handler = ignite.handlers.ModelCheckpoint(
            "output/models/checkpoints",
            model_name,
            score_name="objective",
            score_function=score_function,
            n_saved=5,
            require_empty=False,
            create_dir=True,
            global_step_transform=global_step_transform,
        )
        evaluator.add_event_handler(engine.Events.COMPLETED,
                                    checkpoint_handler, {"model": model})

    trainer.run(train_loader, max_epochs=args.max_epochs)

    if is_train:
        torch.save(model.state_dict(), model_path)
        print("Last model @", model_path)

        model_best_path = link_best_model(model_name)
        print("Best model @", model_best_path)

    return evaluator.state.metrics["loss"]
Exemplo n.º 16
0
 def __init__(self, dataset, device, max_epochs=1):
     super(Engine, self).__init__()
     self.dataset = dataset
     self.device = device
     self.max_epochs = max_epochs
     self.engine = e.Engine(self._update)
Exemplo n.º 17
0
            return_dict = {
                'input_filename': batch['input_filename'],
                'mask': masks
            }

            if 'TAPNet' in args.model:
                # for TAPNet, update attention maps after each iteration
                eval_loader.dataset.update_attmaps(output_logsoftmax_np,
                                                   batch['idx'].numpy())
                # for TAPNet, return extra internal values
                return_dict['attmap'] = add_params['attmap']

            return return_dict


evaluator = engine.Engine(eval_step)

eval_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
#valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
eval_pbar.attach(evaluator)


# evaluate after iter finish
@evaluator.on(engine.Events.ITERATION_COMPLETED)
def evaluator_epoch_comp_callback(engine):
    # save masks for each batch
    batch_output = engine.state.output
    input_filenames = batch_output['input_filename']
    masks = batch_output['mask']

    for i, input_filename in enumerate(input_filenames):
Exemplo n.º 18
0
def evaluator_epoch_comp_callback(engine):
    # save masks for each batch
    batch_output = engine.state.output
    input_filenames = batch_output['input_filename']
    masks = batch_output['mask']

    for i, input_filename in enumerate(input_filenames):
        mask = cv2.resize(masks[i],
                          dsize=(utils.cropped_width, utils.cropped_height),
                          interpolation=cv2.INTER_AREA)

        # if pad:
        #     h_start, w_start = utils.h_start, utils.w_start
        #     h, w = mask.shape
        #     # recover to original shape
        #     full_mask = np.zeros((original_height, original_width))
        #     full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask
        #     mask = full_mask
        #print("Input Filename-->", input_filename)
        #instrument_folder_name = input_filename.parent.parent.name
        instrument_folder_name = os.path.basename(
            os.path.dirname(os.path.dirname(input_filename)))
        #print("instrument_folder_name-->", instrument_folder_name)

        # mask_folder/instrument_dataset_x/problem_type_masks/framexxx.png
        mask_folder = mask_save_dir / instrument_folder_name / utils.mask_folder[
            args.problem_type]
        mask_folder.mkdir(exist_ok=True, parents=True)
        mask_filename = mask_folder / os.path.basename(input_filename)
        #print("mask_filename-->", mask_filename)
        cv2.imwrite(str(mask_filename), mask)

        if 'TAPNet' in args.model:
            attmap = batch_output['attmap'][i]

            attmap_folder = mask_save_dir / instrument_folder_name / '_'.join(
                args.problem_type, 'attmaps')
            attmap_folder.mkdir(exist_ok=True, parents=True)
            attmap_filename = attmap_folder / input_filename.name

            cv2.imwrite(str(attmap_filename), attmap)

    evaluator.run(eval_loader)

    # validator engine
    validator = engine.Engine(valid_step)

    # monitor loss
    valid_ra_loss = imetrics.RunningAverage(
        output_transform=lambda x: x['loss'], alpha=0.98)
    valid_ra_loss.attach(validator, 'valid_ra_loss')

    # monitor validation loss over epoch
    valid_loss = imetrics.Loss(loss_func,
                               output_transform=lambda x:
                               (x['output'], x['target']))
    valid_loss.attach(validator, 'valid_loss')

    # monitor <data> mean metrics
    valid_data_miou = imetrics.RunningAverage(
        output_transform=lambda x: x['iou'].data_mean()['mean'], alpha=0.98)
    valid_data_miou.attach(validator, 'mIoU')
    valid_data_mdice = imetrics.RunningAverage(
        output_transform=lambda x: x['dice'].data_mean()['mean'], alpha=0.98)
    valid_data_mdice.attach(validator, 'mDice')

    # show metrics on progress bar (after every iteration)
    valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice']
    valid_pbar.attach(validator, metric_names=valid_metric_names)

    # ## monitor ignite IoU (the same as iou we are using) ###
    # cm = imetrics.ConfusionMatrix(num_classes,
    #     output_transform=lambda x: (x['output'], x['target']))
    # imetrics.IoU(cm,
    #     ignore_index=0
    #     ).attach(validator, 'iou')

    # # monitor ignite mean iou (over all classes even not exist in gt)
    # mean_iou = imetrics.mIoU(cm,
    #     ignore_index=0
    #     ).attach(validator, 'mean_iou')

    @validator.on(engine.Events.STARTED)
    def validator_start_callback(engine):
        pass

    @validator.on(engine.Events.EPOCH_STARTED)
    def validator_epoch_start_callback(engine):
        engine.state.epoch_metrics = {
            # directly use definition to calculate
            'iou':
            MetricRecord(),
            'dice':
            MetricRecord(),
            'confusion_matrix':
            np.zeros((num_classes, num_classes), dtype=np.uint32),
        }

    # evaluate after iter finish
    @validator.on(engine.Events.ITERATION_COMPLETED)
    def validator_iter_comp_callback(engine):
        pass

    # evaluate after epoch finish
    @validator.on(engine.Events.EPOCH_COMPLETED)
    def validator_epoch_comp_callback(engine):

        # log ignite metrics
        # logging_logger.info(engine.state.metrics)
        # ious = engine.state.metrics['iou']
        # msg = 'IoU: '
        # for ins_id, iou in enumerate(ious):
        #     msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou)
        # logging_logger.info(msg)
        # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean()))

        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics

        ######### NOTICE: Two metrics are available but different ##########
        ### 1. mean metrics for all data calculated by confusion matrix ####
        '''
        compared with using confusion_matrix[1:, 1:] in original code,
        we use the full confusion matrix and only present non-background result
        '''
        confusion_matrix = epoch_metrics['confusion_matrix']  # [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        std_ious = np.std(list(ious.values()))
        std_dices = np.std(list(dices.values()))

        logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' %
                            (mean_ious, std_ious, ious))
        logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' %
                            (mean_dices, std_dices, dices))

        ### 2. mean metrics for all data calculated by definition ###
        iou_data_mean = epoch_metrics['iou'].data_mean()
        dice_data_mean = epoch_metrics['dice'].data_mean()

        logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' %
                            (len(iou_data_mean['items']),
                             iou_data_mean['mean'], iou_data_mean['std']))
        logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' %
                            (len(dice_data_mean['items']),
                             dice_data_mean['mean'], dice_data_mean['std']))

        # record metrics in trainer every epoch
        # trainer.state.metrics_records[trainer.state.epoch] = \
        #     {'miou': mean_ious, 'std_miou': std_ious,
        #     'mdice': mean_dices, 'std_mdice': std_dices}

        trainer.state.metrics_records[trainer.state.epoch] = \
            {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'],
            'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']}

    # log interal variables(attention maps, outputs, etc.) on validation
    def tb_log_valid_iter_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        output = engine.state.output
        batch_size = output['output'].shape[0]
        res_grid = tvutils.make_grid(
            torch.cat([
                output['output_argmax'].unsqueeze(1),
                output['target'].unsqueeze(1),
            ]),
            padding=2,
            normalize=False,  # show origin image
            nrow=batch_size).cpu()

        logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag),
                                img_tensor=res_grid)

        if 'TAPNet' in args.model:
            # log attention maps and other internal values
            inter_vals_grid = tvutils.make_grid(torch.cat([
                output['attmap'],
            ]),
                                                padding=2,
                                                normalize=True,
                                                nrow=batch_size).cpu()
            logger.writer.add_image(tag='%s internal vals' % (log_tag),
                                    img_tensor=inter_vals_grid)

    def tb_log_valid_epoch_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics
        confusion_matrix = epoch_metrics['confusion_matrix']  # [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch)
        logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch)

    if args.tb_log:
        # log internal values
        tb_logger.attach(validator,
                         log_handler=tb_log_valid_iter_vars,
                         event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator,
                         log_handler=tb_log_valid_epoch_vars,
                         event_name=engine.Events.EPOCH_COMPLETED)
        # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names),
        #     event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator,
                         log_handler=OutputHandler('valid_epoch',
                                                   ['valid_loss']),
                         event_name=engine.Events.EPOCH_COMPLETED)

    # score function for model saving
    ckpt_score_function = lambda engine: \
        np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values()))
    # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean']

    ckpt_filename_prefix = 'fold_%d' % fold

    # model saving handler
    model_ckpt_handler = handlers.ModelCheckpoint(
        dirname=args.model_save_dir,
        filename_prefix=ckpt_filename_prefix,
        score_function=ckpt_score_function,
        create_dir=True,
        require_empty=False,
        save_as_state_dict=True,
        atomic=True)

    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED,
                                handler=model_ckpt_handler,
                                to_save={
                                    'model': model,
                                })

    # early stop
    # trainer=trainer, but should be handled by validator
    early_stopping = handlers.EarlyStopping(patience=args.es_patience,
                                            score_function=ckpt_score_function,
                                            trainer=trainer)

    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED,
                                handler=early_stopping)

    # evaluate after epoch finish
    @trainer.on(engine.Events.EPOCH_COMPLETED)
    def trainer_epoch_comp_callback(engine):
        validator.run(valid_loader)

    trainer.run(train_loader, max_epochs=args.max_epochs)

    if args.tb_log:
        # close tb_logger
        tb_logger.close()

    return trainer.state.metrics_records
def process_fold(fold, args):
    num_classes = utils.problem_class[args.problem_type]
    factor = utils.problem_factor[args.problem_type]
    # inputs are RGB images (3 * h * w)
    # outputs are 2d multilabel segmentation maps (h * w)
    model = eval(args.model)(in_channels=3, num_classes=num_classes)
    # data parallel for multi-GPU
    model = nn.DataParallel(model, device_ids=args.device_ids).cuda()

    ckpt_dir = Path(args.ckpt_dir)
    #p = pathlib.Path(ckpt_dir)
    # ckpt for this fold fold_<fold>_model_<epoch>.pth
    print("ckpt_dir--> ", ckpt_dir)
    filenames = glob.glob(args.ckpt_dir + 'fold_%d_model_[0-99]*.pth' % fold)
    #filenames = glob.glob(args.ckpt_dir+'fold_%d_model_[0-99]*.pth')
    #filenames = ckpt_dir.glob(args.ckpt_dir+'fold_%d_model_[0-9]*.pth'%fold)

    print("Filename--> ", filenames)
    # if len(filenames) != 1:
    #    raise ValueError('invalid model ckpt name. correct ckpt name should be \
    #        fold_<fold>_model_<epoch>.pth')

    ckpt_filename = filenames[0]
    # load state dict
    model.load_state_dict(torch.load(str(ckpt_filename)))
    logging.info('Restored model [{}] fold {}.'.format(args.model, fold))

    # segmentation mask save directory
    mask_save_dir = Path(args.mask_save_dir) / ckpt_dir.name
    mask_save_dir.mkdir(exist_ok=True, parents=True)
    #print("mask_save_dir", mask_save_dir)

    eval_transform = Compose(
        [
            Normalize(p=1),
            PadIfNeeded(
                min_height=args.input_height, min_width=args.input_width, p=1),

            # optional
            Resize(height=args.input_height, width=args.input_width, p=1),
            # CenterCrop(height=args.input_height, width=args.input_width, p=1)
        ],
        p=1)

    # train/valid filenames,
    # we evaluate and generate masks on validation set
    train_filenames, valid_filenames = utils.trainval_split(
        args.train_dir, fold)

    eval_num_workers = args.num_workers
    eval_batch_size = args.batch_size
    # additional ds args
    if 'TAPNet' in args.model:
        # in eval, num_workers should be set to 0 for sequences
        eval_num_workers = 0
        # in eval, batch_size should be set to 1 for sequences
        eval_batch_size = 1

    # additional eval dataset kws
    eval_ds_kwargs = {
        'filenames': train_filenames,
        'problem_type': args.problem_type,
        'transform': eval_transform,
        'model': args.model,
        'mode': 'eval',
    }

    # valid dataloader
    eval_loader = DataLoader(
        dataset=RobotSegDataset(**eval_ds_kwargs),
        shuffle=False,  # in eval, no need to shuffle
        num_workers=eval_num_workers,
        batch_size=
        eval_batch_size,  # in valid time. have to use one image by one
        pin_memory=True)

    # process function for ignite engine
    def eval_step(engine, batch):
        with torch.no_grad():
            model.eval()
            #print("batch Keys-->", batch.keys())
            inputs = batch['input'].cuda(non_blocking=True)
            #targets = batch['target'].cuda(non_blocking=True)

            # additional arguments
            add_params = {}
            # for TAPNet, add attention maps
            if 'TAPNet' in args.model:
                add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

            outputs = model(inputs, **add_params)
            output_logsoftmax_np = torch.softmax(outputs, dim=1).cpu().numpy()
            # output_classes and target_classes: <b, h, w>
            output_classes = output_logsoftmax_np.argmax(axis=1)
            masks = (output_classes * factor).astype(np.uint8)
            #print(size(masks))

            return_dict = {
                'input_filename': batch['input_filename'],
                'mask': masks
            }

            if 'TAPNet' in args.model:
                # for TAPNet, update attention maps after each iteration
                eval_loader.dataset.update_attmaps(output_logsoftmax_np,
                                                   batch['idx'].numpy())
                # for TAPNet, return extra internal values
                return_dict['attmap'] = add_params['attmap']

            return return_dict

    # eval engine
    evaluator = engine.Engine(eval_step)

    eval_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    #valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    eval_pbar.attach(evaluator)

    # evaluate after iter finish

    @evaluator.on(engine.Events.ITERATION_COMPLETED)
    def evaluator_epoch_comp_callback(engine):
        global Average_batch_IoU
        # save masks for each batch
        batch_output = engine.state.output
        input_filenames = batch_output['input_filename']
        #print("Input_filenames--> ", input_filenames)
        masks = batch_output['mask']
        iou = []
        #Average_batch_IoU = []
        for i, input_filename in enumerate(input_filenames):
            mask = cv2.resize(masks[i],
                              dsize=(utils.cropped_width,
                                     utils.cropped_height),
                              interpolation=cv2.INTER_AREA)

            # if pad:
            #     h_start, w_start = utils.h_start, utils.w_start
            #     h, w = mask.shape
            #     # recover to original shape
            #     full_mask = np.zeros((original_height, original_width))
            #     full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask
            #     mask = full_mask
            #print("Input Filename-->", input_filename)
            #img = cv2.imread(input_filename)
            #instrument_folder_name = input_filename.parent.parent.name
            instrument_folder_name = os.path.basename(
                os.path.dirname(os.path.dirname(input_filename)))
            #print("instrument_folder_name-->", instrument_folder_name)
            binary_mask = Path(args.type_mask)
            gt_folder = os.path.dirname(
                os.path.dirname(input_filename)) / binary_mask
            #print("gt_folder-->", gt_folder)
            gt_filename = gt_folder / os.path.basename(input_filename)
            #print("gt_filename-->", gt_filename)
            # mask_folder/instrument_dataset_x/problem_type_masks/framexxx.png
            mask_folder = mask_save_dir / instrument_folder_name / utils.mask_folder[
                args.problem_type]
            mask_folder.mkdir(exist_ok=True, parents=True)
            mask_filename = mask_folder / os.path.basename(input_filename)

            gt_mask = cv2.imread(str(gt_filename), cv2.CV_8UC1)
            #print("mask_filename-->", mask_filename)
            cv2.imwrite(str(mask_filename), mask)

            assert (mask.shape == gt_mask.shape)
            image_iou = get_iou(mask, gt_mask)
            if math.isnan(image_iou) == False:
                iou.append(image_iou)
                #print("IoU for image {} = {}".format(input_filename, iou[-1]))

            if 'TAPNet' in args.model:
                attmap = batch_output['attmap'][i]

                attmap_folder = mask_save_dir / instrument_folder_name / '_'.join(
                    args.problem_type, 'attmaps')
                attmap_folder.mkdir(exist_ok=True, parents=True)
                attmap_filename = attmap_folder / os.path.basename(
                    input_filename)

                cv2.imwrite(str(attmap_filename), attmap)
            #Average_batch_IoU.append(np.mean(iou))
        #Average_batch_IoU = list(np.mean(iou))
        Average_batch_IoU.append(np.nanmean(iou))
        #

    evaluator.run(eval_loader)
    print("Average_batch_IoU-->", np.nanmean(Average_batch_IoU))
    f.write(str(np.nanmean(Average_batch_IoU)))
    f.write('\n')
Exemplo n.º 20
0
def train_fold(fold, args):
    # loggers
    logging_logger = args.logging_logger
    if args.tb_log:
        tb_logger = args.tb_logger

    num_classes = utils.problem_class[args.problem_type]

    # init model
    model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False)
    model = nn.DataParallel(model, device_ids=args.device_ids).cuda()

    # transform for train/valid data
    train_transform, valid_transform = get_transform(args.model)

    # loss function
    loss_func = LossMulti(num_classes, args.jaccard_weight)
    if args.semi:
        loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method)

    # train/valid filenames
    train_filenames, valid_filenames = utils.trainval_split(args.train_dir, fold)

    # DataLoader and Dataset args
    train_shuffle = True
    train_ds_kwargs = {
        'filenames': train_filenames,
        'problem_type': args.problem_type,
        'transform': train_transform,
        'model': args.model,
        'mode': 'train',
        'semi': args.semi,
    }

    valid_num_workers = args.num_workers
    valid_batch_size = args.batch_size
    if 'TAPNet' in args.model:
        # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead
        train_shuffle = False
        train_ds_kwargs['batch_size'] = args.batch_size
        train_ds_kwargs['mf'] = args.mf
    if args.semi == True:
        train_ds_kwargs['semi_method'] = args.semi_method
        train_ds_kwargs['semi_percentage'] = args.semi_percentage

    # additional valid dataset kws
    valid_ds_kwargs = {
        'filenames': valid_filenames,
        'problem_type': args.problem_type,
        'transform': valid_transform,
        'model': args.model,
        'mode': 'valid',
    }

    if 'TAPNet' in args.model:
        # in validation, num_workers should be set to 0 for sequences
        valid_num_workers = 0
        # in validation, batch_size should be set to 1 for sequences
        valid_batch_size = 1
        valid_ds_kwargs['mf'] = args.mf

    # train dataloader
    train_loader = DataLoader(
        dataset=RobotSegDataset(**train_ds_kwargs),
        shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle
        num_workers=args.num_workers,
        batch_size=args.batch_size,
        pin_memory=True
    )
    # valid dataloader
    valid_loader = DataLoader(
        dataset=RobotSegDataset(**valid_ds_kwargs),
        shuffle=False, # in validation, no need to shuffle
        num_workers=valid_num_workers,
        batch_size=valid_batch_size, # in valid time. have to use one image by one
        pin_memory=True
    )

    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, 
    #     weight_decay=args.weight_decay, nesterov=True)    

    # ignite trainer process function
    def train_step(engine, batch):
        # set model to train
        model.train()
        # clear gradients
        optimizer.zero_grad()
        
        # additional params to feed into model
        add_params = {}
        inputs = batch['input'].cuda(non_blocking=True)
        with torch.no_grad():
            targets = batch['target'].cuda(non_blocking=True)
            # for TAPNet, add attention maps
            if 'TAPNet' in args.model:
                add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

        outputs = model(inputs, **add_params)

        loss_kwargs = {}

        if args.semi:
            loss_kwargs['labeled'] = batch['labeled']
            if args.semi_method == 'rev_flow':
                loss_kwargs['optflow'] = batch['optflow']
            loss = loss_func_semi(outputs, targets, **loss_kwargs)
        else:
            loss = loss_func(outputs, targets, **loss_kwargs)
        loss.backward()
        optimizer.step()

        return_dict = {
            'output': outputs,
            'target': targets,
            'loss_kwargs': loss_kwargs,
            'loss': loss.item(),
        }

        # for TAPNet, update attention maps after each iteration
        if 'TAPNet' in args.model:
            # output_classes and target_classes: <b, h, w>
            output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy()
            # update attention maps
            train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy())
            return_dict['attmap'] = add_params['attmap']

        return return_dict
    
    # init trainer
    trainer = engine.Engine(train_step)

    # lr scheduler and handler
    # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr)
    # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler)
    # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler)

    step_scheduler = optim.lr_scheduler.StepLR(optimizer,
        step_size=args.lr_decay_epochs, gamma=args.lr_decay)
    lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler)
    trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler)


    @trainer.on(engine.Events.STARTED)
    def trainer_start_callback(engine):
        logging_logger.info('training fold {}, {} train / {} valid files'. \
            format(fold, len(train_filenames), len(valid_filenames)))

        # resume training
        if args.resume:
            # ckpt for current fold fold_<fold>_model_<epoch>.pth
            ckpt_dir = Path(args.ckpt_dir)
            ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0]
            res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename)
            # restore epoch
            engine.state.epoch = int(res.groups()[0])
            # load model state dict
            model.load_state_dict(torch.load(str(ckpt_filename)))
            logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch))
        else:
            logging_logger.info('train model [{}] from scratch'.format(args.model))

        # record metrics history every epoch
        engine.state.metrics_records = {}


    @trainer.on(engine.Events.EPOCH_STARTED)
    def trainer_epoch_start_callback(engine):
        # log learning rate on pbar
        train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \
            (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size))
        
        # for TAPNet, change dataset schedule to random after the first epoch
        if 'TAPNet' in args.model and engine.state.epoch > 1:
            train_loader.dataset.set_dataset_schedule("shuffle")


    @trainer.on(engine.Events.ITERATION_COMPLETED)
    def trainer_iter_comp_callback(engine):
        # logging_logger.info(engine.state.metrics)
        pass

    # monitor loss
    # running average loss
    train_ra_loss = imetrics.RunningAverage(output_transform=
        lambda x: x['loss'], alpha=0.98)
    train_ra_loss.attach(trainer, 'train_ra_loss')

    # monitor train loss over epoch
    if args.semi:
        train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs']))
    else:
        train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target']))
    train_loss.attach(trainer, 'train_loss')

    # progress bar
    train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    train_metric_names = ['train_ra_loss']
    train_pbar.attach(trainer, metric_names=train_metric_names)

    # tensorboardX: log train info
    if args.tb_log:
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), 
            event_name=engine.Events.EPOCH_STARTED)

        tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names),
            event_name=engine.Events.ITERATION_COMPLETED)

        tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']),
            event_name=engine.Events.EPOCH_COMPLETED)

        tb_logger.attach(trainer,
             log_handler=WeightsScalarHandler(model, reduction=torch.norm),
             event_name=engine.Events.ITERATION_COMPLETED)

        # tb_logger.attach(trainer, log_handler=tb_log_train_vars, 
        #     event_name=engine.Events.ITERATION_COMPLETED)


    # ignite validator process function
    def valid_step(engine, batch):
        with torch.no_grad():
            model.eval()
            inputs = batch['input'].cuda(non_blocking=True)
            targets = batch['target'].cuda(non_blocking=True)

            # additional arguments
            add_params = {}
            # for TAPNet, add attention maps
            if 'TAPNet' in args.model:
                add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

            # output logits
            outputs = model(inputs, **add_params)
            # loss
            loss = loss_func(outputs, targets)

            output_softmaxs = torch.softmax(outputs, dim=1)
            output_argmaxs = output_softmaxs.argmax(dim=1)
            # output_classes and target_classes: <b, h, w>
            output_classes = output_argmaxs.cpu().numpy()
            target_classes = targets.cpu().numpy()

            # record current batch metrics
            iou_mRecords = MetricRecord()
            dice_mRecords = MetricRecord()

            cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32)

            for output_class, target_class in zip(output_classes, target_classes):
                # calculate metrics for each frame
                # calculate using confusion matrix or dirctly using definition
                cm = calculate_confusion_matrix_from_arrays(output_class, target_class, num_classes)
                iou_mRecords.update_record(calculate_iou(cm))
                dice_mRecords.update_record(calculate_dice(cm))
                cm_b += cm

                ######## calculate directly using definition ##########
                # iou_mRecords.update_record(iou_multi_np(target_class, output_class))
                # dice_mRecords.update_record(dice_multi_np(target_class, output_class))

            # accumulate batch metrics to engine state
            engine.state.epoch_metrics['confusion_matrix'] += cm_b
            engine.state.epoch_metrics['iou'].merge(iou_mRecords)
            engine.state.epoch_metrics['dice'].merge(dice_mRecords)


            return_dict = {
                'loss': loss.item(),
                'output': outputs,
                'output_argmax': output_argmaxs,
                'target': targets,
                # for monitoring
                'iou': iou_mRecords,
                'dice': dice_mRecords,
            }

            if 'TAPNet' in args.model:
                # for TAPNet, update attention maps after each iteration
                valid_loader.dataset.update_attmaps(output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy())
                # for TAPNet, return extra internal values
                return_dict['attmap'] = add_params['attmap']
                # TODO: for TAPNet, return internal self-learned attention maps

            return return_dict


    # validator engine
    validator = engine.Engine(valid_step)

    # monitor loss
    valid_ra_loss = imetrics.RunningAverage(output_transform=
        lambda x: x['loss'], alpha=0.98)
    valid_ra_loss.attach(validator, 'valid_ra_loss')

    # monitor validation loss over epoch
    valid_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target']))
    valid_loss.attach(validator, 'valid_loss')
    
    # monitor <data> mean metrics
    valid_data_miou = imetrics.RunningAverage(output_transform=
        lambda x: x['iou'].data_mean()['mean'], alpha=0.98)
    valid_data_miou.attach(validator, 'mIoU')
    valid_data_mdice = imetrics.RunningAverage(output_transform=
        lambda x: x['dice'].data_mean()['mean'], alpha=0.98)
    valid_data_mdice.attach(validator, 'mDice')

    # show metrics on progress bar (after every iteration)
    valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice']
    valid_pbar.attach(validator, metric_names=valid_metric_names)


    # ## monitor ignite IoU (the same as iou we are using) ###
    # cm = imetrics.ConfusionMatrix(num_classes, 
    #     output_transform=lambda x: (x['output'], x['target']))
    # imetrics.IoU(cm, 
    #     ignore_index=0
    #     ).attach(validator, 'iou')

    # # monitor ignite mean iou (over all classes even not exist in gt)
    # mean_iou = imetrics.mIoU(cm, 
    #     ignore_index=0
    #     ).attach(validator, 'mean_iou')


    @validator.on(engine.Events.STARTED)
    def validator_start_callback(engine):
        pass

    @validator.on(engine.Events.EPOCH_STARTED)
    def validator_epoch_start_callback(engine):
        engine.state.epoch_metrics = {
            # directly use definition to calculate
            'iou': MetricRecord(),
            'dice': MetricRecord(),
            'confusion_matrix': np.zeros((num_classes, num_classes), dtype=np.uint32),
        }


    # evaluate after iter finish
    @validator.on(engine.Events.ITERATION_COMPLETED)
    def validator_iter_comp_callback(engine):
        pass

    # evaluate after epoch finish
    @validator.on(engine.Events.EPOCH_COMPLETED)
    def validator_epoch_comp_callback(engine):

        # log ignite metrics
        # logging_logger.info(engine.state.metrics)
        # ious = engine.state.metrics['iou']
        # msg = 'IoU: '
        # for ins_id, iou in enumerate(ious):
        #     msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou)
        # logging_logger.info(msg)
        # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean()))

        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics

        ######### NOTICE: Two metrics are available but different ##########
        ### 1. mean metrics for all data calculated by confusion matrix ####

        '''
        compared with using confusion_matrix[1:, 1:] in original code,
        we use the full confusion matrix and only present non-background result
        '''
        confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        std_ious = np.std(list(ious.values()))
        std_dices = np.std(list(dices.values()))

        logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % 
            (mean_ious, std_ious, ious))
        logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % 
            (mean_dices, std_dices, dices))


        ### 2. mean metrics for all data calculated by definition ###
        iou_data_mean = epoch_metrics['iou'].data_mean()
        dice_data_mean = epoch_metrics['dice'].data_mean()

        logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' %
            (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std']))
        logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' %
            (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std']))

        # record metrics in trainer every epoch
        # trainer.state.metrics_records[trainer.state.epoch] = \
        #     {'miou': mean_ious, 'std_miou': std_ious,
        #     'mdice': mean_dices, 'std_mdice': std_dices}
        
        trainer.state.metrics_records[trainer.state.epoch] = \
            {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'],
            'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']}


    # log interal variables(attention maps, outputs, etc.) on validation
    def tb_log_valid_iter_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        output = engine.state.output
        batch_size = output['output'].shape[0]
        res_grid = tvutils.make_grid(torch.cat([
            output['output_argmax'].unsqueeze(1),
            output['target'].unsqueeze(1),
        ]), padding=2, 
        normalize=False, # show origin image
        nrow=batch_size).cpu()

        logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag), img_tensor=res_grid)

        if 'TAPNet' in args.model:
            # log attention maps and other internal values
            inter_vals_grid = tvutils.make_grid(torch.cat([
                output['attmap'],
            ]), padding=2, normalize=True, nrow=batch_size).cpu()
            logger.writer.add_image(tag='%s internal vals' % (log_tag), img_tensor=inter_vals_grid)

    def tb_log_valid_epoch_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics
        confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch)
        logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch)



    if args.tb_log:
        # log internal values
        tb_logger.attach(validator, log_handler=tb_log_valid_iter_vars, 
            event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator, log_handler=tb_log_valid_epoch_vars,
            event_name=engine.Events.EPOCH_COMPLETED)
        # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names),
        #     event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator, log_handler=OutputHandler('valid_epoch', ['valid_loss']),
            event_name=engine.Events.EPOCH_COMPLETED)


    # score function for model saving
    ckpt_score_function = lambda engine: \
        np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values()))
    # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean']
    
    ckpt_filename_prefix = 'fold_%d' % fold

    # model saving handler
    model_ckpt_handler = handlers.ModelCheckpoint(
        dirname=args.model_save_dir,
        filename_prefix=ckpt_filename_prefix, 
        score_function=ckpt_score_function,
        create_dir=True,
        require_empty=False,
        save_as_state_dict=True,
        atomic=True)


    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, 
        handler=model_ckpt_handler,
        to_save={
            'model': model,
        })

    # early stop
    # trainer=trainer, but should be handled by validator
    early_stopping = handlers.EarlyStopping(patience=args.es_patience, 
        score_function=ckpt_score_function,
        trainer=trainer
        )

    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED,
        handler=early_stopping)


    # evaluate after epoch finish
    @trainer.on(engine.Events.EPOCH_COMPLETED)
    def trainer_epoch_comp_callback(engine):
        validator.run(valid_loader)

    trainer.run(train_loader, max_epochs=args.max_epochs)

    if args.tb_log:
        # close tb_logger
        tb_logger.close()

    return trainer.state.metrics_records
Exemplo n.º 21
0
def create_supervised_trainer(model,
                              optimizer,
                              loss_fn,
                              device=None,
                              non_blocking=False,
                              prepare_batch=engine._prepare_batch,
                              check_nan=False,
                              grad_clip=None,
                              output_predictions=False):
    """As ignite.engine.create_supervised_trainer, but may also optionall perform:
    - NaN checking on predictions (in a more debuggable way than ignite.handlers.TerminateOnNaN)
    - Gradient clipping
    - Record the predictions made by a model

    Arguments:
        (as ignite.engine.create_supervised_trainer, plus)
        check_nan: Optional boolean specifying whether the engine should check predictions for NaN values. Defaults to
            False. If True, and a NaN value is encountered, then a RuntimeError will be raised with attributes 'x', 'y',
            'y_pred', 'model', details the feature, label, prediction and model, respetively, on which this occurred.
        grad_clip: Optional number, boolean or None, specifying the value to clip the infinity-norm of the gradient to.
            Defaults to None. If False or None then no gradient clipping will be applied. If True then the gradient is
            clipped to 1.0.
        output_predictions: Optional boolean specifying whether the engine should record the predictions the model made
            on a batch. Defaults to False. If True then state.output will be a tuple of (loss, predictions). If False
            then state.output will just be the loss. (Not wrapped in a tuple.)
    """

    if device:
        model.to(device)

    if grad_clip is False:
        grad_clip = None
    elif grad_clip is True:
        grad_clip = 1.0

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)

        if check_nan and torch.isnan(y_pred).any():
            e = RuntimeError('Model generated NaN value.')
            e.y = y
            e.y_pred = y_pred
            e.x = x
            e.model = model
            raise e

        loss = loss_fn(y_pred, y)
        loss.backward()

        if grad_clip is not None:
            nnutils.clip_grad_norm_(model.parameters(),
                                    grad_clip,
                                    norm_type='inf')

        optimizer.step()

        if output_predictions:
            return loss.item(), y_pred
        else:
            return loss.item()

    return engine.Engine(_update)
Exemplo n.º 22
0
    def _reload_eval_engine(self):
        self.eval_engine = engine.Engine(self._eval_fn)

        if len(self.metrics) > 0:
            for name, metric in self.metrics.items():
                metric.attach(self.eval_engine, name)
Exemplo n.º 23
0
def run(experiment_name: str,
        visdom_host: str,
        visdom_port: int,
        visdom_env_path: str,
        model_class: str,
        model_args: Dict[str, Any],
        optimizer_class: str,
        optimizer_args: Dict[str, Any],
        dataset_class: str,
        dataset_args: Dict[str, Any],
        batch_train: int,
        batch_test: int,
        workers_train: int,
        workers_test: int,
        transforms: List[Dict[str, Union[str, Dict[str, Any]]]],
        epochs: int,
        log_interval: int,
        saved_models_path: str,
        performance_metrics: Optional = None,
        scheduler_class: Optional[str] = None,
        scheduler_args: Optional[Dict[str, Any]] = None,
        model_suffix: Optional[str] = None,
        setup_suffix: Optional[str] = None,
        orig_stdout: Optional[io.TextIOBase] = None):

    with _utils.tqdm_stdout(orig_stdout) as orig_stdout:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        transforms_train = list()
        transforms_test = list()

        for idx, transform in enumerate(transforms):
            use_train = transform.get('train', True)
            use_test = transform.get('test', True)

            transform = _utils.load_class(
                transform['class'])(**transform['args'])

            if use_train:
                transforms_train.append(transform)
            if use_test:
                transforms_test.append(transform)

            transforms[idx]['train'] = use_train
            transforms[idx]['test'] = use_test

        transforms_train = tv.transforms.Compose(transforms_train)
        transforms_test = tv.transforms.Compose(transforms_test)

        Dataset: Type = _utils.load_class(dataset_class)

        train_loader, eval_loader = _utils.get_data_loaders(
            Dataset, dataset_args, batch_train, batch_test, workers_train,
            workers_test, transforms_train, transforms_test)

        Network: Type = _utils.load_class(model_class)
        model: _interfaces.AbstractNet = Network(**model_args)
        model = model.to(device)

        Optimizer: Type = _utils.load_class(optimizer_class)
        optimizer: torch.optim.Optimizer = Optimizer(model.parameters(),
                                                     **optimizer_args)

        if scheduler_class is not None:
            Scheduler: Type = _utils.load_class(scheduler_class)

            if scheduler_args is None:
                scheduler_args = dict()

            scheduler: Optional[
                torch.optim.lr_scheduler._LRScheduler] = Scheduler(
                    optimizer, **scheduler_args)
        else:
            scheduler = None

        model_short_name = ''.join(
            [c for c in Network.__name__ if c == c.upper()])
        model_name = '{}{}'.format(
            model_short_name,
            '-{}'.format(model_suffix) if model_suffix is not None else '')
        visdom_env_name = '{}_{}_{}{}'.format(
            Dataset.__name__, experiment_name, model_name,
            '-{}'.format(setup_suffix) if setup_suffix is not None else '')

        vis, vis_pid = _visdom.get_visdom_instance(visdom_host, visdom_port,
                                                   visdom_env_name,
                                                   visdom_env_path)

        prog_bar_epochs = tqdm.tqdm(total=epochs,
                                    desc='Epochs',
                                    file=orig_stdout,
                                    dynamic_ncols=True,
                                    unit='epoch')
        prog_bar_iters = tqdm.tqdm(desc='Batches',
                                   file=orig_stdout,
                                   dynamic_ncols=True)

        tqdm.tqdm.write(f'\n{repr(model)}\n')
        tqdm.tqdm.write('Total number of parameters: {:.2f}M'.format(
            sum(p.numel() for p in model.parameters()) / 1e6))

        def training_step(_: ieng.Engine,
                          batch: _interfaces.TensorPair) -> torch.Tensor:
            model.train()

            optimizer.zero_grad()

            x, y = batch

            x = x.to(device)
            y = y.to(device)

            _, loss = model(x, y)

            loss.backward(retain_graph=False)
            optimizer.step(None)

            return loss.item()

        def eval_step(_: ieng.Engine,
                      batch: _interfaces.TensorPair) -> _interfaces.TensorPair:
            model.eval()

            with torch.no_grad():
                x, y = batch

                x = x.to(device)
                y = y.to(device)

                y_pred = model(x)

            return y_pred, y

        trainer = ieng.Engine(training_step)
        validator_train = ieng.Engine(eval_step)
        validator_eval = ieng.Engine(eval_step)

        # placeholder for summary window
        vis.text(text='',
                 win=experiment_name,
                 env=visdom_env_name,
                 opts={
                     'title': 'Summary',
                     'width': 940,
                     'height': 416
                 },
                 append=vis.win_exists(experiment_name, visdom_env_name))

        default_metrics = {
            "Loss": {
                "window_name":
                None,
                "x_label":
                "#Epochs",
                "y_label":
                model.loss_fn_name,
                "width":
                940,
                "height":
                416,
                "lines": [{
                    "line_label":
                    "SMA",
                    "object":
                    imet.RunningAverage(output_transform=lambda x: x),
                    "test":
                    False,
                    "update_rate":
                    "iteration"
                }, {
                    "line_label": "Val.",
                    "object": imet.Loss(model.loss_fn)
                }]
            }
        }

        performance_metrics = {**default_metrics, **performance_metrics}
        checkpoint_metrics = list()

        for scope_name, scope in performance_metrics.items():
            scope['window_name'] = scope.get('window_name',
                                             scope_name) or scope_name

            for line in scope['lines']:
                if 'object' not in line:
                    line['object']: imet.Metric = _utils.load_class(
                        line['class'])(**line['args'])

                line['metric_label'] = '{}: {}'.format(scope['window_name'],
                                                       line['line_label'])

                line['update_rate'] = line.get('update_rate', 'epoch')
                line_suffixes = list()
                if line['update_rate'] == 'iteration':
                    line['object'].attach(trainer, line['metric_label'])
                    line['train'] = False
                    line['test'] = False

                    line_suffixes.append(' Train.')

                if line.get('train', True):
                    line['object'].attach(validator_train,
                                          line['metric_label'])
                    line_suffixes.append(' Train.')
                if line.get('test', True):
                    line['object'].attach(validator_eval, line['metric_label'])
                    line_suffixes.append(' Eval.')

                    if line.get('is_checkpoint', False):
                        checkpoint_metrics.append(line['metric_label'])

                for line_suffix in line_suffixes:
                    _visdom.plot_line(
                        vis=vis,
                        window_name=scope['window_name'],
                        env=visdom_env_name,
                        line_label=line['line_label'] + line_suffix,
                        x_label=scope['x_label'],
                        y_label=scope['y_label'],
                        width=scope['width'],
                        height=scope['height'],
                        draw_marker=(line['update_rate'] == 'epoch'))

        if checkpoint_metrics:
            score_name = 'performance'

            def get_score(engine: ieng.Engine) -> float:
                current_mode = getattr(
                    engine.state.dataloader.iterable.dataset,
                    dataset_args['training']['key'])
                val_mode = dataset_args['training']['no']

                score = 0.0
                if current_mode == val_mode:
                    for metric_name in checkpoint_metrics:
                        try:
                            score += engine.state.metrics[metric_name]
                        except KeyError:
                            pass

                return score

            model_saver = ihan.ModelCheckpoint(os.path.join(
                saved_models_path, visdom_env_name),
                                               filename_prefix=visdom_env_name,
                                               score_name=score_name,
                                               score_function=get_score,
                                               n_saved=3,
                                               save_as_state_dict=True,
                                               require_empty=False,
                                               create_dir=True)

            validator_eval.add_event_handler(ieng.Events.EPOCH_COMPLETED,
                                             model_saver, {model_name: model})

        @trainer.on(ieng.Events.EPOCH_STARTED)
        def reset_progress_iterations(engine: ieng.Engine):
            prog_bar_iters.clear()
            prog_bar_iters.n = 0
            prog_bar_iters.last_print_n = 0
            prog_bar_iters.start_t = time.time()
            prog_bar_iters.last_print_t = time.time()
            prog_bar_iters.total = len(engine.state.dataloader)

        @trainer.on(ieng.Events.ITERATION_COMPLETED)
        def log_training(engine: ieng.Engine):
            prog_bar_iters.update(1)

            num_iter = (engine.state.iteration - 1) % len(train_loader) + 1

            early_stop = np.isnan(engine.state.output) or np.isinf(
                engine.state.output)

            if num_iter % log_interval == 0 or num_iter == len(
                    train_loader) or early_stop:
                tqdm.tqdm.write(
                    'Epoch[{}] Iteration[{}/{}] Loss: {:.4f}'.format(
                        engine.state.epoch, num_iter, len(train_loader),
                        engine.state.output))

                x_pos = engine.state.epoch + num_iter / len(train_loader) - 1
                for scope_name, scope in performance_metrics.items():
                    for line in scope['lines']:
                        if line['update_rate'] == 'iteration':
                            line_label = '{} Train.'.format(line['line_label'])
                            line_value = engine.state.metrics[
                                line['metric_label']]

                            if engine.state.epoch > 1:
                                _visdom.plot_line(
                                    vis=vis,
                                    window_name=scope['window_name'],
                                    env=visdom_env_name,
                                    line_label=line_label,
                                    x_label=scope['x_label'],
                                    y_label=scope['y_label'],
                                    x=np.full(1, x_pos),
                                    y=np.full(1, line_value))

            if early_stop:
                tqdm.tqdm.write(
                    colored('Early stopping due to invalid loss value.',
                            'red'))
                trainer.terminate()

        def log_validation(engine: ieng.Engine, train: bool = True):

            if train:
                run_type = 'Train.'
                data_loader = train_loader
                validator = validator_train
            else:
                run_type = 'Eval.'
                data_loader = eval_loader
                validator = validator_eval

            prog_bar_validation = tqdm.tqdm(data_loader,
                                            desc=f'Validation {run_type}',
                                            file=orig_stdout,
                                            dynamic_ncols=True,
                                            leave=False)
            validator.run(prog_bar_validation)
            prog_bar_validation.clear()
            prog_bar_validation.close()

            tqdm_info = ['Epoch: {}'.format(engine.state.epoch)]
            for scope_name, scope in performance_metrics.items():
                for line in scope['lines']:
                    if line['update_rate'] == 'epoch':
                        try:
                            line_label = '{} {}'.format(
                                line['line_label'], run_type)
                            line_value = validator.state.metrics[
                                line['metric_label']]

                            _visdom.plot_line(vis=vis,
                                              window_name=scope['window_name'],
                                              env=visdom_env_name,
                                              line_label=line_label,
                                              x_label=scope['x_label'],
                                              y_label=scope['y_label'],
                                              x=np.full(1, engine.state.epoch),
                                              y=np.full(1, line_value),
                                              draw_marker=True)

                            tqdm_info.append('{}: {:.4f}'.format(
                                line_label, line_value))
                        except KeyError:
                            pass

            tqdm.tqdm.write('{} results - {}'.format(run_type,
                                                     '; '.join(tqdm_info)))

        @trainer.on(ieng.Events.EPOCH_COMPLETED)
        def log_validation_train(engine: ieng.Engine):
            log_validation(engine, True)

        @trainer.on(ieng.Events.EPOCH_COMPLETED)
        def log_validation_eval(engine: ieng.Engine):
            log_validation(engine, False)

            if engine.state.epoch == 1:
                summary = _utils.build_summary_str(
                    experiment_name=experiment_name,
                    model_short_name=model_name,
                    model_class=model_class,
                    model_args=model_args,
                    optimizer_class=optimizer_class,
                    optimizer_args=optimizer_args,
                    dataset_class=dataset_class,
                    dataset_args=dataset_args,
                    transforms=transforms,
                    epochs=epochs,
                    batch_train=batch_train,
                    log_interval=log_interval,
                    saved_models_path=saved_models_path,
                    scheduler_class=scheduler_class,
                    scheduler_args=scheduler_args)
                _visdom.create_summary_window(vis=vis,
                                              visdom_env_name=visdom_env_name,
                                              experiment_name=experiment_name,
                                              summary=summary)

            vis.save([visdom_env_name])

            prog_bar_epochs.update(1)

            if scheduler is not None:
                scheduler.step(engine.state.epoch)

        trainer.run(train_loader, max_epochs=epochs)

        if vis_pid is not None:
            tqdm.tqdm.write('Stopping visdom')
            os.kill(vis_pid, signal.SIGTERM)

        del vis
        del train_loader
        del eval_loader

        prog_bar_iters.clear()
        prog_bar_iters.close()

        prog_bar_epochs.clear()
        prog_bar_epochs.close()

    tqdm.tqdm.write('\n')
Exemplo n.º 24
0
    def create_evaluator(self,
                         profile: Profile,
                         shared: Storage,
                         logger: Logger,
                         model: nn.Module,
                         loss_function: nn.Module,
                         optimizer: optim.Optimizer,
                         lr_scheduler: Any,
                         output_transform=lambda x, y, y_pred: (y_pred, y),
                         **kwargs) -> engine.Engine:
        """

        Args:
            profile: Runtime profile defined in TOML file.
            shared: Shared storage in the whole lifecycle.
            logger: The logger named with this Task.
            model: The model to train.
            loss_function: The loss function to train.
            optimizer: The optimizer to train.
            lr_scheduler: The scheduler to control the learning rate.
            output_transform: The action to transform the output of the model.

        Returns:
            The evaluator engine.
        """

        if 'device' in profile:
            device_type = profile.device
        else:
            device_type = 'cpu'

        if 'non_blocking' in profile:
            non_blocking = profile.non_blocking
        else:
            non_blocking = False

        if 'deterministic' in profile:
            deterministic = profile.deterministic
        else:
            deterministic = False

        _metrics = {}
        self.register_metrics(profile, shared, logger, _metrics)

        def _inference(_engine: engine.Engine, _batch: Tuple[torch.Tensor]):
            model.eval()
            with torch.no_grad():
                x, y = self.prepare_validate_batch(profile,
                                                   shared,
                                                   logger,
                                                   _batch,
                                                   device=device_type,
                                                   non_blocking=non_blocking)
                y_pred = model(x)
                return output_transform(x, y, y_pred)

        evaluator = engine.DeterministicEngine(
            _inference) if deterministic else engine.Engine(_inference)

        for name, metric in _metrics.items():
            metric.attach(evaluator, name)

        return evaluator
Exemplo n.º 25
0
        x, y = batch['payload'], batch['target']
        ypred = CLF (x)
        loss = LFN (ypred, y.squeeze(1))
        loss.backward()
        OPM.step()
        return loss.item()

    def eval_step(engine, batch):
        CLF.eval()
        with t.no_grad():
            x, y = batch['payload'], batch['target']
            y = y.squeeze (1)
            ypred = CLF (x)
            return ypred, y

    TRAINER   = ie.Engine (train_step)
    EVALUATOR = ie.Engine (eval_step)
    for name, metric in VAL_METRICS.items():
        metric.attach (EVALUATOR, name)
    #########################
    TO_CHECKP  = {
        "trainer":TRAINER,
        "evaluator":EVALUATOR,
        "model":CLF,
        "optimizer":OPM,
    }
    tckp = ih.Checkpoint (
        to_save = TO_CHECKP,
        save_handler = ih.DiskSaver (MDIR, require_empty=False),
        n_saved=10,
    )