Beispiel #1
0
    def _test(p, r, average, n_epochs):
        n_iters = 60
        s = 16
        n_classes = 7

        offset = n_iters * s
        y_true = torch.randint(0, n_classes, size=(offset * dist.get_world_size(), )).to(device)
        y_preds = torch.rand(offset * dist.get_world_size(), n_classes).to(device)

        def update(engine, i):
            return y_preds[i * s + rank * offset:(i + 1) * s + rank * offset, :], \
                y_true[i * s + rank * offset:(i + 1) * s + rank * offset]

        engine = Engine(update)

        fbeta = Fbeta(beta=2.5, average=average, device=device)
        fbeta.attach(engine, "f2.5")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=n_epochs)

        assert "f2.5" in engine.state.metrics
        res = engine.state.metrics['f2.5']
        if isinstance(res, torch.Tensor):
            res = res.cpu().numpy()

        true_res = fbeta_score(y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), beta=2.5,
                               average='macro' if average else None)

        assert pytest.approx(res) == true_res
Beispiel #2
0
def test_wrong_inputs():

    with pytest.raises(ValueError, match=r"Beta should be a positive integer"):
        Fbeta(0.0)

    with pytest.raises(
            ValueError,
            match=r"Input precision metric should have average=False"):
        p = Precision(average=True)
        Fbeta(1.0, precision=p)

    with pytest.raises(ValueError,
                       match=r"Input recall metric should have average=False"):
        r = Recall(average=True)
        Fbeta(1.0, recall=r)

    with pytest.raises(
            ValueError,
            match=
            r"If precision argument is provided, output_transform should be None"
    ):
        p = Precision(average=False)
        Fbeta(1.0, precision=p, output_transform=lambda x: x)

    with pytest.raises(
            ValueError,
            match=
            r"If recall argument is provided, output_transform should be None"
    ):
        r = Recall(average=False)
        Fbeta(1.0, recall=r, output_transform=lambda x: x)
Beispiel #3
0
    def _test(p, r, average, output_transform):
        np.random.seed(1)

        n_iters = 10
        batch_size = 10
        n_classes = 10

        y_true = np.arange(0, n_iters * batch_size) % n_classes
        y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes)
        for i in range(n_iters * batch_size):
            if np.random.rand() > 0.4:
                y_pred[i, y_true[i]] = 1.0
            else:
                j = np.random.randint(0, n_classes)
                y_pred[i, j] = 0.7

        y_true_batch_values = iter(y_true.reshape(n_iters, batch_size))
        y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes))

        def update_fn(engine, batch):
            y_true_batch = next(y_true_batch_values)
            y_pred_batch = next(y_pred_batch_values)
            if output_transform is not None:
                return {
                    "y_pred": torch.from_numpy(y_pred_batch),
                    "y": torch.from_numpy(y_true_batch),
                }
            return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

        evaluator = Engine(update_fn)

        f2 = Fbeta(
            beta=2.0,
            average=average,
            precision=p,
            recall=r,
            output_transform=output_transform,
        )
        f2.attach(evaluator, "f2")

        data = list(range(n_iters))
        state = evaluator.run(data, max_epochs=1)

        f2_true = fbeta_score(
            y_true,
            np.argmax(y_pred, axis=-1),
            average="macro" if average else None,
            beta=2.0,
        )
        if isinstance(state.metrics["f2"], torch.Tensor):
            np.testing.assert_allclose(f2_true, state.metrics["f2"].numpy())
        else:
            assert f2_true == pytest.approx(state.metrics["f2"]), "{} vs {}".format(
                f2_true, state.metrics["f2"]
            )
Beispiel #4
0
def test_wrong_inputs():

    with pytest.raises(ValueError, match=r"Beta should be a positive integer"):
        Fbeta(0.0)

    with pytest.raises(ValueError, match=r"Input precision metric should have average=False"):
        p = Precision(average=True)
        Fbeta(1.0, precision=p)

    with pytest.raises(ValueError, match=r"Input recall metric should have average=False"):
        r = Recall(average=True)
        Fbeta(1.0, recall=r)
