def main():

    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="energy",
                    freq=16)
    logger.add_hook(
        lambda logger, data: logger.plot(data["energy"], "free_energy"),
        feature="energy",
        freq=100)

    task_list = [
        tasks.rgb,
        tasks.normal,
        tasks.principal_curvature,
        tasks.sobel_edges,
        tasks.depth_zbuffer,
        tasks.reshading,
        tasks.edge_occlusion,
        tasks.keypoints3d,
        tasks.keypoints2d,
    ]

    reality = RealityTask('ood',
                          dataset=ImagePairDataset(data_dir=OOD_DIR,
                                                   resize=(256, 256)),
                          tasks=[tasks.rgb, tasks.rgb],
                          batch_size=28)

    # reality = RealityTask('almena',
    #     dataset=TaskDataset(
    #         buildings=['almena'],
    #         tasks=task_list,
    #     ),
    #     tasks=task_list,
    #     batch_size=8
    # )

    graph = TaskGraph(tasks=[reality, *task_list], batch_size=28)

    task = tasks.rgb
    images = [reality.task_data[task]]
    sources = [task.name]

    for _, edge in sorted(
        ((edge.dest_task.name, edge) for edge in graph.adj[task])):
        if isinstance(edge.src_task, RealityTask): continue

        reality.task_data[edge.src_task]
        x = edge(reality.task_data[edge.src_task])

        if edge.dest_task != tasks.normal:
            edge2 = graph.edge_map[(edge.dest_task.name, tasks.normal.name)]
            x = edge2(x)

        images.append(x.clamp(min=0, max=1))
        sources.append(edge.dest_task.name)

    logger.images_grouped(images, ", ".join(sources), resize=256)
def main():

    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16)
    logger.add_hook(lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100)

    task_list = [
        tasks.rgb,
        tasks.normal,
        tasks.principal_curvature,
        tasks.sobel_edges,
        tasks.depth_zbuffer,
        tasks.reshading,
        tasks.edge_occlusion,
        tasks.keypoints3d,
        tasks.keypoints2d,
    ]

    reality = RealityTask('almena', 
        dataset=TaskDataset(
            buildings=['almena'],
            tasks=task_list,
        ),
        tasks=task_list,
        batch_size=8
    )

    graph = TaskGraph(
        tasks=[reality, *task_list],
        batch_size=8
    )

    for task in graph.tasks:
        if isinstance(task, RealityTask): continue

        images = [reality.task_data[task].clamp(min=0, max=1)]
        sources = [task.name]
        for _, edge in sorted(((edge.src_task.name, edge) for edge in graph.in_adj[task])):
            if isinstance(edge.src_task, RealityTask): continue

            x = edge(reality.task_data[edge.src_task])
            images.append(x.clamp(min=0, max=1))
            sources.append(edge.src_task.name)

        logger.images_grouped(images, ", ".join(sources), resize=192)
Example #3
0
def main():

    task_list = [
        tasks.rgb,
        tasks.normal,
        tasks.principal_curvature,
        tasks.sobel_edges,
        tasks.depth_zbuffer,
        tasks.reshading,
        tasks.edge_occlusion,
        tasks.keypoints3d,
        tasks.keypoints2d,
    ]

    reality = RealityTask('almena',
                          dataset=TaskDataset(buildings=['almena'],
                                              tasks=[
                                                  tasks.rgb, tasks.normal,
                                                  tasks.principal_curvature,
                                                  tasks.depth_zbuffer
                                              ]),
                          tasks=[
                              tasks.rgb, tasks.normal,
                              tasks.principal_curvature, tasks.depth_zbuffer
                          ],
                          batch_size=4)

    graph = TaskGraph(
        tasks=[reality, *task_list],
        anchored_tasks=[reality, tasks.rgb],
        reality=reality,
        batch_size=4,
        edges_exclude=[
            ('almena', 'normal'),
            ('almena', 'principal_curvature'),
            ('almena', 'depth_zbuffer'),
            # ('rgb', 'keypoints3d'),
            # ('rgb', 'edge_occlusion'),
        ],
        initialize_first_order=True,
    )

    graph.p.compile(torch.optim.Adam, lr=4e-2)
    graph.estimates.compile(torch.optim.Adam, lr=1e-2)

    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="energy",
                    freq=16)
    logger.add_hook(
        lambda logger, data: logger.plot(data["energy"], "free_energy"),
        feature="energy",
        freq=100)
    logger.add_hook(lambda logger, data: graph.plot_estimates(logger),
                    feature="epoch",
                    freq=32)
    logger.add_hook(lambda logger, data: graph.update_paths(logger),
                    feature="epoch",
                    freq=32)

    graph.plot_estimates(logger)
    graph.plot_paths(logger,
                     dest_tasks=[
                         tasks.normal, tasks.principal_curvature,
                         tasks.depth_zbuffer
                     ],
                     show_images=False)

    for epochs in range(0, 4000):
        logger.update("epoch", epochs)

        with torch.no_grad():
            free_energy = graph.free_energy(sample=16)
            graph.averaging_step()

        logger.update("energy", free_energy)
