Ejemplo n.º 1
0
def main(arg):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu)
    model_save_dir = "./experiments/" + arg.name + "/"

    with tf.variable_scope("Data_prep"):
        if arg.mode == 'train':
            raw_dataset = dataset_map_train[arg.dataset](arg)

        elif arg.mode == 'predict':
            raw_dataset = dataset_map_test[arg.dataset](arg)

        dataset = raw_dataset.map(load_and_preprocess_image, num_parallel_calls=arg.data_parallel_calls)
        dataset = dataset.batch(arg['bn'], drop_remainder=True).repeat(arg.epochs)
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        b_images = next_element

        orig_images = tf.tile(b_images, [2, 1, 1, 1])

        scal = tf.placeholder(dtype=tf.float32, shape=(), name='scal_placeholder')
        tps_scal = tf.placeholder(dtype=tf.float32, shape=(), name='tps_placeholder')
        rot_scal = tf.placeholder(dtype=tf.float32, shape=(), name='rot_scal_placeholder')
        off_scal = tf.placeholder(dtype=tf.float32, shape=(), name='off_scal_placeholder')
        scal_var = tf.placeholder(dtype=tf.float32, shape=(), name='scal_var_placeholder')
        augm_scal = tf.placeholder(dtype=tf.float32, shape=(), name='augm_scal_placeholder')

        tps_param_dic = tps_parameters(2 * arg.bn, scal, tps_scal, rot_scal, off_scal, scal_var)
        tps_param_dic.augm_scal = augm_scal

    ctr = 0
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.95
    with tf.Session(config=config) as sess:

        model = Model(orig_images, arg, tps_param_dic)
        tvar = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        saver = tf.train.Saver(var_list=tvar)
        merged = tf.summary.merge_all()

        if arg.mode == 'train':
            if arg.load:
                ckpt, ctr = find_ckpt(model_save_dir + 'saved_model/')
                saver.restore(sess, ckpt)
            else:
                save_python_files(save_dir=model_save_dir + 'bin/')
                write_hyperparameters(arg.toDict(), model_save_dir)
                sess.run(tf.global_variables_initializer())

            writer = tf.summary.FileWriter("./summaries/" + arg.name, graph=sess.graph)

        elif arg.mode == 'predict':
            ckpt, ctr = find_ckpt(model_save_dir + 'saved_model/')
            saver.restore(sess, ckpt)

        initialize_uninitialized(sess)
        while True:
            print('iteration %d' %ctr)
            try:
                feed = transformation_parameters(arg, ctr, no_transform=(arg.mode == 'predict'))  # no transform if arg.visualize
                trf = {scal: feed.scal, tps_scal: feed.tps_scal,
                       scal_var: feed.scal_var, rot_scal: feed.rot_scal, off_scal: feed.off_scal, augm_scal: feed.augm_scal}
                ctr += 1
                if arg.mode == 'train':
                    if np.mod(ctr, arg.summary_interval) == 0:
                        merged_summary = sess.run(merged, feed_dict=trf)
                        writer.add_summary(merged_summary, global_step=ctr)

                    _, loss = sess.run([model.optimize, model.loss], feed_dict=trf)
                    if np.mod(ctr, arg.save_interval) == 0:
                        saver.save(sess, model_save_dir + '/saved_model/' + 'save_net.ckpt', global_step=ctr)

                elif arg.mode == 'predict':
                    img, img_rec, mu, heat_raw = sess.run([model.image_in, model.reconstruct_same_id, model.mu,
                                                           batch_colour_map(model.part_maps)], feed_dict=trf)

                    save(img, mu, ctr)

            except tf.errors.OutOfRangeError:
                print("End of training.")
                break
