Пример #1
0
def run_eval_suite(name,
                   dest_task=tasks.normal,
                   graph_file=None,
                   model_file=None,
                   logger=None,
                   sample=800,
                   show_images=False,
                   old=False):

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif model_file is not None:
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.normal,
                         dest_task=dest_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    dataset = ValidationMetrics("almena", dest_task=dest_task)
    result = dataset.evaluate(model, sample=800)
    logger.text(name + ": " + str(result))
Пример #2
0
def run_viz_suite(name,
                  data,
                  dest_task=tasks.depth_zbuffer,
                  graph_file=None,
                  model_file=None,
                  logger=None,
                  old=False,
                  multitask=False,
                  percep_mode=None):

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        print('here')
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # DATA LOADING 1
    results = model.predict(data)[:, -3:].clamp(min=0, max=1)
    if results.shape[1] == 1:
        results = torch.cat([results] * 3, dim=1)

    if percep_mode:
        percep_model = Transfer(src_task=dest_task,
                                dest_task=tasks.normal).load_model()
        percep_model.eval()
        eval_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(results),
            batch_size=16,
            num_workers=16,
            shuffle=False,
            pin_memory=True)
        final_preds = []
        for preds, in eval_loader:
            print('preds shape', preds.shape)
            final_preds += [percep_model.forward(preds[:, -3:])]
        results = torch.cat(final_preds, dim=0)

    return results
Пример #3
0
def run_perceptual_eval_suite(name,
                              intermediate_task=tasks.normal,
                              dest_task=tasks.normal,
                              graph_file=None,
                              model_file=None,
                              logger=None,
                              sample=800,
                              show_images=False,
                              old=False,
                              perceptual_transfer=None,
                              multitask=False):

    if perceptual_transfer is None:
        percep_model = Transfer(src_task=intermediate_task,
                                dest_task=dest_task).load_model()

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, intermediate_task],
                          pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, intermediate_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        print('running multitask')
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.rgb,
                         dest_task=intermediate_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    dataset = ValidationMetrics("almena", dest_task=dest_task)
    result = dataset.evaluate_with_percep(model,
                                          sample=800,
                                          percep_model=percep_model)
    logger.text(name + ": " + str(result))
Пример #4
0
def run_viz_suite(name,
                  data_loader,
                  dest_task=tasks.depth_zbuffer,
                  graph_file=None,
                  model_file=None,
                  old=False,
                  multitask=False,
                  percep_mode=None,
                  downsample=6,
                  out_channels=3,
                  final_task=tasks.normal,
                  oldpercep=False):

    extra_task = [final_task] if percep_mode else []

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task] + extra_task,
                          pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        # downsample = 5 or 6
        print('loading main model')
        #model = DataParallelModel.load(UNetReshade(downsample=downsample,  out_channels=out_channels).cuda(), model_file)
        model = DataParallelModel.load(
            UNet(downsample=downsample, out_channels=out_channels).cuda(),
            model_file)
        #model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = DummyModel(
            Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model())

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # DATA LOADING 1
    results = []
    final_preds = []

    if percep_mode:
        print('Loading percep model...')
        if graph_file is not None and not oldpercep:
            percep_model = graph.edge(dest_task, final_task).load_model()
            percep_model.compile(torch.optim.Adam,
                                 lr=3e-4,
                                 weight_decay=2e-6,
                                 amsgrad=True)
        else:
            percep_model = Transfer(src_task=dest_task,
                                    dest_task=final_task).load_model()
        percep_model.eval()

    print("Converting...")
    for data, in data_loader:
        preds = model.predict_on_batch(data)[:, -3:].clamp(min=0, max=1)
        results.append(preds.detach().cpu())
        if percep_mode:
            try:
                final_preds += [
                    percep_model.forward(preds[:, -3:]).detach().cpu()
                ]
            except RuntimeError:
                preds = torch.cat([preds] * 3, dim=1)
                final_preds += [
                    percep_model.forward(preds[:, -3:]).detach().cpu()
                ]
        #break

    if percep_mode:
        results = torch.cat(final_preds, dim=0)
    else:
        results = torch.cat(results, dim=0)

    return results
