Esempio n. 1
0
def main(_run):
    args = tupperware(_run.config)

    regret_ll, var_regret_ll = UCB(
        args, game_i=0, alpha=2, pbar=tqdm(range(args.repeat))
    )
    print(f"UCB regret {regret_ll}")
def main(_run):
    from utils.tupperware import tupperware
    from torchsummary import summary

    args = tupperware(_run.config)
    model = DeepAtrousGuidedFilter(args).to(args.device)
    summary(model, (3, 1024, 2048))
def main(_run):
    from utils.tupperware import tupperware
    from torchsummary import summary

    args = tupperware(_run.config)
    model = LRNet(in_c=12, out_c=12, args=args).to(args.device)
    summary(model, (12, 256, 512))
Esempio n. 4
0
def main(_run):
    from tqdm import tqdm
    from utils.tupperware import tupperware

    args = tupperware(_run.config)
    data = get_dataloaders(args)

    for batch in tqdm(data.train_loader.dataset):
        pass
Esempio n. 5
0
def main(_run):
    args = tupperware(_run.config)

    regret_ll, _, _ = TS_beta(args,
                              game_i=1,
                              params=(1, 1),
                              pbar=tqdm(range(args.repeat)))
    print(f"TS regret {regret_ll}")
    plt.plot(regret_ll)
    plt.show()
Esempio n. 6
0
def main(_run):
    args = tupperware(_run.config)

    regret_ll, _, arm_ll = MRAS_categ_elite(
        args,
        game_i=2,
        N_0=400,
        alpha=1,
        kai_0=0,
        rho_0=0.7,
        epi_J=1e-6,
        pbar=tqdm(range(args.repeat)),
        verbose=True,
    )
    print(f"Arms pulled {arm_ll}")
    print(f"MRAS regret {regret_ll}")
    plt.plot(regret_ll)
    plt.show()