Ejemplo n.º 2
0
def main(arg):
    # Set random seeds
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    np.random.seed(42)
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(42)
    rng = np.random.RandomState(42)

    # Get args
    bn = arg.batch_size
    mode = arg.mode
    name = arg.name
    load_from_ckpt = arg.load_from_ckpt
    lr = arg.lr
    epochs = arg.epochs
    device = torch.device(
        'cuda:' + str(arg.gpu[0]) if torch.cuda.is_available() else 'cpu')
    arg.device = device

    # Load Datasets and DataLoader
    if arg.dataset != "mix":
        dataset = get_dataset(arg.dataset)
    if arg.dataset == 'pennaction':
        # init_dataset = dataset(size=arg.reconstr_dim, action_req=["tennis_serve", "tennis_forehand", "baseball_pitch",
        #                                                           "baseball_swing", "jumping_jacks", "golf_swing"])
        init_dataset = dataset(size=arg.reconstr_dim)
        splits = [
            int(len(init_dataset) * 0.8),
            len(init_dataset) - int(len(init_dataset) * 0.8)
        ]
        train_dataset, test_dataset = random_split(
            init_dataset, splits, generator=torch.Generator().manual_seed(42))
    elif arg.dataset == 'deepfashion':
        train_dataset = dataset(size=arg.reconstr_dim, train=True)
        test_dataset = dataset(size=arg.reconstr_dim, train=False)
    elif arg.dataset == 'human36':
        init_dataset = dataset(size=arg.reconstr_dim)
        splits = [
            int(len(init_dataset) * 0.8),
            len(init_dataset) - int(len(init_dataset) * 0.8)
        ]
        train_dataset, test_dataset = random_split(
            init_dataset, splits, generator=torch.Generator().manual_seed(42))
    elif arg.dataset == 'mix':
        # add pennaction
        dataset_pa = get_dataset("pennaction")
        init_dataset_pa = dataset_pa(size=arg.reconstr_dim,
                                     action_req=[
                                         "tennis_serve", "tennis_forehand",
                                         "baseball_pitch", "baseball_swing",
                                         "jumping_jacks", "golf_swing"
                                     ],
                                     mix=True)
        splits_pa = [
            int(len(init_dataset_pa) * 0.8),
            len(init_dataset_pa) - int(len(init_dataset_pa) * 0.8)
        ]
        train_dataset_pa, test_dataset_pa = random_split(
            init_dataset_pa,
            splits_pa,
            generator=torch.Generator().manual_seed(42))
        # add deepfashion
        dataset_df = get_dataset("deepfashion")
        train_dataset_df = dataset_df(size=arg.reconstr_dim,
                                      train=True,
                                      mix=True)
        test_dataset_df = dataset_df(size=arg.reconstr_dim,
                                     train=False,
                                     mix=True)
        # add human36
        dataset_h36 = get_dataset("human36")
        init_dataset_h36 = dataset_h36(size=arg.reconstr_dim, mix=True)
        splits_h36 = [
            int(len(init_dataset_h36) * 0.8),
            len(init_dataset_h36) - int(len(init_dataset_h36) * 0.8)
        ]
        train_dataset_h36, test_dataset_h36 = random_split(
            init_dataset_h36,
            splits_h36,
            generator=torch.Generator().manual_seed(42))
        # Concatinate all
        train_datasets = [train_dataset_df, train_dataset_h36]
        test_datasets = [test_dataset_df, test_dataset_h36]
        train_dataset = ConcatDataset(train_datasets)
        test_dataset = ConcatDataset(test_datasets)

    train_loader = DataLoader(train_dataset,
                              batch_size=bn,
                              shuffle=True,
                              num_workers=4)
    test_loader = DataLoader(test_dataset,
                             batch_size=bn,
                             shuffle=True,
                             num_workers=4)

    if mode == 'train':
        # Make new directory
        model_save_dir = '../results/' + arg.dataset + '/' + name
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
            os.makedirs(model_save_dir + '/summary')

        # Save Hyperparameters
        write_hyperparameters(arg.toDict(), model_save_dir)

        # Define Model
        model = Model(arg)
        if len(arg.gpu) > 1:
            model = torch.nn.DataParallel(model, device_ids=arg.gpu)
        model.to(device)
        if load_from_ckpt:
            model = load_model(model, model_save_dir, device).to(device)
        # Dataparallel
        print(arg.gpu)
        print(f'Number of Parameters: {count_parameters(model)}')

        # Definde Optimizer
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     weight_decay=arg.weight_decay)
        scheduler = ReduceLROnPlateau(optimizer,
                                      factor=0.2,
                                      threshold=1e-4,
                                      patience=6)

        # Log with wandb
        wandb.init(project='Disentanglement', config=arg, name=arg.name)
        wandb.watch(model, log='all')

        # Make Training
        with torch.autograd.set_detect_anomaly(False):
            for epoch in range(epochs + 1):
                # Train on Train Set
                model.train()
                # model.mode = 'train'
                for step, (original, keypoints) in enumerate(train_loader):
                    bn = original.shape[0]
                    original, keypoints = original.to(device), keypoints.to(
                        device)
                    # Forward Pass
                    ground_truth_images, img_reconstr, mu, L_inv, part_map_norm, heat_map, heat_map_norm, total_loss = model(
                        original)
                    # Track Mean and Precision Matrix
                    mu_norm = torch.mean(torch.norm(
                        mu[:bn], p=1, dim=2)).cpu().detach().numpy()
                    L_inv_norm = torch.mean(
                        torch.linalg.norm(L_inv[:bn], ord='fro',
                                          dim=[2, 3])).cpu().detach().numpy()
                    wandb.log({"Part Means": mu_norm})
                    wandb.log({"Precision Matrix": L_inv_norm})
                    # Zero out gradients
                    optimizer.zero_grad()
                    total_loss.backward()
                    optimizer.step()
                    # Track Loss
                    wandb.log({"Training Loss": total_loss.cpu()})
                    # Track Metric
                    score, mu, L_inv, part_map_norm, heat_map = keypoint_metric(
                        mu, keypoints, L_inv, part_map_norm, heat_map,
                        arg.reconstr_dim)
                    wandb.log({"Metric Train": score})
                    # Track progress
                    if step % 10000 == 0 and bn >= 4:
                        for step_, (original,
                                    keypoints) in enumerate(test_loader):
                            with torch.no_grad():
                                original, keypoints = original.to(
                                    device), keypoints.to(device)
                                ground_truth_images, img_reconstr, mu, L_inv, part_map_norm,\
                                heat_map, heat_map_norm, total_loss = model(original)
                                # Visualize Results
                                score, mu, L_inv, part_map_norm, heat_map = keypoint_metric(
                                    mu, keypoints, L_inv, part_map_norm,
                                    heat_map, arg.reconstr_dim)
                                img = visualize_results(
                                    ground_truth_images, img_reconstr, mu,
                                    L_inv, part_map_norm, heat_map, keypoints,
                                    model_save_dir + '/summary/', epoch,
                                    arg.background)
                                wandb.log({
                                    "Summary at step" + str(step):
                                    [wandb.Image(img)]
                                })
                                save_model(model, model_save_dir)
                                if step_ == 0:
                                    break

                # Evaluate on Test Set
                model.eval()
                val_score = torch.zeros(1)
                val_loss = torch.zeros(1)
                for step, (original, keypoints) in enumerate(test_loader):
                    with torch.no_grad():
                        original, keypoints = original.to(
                            device), keypoints.to(device)
                        ground_truth_images, img_reconstr, mu, L_inv, part_map_norm, heat_map, heat_map_norm, total_loss = model(
                            original)
                        # Track Loss and Metric
                        score, mu, L_inv, part_map_norm, heat_map = keypoint_metric(
                            mu, keypoints, L_inv, part_map_norm, heat_map,
                            arg.reconstr_dim)
                        val_score += score.cpu()
                        val_loss += total_loss.cpu()

                val_loss = val_loss / (step + 1)
                val_score = val_score / (step + 1)
                if epoch == 0:
                    best_score = val_score
                if val_score <= best_score:
                    best_score = val_score
                save_model(model, model_save_dir)
                scheduler.step(val_score)
                wandb.log({"Evaluation Loss": val_loss})
                wandb.log({"Metric Validation": val_score})

                # Track Progress & Visualization
                for step, (original, keypoints) in enumerate(test_loader):
                    with torch.no_grad():
                        original, keypoints = original.to(
                            device), keypoints.to(device)
                        ground_truth_images, img_reconstr, mu, L_inv, part_map_norm, heat_map, heat_map_norm, total_loss = model(
                            original)
                        score, mu, L_inv, part_map_norm, heat_map = keypoint_metric(
                            mu, keypoints, L_inv, part_map_norm, heat_map,
                            arg.reconstr_dim)
                        img = visualize_results(ground_truth_images,
                                                img_reconstr, mu, L_inv,
                                                part_map_norm, heat_map,
                                                keypoints,
                                                model_save_dir + '/summary/',
                                                epoch, arg.background)
                        wandb.log(
                            {"Summary_" + str(epoch): [wandb.Image(img)]})
                        if step == 0:
                            break

    elif mode == 'predict':
        # Make Directory for Predictions
        model_save_dir = '../results/' + arg.dataset + '/' + name
        # Dont use Transformations
        arg.tps_scal = 0.
        arg.rot_scal = 0.
        arg.off_scal = 0.
        arg.scal_var = 0.
        arg.augm_scal = 1.
        arg.contrast = 0.
        arg.brightness = 0.
        arg.saturation = 0.
        arg.hue = 0.

        # Load Model and Dataset
        model = Model(arg).to(device)
        model = load_model(model, model_save_dir, device)
        model.eval()

        # Log with wandb
        # wandb.init(project='Disentanglement', config=arg, name=arg.name)
        # wandb.watch(model, log='all')

        # Predict on Dataset
        val_score = torch.zeros(1)
        for step, (original, keypoints) in enumerate(test_loader):
            with torch.no_grad():
                original, keypoints = original.to(device), keypoints.to(device)
                ground_truth_images, img_reconstr, mu, L_inv, part_map_norm, heat_map, heat_map_norm, total_loss = model(
                    original)
                score, mu_new, L_inv, part_map_norm_new, heat_map_new = keypoint_metric(
                    mu, keypoints, L_inv, part_map_norm, heat_map,
                    arg.reconstr_dim)
                if step == 0:
                    img = visualize_predictions(original, img_reconstr, mu_new,
                                                part_map_norm_new,
                                                heat_map_new, mu,
                                                part_map_norm, heat_map,
                                                model_save_dir)
                # wandb.log({"Prediction": [wandb.Image(img)]})
                val_score += score.cpu()

        val_score = val_score / (step + 1)
        print("Validation Score: ", val_score)