Beispiel #5
0
def test_predict(model,dataloader_test,use_cuda):
    if use_cuda:
        model = model.cuda()
    
    precision = Precision()
    recall = Recall()
    f1 = Fbeta(beta=1.0, average=True, precision=precision, recall=recall)
    
    for i,(img, label) in enumerate(dataloader_test):
        img, labels = Variable(img),Variable(label)
        if use_cuda:
            img = img.cuda()
            label = label.cuda()
            pred = model(img)
            _,my_label = torch.max(label, dim=1)
            precision.update((pred, my_label))
            recall.update((pred, my_label))
            f1.update((pred, my_label))
            
    precision.compute()
    recall.compute()
    print("\tF1 Score: {:0.2f}".format(f1.compute()*100))
Beispiel #6
0
    def run(self, logging_dir=None, best_model_only=True):

        #assert self.model is not None, '[ERROR] No model object loaded. Please load a PyTorch model torch.nn object into the class object.'
        #assert (self.train_loader is not None) or (self.val_loader is not None), '[ERROR] You must specify data loaders.'

        for key in self.trainer_status.keys():
            assert self.trainer_status[
                key], '[ERROR] The {} has not been generated and you cannot proceed.'.format(
                    key)
        print('[INFO] Trainer pass OK for training.')

        # TRAIN ENGINE
        # Create the objects for training
        self.train_engine = self.create_trainer()

        # METRICS AND EVALUATION
        # Metrics - running average
        RunningAverage(output_transform=lambda x: x).attach(
            self.train_engine, 'loss')

        # Metrics - epochs
        metrics = {
            'accuracy': Accuracy(),
            'recall': Recall(average=True),
            'precision': Precision(average=True),
            'f1': Fbeta(beta=1),
            'topKCatAcc': TopKCategoricalAccuracy(k=5),
            'loss': Loss(self.criterion)
        }

        # Create evaluators
        self.evaluator = self.create_evaluator(metrics=metrics)
        self.train_evaluator = self.create_evaluator(metrics=metrics,
                                                     tag='train')

        # LOGGING
        # Create logging to terminal
        self.add_logging()

        # Create Tensorboard logging
        self.add_tensorboard_logging(logging_dir=logging_dir)

        ## CALLBACKS
        self.create_callbacks(best_model_only=best_model_only)

        ## TRAIN
        # Train the model
        print('[INFO] Executing model training...')
        self.train_engine.run(self.train_loader,
                              max_epochs=self.config.TRAIN.NUM_EPOCHS)
        print('[INFO] Model training is complete.')
