Пример #1
0
def make_pairs(img_original, arg):
    bn, c, h, w = img_original.shape
    # Make image and grid
    tps_param_dic = tps_parameters(bn, arg.scal, 0., 0., 0., 0., arg.augm_scal)
    coord, vector = make_input_tps_param(tps_param_dic)
    coord, vector = coord.to(arg.device), vector.to(arg.device)
    img, mesh = ThinPlateSpline(img_original,
                                coord,
                                vector,
                                arg.reconstr_dim,
                                device=arg.device)
    # Make transformed image and grid
    tps_param_dic_rot = tps_parameters(bn, arg.scal, arg.tps_scal,
                                       arg.rot_scal, arg.off_scal,
                                       arg.scal_var, arg.augm_scal)
    coord_rot, vector_rot = make_input_tps_param(tps_param_dic_rot)
    coord_rot, vector_rot = coord_rot.to(arg.device), vector_rot.to(arg.device)
    img_rot, mesh_rot = ThinPlateSpline(img_original,
                                        coord_rot,
                                        vector_rot,
                                        arg.reconstr_dim,
                                        device=arg.device)
    # Make augmentation
    img_stack = torch.cat([img, img_rot], dim=0)
    img_stack_augm = augm(img_stack, arg, arg.device)
    img_augm, img_rot_augm = img_stack_augm[:bn], img_stack_augm[bn:]

    # Make input stack
    input_images = F.interpolate(torch.cat([img_augm, img_rot], dim=0),
                                 size=arg.reconstr_dim).clamp(min=0., max=1.)
    reconstr_images = F.interpolate(torch.cat([img, img_rot_augm], dim=0),
                                    size=arg.reconstr_dim).clamp(min=0.,
                                                                 max=1.)
    mesh_stack = torch.cat([mesh, mesh_rot], dim=0)

    return input_images, reconstr_images, mesh_stack
Пример #2
0
    def forward(self, x):
        # tps
        image_orig = x.repeat(2, 1, 1, 1)
        tps_param_dic = tps_parameters(image_orig.shape[0], self.scal,
                                       self.tps_scal, self.rot_scal,
                                       self.off_scal, self.scal_var,
                                       self.augm_scal)
        coord, vector = make_input_tps_param(tps_param_dic)
        coord, vector = coord.to(self.device), vector.to(self.device)
        t_images, t_mesh = ThinPlateSpline(image_orig,
                                           coord,
                                           vector,
                                           128,
                                           device=self.device)
        image_in, image_rec = prepare_pairs(t_images, self.arg)
        transform_mesh = F.interpolate(t_mesh, size=64)
        volume_mesh = AbsDetJacobian(transform_mesh, self.device)

        # encoding
        _, part_maps, sum_part_maps = self.E_sigma(image_in)
        mu, L_inv = get_mu_and_prec(part_maps, self.device, self.L_inv_scal)
        heat_map = get_heat_map(mu, L_inv, self.device)
        raw_features = self.E_alpha(sum_part_maps)
        features = get_local_part_appearances(raw_features, part_maps)

        # transform
        integrant = (part_maps.unsqueeze(-1) *
                     volume_mesh.unsqueeze(-1)).squeeze()
        integrant = integrant / torch.sum(integrant, dim=[2, 3], keepdim=True)
        mu_t = contract('akij, alij -> akl', integrant, transform_mesh)
        transform_mesh_out_prod = contract('amij, anij -> amnij',
                                           transform_mesh, transform_mesh)
        mu_out_prod = contract('akm, akn -> akmn', mu_t, mu_t)
        stddev_t = contract('akij, amnij -> akmn', integrant,
                            transform_mesh_out_prod) - mu_out_prod

        # processing
        encoding = feat_mu_to_enc(features, mu, L_inv, self.device,
                                  self.covariance)
        reconstruct_same_id = self.decoder(encoding)

        loss = nn.MSELoss()(image_rec, reconstruct_same_id)

        if self.mode == 'predict':
            return image_in, image_rec, mu, heat_map

        elif self.mode == 'train':
            return reconstruct_same_id, loss
