def main():

    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="energy",
                    freq=16)
    logger.add_hook(
        lambda logger, data: logger.plot(data["energy"], "free_energy"),
        feature="energy",
        freq=100)

    task_list = [
        tasks.rgb,
        tasks.normal,
        tasks.principal_curvature,
        tasks.sobel_edges,
        tasks.depth_zbuffer,
        tasks.reshading,
        tasks.edge_occlusion,
        tasks.keypoints3d,
        tasks.keypoints2d,
    ]

    reality = RealityTask('ood',
                          dataset=ImagePairDataset(data_dir=OOD_DIR,
                                                   resize=(256, 256)),
                          tasks=[tasks.rgb, tasks.rgb],
                          batch_size=28)

    # reality = RealityTask('almena',
    #     dataset=TaskDataset(
    #         buildings=['almena'],
    #         tasks=task_list,
    #     ),
    #     tasks=task_list,
    #     batch_size=8
    # )

    graph = TaskGraph(tasks=[reality, *task_list], batch_size=28)

    task = tasks.rgb
    images = [reality.task_data[task]]
    sources = [task.name]

    for _, edge in sorted(
        ((edge.dest_task.name, edge) for edge in graph.adj[task])):
        if isinstance(edge.src_task, RealityTask): continue

        reality.task_data[edge.src_task]
        x = edge(reality.task_data[edge.src_task])

        if edge.dest_task != tasks.normal:
            edge2 = graph.edge_map[(edge.dest_task.name, tasks.normal.name)]
            x = edge2(x)

        images.append(x.clamp(min=0, max=1))
        sources.append(edge.dest_task.name)

    logger.images_grouped(images, ", ".join(sources), resize=256)
def main():

    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16)
    logger.add_hook(lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100)

    task_list = [
        tasks.rgb,
        tasks.normal,
        tasks.principal_curvature,
        tasks.sobel_edges,
        tasks.depth_zbuffer,
        tasks.reshading,
        tasks.edge_occlusion,
        tasks.keypoints3d,
        tasks.keypoints2d,
    ]

    reality = RealityTask('almena', 
        dataset=TaskDataset(
            buildings=['almena'],
            tasks=task_list,
        ),
        tasks=task_list,
        batch_size=8
    )

    graph = TaskGraph(
        tasks=[reality, *task_list],
        batch_size=8
    )

    for task in graph.tasks:
        if isinstance(task, RealityTask): continue

        images = [reality.task_data[task].clamp(min=0, max=1)]
        sources = [task.name]
        for _, edge in sorted(((edge.src_task.name, edge) for edge in graph.in_adj[task])):
            if isinstance(edge.src_task, RealityTask): continue

            x = edge(reality.task_data[edge.src_task])
            images.append(x.clamp(min=0, max=1))
            sources.append(edge.src_task.name)

        logger.images_grouped(images, ", ".join(sources), resize=192)
def main():

    logger = VisdomLogger("train", env=JOB)

    test_loader, test_images = load_doom()
    test_images = torch.cat(test_images, dim=0)
    src_task, dest_task = get_task("rgb"), get_task("normal")

    print(test_images.shape)
    src_task.plot_func(test_images, f"images", logger, resize=128)

    paths = ["F(RC(x))", "F(EC(a(x)))", "n(x)", "npstep(x)"]

    for path_str in paths:

        path_list = path_str.replace(')', '').split('(')[::-1][1:]
        path = [TRANSFER_MAP[name] for name in path_list]

        class PathModel(TrainableModel):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                with torch.no_grad():
                    for f in path:
                        x = f(x)
                return x

            def loss(self, pred, target):
                loss = torch.tensor(0.0, device=pred.device)
                return loss, (loss.detach(), )

        model = PathModel()

        preds = model.predict(test_loader)
        dest_task.plot_func(preds, f"preds_{path_str}", logger, resize=128)
        transform = transforms.ToPILImage()
        os.makedirs(f"{BASE_DIR}/doom_processed/{path_str}/video2",
                    exist_ok=True)

        for image, file in zip(preds, test_loader.dataset.files):
            image = transform(image.cpu())
            filename = file.split("/")[-1]
            print(filename)
            image.save(
                f"{BASE_DIR}/doom_processed/{path_str}/video2/{filename}")

    print(preds.shape)
Exemple #4
0
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from models import DecodingModel, DataParallelModel
from torchvision import models
from logger import Logger, VisdomLogger
from utils import *

import IPython

import transforms


# LOGGING
logger = VisdomLogger("encoding", server="35.230.67.129", port=8000, env=JOB)
logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20)
logger.add_hook(lambda x: logger.plot(x, "Encoding Loss", opts=dict(ymin=0)), feature="loss", freq=50)


""" 
Computes the changed images, given a a specified perturbation, standard deviation weighting, 
and epsilon.
"""


