def _test_distrib_integration(device):
    import numpy as np
    from ignite.engine import Engine

    rank = idist.get_rank()
    n_iters = 80
    s = 50
    offset = n_iters * s

    y_true = torch.arange(0, offset * idist.get_world_size(), dtype=torch.float).to(device)
    y_preds = torch.ones(offset * idist.get_world_size(), dtype=torch.float).to(device)

    def update(engine, i):
        return (
            y_preds[i * s + offset * rank : (i + 1) * s + offset * rank],
            y_true[i * s + offset * rank : (i + 1) * s + offset * rank],
        )

    engine = Engine(update)

    m = MeanAbsoluteError()
    m.attach(engine, "mae")

    data = list(range(n_iters))
    engine.run(data=data, max_epochs=1)

    assert "mae" in engine.state.metrics
    res = engine.state.metrics["mae"]

    true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy()))

    assert pytest.approx(res) == true_res
Beispiel #2
0
def test_zero_div():
    mae = MeanAbsoluteError()
    with pytest.raises(
            NotComputableError,
            match=
            r"MeanAbsoluteError must have at least one example before it can be computed"
    ):
        mae.compute()
Beispiel #3
0
def test_accumulator_detached():
    mae = MeanAbsoluteError()

    y_pred = torch.tensor([[2.0], [-2.0]], requires_grad=True)
    y = torch.zeros(2)
    mae.update((y_pred, y))

    assert not mae._sum_of_absolute_errors.requires_grad
def test_compute():
    mae = MeanAbsoluteError()

    y_pred = torch.Tensor([[2.0], [-2.0]])
    y = torch.zeros(2)
    mae.update((y_pred, y))
    assert mae.compute() == 2.0

    mae.reset()
    y_pred = torch.Tensor([[3.0], [-3.0]])
    y = torch.zeros(2)
    mae.update((y_pred, y))
    assert mae.compute() == 3.0
Beispiel #5
0
    def _test(metric_device):
        engine = Engine(update)

        m = MeanAbsoluteError(device=metric_device)
        m.attach(engine, "mae")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=1)

        assert "mae" in engine.state.metrics
        res = engine.state.metrics["mae"]

        true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy()))

        assert pytest.approx(res) == true_res
Beispiel #6
0
def metrics_selector(mode, loss):
    mode = mode.lower()
    if mode == "classification":
        metrics = {
            "loss": loss,
            "accuracy": Accuracy(),
            "accuracy_topk": TopKCategoricalAccuracy(),
            "precision": Precision(average=True),
            "recall": Recall(average=True)
        }
    elif mode == "multiclass-multilabel":
        metrics = {
            "loss": loss,
            "accuracy": Accuracy(),
        }
    elif mode == "regression":
        metrics = {
            "loss": loss,
            "mse": MeanSquaredError(),
            "mae": MeanAbsoluteError()
        }
    else:
        raise RuntimeError(
            "Invalid task mode, select classification or regression")

    return metrics
Beispiel #7
0
def create_sr_evaluator(
    model,
    device=None,
    non_blocking=True,
    denormalize=True,
    mean=None,
):
    # transfer mean to the device and reshape it so
    # that is is broadcastable to the BCHW format
    mean = mean.to(device).reshape(1, -1, 1, 1)

    def denorm_fn(x):
        return torch.clamp(x + mean, min=0., max=1.)

    def _evaluate_model(engine, batch):
        model.eval()
        x, y = _prepare_batch(batch, device=device, non_blocking=non_blocking)
        with torch.no_grad():
            y_pred = model(x)
        if denormalize:
            y_pred, y = map(denorm_fn, [y_pred, y])
        return y_pred, y

    engine = Engine(_evaluate_model)
    MeanAbsoluteError().attach(engine, 'l1')
    MeanSquaredError().attach(engine, 'l2')
    PNSR(max_value=1.0).attach(engine, 'pnsr')

    return engine
