示例#1
0
 def log_validation_loss(trainer):
     evaluator.run(ignite_selected(valid_loader, slice_size=slice_size))
     metrics = evaluator.state.metrics
     print("Epoch {:3d} Valid loss: {:8.6f} ←".format(
         trainer.state.epoch, metrics['loss']))
     trainer.state.dataloader = ignite_selected(train_loader,
                                                slice_size=slice_size,
                                                **filter_dict(
                                                    hyper_params,
                                                    "ignite_selected"))
     nonlocal best_loss
     best_loss = min(best_loss, metrics['loss'])
示例#2
0
def main():
    parser = argparse.ArgumentParser(description="Evaluate the Nero + kNN model combination")
    parser.add_argument("-s",
                        "--split",
                        default="validation",
                        choices={"validation", "test"},
                        help="data split (for 'test' it only predicts)")
    parser.add_argument("-c",
                        "--city",
                        required=True,
                        choices=CITIES,
                        help="which city to evaluate")
    parser.add_argument("--overwrite",
                        default=False,
                        action="store_true",
                        help="overwrite existing predictions if they exist")
    parser.add_argument("-v",
                        "--verbose",
                        action="count",
                        help="verbosity level")
    args = parser.parse_args()

    if args.verbose:
        print(args)

    transforms = [
        lambda x: x.float(),
        lambda x: x / 255,
        src.dataset.Traffic4CastSample.Transforms.Permute("TCHW"),
        src.dataset.Traffic4CastSample.Transforms.SelectChannels(CHANNELS),
    ]
    dataset = src.dataset.Traffic4CastDataset(ROOT, args.split, [args.city], transforms)

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        collate_fn=src.dataset.Traffic4CastDataset.collate_list)

    to_str = lambda v: f"{v:10.7f}"

    losses = []
    device = "cuda"
    non_blocking = False
    output_transform = lambda x, y, y_pred: (y_pred, y,)
    slice_size = P + F

    dirname = get_prediction_folder(args.split, MODEL_NAME, args.city)
    os.makedirs(dirname, exist_ok=True)

    loss = nn.MSELoss()

    def prepare_batch(batch, device, non_blocking):
        batch1, date, frames = batch
        batch1 = batch1.to(device)
        batch1 = batch1.reshape(batch1.shape[0], -1, batch1.shape[3], batch1.shape[4])
        inp = batch1[:, :P * C], date, frames
        tgt = batch1[:, P * C:]
        return inp, tgt

    def _inference(batch):
        with torch.no_grad():
            x, y = prepare_batch(batch, device=DEVICE, non_blocking=non_blocking)
            y_pred = model(x, get_path1)
            return output_transform(x, y, y_pred)

    def get_path(date):
        return os.path.join(dirname, src.dataset.date_to_path(date))

    def get_path1(model_name, date):
        str_channels = "_".join(CHANNELS)
        if model_name == "nero" and args.city == "Berlin":
            model_name = "nero-8-32_" + str_channels + "_" + args.city
        elif model_name == "nero" and args.city == "Moscow":
            model_name = "nero-8-64_" + str_channels + "_" + args.city
        elif model_name == "nero" and args.city == "Istanbul":
            model_name = "nero-4-64_" + str_channels + "_" + args.city
        dirname = get_prediction_folder(args.split, model_name, args.city)
        return os.path.join(dirname, src.dataset.date_to_path(date))

    for batch in ignite_selected(loader, slice_size=slice_size):
        output = _inference(batch)
        curr_loss = loss(output[0], output[1]).item()
        losses.append(curr_loss)
        if not os.path.exists(get_path(batch[1])) or args.overwrite:
            path = get_path(batch[1])
            data = to_uint8(output[0])
            data = data.view(5, 3, 3, 495, 436)
            data = data.permute(0, 1, 3, 4, 2)  # Move channels at the end
            data = data.cpu().numpy().astype(np.uint8)
            submission_write.write_data(data, path)
        if args.verbose:
            diff = ((output[0] - output[1]).view(5, 3, 3, H, W) ** 2)
            print("T", diff[:, 0].mean(dim=(1, 2, 3)))
            print("F", diff.mean(dim=(0, 2, 3, 4)))
            print(to_str(curr_loss))

        # sys.exit()

    print(to_str(np.mean(losses)))