def compute_changed_images(images, perturbation, std_weights, epsilon=EPSILON):

    perturbation_w2 = perturbation * std_weights
    perturbation_zc = (
        perturbation_w2
    points_in_camera = points_in_camera_coords(depth_map, cached_pixel_to_ray_array)
    surface_normals = surface_normal(points_in_camera, pixel_width=depth_map.shape[1], pixel_height=depth_map.shape[0])
    surface_normals = 1.0 - surface_normals*0.5 - 0.5
    return file_name, surface_normals





if __name__ == "__main__":
    files = glob.glob("mount/sintel/training/depth/*/frame*.dpt")
    depth_map = load_depth_map_in_m(files[0])
    normals_list = []

    logger = VisdomLogger("train", env=JOB)

    with Pool() as pool:
        for i, (file, normals) in enumerate(
            pool.imap_unordered(normal_from_depth, files)
        ):

            filename = "normal" + file.split('/')[-1][5:][:-4] + ".png"
            filename = file[:-len(filename)] + "/" + filename
            print (i, len(files), filename)
            plt.imsave(filename, normals)
            normals_list.append(normals)
            print (len(normals_list))

    # normals_list = torch.FloatTensor(np.array(normals_list).astype(np.float32))
    # normals_list = normals_list.permute((0, 3, 1, 2))
Exemple #6
0
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

        device_id = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(device_id)
        print(f"Setting CUDA Device to {device_id}")

        dist.init_process_group(backend=args.dist_backend)
        main_proc = device_id == 0  # Main process handles saving of models and reporting

    checkpoint_handler = CheckpointHandler(save_folder=args.save_folder,
                                           best_val_model_name=args.best_val_model_name,
                                           checkpoint_per_iteration=args.checkpoint_per_iteration,
                                           save_n_recent_models=args.save_n_recent_models)

    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir, args.log_params)

    if args.load_auto_checkpoint:
        latest_checkpoint = checkpoint_handler.find_latest_checkpoint()
        if latest_checkpoint:
            args.continue_from = latest_checkpoint

    if args.continue_from:  # Starting from previous model
        state = TrainingState.load_state(state_path=args.continue_from)
        model = state.model
        if args.finetune:
            state.init_finetune_states(args.epochs)

        if main_proc and args.visdom:  # Add previous scores to visdom graph
    if args.distributed:
        if args.gpu_rank:
            torch.cuda.set_device(int(args.gpu_rank))
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
        main_proc = args.rank == 0  # Only the first proc should save models
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_wer = None
    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir,
                                               args.log_params)

    avg_loss, start_epoch, start_iter, optim_state = 0, 0, 0, None
    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from,
                             map_location=lambda storage, loc: storage)
        model = DeepSpeech.load_model_package(package)
        labels = model.labels
        audio_conf = model.audio_conf
        if not args.finetune:  # Don't want to restart training
            optim_state = package['optim_dict']
            start_epoch = int(package.get(
Exemple #8
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    cont_gan=None,
    pre_gan=None,
    max_epochs=800,
    use_baseline=False,
    use_patches=False,
    patch_frac=None,
    patch_size=64,
    patch_sigma=0,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        energy_loss.get_tasks("train_subset"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    train_step, val_step = train_step // 16, val_step // 16
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               batch_size=batch_size,
                               shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, train_subset, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not USE_RAID and not use_baseline: graph.load_weights(cont)
    pre_gan = pre_gan or 1
    discriminator = Discriminator(energy_loss.losses['gan'],
                                  frac=patch_frac,
                                  size=(patch_size if use_patches else 224),
                                  sigma=patch_sigma,
                                  use_patches=use_patches)
    if cont_gan is not None: discriminator.load_weights(cont_gan)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    logger.add_hook(
        lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"),
        feature="epoch",
        freq=1)
    energy_loss.logger_hooks(logger)

    best_ood_val_loss = float('inf')

    logger.add_hook(partial(jointplot, loss_type=f"gan_subset"),
                    feature=f"val_gan_subset",
                    freq=1)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        if visualize: return

        graph.train()
        discriminator.train()

        for _ in range(0, train_step):
            if epochs > pre_gan:

                train_loss = energy_loss(graph,
                                         discriminator=discriminator,
                                         realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])

                graph.step(train_loss)
                train.step()
                logger.update("loss", train_loss)

                # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse'])
                # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1])
                # train.step()

                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan'])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # train.step()

                # graph.step(train_loss1 + train_loss2)
                # logger.update("loss", train_loss1 + train_loss2)

                # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_id'])
                # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1])
                # graph.step(train_loss1)

                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_ood'])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # graph.step(train_loss2)

                # train_loss3 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan'])
                # train_loss3 = sum([train_loss3[loss_name] for loss_name in train_loss3])
                # graph.step(train_loss3)

                # logger.update("loss", train_loss1 + train_loss2 + train_loss3)
                # train.step()

                # graph fooling loss
                # n(~x), and y^ (128 subset)
                # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train])
                # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2])
                # train_loss = train_loss1 + train_loss2

            warmup = 5  # if epochs < pre_gan else 1
            for i in range(warmup):
                # y_hat = graph.sample_path([tasks.normal(size=512)], reality=train_subset)
                # n_x = graph.sample_path([tasks.rgb(size=512), tasks.normal(size=512)], reality=train)

                y_hat = graph.sample_path([tasks.normal], reality=train_subset)
                n_x = graph.sample_path(
                    [tasks.rgb(blur_radius=6),
                     tasks.normal(blur_radius=6)],
                    reality=train)

                def coeff_hook(coeff):
                    def fun1(grad):
                        return coeff * grad.clone()

                    return fun1

                logit_path1 = discriminator(y_hat.detach())
                coeff = 0.1
                path_value2 = n_x * 1.0
                path_value2.register_hook(coeff_hook(coeff))
                logit_path2 = discriminator(path_value2)
                binary_label = torch.Tensor(
                    [1] * logit_path1.size(0) +
                    [0] * logit_path2.size(0)).float().cuda()
                gan_loss = nn.BCEWithLogitsLoss(size_average=True)(torch.cat(
                    (logit_path1, logit_path2), dim=0).view(-1), binary_label)
                discriminator.discriminator.step(gan_loss)

                logger.update("train_gan_subset", gan_loss)
                logger.update("val_gan_subset", gan_loss)

                # print ("Gan loss: ", (-gan_loss).data.cpu().numpy())
                train.step()
                train_subset.step()

        graph.eval()
        discriminator.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph,
                                       discriminator=discriminator,
                                       realities=[val, train_subset])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        if epochs > pre_gan:
            energy_loss.logger_update(logger)
            logger.step()

            if logger.data["train_subset_val_ood : y^ -> n(~x)"][
                    -1] < best_ood_val_loss:
                best_ood_val_loss = logger.data[
                    "train_subset_val_ood : y^ -> n(~x)"][-1]
                energy_loss.plot_paths(graph, logger, realities, prefix="best")
Exemple #9
0
def main():
    parser = argparse.ArgumentParser(description="-----[CNN-classifier]-----")
    parser.add_argument("--similarity",
                        default=0.0,
                        type=float,
                        help="similarity threshold")
    parser.add_argument(
        "--similarity_representation",
        default="W2V",
        help=
        "similarity representation. Available methods: CNN, AUTOENCODER, W2V")
    parser.add_argument(
        "--mode",
        default="train",
        help="train: train (with test) a model / test: test saved models")
    parser.add_argument(
        "--model",
        default="cnn",
        help="Type of model to use. Default: CNN. Available models: CNN, RNN")
    parser.add_argument("--embedding",
                        default="w2v",
                        help="available embedings: random, w2v")
    parser.add_argument("--dataset",
                        default="MR",
                        help="available datasets: MR, TREC")
    parser.add_argument("--encoder",
                        default=None,
                        help="Path to encoder model file")
    parser.add_argument("--decoder",
                        default=None,
                        help="Path to decoder model file")
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        help='batch size for training [default: 32]')
    parser.add_argument(
        '--selection-size',
        type=int,
        default=32,
        help='selection size for selection function [default: 32]')
    parser.add_argument("--save_model",
                        default="F",
                        help="whether saving model or not (T/F)")
    parser.add_argument("--early_stopping",
                        default="F",
                        help="whether to apply early stopping(T/F)")
    parser.add_argument("--epoch",
                        default=100,
                        type=int,
                        help="number of max epoch")
    parser.add_argument("--learning_rate",
                        default=0.1,
                        type=float,
                        help="learning rate")
    parser.add_argument("--dropout_embed",
                        default=0.2,
                        type=float,
                        help="Dropout embed probability. Default: 0.2")
    parser.add_argument("--dropout_model",
                        default=0.4,
                        type=float,
                        help="Dropout model probability. Default: 0.4")
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='Cuda device to run on')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disable the gpu')
    parser.add_argument(
        "--scorefn",
        default="entropy",
        help="available scoring functions: entropy, random, egl")
    parser.add_argument('--average',
                        type=int,
                        default=1,
                        help='Number of runs to average [default: 1]')
    parser.add_argument('--hnodes',
                        type=int,
                        default=256,
                        help='Number of nodes in the hidden layer(s)')
    parser.add_argument('--hlayers',
                        type=int,
                        default=1,
                        help='Number of hidden layers')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=1e-5,
                        help='Value of weight_decay')
    parser.add_argument('--no-log',
                        action='store_true',
                        default=False,
                        help='Disable logging')
    parser.add_argument('--data_path',
                        default='/data/stud/jorgebjorn/data',
                        type=str,
                        help='Dir path to datasets')
    parser.add_argument('--c', default='', type=str, help='Comment to run ')

    options = parser.parse_args()

    params["DATA_PATH"] = options.data_path  #TODO rewrite?

    getattr(utils, "read_{}".format(options.dataset))()

    data["vocab"] = sorted(
        list(
            set([
                w for sent in data["train_x"] + data["dev_x"] + data["test_x"]
                for w in sent
            ])))
    data["classes"] = sorted(list(set(data["train_y"])))
    data["word_to_idx"] = {w: i for i, w in enumerate(data["vocab"])}

    params_local = {
        "SIMILARITY_THRESHOLD":
        options.similarity,
        "SIMILARITY_REPRESENTATION":
        options.similarity_representation,
        "DATA_PATH":
        options.data_path,
        "MODEL":
        options.model,
        "EMBEDDING":
        options.embedding,
        "DATASET":
        options.dataset,
        "SAVE_MODEL":
        bool(options.save_model == "T"),
        "EARLY_STOPPING":
        bool(options.early_stopping == "T"),
        "EPOCH":
        options.epoch,
        "LEARNING_RATE":
        options.learning_rate,
        "MAX_SENT_LEN":
        max([
            len(sent)
            for sent in data["train_x"] + data["dev_x"] + data["test_x"]
        ]),
        "SELECTION_SIZE":
        options.selection_size,
        "BATCH_SIZE":
        options.batch_size,
        "WORD_DIM":
        300,
        "VOCAB_SIZE":
        len(data["vocab"]),
        "CLASS_SIZE":
        len(data["classes"]),
        "FILTERS": [3, 4, 5],
        "FILTER_NUM": [100, 100, 100],
        "DROPOUT_EMBED":
        options.dropout_embed,
        "DROPOUT_MODEL":
        options.dropout_model,
        "DEVICE":
        options.device,
        "NO_CUDA":
        options.no_cuda,
        "SCORE_FN":
        options.scorefn,
        "N_AVERAGE":
        options.average,
        "HIDDEN_SIZE":
        options.hnodes,
        "HIDDEN_LAYERS":
        options.hlayers,
        "WEIGHT_DECAY":
        options.weight_decay,
        "LOG":
        not options.no_log,
        "ENCODER":
        options.encoder,
        "DECODER":
        options.decoder,
        "C":
        options.c
    }

    for key in params_local:
        params[key] = params_local[key]

    if params["LOG"]:
        logger_name = 'SS/{}_{}_{}_{}_{}'.format(
            getpass.getuser(),
            datetime.datetime.now().strftime("%d-%m-%y_%H:%M"),
            options.dataset, params["C"],
            str(uuid.uuid4())[:4])
        global_logger["lg"] = VisdomLogger(
            logger_name, "{}_{}".format(params["SIMILARITY_THRESHOLD"],
                                        params["SIMILARITY_REPRESENTATION"]))
        # global_logger["lg"].parameters_summary()
        print("visdom logger OK")
        # quit()

    params["CUDA"] = (not params["NO_CUDA"]) and torch.cuda.is_available()
    del params["NO_CUDA"]

    if params["CUDA"]:
        torch.cuda.set_device(params["DEVICE"])

    if params["EMBEDDING"] == "w2v":
        utils.load_word2vec()

    encoder = rnnae.EncoderRNN()
    # decoder = rnnae.DecoderRNN()
    decoder = rnnae.AttnDecoderRNN()
    feature_extractor = CNN2()

    if params["ENCODER"] != None:
        print("Loading encoder")
        encoder.load_state_dict(torch.load(params["ENCODER"]))

    if params["DECODER"] != None:
        print("Loading decoder")
        decoder.load_state_dict(torch.load(params["DECODER"]))

    if params["CUDA"]:
        encoder, decoder, feature_extractor = encoder.cuda(), decoder.cuda(
        ), feature_extractor.cuda()

    models["ENCODER"] = encoder
    models["DECODER"] = decoder
    models["FEATURE_EXTRACTOR"] = feature_extractor

    print("=" * 20 + "INFORMATION" + "=" * 20)
    for key, value in params.items():
        print("{}: {}".format(key.upper(), value))

    if params["EMBEDDING"] == "random" and params["SIMILARITY_THRESHOLD"] > 0:
        print("********** WARNING *********")
        print("Random embedding makes similarity threshold have no effect. \n")

    print("=" * 20 + "TRAINING STARTED" + "=" * 20)
    train.active_train()
    print("=" * 20 + "TRAINING FINISHED" + "=" * 20)