Ejemplo n.º 3
0
def main(arg):
    # Set random seeds
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    np.random.seed(42)
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(42)
    rng = np.random.RandomState(42)

    # Get args
    bn = arg.batch_size
    mode = arg.mode
    name = arg.name
    load_from_ckpt = arg.load_from_ckpt
    lr = arg.lr
    L_inv_scal = arg.L_inv_scal
    epochs = arg.epochs
    device = torch.device(
        'cuda:' + str(arg.gpu[0]) if torch.cuda.is_available() else 'cpu')
    arg.device = device

    # Load Datasets and DataLoader
    dataset = get_dataset(arg.dataset)
    if arg.dataset == 'pennaction':
        init_dataset = dataset(size=arg.reconstr_dim,
                               action_req=[
                                   "tennis_serve", "tennis_forehand",
                                   "baseball_pitch", "baseball_swing",
                                   "jumping_jacks", "golf_swing"
                               ])
        splits = [
            int(len(init_dataset) * 0.8),
            len(init_dataset) - int(len(init_dataset) * 0.8)
        ]
        train_dataset, test_dataset = torch.utils.data.random_split(
            init_dataset, splits)
    else:
        train_dataset = dataset(size=arg.reconstr_dim, train=True)
        test_dataset = dataset(size=arg.reconstr_dim, train=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=bn,
                              shuffle=True,
                              num_workers=4)
    test_loader = DataLoader(test_dataset,
                             batch_size=bn,
                             shuffle=False,
                             num_workers=4)

    if mode == 'train':
        # Make new directory
        model_save_dir = '../results/' + name
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
            os.makedirs(model_save_dir + '/summary')

        # Save Hyperparameters
        write_hyperparameters(arg.toDict(), model_save_dir)

        # Define Model
        model = Model(arg).to(device)
        if load_from_ckpt:
            model = load_model(model, model_save_dir, device).to(device)
        print(f'Number of Parameters: {count_parameters(model)}')

        # Definde Optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        # Log with wandb
        wandb.init(project='Disentanglement', config=arg, name=arg.name)
        wandb.watch(model, log='all')

        # Make Training
        with torch.autograd.set_detect_anomaly(False):
            for epoch in range(epochs + 1):
                # Train on Train Set
                model.train()
                model.mode = 'train'
                for step, (original, keypoints) in enumerate(train_loader):
                    if epoch != 0:
                        model.L_sep = 0.
                    original, keypoints = original.to(device), keypoints.to(
                        device)
                    image_rec, reconstruct_same_id, loss, rec_loss, transform_loss, precision_loss, mu, L_inv, mu_original = model(
                        original)
                    mu_norm = torch.mean(torch.norm(
                        mu, p=1, dim=2)).cpu().detach().numpy()
                    L_inv_norm = torch.mean(
                        torch.linalg.norm(L_inv, ord='fro',
                                          dim=[2, 3])).cpu().detach().numpy()
                    # Track Mean and Precision Matrix
                    wandb.log({"Part Means": mu_norm})
                    wandb.log({"Precision Matrix": L_inv_norm})
                    # Zero out gradients
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    # Track Loss
                    wandb.log({"Training Loss": loss})
                    # Track Metric
                    score = keypoint_metric(mu_original, keypoints)
                    wandb.log({"Metric Train": score})

                # Evaluate on Test Set
                model.eval()
                val_score = torch.zeros(1)
                val_loss = torch.zeros(1)
                for step, (original, keypoints) in enumerate(test_loader):
                    with torch.no_grad():
                        original, keypoints = original.to(
                            device), keypoints.to(device)
                        image_rec, reconstruct_same_id, loss, rec_loss, transform_loss, precision_loss, mu, L_inv, mu_original = model(
                            original)
                        # Track Loss and Metric
                        score = keypoint_metric(mu_original, keypoints)
                        val_score += score.cpu()
                        val_loss += loss.cpu()

                val_loss = val_loss / (step + 1)
                val_score = val_score / (step + 1)
                wandb.log({"Evaluation Loss": val_loss})
                wandb.log({"Metric Validation": val_score})

                # Track Progress & Visualization
                for step, (original, keypoints) in enumerate(test_loader):
                    with torch.no_grad():
                        model.mode = 'predict'
                        original, keypoints = original.to(
                            device), keypoints.to(device)
                        original_part_maps, mu_original, image_rec, part_maps, part_maps, reconstruction = model(
                            original)
                        # img = visualize_predictions(original, original_part_maps, keypoints, reconstruction, image_rec[:original.shape[0]],
                        #                    image_rec[original.shape[0]:], part_maps[original.shape[0]:], part_maps[:original.shape[0]],
                        # #                    L_inv_scal, model_save_dir + '/summary/', epoch, device, show_labels=False)
                        # if epoch % 5 == 0:
                        #     wandb.log({"Summary_" + str(epoch): [wandb.Image(img)]})
                        save_model(model, model_save_dir)

                        if step == 0:
                            break
                # Decrements
                # model.L_sep = arg.sig_decr * model.L_sep

    elif mode == 'predict':
        # Make Directory for Predictions
        model_save_dir = '../results/' + arg.dataset + '/' + name
        # Dont use Transformations
        arg.tps_scal = 0.
        arg.rot_scal = 0.
        arg.off_scal = 0.
        arg.scal_var = 0.
        arg.augm_scal = 1.
        arg.contrast = 0.
        arg.brightness = 0.
        arg.saturation = 0.
        arg.hue = 0.

        # Load Model and Dataset
        model = Model(arg).to(device)
        model = load_model(model, model_save_dir, device)
        model.eval()

        # Log with wandb
        # wandb.init(project='Disentanglement', config=arg, name=arg.name)
        # wandb.watch(model, log='all')

        # Predict on Dataset
        val_score = torch.zeros(1)
        for step, (original, keypoints) in enumerate(test_loader):
            with torch.no_grad():
                original, keypoints = original.to(device), keypoints.to(device)
                ground_truth_images, img_reconstr, mu, L_inv, part_map_norm, heat_map, heat_map_norm, total_loss = model(
                    original)
                score, mu_new, L_inv, part_map_norm_new, heat_map_new = keypoint_metric(
                    mu, keypoints, L_inv, part_map_norm, heat_map,
                    arg.reconstr_dim)
                if step == 0:
                    img = visualize_predictions(original, img_reconstr, mu_new,
                                                part_map_norm_new,
                                                heat_map_new, mu,
                                                part_map_norm, heat_map,
                                                model_save_dir)
                # wandb.log({"Prediction": [wandb.Image(img)]})
                val_score += score.cpu()

        val_score = val_score / (step + 1)
        print("Validation Score: ", val_score)