示例#3
0
def train(args, hyper_params):

    print(args)
    print(hyper_params)

    args.channels.sort(
        key=lambda x: src.dataset.Traffic4CastSample.channel_to_index[x])

    model = MODELS[args.model_type](**filter_dict(hyper_params, "model"))
    slice_size = model.past + model.future

    assert model.future == 3

    if args.model is not None:
        model_path = args.model
        model_name = os.path.basename(args.model)
        model.load(model_path)
    else:
        model_name = f"{args.model_type}_" + "_".join(args.channels +
                                                      args.cities)
        model_path = f"output/models/{model_name}.pth"

    if model.num_channels != len(args.channels):
        print(f"ERROR: Model to channels missmatch. Model can predict "
              f"{model.num_channels} channels. {len(args.channels)} were "
              "selected.")
        sys.exit(1)

    transforms = [
        lambda x: x.float(),
        lambda x: x / 255,
        src.dataset.Traffic4CastSample.Transforms.Permute("TCHW"),
        src.dataset.Traffic4CastSample.Transforms.SelectChannels(
            args.channels),
    ]
    train_dataset = src.dataset.Traffic4CastDataset(ROOT, "training",
                                                    args.cities, transforms)
    valid_dataset = src.dataset.Traffic4CastDataset(ROOT, "validation",
                                                    args.cities, transforms)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        collate_fn=src.dataset.Traffic4CastDataset.collate_list,
        shuffle=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=1,
        collate_fn=src.dataset.Traffic4CastDataset.collate_list,
        shuffle=False)

    ignite_train = ignite_selected(
        train_loader,
        slice_size=slice_size,
        **filter_dict(hyper_params, "ignite_selected"),
    )

    optimizer = torch.optim.Adam(
        model.parameters(),
        **filter_dict(hyper_params, "optimizer"),
    )
    loss = nn.MSELoss()

    best_loss = 1.0

    device = args.device
    if device.find('cuda') != -1 and not torch.cuda.is_available():
        device = 'cpu'
    trainer = engine.create_supervised_trainer(
        model,
        optimizer,
        loss,
        device=device,
        prepare_batch=model.ignite_batch)
    evaluator = engine.create_supervised_evaluator(
        model,
        metrics={'loss': ignite.metrics.Loss(loss)},
        device=device,
        prepare_batch=model.ignite_batch)

    @trainer.on(engine.Events.ITERATION_COMPLETED)
    def log_training_loss(trainer):
        print("Epoch {:3d} Train loss: {:8.6f}".format(trainer.state.epoch,
                                                       trainer.state.output))

    @trainer.on(engine.Events.EPOCH_COMPLETED)
    def log_validation_loss(trainer):
        evaluator.run(ignite_selected(valid_loader, slice_size=slice_size))
        metrics = evaluator.state.metrics
        print("Epoch {:3d} Valid loss: {:8.6f} ←".format(
            trainer.state.epoch, metrics['loss']))
        trainer.state.dataloader = ignite_selected(train_loader,
                                                   slice_size=slice_size,
                                                   **filter_dict(
                                                       hyper_params,
                                                       "ignite_selected"))
        nonlocal best_loss
        best_loss = min(best_loss, metrics['loss'])

    if "learning-rate-scheduler" in args.callbacks:
        lr_reduce = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   verbose=args.verbose,
                                                   **LR_REDUCE_PARAMS)

        @evaluator.on(engine.Events.COMPLETED)
        def update_lr_reduce(engine):
            loss = engine.state.metrics['loss']
            lr_reduce.step(loss)

    def score_function(engine):
        return -engine.state.metrics['loss']

    if "early-stopping" in args.callbacks:
        early_stopping_handler = ignite.handlers.EarlyStopping(
            patience=PATIENCE, score_function=score_function, trainer=trainer)
        evaluator.add_event_handler(engine.Events.EPOCH_COMPLETED,
                                    early_stopping_handler)

    if "model-checkpoint" in args.callbacks:
        checkpoint_handler = ignite.handlers.ModelCheckpoint(
            "output/models/checkpoints",
            model_name,
            score_function=score_function,
            n_saved=1,
            require_empty=False,
            create_dir=True)
        evaluator.add_event_handler(engine.Events.EPOCH_COMPLETED,
                                    checkpoint_handler, {"model": model})

    if "tensorboard" in args.callbacks:
        logger = tensorboard_logger.TensorboardLogger(
            log_dir=f"output/tensorboard/{model_name}")
        logger.attach(trainer,
                      log_handler=tensorboard_logger.OutputHandler(
                          tag="training",
                          output_transform=lambda loss: {'loss': loss}),
                      event_name=engine.Events.ITERATION_COMPLETED)
        logger.attach(evaluator,
                      log_handler=tensorboard_logger.OutputHandler(
                          tag="validation",
                          metric_names=["loss"],
                          another_engine=trainer),
                      event_name=engine.Events.EPOCH_COMPLETED)

    trainer.run(ignite_train, **filter_dict(hyper_params, "trainer_run"))

    if "save-model" in args.callbacks and not "model-checkpoint" in args.callbacks:
        torch.save(model.state_dict(), model_path)
        print("Model saved at:", model_path)
    elif "save-model" in args.callbacks:
        # Move best model from checkpoint directory to output/models
        checkpoints_dir = "output/models/checkpoints"
        source, *_ = [
            f for f in reversed(utils.sorted_ls(checkpoints_dir))
            if f.startswith(model_name)
        ]  # get most recent model
        os.rename(os.path.join(checkpoints_dir, source), model_path)
        print("Model saved at:", model_path)

    return {
        'loss': best_loss,  # HpBandSter always minimizes!
        'info': {
            'args': vars(args),
            'hyper-params': hyper_params,
        },
    }
