def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING

    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    train_step, val_step = 4, 4
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      freeze_list=energy_loss.freeze_list,
                      finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not USE_RAID: graph.load_weights(cont)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix=f"epoch_{epochs}")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.select_losses(val)
        if epochs != 0:
            energy_loss.logger_update(logger)
        else:
            energy_loss.metrics = {}
        logger.step()

        logger.text(f"Chosen losses: {energy_loss.chosen_losses}")
        logger.text(f"Percep winrate: {energy_loss.percep_winrate}")
        graph.train()
        for _ in range(0, train_step):
            train_loss2 = energy_loss(graph, realities=[train])
            train_loss = sum(train_loss2.values())

            graph.step(train_loss)
            train.step()

            logger.update("loss", train_loss)
Exemple #2
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{BASE_DIR}/shared/results_LBP_multipercep_lat_winrate_8/graph.pth",
    cont_gan=None,
    pre_gan=None,
    max_epochs=800,
    use_baseline=False,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        energy_loss.get_tasks("val"),
        batch_size=batch_size,
        fast=fast,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        energy_loss.get_tasks("train_subset"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    if not fast:
        train_step, val_step = train_step // (16 * 4), val_step // (16)
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               batch_size=batch_size,
                               shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    # ood = RealityTask.from_static("ood", ood_set, [energy_loss.get_tasks("ood")])

    # GRAPH
    realities = [
        train,
        val,
        test,
    ] + [train_subset]  #[ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      finetuned=finetuned,
                      freeze_list=energy_loss.freeze_list)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    graph.load_weights(cont)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    # logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)

    best_ood_val_loss = float('inf')
    energy_losses = []
    mse_losses = []
    pearsonr_vals = []
    percep_losses = defaultdict(list)
    pearson_percep = defaultdict(list)
    # # TRAINING
    # for epochs in range(0, max_epochs):

    # 	logger.update("epoch", epochs)
    # 	if epochs == 0:
    # 		energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
    # 	# if visualize: return

    # 	graph.eval()
    # 	for _ in range(0, val_step):
    # 		with torch.no_grad():
    # 			losses = energy_loss(graph, realities=[val])
    # 			all_perceps = [losses[loss_name] for loss_name in losses if 'percep' in loss_name ]
    # 			energy_avg = sum(all_perceps) / len(all_perceps)
    # 			for loss_name in losses:
    # 				if 'percep' not in loss_name: continue
    # 				percep_losses[loss_name] += [losses[loss_name].data.cpu().numpy()]
    # 			mse = losses['mse']
    # 			energy_losses.append(energy_avg.data.cpu().numpy())
    # 			mse_losses.append(mse.data.cpu().numpy())

    # 		val.step()
    # 	mse_arr = np.array(mse_losses)
    # 	energy_arr = np.array(energy_losses)
    # 	# logger.scatter(mse_arr - mse_arr.mean() / np.std(mse_arr), \
    # 	# 	energy_arr - energy_arr.mean() / np.std(energy_arr), \
    # 	# 	'unit_normal_all', opts={'xlabel':'mse','ylabel':'energy'})
    # 	logger.scatter(mse_arr, energy_arr, \
    # 		'mse_energy_all', opts={'xlabel':'mse','ylabel':'energy'})
    # 	pearsonr, p = scipy.stats.pearsonr(mse_arr, energy_arr)
    # 	logger.text(f'pearsonr = {pearsonr}, p = {p}')
    # 	pearsonr_vals.append(pearsonr)
    # 	logger.plot(pearsonr_vals, 'pearsonr_all')
    # 	for percep_name in percep_losses:
    # 		percep_loss_arr = np.array(percep_losses[percep_name])
    # 		logger.scatter(mse_arr, percep_loss_arr, f'mse_energy_{percep_name}', \
    # 			opts={'xlabel':'mse','ylabel':'energy'})
    # 		pearsonr, p = scipy.stats.pearsonr(mse_arr, percep_loss_arr)
    # 		pearson_percep[percep_name] += [pearsonr]
    # 		logger.plot(pearson_percep[percep_name], f'pearson_{percep_name}')

    # energy_loss.logger_update(logger)
    # if logger.data['val_mse : n(~x) -> y^'][-1] < best_ood_val_loss:
    # 	best_ood_val_loss = logger.data['val_mse : n(~x) -> y^'][-1]
    # 	energy_loss.plot_paths(graph, logger, realities, prefix="best")

    energy_mean_by_blur = []
    energy_std_by_blur = []
    mse_mean_by_blur = []
    mse_std_by_blur = []
    for blur_size in np.arange(0, 10, 0.5):
        tasks.rgb.blur_radius = blur_size if blur_size > 0 else None
        train_subset.step()
        # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")

        energy_losses = []
        mse_losses = []
        for epochs in range(subset_size // batch_size):
            with torch.no_grad():
                flosses = energy_loss(graph,
                                      realities=[train_subset],
                                      reduce=False)
                losses = energy_loss(graph,
                                     realities=[train_subset],
                                     reduce=False)
                all_perceps = np.stack([
                    losses[loss_name].data.cpu().numpy()
                    for loss_name in losses if 'percep' in loss_name
                ])
                energy_losses += list(all_perceps.mean(0))
                mse_losses += list(losses['mse'].data.cpu().numpy())
            train_subset.step()
        mse_losses = np.array(mse_losses)
        energy_losses = np.array(energy_losses)
        logger.text(
            f'blur_radius = {blur_size}, mse = {mse_losses.mean()}, energy = {energy_losses.mean()}'
        )
        logger.scatter(mse_losses, energy_losses, \
         f'mse_energy, blur = {blur_size}', opts={'xlabel':'mse','ylabel':'energy'})

        energy_mean_by_blur += [energy_losses.mean()]
        energy_std_by_blur += [np.std(energy_losses)]
        mse_mean_by_blur += [mse_losses.mean()]
        mse_std_by_blur += [np.std(mse_losses)]

    logger.plot(energy_mean_by_blur, f'energy_mean_by_blur')
    logger.plot(energy_std_by_blur, f'energy_std_by_blur')
    logger.plot(mse_mean_by_blur, f'mse_mean_by_blur')
    logger.plot(mse_std_by_blur, f'mse_std_by_blur')