Пример #5
0
def main(
    mode="standard", visualize=False,
    pretrained=True, finetuned=False, batch_size=None, 
    **kwargs,
):

    configs = {
        "VISUALS3_rgb2normals2x_multipercep8_winrate_standardized_upd": dict(
            loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
            cont="mount/shared/results_LBP_multipercep8_winrate_standardized_upd_3/graph.pth",
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2reshade2x_latwinrate_reshadetarget": dict(
            loss_configs=["baseline_reshade_size256", "baseline_reshade_size320", "baseline_reshade_size384", "baseline_reshade_size448", "baseline_reshade_size512"],
            cont="mount/shared/results_LBP_multipercep_latwinrate_reshadingtarget_6/graph.pth",
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2reshade2x_reshadebaseline": dict(
            loss_configs=["baseline_reshade_size256", "baseline_reshade_size320", "baseline_reshade_size384", "baseline_reshade_size448", "baseline_reshade_size512"],
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2reshade2x_latwinrate_depthtarget": dict(
            loss_configs=["baseline_depth_size256", "baseline_depth_size320", "baseline_depth_size384", "baseline_depth_size448", "baseline_depth_size512"],
            cont="mount/shared/results_LBP_multipercep_latwinrate_reshadingtarget_6/graph.pth",
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2reshade2x_depthbaseline": dict(
            loss_configs=["baseline_depth_size256", "baseline_depth_size320", "baseline_depth_size384", "baseline_depth_size448", "baseline_depth_size512"],
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2normals2x_baseline": dict(
            loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2normals2x_multipercep": dict(
            loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
            test=True, ood=True, oodfull=False,
            cont="mount/shared/results_LBP_multipercep_32/graph.pth",
        ),
        "VISUALS3_rgb2x2normals_baseline": dict(
            loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"],
            finetuned=False,
            test=True, ood=True, ood_full=False,
        ),
        "VISUALS3_rgb2x2normals_finetuned": dict(
            loss_configs=["rgb2x2normals_plots", "rgb2x_plots2normals_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"],
            finetuned=True,
            test=True, ood=True, ood_full=False,
        ),
        "VISUALS3_rgb2x_baseline": dict(
            loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"],
            finetuned=False,
            test=True, ood=True, ood_full=False,
        ),
        "VISUALS3_rgb2x_finetuned": dict(
            loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"],
            finetuned=True,
            test=True, ood=True, ood_full=False,
        ),
    }

    # configs = {
    #   "VISUALS_rgb2normals2x_latv2": dict(
    #       loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #       cont="mount/shared/results_LBP_multipercep_latv2_10/graph.pth",
    #   ),
    #   "VISUALS_rgb2normals2x_lat_winrate": dict(
    #     loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #     cont="mount/shared/results_LBP_multipercep_lat_winrate_8/graph.pth",
    #   ),
    #   "VISUALS_rgb2normals2x_multipercep": dict(
    #     loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #     cont="mount/shared/results_LBP_multipercep_32/graph.pth",
    #   ),
    #   "VISUALS_rgb2normals2x_rndv2": dict(
    #     loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #     cont="mount/shared/results_LBP_multipercep_rnd_11/graph.pth",
    #   ),
    #   "VISUALS_rgb2normals2x_baseline": dict(
    #     loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #     cont=None,
    #   ),
    #   "VISUALS_rgb2x2normals_baseline": dict(
    #     loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"],
    #     finetuned=False,
    #   ),
    #   "VISUALS_rgb2x2normals_finetuned": dict(
    #     loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"],
    #     finetuned=True,
    #   ),
    #   "VISUALS_y2normals_baseline": dict(
    #     loss_configs=["y2normals_plots", "y2normals_plots_size320", "y2normals_plots_size384", "y2normals_plots_size448", "y2normals_plots_size512"],
    #     finetuned=False,
    #   ),
    #   "VISUALS_y2normals_finetuned": dict(
    #     loss_configs=["y2normals_plots", "y2normals_plots_size320", "y2normals_plots_size384", "y2normals_plots_size448", "y2normals_plots_size512"],
    #     finetuned=True,
    #   ),
    #   "VISUALS_rgb2x_baseline": dict(
    #     loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"],
    #     finetuned=False,
    #   ),
    #   "VISUALS_rgb2x_finetuned": dict(
    #     loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"],
    #     finetuned=True,
    #   ),
    # }

    for i in range(0, 5):

        config = configs[list(configs.keys())[0]]

        finetuned = config.get("finetuned", False)
        loss_configs = config["loss_configs"]

        loss_config = loss_configs[i]

        batch_size = batch_size or 32
        energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

        # DATA LOADING 1
        test_set = load_test(energy_loss.get_tasks("test"), sample=8)

        ood_tasks = [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb']
        ood_set = load_ood(ood_tasks, sample=4)
        print (ood_tasks)
        
        test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
        ood = RealityTask.from_static("ood", ood_set, ood_tasks)

        # DATA LOADING 2
        ood_tasks = list(set([tasks.rgb] + [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb']))
        test_set = load_test(ood_tasks, sample=2)
        ood_set = load_ood(ood_tasks)

        test2 = RealityTask.from_static("test", test_set, ood_tasks)
        ood2 = RealityTask.from_static("ood", ood_set, ood_tasks)

        # DATA LOADING 3
        test_set = load_test(energy_loss.get_tasks("test"), sample=8)
        ood_tasks = [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb']

        ood_loader = torch.utils.data.DataLoader(
            ImageDataset(tasks=ood_tasks, data_dir=f"{SHARED_DIR}/ood_images"),
            batch_size=32,
            num_workers=32, shuffle=False, pin_memory=True
        )
        data = list(itertools.islice(ood_loader, 2))
        test_set = data[0]
        ood_set = data[1]
        
        test3 = RealityTask.from_static("test", test_set, ood_tasks)
        ood3 = RealityTask.from_static("ood", ood_set, ood_tasks)




        for name, config in configs.items():

            finetuned = config.get("finetuned", False)
            loss_configs = config["loss_configs"]
            cont = config.get("cont", None)

            logger = VisdomLogger("train", env=name, delete=True if i == 0 else False)
            if config.get("test", False):                
                # GRAPH
                realities = [test, ood]
                print ("Finetuned: ", finetuned)
                graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True)
                if cont is not None: graph.load_weights(cont)

                # LOGGING
                energy_loss.plot_paths_errors(graph, logger, realities, prefix=loss_config)

    
            logger = VisdomLogger("train", env=name + "_ood", delete=True if i == 0 else False)
            if config.get("ood", False):
                # GRAPH
                realities = [test2, ood2]
                print ("Finetuned: ", finetuned)
                graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True)
                if cont is not None: graph.load_weights(cont)

                energy_loss.plot_paths(graph, logger, realities, prefix=loss_config)

            logger = VisdomLogger("train", env=name + "_oodfull", delete=True if i == 0 else False)
            if config.get("oodfull", False):

                # GRAPH
                realities = [test3, ood3]
                print ("Finetuned: ", finetuned)
                graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True)
                if cont is not None: graph.load_weights(cont)

                energy_loss.plot_paths(graph, logger, realities, prefix=loss_config)
Пример #6
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    cont_gan=None,
    pre_gan=None,
    max_epochs=800,
    use_baseline=False,
    use_patches=False,
    patch_frac=None,
    patch_size=64,
    patch_sigma=0,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        energy_loss.get_tasks("train_subset"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    train_step, val_step = train_step // 16, val_step // 16
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               batch_size=batch_size,
                               shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, train_subset, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not USE_RAID and not use_baseline: graph.load_weights(cont)
    pre_gan = pre_gan or 1
    discriminator = Discriminator(energy_loss.losses['gan'],
                                  frac=patch_frac,
                                  size=(patch_size if use_patches else 224),
                                  sigma=patch_sigma,
                                  use_patches=use_patches)
    if cont_gan is not None: discriminator.load_weights(cont_gan)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    logger.add_hook(
        lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"),
        feature="epoch",
        freq=1)
    energy_loss.logger_hooks(logger)

    best_ood_val_loss = float('inf')

    logger.add_hook(partial(jointplot, loss_type=f"gan_subset"),
                    feature=f"val_gan_subset",
                    freq=1)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        if visualize: return

        graph.train()
        discriminator.train()

        for _ in range(0, train_step):
            if epochs > pre_gan:

                train_loss = energy_loss(graph,
                                         discriminator=discriminator,
                                         realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])

                graph.step(train_loss)
                train.step()
                logger.update("loss", train_loss)

                # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse'])
                # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1])
                # train.step()

                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan'])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # train.step()

                # graph.step(train_loss1 + train_loss2)
                # logger.update("loss", train_loss1 + train_loss2)

                # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_id'])
                # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1])
                # graph.step(train_loss1)

                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_ood'])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # graph.step(train_loss2)

                # train_loss3 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan'])
                # train_loss3 = sum([train_loss3[loss_name] for loss_name in train_loss3])
                # graph.step(train_loss3)

                # logger.update("loss", train_loss1 + train_loss2 + train_loss3)
                # train.step()

                # graph fooling loss
                # n(~x), and y^ (128 subset)
                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # train_loss = train_loss1 + train_loss2

            warmup = 5  # if epochs < pre_gan else 1
            for i in range(warmup):
                # y_hat = graph.sample_path([tasks.normal(size=512)], reality=train_subset)
                # n_x = graph.sample_path([tasks.rgb(size=512), tasks.normal(size=512)], reality=train)

                y_hat = graph.sample_path([tasks.normal], reality=train_subset)
                n_x = graph.sample_path(
                    [tasks.rgb(blur_radius=6),
                     tasks.normal(blur_radius=6)],
                    reality=train)

                def coeff_hook(coeff):
                    def fun1(grad):
                        return coeff * grad.clone()

                    return fun1

                logit_path1 = discriminator(y_hat.detach())
                coeff = 0.1
                path_value2 = n_x * 1.0
                path_value2.register_hook(coeff_hook(coeff))
                logit_path2 = discriminator(path_value2)
                binary_label = torch.Tensor(
                    [1] * logit_path1.size(0) +
                    [0] * logit_path2.size(0)).float().cuda()
                gan_loss = nn.BCEWithLogitsLoss(size_average=True)(torch.cat(
                    (logit_path1, logit_path2), dim=0).view(-1), binary_label)
                discriminator.discriminator.step(gan_loss)

                logger.update("train_gan_subset", gan_loss)
                logger.update("val_gan_subset", gan_loss)

                # print ("Gan loss: ", (-gan_loss).data.cpu().numpy())
                train.step()
                train_subset.step()

        graph.eval()
        discriminator.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph,
                                       discriminator=discriminator,
                                       realities=[val, train_subset])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        if epochs > pre_gan:
            energy_loss.logger_update(logger)
            logger.step()

            if logger.data["train_subset_val_ood : y^ -> n(~x)"][
                    -1] < best_ood_val_loss:
                best_ood_val_loss = logger.data[
                    "train_subset_val_ood : y^ -> n(~x)"][-1]
                energy_loss.plot_paths(graph, logger, realities, prefix="best")
Пример #7
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING

    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    train_step, val_step = 4, 4
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      freeze_list=energy_loss.freeze_list,
                      finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not USE_RAID: graph.load_weights(cont)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix=f"epoch_{epochs}")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.select_losses(val)
        if epochs != 0:
            energy_loss.logger_update(logger)
        else:
            energy_loss.metrics = {}
        logger.step()

        logger.text(f"Chosen losses: {energy_loss.chosen_losses}")
        logger.text(f"Percep winrate: {energy_loss.percep_winrate}")
        graph.train()
        for _ in range(0, train_step):
            train_loss2 = energy_loss(graph, realities=[train])
            train_loss = sum(train_loss2.values())

            graph.step(train_loss)
            train.step()

            logger.update("loss", train_loss)
