示例#1
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    use_optimizer=False,
    subset_size=None,
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    print(kwargs)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    train_step, val_step = train_step // 4, val_step // 4
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood([
        tasks.rgb,
    ])
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=False,
        freeze_list=energy_loss.freeze_list,
    )
    graph.edge(tasks.rgb, tasks.normal).model = None
    graph.edge(
        tasks.rgb, tasks.normal
    ).path = "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/n.pth"
    graph.edge(tasks.rgb, tasks.normal).load_model()
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    if use_optimizer:
        optimizer = torch.load(
            "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/optimizer.pth"
        )
        graph.optimizer.load_state_dict(optimizer.state_dict())

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        energy_loss.logger_update(logger)

        # if logger.data["val_mse : y^ -> n(~x)"][-1] < best_ood_val_loss:
        # 	best_ood_val_loss = logger.data["val_mse : y^ -> n(~x)"][-1]
        # 	energy_loss.plot_paths(graph, logger, realities, prefix="best")

        logger.step()
示例#2
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    cont_gan=None,
    pre_gan=None,
    max_epochs=800,
    use_baseline=False,
    use_patches=False,
    patch_frac=None,
    patch_size=64,
    patch_sigma=0,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        energy_loss.get_tasks("train_subset"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    train_step, val_step = train_step // 16, val_step // 16
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               batch_size=batch_size,
                               shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, train_subset, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not USE_RAID and not use_baseline: graph.load_weights(cont)
    pre_gan = pre_gan or 1
    discriminator = Discriminator(energy_loss.losses['gan'],
                                  frac=patch_frac,
                                  size=(patch_size if use_patches else 224),
                                  sigma=patch_sigma,
                                  use_patches=use_patches)
    if cont_gan is not None: discriminator.load_weights(cont_gan)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    logger.add_hook(
        lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"),
        feature="epoch",
        freq=1)
    energy_loss.logger_hooks(logger)

    best_ood_val_loss = float('inf')

    logger.add_hook(partial(jointplot, loss_type=f"gan_subset"),
                    feature=f"val_gan_subset",
                    freq=1)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        if visualize: return

        graph.train()
        discriminator.train()

        for _ in range(0, train_step):
            if epochs > pre_gan:

                train_loss = energy_loss(graph,
                                         discriminator=discriminator,
                                         realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])

                graph.step(train_loss)
                train.step()
                logger.update("loss", train_loss)

                # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse'])
                # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1])
                # train.step()

                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan'])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # train.step()

                # graph.step(train_loss1 + train_loss2)
                # logger.update("loss", train_loss1 + train_loss2)

                # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_id'])
                # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1])
                # graph.step(train_loss1)

                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_ood'])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # graph.step(train_loss2)

                # train_loss3 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan'])
                # train_loss3 = sum([train_loss3[loss_name] for loss_name in train_loss3])
                # graph.step(train_loss3)

                # logger.update("loss", train_loss1 + train_loss2 + train_loss3)
                # train.step()

                # graph fooling loss
                # n(~x), and y^ (128 subset)
                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # train_loss = train_loss1 + train_loss2

            warmup = 5  # if epochs < pre_gan else 1
            for i in range(warmup):
                # y_hat = graph.sample_path([tasks.normal(size=512)], reality=train_subset)
                # n_x = graph.sample_path([tasks.rgb(size=512), tasks.normal(size=512)], reality=train)

                y_hat = graph.sample_path([tasks.normal], reality=train_subset)
                n_x = graph.sample_path(
                    [tasks.rgb(blur_radius=6),
                     tasks.normal(blur_radius=6)],
                    reality=train)

                def coeff_hook(coeff):
                    def fun1(grad):
                        return coeff * grad.clone()

                    return fun1

                logit_path1 = discriminator(y_hat.detach())
                coeff = 0.1
                path_value2 = n_x * 1.0
                path_value2.register_hook(coeff_hook(coeff))
                logit_path2 = discriminator(path_value2)
                binary_label = torch.Tensor(
                    [1] * logit_path1.size(0) +
                    [0] * logit_path2.size(0)).float().cuda()
                gan_loss = nn.BCEWithLogitsLoss(size_average=True)(torch.cat(
                    (logit_path1, logit_path2), dim=0).view(-1), binary_label)
                discriminator.discriminator.step(gan_loss)

                logger.update("train_gan_subset", gan_loss)
                logger.update("val_gan_subset", gan_loss)

                # print ("Gan loss: ", (-gan_loss).data.cpu().numpy())
                train.step()
                train_subset.step()

        graph.eval()
        discriminator.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph,
                                       discriminator=discriminator,
                                       realities=[val, train_subset])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        if epochs > pre_gan:
            energy_loss.logger_update(logger)
            logger.step()

            if logger.data["train_subset_val_ood : y^ -> n(~x)"][
                    -1] < best_ood_val_loss:
                best_ood_val_loss = logger.data[
                    "train_subset_val_ood : y^ -> n(~x)"][-1]
                energy_loss.plot_paths(graph, logger, realities, prefix="best")