Beispiel #7
0
    def run(self, train_loader, val_loader, test_loader):
        """Perform model training and evaluation on holdout dataset."""

        ## attach certain metrics to trainer ##
        CpuInfo().attach(self.trainer, "cpu_util")
        Loss(self.loss).attach(self.trainer, "loss")

        ###### configure evaluator settings ######
        def get_output_transform(target: str, collapse_y: bool = False):
            return lambda out: metric_output_transform(
                out, self.loss, target, collapse_y=collapse_y)

        graph_num_classes = len(self.graph_classes)
        node_num_classes = len(self.node_classes)
        node_num_classes = 2 if node_num_classes == 1 else node_num_classes

        node_output_transform = get_output_transform("node")
        node_output_transform_collapsed = get_output_transform("node",
                                                               collapse_y=True)
        graph_output_transform = get_output_transform("graph")
        graph_output_transform_collapsed = get_output_transform(
            "graph", collapse_y=True)

        # metrics we are interested in
        base_metrics: dict = {
            'loss':
            Loss(self.loss),
            "cpu_util":
            CpuInfo(),
            'node_accuracy_avg':
            Accuracy(output_transform=node_output_transform,
                     is_multilabel=False),
            'node_accuracy':
            LabelwiseAccuracy(output_transform=node_output_transform,
                              is_multilabel=False),
            "node_recall":
            Recall(output_transform=node_output_transform_collapsed,
                   is_multilabel=False,
                   average=False),
            "node_precision":
            Precision(output_transform=node_output_transform_collapsed,
                      is_multilabel=False,
                      average=False),
            "node_f1_score":
            Fbeta(1,
                  output_transform=node_output_transform_collapsed,
                  average=False),
            "node_c_matrix":
            ConfusionMatrix(node_num_classes,
                            output_transform=node_output_transform_collapsed,
                            average=None)
        }

        metrics = dict(**base_metrics)

        # settings for the evaluator
        evaluator_settings = {
            "device": self.device,
            "loss_fn": self.loss,
            "node_classes": self.node_classes,
            "graph_classes": self.graph_classes,
            "non_blocking": True,
            "metrics": OrderedDict(sorted(metrics.items(),
                                          key=lambda m: m[0])),
            "pred_collector_function": self._pred_collector_function
        }

        ## configure evaluators ##
        val_evaluator = None
        if len(val_loader):
            val_evaluator = create_supervised_evaluator(
                self.model, **evaluator_settings)
            # configure behavior for early stopping
            if self.stopper:
                val_evaluator.add_event_handler(Events.COMPLETED, self.stopper)
            # configure behavior for checkpoint saving
            val_evaluator.add_event_handler(Events.COMPLETED,
                                            self.best_checkpoint_handler)
            val_evaluator.add_event_handler(Events.COMPLETED,
                                            self.latest_checkpoint_handler)
        else:
            self.trainer.add_event_handler(Events.COMPLETED,
                                           self.latest_checkpoint_handler)

        test_evaluator = None
        if len(test_loader):
            test_evaluator = create_supervised_evaluator(
                self.model, **evaluator_settings)
        #############################

        @self.trainer.on(Events.STARTED)
        def log_training_start(trainer):
            self.custom_print("Start training...")

        @self.trainer.on(Events.EPOCH_COMPLETED)
        def compute_metrics(trainer):
            """Compute evaluation metric values after each epoch."""

            epoch = trainer.state.epoch

            self.custom_print(f"Finished epoch {epoch:03d}!")

            if len(val_loader):
                self.persist_collection = True
                val_evaluator.run(val_loader)
                self._save_collected_predictions(
                    prefix=f"validation_epoch{epoch:03}")
                # write metrics to file
                self.write_metrics(trainer, val_evaluator, suffix="validation")

        @self.trainer.on(Events.COMPLETED)
        def log_training_complete(trainer):
            """Trigger evaluation on test set if training is completed."""

            epoch = trainer.state.epoch
            suffix = "(Early Stopping)" if epoch < self.epochs else ""

            self.custom_print("Finished after {:03d} epochs! {}".format(
                epoch, suffix))

            # load best model for evaluation
            self.custom_print("Load best model for final evaluation...")
            last_checkpoint: str = self.best_checkpoint_handler.last_checkpoint or self.latest_checkpoint_handler.last_checkpoint
            best_checkpoint_path = os.path.join(self.save_path,
                                                last_checkpoint)
            checkpoint_path_dict: dict = {
                "latest_checkpoint_path":
                best_checkpoint_path  # we want to load states from the best checkpoint as "latest" configuration for testing
            }
            self.model, self.optimizer, self.trainer, _, _, _ = self._load_checkpoint(
                self.model,
                self.optimizer,
                self.trainer,
                None,
                None,
                None,
                checkpoint_path_dict=checkpoint_path_dict)

            if len(test_loader):
                self.persist_collection = True
                test_evaluator.run(test_loader)
                self._save_collected_predictions(prefix="test_final")
                # write metrics to file
                self.write_metrics(trainer, test_evaluator, suffix="test")

        # terminate training if Nan values are produced
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                       TerminateOnNan())

        # start the actual training
        self.custom_print(f"Train for a maximum of {self.epochs} epochs...")
        self.trainer.run(train_loader, max_epochs=self.epochs)