def main(
    loss_config="conservative_full", mode="standard", visualize=False,
    fast=False, batch_size=None, 
    subset_size=None, max_epochs=800, **kwargs,
):
        
    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    print (kwargs)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size, fast=fast,
        subset_size=subset_size
    )
    train_step, val_step = 24, 12
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood([tasks.rgb,])
    print (train_step, val_step)
    
    train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, 
        freeze_list=energy_loss.freeze_list,
    )
    graph.compile(torch.optim.Adam, lr=1e-6, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            # logger.update("loss", val_loss)

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            # logger.update("loss", train_loss)

        energy_loss.logger_update(logger)
        
        for param_group in graph.optimizer.param_groups:
            param_group['lr'] *= 1.2
            print ("LR: ", param_group['lr'])

        logger.step()
Exemple #11
0
def main():

    task_list = [
        tasks.rgb,
        tasks.normal,
        tasks.principal_curvature,
        tasks.sobel_edges,
        tasks.depth_zbuffer,
        tasks.reshading,
        tasks.edge_occlusion,
        tasks.keypoints3d,
        tasks.keypoints2d,
    ]

    reality = RealityTask('almena',
                          dataset=TaskDataset(buildings=['almena'],
                                              tasks=[
                                                  tasks.rgb, tasks.normal,
                                                  tasks.principal_curvature,
                                                  tasks.depth_zbuffer
                                              ]),
                          tasks=[
                              tasks.rgb, tasks.normal,
                              tasks.principal_curvature, tasks.depth_zbuffer
                          ],
                          batch_size=4)

    graph = TaskGraph(
        tasks=[reality, *task_list],
        anchored_tasks=[reality, tasks.rgb],
        reality=reality,
        batch_size=4,
        edges_exclude=[
            ('almena', 'normal'),
            ('almena', 'principal_curvature'),
            ('almena', 'depth_zbuffer'),
            # ('rgb', 'keypoints3d'),
            # ('rgb', 'edge_occlusion'),
        ],
        initialize_first_order=True,
    )

    graph.p.compile(torch.optim.Adam, lr=4e-2)
    graph.estimates.compile(torch.optim.Adam, lr=1e-2)

    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="energy",
                    freq=16)
    logger.add_hook(
        lambda logger, data: logger.plot(data["energy"], "free_energy"),
        feature="energy",
        freq=100)
    logger.add_hook(lambda logger, data: graph.plot_estimates(logger),
                    feature="epoch",
                    freq=32)
    logger.add_hook(lambda logger, data: graph.update_paths(logger),
                    feature="epoch",
                    freq=32)

    graph.plot_estimates(logger)
    graph.plot_paths(logger,
                     dest_tasks=[
                         tasks.normal, tasks.principal_curvature,
                         tasks.depth_zbuffer
                     ],
                     show_images=False)

    for epochs in range(0, 4000):
        logger.update("epoch", epochs)

        with torch.no_grad():
            free_energy = graph.free_energy(sample=16)
            graph.averaging_step()

        logger.update("energy", free_energy)
Exemple #12
0
                                    0.0)
        targets = [binary.random(n=TARGET_SIZE) for i in range(len(images))]
        torch.save((perturbation.data, images.data, targets),
                   f"{output_path}/{k}.pth")


if __name__ == "__main__":

    model = DataParallelModel(
        DecodingModel(n=DIST_SIZE, distribution=transforms.training))
    params = itertools.chain(model.module.classifier.parameters(),
                             model.module.features[-1].parameters())
    optimizer = torch.optim.Adam(params, lr=2.5e-3)
    init_data("data/amnesia")

    logger = VisdomLogger("train", server="35.230.67.129", port=8000, env=JOB)
    logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20)
    logger.add_hook(lambda data: logger.plot(data, "train_loss"),
                    feature="loss",
                    freq=50)
    logger.add_hook(lambda data: logger.plot(data, "train_bits"),
                    feature="bits",
                    freq=50)
    logger.add_hook(
        lambda x: model.save("output/train_test.pth", verbose=True),
        feature="epoch",
        freq=100)
    model.save("output/train_test.pth", verbose=True)

    files = glob.glob(f"data/amnesia/*.pth")
    for i, save_file in enumerate(
def main(
	loss_config="multiperceptual", mode="standard", visualize=False,
	fast=False, batch_size=None,
	subset_size=None, max_epochs=5000, dataaug=False, **kwargs,
):

	# CONFIG
	wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"data_aug":dataaug,"lr":"3e-5",
		"n_gauss":1,"distribution":"laplace"})

	batch_size = batch_size or (4 if fast else 64)
	energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

	# DATA LOADING
	train_dataset, val_dataset, val_noaug_dataset, train_step, val_step = load_train_val_merging(
		energy_loss.get_tasks("train_c"),
		batch_size=batch_size, fast=fast,
		subset_size=subset_size,
	)
	test_set = load_test(energy_loss.get_tasks("test"))

	ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/')
	ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/')
	ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35)

	train = RealityTask("train_c", train_dataset, batch_size=batch_size, shuffle=True)      # distorted and undistorted 
	val = RealityTask("val_c", val_dataset, batch_size=batch_size, shuffle=True)            # distorted and undistorted 
	val_noaug = RealityTask("val", val_noaug_dataset, batch_size=batch_size, shuffle=True)  # no augmentation
	test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))

	ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])                                  ## standard ood set - natural
	ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,])          ## synthetic distortion images used for sig training 
	ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,])                      ## unseen syn distortions

	# GRAPH
	realities = [train, val, val_noaug, test, ood, ood_syn_aug, ood_syn]
	graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,
		freeze_list=energy_loss.freeze_list,
	)
	graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

	# LOGGING
	logger = VisdomLogger("train", env=JOB)    # fake visdom logger
	logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
	energy_loss.logger_hooks(logger)

	graph.eval()
	path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
	for reality_paths, reality_images in path_values.items():
		wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0)

	with torch.no_grad():
		for reality in [val,val_noaug]:
			for _ in range(0, val_step):
				val_loss = energy_loss(graph, realities=[reality])
				val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
				reality.step()
				logger.update("loss", val_loss)

		for _ in range(0, train_step):
			train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True)
			train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
			train.step()
			logger.update("loss", train_loss)

	energy_loss.logger_update(logger)

	data=logger.step()
	del data['loss']
	data = {k:v[0] for k,v in data.items()}
	wandb.log(data, step=0)

	# TRAINING
	for epochs in range(0, max_epochs):

		logger.update("epoch", epochs)

		graph.train()
		for _ in range(0, train_step):
			train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True)
			train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
			graph.step(train_loss)
			train.step()
			logger.update("loss", train_loss)

		graph.eval()
		for reality in [val,val_noaug]:
			for _ in range(0, val_step):
				with torch.no_grad():
					val_loss = energy_loss(graph, realities=[reality])
					val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
				reality.step()
				logger.update("loss", val_loss)

		energy_loss.logger_update(logger)

		data=logger.step()
		del data['loss']
		del data['epoch']
		data = {k:v[0] for k,v in data.items()}
		wandb.log(data, step=epochs+1)

		if epochs % 10 == 0:
			graph.save(f"{RESULTS_DIR}/graph.pth")
			torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")

		if epochs % 25 == 0:
			path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
			for reality_paths, reality_images in path_values.items():
				wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1)



	graph.save(f"{RESULTS_DIR}/graph.pth")
	torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")