Пример #3
0
def main(arg):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu)
    model_save_dir = "./experiments/" + arg.name + "/"

    with tf.variable_scope("Data_prep"):
        if arg.mode == 'train':
            raw_dataset = dataset_map_train[arg.dataset](arg)

        elif arg.mode == 'predict':
            raw_dataset = dataset_map_test[arg.dataset](arg)

        dataset = raw_dataset.map(load_and_preprocess_image, num_parallel_calls=arg.data_parallel_calls)
        dataset = dataset.batch(arg['bn'], drop_remainder=True).repeat(arg.epochs)
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        b_images = next_element

        orig_images = tf.tile(b_images, [2, 1, 1, 1])

        scal = tf.placeholder(dtype=tf.float32, shape=(), name='scal_placeholder')
        tps_scal = tf.placeholder(dtype=tf.float32, shape=(), name='tps_placeholder')
        rot_scal = tf.placeholder(dtype=tf.float32, shape=(), name='rot_scal_placeholder')
        off_scal = tf.placeholder(dtype=tf.float32, shape=(), name='off_scal_placeholder')
        scal_var = tf.placeholder(dtype=tf.float32, shape=(), name='scal_var_placeholder')
        augm_scal = tf.placeholder(dtype=tf.float32, shape=(), name='augm_scal_placeholder')

        tps_param_dic = tps_parameters(2 * arg.bn, scal, tps_scal, rot_scal, off_scal, scal_var)
        tps_param_dic.augm_scal = augm_scal

    ctr = 0
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.95
    with tf.Session(config=config) as sess:

        model = Model(orig_images, arg, tps_param_dic)
        tvar = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        saver = tf.train.Saver(var_list=tvar)
        merged = tf.summary.merge_all()

        if arg.mode == 'train':
            if arg.load:
                ckpt, ctr = find_ckpt(model_save_dir + 'saved_model/')
                saver.restore(sess, ckpt)
            else:
                save_python_files(save_dir=model_save_dir + 'bin/')
                write_hyperparameters(arg.toDict(), model_save_dir)
                sess.run(tf.global_variables_initializer())

            writer = tf.summary.FileWriter("./summaries/" + arg.name, graph=sess.graph)

        elif arg.mode == 'predict':
            ckpt, ctr = find_ckpt(model_save_dir + 'saved_model/')
            saver.restore(sess, ckpt)

        initialize_uninitialized(sess)
        while True:
            print('iteration %d' %ctr)
            try:
                feed = transformation_parameters(arg, ctr, no_transform=(arg.mode == 'predict'))  # no transform if arg.visualize
                trf = {scal: feed.scal, tps_scal: feed.tps_scal,
                       scal_var: feed.scal_var, rot_scal: feed.rot_scal, off_scal: feed.off_scal, augm_scal: feed.augm_scal}
                ctr += 1
                if arg.mode == 'train':
                    if np.mod(ctr, arg.summary_interval) == 0:
                        merged_summary = sess.run(merged, feed_dict=trf)
                        writer.add_summary(merged_summary, global_step=ctr)

                    _, loss = sess.run([model.optimize, model.loss], feed_dict=trf)
                    if np.mod(ctr, arg.save_interval) == 0:
                        saver.save(sess, model_save_dir + '/saved_model/' + 'save_net.ckpt', global_step=ctr)

                elif arg.mode == 'predict':
                    img, img_rec, mu, heat_raw = sess.run([model.image_in, model.reconstruct_same_id, model.mu,
                                                           batch_colour_map(model.part_maps)], feed_dict=trf)

                    save(img, mu, ctr)

            except tf.errors.OutOfRangeError:
                print("End of training.")
                break