示例#3
0
def main(
    loss_config="baseline_normal",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    learning_rate=5e-4,
    subset_size=None,
    max_epochs=5000,
    dataaug=False,
    **kwargs,
):

    # CONFIG
    wandb.config.update({
        "loss_config": loss_config,
        "batch_size": batch_size,
        "data_aug": dataaug,
        "lr": learning_rate
    })

    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=False,
        freeze_list=energy_loss.freeze_list,
    )
    graph.compile(torch.optim.Adam,
                  lr=learning_rate,
                  weight_decay=2e-6,
                  amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)  # fake visdom logger
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)

        if (epochs % 100 == 0) or (epochs % 10 == 0 and epochs < 30):
            path_values = energy_loss.plot_paths(graph,
                                                 logger,
                                                 realities,
                                                 prefix="")
            for reality_paths, reality_images in path_values.items():
                wandb.log({reality_paths: [wandb.Image(reality_images)]},
                          step=epochs)

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph,
                                     realities=[train],
                                     compute_grad_ratio=True)
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])
            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)

        data = logger.step()
        del data['loss']
        del data['epoch']
        data = {k: v[0] for k, v in data.items()}
        wandb.log(data, step=epochs)

        # save model and opt state every 10 epochs
        if epochs % 10 == 0:
            graph.save(f"{RESULTS_DIR}/graph.pth")
            torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")

        # lower lr after 1500 epochs
        if epochs == 1500:
            graph.optimizer.param_groups[0]['lr'] = 3e-5

    graph.save(f"{RESULTS_DIR}/graph.pth")
    torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
示例#4
0
def main(
    loss_config="conservative_full", mode="standard", visualize=False,
    fast=False, batch_size=None, 
    subset_size=None, max_epochs=800, **kwargs,
):
        
    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    print (kwargs)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size, fast=fast,
        subset_size=subset_size
    )
    train_step, val_step = 24, 12
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood([tasks.rgb,])
    print (train_step, val_step)
    
    train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, 
        freeze_list=energy_loss.freeze_list,
    )
    graph.compile(torch.optim.Adam, lr=1e-6, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            # logger.update("loss", val_loss)

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            # logger.update("loss", train_loss)

        energy_loss.logger_update(logger)
        
        for param_group in graph.optimizer.param_groups:
            param_group['lr'] *= 1.2
            print ("LR: ", param_group['lr'])

        logger.step()
示例#5
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING

    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    train_step, val_step = 4, 4
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      freeze_list=energy_loss.freeze_list,
                      finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not USE_RAID: graph.load_weights(cont)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix=f"epoch_{epochs}")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.select_losses(val)
        if epochs != 0:
            energy_loss.logger_update(logger)
        else:
            energy_loss.metrics = {}
        logger.step()

        logger.text(f"Chosen losses: {energy_loss.chosen_losses}")
        logger.text(f"Percep winrate: {energy_loss.percep_winrate}")
        graph.train()
        for _ in range(0, train_step):
            train_loss2 = energy_loss(graph, realities=[train])
            train_loss = sum(train_loss2.values())

            graph.step(train_loss)
            train.step()

            logger.update("loss", train_loss)
示例#6
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    cont_gan=None,
    pre_gan=None,
    use_patches=False,
    patch_size=64,
    use_baseline=False,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      finetuned=finetuned,
                      freeze_list=energy_loss.freeze_list)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not use_baseline and not USE_RAID:
        graph.load_weights(cont)

    pre_gan = pre_gan or 1
    discriminator = Discriminator(energy_loss.losses['gan'],
                                  size=(patch_size if use_patches else 224),
                                  use_patches=use_patches)
    # if cont_gan is not None: discriminator.load_weights(cont_gan)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    logger.add_hook(
        lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"),
        feature="epoch",
        freq=1)
    energy_loss.logger_hooks(logger)

    # TRAINING
    for epochs in range(0, 80):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.train()
        discriminator.train()

        for _ in range(0, train_step):
            if epochs > pre_gan:
                energy_loss.train_iter += 1
                train_loss = energy_loss(graph,
                                         discriminator=discriminator,
                                         realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])
                graph.step(train_loss)
                train.step()
                logger.update("loss", train_loss)

            for i in range(5 if epochs <= pre_gan else 1):
                train_loss2 = energy_loss(graph,
                                          discriminator=discriminator,
                                          realities=[train])
                discriminator.step(train_loss2)
                train.step()

        graph.eval()
        discriminator.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph,
                                       discriminator=discriminator,
                                       realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()