Пример #8
0
def main(
    fast=False,
    batch_size=None,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 32)
    energy_loss = get_energy_loss(config="consistency_two_path",
                                  mode="standard",
                                  **kwargs)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)

    # DATA LOADING
    video_dataset = ImageDataset(
        files=sorted(
            glob.glob(f"mount/taskonomy_house_tour/original/image*.png"),
            key=lambda x: int(os.path.basename(x)[5:-4])),
        return_tuple=True,
        resize=720,
    )
    video = RealityTask("video",
                        video_dataset, [
                            tasks.rgb,
                        ],
                        batch_size=batch_size,
                        shuffle=False)

    # GRAPHS
    graph_baseline = TaskGraph(tasks=energy_loss.tasks + [video],
                               finetuned=False)
    graph_baseline.compile(torch.optim.Adam,
                           lr=3e-5,
                           weight_decay=2e-6,
                           amsgrad=True)

    graph_finetuned = TaskGraph(tasks=energy_loss.tasks + [video],
                                finetuned=True)
    graph_finetuned.compile(torch.optim.Adam,
                            lr=3e-5,
                            weight_decay=2e-6,
                            amsgrad=True)

    graph_conservative = TaskGraph(tasks=energy_loss.tasks + [video],
                                   finetuned=True)
    graph_conservative.compile(torch.optim.Adam,
                               lr=3e-5,
                               weight_decay=2e-6,
                               amsgrad=True)
    graph_conservative.load_weights(
        f"{MODELS_DIR}/conservative/conservative.pth")

    graph_ood_conservative = TaskGraph(tasks=energy_loss.tasks + [video],
                                       finetuned=True)
    graph_ood_conservative.compile(torch.optim.Adam,
                                   lr=3e-5,
                                   weight_decay=2e-6,
                                   amsgrad=True)
    graph_ood_conservative.load_weights(
        f"{SHARED_DIR}/results_2F_grounded_1percent_gt_twopath_512_256_crop_7/graph_grounded_1percent_gt_twopath.pth"
    )

    graphs = {
        "baseline": graph_baseline,
        "finetuned": graph_finetuned,
        "conservative": graph_conservative,
        "ood_conservative": graph_ood_conservative,
    }

    inv_transform = transforms.ToPILImage()
    data = {key: {"losses": [], "zooms": []} for key in graphs}
    size = 256
    for batch in range(0, 700):

        if batch * batch_size > len(video_dataset.files): break

        frac = (batch * batch_size * 1.0) / len(video_dataset.files)
        if frac < 0.3:
            size = int(256.0 - 128 * frac / 0.3)
        elif frac < 0.5:
            size = int(128.0 + 128 * (frac - 0.3) / 0.2)
        else:
            size = int(256.0 + (720 - 256) * (frac - 0.5) / 0.5)
        print(size)
        # video.reload()
        size = (size // 32) * 32
        print(size)
        video.step()
        video.task_data[tasks.rgb] = resize(
            video.task_data[tasks.rgb].to(DEVICE), size).data
        print(video.task_data[tasks.rgb].shape)

        with torch.no_grad():

            for i, img in enumerate(video.task_data[tasks.rgb]):
                inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                    f"mount/taskonomy_house_tour/distorted/image{batch*batch_size + i}.png"
                )

            for name, graph in graphs.items():
                normals = graph.sample_path([tasks.rgb, tasks.normal],
                                            reality=video)
                normals2 = graph.sample_path(
                    [tasks.rgb, tasks.principal_curvature, tasks.normal],
                    reality=video)

                for i, img in enumerate(normals):
                    energy, _ = tasks.normal.norm(normals[i:(i + 1)],
                                                  normals2[i:(i + 1)])
                    data[name]["losses"] += [energy.data.cpu().numpy().mean()]
                    data[name]["zooms"] += [size]
                    inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                        f"mount/taskonomy_house_tour/normals_{name}/image{batch*batch_size + i}.png"
                    )

                for i, img in enumerate(normals2):
                    inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                        f"mount/taskonomy_house_tour/path2_{name}/image{batch*batch_size + i}.png"
                    )

    pickle.dump(data, open(f"mount/taskonomy_house_tour/data.pkl", 'wb'))
    os.system("bash ~/scaling/scripts/create_vids.sh")