Beispiel #8
0
def _test_distrib_accumulator_device(device):

    metric_devices = [torch.device("cpu")]
    if device.type != "xla":
        metric_devices.append(idist.device())
    for metric_device in metric_devices:
        mae = MeanAbsoluteError(device=metric_device)

        for dev in [mae._device, mae._sum_of_absolute_errors.device]:
            assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

        y_pred = torch.tensor([[2.0], [-2.0]])
        y = torch.zeros(2)
        mae.update((y_pred, y))

        for dev in [mae._device, mae._sum_of_absolute_errors.device]:
            assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
def test_compute():

    mae = MeanAbsoluteError()

    def _test(y_pred, y, batch_size):
        mae.reset()
        if batch_size > 1:
            n_iters = y.shape[0] // batch_size + 1
            for i in range(n_iters):
                idx = i * batch_size
                mae.update(
                    (y_pred[idx:idx + batch_size], y[idx:idx + batch_size]))
        else:
            mae.update((y_pred, y, batch_size))

        np_y = y.numpy()
        np_y_pred = y_pred.numpy()

        np_res = (np.abs(np_y_pred - np_y)).sum() / np_y.shape[0]
        assert isinstance(mae.compute(), float)
        assert mae.compute() == np_res

    def get_test_cases():

        test_cases = [
            (torch.randint(0, 10,
                           size=(100, 1)), torch.randint(0, 10,
                                                         size=(100, 1)), 1),
            (torch.randint(-10, 10,
                           size=(100, 5)), torch.randint(-10,
                                                         10,
                                                         size=(100, 5)), 1),
            # updated batches
            (torch.randint(0, 10,
                           size=(100, 1)), torch.randint(0, 10,
                                                         size=(100, 1)), 16),
            (torch.randint(-20, 20,
                           size=(100, 5)), torch.randint(-20,
                                                         20,
                                                         size=(100, 5)), 16),
        ]

        return test_cases

    for _ in range(5):
        # check multiple random inputs as random exact occurencies are rare
        test_cases = get_test_cases()
        for y_pred, y, batch_size in test_cases:
            _test(y_pred, y, batch_size)
def _test_distrib_accumulator_device(device):

    metric_devices = [torch.device("cpu")]
    if device.type != "xla":
        metric_devices.append(idist.device())
    for metric_device in metric_devices:
        mae = MeanAbsoluteError(device=metric_device)
        assert mae._device == metric_device
        assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format(
            type(mae._sum_of_absolute_errors.device),
            mae._sum_of_absolute_errors.device,
            type(metric_device),
            metric_device,
        )

        y_pred = torch.tensor([[2.0], [-2.0]])
        y = torch.zeros(2)
        mae.update((y_pred, y))
        assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format(
            type(mae._sum_of_absolute_errors.device),
            mae._sum_of_absolute_errors.device,
            type(metric_device),
            metric_device,
        )
Beispiel #11
0
    def __init__(self, output_transform=binary_transform):
        self._y = None
        self._pred = None
        super().__init__(output_transform=output_transform)

    def reset(self):
        self._y = list()
        self._pred = list()
        super().reset()

    def update(self, output):
        y_pred, y = output
        self._y.append(y)
        self._pred.append(y_pred)

    def compute(self):
        y_pred = torch.cat(self._pred, 0).cpu()
        y = torch.cat(self._y, 0).cpu()
        score = f1_score(y, y_pred, average='weighted')
        return score