示例#7
0
def main(
	loss_config="multiperceptual", mode="standard", visualize=False,
	fast=False, batch_size=None,
	subset_size=None, max_epochs=5000, dataaug=False, **kwargs,
):

	# CONFIG
	wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"data_aug":dataaug,"lr":"3e-5",
		"n_gauss":1,"distribution":"laplace"})

	batch_size = batch_size or (4 if fast else 64)
	energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

	# DATA LOADING
	train_dataset, val_dataset, val_noaug_dataset, train_step, val_step = load_train_val_merging(
		energy_loss.get_tasks("train_c"),
		batch_size=batch_size, fast=fast,
		subset_size=subset_size,
	)
	test_set = load_test(energy_loss.get_tasks("test"))

	ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/')
	ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/')
	ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35)

	train = RealityTask("train_c", train_dataset, batch_size=batch_size, shuffle=True)      # distorted and undistorted 
	val = RealityTask("val_c", val_dataset, batch_size=batch_size, shuffle=True)            # distorted and undistorted 
	val_noaug = RealityTask("val", val_noaug_dataset, batch_size=batch_size, shuffle=True)  # no augmentation
	test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))

	ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])                                  ## standard ood set - natural
	ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,])          ## synthetic distortion images used for sig training 
	ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,])                      ## unseen syn distortions

	# GRAPH
	realities = [train, val, val_noaug, test, ood, ood_syn_aug, ood_syn]
	graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,
		freeze_list=energy_loss.freeze_list,
	)
	graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

	# LOGGING
	logger = VisdomLogger("train", env=JOB)    # fake visdom logger
	logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
	energy_loss.logger_hooks(logger)

	graph.eval()
	path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
	for reality_paths, reality_images in path_values.items():
		wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0)

	with torch.no_grad():
		for reality in [val,val_noaug]:
			for _ in range(0, val_step):
				val_loss = energy_loss(graph, realities=[reality])
				val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
				reality.step()
				logger.update("loss", val_loss)

		for _ in range(0, train_step):
			train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True)
			train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
			train.step()
			logger.update("loss", train_loss)

	energy_loss.logger_update(logger)

	data=logger.step()
	del data['loss']
	data = {k:v[0] for k,v in data.items()}
	wandb.log(data, step=0)

	# TRAINING
	for epochs in range(0, max_epochs):

		logger.update("epoch", epochs)

		graph.train()
		for _ in range(0, train_step):
			train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True)
			train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
			graph.step(train_loss)
			train.step()
			logger.update("loss", train_loss)

		graph.eval()
		for reality in [val,val_noaug]:
			for _ in range(0, val_step):
				with torch.no_grad():
					val_loss = energy_loss(graph, realities=[reality])
					val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
				reality.step()
				logger.update("loss", val_loss)

		energy_loss.logger_update(logger)

		data=logger.step()
		del data['loss']
		del data['epoch']
		data = {k:v[0] for k,v in data.items()}
		wandb.log(data, step=epochs+1)

		if epochs % 10 == 0:
			graph.save(f"{RESULTS_DIR}/graph.pth")
			torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")

		if epochs % 25 == 0:
			path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
			for reality_paths, reality_images in path_values.items():
				wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1)



	graph.save(f"{RESULTS_DIR}/graph.pth")
	torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")
示例#8
0
def main(
        fast=False,
        subset_size=None,
        early_stopping=float('inf'),
        mode='standard',
        max_epochs=800,
        **kwargs,
):

    early_stopping = 8
    loss_config_percepnet = {
        "paths": {
            "y": [tasks.normal],
            "z^": [tasks.principal_curvature],
            "f(y)": [tasks.normal, tasks.principal_curvature],
        },
        "losses": {
            "mse": {
                ("train", "val"): [
                    ("f(y)", "z^"),
                ],
            },
        },
        "plots": {
            "ID":
            dict(size=256,
                 realities=("test", "ood"),
                 paths=[
                     "y",
                     "z^",
                     "f(y)",
                 ]),
        },
    }

    # CONFIG
    batch_size = 64
    energy_loss = EnergyLoss(**loss_config_percepnet)

    task_list = [tasks.rgb, tasks.normal, tasks.principal_curvature]
    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        task_list,
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(task_list)
    ood_set = load_ood(task_list)
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set, task_list)
    ood = RealityTask.from_static("ood", ood_set, task_list)

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities,
        pretrained=False,
        freeze_list=[functional_transfers.n],
    )
    graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)
    best_val_loss, stop_idx = float('inf'), 0

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()

        stop_idx += 1
        if logger.data["val_mse : f(y) -> z^"][-1] < best_val_loss:
            print("Better val loss, reset stop_idx: ", stop_idx)
            best_val_loss, stop_idx = logger.data["val_mse : f(y) -> z^"][
                -1], 0
            energy_loss.plot_paths(graph, logger, realities, prefix="best")
            graph.save(weights_dir=f"{RESULTS_DIR}")

        if stop_idx >= early_stopping:
            print("Stopping training now")
            break

    early_stopping = 50
    # CONFIG
    energy_loss = get_energy_loss(config="perceptual", mode=mode, **kwargs)

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities,
        pretrained=False,
        freeze_list=[functional_transfers.f],
    )
    graph.edge(
        tasks.normal,
        tasks.principal_curvature).model.load_weights(f"{RESULTS_DIR}/f.pth")
    graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)
    best_val_loss, stop_idx = float('inf'), 0

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()

        stop_idx += 1
        if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss:
            print("Better val loss, reset stop_idx: ", stop_idx)
            best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][
                -1], 0
            energy_loss.plot_paths(graph, logger, realities, prefix="best")
            graph.save(f"{RESULTS_DIR}/graph.pth")

        if stop_idx >= early_stopping:
            print("Stopping training now")
            break
