Пример #1
0
def main(loss_config="gt_mse", mode="standard", pretrained=False, batch_size=64, **kwargs):

    # MODEL
    # model = DataParallelModel.load(UNet().cuda(), "standardval_rgb2normal_baseline.pth")
    model = functional_transfers.n.load_model() if pretrained else DataParallelModel(UNet())
    model.compile(torch.optim.Adam, lr=(3e-5 if pretrained else 3e-4), weight_decay=2e-6, amsgrad=True)
    scheduler = MultiStepLR(model.optimizer, milestones=[5*i + 1 for i in range(0, 80)], gamma=0.95)

    # FUNCTIONAL LOSS
    functional = get_functional_loss(config=loss_config, mode=mode, model=model, **kwargs)
    print (functional)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
    logger.add_hook(lambda logger, data: model.save(f"{RESULTS_DIR}/model.pth"), feature="loss", freq=400)
    logger.add_hook(lambda logger, data: scheduler.step(), feature="epoch", freq=1)
    functional.logger_hooks(loggers)

    # DATA LOADING
    ood_images = load_ood(ood_path=f'{BASE_DIR}/data/ood_images/')
    train_loader, val_loader, train_step, val_step = load_train_val([tasks.rgb, tasks.normal], batch_size=batch_size)
        # train_buildings=["almena"], val_buildings=["almena"])
    test_set, test_images = load_test("rgb", "normal")
    logger.images(test_images, "images", resize=128)
    logger.images(torch.cat(ood_images, dim=0), "ood_images", resize=128)

    # TRAINING
    for epochs in range(0, 800):
        preds_name = "start_preds" if epochs == 0 and pretrained else "preds"
        ood_name = "start_ood" if epochs == 0 and pretrained else "ood"
        plot_images(model, logger, test_set, dest_task="normal", ood_images=ood_images, 
            loss_models=functional.plot_losses, preds_name=preds_name, ood_name=ood_name
        )
        logger.update("epoch", epochs)
        logger.step()

        train_set = itertools.islice(train_loader, train_step)
        val_set = itertools.islice(val_loader, val_step)
        
        val_metrics = model.predict_with_metrics(val_set, loss_fn=functional, logger=logger)
        train_metrics = model.fit_with_metrics(train_set, loss_fn=functional, logger=logger)
        functional.logger_update(logger, train_metrics, val_metrics)
Пример #2
0
def main():

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda x: logger.step(), feature="loss", freq=25)

    resize = 256
    ood_images = load_ood()[0]
    tasks = [
        get_task(name) for name in [
            'rgb', 'normal', 'principal_curvature', 'depth_zbuffer',
            'sobel_edges', 'reshading', 'keypoints3d', 'keypoints2d'
        ]
    ]

    test_loader = torch.utils.data.DataLoader(TaskDataset(['almena'], tasks),
                                              batch_size=64,
                                              num_workers=12,
                                              shuffle=False,
                                              pin_memory=True)
    imgs = list(itertools.islice(test_loader, 1))[0]
    gt = {tasks[i].name: batch.cuda() for i, batch in enumerate(imgs)}
    num_plot = 4

    logger.images(ood_images, f"x", nrow=2, resize=resize)
    edges = finetuned_transfers

    def get_nbrs(task, edges):
        res = []
        for e in edges:
            if task == e.src_task:
                res.append(e)
        return res

    max_depth = 10
    mse_dict = defaultdict(list)

    def search_small(x, task, prefix, visited, depth, endpoint):

        if task.name == 'normal':
            interleave = torch.stack([
                val for pair in zip(x[:num_plot], gt[task.name][:num_plot])
                for val in pair
            ])
            logger.images(interleave.clamp(max=1, min=0),
                          prefix,
                          nrow=2,
                          resize=resize)
            mse, _ = task.loss_func(x, gt[task.name])
            mse_dict[task.name].append(
                (mse.detach().data.cpu().numpy(), prefix))

        for transfer in get_nbrs(task, edges):
            preds = transfer(x)
            next_prefix = f'{transfer.name}({prefix})'
            print(f"{transfer.src_task.name}2{transfer.dest_task.name}",
                  next_prefix)

            if transfer.dest_task.name not in visited:
                visited.add(transfer.dest_task.name)
                res = search_small(preds, transfer.dest_task, next_prefix,
                                   visited, depth + 1, endpoint)
                visited.remove(transfer.dest_task.name)

        return endpoint == task

    def search_full(x, task, prefix, visited, depth, endpoint):
        for transfer in get_nbrs(task, edges):
            preds = transfer(x)
            next_prefix = f'{transfer.name}({prefix})'
            print(f"{transfer.src_task.name}2{transfer.dest_task.name}",
                  next_prefix)
            if transfer.dest_task.name == 'normal':
                interleave = torch.stack([
                    val for pair in zip(preds[:num_plot], gt[
                        transfer.dest_task.name][:num_plot]) for val in pair
                ])
                logger.images(interleave.clamp(max=1, min=0),
                              next_prefix,
                              nrow=2,
                              resize=resize)
                mse, _ = task.loss_func(preds, gt[transfer.dest_task.name])
                mse_dict[transfer.dest_task.name].append(
                    (mse.detach().data.cpu().numpy(), next_prefix))
            if transfer.dest_task.name not in visited:
                visited.add(transfer.dest_task.name)
                res = search_full(preds, transfer.dest_task, next_prefix,
                                  visited, depth + 1, endpoint)
                visited.remove(transfer.dest_task.name)

        return endpoint == task

    def search(x, task, prefix, visited, depth):
        for transfer in get_nbrs(task, edges):
            preds = transfer(x)
            next_prefix = f'{transfer.name}({prefix})'
            print(f"{transfer.src_task.name}2{transfer.dest_task.name}",
                  next_prefix)
            if transfer.dest_task.name == 'normal':
                logger.images(preds.clamp(max=1, min=0),
                              next_prefix,
                              nrow=2,
                              resize=resize)
            if transfer.dest_task.name not in visited:
                visited.add(transfer.dest_task.name)
                res = search(preds, transfer.dest_task, next_prefix, visited,
                             depth + 1)
                visited.remove(transfer.dest_task.name)

    with torch.no_grad():
        # search_full(gt['rgb'], TASK_MAP['rgb'], 'x', set('rgb'), 1, TASK_MAP['normal'])
        search(ood_images, get_task('rgb'), 'x', set('rgb'), 1)

    for name, mse_list in mse_dict.items():
        mse_list.sort()
        print(name)
        print(mse_list)
        if len(mse_list) == 1: mse_list.append((0, '-'))
        rownames = [pair[1] for pair in mse_list]
        data = [pair[0] for pair in mse_list]
        print(data, rownames)
        logger.bar(data, f'{name}_path_mse', opts={'rownames': rownames})