Пример #4
0
    def forward(self, x):
        batch_size = x.shape[0]
        batch_size2 = 2 * x.shape[0]
        # tps
        image_orig = x.repeat(2, 1, 1, 1)
        tps_param_dic = tps_parameters(batch_size2, self.scal, self.tps_scal,
                                       self.rot_scal, self.off_scal,
                                       self.scal_var, self.augm_scal)
        coord, vector = make_input_tps_param(tps_param_dic)
        coord, vector = coord.to(self.device), vector.to(self.device)
        t_images, t_mesh = ThinPlateSpline(image_orig,
                                           coord,
                                           vector,
                                           self.reconstr_dim,
                                           device=self.device)
        image_in, image_rec = prepare_pairs(t_images, self.arg, self.device)
        transform_mesh = F.interpolate(t_mesh, size=64)
        volume_mesh = AbsDetJacobian(transform_mesh, self.device)

        # encoding
        part_maps_raw, part_maps_norm, sum_part_maps = self.E_sigma(image_in)
        mu, L_inv = get_mu_and_prec(part_maps_norm, self.device,
                                    self.L_inv_scal)
        raw_features = self.E_alpha(sum_part_maps)
        features = get_local_part_appearances(raw_features, part_maps_norm)

        heat_map = get_heat_map(mu, L_inv, self.device, self.background)
        norm = torch.sum(heat_map, 1, keepdim=True) + 1
        heat_map_norm = heat_map / norm

        # transform
        integrant = (part_maps_norm.unsqueeze(-1) *
                     volume_mesh.unsqueeze(-1)).squeeze()
        integrant = integrant / torch.sum(integrant, dim=[2, 3], keepdim=True)
        mu_t = contract('akij, alij -> akl', integrant, transform_mesh)
        transform_mesh_out_prod = contract('amij, anij -> amnij',
                                           transform_mesh, transform_mesh)
        mu_out_prod = contract('akm, akn -> akmn', mu_t, mu_t)
        stddev_t = contract('akij, amnij -> akmn', integrant,
                            transform_mesh_out_prod) - mu_out_prod

        # processing
        encoding = feat_mu_to_enc(features, mu, L_inv, self.device,
                                  self.reconstr_dim, self.background)
        reconstruct_same_id = self.decoder(encoding)

        total_loss, rec_loss, transform_loss, precision_loss = loss_fn(
            batch_size, mu, L_inv, mu_t, stddev_t, reconstruct_same_id,
            image_rec, self.l_2_scal, self.l_2_threshold, self.L_mu,
            self.L_cov, self.L_rec, self.device, self.background, True)

        # norms
        original_part_maps_raw, original_part_maps_norm, original_sum_part_maps = self.E_sigma(
            x)
        mu_original, L_inv_original = get_mu_and_prec(original_part_maps_norm,
                                                      self.device,
                                                      self.L_inv_scal)

        if self.mode == 'predict':
            return image_rec, reconstruct_same_id, mu, L_inv, part_maps_norm, heat_map, heat_map_norm, total_loss

        elif self.mode == 'train':
            return image_rec, reconstruct_same_id, total_loss, rec_loss, transform_loss, precision_loss, mu[:, :
                                                                                                            -1], L_inv[:, :
                                                                                                                       -1], mu_original[:, :
                                                                                                                                        -1]