metric = {
    'acc2': Accuracy(output_transform=binary_transform),
    'acc5': Accuracy(output_transform=five_transform),
    'acc7': Accuracy(output_transform=seven_transform),
    'f1': F1(),
    'corr': Pearson(),
    'mae': MeanAbsoluteError()
}
Beispiel #12
0
def run(opt):
    if opt.log_file is not None:
        logging.basicConfig(filename=opt.log_file, level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    # logger.addHandler(logging.StreamHandler())
    logger = logger.info

    writer = SummaryWriter(log_dir=opt.log_dir)

    model_timer, data_timer = Timer(average=True), Timer(average=True)

    # Training variables
    logger('Loading models')
    model, parameters, mean, std = generate_model(opt)
    optimizer = SGD(parameters,
                    lr=opt.lr,
                    momentum=opt.momentum,
                    weight_decay=opt.weight_decay,
                    nesterov=opt.nesterov)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=opt.lr_patience)

    # Loading checkpoint
    if opt.checkpoint:
        logger('loading checkpoint {}'.format(opt.checkpoint))
        checkpoint = torch.load(opt.checkpoint)

        opt.begin_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    logger('Loading dataset')
    train_transform = get_transform(mean, std, opt.face_size, mode='training')
    train_data = get_training_set(opt, transform=train_transform)
    train_loader = DataLoader(train_data,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.n_threads,
                              pin_memory=True)

    val_transform = get_transform(mean, std, opt.face_size, mode='validation')
    val_data = get_validation_set(opt, transform=val_transform)
    val_loader = DataLoader(val_data,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            num_workers=opt.n_threads,
                            pin_memory=True)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        nn.L1Loss().cuda(),
                                        cuda=True)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'distance':
                                                MeanPairwiseDistance(),
                                                'loss': MeanAbsoluteError()
                                            },
                                            cuda=True)

    # Training timer handlers
    model_timer.attach(trainer,
                       start=Events.EPOCH_STARTED,
                       resume=Events.ITERATION_STARTED,
                       pause=Events.ITERATION_COMPLETED,
                       step=Events.ITERATION_COMPLETED)
    data_timer.attach(trainer,
                      start=Events.EPOCH_STARTED,
                      resume=Events.ITERATION_COMPLETED,
                      pause=Events.ITERATION_STARTED,
                      step=Events.ITERATION_STARTED)

    # Training log/plot handlers
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % opt.log_interval == 0:
            logger(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.2f} Model Process: {:.3f}s/batch "
                "Data Preparation: {:.3f}s/batch".format(
                    engine.state.epoch,
                    iter, len(train_loader), engine.state.output,
                    model_timer.value(), data_timer.value()))
            writer.add_scalar("training/loss", engine.state.output,
                              engine.state.iteration)

    # Log/Plot Learning rate
    @trainer.on(Events.EPOCH_STARTED)
    def log_learning_rate(engine):
        lr = optimizer.param_groups[0]['lr']
        logger('Epoch[{}] Starts with lr={}'.format(engine.state.epoch, lr))
        writer.add_scalar("learning_rate", lr, engine.state.epoch)

    # Checkpointing
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_checkpoint(engine):
        if engine.state.epoch % opt.save_interval == 0:
            save_file_path = os.path.join(
                opt.result_path, 'save_{}.pth'.format(engine.state.epoch))
            states = {
                'epoch': engine.state.epoch,
                'arch': opt.model,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_file_path)

    # val_evaluator event handlers
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        # metric_values = [metrics[m] for m in val_metrics]
        logger("Validation Results - Epoch: {} ".format(engine.state.epoch) +
               ' '.join(
                   ['{}: {:.4f}'.format(m, val)
                    for m, val in metrics.items()]))
        for m, val in metrics.items():
            writer.add_scalar('validation/{}'.format(m), val,
                              engine.state.epoch)

        #
        if engine.state.epoch == 1:
            optimizer.param_groups[0]['lr'] = 1e-4

        # Update Learning Rate
        scheduler.step(metrics['loss'])

    # kick everything off
    logger('Start training')
    trainer.run(train_loader, max_epochs=opt.n_epochs)

    writer.close()
Beispiel #13
0
import torch.nn.functional as F
from ignite.metrics import CategoricalAccuracy, Loss, MeanAbsoluteError

from attributer.attributes import FaceAttributes
from training.metric_utils import ScaledError