示例#9
0
def main(
    loss_config="multiperceptual", mode="winrate", visualize=False,
    fast=False, batch_size=None,
    subset_size=None, max_epochs=800, dataaug=False, **kwargs,
):


    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size, fast=fast,
        subset_size=subset_size,
        dataaug=dataaug,
    )

    if fast:
        train_dataset = val_dataset
        train_step, val_step = 2,2

    train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)

    if fast:
        train_dataset = val_dataset
        train_step, val_step = 2,2
        realities = [train, val]
    else:
        test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville'])
        test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
        realities = [train, val, test]
        # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do:
        # ood_set = load_ood(energy_loss.get_tasks("ood"))
        # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])
        # realities.append(ood)

    # GRAPH
    graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,
        freeze_list=energy_loss.freeze_list,
        initialize_from_transfer=False,
    )
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    os.makedirs(RESULTS_DIR, exist_ok=True)
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)
    energy_loss.plot_paths(graph, logger, realities, prefix="start")

    # BASELINE
    graph.eval()
    with torch.no_grad():
        for _ in range(0, val_step*4):
            val_loss, _ = energy_loss(graph, realities=[val])
            val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        for _ in range(0, train_step*4):
            train_loss, _ = energy_loss(graph, realities=[train])
            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
            train.step()
            logger.update("loss", train_loss)
    energy_loss.logger_update(logger)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph, logger, realities, prefix="")
        if visualize: return

        graph.train()
        for _ in range(0, train_step):
            train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True)
            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)

        logger.step()
示例#10
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    train_step, val_step = train_step // 4, val_step // 4
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    print("Train step: ", train_step, "Val step: ", val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=True,
        freeze_list=[functional_transfers.a, functional_transfers.RC],
    )
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)

    activated_triangles = set()
    triangle_energy = {
        "triangle1_mse": float('inf'),
        "triangle2_mse": float('inf')
    }

    logger.add_hook(partial(jointplot, loss_type=f"energy"),
                    feature=f"val_energy",
                    freq=1)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.train()
        for _ in range(0, train_step):
            # loss_type = random.choice(["triangle1_mse", "triangle2_mse"])
            loss_type = max(triangle_energy, key=triangle_energy.get)

            activated_triangles.add(loss_type)
            train_loss = energy_loss(graph,
                                     realities=[train],
                                     loss_types=[loss_type])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()

            if loss_type == "triangle1_mse":
                consistency_tr1 = energy_loss.metrics["train"][
                    "triangle1_mse : F(RC(x)) -> n(x)"][-1]
                error_tr1 = energy_loss.metrics["train"][
                    "triangle1_mse : n(x) -> y^"][-1]
                triangle_energy["triangle1_mse"] = float(consistency_tr1 /
                                                         error_tr1)

            elif loss_type == "triangle2_mse":
                consistency_tr2 = energy_loss.metrics["train"][
                    "triangle2_mse : S(a(x)) -> n(x)"][-1]
                error_tr2 = energy_loss.metrics["train"][
                    "triangle2_mse : n(x) -> y^"][-1]
                triangle_energy["triangle2_mse"] = float(consistency_tr2 /
                                                         error_tr2)

            print("Triangle energy: ", triangle_energy)
            logger.update("loss", train_loss)

            energy = sum(triangle_energy.values())
            if (energy < float('inf')):
                logger.update("train_energy", energy)
                logger.update("val_energy", energy)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph,
                                       realities=[val],
                                       loss_types=list(activated_triangles))
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])

            val.step()
            logger.update("loss", val_loss)

        activated_triangles = set()
        energy_loss.logger_update(logger)
        logger.step()