Пример #9
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{BASE_DIR}/shared/results_LBP_multipercep_lat_winrate_8/graph.pth",
    cont_gan=None,
    pre_gan=None,
    max_epochs=800,
    use_baseline=False,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        energy_loss.get_tasks("val"),
        batch_size=batch_size,
        fast=fast,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        energy_loss.get_tasks("train_subset"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    if not fast:
        train_step, val_step = train_step // (16 * 4), val_step // (16)
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               batch_size=batch_size,
                               shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    # ood = RealityTask.from_static("ood", ood_set, [energy_loss.get_tasks("ood")])

    # GRAPH
    realities = [
        train,
        val,
        test,
    ] + [train_subset]  #[ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      finetuned=finetuned,
                      freeze_list=energy_loss.freeze_list)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    graph.load_weights(cont)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    # logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)

    best_ood_val_loss = float('inf')
    energy_losses = []
    mse_losses = []
    pearsonr_vals = []
    percep_losses = defaultdict(list)
    pearson_percep = defaultdict(list)
    # # TRAINING
    # for epochs in range(0, max_epochs):

    # 	logger.update("epoch", epochs)
    # 	if epochs == 0:
    # 		energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
    # 	# if visualize: return

    # 	graph.eval()
    # 	for _ in range(0, val_step):
    # 		with torch.no_grad():
    # 			losses = energy_loss(graph, realities=[val])
    # 			all_perceps = [losses[loss_name] for loss_name in losses if 'percep' in loss_name ]
    # 			energy_avg = sum(all_perceps) / len(all_perceps)
    # 			for loss_name in losses:
    # 				if 'percep' not in loss_name: continue
    # 				percep_losses[loss_name] += [losses[loss_name].data.cpu().numpy()]
    # 			mse = losses['mse']
    # 			energy_losses.append(energy_avg.data.cpu().numpy())
    # 			mse_losses.append(mse.data.cpu().numpy())

    # 		val.step()
    # 	mse_arr = np.array(mse_losses)
    # 	energy_arr = np.array(energy_losses)
    # 	# logger.scatter(mse_arr - mse_arr.mean() / np.std(mse_arr), \
    # 	# 	energy_arr - energy_arr.mean() / np.std(energy_arr), \
    # 	# 	'unit_normal_all', opts={'xlabel':'mse','ylabel':'energy'})
    # 	logger.scatter(mse_arr, energy_arr, \
    # 		'mse_energy_all', opts={'xlabel':'mse','ylabel':'energy'})
    # 	pearsonr, p = scipy.stats.pearsonr(mse_arr, energy_arr)
    # 	logger.text(f'pearsonr = {pearsonr}, p = {p}')
    # 	pearsonr_vals.append(pearsonr)
    # 	logger.plot(pearsonr_vals, 'pearsonr_all')
    # 	for percep_name in percep_losses:
    # 		percep_loss_arr = np.array(percep_losses[percep_name])
    # 		logger.scatter(mse_arr, percep_loss_arr, f'mse_energy_{percep_name}', \
    # 			opts={'xlabel':'mse','ylabel':'energy'})
    # 		pearsonr, p = scipy.stats.pearsonr(mse_arr, percep_loss_arr)
    # 		pearson_percep[percep_name] += [pearsonr]
    # 		logger.plot(pearson_percep[percep_name], f'pearson_{percep_name}')

    # energy_loss.logger_update(logger)
    # if logger.data['val_mse : n(~x) -> y^'][-1] < best_ood_val_loss:
    # 	best_ood_val_loss = logger.data['val_mse : n(~x) -> y^'][-1]
    # 	energy_loss.plot_paths(graph, logger, realities, prefix="best")

    energy_mean_by_blur = []
    energy_std_by_blur = []
    mse_mean_by_blur = []
    mse_std_by_blur = []
    for blur_size in np.arange(0, 10, 0.5):
        tasks.rgb.blur_radius = blur_size if blur_size > 0 else None
        train_subset.step()
        # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")

        energy_losses = []
        mse_losses = []
        for epochs in range(subset_size // batch_size):
            with torch.no_grad():
                flosses = energy_loss(graph,
                                      realities=[train_subset],
                                      reduce=False)
                losses = energy_loss(graph,
                                     realities=[train_subset],
                                     reduce=False)
                all_perceps = np.stack([
                    losses[loss_name].data.cpu().numpy()
                    for loss_name in losses if 'percep' in loss_name
                ])
                energy_losses += list(all_perceps.mean(0))
                mse_losses += list(losses['mse'].data.cpu().numpy())
            train_subset.step()
        mse_losses = np.array(mse_losses)
        energy_losses = np.array(energy_losses)
        logger.text(
            f'blur_radius = {blur_size}, mse = {mse_losses.mean()}, energy = {energy_losses.mean()}'
        )
        logger.scatter(mse_losses, energy_losses, \
         f'mse_energy, blur = {blur_size}', opts={'xlabel':'mse','ylabel':'energy'})

        energy_mean_by_blur += [energy_losses.mean()]
        energy_std_by_blur += [np.std(energy_losses)]
        mse_mean_by_blur += [mse_losses.mean()]
        mse_std_by_blur += [np.std(mse_losses)]

    logger.plot(energy_mean_by_blur, f'energy_mean_by_blur')
    logger.plot(energy_std_by_blur, f'energy_std_by_blur')
    logger.plot(mse_mean_by_blur, f'mse_mean_by_blur')
    logger.plot(mse_std_by_blur, f'mse_std_by_blur')
Пример #10
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    cont_gan=None,
    pre_gan=None,
    use_patches=False,
    patch_size=64,
    use_baseline=False,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      finetuned=finetuned,
                      freeze_list=energy_loss.freeze_list)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not use_baseline and not USE_RAID:
        graph.load_weights(cont)

    pre_gan = pre_gan or 1
    discriminator = Discriminator(energy_loss.losses['gan'],
                                  size=(patch_size if use_patches else 224),
                                  use_patches=use_patches)
    # if cont_gan is not None: discriminator.load_weights(cont_gan)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    logger.add_hook(
        lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"),
        feature="epoch",
        freq=1)
    energy_loss.logger_hooks(logger)

    # TRAINING
    for epochs in range(0, 80):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.train()
        discriminator.train()

        for _ in range(0, train_step):
            if epochs > pre_gan:
                energy_loss.train_iter += 1
                train_loss = energy_loss(graph,
                                         discriminator=discriminator,
                                         realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])
                graph.step(train_loss)
                train.step()
                logger.update("loss", train_loss)

            for i in range(5 if epochs <= pre_gan else 1):
                train_loss2 = energy_loss(graph,
                                          discriminator=discriminator,
                                          realities=[train])
                discriminator.step(train_loss2)
                train.step()

        graph.eval()
        discriminator.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph,
                                       discriminator=discriminator,
                                       realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()