Пример #5
0
def main(arg):
    # Set random seeds
    torch.manual_seed(7)
    torch.cuda.manual_seed(7)
    np.random.seed(7)

    # Get args
    bn = arg.bn
    mode = arg.mode
    name = arg.name
    load_from_ckpt = arg.load_from_ckpt
    lr = arg.lr
    epochs = arg.epochs
    device = torch.device('cuda:' +
                          str(arg.gpu) if torch.cuda.is_available() else 'cpu')
    arg.device = device

    if mode == 'train':
        # Make new directory
        model_save_dir = '../results/' + name
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
            os.makedirs(model_save_dir + '/summary')

        # Save Hyperparameters
        write_hyperparameters(arg.toDict(), model_save_dir)

        # Define Model & Optimizer
        model = Model(arg).to(device)
        if load_from_ckpt:
            model = load_model(model, model_save_dir).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        # Log with wandb
        wandb.init(project='Disentanglement', config=arg, name=arg.name)
        wandb.watch(model, log='all')
        # Load Datasets and DataLoader
        train_data, test_data = load_deep_fashion_dataset()
        train_dataset = ImageDataset(np.array(train_data))
        test_dataset = ImageDataset(np.array(test_data))
        train_loader = DataLoader(train_dataset,
                                  batch_size=bn,
                                  shuffle=True,
                                  num_workers=4)
        test_loader = DataLoader(test_dataset, batch_size=bn, num_workers=4)

        # Make Training
        with torch.autograd.set_detect_anomaly(False):
            for epoch in range(epochs + 1):
                # Train on Train Set
                model.train()
                model.mode = 'train'
                for step, original in enumerate(train_loader):
                    original = original.to(device)
                    # Make transformations
                    tps_param_dic = tps_parameters(original.shape[0], arg.scal,
                                                   arg.tps_scal, arg.rot_scal,
                                                   arg.off_scal, arg.scal_var,
                                                   arg.augm_scal)
                    coord, vector = make_input_tps_param(tps_param_dic)
                    coord, vector = coord.to(device), vector.to(device)
                    image_spatial_t, _ = ThinPlateSpline(
                        original, coord, vector, original.shape[3], device)
                    image_appearance_t = K.ColorJitter(arg.brightness,
                                                       arg.contrast,
                                                       arg.saturation,
                                                       arg.hue)(original)
                    image_spatial_t, image_appearance_t = normalize(
                        image_spatial_t), normalize(image_appearance_t)
                    reconstruction, loss, rec_loss, equiv_loss, mu, L_inv = model(
                        original, image_spatial_t, image_appearance_t, coord,
                        vector)
                    mu_norm = torch.mean(torch.norm(
                        mu, p=1, dim=2)).cpu().detach().numpy()
                    L_inv_norm = torch.mean(
                        torch.linalg.norm(L_inv, ord='fro',
                                          dim=[2, 3])).cpu().detach().numpy()
                    wandb.log({"Part Means": mu_norm})
                    wandb.log({"Precision Matrix": L_inv_norm})
                    # Zero out gradients
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    # Track Loss
                    if step == 0:
                        loss_log = torch.tensor([loss])
                        rec_loss_log = torch.tensor([rec_loss])
                    else:
                        loss_log = torch.cat([loss_log, torch.tensor([loss])])
                        rec_loss_log = torch.cat(
                            [rec_loss_log,
                             torch.tensor([rec_loss])])
                    training_loss = torch.mean(loss_log)
                    training_rec_loss = torch.mean(rec_loss_log)
                    wandb.log({"Training Loss": training_loss})
                    wandb.log({"Training Rec Loss": training_rec_loss})
                print(f'Epoch: {epoch}, Train Loss: {training_loss}')

                # Evaluate on Test Set
                model.eval()
                for step, original in enumerate(test_loader):
                    with torch.no_grad():
                        original = original.to(device)
                        tps_param_dic = tps_parameters(original.shape[0],
                                                       arg.scal, arg.tps_scal,
                                                       arg.rot_scal,
                                                       arg.off_scal,
                                                       arg.scal_var,
                                                       arg.augm_scal)
                        coord, vector = make_input_tps_param(tps_param_dic)
                        coord, vector = coord.to(device), vector.to(device)
                        image_spatial_t, _ = ThinPlateSpline(
                            original, coord, vector, original.shape[3], device)
                        image_appearance_t = K.ColorJitter(
                            arg.brightness, arg.contrast, arg.saturation,
                            arg.hue)(original)
                        image_spatial_t, image_appearance_t = normalize(
                            image_spatial_t), normalize(image_appearance_t)
                        reconstruction, loss, rec_loss, equiv_loss, mu, L_inv = model(
                            original, image_spatial_t, image_appearance_t,
                            coord, vector)
                        if step == 0:
                            loss_log = torch.tensor([loss])
                        else:
                            loss_log = torch.cat(
                                [loss_log, torch.tensor([loss])])
                evaluation_loss = torch.mean(loss_log)
                wandb.log({"Evaluation Loss": evaluation_loss})
                print(f'Epoch: {epoch}, Test Loss: {evaluation_loss}')

                # Track Progress
                if True:
                    model.mode = 'predict'
                    original, fmap_shape, fmap_app, reconstruction = model(
                        original, image_spatial_t, image_appearance_t, coord,
                        vector)
                    make_visualization(original, reconstruction,
                                       image_spatial_t, image_appearance_t,
                                       fmap_shape, fmap_app, model_save_dir,
                                       epoch, device)
                    save_model(model, model_save_dir)

    elif mode == 'predict':
        # Make Directory for Predictions
        model_save_dir = '../results/' + name
        if not os.path.exists(model_save_dir + '/predictions'):
            os.makedirs(model_save_dir + '/predictions')
        # Load Model and Dataset
        model = Model(arg).to(device)
        model = load_model(model, model_save_dir).to(device)
        data = load_deep_fashion_dataset()
        test_data = np.array(data[-4:])
        test_dataset = ImageDataset(test_data)
        test_loader = DataLoader(test_dataset, batch_size=bn)
        model.mode = 'predict'
        model.eval()
        # Predict on Dataset
        for step, original in enumerate(test_loader):
            with torch.no_grad():
                original = original.to(device)
                tps_param_dic = tps_parameters(original.shape[0], arg.scal,
                                               arg.tps_scal, arg.rot_scal,
                                               arg.off_scal, arg.scal_var,
                                               arg.augm_scal)
                coord, vector = make_input_tps_param(tps_param_dic)
                coord, vector = coord.to(device), vector.to(device)
                image_spatial_t, _ = ThinPlateSpline(original, coord, vector,
                                                     original.shape[3], device)
                image_appearance_t = K.ColorJitter(arg.brightness,
                                                   arg.contrast,
                                                   arg.saturation,
                                                   arg.hue)(original)
                image, reconstruction, mu, shape_stream_parts, heat_map = model(
                    original, image_spatial_t, image_appearance_t, coord,
                    vector)