def generate_transform_matrices(name, refer_vertices, refer_triangles,
                                factors):
    """Generates len(factors) meshes, each of them is scaled by factors[i] and
       computes the transformations between them.

    Returns:
       M: a set of meshes downsampled from mesh by a factor specified in factors.
       A: Adjacency matrix for each of the meshes
       D: Downsampling transforms between each of the meshes
       U: Upsampling transforms between each of the meshes
    """

    factors = [1.0 / x for x in factors]
    # M, A, D, U = [], [], [], []
    # V, T, A, D, U = [], [], [], [], []
    vertices = []
    triangles = []
    adjacencies = []
    downsamp_trans = []
    upsamp_trans = []
    adjacencies.append(
        utils.get_vert_connectivity(refer_vertices, refer_triangles))
    # M.append(mesh)
    vertices.append(refer_vertices)
    triangles.append(refer_triangles)

    for factor in factors:
        ds_triangle, ds_transform = qslim_decimator_transformer(vertices[-1],
                                                                triangles[-1],
                                                                factor=factor)
        downsamp_trans.append(ds_transform)
        # new_mesh_v = ds_D.dot(M[-1].v)
        ds_vertice = ds_transform.dot(vertices[-1])
        # new_mesh = Mesh(v=new_mesh_v, f=ds_f)
        # M.append(new_mesh)
        vertices.append(ds_vertice)
        triangles.append(ds_triangle)
        adjacencies.append(utils.get_vert_connectivity(ds_vertice,
                                                       ds_triangle))
        # U.append(setup_deformation_transfer(M[-1], M[-2]))
        upsamp_trans.append(
            setup_deformation_transfer(vertices[-1], triangles[-1],
                                       vertices[-2]))

    for i, (vertice, triangle) in enumerate(zip(vertices, triangles)):
        write_obj(
            os.path.join('data', 'reference', name,
                         'reference{}.obj'.format(i)), vertice, triangle)

    return adjacencies, downsamp_trans, upsamp_trans
Exemple #2
0
    def process(self):
        train_data, val_data, test_data = [], [], []
        train_vertices = []
        for idx, data_file in tqdm(enumerate(self.data_file)):
            mesh = Mesh(filename=data_file)
            mesh_verts = torch.Tensor(mesh.v)
            adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
            edge_index = torch.Tensor(np.vstack(
                (adjacency.row, adjacency.col)))
            data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)

            if self.split == 'sliced':
                if idx % 100 <= 10:
                    test_data.append(data)
                elif idx % 100 <= 20:
                    val_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)

            elif self.split == 'expression':
                if data_file.split('/')[-2] == self.split_term:
                    test_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)

            elif self.split == 'identity':
                if data_file.split('/')[-3] == self.split_term:
                    test_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)
            else:
                raise Exception(
                    'sliced, expression and identity are the only supported split terms'
                )

        if self.split != 'sliced':
            val_data = test_data[-self.nVal:]
            test_data = test_data[:-self.nVal]

        mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
        std_train = torch.Tensor(np.std(train_vertices, axis=0))
        norm_dict = {'mean': mean_train, 'std': std_train}
        if self.pre_transform is not None:
            if hasattr(self.pre_transform, 'mean') and hasattr(
                    self.pre_transform, 'std'):
                if self.pre_tranform.mean is None:
                    self.pre_tranform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_tranform.std = std_train
            train_data = [self.pre_transform(td) for td in train_data]
            val_data = [self.pre_transform(td) for td in val_data]
            test_data = [self.pre_transform(td) for td in test_data]

        torch.save(self.collate(train_data), self.processed_paths[0])
        torch.save(self.collate(val_data), self.processed_paths[1])
        torch.save(self.collate(test_data), self.processed_paths[2])
        torch.save(norm_dict, self.processed_paths[3])
Exemple #3
0
    def process(self):
        dataset = []
        vertices = []
        for dx in self.filepaths:
            for fp in self.filepaths[dx]:
                mesh = Mesh(filename=fp)
                mesh_verts = torch.Tensor(mesh.v)
                adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                edge_index = torch.Tensor(
                    np.vstack((adjacency.row, adjacency.col)))

                i = self.categories.index(dx)
                label = np.zeros(len(self.categories))
                label[i] = 1
                data = Data(x=mesh_verts,
                            y=torch.Tensor(label),
                            edge_index=edge_index)

                dataset.append(data)
                vertices.append(mesh.v)

        if self.dtype == 'train':
            mean_train = torch.Tensor(np.mean(vertices, axis=0))
            std_train = torch.Tensor(np.std(vertices, axis=0))
            norm_dict = {'mean': mean_train, 'std': std_train}

            if hasattr(self.pre_transform, 'mean') and hasattr(
                    self.pre_transform, 'std'):
                if self.pre_transform.mean is None:
                    self.pre_transform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_transform.std = std_train
                self.data = [self.pre_transform(td) for td in dataset]

            torch.save(self.collate(dataset), self.processed_paths[0])
            torch.save(norm_dict, self.processed_paths[2])

        elif self.dtype == 'test':
            norm_path = self.processed_paths[2]
            norm_dict = torch.load(norm_path)
            mean_train, std_train = norm_dict['mean'], norm_dict['std']

            if hasattr(self.pre_transform, 'mean') and hasattr(
                    self.pre_transform, 'std'):
                if self.pre_transform.mean is None:
                    self.pre_transform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_transform.std = std_train
                self.data = [self.pre_transform(td) for td in dataset]

            torch.save(self.collate(dataset), self.processed_paths[1])