Beispiel #8
0
def inference(config, local_rank, with_pbar_on_iters=True):

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    # Load model and weights
    model_weights_filepath = Path(
        get_artifact_path(config.run_uuid, config.weights_filename))
    assert model_weights_filepath.exists(), \
        "Model weights file '{}' is not found".format(model_weights_filepath.as_posix())

    model = config.model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[local_rank],
                                                      output_device=local_rank)

    if hasattr(config, "custom_weights_loading"):
        config.custom_weights_loading(model, model_weights_filepath)
    else:
        state_dict = torch.load(model_weights_filepath)
        if not all([k.startswith("module.") for k in state_dict]):
            state_dict = {f"module.{k}": v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)

    model.eval()

    prepare_batch = config.prepare_batch
    non_blocking = getattr(config, "non_blocking", True)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    tta_transforms = getattr(config, "tta_transforms", None)

    def eval_update_function(engine, batch):
        with torch.no_grad():
            x, y, meta = prepare_batch(batch,
                                       device=device,
                                       non_blocking=non_blocking)

            if tta_transforms is not None:
                y_preds = []
                for t in tta_transforms:
                    t_x = t.augment_image(x)
                    t_y_pred = model(t_x)
                    t_y_pred = model_output_transform(t_y_pred)
                    y_pred = t.deaugment_mask(t_y_pred)
                    y_preds.append(y_pred)

                y_preds = torch.stack(y_preds, dim=0)
                y_pred = torch.mean(y_preds, dim=0)
            else:
                y_pred = model(x)
                y_pred = model_output_transform(y_pred)
            return {"y_pred": y_pred, "y": y, "meta": meta}

    evaluator = Engine(eval_update_function)

    has_targets = getattr(config, "has_targets", False)

    if has_targets:

        def output_transform(output):
            return output['y_pred'], output['y']

        num_classes = config.num_classes
        cm_metric = ConfusionMatrix(num_classes=num_classes,
                                    output_transform=output_transform)
        pr = cmPrecision(cm_metric, average=False)
        re = cmRecall(cm_metric, average=False)

        val_metrics = {
            "IoU": IoU(cm_metric),
            "mIoU_bg": mIoU(cm_metric),
            "Accuracy": cmAccuracy(cm_metric),
            "Precision": pr,
            "Recall": re,
            "F1": Fbeta(beta=1.0, output_transform=output_transform)
        }

        if hasattr(config, "metrics") and isinstance(config.metrics, dict):
            val_metrics.update(config.metrics)

        for name, metric in val_metrics.items():
            metric.attach(evaluator, name)

        if dist.get_rank() == 0:
            # Log val metrics:
            mlflow_logger = MLflowLogger()
            mlflow_logger.attach(evaluator,
                                 log_handler=OutputHandler(
                                     tag="validation",
                                     metric_names=list(val_metrics.keys())),
                                 event_name=Events.EPOCH_COMPLETED)

    if dist.get_rank() == 0 and with_pbar_on_iters:
        ProgressBar(persist=True, desc="Inference").attach(evaluator)

    if dist.get_rank() == 0:
        do_save_raw_predictions = getattr(config, "do_save_raw_predictions",
                                          True)
        do_save_overlayed_predictions = getattr(
            config, "do_save_overlayed_predictions", True)

        if not has_targets:
            assert do_save_raw_predictions or do_save_overlayed_predictions, \
                "If no targets, either do_save_overlayed_predictions or do_save_raw_predictions should be " \
                "defined in the config and has value equal True"

        # Save predictions
        if do_save_raw_predictions:
            raw_preds_path = config.output_path / "raw"
            raw_preds_path.mkdir(parents=True)

            evaluator.add_event_handler(Events.ITERATION_COMPLETED,
                                        save_raw_predictions_with_geoinfo,
                                        raw_preds_path)

        if do_save_overlayed_predictions:
            overlayed_preds_path = config.output_path / "overlay"
            overlayed_preds_path.mkdir(parents=True)

            evaluator.add_event_handler(
                Events.ITERATION_COMPLETED,
                save_overlayed_predictions,
                overlayed_preds_path,
                img_denormalize_fn=config.img_denormalize,
                palette=default_palette)

    evaluator.add_event_handler(Events.EXCEPTION_RAISED, report_exception)

    # Run evaluation
    evaluator.run(config.data_loader)
Beispiel #9
0
    num_classes=200
    )
print('done')

## SETUP TRAINER AND EVALUATOR
# Setup model trainer and evaluator
print('[INFO] Creating Ignite training, evaluation objects and logging...', end='')
trainer = create_trainer(model=model, optimizer=optimizer, criterion=criterion, lr_scheduler=lr_scheduler)
# Metrics - running average
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
# Metrics - epochs
metrics = {
    'accuracy':Accuracy(),
    'recall':Recall(average=True),
    'precision':Precision(average=True),
    'f1':Fbeta(beta=1),
    'topKCatAcc':TopKCategoricalAccuracy(k=5),
    'loss':Loss(criterion)
}

# Create evaluators
evaluator = create_evaluator(model, metrics=metrics)
train_evaluator = create_evaluator(model, metrics=metrics, tag='train')

# Add validation logging
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), evaluate_model)

# Add step length update at the end of each epoch
trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step())