示例#4
0
def main():
    parser = argparse.ArgumentParser(description="Evaluate a given model")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        required=True,
                        choices=MODELS,
                        help="which model to use")
    parser.add_argument("-p",
                        "--model-path",
                        type=str,
                        help="path to the saved model")
    parser.add_argument("-s",
                        "--split",
                        default="validation",
                        choices={"validation", "test"},
                        help="data split (for 'test' it only predicts)")
    parser.add_argument("-c",
                        "--city",
                        required=True,
                        choices=CITIES,
                        help="which city to evaluate")
    parser.add_argument("--overwrite",
                        default=False,
                        action="store_true",
                        help="overwrite existing predictions if they exist")
    parser.add_argument("--channels",
                        nargs='+',
                        help="List of channels to predict")
    parser.add_argument(
        "--hyper-params",
        required=False,
        help=("path to JSON file containing hyper-parameter configuration "
              "(over-writes other hyper-parameters passed through the "
              "command line)."),
    )
    parser.add_argument("-v",
                        "--verbose",
                        action="count",
                        help="verbosity level")
    args = parser.parse_args()
    args.channels.sort(
        key=lambda x: src.dataset.Traffic4CastSample.channel_to_index[x])

    if args.verbose:
        print(args)

    if args.hyper_params and os.path.exists(args.hyper_params):
        with open(args.hyper_params, "r") as f:
            hyper_params = json.load(f)
    else:
        hyper_params = {}

    Model = MODELS[args.model]
    model = Model(**filter_dict(hyper_params, "model"))

    loss = nn.MSELoss()

    if args.model_path:
        model.load_state_dict(
            torch.load(args.model_path, map_location="cuda:0"))
        model.eval()

    transforms = [
        lambda x: x.float(),
        lambda x: x / 255,
        src.dataset.Traffic4CastSample.Transforms.Permute("TCHW"),
        src.dataset.Traffic4CastSample.Transforms.SelectChannels(
            args.channels),
    ]
    dataset = src.dataset.Traffic4CastDataset(ROOT, args.split, [args.city],
                                              transforms)

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        collate_fn=src.dataset.Traffic4CastDataset.collate_list)

    if args.model_path:
        model_name, _ = os.path.splitext(os.path.basename(args.model_path))
    else:
        model_name = args.model

    to_str = lambda v: f"{v:9.7f}"

    losses = []
    device = "cuda"
    prepare_batch = model.ignite_batch
    non_blocking = False
    output_transform = lambda x, y, y_pred: (
        y_pred,
        y,
    )
    slice_size = model.past + model.future

    dirname = get_prediction_folder(args.split, model_name, args.city)
    os.makedirs(dirname, exist_ok=True)

    if device:
        model.to(device)

    def _inference(batch):
        model.eval()
        with torch.no_grad():
            x, y = prepare_batch(batch,
                                 device=device,
                                 non_blocking=non_blocking)
            y_pred = model(x)
            return output_transform(x, y, y_pred)

    def get_path(date):
        return os.path.join(dirname, src.dataset.date_to_path(date))

    outputs = []

    for batch in ignite_selected(loader, slice_size=slice_size):
        output = _inference(batch)
        curr_loss = loss(round_torch(output[0]), output[1]).item()
        losses.append(curr_loss)
        outputs.append(output)
        if not os.path.exists(get_path(batch[1])) or args.overwrite:
            path = get_path(batch[1])
            data = to_uint8(output[0])
            data = data.view(5, 3, 3, 495, 436)
            data = data.permute(0, 1, 3, 4, 2)
            data = data.cpu().numpy().astype(np.uint8)
            submission_write.write_data(data, path)
        if args.verbose:
            print(to_str(curr_loss))
        # sys.exit()

    print(to_str(np.mean(losses)))

    if args.verbose:
        # Show per channel results.
        S = 5, 3, 3, 495, 436
        pred = torch.stack([p.view(*S) for p, _ in outputs])
        pred = round_torch(pred)
        true = torch.stack([t.view(*S) for _, t in outputs])
        diff = (pred - true)**2
        diff = diff.mean(dim=(0, 1, 2, 4, 5))
        rows = [[f"{v:.4f}" for v in diff.cpu().numpy()] +
                ["{:.4f}".format(diff.mean().item())]]
        print(tabulate(rows, tablefmt="latex"))