_metrics = {
    FaceAttributes.AGE: ScaledError(MeanAbsoluteError(), 50),
    FaceAttributes.GENDER: CategoricalAccuracy(),
    FaceAttributes.EYEGLASSES: CategoricalAccuracy(),
    FaceAttributes.RECEDING_HAIRLINES: CategoricalAccuracy(),
    FaceAttributes.SMILING: CategoricalAccuracy(),
    FaceAttributes.HEAD_YAW_BIN: CategoricalAccuracy(),
    FaceAttributes.HEAD_PITCH_BIN: CategoricalAccuracy(),
    FaceAttributes.HEAD_ROLL_BIN: CategoricalAccuracy(),
    FaceAttributes.HEAD_YAW: MeanAbsoluteError(),
    FaceAttributes.HEAD_PITCH: MeanAbsoluteError(),
    FaceAttributes.HEAD_ROLL: MeanAbsoluteError(),
}

_losses = {
    FaceAttributes.AGE: F.l1_loss,
    FaceAttributes.GENDER: F.cross_entropy,
    FaceAttributes.EYEGLASSES: F.cross_entropy,
    FaceAttributes.RECEDING_HAIRLINES: F.cross_entropy,
    FaceAttributes.SMILING: F.cross_entropy,
    FaceAttributes.HEAD_YAW_BIN: F.cross_entropy,
    FaceAttributes.HEAD_PITCH_BIN: F.cross_entropy,
    FaceAttributes.HEAD_ROLL_BIN: F.cross_entropy,
    FaceAttributes.HEAD_YAW: F.l1_loss,
    FaceAttributes.HEAD_PITCH: F.l1_loss,
Beispiel #14
0
def run(args, seed):
    config.make_paths()

    torch.random.manual_seed(seed)
    train_loader, val_loader, shape = get_data_loaders(
        config.Training.batch_size,
        proportion=config.Training.proportion,
        test_batch_size=config.Training.batch_size * 2,
    )
    n, d, t = shape
    model = models.ConvNet(d, seq_len=t)

    writer = tb.SummaryWriter(log_dir=config.TENSORBOARD)

    model.to(config.device)  # Move model before creating optimizer
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.MSELoss()

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=config.device)
    trainer.logger = setup_logger("trainer")

    checkpointer = ModelCheckpoint(
        config.MODEL,
        model.__class__.__name__,
        n_saved=2,
        create_dir=True,
        save_as_state_dict=True,
    )
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=config.Training.save_every),
        checkpointer,
        {"model": model},
    )

    val_metrics = {
        "mse": Loss(criterion),
        "mae": MeanAbsoluteError(),
        "rmse": RootMeanSquaredError(),
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=val_metrics,
                                            device=config.device)
    evaluator.logger = setup_logger("evaluator")

    ar_evaluator = create_ar_evaluator(model,
                                       metrics=val_metrics,
                                       device=config.device)
    ar_evaluator.logger = setup_logger("ar")

    @trainer.on(Events.EPOCH_COMPLETED(every=config.Training.save_every))
    def log_ar(engine):
        ar_evaluator.run(val_loader)
        y_pred, y = ar_evaluator.state.output
        fig = plot_output(y, y_pred)
        writer.add_figure("eval/ar", fig, engine.state.epoch)
        plt.close()

    # desc = "ITERATION - loss: {:.2f}"
    # pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED(every=config.Training.log_every))
    def log_training_loss(engine):
        # pbar.desc = desc.format(engine.state.output)
        # pbar.update(log_interval)
        if args.verbose:
            grad_norm = torch.stack(
                [p.grad.norm() for p in model.parameters()]).sum()
            writer.add_scalar("train/grad_norm", grad_norm,
                              engine.state.iteration)
        writer.add_scalar("train/loss", engine.state.output,
                          engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED(every=config.Training.eval_every))
    def log_training_results(engine):
        # pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        for k, v in metrics.items():
            writer.add_scalar(f"train/{k}", v, engine.state.epoch)
        # tqdm.write(
        #    f"Training Results - Epoch: {engine.state.epoch}  Avg mse: {evaluator.state.metrics['mse']:.2f}"
        # )

    @trainer.on(Events.EPOCH_COMPLETED(every=config.Training.eval_every))
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics

        for k, v in metrics.items():
            writer.add_scalar(f"eval/{k}", v, engine.state.epoch)
        # tqdm.write(
        #    f"Validation Results - Epoch: {engine.state.epoch}  Avg mse: {evaluator.state.metrics['mse']:.2f}"
        # )

        # pbar.n = pbar.last_print_n = 0

        y_pred, y = evaluator.state.output

        fig = plot_output(y, y_pred)
        writer.add_figure("eval/preds", fig, engine.state.epoch)
        plt.close()

    # @trainer.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
    # def log_time(engine):
    #    #tqdm.write(
    #    #    f"{trainer.last_event_name.name} took {trainer.state.times[trainer.last_event_name.name]} seconds"
    #    #)
    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)
        ModelCheckpoint.load_objects({"model": model}, ckpt)

    try:
        trainer.run(train_loader, max_epochs=config.Training.max_epochs)
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    # pbar.close()
    writer.close()