Exemple #4
0
    def process(self):
        print('Computing mean and std ...')
        train_vertices = []
        for idx, data_file in tqdm(enumerate(self.datafile_train)):
            mesh = Mesh(filename=data_file)
            train_vertices.append(mesh.v)
        mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
        std_train = torch.Tensor(np.std(train_vertices, axis=0))
        norm_dict = {'mean': mean_train, 'std': std_train}
        dest_path = osp.join(self.processed_dir, self.split_term)
        if not os.path.exists(dest_path):
            os.makedirs(dest_path)
        torch.save(norm_dict, self.processed_paths[0])
        
        if self.pre_transform is not None:
            if hasattr(self.pre_transform, 'mean') and hasattr(self.pre_transform, 'std'):
                if self.pre_tranform.mean is None:
                    self.pre_tranform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_tranform.std = std_train
        
        subsets = ['train', 'test', 'val']
        for subset in subsets:
            print('processing {} ...'.format(subset))
            dest_path = osp.join(self.processed_dir, self.split_term, '{}'.format(subset))
            if not os.path.exists(dest_path):
                os.makedirs(dest_path)

            vertices = []
            for idx, data_file in tqdm(enumerate(eval('self.datafile_'+subset))):
                mesh = Mesh(filename=data_file)
                mesh_verts = torch.Tensor(mesh.v)
                adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                edge_index = torch.LongTensor(np.vstack((adjacency.row, adjacency.col)))
                # edge_index = torch.Tensor(np.vstack((adjacency.row, adjacency.col)))
                data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)
                vertices.append(mesh.v)
                
                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                
                torch.save(data, osp.join(self.processed_dir, self.split_term, '{}/data_{}.pt'.format(subset,idx)))
Exemple #5
0
    def process(self):
        train_data, val_data, test_data = [], [], []
        train_vertices = []
        for dx in self.data_file :
            for tt in self.data_file[dx] :
                for f in tqdm(self.data_file[dx][tt]):
                    mesh = Mesh(filename=f)
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(np.vstack((adjacency.row, adjacency.col)))
                    if dx == 'ad' :
                        data = Data(x=mesh_verts, y=torch.Tensor([1,0]), edge_index=edge_index)
                    elif dx == 'cn' :
                        data = Data(x=mesh_verts, y=torch.Tensor([0,1]), edge_index=edge_index)

                    if tt == 'test' :
                        test_data.append(data)
                    elif tt == 'train' :
                        train_data.append(data)
                        train_vertices.append(mesh.v)

        mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
        std_train = torch.Tensor(np.std(train_vertices, axis=0))
        norm_dict = {'mean': mean_train, 'std': std_train}
        if self.pre_transform is not None:
            if hasattr(self.pre_transform, 'mean') and hasattr(self.pre_transform, 'std'):
                if self.pre_tranform.mean is None:
                    self.pre_tranform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_tranform.std = std_train
            train_data = [self.pre_transform(td) for td in train_data]
            val_data = [self.pre_transform(td) for td in val_data]
            test_data = [self.pre_transform(td) for td in test_data]

        torch.save(self.collate(train_data), self.processed_paths[0])
        #torch.save(self.collate(val_data), self.processed_paths[1])
        torch.save(self.collate(test_data), self.processed_paths[1])
        torch.save(norm_dict, self.processed_paths[2])