Пример #11
0
def main(
    loss_config="multiperceptual",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    resume=False,
    learning_rate=3e-5,
    subset_size=None,
    max_epochs=2000,
    dataaug=False,
    **kwargs,
):

    # CONFIG
    wandb.config.update({
        "loss_config": loss_config,
        "batch_size": batch_size,
        "data_aug": dataaug,
        "lr": learning_rate
    })

    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=False,
        freeze_list=energy_loss.freeze_list,
    )
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if resume:
        graph.load_weights('/workspace/shared/results_test_1/graph.pth')
        graph.optimizer.load_state_dict(
            torch.load('/workspace/shared/results_test_1/opt.pth'))

    # LOGGING
    logger = VisdomLogger("train", env=JOB)  # fake visdom logger
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)

    ######## baseline computation
    if not resume:
        graph.eval()
        with torch.no_grad():
            for _ in range(0, val_step):
                val_loss, _, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
                val.step()
                logger.update("loss", val_loss)

            for _ in range(0, train_step):
                train_loss, _, _ = energy_loss(graph, realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])
                train.step()
                logger.update("loss", train_loss)

        energy_loss.logger_update(logger)
        data = logger.step()
        del data['loss']
        data = {k: v[0] for k, v in data.items()}
        wandb.log(data, step=0)

        path_values = energy_loss.plot_paths(graph,
                                             logger,
                                             realities,
                                             prefix="")
        for reality_paths, reality_images in path_values.items():
            wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0)
    ###########

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss, _, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        graph.train()
        for _ in range(0, train_step):
            train_loss, coeffs, avg_grads = energy_loss(
                graph, realities=[train], compute_grad_ratio=True)
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        energy_loss.logger_update(logger)

        data = logger.step()
        del data['loss']
        del data['epoch']
        data = {k: v[0] for k, v in data.items()}
        wandb.log(data, step=epochs)
        wandb.log(coeffs, step=epochs)
        wandb.log(avg_grads, step=epochs)

        if epochs % 5 == 0:
            graph.save(f"{RESULTS_DIR}/graph.pth")
            torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")

        if epochs % 10 == 0:
            path_values = energy_loss.plot_paths(graph,
                                                 logger,
                                                 realities,
                                                 prefix="")
            for reality_paths, reality_images in path_values.items():
                wandb.log({reality_paths: [wandb.Image(reality_images)]},
                          step=epochs + 1)

    graph.save(f"{RESULTS_DIR}/graph.pth")
    torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