Exemple #14
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    cont_gan=None,
    pre_gan=None,
    use_patches=False,
    patch_size=64,
    use_baseline=False,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      finetuned=finetuned,
                      freeze_list=energy_loss.freeze_list)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not use_baseline and not USE_RAID:
        graph.load_weights(cont)

    pre_gan = pre_gan or 1
    discriminator = Discriminator(energy_loss.losses['gan'],
                                  size=(patch_size if use_patches else 224),
                                  use_patches=use_patches)
    # if cont_gan is not None: discriminator.load_weights(cont_gan)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    logger.add_hook(
        lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"),
        feature="epoch",
        freq=1)
    energy_loss.logger_hooks(logger)

    # TRAINING
    for epochs in range(0, 80):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.train()
        discriminator.train()

        for _ in range(0, train_step):
            if epochs > pre_gan:
                energy_loss.train_iter += 1
                train_loss = energy_loss(graph,
                                         discriminator=discriminator,
                                         realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])
                graph.step(train_loss)
                train.step()
                logger.update("loss", train_loss)

            for i in range(5 if epochs <= pre_gan else 1):
                train_loss2 = energy_loss(graph,
                                          discriminator=discriminator,
                                          realities=[train])
                discriminator.step(train_loss2)
                train.step()

        graph.eval()
        discriminator.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph,
                                       discriminator=discriminator,
                                       realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()
Exemple #15
0
def main(
    loss_config="multiperceptual", mode="winrate", visualize=False,
    fast=False, batch_size=None,
    subset_size=None, max_epochs=800, dataaug=False, **kwargs,
):


    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size, fast=fast,
        subset_size=subset_size,
        dataaug=dataaug,
    )

    if fast:
        train_dataset = val_dataset
        train_step, val_step = 2,2

    train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)

    if fast:
        train_dataset = val_dataset
        train_step, val_step = 2,2
        realities = [train, val]
    else:
        test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville'])
        test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
        realities = [train, val, test]
        # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do:
        # ood_set = load_ood(energy_loss.get_tasks("ood"))
        # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])
        # realities.append(ood)

    # GRAPH
    graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,
        freeze_list=energy_loss.freeze_list,
        initialize_from_transfer=False,
    )
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    os.makedirs(RESULTS_DIR, exist_ok=True)
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)
    energy_loss.plot_paths(graph, logger, realities, prefix="start")

    # BASELINE
    graph.eval()
    with torch.no_grad():
        for _ in range(0, val_step*4):
            val_loss, _ = energy_loss(graph, realities=[val])
            val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        for _ in range(0, train_step*4):
            train_loss, _ = energy_loss(graph, realities=[train])
            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
            train.step()
            logger.update("loss", train_loss)
    energy_loss.logger_update(logger)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph, logger, realities, prefix="")
        if visualize: return

        graph.train()
        for _ in range(0, train_step):
            train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True)
            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)

        logger.step()
Exemple #16
0
def main(
        fast=False,
        subset_size=None,
        early_stopping=float('inf'),
        mode='standard',
        max_epochs=800,
        **kwargs,
):

    early_stopping = 8
    loss_config_percepnet = {
        "paths": {
            "y": [tasks.normal],
            "z^": [tasks.principal_curvature],
            "f(y)": [tasks.normal, tasks.principal_curvature],
        },
        "losses": {
            "mse": {
                ("train", "val"): [
                    ("f(y)", "z^"),
                ],
            },
        },
        "plots": {
            "ID":
            dict(size=256,
                 realities=("test", "ood"),
                 paths=[
                     "y",
                     "z^",
                     "f(y)",
                 ]),
        },
    }

    # CONFIG
    batch_size = 64
    energy_loss = EnergyLoss(**loss_config_percepnet)

    task_list = [tasks.rgb, tasks.normal, tasks.principal_curvature]
    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        task_list,
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(task_list)
    ood_set = load_ood(task_list)
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set, task_list)
    ood = RealityTask.from_static("ood", ood_set, task_list)

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities,
        pretrained=False,
        freeze_list=[functional_transfers.n],
    )
    graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)
    best_val_loss, stop_idx = float('inf'), 0

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()

        stop_idx += 1
        if logger.data["val_mse : f(y) -> z^"][-1] < best_val_loss:
            print("Better val loss, reset stop_idx: ", stop_idx)
            best_val_loss, stop_idx = logger.data["val_mse : f(y) -> z^"][
                -1], 0
            energy_loss.plot_paths(graph, logger, realities, prefix="best")
            graph.save(weights_dir=f"{RESULTS_DIR}")

        if stop_idx >= early_stopping:
            print("Stopping training now")
            break

    early_stopping = 50
    # CONFIG
    energy_loss = get_energy_loss(config="perceptual", mode=mode, **kwargs)

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities,
        pretrained=False,
        freeze_list=[functional_transfers.f],
    )
    graph.edge(
        tasks.normal,
        tasks.principal_curvature).model.load_weights(f"{RESULTS_DIR}/f.pth")
    graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)
    best_val_loss, stop_idx = float('inf'), 0

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)
        logger.step()

        stop_idx += 1
        if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss:
            print("Better val loss, reset stop_idx: ", stop_idx)
            best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][
                -1], 0
            energy_loss.plot_paths(graph, logger, realities, prefix="best")
            graph.save(f"{RESULTS_DIR}/graph.pth")

        if stop_idx >= early_stopping:
            print("Stopping training now")
            break
Exemple #17
0
def main(src_task, dest_task, fast=False):

    # src_task, dest_task = get_task(src_task), get_task(dest_task)
    model = DataParallelModel(get_model(src_task, dest_task).cuda())
    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=25)
    logger.add_hook(partial(jointplot, loss_type="mse_loss"),
                    feature="val_mse_loss",
                    freq=1)
    logger.add_hook(lambda logger, data: model.save(
        f"{RESULTS_DIR}/{src_task.name}2{dest_task.name}.pth"),
                    feature="loss",
                    freq=400)

    # DATA LOADING
    train_loader, val_loader, train_step, val_step = load_train_val(
        [src_task, dest_task],
        batch_size=48,
        train_buildings=["almena"],
        val_buildings=["almena"])
    train_step, val_step = 5, 5
    test_set, test_images = load_test(src_task, dest_task)
    src_task.plot_func(test_images, "images", logger, resize=128)

    for epochs in range(0, 800):

        logger.update("epoch", epochs)
        plot_images(model, logger, test_set, dest_task, show_masks=True)

        train_set = itertools.islice(train_loader, train_step)
        val_set = itertools.islice(val_loader, val_step)

        (train_mse_data, ) = model.fit_with_metrics(train_set,
                                                    loss_fn=dest_task.norm,
                                                    logger=logger)
        logger.update("train_mse_loss", train_mse_data)
        (val_mse_data, ) = model.predict_with_metrics(val_set,
                                                      loss_fn=dest_task.norm,
                                                      logger=logger)
        logger.update("val_mse_loss", val_mse_data)