def main(_run):
    args = tupperware(_run.config)

    # Dir init
    dir_init(args, is_local_rank_0=is_local_rank_0)

    # Ignore warnings
    if not is_local_rank_0:
        warnings.filterwarnings("ignore")

    # Mutli GPUS Setup
    if args.distdataparallel:
        rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        world_size = dist.get_world_size()
    else:
        rank = args.device
        world_size = 1

    # Get data
    data = get_dataloaders(args, is_local_rank_0=is_local_rank_0)

    # Model
    G = get_model.model(args).to(rank)

    # Optimisers
    g_optimizer, g_lr_scheduler = get_optimisers(G, args)

    # Load Models
    G, g_optimizer, global_step, start_epoch, loss = load_models(
        G, g_optimizer, args, is_local_rank_0=is_local_rank_0)

    if args.distdataparallel:
        # Wrap with Distributed Data Parallel
        G = torch.nn.parallel.DistributedDataParallel(G,
                                                      device_ids=[rank],
                                                      output_device=rank)

    # Log no of GPUs
    if is_local_rank_0:
        world_size = int(
            os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
        logging.info("Using {} GPUs".format(world_size))

        writer = SummaryWriter(log_dir=str(args.run_dir))
        writer.add_text("Args", pprint_args(args))

        # Pbars
        train_pbar = tqdm(range(len(data.train_loader) * args.batch_size),
                          dynamic_ncols=True)

        val_pbar = (tqdm(range(len(data.val_loader) * args.batch_size),
                         dynamic_ncols=True) if data.val_loader else None)

        test_pbar = (tqdm(range(len(data.test_loader) * args.batch_size),
                          dynamic_ncols=True) if data.test_loader else None)

    # Initialise losses
    g_loss = GLoss(args).to(rank)

    # Compatibility with checkpoints without global_step
    if not global_step:
        global_step = start_epoch * len(data.train_loader) * args.batch_size

    start_epoch = global_step // len(data.train_loader.dataset)

    # Exponential averaging of loss
    loss_dict = {
        "total_loss": 0.0,
        "image_loss": 0.0,
        "cobi_rgb_loss": 0.0,
        "train_PSNR": 0.0,
    }

    metric_dict = {"PSNR": 0.0, "total_loss": 0.0}
    avg_metrics = AvgLoss_with_dict(loss_dict=metric_dict, args=args)
    exp_loss = ExpLoss_with_dict(loss_dict=loss_dict, args=args)

    try:
        for epoch in range(start_epoch, args.num_epochs):
            # Train mode
            G.train()

            if is_local_rank_0:
                train_pbar.reset()

            if args.distdataparallel:
                data.train_loader.sampler.set_epoch(epoch)

            for i, batch in enumerate(data.train_loader):
                # allows for interrupted training
                if ((global_step + 1) %
                    (len(data.train_loader) * args.batch_size)
                        == 0) and (epoch == start_epoch):
                    break

                loss_dict = defaultdict(float)

                source, target, filename = batch
                source, target = (source.to(rank), target.to(rank))

                # ------------------------------- #
                # Update Gen
                # ------------------------------- #
                G.zero_grad()
                output = G(source)

                g_loss(output=output, target=target)

                g_loss.total_loss.backward()
                g_optimizer.step()

                # Update lr schedulers
                g_lr_scheduler.step(epoch + i / len(data.train_loader))

                # if is_local_rank_0:
                # Train PSNR
                loss_dict["train_PSNR"] += PSNR(output, target)

                # Accumulate all losses
                loss_dict["total_loss"] += g_loss.total_loss
                loss_dict["image_loss"] += g_loss.image_loss
                loss_dict["cobi_rgb_loss"] += g_loss.cobi_rgb_loss

                exp_loss += reduce_loss_dict(loss_dict, world_size=world_size)

                global_step += args.batch_size * world_size

                if is_local_rank_0:
                    train_pbar.update(args.batch_size)
                    train_pbar.set_description(
                        f"Epoch: {epoch + 1} | Gen loss: {exp_loss.loss_dict['total_loss']:.3f} "
                    )

                # Write lr rates and metrics
                if is_local_rank_0 and i % (args.log_interval) == 0:
                    gen_lr = g_optimizer.param_groups[0]["lr"]
                    writer.add_scalar("lr/gen", gen_lr, global_step)

                    for metric in exp_loss.loss_dict:
                        writer.add_scalar(
                            f"Train_Metrics/{metric}",
                            exp_loss.loss_dict[metric],
                            global_step,
                        )

                    # Display images at end of epoch
                    n = np.min([3, args.batch_size])
                    for e in range(n):
                        source_vis = source[e].mul(0.5).add(0.5)
                        target_vis = target[e].mul(0.5).add(0.5)
                        output_vis = output[e].mul(0.5).add(0.5)

                        writer.add_image(
                            f"Source/Train_{e + 1}",
                            source_vis.cpu().detach(),
                            global_step,
                        )

                        writer.add_image(
                            f"Target/Train_{e + 1}",
                            target_vis.cpu().detach(),
                            global_step,
                        )

                        writer.add_image(
                            f"Output/Train_{e + 1}",
                            output_vis.cpu().detach(),
                            global_step,
                        )

                        writer.add_text(f"Filename/Train_{e + 1}", filename[e],
                                        global_step)

            if is_local_rank_0:
                # Save ckpt at end of epoch
                logging.info(
                    f"Saving weights at epoch {epoch + 1} global step {global_step}"
                )

                # Save weights
                save_weights(
                    epoch=epoch,
                    global_step=global_step,
                    G=G,
                    g_optimizer=g_optimizer,
                    loss=loss,
                    tag="latest",
                    args=args,
                )

                train_pbar.refresh()

            # Run val and test only occasionally
            if epoch % args.val_test_epoch_interval != 0:
                continue

            # Val and test
            with torch.no_grad():
                G.eval()

                if data.val_loader:
                    avg_metrics.reset()
                    if is_local_rank_0:
                        val_pbar.reset()

                    filename_static = []

                    for i, batch in enumerate(data.val_loader):
                        metrics_dict = defaultdict(float)

                        source, target, filename = batch
                        source, target = (source.to(rank), target.to(rank))

                        output = G(source)
                        g_loss(output=output, target=target)

                        # Total loss
                        metrics_dict["total_loss"] += g_loss.total_loss
                        # PSNR
                        metrics_dict["PSNR"] += PSNR(output, target)

                        avg_metrics += reduce_loss_dict(metrics_dict,
                                                        world_size=world_size)

                        # Save image
                        if args.static_val_image in filename:
                            filename_static = filename
                            source_static = source
                            target_static = target
                            output_static = output

                        if is_local_rank_0:
                            val_pbar.update(args.batch_size)
                            val_pbar.set_description(
                                f"Val Epoch : {epoch + 1} Step: {global_step}| PSNR: {avg_metrics.loss_dict['PSNR']:.3f}"
                            )
                    if is_local_rank_0:
                        for metric in avg_metrics.loss_dict:
                            writer.add_scalar(
                                f"Val_Metrics/{metric}",
                                avg_metrics.loss_dict[metric],
                                global_step,
                            )

                        n = np.min([3, args.batch_size])
                        for e in range(n):
                            source_vis = source[e].mul(0.5).add(0.5)
                            target_vis = target[e].mul(0.5).add(0.5)
                            output_vis = output[e].mul(0.5).add(0.5)

                            writer.add_image(
                                f"Source/Val_{e+1}",
                                source_vis.cpu().detach(),
                                global_step,
                            )
                            writer.add_image(
                                f"Target/Val_{e+1}",
                                target_vis.cpu().detach(),
                                global_step,
                            )
                            writer.add_image(
                                f"Output/Val_{e+1}",
                                output_vis.cpu().detach(),
                                global_step,
                            )

                            writer.add_text(f"Filename/Val_{e + 1}",
                                            filename[e], global_step)

                        for e, name in enumerate(filename_static):
                            if name == args.static_val_image:
                                source_vis = source_static[e].mul(0.5).add(0.5)
                                target_vis = target_static[e].mul(0.5).add(0.5)
                                output_vis = output_static[e].mul(0.5).add(0.5)

                                writer.add_image(
                                    f"Source/Val_Static",
                                    source_vis.cpu().detach(),
                                    global_step,
                                )
                                writer.add_image(
                                    f"Target/Val_Static",
                                    target_vis.cpu().detach(),
                                    global_step,
                                )
                                writer.add_image(
                                    f"Output/Val_Static",
                                    output_vis.cpu().detach(),
                                    global_step,
                                )

                                writer.add_text(
                                    f"Filename/Val_Static",
                                    filename_static[e],
                                    global_step,
                                )

                                break

                        logging.info(
                            f"Saving weights at END OF epoch {epoch + 1} global step {global_step}"
                        )

                        # Save weights
                        if avg_metrics.loss_dict["total_loss"] < loss:
                            is_min = True
                            loss = avg_metrics.loss_dict["total_loss"]
                        else:
                            is_min = False

                        # Save weights
                        save_weights(
                            epoch=epoch,
                            global_step=global_step,
                            G=G,
                            g_optimizer=g_optimizer,
                            loss=loss,
                            is_min=is_min,
                            args=args,
                            tag="best",
                        )

                        val_pbar.refresh()

                # Test
                if data.test_loader:
                    filename_static = []

                    if is_local_rank_0:
                        test_pbar.reset()

                    for i, batch in enumerate(data.test_loader):
                        source, filename = batch
                        source = source.to(rank)

                        output = G(source)

                        # Save image
                        if args.static_test_image in filename:
                            filename_static = filename
                            source_static = source
                            output_static = output

                        if is_local_rank_0:
                            test_pbar.update(args.batch_size)
                            test_pbar.set_description(
                                f"Test Epoch : {epoch + 1} Step: {global_step}"
                            )

                    if is_local_rank_0:
                        n = np.min([3, args.batch_size])
                        for e in range(n):
                            source_vis = source[e].mul(0.5).add(0.5)
                            output_vis = output[e].mul(0.5).add(0.5)

                            writer.add_image(
                                f"Source/Test_{e+1}",
                                source_vis.cpu().detach(),
                                global_step,
                            )

                            writer.add_image(
                                f"Output/Test_{e+1}",
                                output_vis.cpu().detach(),
                                global_step,
                            )

                            writer.add_text(f"Filename/Test_{e + 1}",
                                            filename[e], global_step)

                        for e, name in enumerate(filename_static):
                            if name == args.static_test_image:
                                source_vis = source_static[e]
                                output_vis = output_static[e]

                                writer.add_image(
                                    f"Source/Test_Static",
                                    source_vis.cpu().detach(),
                                    global_step,
                                )

                                writer.add_image(
                                    f"Output/Test_Static",
                                    output_vis.cpu().detach(),
                                    global_step,
                                )

                                writer.add_text(
                                    f"Filename/Test_Static",
                                    filename_static[e],
                                    global_step,
                                )

                                break

                        test_pbar.refresh()

    except KeyboardInterrupt:
        if is_local_rank_0:
            logging.info("-" * 89)
            logging.info("Exiting from training early. Saving models")

            for pbar in [train_pbar, val_pbar, test_pbar]:
                if pbar:
                    pbar.refresh()

            save_weights(
                epoch=epoch,
                global_step=global_step,
                G=G,
                g_optimizer=g_optimizer,
                loss=loss,
                is_min=True,
                args=args,
            )
Esempio n. 8
0
def main(_run):
    args = tupperware(_run.config)
    args.finetune = False
    args.batch_size = 1

    device = args.device

    # Get data
    data = get_dataloaders(args)

    # Model
    G = get_model.model(args).to(device)

    # LPIPS Criterion
    lpips_criterion = PerceptualLoss(
        model="net-lin", net="alex", use_gpu=True, gpu_ids=[device]
    ).to(device)

    # Load Models
    G, _, global_step, start_epoch, loss = load_models(
        G, g_optimizer=None, args=args, tag=args.inference_mode
    )

    # Metric loggers
    val_metrics_dict = {"PSNR": 0.0, "SSIM": 0.0, "LPIPS_01": 0.0, "LPIPS_11": 0.0}
    avg_val_metrics = AvgLoss_with_dict(loss_dict=val_metrics_dict, args=args)

    logging.info(f"Loaded experiment {args.exp_name} trained for {start_epoch} epochs.")

    # Train, val and test paths
    val_path = args.output_dir / f"val_{args.inference_mode}_epoch_{start_epoch}"
    test_path = args.output_dir / f"test_{args.inference_mode}_epoch_{start_epoch}"

    if args.self_ensemble:
        val_path = val_path.parent / f"{val_path.name}_self_ensemble"
        test_path = test_path.parent / f"{test_path.name}_self_ensemble"

    val_path.mkdir(exist_ok=True, parents=True)
    test_path.mkdir(exist_ok=True, parents=True)

    with torch.no_grad():
        G.eval()

        # Run val for an epoch
        avg_val_metrics.reset()
        pbar = tqdm(range(len(data.val_loader) * args.batch_size), dynamic_ncols=True)

        for i, batch in enumerate(data.val_loader):
            metrics_dict = defaultdict(float)

            source, target, filename = batch
            source, target = (source.to(device), target.to(device))

            output = G(source)

            if args.self_ensemble:
                output_ensembled = [output]

                for k in ensemble_ops.keys():
                    # Forward transform
                    source_t = ensemble_ops[k][0](source)
                    output_t = G(source_t)
                    # Inverse transform
                    output_t = ensemble_ops[k][1](output_t)

                    output_ensembled.append(output_t)

                output_ensembled = torch.cat(output_ensembled, dim=0)

                output = torch.mean(output_ensembled, dim=0, keepdim=True)

            # PSNR
            output_255 = (output.mul(0.5).add(0.5) * 255.0).int()
            output_quant = (output_255.float() / 255.0).sub(0.5).mul(2)

            target_255 = (target.mul(0.5).add(0.5) * 255.0).int()
            target_quant = (target_255.float() / 255.0).sub(0.5).mul(2)

            # LPIPS
            metrics_dict["LPIPS_01"] += lpips_criterion(
                output_quant.mul(0.5).add(0.5), target_quant.mul(0.5).add(0.5)
            ).item()

            metrics_dict["LPIPS_11"] += lpips_criterion(
                output_quant, target_quant
            ).item()

            for e in range(args.batch_size):
                # Compute SSIM
                target_numpy = (
                    target[e].mul(0.5).add(0.5).permute(1, 2, 0).cpu().detach().numpy()
                )

                output_numpy = (
                    output[e].mul(0.5).add(0.5).permute(1, 2, 0).cpu().detach().numpy()
                )

                metrics_dict["PSNR"] += PSNR_numpy(target_numpy, output_numpy)
                metrics_dict["SSIM"] += ssim(
                    target_numpy,
                    output_numpy,
                    gaussian_weights=True,
                    use_sample_covariance=False,
                    multichannel=True,
                )

                # Dump to output folder
                path_output = val_path / filename[e]

                cv2.imwrite(
                    str(path_output), (output_numpy[:, :, ::-1] * 255.0).astype(np.int)
                )

            metrics_dict["SSIM"] = metrics_dict["SSIM"] / args.batch_size
            metrics_dict["PSNR"] = metrics_dict["PSNR"] / args.batch_size

            avg_val_metrics += metrics_dict

            pbar.update(args.batch_size)
            pbar.set_description(
                f"Val Epoch : {start_epoch} Step: {global_step}| PSNR: {avg_val_metrics.loss_dict['PSNR']:.3f} | SSIM: {avg_val_metrics.loss_dict['SSIM']:.3f} | LPIPS 01: {avg_val_metrics.loss_dict['LPIPS_01']:.3f} | LPIPS 11: {avg_val_metrics.loss_dict['LPIPS_11']:.3f}"
            )

        with open(val_path / "metrics.txt", "w") as f:
            L = [
                f"exp_name:{args.exp_name} trained for {start_epoch} epochs\n",
                "Val Metrics \n\n",
            ]
            L = L + [f"{k}:{v}\n" for k, v in avg_val_metrics.loss_dict.items()]
            f.writelines(L)

        if data.test_loader:
            pbar = tqdm(
                range(len(data.test_loader) * args.batch_size), dynamic_ncols=True
            )

            for i, batch in enumerate(data.test_loader):

                source, filename = batch
                source = source.to(device)

                output = G(source)

                if args.self_ensemble:
                    output_ensembled = [output]

                    for k in ensemble_ops.keys():
                        # Forward transform
                        source_t = ensemble_ops[k][0](source)
                        output_t = G(source_t)
                        # Inverse transform
                        output_t = ensemble_ops[k][1](output_t)

                        output_ensembled.append(output_t)

                    output_ensembled = torch.cat(output_ensembled, dim=0)
                    output = torch.mean(output_ensembled, dim=0, keepdim=True)

                for e in range(args.batch_size):
                    output_numpy = (
                        output[e]
                        .mul(0.5)
                        .add(0.5)
                        .permute(1, 2, 0)
                        .cpu()
                        .detach()
                        .numpy()
                    )

                    # Dump to output folder
                    path_output = test_path / filename[e]

                    cv2.imwrite(
                        str(path_output),
                        (output_numpy[:, :, ::-1] * 255.0).astype(np.int),
                    )

                pbar.update(args.batch_size)
                pbar.set_description(f"Test Epoch : {start_epoch} Step: {global_step}")
Esempio n. 9
0
def main(_run):
    args = tupperware(_run.config)

    log_dir = Path("logs")
    log_dir.mkdir(parents=True, exist_ok=True)

    for game_i in range(len(args.games)):
        print(f"Game {game_i + 1}")
        print(f"Arms {args.games[game_i]}")

        D = {}

        ######################
        # UCB
        ######################

        for alpha in [2]:
            pbar = tqdm(range(args.repeat))
            regret_ll, var_regret_ll, arm_ll = UCB(args, game_i, alpha, pbar)
            d = {}
            d["regret"] = regret_ll
            d["var"] = var_regret_ll
            d["arm_ll"] = arm_ll
            D[f"UCB-alpha={alpha}"] = d

        #######################
        # TS Beta
        #######################

        for params in [(1, 1)]:  # [(1, 1), (0.2, 0.8)]:
            pbar = tqdm(range(args.repeat))
            regret_ll, var_regret_ll, arm_ll = TS_beta(args, game_i, params,
                                                       pbar)
            d = {}
            d["regret"] = regret_ll
            d["var"] = var_regret_ll
            d["arm_ll"] = arm_ll
            D[f"TS-beta-params-{params}"] = d

        ######################
        # Asym UCB
        ######################

        # d = {}
        # regret_ll, var_regret_ll = Asymp_UCB(args, game_i, pbar=pbar)
        # d['regret'] = regret_ll
        # d['var'] = var_regret_ll
        # D['Asym-UCB'] = d

        ######################
        # MRAS Categ
        ######################

        # d = {}
        # pbar = tqdm(range(args.repeat))
        # regret_ll, var_regret_ll, arm_ll = MRAS_categ(args, game_i, pbar=pbar)
        # d["regret"] = regret_ll
        # d["var"] = var_regret_ll
        # d["arm_ll"] = arm_ll
        # D["MRAS-Categ-exp"] = d

        ######################
        # MRAS Categ Elite
        ######################

        d = {}
        pbar = tqdm(range(args.repeat))
        regret_ll, var_regret_ll, arm_ll = MRAS_categ_elite(args,
                                                            game_i,
                                                            N_0=400,
                                                            alpha=4,
                                                            rho_0=0.7,
                                                            pbar=pbar)
        d["regret"] = regret_ll
        d["var"] = var_regret_ll
        d["arm_ll"] = arm_ll
        D["MRAS-Categ-Elite"] = d

        ######################
        # Gap Indep
        ######################

        # d= {}
        # d['regret'] = gap_indep
        # d['var'] = np.zeros_like(gap_indep)
        # D['gap-indpendent-minimax'] = d

        ######################
        # Gap Dependent
        ######################
        # d = {}
        # d['regret'] = gap_dep
        # d['var'] = np.zeros_like(gap_dep)
        # D['gap-dependent-minimax'] = d

        plot_regret(D, game_i, args, supress=True)
        plot_prob_arm(D, game_i, args, supress=True)