Exemple #6
0
    def process(self):
        train_data, val_data, test_data = [], [], []
        train_vertices = []
        for key in self.data_file:
            for idx, data_file in tqdm(enumerate(self.data_file[key])):
                if key == 'lgtd':
                    mesh = Mesh(filename=data_file[0])
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(
                        np.vstack((adjacency.row, adjacency.col)))
                    mesh_m24 = Mesh(filename=data_file[1])
                    data = Data(x=mesh_verts,
                                y=torch.Tensor(mesh_m24.v),
                                edge_index=edge_index)
                elif key == 'lgtddx':
                    mesh = Mesh(filename=data_file[0])
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(
                        np.vstack((adjacency.row, adjacency.col)))
                    mesh_m24 = Mesh(filename=data_file[1])
                    data = Data(x=mesh_verts,
                                y=torch.Tensor(mesh_m24.v),
                                edge_index=edge_index,
                                label=torch.Tensor(data_file[2]))
                elif key == 'lgtdvc':
                    #print(data_file)
                    mesh = Mesh(filename=data_file[0])
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(
                        np.vstack((adjacency.row, adjacency.col)))
                    mesh_fu = Mesh(filename=data_file[1])
                    data = Data(x=mesh_verts,
                                y=torch.Tensor(mesh_fu.v),
                                edge_index=edge_index,
                                label=torch.Tensor(data_file[2]),
                                period=torch.Tensor(data_file[3]))
                else:
                    mesh = Mesh(filename=data_file)
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(
                        np.vstack((adjacency.row, adjacency.col)))
                    if key == 'ad':
                        data = Data(x=mesh_verts,
                                    y=mesh_verts,
                                    label=torch.Tensor([1, 0]),
                                    edge_index=edge_index)
                    elif key == 'cn':
                        data = Data(x=mesh_verts,
                                    y=mesh_verts,
                                    label=torch.Tensor([0, 1]),
                                    edge_index=edge_index)
                    elif key == 'all':
                        data = Data(x=mesh_verts,
                                    y=mesh_verts,
                                    edge_index=edge_index)

                if idx % 100 < 10:
                    test_data.append(data)
                    #print(data.period)
                elif idx % 100 < 20:
                    val_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)
        print(len(train_data), len(val_data), len(test_data))

        mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
        std_train = torch.Tensor(np.std(train_vertices, axis=0))
        norm_dict = {'mean': mean_train, 'std': std_train}
        if self.pre_transform is not None:
            if hasattr(self.pre_transform, 'mean') and hasattr(
                    self.pre_transform, 'std'):
                if self.pre_tranform.mean is None:
                    self.pre_tranform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_tranform.std = std_train
            train_data = [self.pre_transform(td) for td in train_data]
            val_data = [self.pre_transform(td) for td in val_data]
            test_data = [self.pre_transform(td) for td in test_data]

        torch.save(self.collate(train_data), self.processed_paths[0])
        torch.save(self.collate(val_data), self.processed_paths[1])
        torch.save(self.collate(test_data), self.processed_paths[2])
        torch.save(norm_dict, self.processed_paths[3])
Exemple #7
0
    def process(self):
        train_data, val_data, test_data = [], [], []
        train_vertices = []
        count = 0
        for idx, data_file in tqdm(enumerate(self.data_file)):
            fn = data_file.split('/')
            #print(fn[2])
            mesh = Mesh(filename=data_file)
            mesh_verts = torch.Tensor(mesh.v)
            # print(torch.mean(mesh_verts, dim = 0, keepdim = True))
            #mesh_verts = mesh_verts-torch.mean(mesh_verts, dim = 0, keepdim = True)
            '''
            if idx % 100 == 0:
                print(torch.mean(mesh_verts, dim=0))
                save_obj(mesh_verts, mesh.f+1, os.path.join('./vis/', 'data'+str(idx)+'.obj'))
            '''

            adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
            edge_index = torch.Tensor(np.vstack(
                (adjacency.row, adjacency.col))).long()

            data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)

            if self.split == 'sliced':
                if idx % 100 <= 10:
                    test_data.append(data)
            # elif idx % 100 <= 10:
            #    val_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)

            elif self.split == 'expression':
                if data_file.split('/')[-2] == self.split_term:
                    test_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)

            elif self.split == 'identity':
                if data_file.split('/')[-3] == self.split_term:
                    test_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)
            else:
                raise Exception(
                    'sliced, expression and identity are the only supported split terms'
                )

        if self.split != 'sliced':
            val_data = test_data[-self.nVal:]
            test_data = test_data[:-self.nVal]

        mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
        print(mean_train.shape)
        return
        std_train = torch.Tensor(np.std(train_vertices, axis=0))
        norm_dict = {'mean': mean_train, 'std': std_train}
        if self.pre_transform is not None:
            if hasattr(self.pre_transform, 'mean') and hasattr(
                    self.pre_transform, 'std'):
                if self.pre_tranform.mean is None:
                    self.pre_tranform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_tranform.std = std_train
            train_data = [self.pre_transform(td) for td in train_data]
            val_data = [self.pre_transform(td) for td in val_data]
            test_data = [self.pre_transform(td) for td in test_data]

        torch.save(self.collate(train_data), self.processed_paths[0])
        torch.save(self.collate(test_data), self.processed_paths[1])
        torch.save(self.collate(test_data), self.processed_paths[2])
        torch.save(norm_dict, self.processed_paths[3])