Exemple #18
0
def main(
    loss_config="multiperceptual",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    resume=False,
    learning_rate=3e-5,
    subset_size=None,
    max_epochs=2000,
    dataaug=False,
    **kwargs,
):

    # CONFIG
    wandb.config.update({
        "loss_config": loss_config,
        "batch_size": batch_size,
        "data_aug": dataaug,
        "lr": learning_rate
    })

    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=False,
        freeze_list=energy_loss.freeze_list,
    )
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if resume:
        graph.load_weights('/workspace/shared/results_test_1/graph.pth')
        graph.optimizer.load_state_dict(
            torch.load('/workspace/shared/results_test_1/opt.pth'))

    # LOGGING
    logger = VisdomLogger("train", env=JOB)  # fake visdom logger
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)

    ######## baseline computation
    if not resume:
        graph.eval()
        with torch.no_grad():
            for _ in range(0, val_step):
                val_loss, _, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
                val.step()
                logger.update("loss", val_loss)

            for _ in range(0, train_step):
                train_loss, _, _ = energy_loss(graph, realities=[train])
                train_loss = sum(
                    [train_loss[loss_name] for loss_name in train_loss])
                train.step()
                logger.update("loss", train_loss)

        energy_loss.logger_update(logger)
        data = logger.step()
        del data['loss']
        data = {k: v[0] for k, v in data.items()}
        wandb.log(data, step=0)

        path_values = energy_loss.plot_paths(graph,
                                             logger,
                                             realities,
                                             prefix="")
        for reality_paths, reality_images in path_values.items():
            wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0)
    ###########

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss, _, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        graph.train()
        for _ in range(0, train_step):
            train_loss, coeffs, avg_grads = energy_loss(
                graph, realities=[train], compute_grad_ratio=True)
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        energy_loss.logger_update(logger)

        data = logger.step()
        del data['loss']
        del data['epoch']
        data = {k: v[0] for k, v in data.items()}
        wandb.log(data, step=epochs)
        wandb.log(coeffs, step=epochs)
        wandb.log(avg_grads, step=epochs)

        if epochs % 5 == 0:
            graph.save(f"{RESULTS_DIR}/graph.pth")
            torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")

        if epochs % 10 == 0:
            path_values = energy_loss.plot_paths(graph,
                                                 logger,
                                                 realities,
                                                 prefix="")
            for reality_paths, reality_images in path_values.items():
                wandb.log({reality_paths: [wandb.Image(reality_images)]},
                          step=epochs + 1)

    graph.save(f"{RESULTS_DIR}/graph.pth")
    torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
Exemple #19
0
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):

        file = self.files[idx]
        res = []
        seed = random.randint(0, 1e10)
        for task in self.tasks:
            image = task.file_loader(file, seed=seed)
            if image.shape[0] == 1: image = image.expand(3, -1, -1)
            res.append(image)
        return tuple(res)




if __name__ == "__main__":

    logger = VisdomLogger("data", env=JOB)
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        [tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.rgb(size=512)],
        batch_size=32,
    )
    print ("created dataset")
    logger.add_hook(lambda logger, data: logger.step(), freq=32)

    for i, _ in enumerate(train_dataset):
        logger.update("epoch", i)
def main(
    loss_config="geonet", mode="geonet", visualize=False,
    fast=False, batch_size=None, 
    subset_size=None, early_stopping=float('inf'),
    max_epochs=800, **kwargs,
):

    print(kwargs)
    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING

    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size, fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    print (train_step, val_step)

    
    # GRAPH
    print(energy_loss.tasks)
    
    print('train tasks', energy_loss.get_tasks("train"))
    train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
    print('val tasks', energy_loss.get_tasks("val"))
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    print('test tasks', energy_loss.get_tasks("test"))
    test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
    print('ood tasks', energy_loss.get_tasks("ood"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))
    print('done')


    # GRAPH
    realities = [train, val, test, ood]
#     graph = GeoNetTaskGraph(tasks=energy_loss.tasks, realities=realities, pretrained=False)

    graph = GeoNetTaskGraph(tasks=energy_loss.tasks, realities=realities, pretrained=True)


    # n(x)/norm(n(x))
    # (f(n(x)) / RC(x)) 
    #graph.compile(torch.optim.Adam, grad_clip=2.0, lr=1e-5, weight_decay=0e-6, amsgrad=True)
    graph.compile(torch.optim.Adam, grad_clip=5.0, lr=4e-5, weight_decay=2e-6, amsgrad=True)
    #graph.compile(torch.optim.Adam, grad_clip=5.0, lr=1e-6, weight_decay=2e-6, amsgrad=True)
    #graph.compile(torch.optim.Adam, grad_clip=5.0, lr=1e-5, weight_decay=2e-6, amsgrad=True)


    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
    energy_loss.logger_hooks(logger)
    best_val_loss, stop_idx = float('inf'), 0

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        try:
            energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
        except:
            pass
        if visualize: return

        graph.train()
        print('training for', train_step, 'steps')
        for _ in range(0, train_step):
            try:
                train_loss = energy_loss(graph, realities=[train])
                train_loss = sum([train_loss[loss_name] for loss_name in train_loss])

                graph.step(train_loss)
                train.step()
                logger.update("loss", train_loss)
            except NotImplementedError:
                pass

        graph.eval()
        for _ in range(0, val_step):
            try:
                with torch.no_grad():
                    val_loss = energy_loss(graph, realities=[val])
                    val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
                val.step()
                logger.update("loss", val_loss)
            except NotImplementedError:
                pass

        energy_loss.logger_update(logger)
        logger.step()

        stop_idx += 1 
        try:
            curr_val_loss = (logger.data["val_mse : N(rgb) -> normal"][-1] + logger.data["val_mse : D(rgb) -> depth"][-1])
            if curr_val_loss < best_val_loss:
                print ("Better val loss, reset stop_idx: ", stop_idx)
                best_val_loss, stop_idx = curr_val_loss, 0
                energy_loss.plot_paths(graph, logger, realities, prefix="best")
                graph.save(f"{RESULTS_DIR}/graph.pth")
        except NotImplementedError:
            pass
    
        if stop_idx >= early_stopping:
            print ("Stopping training now")
            return
def visdom_logger():
    lg = VisdomLogger()
    return lg