Beispiel #15
0
test_dataset = TrainValTestDataset(image_dataset, mode="test")
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=args.batchsize,
                         num_workers=num_workers)

model = Model(number_of_classes=number_of_classes)

optimizer = optim.Adam(model.parameters(), lr=args.learningrate)

trainer = create_supervised_trainer(model, optimizer, criterion, device=device)

metrics = {
    "accuracy":
    Accuracy(),
    "MAE":
    MeanAbsoluteError(
        output_transform=lambda out: (torch.max(out[0], dim=1)[1], out[1])),
    "MSE":
    MeanSquaredError(
        output_transform=lambda out: (torch.max(out[0], dim=1)[1], out[1])),
    "loss":
    Loss(loss_fn=criterion)
}

evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)


@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(trainer):
    print(
        f"Training (Epoch {trainer.state.epoch}): {trainer.state.output:.3f}")
Beispiel #16
0
def test_zero_div():
    mae = MeanAbsoluteError()
    with pytest.raises(NotComputableError):
        mae.compute()
Beispiel #17
0
    def train(self, config, **kwargs):
        config_parameters = parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            'run',
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: -engine.state.metrics['Loss'],
            save_as_state_dict=False,
            score_name='loss')

        train_kaldi_string = parsecopyfeats(
            config_parameters['trainfeatures'],
            **config_parameters['feature_args'])
        dev_kaldi_string = parsecopyfeats(config_parameters['devfeatures'],
                                          **config_parameters['feature_args'])
        logger = genlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Experiment is stored in {}".format(outputdir))
        for line in pformat(config_parameters).split('\n'):
            logger.info(line)
        scaler = getattr(
            pre,
            config_parameters['scaler'])(**config_parameters['scaler_args'])
        inputdim = -1
        logger.info("<== Estimating Scaler ({}) ==>".format(
            scaler.__class__.__name__))
        for _, feat in kaldi_io.read_mat_ark(train_kaldi_string):
            scaler.partial_fit(feat)
            inputdim = feat.shape[-1]
        assert inputdim > 0, "Reading inputstream failed"
        logger.info("Features: {} Input dimension: {}".format(
            config_parameters['trainfeatures'], inputdim))
        logger.info("<== Labels ==>")
        train_label_df = pd.read_csv(
            config_parameters['trainlabels']).set_index('Participant_ID')
        dev_label_df = pd.read_csv(
            config_parameters['devlabels']).set_index('Participant_ID')
        train_label_df.index = train_label_df.index.astype(str)
        dev_label_df.index = dev_label_df.index.astype(str)
        # target_type = ('PHQ8_Score', 'PHQ8_Binary')
        target_type = ('PHQ8_Score', 'PHQ8_Binary')
        n_labels = len(target_type)  # PHQ8 + Binary
        # Scores and their respective PHQ8
        train_labels = train_label_df.loc[:, target_type].T.apply(
            tuple).to_dict()
        dev_labels = dev_label_df.loc[:, target_type].T.apply(tuple).to_dict()
        train_dataloader = create_dataloader(
            train_kaldi_string,
            train_labels,
            transform=scaler.transform,
            shuffle=True,
            **config_parameters['dataloader_args'])
        cv_dataloader = create_dataloader(
            dev_kaldi_string,
            dev_labels,
            transform=scaler.transform,
            shuffle=False,
            **config_parameters['dataloader_args'])
        model = getattr(models, config_parameters['model'])(
            inputdim=inputdim,
            output_size=n_labels,
            **config_parameters['model_args'])
        if 'pretrain' in config_parameters:
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrain']))
            pretrained_model = torch.load(config_parameters['pretrain'],
                                          map_location=lambda st, loc: st)
            if 'Attn' in pretrained_model.__class__.__name__:
                model.lstm.load_state_dict(pretrained_model.lstm.state_dict())
            else:
                model.net.load_state_dict(pretrained_model.net.state_dict())
        logger.info("<== Model ==>")
        for line in pformat(model).split('\n'):
            logger.info(line)
        criterion = getattr(
            losses,
            config_parameters['loss'])(**config_parameters['loss_args'])
        optimizer = getattr(torch.optim, config_parameters['optimizer'])(
            list(model.parameters()) + list(criterion.parameters()),
            **config_parameters['optimizer_args'])
        poolingfunction = parse_poolingfunction(
            config_parameters['poolingfunction'])
        criterion = criterion.to(device)
        model = model.to(device)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                outputs, targets = Runner._forward(model, batch,
                                                   poolingfunction)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return Runner._forward(model, batch, poolingfunction)

        def meter_transform(output):
            y_pred, y = output
            # y_pred is of shape [Bx2] (0 = MSE, 1 = BCE)
            # y = is of shape [Bx2] (0=Mse, 1 = BCE)
            return torch.sigmoid(y_pred[:, 1]).round(), y[:, 1].long()

        precision = Precision(output_transform=meter_transform, average=False)
        recall = Recall(output_transform=meter_transform, average=False)
        F1 = (precision * recall * 2 / (precision + recall)).mean()
        metrics = {
            'Loss':
            Loss(criterion),
            'Recall':
            Recall(output_transform=meter_transform, average=True),
            'Precision':
            Precision(output_transform=meter_transform, average=True),
            'MAE':
            MeanAbsoluteError(
                output_transform=lambda out: (out[0][:, 0], out[1][:, 0])),
            'F1':
            F1
        }

        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)
        RunningAverage(output_transform=lambda x: x).attach(
            train_engine, 'run_loss')
        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine, ['run_loss'])

        scheduler = getattr(torch.optim.lr_scheduler,
                            config_parameters['scheduler'])(
                                optimizer,
                                **config_parameters['scheduler_args'])
        early_stop_handler = EarlyStopping(
            patience=5,
            score_function=lambda engine: -engine.state.metrics['Loss'],
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                               'scaler': scaler,
                                               'config': config_parameters
                                           })

        @train_engine.on(Events.EPOCH_COMPLETED)
        def compute_metrics(engine):
            inference_engine.run(cv_dataloader)
            validation_string_list = [
                "Validation Results - Epoch: {:<3}".format(engine.state.epoch)
            ]
            for metric in metrics:
                validation_string_list.append("{}: {:<5.2f}".format(
                    metric, inference_engine.state.metrics[metric]))
            logger.info(" ".join(validation_string_list))

            pbar.n = pbar.last_print_n = 0

        @inference_engine.on(Events.COMPLETED)
        def update_reduce_on_plateau(engine):
            val_loss = engine.state.metrics['Loss']
            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
                scheduler.step(val_loss)
            else:
                scheduler.step()

        train_engine.run(train_dataloader,
                         max_epochs=config_parameters['epochs'])
        # Return for further processing
        return outputdir
