예제 #1
0
def run_eval_suite(name,
                   dest_task=tasks.normal,
                   graph_file=None,
                   model_file=None,
                   logger=None,
                   sample=800,
                   show_images=False,
                   old=False):

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif model_file is not None:
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.normal,
                         dest_task=dest_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    dataset = ValidationMetrics("almena", dest_task=dest_task)
    result = dataset.evaluate(model, sample=800)
    logger.text(name + ": " + str(result))
예제 #2
0
def run_viz_suite(name,
                  data,
                  dest_task=tasks.depth_zbuffer,
                  graph_file=None,
                  model_file=None,
                  logger=None,
                  old=False,
                  multitask=False,
                  percep_mode=None):

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        print('here')
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # DATA LOADING 1
    results = model.predict(data)[:, -3:].clamp(min=0, max=1)
    if results.shape[1] == 1:
        results = torch.cat([results] * 3, dim=1)

    if percep_mode:
        percep_model = Transfer(src_task=dest_task,
                                dest_task=tasks.normal).load_model()
        percep_model.eval()
        eval_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(results),
            batch_size=16,
            num_workers=16,
            shuffle=False,
            pin_memory=True)
        final_preds = []
        for preds, in eval_loader:
            print('preds shape', preds.shape)
            final_preds += [percep_model.forward(preds[:, -3:])]
        results = torch.cat(final_preds, dim=0)

    return results
예제 #3
0
def run_perceptual_eval_suite(name,
                              intermediate_task=tasks.normal,
                              dest_task=tasks.normal,
                              graph_file=None,
                              model_file=None,
                              logger=None,
                              sample=800,
                              show_images=False,
                              old=False,
                              perceptual_transfer=None,
                              multitask=False):

    if perceptual_transfer is None:
        percep_model = Transfer(src_task=intermediate_task,
                                dest_task=dest_task).load_model()

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, intermediate_task],
                          pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, intermediate_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        print('running multitask')
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.rgb,
                         dest_task=intermediate_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    dataset = ValidationMetrics("almena", dest_task=dest_task)
    result = dataset.evaluate_with_percep(model,
                                          sample=800,
                                          percep_model=percep_model)
    logger.text(name + ": " + str(result))
예제 #4
0
def run_viz_suite(name,
                  data_loader,
                  dest_task=tasks.depth_zbuffer,
                  graph_file=None,
                  model_file=None,
                  old=False,
                  multitask=False,
                  percep_mode=None,
                  downsample=6,
                  out_channels=3,
                  final_task=tasks.normal,
                  oldpercep=False):

    extra_task = [final_task] if percep_mode else []

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task] + extra_task,
                          pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        # downsample = 5 or 6
        print('loading main model')
        #model = DataParallelModel.load(UNetReshade(downsample=downsample,  out_channels=out_channels).cuda(), model_file)
        model = DataParallelModel.load(
            UNet(downsample=downsample, out_channels=out_channels).cuda(),
            model_file)
        #model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = DummyModel(
            Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model())

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # DATA LOADING 1
    results = []
    final_preds = []

    if percep_mode:
        print('Loading percep model...')
        if graph_file is not None and not oldpercep:
            percep_model = graph.edge(dest_task, final_task).load_model()
            percep_model.compile(torch.optim.Adam,
                                 lr=3e-4,
                                 weight_decay=2e-6,
                                 amsgrad=True)
        else:
            percep_model = Transfer(src_task=dest_task,
                                    dest_task=final_task).load_model()
        percep_model.eval()

    print("Converting...")
    for data, in data_loader:
        preds = model.predict_on_batch(data)[:, -3:].clamp(min=0, max=1)
        results.append(preds.detach().cpu())
        if percep_mode:
            try:
                final_preds += [
                    percep_model.forward(preds[:, -3:]).detach().cpu()
                ]
            except RuntimeError:
                preds = torch.cat([preds] * 3, dim=1)
                final_preds += [
                    percep_model.forward(preds[:, -3:]).detach().cpu()
                ]
        #break

    if percep_mode:
        results = torch.cat(final_preds, dim=0)
    else:
        results = torch.cat(results, dim=0)

    return results
예제 #5
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    use_optimizer=False,
    subset_size=None,
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    print(kwargs)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    train_step, val_step = train_step // 4, val_step // 4
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood([
        tasks.rgb,
    ])
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=False,
        freeze_list=energy_loss.freeze_list,
    )
    graph.edge(tasks.rgb, tasks.normal).model = None
    graph.edge(
        tasks.rgb, tasks.normal
    ).path = "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/n.pth"
    graph.edge(tasks.rgb, tasks.normal).load_model()
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    if use_optimizer:
        optimizer = torch.load(
            "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/optimizer.pth"
        )
        graph.optimizer.load_state_dict(optimizer.state_dict())

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        energy_loss.logger_update(logger)

        # if logger.data["val_mse : y^ -> n(~x)"][-1] < best_ood_val_loss:
        # 	best_ood_val_loss = logger.data["val_mse : y^ -> n(~x)"][-1]
        # 	energy_loss.plot_paths(graph, logger, realities, prefix="best")

        logger.step()