Example #4
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{BASE_DIR}/shared/results_LBP_multipercep_lat_winrate_8/graph.pth",
    cont_gan=None,
    pre_gan=None,
    max_epochs=800,
    use_baseline=False,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        energy_loss.get_tasks("val"),
        batch_size=batch_size,
        fast=fast,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        energy_loss.get_tasks("train_subset"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    if not fast:
        train_step, val_step = train_step // (16 * 4), val_step // (16)
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               batch_size=batch_size,
                               shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    # ood = RealityTask.from_static("ood", ood_set, [energy_loss.get_tasks("ood")])

    # GRAPH
    realities = [
        train,
        val,
        test,
    ] + [train_subset]  #[ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      finetuned=finetuned,
                      freeze_list=energy_loss.freeze_list)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    graph.load_weights(cont)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    # logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)

    best_ood_val_loss = float('inf')
    energy_losses = []
    mse_losses = []
    pearsonr_vals = []
    percep_losses = defaultdict(list)
    pearson_percep = defaultdict(list)
    # # TRAINING
    # for epochs in range(0, max_epochs):

    # 	logger.update("epoch", epochs)
    # 	if epochs == 0:
    # 		energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
    # 	# if visualize: return

    # 	graph.eval()
    # 	for _ in range(0, val_step):
    # 		with torch.no_grad():
    # 			losses = energy_loss(graph, realities=[val])
    # 			all_perceps = [losses[loss_name] for loss_name in losses if 'percep' in loss_name ]
    # 			energy_avg = sum(all_perceps) / len(all_perceps)
    # 			for loss_name in losses:
    # 				if 'percep' not in loss_name: continue
    # 				percep_losses[loss_name] += [losses[loss_name].data.cpu().numpy()]
    # 			mse = losses['mse']
    # 			energy_losses.append(energy_avg.data.cpu().numpy())
    # 			mse_losses.append(mse.data.cpu().numpy())

    # 		val.step()
    # 	mse_arr = np.array(mse_losses)
    # 	energy_arr = np.array(energy_losses)
    # 	# logger.scatter(mse_arr - mse_arr.mean() / np.std(mse_arr), \
    # 	# 	energy_arr - energy_arr.mean() / np.std(energy_arr), \
    # 	# 	'unit_normal_all', opts={'xlabel':'mse','ylabel':'energy'})
    # 	logger.scatter(mse_arr, energy_arr, \
    # 		'mse_energy_all', opts={'xlabel':'mse','ylabel':'energy'})
    # 	pearsonr, p = scipy.stats.pearsonr(mse_arr, energy_arr)
    # 	logger.text(f'pearsonr = {pearsonr}, p = {p}')
    # 	pearsonr_vals.append(pearsonr)
    # 	logger.plot(pearsonr_vals, 'pearsonr_all')
    # 	for percep_name in percep_losses:
    # 		percep_loss_arr = np.array(percep_losses[percep_name])
    # 		logger.scatter(mse_arr, percep_loss_arr, f'mse_energy_{percep_name}', \
    # 			opts={'xlabel':'mse','ylabel':'energy'})
    # 		pearsonr, p = scipy.stats.pearsonr(mse_arr, percep_loss_arr)
    # 		pearson_percep[percep_name] += [pearsonr]
    # 		logger.plot(pearson_percep[percep_name], f'pearson_{percep_name}')

    # energy_loss.logger_update(logger)
    # if logger.data['val_mse : n(~x) -> y^'][-1] < best_ood_val_loss:
    # 	best_ood_val_loss = logger.data['val_mse : n(~x) -> y^'][-1]
    # 	energy_loss.plot_paths(graph, logger, realities, prefix="best")

    energy_mean_by_blur = []
    energy_std_by_blur = []
    mse_mean_by_blur = []
    mse_std_by_blur = []
    for blur_size in np.arange(0, 10, 0.5):
        tasks.rgb.blur_radius = blur_size if blur_size > 0 else None
        train_subset.step()
        # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")

        energy_losses = []
        mse_losses = []
        for epochs in range(subset_size // batch_size):
            with torch.no_grad():
                flosses = energy_loss(graph,
                                      realities=[train_subset],
                                      reduce=False)
                losses = energy_loss(graph,
                                     realities=[train_subset],
                                     reduce=False)
                all_perceps = np.stack([
                    losses[loss_name].data.cpu().numpy()
                    for loss_name in losses if 'percep' in loss_name
                ])
                energy_losses += list(all_perceps.mean(0))
                mse_losses += list(losses['mse'].data.cpu().numpy())
            train_subset.step()
        mse_losses = np.array(mse_losses)
        energy_losses = np.array(energy_losses)
        logger.text(
            f'blur_radius = {blur_size}, mse = {mse_losses.mean()}, energy = {energy_losses.mean()}'
        )
        logger.scatter(mse_losses, energy_losses, \
         f'mse_energy, blur = {blur_size}', opts={'xlabel':'mse','ylabel':'energy'})

        energy_mean_by_blur += [energy_losses.mean()]
        energy_std_by_blur += [np.std(energy_losses)]
        mse_mean_by_blur += [mse_losses.mean()]
        mse_std_by_blur += [np.std(mse_losses)]

    logger.plot(energy_mean_by_blur, f'energy_mean_by_blur')
    logger.plot(energy_std_by_blur, f'energy_std_by_blur')
    logger.plot(mse_mean_by_blur, f'mse_mean_by_blur')
    logger.plot(mse_std_by_blur, f'mse_std_by_blur')
Example #5
0
        torch.save((perturbation.data, images.data, targets),
                   f"{output_path}/{k}.pth")


if __name__ == "__main__":

    model = DataParallelModel(
        DecodingModel(n=DIST_SIZE, distribution=transforms.training))
    params = itertools.chain(model.module.classifier.parameters(),
                             model.module.features[-1].parameters())
    optimizer = torch.optim.Adam(params, lr=2.5e-3)
    init_data("data/amnesia")

    logger = VisdomLogger("train", server="35.230.67.129", port=8000, env=JOB)
    logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20)
    logger.add_hook(lambda data: logger.plot(data, "train_loss"),
                    feature="loss",
                    freq=50)
    logger.add_hook(lambda data: logger.plot(data, "train_bits"),
                    feature="bits",
                    freq=50)
    logger.add_hook(
        lambda x: model.save("output/train_test.pth", verbose=True),
        feature="epoch",
        freq=100)
    model.save("output/train_test.pth", verbose=True)

    files = glob.glob(f"data/amnesia/*.pth")
    for i, save_file in enumerate(
            random.choice(files) for i in range(0, 2701)):
Example #6
0
from torch.autograd import Variable

from models import DecodingModel, DataParallelModel
from torchvision import models
from logger import Logger, VisdomLogger
from utils import *

import IPython

import transforms


# LOGGING
logger = VisdomLogger("encoding", server="35.230.67.129", port=8000, env=JOB)
logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20)
logger.add_hook(lambda x: logger.plot(x, "Encoding Loss", opts=dict(ymin=0)), feature="loss", freq=50)


""" 
Computes the changed images, given a a specified perturbation, standard deviation weighting, 
and epsilon.
"""


def compute_changed_images(images, perturbation, std_weights, epsilon=EPSILON):

    perturbation_w2 = perturbation * std_weights
    perturbation_zc = (
        perturbation_w2
        / perturbation_w2.view(perturbation_w2.shape[0], -1)
        .norm(2, dim=1, keepdim=True)