示例#11
0
def main(
    loss_config="multiperceptual",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    resume=False,
    learning_rate=3e-5,
    subset_size=None,
    max_epochs=2000,
    dataaug=False,
    **kwargs,
):

    # CONFIG
    wandb.config.update({
        "loss_config": loss_config,
        "batch_size": batch_size,
        "data_aug": dataaug,
        "lr": learning_rate
    })

    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=False,
        freeze_list=energy_loss.freeze_list,
    )
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if resume:
        graph.load_weights('/workspace/shared/results_test_1/graph.pth')
        graph.optimizer.load_state_dict(
            torch.load('/workspace/shared/results_test_1/opt.pth'))

    # LOGGING
    logger = VisdomLogger("train", env=JOB)  # fake visdom logger
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)

    ######## baseline computation
    if not resume:
        graph.eval()
        with torch.no_grad():
            for _ in range(0, val_step):
                val_loss, _, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
                val.step()
                logger.update("loss", val_loss)

            for _ in range(0, train_step):
                train_loss, _, _ = energy_loss(graph, realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])
                train.step()
                logger.update("loss", train_loss)

        energy_loss.logger_update(logger)
        data = logger.step()
        del data['loss']
        data = {k: v[0] for k, v in data.items()}
        wandb.log(data, step=0)

        path_values = energy_loss.plot_paths(graph,
                                             logger,
                                             realities,
                                             prefix="")
        for reality_paths, reality_images in path_values.items():
            wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0)
    ###########

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss, _, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        graph.train()
        for _ in range(0, train_step):
            train_loss, coeffs, avg_grads = energy_loss(
                graph, realities=[train], compute_grad_ratio=True)
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        energy_loss.logger_update(logger)

        data = logger.step()
        del data['loss']
        del data['epoch']
        data = {k: v[0] for k, v in data.items()}
        wandb.log(data, step=epochs)
        wandb.log(coeffs, step=epochs)
        wandb.log(avg_grads, step=epochs)

        if epochs % 5 == 0:
            graph.save(f"{RESULTS_DIR}/graph.pth")
            torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")

        if epochs % 10 == 0:
            path_values = energy_loss.plot_paths(graph,
                                                 logger,
                                                 realities,
                                                 prefix="")
            for reality_paths, reality_images in path_values.items():
                wandb.log({reality_paths: [wandb.Image(reality_images)]},
                          step=epochs + 1)

    graph.save(f"{RESULTS_DIR}/graph.pth")
    torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
示例#12
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=64,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    ood_batch_size = ood_batch_size or batch_size
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        [tasks.rgb, tasks.normal, tasks.principal_curvature],
        return_dataset=True,
        batch_size=batch_size,
        train_buildings=["almena"] if fast else None,
        val_buildings=["almena"] if fast else None,
        resize=256,
    )
    ood_consistency_dataset, _, _, _ = load_train_val(
        [
            tasks.rgb,
        ],
        return_dataset=True,
        train_buildings=["almena"] if fast else None,
        val_buildings=["almena"] if fast else None,
        resize=512,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        [
            tasks.rgb,
            tasks.normal,
        ],
        return_dataset=True,
        train_buildings=["almena"] if fast else None,
        val_buildings=["almena"] if fast else None,
        resize=512,
        subset_size=subset_size,
    )

    train_step, val_step = train_step // 4, val_step // 4
    if fast: train_step, val_step = 20, 20
    test_set = load_test([tasks.rgb, tasks.normal, tasks.principal_curvature])
    ood_images = load_ood()
    ood_images_large = load_ood(resize=512, sample=8)
    ood_consistency_test = load_test([
        tasks.rgb,
    ], resize=512)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    ood_consistency = RealityTask("ood_consistency",
                                  ood_consistency_dataset,
                                  batch_size=ood_batch_size,
                                  shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               tasks=[tasks.rgb, tasks.normal],
                               batch_size=ood_batch_size,
                               shuffle=False)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static(
        "test", test_set, [tasks.rgb, tasks.normal, tasks.principal_curvature])
    ood_test = RealityTask.from_static("ood_test", (ood_images, ), [
        tasks.rgb,
    ])
    ood_test_large = RealityTask.from_static("ood_test_large",
                                             (ood_images_large, ), [
                                                 tasks.rgb,
                                             ])
    ood_consistency_test = RealityTask.from_static("ood_consistency_test",
                                                   ood_consistency_test, [
                                                       tasks.rgb,
                                                   ])

    realities = [
        train, val, train_subset, ood_consistency, test, ood_test,
        ood_test_large, ood_consistency_test
    ]
    energy_loss.load_realities(realities)

    # GRAPH
    graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    # graph.load_weights(f"{MODELS_DIR}/conservative/conservative.pth")
    graph.load_weights(
        f"{SHARED_DIR}/results_2FF_train_subset_512_true_baseline_3/graph_baseline.pth"
    )
    print(graph)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(
        lambda _, __: graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth"),
        feature="epoch",
        freq=1,
    )
    graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth")
    energy_loss.logger_hooks(logger)

    # TRAINING
    for epochs in range(0, 800):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.train()
        for _ in range(0, train_step):
            train.step()
            train_loss = energy_loss(graph, reality=train)
            graph.step(train_loss)
            logger.update("loss", train_loss)

            train_subset.step()
            train_subset_loss = energy_loss(graph, reality=train_subset)
            graph.step(train_subset_loss)

            ood_consistency.step()
            ood_consistency_loss = energy_loss(graph, reality=ood_consistency)
            if ood_consistency_loss is not None:
                graph.step(ood_consistency_loss)

        graph.eval()
        for _ in range(0, val_step):
            val.step()
            with torch.no_grad():
                val_loss = energy_loss(graph, reality=val)
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()
示例#13
0
def main(
	loss_config="baseline", mode="standard", visualize=False,
	fast=False, batch_size=None, path=None,
	subset_size=None, early_stopping=float('inf'),
	max_epochs=800, **kwargs
):
	
	# CONFIG
	batch_size = batch_size or (4 if fast else 64)
	energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

	# DATA LOADING
	train_dataset, val_dataset, train_step, val_step = load_train_val(
		energy_loss.get_tasks("train"),
		batch_size=batch_size, fast=fast,
		subset_size=subset_size,
	)
	test_set = load_test(energy_loss.get_tasks("test"))
	print('tasks', energy_loss.get_tasks("ood"))

	ood_set = load_ood(energy_loss.get_tasks("ood"))
	print (train_step, val_step)
    
	train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
	val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
	test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
	ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

	# GRAPH
	realities = [train, val, test, ood]
	graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=False, 
		freeze_list=energy_loss.freeze_list,
	)
	graph.edge(tasks.rgb, tasks.normal).model = None 
	graph.edge(tasks.rgb, tasks.normal).path = path
	graph.edge(tasks.rgb, tasks.normal).load_model()
	graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)
	graph.save(weights_dir=f"{RESULTS_DIR}")

	# LOGGING
	logger = VisdomLogger("train", env=JOB)
	logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
	energy_loss.logger_hooks(logger)
	best_val_loss, stop_idx = float('inf'), 0