Exemple #22
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{BASE_DIR}/shared/results_LBP_multipercep_lat_winrate_8/graph.pth",
    cont_gan=None,
    pre_gan=None,
    max_epochs=800,
    use_baseline=False,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        energy_loss.get_tasks("val"),
        batch_size=batch_size,
        fast=fast,
    )
    train_subset_dataset, _, _, _ = load_train_val(
        energy_loss.get_tasks("train_subset"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    if not fast:
        train_step, val_step = train_step // (16 * 4), val_step // (16)
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    train_subset = RealityTask("train_subset",
                               train_subset_dataset,
                               batch_size=batch_size,
                               shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    # ood = RealityTask.from_static("ood", ood_set, [energy_loss.get_tasks("ood")])

    # GRAPH
    realities = [
        train,
        val,
        test,
    ] + [train_subset]  #[ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      finetuned=finetuned,
                      freeze_list=energy_loss.freeze_list)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    graph.load_weights(cont)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    # logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)

    best_ood_val_loss = float('inf')
    energy_losses = []
    mse_losses = []
    pearsonr_vals = []
    percep_losses = defaultdict(list)
    pearson_percep = defaultdict(list)
    # # TRAINING
    # for epochs in range(0, max_epochs):

    # 	logger.update("epoch", epochs)
    # 	if epochs == 0:
    # 		energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")
    # 	# if visualize: return

    # 	graph.eval()
    # 	for _ in range(0, val_step):
    # 		with torch.no_grad():
    # 			losses = energy_loss(graph, realities=[val])
    # 			all_perceps = [losses[loss_name] for loss_name in losses if 'percep' in loss_name ]
    # 			energy_avg = sum(all_perceps) / len(all_perceps)
    # 			for loss_name in losses:
    # 				if 'percep' not in loss_name: continue
    # 				percep_losses[loss_name] += [losses[loss_name].data.cpu().numpy()]
    # 			mse = losses['mse']
    # 			energy_losses.append(energy_avg.data.cpu().numpy())
    # 			mse_losses.append(mse.data.cpu().numpy())

    # 		val.step()
    # 	mse_arr = np.array(mse_losses)
    # 	energy_arr = np.array(energy_losses)
    # 	# logger.scatter(mse_arr - mse_arr.mean() / np.std(mse_arr), \
    # 	# 	energy_arr - energy_arr.mean() / np.std(energy_arr), \
    # 	# 	'unit_normal_all', opts={'xlabel':'mse','ylabel':'energy'})
    # 	logger.scatter(mse_arr, energy_arr, \
    # 		'mse_energy_all', opts={'xlabel':'mse','ylabel':'energy'})
    # 	pearsonr, p = scipy.stats.pearsonr(mse_arr, energy_arr)
    # 	logger.text(f'pearsonr = {pearsonr}, p = {p}')
    # 	pearsonr_vals.append(pearsonr)
    # 	logger.plot(pearsonr_vals, 'pearsonr_all')
    # 	for percep_name in percep_losses:
    # 		percep_loss_arr = np.array(percep_losses[percep_name])
    # 		logger.scatter(mse_arr, percep_loss_arr, f'mse_energy_{percep_name}', \
    # 			opts={'xlabel':'mse','ylabel':'energy'})
    # 		pearsonr, p = scipy.stats.pearsonr(mse_arr, percep_loss_arr)
    # 		pearson_percep[percep_name] += [pearsonr]
    # 		logger.plot(pearson_percep[percep_name], f'pearson_{percep_name}')

    # energy_loss.logger_update(logger)
    # if logger.data['val_mse : n(~x) -> y^'][-1] < best_ood_val_loss:
    # 	best_ood_val_loss = logger.data['val_mse : n(~x) -> y^'][-1]
    # 	energy_loss.plot_paths(graph, logger, realities, prefix="best")

    energy_mean_by_blur = []
    energy_std_by_blur = []
    mse_mean_by_blur = []
    mse_std_by_blur = []
    for blur_size in np.arange(0, 10, 0.5):
        tasks.rgb.blur_radius = blur_size if blur_size > 0 else None
        train_subset.step()
        # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "")

        energy_losses = []
        mse_losses = []
        for epochs in range(subset_size // batch_size):
            with torch.no_grad():
                flosses = energy_loss(graph,
                                      realities=[train_subset],
                                      reduce=False)
                losses = energy_loss(graph,
                                     realities=[train_subset],
                                     reduce=False)
                all_perceps = np.stack([
                    losses[loss_name].data.cpu().numpy()
                    for loss_name in losses if 'percep' in loss_name
                ])
                energy_losses += list(all_perceps.mean(0))
                mse_losses += list(losses['mse'].data.cpu().numpy())
            train_subset.step()
        mse_losses = np.array(mse_losses)
        energy_losses = np.array(energy_losses)
        logger.text(
            f'blur_radius = {blur_size}, mse = {mse_losses.mean()}, energy = {energy_losses.mean()}'
        )
        logger.scatter(mse_losses, energy_losses, \
         f'mse_energy, blur = {blur_size}', opts={'xlabel':'mse','ylabel':'energy'})

        energy_mean_by_blur += [energy_losses.mean()]
        energy_std_by_blur += [np.std(energy_losses)]
        mse_mean_by_blur += [mse_losses.mean()]
        mse_std_by_blur += [np.std(mse_losses)]

    logger.plot(energy_mean_by_blur, f'energy_mean_by_blur')
    logger.plot(energy_std_by_blur, f'energy_std_by_blur')
    logger.plot(mse_mean_by_blur, f'mse_mean_by_blur')
    logger.plot(mse_std_by_blur, f'mse_std_by_blur')
Exemple #23
0
def main(
    mode="standard", visualize=False,
    pretrained=True, finetuned=False, batch_size=None, 
    **kwargs,
):

    configs = {
        "VISUALS3_rgb2normals2x_multipercep8_winrate_standardized_upd": dict(
            loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
            cont="mount/shared/results_LBP_multipercep8_winrate_standardized_upd_3/graph.pth",
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2reshade2x_latwinrate_reshadetarget": dict(
            loss_configs=["baseline_reshade_size256", "baseline_reshade_size320", "baseline_reshade_size384", "baseline_reshade_size448", "baseline_reshade_size512"],
            cont="mount/shared/results_LBP_multipercep_latwinrate_reshadingtarget_6/graph.pth",
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2reshade2x_reshadebaseline": dict(
            loss_configs=["baseline_reshade_size256", "baseline_reshade_size320", "baseline_reshade_size384", "baseline_reshade_size448", "baseline_reshade_size512"],
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2reshade2x_latwinrate_depthtarget": dict(
            loss_configs=["baseline_depth_size256", "baseline_depth_size320", "baseline_depth_size384", "baseline_depth_size448", "baseline_depth_size512"],
            cont="mount/shared/results_LBP_multipercep_latwinrate_reshadingtarget_6/graph.pth",
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2reshade2x_depthbaseline": dict(
            loss_configs=["baseline_depth_size256", "baseline_depth_size320", "baseline_depth_size384", "baseline_depth_size448", "baseline_depth_size512"],
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2normals2x_baseline": dict(
            loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
            test=True, ood=True, oodfull=False,
        ),
        "VISUALS3_rgb2normals2x_multipercep": dict(
            loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
            test=True, ood=True, oodfull=False,
            cont="mount/shared/results_LBP_multipercep_32/graph.pth",
        ),
        "VISUALS3_rgb2x2normals_baseline": dict(
            loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"],
            finetuned=False,
            test=True, ood=True, ood_full=False,
        ),
        "VISUALS3_rgb2x2normals_finetuned": dict(
            loss_configs=["rgb2x2normals_plots", "rgb2x_plots2normals_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"],
            finetuned=True,
            test=True, ood=True, ood_full=False,
        ),
        "VISUALS3_rgb2x_baseline": dict(
            loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"],
            finetuned=False,
            test=True, ood=True, ood_full=False,
        ),
        "VISUALS3_rgb2x_finetuned": dict(
            loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"],
            finetuned=True,
            test=True, ood=True, ood_full=False,
        ),
    }

    # configs = {
    #   "VISUALS_rgb2normals2x_latv2": dict(
    #       loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #       cont="mount/shared/results_LBP_multipercep_latv2_10/graph.pth",
    #   ),
    #   "VISUALS_rgb2normals2x_lat_winrate": dict(
    #     loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #     cont="mount/shared/results_LBP_multipercep_lat_winrate_8/graph.pth",
    #   ),
    #   "VISUALS_rgb2normals2x_multipercep": dict(
    #     loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #     cont="mount/shared/results_LBP_multipercep_32/graph.pth",
    #   ),
    #   "VISUALS_rgb2normals2x_rndv2": dict(
    #     loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #     cont="mount/shared/results_LBP_multipercep_rnd_11/graph.pth",
    #   ),
    #   "VISUALS_rgb2normals2x_baseline": dict(
    #     loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"],
    #     cont=None,
    #   ),
    #   "VISUALS_rgb2x2normals_baseline": dict(
    #     loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"],
    #     finetuned=False,
    #   ),
    #   "VISUALS_rgb2x2normals_finetuned": dict(
    #     loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"],
    #     finetuned=True,
    #   ),
    #   "VISUALS_y2normals_baseline": dict(
    #     loss_configs=["y2normals_plots", "y2normals_plots_size320", "y2normals_plots_size384", "y2normals_plots_size448", "y2normals_plots_size512"],
    #     finetuned=False,
    #   ),
    #   "VISUALS_y2normals_finetuned": dict(
    #     loss_configs=["y2normals_plots", "y2normals_plots_size320", "y2normals_plots_size384", "y2normals_plots_size448", "y2normals_plots_size512"],
    #     finetuned=True,
    #   ),
    #   "VISUALS_rgb2x_baseline": dict(
    #     loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"],
    #     finetuned=False,
    #   ),
    #   "VISUALS_rgb2x_finetuned": dict(
    #     loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"],
    #     finetuned=True,
    #   ),
    # }

    for i in range(0, 5):

        config = configs[list(configs.keys())[0]]

        finetuned = config.get("finetuned", False)
        loss_configs = config["loss_configs"]

        loss_config = loss_configs[i]

        batch_size = batch_size or 32
        energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

        # DATA LOADING 1
        test_set = load_test(energy_loss.get_tasks("test"), sample=8)

        ood_tasks = [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb']
        ood_set = load_ood(ood_tasks, sample=4)
        print (ood_tasks)
        
        test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
        ood = RealityTask.from_static("ood", ood_set, ood_tasks)

        # DATA LOADING 2
        ood_tasks = list(set([tasks.rgb] + [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb']))
        test_set = load_test(ood_tasks, sample=2)
        ood_set = load_ood(ood_tasks)

        test2 = RealityTask.from_static("test", test_set, ood_tasks)
        ood2 = RealityTask.from_static("ood", ood_set, ood_tasks)

        # DATA LOADING 3
        test_set = load_test(energy_loss.get_tasks("test"), sample=8)
        ood_tasks = [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb']

        ood_loader = torch.utils.data.DataLoader(
            ImageDataset(tasks=ood_tasks, data_dir=f"{SHARED_DIR}/ood_images"),
            batch_size=32,
            num_workers=32, shuffle=False, pin_memory=True
        )
        data = list(itertools.islice(ood_loader, 2))
        test_set = data[0]
        ood_set = data[1]
        
        test3 = RealityTask.from_static("test", test_set, ood_tasks)
        ood3 = RealityTask.from_static("ood", ood_set, ood_tasks)




        for name, config in configs.items():

            finetuned = config.get("finetuned", False)
            loss_configs = config["loss_configs"]
            cont = config.get("cont", None)

            logger = VisdomLogger("train", env=name, delete=True if i == 0 else False)
            if config.get("test", False):                
                # GRAPH
                realities = [test, ood]
                print ("Finetuned: ", finetuned)
                graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True)
                if cont is not None: graph.load_weights(cont)

                # LOGGING
                energy_loss.plot_paths_errors(graph, logger, realities, prefix=loss_config)

    
            logger = VisdomLogger("train", env=name + "_ood", delete=True if i == 0 else False)
            if config.get("ood", False):
                # GRAPH
                realities = [test2, ood2]
                print ("Finetuned: ", finetuned)
                graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True)
                if cont is not None: graph.load_weights(cont)

                energy_loss.plot_paths(graph, logger, realities, prefix=loss_config)

            logger = VisdomLogger("train", env=name + "_oodfull", delete=True if i == 0 else False)
            if config.get("oodfull", False):

                # GRAPH
                realities = [test3, ood3]
                print ("Finetuned: ", finetuned)
                graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True)
                if cont is not None: graph.load_weights(cont)

                energy_loss.plot_paths(graph, logger, realities, prefix=loss_config)
Exemple #24
0
def main(
    fast=False,
    batch_size=None,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 32)
    energy_loss = get_energy_loss(config="consistency_two_path",
                                  mode="standard",
                                  **kwargs)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)

    # DATA LOADING
    video_dataset = ImageDataset(
        files=sorted(
            glob.glob(f"mount/taskonomy_house_tour/original/image*.png"),
            key=lambda x: int(os.path.basename(x)[5:-4])),
        return_tuple=True,
        resize=720,
    )
    video = RealityTask("video",
                        video_dataset, [
                            tasks.rgb,
                        ],
                        batch_size=batch_size,
                        shuffle=False)

    # GRAPHS
    graph_baseline = TaskGraph(tasks=energy_loss.tasks + [video],
                               finetuned=False)
    graph_baseline.compile(torch.optim.Adam,
                           lr=3e-5,
                           weight_decay=2e-6,
                           amsgrad=True)

    graph_finetuned = TaskGraph(tasks=energy_loss.tasks + [video],
                                finetuned=True)
    graph_finetuned.compile(torch.optim.Adam,
                            lr=3e-5,
                            weight_decay=2e-6,
                            amsgrad=True)

    graph_conservative = TaskGraph(tasks=energy_loss.tasks + [video],
                                   finetuned=True)
    graph_conservative.compile(torch.optim.Adam,
                               lr=3e-5,
                               weight_decay=2e-6,
                               amsgrad=True)
    graph_conservative.load_weights(
        f"{MODELS_DIR}/conservative/conservative.pth")

    graph_ood_conservative = TaskGraph(tasks=energy_loss.tasks + [video],
                                       finetuned=True)
    graph_ood_conservative.compile(torch.optim.Adam,
                                   lr=3e-5,
                                   weight_decay=2e-6,
                                   amsgrad=True)
    graph_ood_conservative.load_weights(
        f"{SHARED_DIR}/results_2F_grounded_1percent_gt_twopath_512_256_crop_7/graph_grounded_1percent_gt_twopath.pth"
    )

    graphs = {
        "baseline": graph_baseline,
        "finetuned": graph_finetuned,
        "conservative": graph_conservative,
        "ood_conservative": graph_ood_conservative,
    }

    inv_transform = transforms.ToPILImage()
    data = {key: {"losses": [], "zooms": []} for key in graphs}
    size = 256
    for batch in range(0, 700):

        if batch * batch_size > len(video_dataset.files): break

        frac = (batch * batch_size * 1.0) / len(video_dataset.files)
        if frac < 0.3:
            size = int(256.0 - 128 * frac / 0.3)
        elif frac < 0.5:
            size = int(128.0 + 128 * (frac - 0.3) / 0.2)
        else:
            size = int(256.0 + (720 - 256) * (frac - 0.5) / 0.5)
        print(size)
        # video.reload()
        size = (size // 32) * 32
        print(size)
        video.step()
        video.task_data[tasks.rgb] = resize(
            video.task_data[tasks.rgb].to(DEVICE), size).data
        print(video.task_data[tasks.rgb].shape)

        with torch.no_grad():

            for i, img in enumerate(video.task_data[tasks.rgb]):
                inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                    f"mount/taskonomy_house_tour/distorted/image{batch*batch_size + i}.png"
                )

            for name, graph in graphs.items():
                normals = graph.sample_path([tasks.rgb, tasks.normal],
                                            reality=video)
                normals2 = graph.sample_path(
                    [tasks.rgb, tasks.principal_curvature, tasks.normal],
                    reality=video)

                for i, img in enumerate(normals):
                    energy, _ = tasks.normal.norm(normals[i:(i + 1)],
                                                  normals2[i:(i + 1)])
                    data[name]["losses"] += [energy.data.cpu().numpy().mean()]
                    data[name]["zooms"] += [size]
                    inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                        f"mount/taskonomy_house_tour/normals_{name}/image{batch*batch_size + i}.png"
                    )

                for i, img in enumerate(normals2):
                    inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                        f"mount/taskonomy_house_tour/path2_{name}/image{batch*batch_size + i}.png"
                    )

    pickle.dump(data, open(f"mount/taskonomy_house_tour/data.pkl", 'wb'))
    os.system("bash ~/scaling/scripts/create_vids.sh")
Exemple #25
0
def main(
    loss_config="baseline_normal",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    learning_rate=5e-4,
    subset_size=None,
    max_epochs=5000,
    dataaug=False,
    **kwargs,
):

    # CONFIG
    wandb.config.update({
        "loss_config": loss_config,
        "batch_size": batch_size,
        "data_aug": dataaug,
        "lr": learning_rate
    })

    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size,
    )
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=False,
        freeze_list=energy_loss.freeze_list,
    )
    graph.compile(torch.optim.Adam,
                  lr=learning_rate,
                  weight_decay=2e-6,
                  amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)  # fake visdom logger
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    energy_loss.logger_hooks(logger)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)

        if (epochs % 100 == 0) or (epochs % 10 == 0 and epochs < 30):
            path_values = energy_loss.plot_paths(graph,
                                                 logger,
                                                 realities,
                                                 prefix="")
            for reality_paths, reality_images in path_values.items():
                wandb.log({reality_paths: [wandb.Image(reality_images)]},
                          step=epochs)

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph,
                                     realities=[train],
                                     compute_grad_ratio=True)
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])
            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)

        data = logger.step()
        del data['loss']
        del data['epoch']
        data = {k: v[0] for k, v in data.items()}
        wandb.log(data, step=epochs)

        # save model and opt state every 10 epochs
        if epochs % 10 == 0:
            graph.save(f"{RESULTS_DIR}/graph.pth")
            torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")

        # lower lr after 1500 epochs
        if epochs == 1500:
            graph.optimizer.param_groups[0]['lr'] = 3e-5

    graph.save(f"{RESULTS_DIR}/graph.pth")
    torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
    )
    train_step, val_step = train_step // 4, val_step // 4
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    print("Train step: ", train_step, "Val step: ", val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=True,
        freeze_list=[functional_transfers.a, functional_transfers.RC],
    )
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)

    activated_triangles = set()
    triangle_energy = {
        "triangle1_mse": float('inf'),
        "triangle2_mse": float('inf')
    }

    logger.add_hook(partial(jointplot, loss_type=f"energy"),
                    feature=f"val_energy",
                    freq=1)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.train()
        for _ in range(0, train_step):
            # loss_type = random.choice(["triangle1_mse", "triangle2_mse"])
            loss_type = max(triangle_energy, key=triangle_energy.get)

            activated_triangles.add(loss_type)
            train_loss = energy_loss(graph,
                                     realities=[train],
                                     loss_types=[loss_type])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()

            if loss_type == "triangle1_mse":
                consistency_tr1 = energy_loss.metrics["train"][
                    "triangle1_mse : F(RC(x)) -> n(x)"][-1]
                error_tr1 = energy_loss.metrics["train"][
                    "triangle1_mse : n(x) -> y^"][-1]
                triangle_energy["triangle1_mse"] = float(consistency_tr1 /
                                                         error_tr1)

            elif loss_type == "triangle2_mse":
                consistency_tr2 = energy_loss.metrics["train"][
                    "triangle2_mse : S(a(x)) -> n(x)"][-1]
                error_tr2 = energy_loss.metrics["train"][
                    "triangle2_mse : n(x) -> y^"][-1]
                triangle_energy["triangle2_mse"] = float(consistency_tr2 /
                                                         error_tr2)

            print("Triangle energy: ", triangle_energy)
            logger.update("loss", train_loss)

            energy = sum(triangle_energy.values())
            if (energy < float('inf')):
                logger.update("train_energy", energy)
                logger.update("val_energy", energy)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph,
                                       realities=[val],
                                       loss_types=list(activated_triangles))
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])

            val.step()
            logger.update("loss", val_loss)

        activated_triangles = set()
        energy_loss.logger_update(logger)
        logger.step()