Пример #12
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=64,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    ood_batch_size = ood_batch_size or batch_size
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        [tasks.rgb, tasks.normal, tasks.principal_curvature],
        return_dataset=True,
        batch_size=batch_size,
        train_buildings=["almena"] if fast else None,
        val_buildings=["almena"] if fast else None,
        resize=256,
    )
    ood_consistency_dataset, _, _, _ = load_train_val(
        [
            tasks.rgb,
        ],
        return_dataset=True,
        train_buildings=["almena"] if fast else None,
        val_buildings=["almena"] if fast else None,
        resize=512,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        [
            tasks.rgb,
            tasks.normal,
        ],
        return_dataset=True,
        train_buildings=["almena"] if fast else None,
        val_buildings=["almena"] if fast else None,
        resize=512,
        subset_size=subset_size,
    )

    train_step, val_step = train_step // 4, val_step // 4
    if fast: train_step, val_step = 20, 20
    test_set = load_test([tasks.rgb, tasks.normal, tasks.principal_curvature])
    ood_images = load_ood()
    ood_images_large = load_ood(resize=512, sample=8)
    ood_consistency_test = load_test([
        tasks.rgb,
    ], resize=512)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    ood_consistency = RealityTask("ood_consistency",
                                  ood_consistency_dataset,
                                  batch_size=ood_batch_size,
                                  shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               tasks=[tasks.rgb, tasks.normal],
                               batch_size=ood_batch_size,
                               shuffle=False)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static(
        "test", test_set, [tasks.rgb, tasks.normal, tasks.principal_curvature])
    ood_test = RealityTask.from_static("ood_test", (ood_images, ), [
        tasks.rgb,
    ])
    ood_test_large = RealityTask.from_static("ood_test_large",
                                             (ood_images_large, ), [
                                                 tasks.rgb,
                                             ])
    ood_consistency_test = RealityTask.from_static("ood_consistency_test",
                                                   ood_consistency_test, [
                                                       tasks.rgb,
                                                   ])

    realities = [
        train, val, train_subset, ood_consistency, test, ood_test,
        ood_test_large, ood_consistency_test
    ]
    energy_loss.load_realities(realities)

    # GRAPH
    graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    # graph.load_weights(f"{MODELS_DIR}/conservative/conservative.pth")
    graph.load_weights(
        f"{SHARED_DIR}/results_2FF_train_subset_512_true_baseline_3/graph_baseline.pth"
    )
    print(graph)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(
        lambda _, __: graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth"),
        feature="epoch",
        freq=1,
    )
    graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth")
    energy_loss.logger_hooks(logger)

    # TRAINING
    for epochs in range(0, 800):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.train()
        for _ in range(0, train_step):
            train.step()
            train_loss = energy_loss(graph, reality=train)
            graph.step(train_loss)
            logger.update("loss", train_loss)

            train_subset.step()
            train_subset_loss = energy_loss(graph, reality=train_subset)
            graph.step(train_subset_loss)

            ood_consistency.step()
            ood_consistency_loss = energy_loss(graph, reality=ood_consistency)
            if ood_consistency_loss is not None:
                graph.step(ood_consistency_loss)

        graph.eval()
        for _ in range(0, val_step):
            val.step()
            with torch.no_grad():
                val_loss = energy_loss(graph, reality=val)
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()
Пример #13
0
def main(
	loss_config="trainsig_edgereshade", mode="standard", visualize=False,
	fast=False, batch_size=32, learning_rate=3e-5, resume=False,
	subset_size=None, max_epochs=800, dataaug=False, **kwargs,
):

	# CONFIG
	wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"lr":learning_rate})

	batch_size = batch_size or (4 if fast else 64)
	energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

	# DATA LOADING
	train_undist_dataset, train_dist_dataset, val_ooddist_dataset, val_dist_dataset, val_dataset, train_step, val_step = load_train_val_sig(
		energy_loss.get_tasks("val"),
		batch_size=batch_size, fast=fast,
		subset_size=subset_size,
	)
	test_set = load_test(energy_loss.get_tasks("test"))

	ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/')
	ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/')
	ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35)

	train_undist = RealityTask("train_undist", train_undist_dataset, batch_size=batch_size, shuffle=True)
	train_dist = RealityTask("train_dist", train_dist_dataset, batch_size=batch_size, shuffle=True)
	val_ooddist = RealityTask("val_ooddist", val_ooddist_dataset, batch_size=batch_size, shuffle=True)
	val_dist = RealityTask("val_dist", val_dist_dataset, batch_size=batch_size, shuffle=True)
	val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
	test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))

	ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])                                  ## standard ood set - natural
	ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,])          ## synthetic distortion images used for sig training 
	ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,])                      ## unseen syn distortions

	# GRAPH
	realities = [train_undist, train_dist, val_ooddist, val_dist, val, test, ood, ood_syn_aug, ood_syn]
	graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,
		freeze_list=energy_loss.freeze_list,
	)
	graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

	if resume:
		graph.load_weights('/workspace/shared/results_test_1/graph.pth')
		graph.optimizer.load_state_dict(torch.load('/workspace/shared/results_test_1/opt.pth'))
	# else:
	# 	folder_name='/workspace/shared/results_wavelet2normal_depthreshadecurvimgnetl1perceps_0.1nll/'
	# 	# pdb.set_trace()
	# 	in_domain='wav'
	# 	out_domain='normal'
	# 	graph.load_weights(folder_name+'graph.pth', [str((in_domain, out_domain))])
	# 	create_t0_graph(folder_name,in_domain,out_domain)
	# 	graph.load_weights(folder_name+'graph_t0.pth', [str((in_domain, f'{out_domain}_t0'))])


	# LOGGING
	logger = VisdomLogger("train", env=JOB)    # fake visdom logger
	logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
	energy_loss.logger_hooks(logger)

	# BASELINE 
	if not resume:
		graph.eval()
		with torch.no_grad():
			for reality in [val_ooddist,val_dist,val]:
				for _ in range(0, val_step):
					val_loss = energy_loss(graph, realities=[reality])
					val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
					reality.step()
					logger.update("loss", val_loss)
			for reality in [train_undist,train_dist]:
				for _ in range(0, train_step):
					train_loss = energy_loss(graph, realities=[reality])
					train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
					reality.step()
					logger.update("loss", train_loss)
		
		energy_loss.logger_update(logger)
		data=logger.step()
		del data['loss']
		data = {k:v[0] for k,v in data.items()}
		wandb.log(data, step=0)

		path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
		for reality_paths, reality_images in path_values.items():
			wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0)


	# TRAINING
	for epochs in range(0, max_epochs):

		logger.update("epoch", epochs)

		graph.train()
		for _ in range(0, train_step):
			train_loss_nll = energy_loss(graph, realities=[train_undist])
			train_loss_nll = sum([train_loss_nll[loss_name] for loss_name in train_loss_nll])
			train_loss_lwfsig = energy_loss(graph, realities=[train_dist])
			train_loss_lwfsig = sum([train_loss_lwfsig[loss_name] for loss_name in train_loss_lwfsig])
			train_loss = train_loss_nll+train_loss_lwfsig
			graph.step(train_loss)
			train_undist.step()
			train_dist.step()
			logger.update("loss", train_loss)

		graph.eval()
		for _ in range(0, val_step):
			with torch.no_grad():
				val_loss = energy_loss(graph, realities=[val_dist])
				val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
			val_dist.step()
			logger.update("loss", val_loss)
		
		if epochs % 20 == 0:
			for reality in [val,val_ooddist]:
				for _ in range(0, val_step):
					with torch.no_grad():
						val_loss = energy_loss(graph, realities=[reality])
						val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
					reality.step()
					logger.update("loss", val_loss)

		energy_loss.logger_update(logger)

		data=logger.step()
		del data['loss']
		data = {k:v[0] for k,v in data.items()}
		wandb.log(data, step=epochs+1)

		if epochs % 10 == 0:
			graph.save(f"{RESULTS_DIR}/graph.pth")
			torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")

		if (epochs % 100 == 0) or (epochs % 15 == 0 and epochs <= 30):
			path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
			for reality_paths, reality_images in path_values.items():
				wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1)



	graph.save(f"{RESULTS_DIR}/graph.pth")
	torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")