# 	return 
	# TRAINING
	for epochs in range(0, max_epochs):

		logger.update("epoch", epochs)
		energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
		if visualize: return

		graph.train()
		for _ in range(0, train_step):
			train_loss = energy_loss(graph, realities=[train])
			train_loss = sum([train_loss[loss_name] for loss_name in train_loss])

			graph.step(train_loss)
			train.step()
			logger.update("loss", train_loss)

		graph.eval()
		for _ in range(0, val_step):
			with torch.no_grad():
				val_loss = energy_loss(graph, realities=[val])
				val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
			val.step()
			logger.update("loss", val_loss)

		energy_loss.logger_update(logger)
		logger.step()

		stop_idx += 1
		if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss:
			print ("Better val loss, reset stop_idx: ", stop_idx)
			best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][-1], 0
			energy_loss.plot_paths(graph, logger, realities, prefix="best")
			graph.save(weights_dir=f"{RESULTS_DIR}")

		if stop_idx >= early_stopping:
			print ("Stopping training now")
			return
示例#14
0
def main(
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    max_epochs=800,
    **kwargs,
):

    task_list = [
        tasks.rgb,
        tasks.normal,
        tasks.principal_curvature,
        tasks.depth_zbuffer,
        # tasks.sobel_edges,
    ]

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        task_list,
        batch_size=batch_size,
        fast=fast,
    )
    test_set = load_test(task_list)
    train_step, val_step = train_step // 4, val_step // 4
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set, task_list)
    # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])

    # GRAPH
    realities = [train, val, test]
    graph = TaskGraph(tasks=task_list + realities, finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    logger.add_hook(partial(jointplot, loss_type=f"energy"),
                    feature=f"val_energy",
                    freq=1)

    def path_name(path):
        if len(path) == 1:
            print(path)
            return str(path[0])
        if str((path[-2].name, path[-1].name)) not in graph.edge_map:
            return None
        sub_name = path_name(path[:-1])
        if sub_name is None: return None
        return f'{get_transfer_name(Transfer(path[-2], path[-1]))}({sub_name})'

    for i in range(0, 20):

        gt_paths = {
            path_name(path): list(path)
            for path in itertools.permutations(task_list, 1)
            if path_name(path) is not None
        }
        baseline_paths = {
            path_name(path): list(path)
            for path in itertools.permutations(task_list, 2)
            if path_name(path) is not None
        }
        paths = {
            path_name(path): list(path)
            for path in itertools.permutations(task_list, 3)
            if path_name(path) is not None
        }
        selected_paths = dict(random.sample(paths.items(), k=3))

        print("Chosen paths: ", selected_paths)

        loss_config = {
            "paths": {
                **gt_paths,
                **baseline_paths,
                **paths
            },
            "losses": {
                "baseline_mse": {
                    (
                        "train",
                        "val",
                    ): [(path_name, str(path[-1]))
                        for path_name, path in baseline_paths.items()]
                },
                "mse": {
                    (
                        "train",
                        "val",
                    ): [(path_name, str(path[-1]))
                        for path_name, path in selected_paths.items()]
                },
                "eval_mse": {
                    ("val", ): [(path_name, str(path[-1]))
                                for path_name, path in paths.items()]
                }
            },
            "plots": {
                "ID":
                dict(size=256,
                     realities=("test", ),
                     paths=[
                         path_name
                         for path_name, path in selected_paths.items()
                     ] + [
                         str(path[-1])
                         for path_name, path in selected_paths.items()
                     ]),
            },
        }

        energy_loss = EnergyLoss(**loss_config)
        energy_loss.logger_hooks(logger)

        # TRAINING
        for epochs in range(0, 5):

            logger.update("epoch", epochs)
            energy_loss.plot_paths(graph, logger, realities, prefix="")

            graph.train()
            for _ in range(0, train_step):
                train_loss = energy_loss(graph,
                                         realities=[train],
                                         loss_types=["mse", "baseline_mse"])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])

                graph.step(train_loss)
                train.step()
                logger.update("loss", train_loss)

            graph.eval()
            for _ in range(0, val_step):
                with torch.no_grad():
                    val_loss = energy_loss(graph,
                                           realities=[val],
                                           loss_types=["mse", "baseline_mse"])
                    val_loss = sum(
                        [val_loss[loss_name] for loss_name in val_loss])
                val.step()
                logger.update("loss", val_loss)

            val_loss = energy_loss(graph,
                                   realities=[val],
                                   loss_types=["eval_mse"])
            val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()

            logger.update("train_energy", val_loss)
            logger.update("val_energy", val_loss)

            energy_loss.logger_update(logger)
            logger.step()