Exemple #8
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print(
            'No visual output directory is provided. Checkpoint directory will be used to store the visual results'
        )
        output_dir = checkpoint_dir

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()

    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    print('Loading model')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    checkpoint_file = config['checkpoint_file']
    print(checkpoint_file)
    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch_num']
        coma.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        #To find if this is fixed in pytorch
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    coma.to(device)
    print('making...')
    norm = torch.load('../processed_data/processed/sliced_norm.pt')
    normalize_transform.mean = norm['mean']
    normalize_transform.std = norm['std']

    #'0512','0901','0516','0509','0507','9305','0503','4919','4902',
    files = [
        '0514', '0503', '0507', '0509', '0512', '0501', '0901', '1001', '4902',
        '4913', '4919', '9302', '9305', '12411'
    ]

    coma.eval()

    meshviewer = MeshViewers(shape=(1, 2))
    for file in files:
        #mat = np.load('../Dress Dataset/'+file+'/'+file+'_pose.npz')
        mesh_dir = os.listdir('../processed_data/' + file + '/mesh/')
        latent = []
        print(len(mesh_dir))
        for i in tqdm(range(len(mesh_dir))):
            data_file = '../processed_data/' + file + '/mesh/' + str(
                i) + '.obj'
            mesh = Mesh(filename=data_file)
            adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
            edge_index = torch.Tensor(np.vstack(
                (adjacency.row, adjacency.col))).long()
            mesh_verts = (torch.Tensor(mesh.v) -
                          normalize_transform.mean) / normalize_transform.std
            data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)
            data = data.to(device)
            with torch.no_grad():
                out, feature = coma(data)
                latent.append(feature.cpu().detach().numpy())
            # print(feature.shape)
            if i % 50 == 0:
                expected_out = data.x
                out = out.cpu().detach(
                ) * normalize_transform.std + normalize_transform.mean
                expected_out = expected_out.cpu().detach(
                ) * normalize_transform.std + normalize_transform.mean
                out = out.numpy()
                save_obj(out, template_mesh.f + 1,
                         './vis/reconstruct_' + str(i) + '.obj')
                save_obj(expected_out, template_mesh.f + 1,
                         './vis/ori_' + str(i) + '.obj')

        np.save('./processed/0820/' + file, latent)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
Exemple #9
0
    def process(self, transform=None):
        train_data, val_data, test_data = [], [], []
        train_vertices = []
        for idx, data_file in tqdm(enumerate(self.data_file)):
            # print('\n\n' + data_file + '\n\n')
            mesh = Mesh(filename=data_file)
            mesh_verts = torch.Tensor(mesh.v)
            # print(mesh.v)
            adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
            # print(adjacency.row)
            edge_index = torch.Tensor(np.vstack(
                (adjacency.row, adjacency.col))).type(torch.LongTensor)
            # print(edge_index)

            # taken from pyG (transform.NormalizedScale())
            mesh_verts = mesh_verts - mesh_verts.mean(dim=-2, keepdim=True)
            scale = (1 / mesh_verts.abs().max()) * 0.999999
            mesh_verts = mesh_verts * scale

            data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)

            # taken from pyG (transform.NormalizedScale())
            # data.x = data.x - data.x.mean(dim=-2, keepdim=True)
            # scale = (1 / data.x.abs().max()) * 0.999999
            # data.x = data.x * scale

            if self.split == 'sliced':
                if idx % 100 <= 10:
                    test_data.append(data)
                elif idx % 100 <= 20:
                    val_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)

            elif self.split == 'expression':
                if data_file.split('/')[-2] == self.split_term:
                    test_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)

            elif self.split == 'identity':
                if data_file.split('/')[-3] == self.split_term:
                    test_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)
            else:
                raise Exception(
                    'sliced, expression and identity are the only supported split terms'
                )

        if self.split != 'sliced':
            val_data = test_data[-self.nVal:]
            test_data = test_data[:-self.nVal]

        mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
        std_train = torch.Tensor(np.std(train_vertices, axis=0))
        norm_dict = {'mean': mean_train, 'std': std_train}
        if self.pre_transform is not None:
            if hasattr(self.pre_transform, 'mean') and hasattr(
                    self.pre_transform, 'std'):
                if self.pre_tranform.mean is None:
                    self.pre_tranform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_tranform.std = std_train
            train_data = [self.pre_transform(td) for td in train_data]
            val_data = [self.pre_transform(td) for td in val_data]
            test_data = [self.pre_transform(td) for td in test_data]

            print('train_data[0].x : {}'.format(train_data[0].x))
            print('\n\n')
            print('train_data[0].x : {}'.format(train_data[1].x))

        torch.save(self.collate(train_data), self.processed_paths[0])
        torch.save(self.collate(val_data), self.processed_paths[1])
        torch.save(self.collate(test_data), self.processed_paths[2])
        torch.save(norm_dict, self.processed_paths[3])