def MAEMetric(key):
    """Create max absolute error metric on key."""
    return DictMetric(key, MeanAbsoluteError())
Beispiel #19
0
def run(
    train_batch_size: int,
    val_batch_size: int,
    epochs: int,
    lr: float,
    model_name: str,
    architecture: str,
    momentum: float,
    log_interval: int,
    log_dir: str,
    save_dir: str,
    save_step: int,
    val_step: int,
    num_workers: int,
    patience: int,
    eval_only: bool = False,
    overfit_on_few_samples: bool = False,
):
    train_loader, val_loader, test_loader = get_data_loaders(
        train_batch_size,
        val_batch_size,
        num_workers=num_workers,
        overfit_on_few_samples=overfit_on_few_samples,
    )

    models_available = {'convmos': ConvMOS}

    model = models_available[model_name](architecture=architecture)
    writer = create_summary_writer(model, train_loader, log_dir)
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'

    model = model.to(device=device)

    # E-OBS only provides observational data for land so we need to use a mask to avoid fitting on the sea
    land_mask_np = np.load('remo_eobs_land_mask.npy')
    # Convert booleans to 1 and 0, and convert numpy array to torch Tensor
    land_mask = torch.from_numpy(1 * land_mask_np).to(device)
    print('Land mask:')
    print(land_mask)
    loss_fn = partial(masked_mse_loss, mask=land_mask)

    optimizer = Adam(model.parameters(), lr=lr)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device)

    metrics = {
        'rmse': RootMeanSquaredError(),
        'mae': MeanAbsoluteError(),
        'mse': Loss(loss_fn),
    }
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=metrics,
                                                device=device)

    to_save = {'model': model, 'optimizer': optimizer, 'trainer': trainer}
    checkpoint_handler = Checkpoint(
        to_save,
        DiskSaver(save_dir, create_dir=True, require_empty=False),
        n_saved=2,
        global_step_transform=global_step_from_engine(trainer),
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=save_step),
                              checkpoint_handler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    def score_function(engine):
        val_loss = engine.state.metrics['mse']
        return -val_loss

    best_checkpoint_handler = Checkpoint(
        to_save,
        DiskSaver(save_dir, create_dir=True, require_empty=False),
        n_saved=2,
        filename_prefix='best',
        score_function=score_function,
        score_name='val_loss',
        global_step_transform=global_step_from_engine(trainer),
    )
    val_evaluator.add_event_handler(Events.COMPLETED, best_checkpoint_handler)

    earlystop_handler = EarlyStopping(patience=patience,
                                      score_function=score_function,
                                      trainer=trainer)
    val_evaluator.add_event_handler(Events.COMPLETED, earlystop_handler)

    # Maybe load model
    checkpoint_files = glob(join(save_dir, 'checkpoint_*.pt'))
    if len(checkpoint_files) > 0:
        # latest_checkpoint_file = sorted(checkpoint_files)[-1]
        epoch_list = [
            int(c.split('.')[0].split('_')[-1]) for c in checkpoint_files
        ]
        last_epoch = sorted(epoch_list)[-1]
        latest_checkpoint_file = join(save_dir, f'checkpoint_{last_epoch}.pt')
        print('Loading last checkpoint', latest_checkpoint_file)
        last_epoch = int(latest_checkpoint_file.split('.')[0].split('_')[-1])
        if last_epoch >= epochs:
            print('Training was already completed')
            eval_only = True
            # return

        checkpoint = torch.load(latest_checkpoint_file, map_location=device)
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                  "".format(engine.state.epoch, iter, len(train_loader),
                            engine.state.output))
            writer.add_scalar("training/loss", engine.state.output,
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        avg_rmse = metrics['rmse']
        avg_mae = metrics['mae']
        avg_mse = metrics['mse']
        print(
            "Training Results - Epoch: {}  Avg RMSE: {:.2f} Avg loss: {:.2f} Avg MAE: {:.2f}"
            .format(engine.state.epoch, avg_rmse, avg_mse, avg_mae))
        writer.add_scalar("training/avg_loss", avg_mse, engine.state.epoch)
        writer.add_scalar("training/avg_rmse", avg_rmse, engine.state.epoch)
        writer.add_scalar("training/avg_mae", avg_mae, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED(every=val_step))
    def log_validation_results(engine):
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_rmse = metrics['rmse']
        avg_mae = metrics['mae']
        avg_mse = metrics['mse']
        print(
            "Validation Results - Epoch: {}  Avg RMSE: {:.2f} Avg loss: {:.2f} Avg MAE: {:.2f}"
            .format(engine.state.epoch, avg_rmse, avg_mse, avg_mae))
        writer.add_scalar("validation/avg_loss", avg_mse, engine.state.epoch)
        writer.add_scalar("validation/avg_rmse", avg_rmse, engine.state.epoch)
        writer.add_scalar("validation/avg_mae", avg_mae, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED(every=save_step))
    def log_model_weights(engine):
        for name, param in model.named_parameters():
            writer.add_histogram(f"model/weights_{name}", param,
                                 engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED(every=save_step))
    def regularly_predict_val_data(engine):
        predict_data(engine.state.epoch, val_loader)

    def predict_data(epoch: int, data_loader) -> xr.Dataset:
        # Predict all test data points and write the predictions
        print(f'Predicting {data_loader.dataset.mode} data...')
        data_loader_iter = iter(data_loader)
        pred_np = None
        for i in range(len(data_loader)):
            x, y = next(data_loader_iter)
            # print(x)
            pred = (model.forward(x.to(device=device)).to(
                device='cpu').detach().numpy()[:, 0, :, :])
            # print('=======================================')
            # print(pred)
            if pred_np is None:
                pred_np = pred
            else:
                pred_np = np.concatenate((pred_np, pred), axis=0)

        preds = xr.Dataset(
            {
                'pred': (['time', 'lat', 'lon'], pred_np),
                'input': (['time', 'lat', 'lon'], data_loader.dataset.X),
                'target':
                (['time', 'lat', 'lon'], data_loader.dataset.Y[:, :, :, 0]),
            },
            coords={
                'time': data_loader.dataset.
                times,  # list(range(len(val_loader.dataset))),
                'lon_var': (
                    ('lat', 'lon'),
                    data_loader.dataset.lons[0],
                ),  # list(range(x.shape[-2])),
                'lat_var': (('lat', 'lon'), data_loader.dataset.lats[0]),
            },  # list(range(x.shape[-1]))}
        )

        preds.to_netcdf(
            join(save_dir,
                 f'predictions_{data_loader.dataset.mode}_{epoch}.nc'))
        return preds

    # kick everything off
    if not eval_only:
        trainer.run(train_loader, max_epochs=epochs)

    # Load best model
    best_checkpoint = best_checkpoint_handler.last_checkpoint
    print('Loading best checkpoint from', best_checkpoint)
    checkpoint = torch.load(join(save_dir,
                                 best_checkpoint_handler.last_checkpoint),
                            map_location=device)
    Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    writer.close()

    val_preds = predict_data(trainer.state.epoch, val_loader)
    test_preds = predict_data(trainer.state.epoch, test_loader)
    val_res = mean_metrics(calculate_metrics(val_preds.pred, val_preds.target))
    test_res = mean_metrics(
        calculate_metrics(test_preds.pred, test_preds.target))

    # val_evaluator.run(val_loader)
    results = {}
    # Store the config, ...
    results.update({
        section_name: dict(config[section_name])
        for section_name in config.sections()
    })
    # ... the last training metrics,
    results.update(
        {f'train_{k}': v
         for k, v in train_evaluator.state.metrics.items()})
    # ... the last validation metrics from torch,
    results.update(
        {f'val_torch_{k}': v
         for k, v in val_evaluator.state.metrics.items()})
    # ... the validation metrics that I calculate,
    results.update({f'val_{k}': v for k, v in val_res.items()})
    # ... asnd the test metrics that I calculate
    results.update({f'test_{k}': v for k, v in test_res.items()})
    write_results_file(join('results', 'results.json'),
                       pd.json_normalize(results))