示例#15
0
def main(
    job_config="jobinfo.txt",
    models_dir="models",
    fast=False,
    batch_size=None,
    subset_size=None,
    max_epochs=500,
    dataaug=False,
    **kwargs,
):
    loss_config, loss_mode, model_class = None, None, None
    experiment, base_dir = None, None
    current_dir = os.path.dirname(__file__)
    job_config = os.path.normpath(
        os.path.join(os.path.join(current_dir, "config"), job_config))
    if os.path.isfile(job_config):
        with open(job_config) as config_file:
            out = config_file.read().strip().split(',\n')
            loss_config, loss_mode, model_class, experiment, base_dir = out
    loss_config = loss_config or LOSS_CONFIG
    loss_mode = loss_mode or LOSS_MODE
    model_class = model_class or MODEL_CLASS
    base_dir = base_dir or BASE_DIR

    base_dir = os.path.normpath(os.path.join(current_dir, base_dir))
    experiment = experiment or EXPERIMENT
    job = "_".join(experiment.split("_")[0:-1])

    models_dir = os.path.join(base_dir, models_dir)
    results_dir = f"{base_dir}/results/results_{experiment}"
    results_dir_models = f"{base_dir}/results/results_{experiment}/models"

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config,
                                  loss_mode=loss_mode,
                                  **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
        dataaug=dataaug,
    )

    if fast:
        train_dataset = val_dataset
        train_step, val_step = 2, 2

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)

    test_set = load_test(energy_loss.get_tasks("test"),
                         buildings=['almena', 'albertville', 'espanola'])
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])
    realities = [train, val, test, ood]

    # GRAPH
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      tasks_in=energy_loss.tasks_in,
                      tasks_out=energy_loss.tasks_out,
                      pretrained=True,
                      models_dir=models_dir,
                      freeze_list=energy_loss.freeze_list,
                      direct_edges=energy_loss.direct_edges,
                      model_class=model_class)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(results_dir_models, exist_ok=True)
    logger = VisdomLogger("train", env=job, port=PORT, server=SERVER)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{results_dir}/graph.pth",
                                             results_dir_models),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)
    energy_loss.plot_paths(graph, logger, realities, prefix="start")

    # BASELINE
    graph.eval()
    with torch.no_grad():
        for _ in range(0, val_step * 4):
            val_loss, _ = energy_loss(graph, realities=[val])
            val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        for _ in range(0, train_step * 4):
            train_loss, _ = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])
            train.step()
            logger.update("loss", train_loss)
    energy_loss.logger_update(logger)

    # TRAINING
    for epochs in range(0, max_epochs):
        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph, logger, realities, prefix="finish")

        graph.train()
        for _ in range(0, train_step):
            train_loss, grad_mse_coeff = energy_loss(graph,
                                                     realities=[train],
                                                     compute_grad_ratio=True)
            graph.step(train_loss,
                       losses=energy_loss.losses,
                       paths=energy_loss.paths)
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])
            train.step()
            logger.update("loss", train_loss)
            del train_loss
        print(grad_mse_coeff)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()