Ejemplo n.º 4
0
def main(arg):
    # Set random seeds
    torch.manual_seed(7)
    torch.cuda.manual_seed(7)
    np.random.seed(7)

    # Get args
    bn = arg.bn
    mode = arg.mode
    name = arg.name
    load_from_ckpt = arg.load_from_ckpt
    lr = arg.lr
    epochs = arg.epochs
    device = torch.device('cuda:' +
                          str(arg.gpu) if torch.cuda.is_available() else 'cpu')
    arg.device = device

    if mode == 'train':
        # Make new directory
        model_save_dir = '../results/' + name
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
            os.makedirs(model_save_dir + '/summary')

        # Save Hyperparameters
        write_hyperparameters(arg.toDict(), model_save_dir)

        # Define Model & Optimizer
        model = Model(arg).to(device)
        if load_from_ckpt:
            model = load_model(model, model_save_dir).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        # Log with wandb
        wandb.init(project='Disentanglement', config=arg, name=arg.name)
        wandb.watch(model, log='all')
        # Load Datasets and DataLoader
        train_data, test_data = load_deep_fashion_dataset()
        train_dataset = ImageDataset(np.array(train_data))
        test_dataset = ImageDataset(np.array(test_data))
        train_loader = DataLoader(train_dataset,
                                  batch_size=bn,
                                  shuffle=True,
                                  num_workers=4)
        test_loader = DataLoader(test_dataset, batch_size=bn, num_workers=4)

        # Make Training
        with torch.autograd.set_detect_anomaly(False):
            for epoch in range(epochs + 1):
                # Train on Train Set
                model.train()
                model.mode = 'train'
                for step, original in enumerate(train_loader):
                    original = original.to(device)
                    # Make transformations
                    tps_param_dic = tps_parameters(original.shape[0], arg.scal,
                                                   arg.tps_scal, arg.rot_scal,
                                                   arg.off_scal, arg.scal_var,
                                                   arg.augm_scal)
                    coord, vector = make_input_tps_param(tps_param_dic)
                    coord, vector = coord.to(device), vector.to(device)
                    image_spatial_t, _ = ThinPlateSpline(
                        original, coord, vector, original.shape[3], device)
                    image_appearance_t = K.ColorJitter(arg.brightness,
                                                       arg.contrast,
                                                       arg.saturation,
                                                       arg.hue)(original)
                    image_spatial_t, image_appearance_t = normalize(
                        image_spatial_t), normalize(image_appearance_t)
                    reconstruction, loss, rec_loss, equiv_loss, mu, L_inv = model(
                        original, image_spatial_t, image_appearance_t, coord,
                        vector)
                    mu_norm = torch.mean(torch.norm(
                        mu, p=1, dim=2)).cpu().detach().numpy()
                    L_inv_norm = torch.mean(
                        torch.linalg.norm(L_inv, ord='fro',
                                          dim=[2, 3])).cpu().detach().numpy()
                    wandb.log({"Part Means": mu_norm})
                    wandb.log({"Precision Matrix": L_inv_norm})
                    # Zero out gradients
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    # Track Loss
                    if step == 0:
                        loss_log = torch.tensor([loss])
                        rec_loss_log = torch.tensor([rec_loss])
                    else:
                        loss_log = torch.cat([loss_log, torch.tensor([loss])])
                        rec_loss_log = torch.cat(
                            [rec_loss_log,
                             torch.tensor([rec_loss])])
                    training_loss = torch.mean(loss_log)
                    training_rec_loss = torch.mean(rec_loss_log)
                    wandb.log({"Training Loss": training_loss})
                    wandb.log({"Training Rec Loss": training_rec_loss})
                print(f'Epoch: {epoch}, Train Loss: {training_loss}')

                # Evaluate on Test Set
                model.eval()
                for step, original in enumerate(test_loader):
                    with torch.no_grad():
                        original = original.to(device)
                        tps_param_dic = tps_parameters(original.shape[0],
                                                       arg.scal, arg.tps_scal,
                                                       arg.rot_scal,
                                                       arg.off_scal,
                                                       arg.scal_var,
                                                       arg.augm_scal)
                        coord, vector = make_input_tps_param(tps_param_dic)
                        coord, vector = coord.to(device), vector.to(device)
                        image_spatial_t, _ = ThinPlateSpline(
                            original, coord, vector, original.shape[3], device)
                        image_appearance_t = K.ColorJitter(
                            arg.brightness, arg.contrast, arg.saturation,
                            arg.hue)(original)
                        image_spatial_t, image_appearance_t = normalize(
                            image_spatial_t), normalize(image_appearance_t)
                        reconstruction, loss, rec_loss, equiv_loss, mu, L_inv = model(
                            original, image_spatial_t, image_appearance_t,
                            coord, vector)
                        if step == 0:
                            loss_log = torch.tensor([loss])
                        else:
                            loss_log = torch.cat(
                                [loss_log, torch.tensor([loss])])
                evaluation_loss = torch.mean(loss_log)
                wandb.log({"Evaluation Loss": evaluation_loss})
                print(f'Epoch: {epoch}, Test Loss: {evaluation_loss}')

                # Track Progress
                if True:
                    model.mode = 'predict'
                    original, fmap_shape, fmap_app, reconstruction = model(
                        original, image_spatial_t, image_appearance_t, coord,
                        vector)
                    make_visualization(original, reconstruction,
                                       image_spatial_t, image_appearance_t,
                                       fmap_shape, fmap_app, model_save_dir,
                                       epoch, device)
                    save_model(model, model_save_dir)

    elif mode == 'predict':
        # Make Directory for Predictions
        model_save_dir = '../results/' + name
        if not os.path.exists(model_save_dir + '/predictions'):
            os.makedirs(model_save_dir + '/predictions')
        # Load Model and Dataset
        model = Model(arg).to(device)
        model = load_model(model, model_save_dir).to(device)
        data = load_deep_fashion_dataset()
        test_data = np.array(data[-4:])
        test_dataset = ImageDataset(test_data)
        test_loader = DataLoader(test_dataset, batch_size=bn)
        model.mode = 'predict'
        model.eval()
        # Predict on Dataset
        for step, original in enumerate(test_loader):
            with torch.no_grad():
                original = original.to(device)
                tps_param_dic = tps_parameters(original.shape[0], arg.scal,
                                               arg.tps_scal, arg.rot_scal,
                                               arg.off_scal, arg.scal_var,
                                               arg.augm_scal)
                coord, vector = make_input_tps_param(tps_param_dic)
                coord, vector = coord.to(device), vector.to(device)
                image_spatial_t, _ = ThinPlateSpline(original, coord, vector,
                                                     original.shape[3], device)
                image_appearance_t = K.ColorJitter(arg.brightness,
                                                   arg.contrast,
                                                   arg.saturation,
                                                   arg.hue)(original)
                image, reconstruction, mu, shape_stream_parts, heat_map = model(
                    original, image_spatial_t, image_appearance_t, coord,
                    vector)