Exemple #27
0
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    fast=False,
    batch_size=None,
    use_optimizer=False,
    subset_size=None,
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    print(kwargs)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    train_step, val_step = train_step // 4, val_step // 4
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood([
        tasks.rgb,
    ])
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, [
        tasks.rgb,
    ])

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(
        tasks=energy_loss.tasks + realities,
        pretrained=True,
        finetuned=False,
        freeze_list=energy_loss.freeze_list,
    )
    graph.edge(tasks.rgb, tasks.normal).model = None
    graph.edge(
        tasks.rgb, tasks.normal
    ).path = "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/n.pth"
    graph.edge(tasks.rgb, tasks.normal).load_model()
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    if use_optimizer:
        optimizer = torch.load(
            "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/optimizer.pth"
        )
        graph.optimizer.load_state_dict(optimizer.state_dict())

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix="start" if epochs == 0 else "")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        graph.train()
        for _ in range(0, train_step):
            train_loss = energy_loss(graph, realities=[train])
            train_loss = sum(
                [train_loss[loss_name] for loss_name in train_loss])

            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        energy_loss.logger_update(logger)

        # if logger.data["val_mse : y^ -> n(~x)"][-1] < best_ood_val_loss:
        # 	best_ood_val_loss = logger.data["val_mse : y^ -> n(~x)"][-1]
        # 	energy_loss.plot_paths(graph, logger, realities, prefix="best")

        logger.step()