示例#16
0
def main(
	loss_config="trainsig_edgereshade", mode="standard", visualize=False,
	fast=False, batch_size=32, learning_rate=3e-5, resume=False,
	subset_size=None, max_epochs=800, dataaug=False, **kwargs,
):

	# CONFIG
	wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"lr":learning_rate})

	batch_size = batch_size or (4 if fast else 64)
	energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

	# DATA LOADING
	train_undist_dataset, train_dist_dataset, val_ooddist_dataset, val_dist_dataset, val_dataset, train_step, val_step = load_train_val_sig(
		energy_loss.get_tasks("val"),
		batch_size=batch_size, fast=fast,
		subset_size=subset_size,
	)
	test_set = load_test(energy_loss.get_tasks("test"))

	ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/')
	ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/')
	ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35)

	train_undist = RealityTask("train_undist", train_undist_dataset, batch_size=batch_size, shuffle=True)
	train_dist = RealityTask("train_dist", train_dist_dataset, batch_size=batch_size, shuffle=True)
	val_ooddist = RealityTask("val_ooddist", val_ooddist_dataset, batch_size=batch_size, shuffle=True)
	val_dist = RealityTask("val_dist", val_dist_dataset, batch_size=batch_size, shuffle=True)
	val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
	test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))

	ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])                                  ## standard ood set - natural
	ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,])          ## synthetic distortion images used for sig training 
	ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,])                      ## unseen syn distortions

	# GRAPH
	realities = [train_undist, train_dist, val_ooddist, val_dist, val, test, ood, ood_syn_aug, ood_syn]
	graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,
		freeze_list=energy_loss.freeze_list,
	)
	graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

	if resume:
		graph.load_weights('/workspace/shared/results_test_1/graph.pth')
		graph.optimizer.load_state_dict(torch.load('/workspace/shared/results_test_1/opt.pth'))
	# else:
	# 	folder_name='/workspace/shared/results_wavelet2normal_depthreshadecurvimgnetl1perceps_0.1nll/'
	# 	# pdb.set_trace()
	# 	in_domain='wav'
	# 	out_domain='normal'
	# 	graph.load_weights(folder_name+'graph.pth', [str((in_domain, out_domain))])
	# 	create_t0_graph(folder_name,in_domain,out_domain)
	# 	graph.load_weights(folder_name+'graph_t0.pth', [str((in_domain, f'{out_domain}_t0'))])


	# LOGGING
	logger = VisdomLogger("train", env=JOB)    # fake visdom logger
	logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
	energy_loss.logger_hooks(logger)

	# BASELINE 
	if not resume:
		graph.eval()
		with torch.no_grad():
			for reality in [val_ooddist,val_dist,val]:
				for _ in range(0, val_step):
					val_loss = energy_loss(graph, realities=[reality])
					val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
					reality.step()
					logger.update("loss", val_loss)
			for reality in [train_undist,train_dist]:
				for _ in range(0, train_step):
					train_loss = energy_loss(graph, realities=[reality])
					train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
					reality.step()
					logger.update("loss", train_loss)
		
		energy_loss.logger_update(logger)
		data=logger.step()
		del data['loss']
		data = {k:v[0] for k,v in data.items()}
		wandb.log(data, step=0)

		path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
		for reality_paths, reality_images in path_values.items():
			wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0)


	# TRAINING
	for epochs in range(0, max_epochs):

		logger.update("epoch", epochs)

		graph.train()
		for _ in range(0, train_step):
			train_loss_nll = energy_loss(graph, realities=[train_undist])
			train_loss_nll = sum([train_loss_nll[loss_name] for loss_name in train_loss_nll])
			train_loss_lwfsig = energy_loss(graph, realities=[train_dist])
			train_loss_lwfsig = sum([train_loss_lwfsig[loss_name] for loss_name in train_loss_lwfsig])
			train_loss = train_loss_nll+train_loss_lwfsig
			graph.step(train_loss)
			train_undist.step()
			train_dist.step()
			logger.update("loss", train_loss)

		graph.eval()
		for _ in range(0, val_step):
			with torch.no_grad():
				val_loss = energy_loss(graph, realities=[val_dist])
				val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
			val_dist.step()
			logger.update("loss", val_loss)
		
		if epochs % 20 == 0:
			for reality in [val,val_ooddist]:
				for _ in range(0, val_step):
					with torch.no_grad():
						val_loss = energy_loss(graph, realities=[reality])
						val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
					reality.step()
					logger.update("loss", val_loss)

		energy_loss.logger_update(logger)

		data=logger.step()
		del data['loss']
		data = {k:v[0] for k,v in data.items()}
		wandb.log(data, step=epochs+1)

		if epochs % 10 == 0:
			graph.save(f"{RESULTS_DIR}/graph.pth")
			torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")

		if (epochs % 100 == 0) or (epochs % 15 == 0 and epochs <= 30):
			path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
			for reality_paths, reality_images in path_values.items():
				wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1)



	graph.save(f"{RESULTS_DIR}/graph.pth")
	torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")