# Add TensorBoard logging
Beispiel #10
0
def training(config, local_rank, with_pbar_on_iters=True):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    train_loader = config.train_loader
    train_sampler = getattr(train_loader, "sampler", None)
    assert train_sampler is not None, "Train loader of type '{}' " \
                                      "should have attribute 'sampler'".format(type(train_loader))
    assert hasattr(train_sampler, 'set_epoch') and callable(train_sampler.set_epoch), \
        "Train sampler should have a callable method `set_epoch`"

    train_eval_loader = config.train_eval_loader
    val_loader = config.val_loader

    model = config.model.to(device)
    optimizer = config.optimizer
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=getattr(
                                          config, "fp16_opt_level", "O2"),
                                      num_losses=1)
    model = DDP(model, delay_allreduce=True)

    criterion = config.criterion.to(device)

    # Setup trainer
    prepare_batch = getattr(config, "prepare_batch")
    non_blocking = getattr(config, "non_blocking", True)
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    def train_update_function(engine, batch):
        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y)

        if isinstance(loss, Mapping):
            assert 'supervised batch loss' in loss
            loss_dict = loss
            output = {k: v.item() for k, v in loss_dict.items()}
            loss = loss_dict['supervised batch loss'] / accumulation_steps
        else:
            output = {'supervised batch loss': loss.item()}

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return output

    output_names = getattr(config, "output_names", [
        'supervised batch loss',
    ])

    trainer = Engine(train_update_function)
    common.setup_common_distrib_training_handlers(
        trainer,
        train_sampler,
        to_save={
            'model': model,
            'optimizer': optimizer
        },
        save_every_iters=1000,
        output_path=config.output_path.as_posix(),
        lr_scheduler=config.lr_scheduler,
        output_names=output_names,
        with_pbars=True,
        with_pbar_on_iters=with_pbar_on_iters,
        log_every_iters=1)

    def output_transform(output):
        return output['y_pred'], output['y']

    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes,
                                output_transform=output_transform)
    pr = cmPrecision(cm_metric, average=False)
    re = cmRecall(cm_metric, average=False)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
        "Accuracy": cmAccuracy(cm_metric),
        "Precision": pr,
        "Recall": re,
        "F1": Fbeta(beta=1.0, output_transform=output_transform)
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator_args = dict(model=model,
                          metrics=val_metrics,
                          device=device,
                          non_blocking=non_blocking,
                          prepare_batch=prepare_batch,
                          output_transform=lambda x, y, y_pred: {
                              'y_pred': model_output_transform(y_pred),
                              'y': y
                          })
    train_evaluator = create_supervised_evaluator(**evaluator_args)
    evaluator = create_supervised_evaluator(**evaluator_args)

    if dist.get_rank() == 0 and with_pbar_on_iters:
        ProgressBar(persist=False,
                    desc="Train Evaluation").attach(train_evaluator)
        ProgressBar(persist=False, desc="Val Evaluation").attach(evaluator)

    def run_validation(engine):
        train_evaluator.run(train_eval_loader)
        evaluator.run(val_loader)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1)),
        run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    score_metric_name = "mIoU_bg"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    if dist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(config.output_path.as_posix(),
                                            trainer,
                                            optimizer,
                                            evaluators={
                                                "training": train_evaluator,
                                                "validation": evaluator
                                            })
        common.setup_mlflow_logging(trainer,
                                    optimizer,
                                    evaluators={
                                        "training": train_evaluator,
                                        "validation": evaluator
                                    })

        common.save_best_model_by_val_score(config.output_path.as_posix(),
                                            evaluator,
                                            model,
                                            metric_name=score_metric_name,
                                            trainer=trainer)

        # Log train/val predictions:
        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2))

        log_train_predictions = getattr(config, "log_train_predictions", False)
        if log_train_predictions:
            tb_logger.attach(train_evaluator,
                             log_handler=predictions_gt_images_handler(
                                 img_denormalize_fn=config.img_denormalize,
                                 n_images=15,
                                 another_engine=trainer,
                                 prefix_tag="validation"),
                             event_name=Events.ITERATION_COMPLETED(
                                 once=len(train_eval_loader) // 2))

    trainer.run(train_loader, max_epochs=config.num_epochs)
Beispiel #11
0
def training(config, local_rank, with_pbar_on_iters=True):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    train_loader = config.train_loader
    train_sampler = getattr(train_loader, "sampler", None)
    assert train_sampler is not None, "Train loader of type '{}' " \
                                      "should have attribute 'sampler'".format(type(train_loader))
    assert hasattr(train_sampler, 'set_epoch') and callable(train_sampler.set_epoch), \
        "Train sampler should have a callable method `set_epoch`"

    unsup_train_loader = config.unsup_train_loader
    unsup_train_sampler = getattr(unsup_train_loader, "sampler", None)
    assert unsup_train_sampler is not None, "Train loader of type '{}' " \
                                      "should have attribute 'sampler'".format(type(unsup_train_loader))
    assert hasattr(unsup_train_sampler, 'set_epoch') and callable(unsup_train_sampler.set_epoch), \
        "Unsupervised train sampler should have a callable method `set_epoch`"

    train_eval_loader = config.train_eval_loader
    val_loader = config.val_loader

    model = config.model.to(device)
    optimizer = config.optimizer
    model, optimizer = amp.initialize(model, optimizer, opt_level=getattr(config, "fp16_opt_level", "O2"), num_losses=2)
    model = DDP(model, delay_allreduce=True)
    
    criterion = config.criterion.to(device)
    unsup_criterion = config.unsup_criterion.to(device)
    unsup_batch_num_repetitions = getattr(config, "unsup_batch_num_repetitions", 1)

    # Setup trainer
    prepare_batch = getattr(config, "prepare_batch")
    non_blocking = getattr(config, "non_blocking", True)
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform", lambda x: x)

    def cycle(seq):
        while True:
            for i in seq:
                yield i

    unsup_train_loader_iter = cycle(unsup_train_loader)

    def supervised_loss(batch):
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y)
        return loss

    def unsupervised_loss(x):

        with torch.no_grad():
            y_pred_orig = model(x)

            # Data augmentation: geom only
            k = random.randint(1, 3)
            x_aug = torch.rot90(x, k=k, dims=(2, 3))
            y_pred_orig_aug = torch.rot90(y_pred_orig, k=k, dims=(2, 3))
            if random.random() < 0.5:
                x_aug = torch.flip(x_aug, dims=(2, ))
                y_pred_orig_aug = torch.flip(y_pred_orig_aug, dims=(2, )) 
            if random.random() < 0.5:
                x_aug = torch.flip(x_aug, dims=(3, ))
                y_pred_orig_aug = torch.flip(y_pred_orig_aug, dims=(3, )) 

            y_pred_orig_aug = y_pred_orig_aug.argmax(dim=1).long()

        y_pred_aug = model(x_aug.detach())

        loss = unsup_criterion(y_pred_aug, y_pred_orig_aug.detach())

        return loss

    def train_update_function(engine, batch):
        model.train()

        loss = supervised_loss(batch)
        if isinstance(loss, Mapping):
            assert 'supervised batch loss' in loss
            loss_dict = loss
            output = {k: v.item() for k, v in loss_dict.items()}
            loss = loss_dict['supervised batch loss'] / accumulation_steps
        else:
            output = {'supervised batch loss': loss.item()}
        
        # Difference with original UDA
        # Apply separately grads from supervised/unsupervised parts
        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        unsup_batch = next(unsup_train_loader_iter)
        unsup_x = unsup_batch['image']
        unsup_x = convert_tensor(unsup_x, device=device, non_blocking=non_blocking)

        for _ in range(unsup_batch_num_repetitions):
            unsup_loss = engine.state.unsup_lambda * unsupervised_loss(unsup_x)

            assert isinstance(unsup_loss, torch.Tensor)
            output['unsupervised batch loss'] = unsup_loss.item()

            with amp.scale_loss(unsup_loss, optimizer, loss_id=1) as scaled_loss:
                scaled_loss.backward()

            if engine.state.iteration % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

        unsup_batch = None
        unsup_x = None

        total_loss = loss +  unsup_loss
        output['total batch loss'] = total_loss.item()

        return output

    output_names = getattr(config, "output_names", 
                           ['supervised batch loss', 'unsupervised batch loss', 'total batch loss'])

    trainer = Engine(train_update_function)

    @trainer.on(Events.STARTED)
    def init(engine):
        if hasattr(config, "unsup_lambda_min"):
            engine.state.unsup_lambda = config.unsup_lambda_min
        else:
            engine.state.unsup_lambda = getattr(config, "unsup_lambda", 0.001)

    @trainer.on(Events.ITERATION_COMPLETED)
    def update_unsup_params(engine):        
        engine.state.unsup_lambda += getattr(config, "unsup_lambda_delta", 0.00001)
        if hasattr(config, "unsup_lambda_max"):
            m = config.unsup_lambda_max
            engine.state.unsup_lambda = engine.state.unsup_lambda if engine.state.unsup_lambda < m else m

    common.setup_common_distrib_training_handlers(
        trainer, train_sampler,
        to_save={'model': model, 'optimizer': optimizer},
        save_every_iters=1000, output_path=config.output_path.as_posix(),
        lr_scheduler=config.lr_scheduler, output_names=output_names,
        with_pbars=True, with_pbar_on_iters=with_pbar_on_iters, log_every_iters=1
    )

    def output_transform(output):        
        return output['y_pred'], output['y']

    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes, output_transform=output_transform)
    pr = cmPrecision(cm_metric, average=False)
    re = cmRecall(cm_metric, average=False)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
        "Accuracy": cmAccuracy(cm_metric),
        "Precision": pr,
        "Recall": re,
        "F1": Fbeta(beta=1.0, output_transform=output_transform)
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator_args = dict(
        model=model, metrics=val_metrics, device=device, non_blocking=non_blocking, prepare_batch=prepare_batch,
        output_transform=lambda x, y, y_pred: {'y_pred': model_output_transform(y_pred), 'y': y}
    )
    train_evaluator = create_supervised_evaluator(**evaluator_args)
    evaluator = create_supervised_evaluator(**evaluator_args)

    if dist.get_rank() == 0 and with_pbar_on_iters:
        ProgressBar(persist=False, desc="Train Evaluation").attach(train_evaluator)
        ProgressBar(persist=False, desc="Val Evaluation").attach(evaluator)

    def run_validation(engine):
        train_evaluator.run(train_eval_loader)
        evaluator.run(val_loader)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1)), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    score_metric_name = "mIoU_bg"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience, evaluator, trainer, metric_name=score_metric_name)

    if dist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(config.output_path.as_posix(), trainer, optimizer,
                                            evaluators={"training": train_evaluator, "validation": evaluator})
        common.setup_mlflow_logging(trainer, optimizer,
                                    evaluators={"training": train_evaluator, "validation": evaluator})

        common.save_best_model_by_val_score(config.output_path.as_posix(), evaluator, model,
                                            metric_name=score_metric_name, trainer=trainer)

        # Log unsup_lambda
        @trainer.on(Events.ITERATION_COMPLETED(every=100))
        def tblog_unsupervised_lambda(engine):
            tb_logger.writer.add_scalar("training/unsupervised lambda", engine.state.unsup_lambda, engine.state.iteration)
            mlflow.log_metric("training unsupervised lambda", engine.state.unsup_lambda, step=engine.state.iteration)

        # Log train/val predictions:
        tb_logger.attach(evaluator,
                         log_handler=predictions_gt_images_handler(img_denormalize_fn=config.img_denormalize,
                                                                   n_images=15,
                                                                   another_engine=trainer,
                                                                   prefix_tag="validation"),
                         event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2))

        log_train_predictions = getattr(config, "log_train_predictions", False)
        if log_train_predictions:
            tb_logger.attach(train_evaluator,
                             log_handler=predictions_gt_images_handler(img_denormalize_fn=config.img_denormalize,
                                                                       n_images=15,
                                                                       another_engine=trainer,
                                                                       prefix_tag="validation"),
                             event_name=Events.ITERATION_COMPLETED(once=len(train_eval_loader) // 2))

    trainer.run(train_loader, max_epochs=config.num_epochs)
Beispiel #12
0
def train_predict(dataloader_train,dataloader_val,model,epochs,learning_rate,use_cuda):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    if use_cuda:
        model = model.cuda()
    model = model.train()
    
    start.record()
    train_loss_list=[]
    val_loss_list=[]
    train_f1=[]
    val_f1=[]
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    
    precision = Precision()
    recall = Recall()
    f1 = Fbeta(beta=1.0, average=True, precision=precision, recall=recall)

    
    for epoch in range(epochs):
        print("Epoch: {}".format(epoch+1))
        for i,(img, label) in enumerate(dataloader_train):
            img, label = Variable(img),Variable(label)
            if use_cuda:
                img = img.cuda()
                label = label.cuda()
            optimizer.zero_grad()
            pred = model.forward(img)
            _,my_label = torch.max(label, dim=1)
            loss = loss_fn(pred,my_label)
            if i == len(dataloader_train)-1:
                train_loss_list.append(loss.item())
            loss.backward()
            optimizer.step()
            precision.update((pred, my_label))
            recall.update((pred, my_label))
            f1.update((pred, my_label))
        print("\tTrain loss: {:0.2f}".format(train_loss_list[-1]))
        precision.compute()
        recall.compute()
        train_f1.append(f1.compute()*100)
        print("\tTrain F1 Score: {:0.2f}%".format(train_f1[-1]))
        
        precision = Precision()
        recall = Recall()
        f1 = Fbeta(beta=1.0, average=True, precision=precision, recall=recall)
        
        with torch.no_grad():
            for i,(img, label) in enumerate(dataloader_val):
                img, labels = Variable(img),Variable(label)
                if use_cuda:
                    img = img.cuda()
                    label = label.cuda()
                pred = model(img)
                _,my_label = torch.max(label, dim=1)
                loss = loss_fn(pred,my_label)
                if i == len(dataloader_val)-1:
                    val_loss_list.append(loss.item())
                precision.update((pred, my_label))
                recall.update((pred, my_label))
                f1.update((pred, my_label))
        print("\n\tVal loss: {:0.2f}".format(val_loss_list[-1]))
        precision.compute()
        recall.compute()
        val_f1.append(f1.compute()*100)
        print("\tVal F1 Score: {:0.2f}%".format(val_f1[-1]))
    
    end.record()
    torch.cuda.synchronize()
    time = start.elapsed_time(end)
    return (train_loss_list,val_loss_list,train_f1,val_f1,time,model)
Beispiel #13
0
            end_value=0,
            cycle_size=len(train_loader) * cfg.n_epochs,
            start_value_mult=0,
            end_value_mult=0),
                                        warmup_start_value=0.0,
                                        warmup_end_value=cfg.lr,
                                        warmup_duration=len(train_loader)))

    evaluator = create_supervised_evaluator(
        model,
        metrics={
            'loss': Loss(loss),
            'acc_smpl': Accuracy(threshold_output, is_multilabel=True),
            'p': Precision(threshold_output, average=True),
            'r': Recall(threshold_output, average=True),
            'f1': Fbeta(1.0, output_transform=threshold_output),
            'ap': AveragePrecision(output_transform=activate_output)
        },
        device=device)

    model_checkpoint = ModelCheckpoint(
        dirname=wandb.run.dir,
        filename_prefix='best',
        require_empty=False,
        score_function=lambda e: e.state.metrics['ap'],
        global_step_transform=global_step_from_engine(trainer))
    evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                {'model': model})

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(trainer):
Beispiel #14
0
def set_image_classification_trainer(model, optimizer, criterion, device,
                                     loaders, loggers):
    def train_step(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y).mean()
        loss.backward()
        optimizer.step()
        return loss.item()

    trainer = Engine(train_step)
    loggers['progress_bar'].attach(trainer, metric_names='all')

    def validation_step(engine, batch):
        model.eval()
        with torch.no_grad():
            x, target = batch[0].to(device), batch[1].to(device)
            y = model(x)
            return {'y_pred': y, 'y': target, 'criterion_kwargs': {}}

    evaluator = Engine(validation_step)
    evaluator.state.validation_completed = 0
    evaluator.register_events(*EvaluatorEvents, event_to_attr=event_to_attr)

    metrics = {
        'loss': Loss(criterion),
        'F1': Fbeta(beta=1, average=False),
        'mA': Accuracy(is_multilabel=False),
        'mP': Precision(average=False, is_multilabel=False),
        'mR': Recall(average=False, is_multilabel=False)
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=250),
                              log_training_loss, loggers)

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        with evaluator.add_event_handler(Events.COMPLETED, log_results,
                                         'train', engine.state.epoch, loggers):
            evaluator.run(loaders['train'])
        with evaluator.add_event_handler(Events.COMPLETED, log_results,
                                         'validation', engine.state.epoch,
                                         loggers):
            evaluator.run(loaders['validation'])
            evaluator.state.validation_completed += 1
            evaluator.fire_event(EvaluatorEvents.VALIDATION_COMPLETED)

    @trainer.on(Events.COMPLETED)
    def test(engine):
        with evaluator.add_event_handler(
                Events.COMPLETED, log_results, 'test', engine.state.epoch,
                loggers), evaluator.add_event_handler(
                    Events.COMPLETED,
                    log_calibration_results,
                    'test',
                    loggers,
                    output_transform=lambda output: {
                        'y_pred': F.softmax(output['y_pred'], dim=1),
                        'y': output['y']
                    }):
            evaluator.run(loaders['test'])

    return trainer, evaluator