Exemple #28
0
def test_dataloader():
    args = parser.parse_args()

    # Set seeds for determinism
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")
    if args.mixed_precision and not args.cuda:
        raise ValueError(
            'If using mixed precision training, CUDA must be enabled!')
    args.distributed = args.world_size > 1
    main_proc = True
    device = torch.device("cuda" if args.cuda else "cpu")
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_wer = None
    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir,
                                               args.log_params)

    avg_loss, start_epoch, start_iter, optim_state = 0, 0, 0, None

    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from,
                             map_location=lambda storage, loc: storage)

        if not args.finetune:  # Don't want to restart training
            optim_state = package['optim_dict']
            start_epoch = int(package.get(
                'epoch', 1)) - 1  # Index start at 0 for training
            start_iter = package.get('iteration', None)
            if start_iter is None:
                start_epoch += 1  # We saved model after epoch finished, start at the next epoch.
                start_iter = 0
            else:
                start_iter += 1
            avg_loss = int(package.get('avg_loss', 0))
            loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], \
                                                     package['wer_results']
            if main_proc and args.visdom:  # Add previous scores to visdom graph
                visdom_logger.load_previous_values(start_epoch, package)
            if main_proc and args.tensorboard:  # Previous scores to tensorboard logs
                tensorboard_logger.load_previous_values(start_epoch, package)

        print("Loading label from %s" % args.labels_path)
        with open(args.labels_path) as label_file:
            labels = str(''.join(json.load(label_file)))

        audio_conf = dict(sample_rate=args.sample_rate,
                          window_size=args.window_size,
                          window_stride=args.window_stride,
                          window=args.window,
                          noise_dir=args.noise_dir,
                          noise_prob=args.noise_prob,
                          noise_levels=(args.noise_min, args.noise_max))
    else:
        print("must load model!!!")
        exit()

    # decoder = GreedyDecoder(labels)
    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)

    train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)

    for i, (data) in enumerate(train_loader, start=start_iter):
        # 获取初始输入
        inputs, targets, input_percentages, target_sizes = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
        inputs = inputs.to(device)
        size = inputs.size()
        print(size)
        # 初始化模型
        model = M_Noise_Deepspeech(package, size)
        for para in model.deepspeech_net.parameters():
            para.requires_grad = False
        model = model.to(device)

        # 获取初始输出
        out_star = model.deepspeech_net(inputs, input_sizes)[0]
        out_star = out_star.transpose(0, 1)  # TxNxH
        float_out_star = out_star.float()
        break

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.SGD(parameters,
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=1e-5)
    print(model)
def main(
    loss_config="conservative_full",
    mode="standard",
    visualize=False,
    pretrained=True,
    finetuned=False,
    fast=False,
    batch_size=None,
    ood_batch_size=None,
    subset_size=None,
    cont=f"{MODELS_DIR}/conservative/conservative.pth",
    max_epochs=800,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING

    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size,
        fast=fast,
        subset_size=subset_size)
    test_set = load_test(energy_loss.get_tasks("test"))
    ood_set = load_ood(energy_loss.get_tasks("ood"))
    train_step, val_step = 4, 4
    print(train_step, val_step)

    train = RealityTask("train",
                        train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
    test = RealityTask.from_static("test", test_set,
                                   energy_loss.get_tasks("test"))
    ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood"))

    # GRAPH
    realities = [train, val, test, ood]
    graph = TaskGraph(tasks=energy_loss.tasks + realities,
                      freeze_list=energy_loss.freeze_list,
                      finetuned=finetuned)
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
    if not USE_RAID: graph.load_weights(cont)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"),
                    feature="epoch",
                    freq=1)
    energy_loss.logger_hooks(logger)
    best_ood_val_loss = float('inf')

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph,
                               logger,
                               realities,
                               prefix=f"epoch_{epochs}")
        if visualize: return

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.select_losses(val)
        if epochs != 0:
            energy_loss.logger_update(logger)
        else:
            energy_loss.metrics = {}
        logger.step()

        logger.text(f"Chosen losses: {energy_loss.chosen_losses}")
        logger.text(f"Percep winrate: {energy_loss.percep_winrate}")
        graph.train()
        for _ in range(0, train_step):
            train_loss2 = energy_loss(graph, realities=[train])
            train_loss = sum(train_loss2.values())

            graph.step(train_loss)
            train.step()

            logger.update("loss", train_loss)
Exemple #30
0
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

        device_id = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(device_id)
        print(f"Setting CUDA Device to {device_id}")

        dist.init_process_group(backend=args.dist_backend)
        main_proc = device_id == 0  # Main process handles saving of models and reporting

    checkpoint_handler = CheckpointHandler(save_folder=args.save_folder,
                                           best_val_model_name=args.best_val_model_name,
                                           checkpoint_per_iteration=args.checkpoint_per_iteration,
                                           save_n_recent_models=args.save_n_recent_models)

    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id + "-" + str(int(time.time())), args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id + "-" + str(int(time.time())), args.log_dir, args.log_params)

    if args.load_auto_checkpoint:
        latest_checkpoint = checkpoint_handler.find_latest_checkpoint()
        if latest_checkpoint:
            args.continue_from = latest_checkpoint

    if args.continue_from:  # Starting from previous model
        state = TrainingState.load_state(state_path=args.continue_from)
        model = state.model
        if args.finetune:
            state.init_finetune_states(args.epochs)

        if main_proc and args.visdom:  # Add previous scores to visdom graph