예제 #6
0
def main(
        fast=False,
        subset_size=None,
        early_stopping=float('inf'),
        mode='standard',
        max_epochs=800,
        **kwargs,
):

    early_stopping = 8
    loss_config_percepnet = {
        "paths": {
            "y": [tasks.normal],
            "z^": [tasks.principal_curvature],
            "f(y)": [tasks.normal, tasks.principal_curvature],
        },
        "losses": {
            "mse": {
                ("train", "val"): [
                    ("f(y)", "z^"),
                ],
            },
        },
        "plots": {
            "ID":
            dict(size=256,
                 realities=("test", "ood"),
                 paths=[
                     "y",
                     "z^",
                     "f(y)",
                 ]),
        },
    }

    # CONFIG
    batch_size = 64
    energy_loss = EnergyLoss(**loss_config_percepnet)

    task_list = [tasks.rgb, tasks.normal, tasks.principal_curvature]
    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        task_list,
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(task_list)
    ood_set = load_ood(task_list)
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set, task_list)
    ood = RealityTask.from_static("ood", ood_set, task_list)

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities,
        pretrained=False,
        freeze_list=[functional_transfers.n],
    )
    graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)
    best_val_loss, stop_idx = float('inf'), 0

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()

        stop_idx += 1
        if logger.data["val_mse : f(y) -> z^"][-1] < best_val_loss:
            print("Better val loss, reset stop_idx: ", stop_idx)
            best_val_loss, stop_idx = logger.data["val_mse : f(y) -> z^"][
                -1], 0
            energy_loss.plot_paths(graph, logger, realities, prefix="best")
            graph.save(weights_dir=f"{RESULTS_DIR}")

        if stop_idx >= early_stopping:
            print("Stopping training now")
            break

    early_stopping = 50
    # CONFIG
    energy_loss = get_energy_loss(config="perceptual", mode=mode, **kwargs)

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities,
        pretrained=False,
        freeze_list=[functional_transfers.f],
    )
    graph.edge(
        tasks.normal,
        tasks.principal_curvature).model.load_weights(f"{RESULTS_DIR}/f.pth")
    graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)
    best_val_loss, stop_idx = float('inf'), 0

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()

        stop_idx += 1
        if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss:
            print("Better val loss, reset stop_idx: ", stop_idx)
            best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][
                -1], 0
            energy_loss.plot_paths(graph, logger, realities, prefix="best")
            graph.save(f"{RESULTS_DIR}/graph.pth")

        if stop_idx >= early_stopping:
            print("Stopping training now")
            break
예제 #7
0
def main(
	loss_config="baseline", mode="standard", visualize=False,
	fast=False, batch_size=None, path=None,
	subset_size=None, early_stopping=float('inf'),
	max_epochs=800, **kwargs
):
	
	# CONFIG
	batch_size = batch_size or (4 if fast else 64)
	energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

	# DATA LOADING
	train_dataset, val_dataset, train_step, val_step = load_train_val(
		energy_loss.get_tasks("train"),
		batch_size=batch_size, fast=fast,
		subset_size=subset_size,
	)
	test_set = load_test(energy_loss.get_tasks("test"))
	print('tasks', energy_loss.get_tasks("ood"))

	ood_set = load_ood(energy_loss.get_tasks("ood"))
	print (train_step, val_step)
    
	train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
	val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
	test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
	ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

	# GRAPH
	realities = [train, val, test, ood]
	graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=False, 
		freeze_list=energy_loss.freeze_list,
	)
	graph.edge(tasks.rgb, tasks.normal).model = None 
	graph.edge(tasks.rgb, tasks.normal).path = path
	graph.edge(tasks.rgb, tasks.normal).load_model()
	graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)
	graph.save(weights_dir=f"{RESULTS_DIR}")

	# LOGGING
	logger = VisdomLogger("train", env=JOB)
	logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
	energy_loss.logger_hooks(logger)
	best_val_loss, stop_idx = float('inf'), 0
# 	return 
	# TRAINING
	for epochs in range(0, max_epochs):

		logger.update("epoch", epochs)
		energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
		if visualize: return

		graph.train()
		for _ in range(0, train_step):
			train_loss = energy_loss(graph, realities=[train])
			train_loss = sum([train_loss[loss_name] for loss_name in train_loss])

			graph.step(train_loss)
			train.step()
			logger.update("loss", train_loss)

		graph.eval()
		for _ in range(0, val_step):
			with torch.no_grad():
				val_loss = energy_loss(graph, realities=[val])
				val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
			val.step()
			logger.update("loss", val_loss)

		energy_loss.logger_update(logger)
		logger.step()

		stop_idx += 1
		if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss:
			print ("Better val loss, reset stop_idx: ", stop_idx)
			best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][-1], 0
			energy_loss.plot_paths(graph, logger, realities, prefix="best")
			graph.save(weights_dir=f"{RESULTS_DIR}")

		if stop_idx >= early_stopping:
			print ("Stopping training now")
			return
예제 #8
0
def main(
    loss_config="baseline",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    train_step, val_step = 4 * train_step, 4 * val_step
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        freeze_list=energy_loss.freeze_list,
    )
    graph.edge(tasks.rgb, tasks.normal).model = None
    graph.edge(tasks.rgb, tasks.normal
               ).path = f"{SHARED_DIR}/results_SAMPLEFF_consistency1m_25/n.pth"
    graph.edge(tasks.rgb, tasks.normal).load_model()
    graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)

    # TRAINING
    graph.eval()
    for _ in range(0, val_step):
        with torch.no_grad():
            val_loss = energy_loss(graph, realities=[val])
            val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
        val.step()
        logger.update("loss", val_loss)

    energy_loss.logger_update(logger)
    logger.step()

    # print ("Train mse: ", logger.data["train_mse : n(x) -> y^"])
    print("Val mse: ", logger.data["val_mse : n(x) -> y^"])