示例#1
0
def pytorch_net_tutorial():
    # What is it? PyTorchNet is a derived class of LightningModule, allowing for extended operations on it
    # Let's see some of them:
    nn = F2PEncoderDecoder(
    )  # Remember that F2PEncoderDecoder is a subclass of PytorchNet
    nn.identify_system(
    )  # Outputs the specs of the current system - Useful for identifying existing GPUs

    banner('General Net Info')
    target_input_size = ((6890, 3), (6890, 3))
    # nn.summary(x_shape=target_input_size)
    # nn.summary(batch_size=3, x_shape=target_input_size)
    print(f'On GPU = {nn.ongpu()}'
          )  # Whether the system is on the GPU or not. Will print False
    nn.print_memory_usage(device=0)  # Print GPU 0's memory consumption
    # print(f'Output size = {nn.output_size(x_shape=target_input_size)}')
    nn.print_weights()
    # nn.visualize(x_shape=target_input_size, frmt='pdf')  # Prints PDF to current directory

    # Let's say we have some sort of nn.Module lightning:
    banner('Some lightning from the net usecase')
    import torchvision
    nn = torchvision.models.alexnet(pretrained=False)
    # We can extend it's functionality at runtime with a monkeypatch:
    py_nn = PytorchNet.monkeypatch(nn)
    nn.print_weights()
    py_nn.summary(x_shape=(3, 28, 28), batch_size=64)
示例#2
0
def project_smal_main():
    banner('SMAL Projection')
    deformer = Projection(max_angs=10, angs_to_take=10)
    in_dp = ROOT / '..' / 'data' / 'synthetic' / 'SMALTestPyProj' / 'full'
    m = SMALCreator(deformer, in_dp)
    for sub in m.subjects():
        m.deform_subject(sub=sub)
示例#3
0
def project_faust_scan_test_main():
    banner('FAUST Test Scans Projection')
    deformer = Projection(max_angs=10, angs_to_take=10)
    in_dp = ROOT / '..' / 'data' / 'scan' / 'FaustTestScanPyProj' / 'full'
    m = FaustTestScanCreator(deformer, in_dp)
    for sub in m.subjects():
        m.deform_subject(sub=sub)
示例#4
0
    def train(self, debug_mode=False):
        banner('Training Phase')
        if self.trainer is None:
            self._init_training_assets()
            log.info(f'Training on dataset: {self.data.curr_trainset_name()}')
            self.testing_only = False

        self._trainer(debug_mode).fit(self.nn, self.data.train_ldr,
                                      self.data.vald_ldrs, self.data.test_ldrs)
示例#5
0
def project_mixamo_main():
    if os.name == 'nt':
        in_dp = Path(r'Z:\ShapeCompletion\Mixamo\Blender\MPI-FAUST')
    else:  # Presuming Linux
        in_dp = Path(
            r"/usr/samba_mount/ShapeCompletion/Mixamo/Blender/MPI-FAUST")

    banner('MIXAMO Projection')
    deformer = Projection(max_angs=10, angs_to_take=2)
    m = MixamoCreator(deformer, in_dp, shape_frac_from_vgroup=1)
    for sub in m.subjects():
        m.deform_subject(sub=sub)
示例#6
0
def train_main():
    banner('Network Init')
    nn = F2PEncoderDecoder(parser())
    nn.identify_system()

    # Bring in data:
    ldrs = mixamo_loader_set(nn.hp)

    # Supply the network with the loaders:
    trainer = LightningTrainer(nn, ldrs)
    trainer.train()
    trainer.test()
    trainer.finalize()
示例#7
0
    def deform_subject(self, sub):
        banner(
            title(
                f'{self.dataset_name()} Dataset :: Subject {sub} :: Deformation {self.deform_name()} Commencing'
            ))
        (self.tmp_dp / sub).mkdir(
            exist_ok=True)  # TODO - Presuming this dir structure

        if self.deformer.needs_validation():
            self._deform_subject_validated(sub)
        else:
            self._deform_subject_unvalidated(sub)
        banner(f'Deformation of Subject {sub} - COMPLETED')
示例#8
0
def hit_test():
    # import time
    from pprint import pprint
    # print = (lambda p: lambda *args, **kwargs: [p(*args, **kwargs), time.sleep(.01)])(print)
    hit = {
        'Subject1': {
            'Pose1': 200,
            'Pose2': 300,
            # 'Pose3': {
            #     'Seq1': 500,
            #     'Seq2': 600
            # }
        },
        'Subject2': {
            'Pose1': 1,
            'Pose2': 4,
            'Pose3': 4,
        },
        'Subject3': {
            'Pose1': 100,
            'Pose2': 50,
        },
        # 'Subject 4': 3
    }

    hit_mem = HierarchicalIndexTree(hit, in_memory=True)
    hit_out_mem = HierarchicalIndexTree(hit, in_memory=False)

    banner('In Memory vs Out of Memory Tests')
    print(hit_mem)
    print(hit_mem.depth())
    print(hit_mem.num_indexed())
    print(hit_out_mem)
    print(hit_out_mem.depth())
    print(hit_out_mem.num_indexed())

    ids = range(hit_mem.num_indexed())
    for i in ids:
        # print(hit_mem.si2hi(i))
        # print(hit_out_mem.si2hi(i))
        assert (hit_mem.si2hi(i) == hit_out_mem.si2hi(i))

    banner('ID Union Tests')
    pprint(hit_mem.get_id_union_by_depth(depth=1))
    pprint(hit_mem.get_id_union_by_depth(depth=2))
    banner('Removal Tests')
    print(hit_mem.remove_ids_by_depth('Subject1', depth=1))
    print(hit_mem.remove_ids_by_depth(['Subject1', 'Subject2'], 1))
    print(hit_mem.remove_ids_by_depth(['Subject1', 'Subject2', 'Subject3'], 1))
    print(hit_mem.remove_ids_by_depth(['Pose1'], 2))
    print(hit_mem.remove_ids_by_depth(['Pose1', 'Pose2'], 2))
    print(hit_mem.remove_ids_by_depth(['Pose1', 'Pose2', 'Pose3'], 2))
    banner('Random Tests')
    print(hit_mem.random_path_from_partial_path())
    print(hit_mem.random_path_from_partial_path())
    print(hit_mem.random_path_from_partial_path(('Subject1', 'Pose2')))
示例#9
0
def dataset_tutorial():
    # Use the menu to see which datasets are implemented
    print(FullPartDatasetMenu.which())
    ds = FullPartDatasetMenu.get(
        'FaustPyProj')  # This will fail if you don't have the data on disk
    # ds.validate_dataset()  # Make sure all files are available - Only run this once, to make sure.

    # For simplicity's sake, we support the old random dataloader as well:
    ldr = ds.rand_loader(num_samples=1000,
                         transforms=[Center()],
                         batch_size=16,
                         n_channels=6,
                         device='cpu-single',
                         mode='f2p')
    for point in ldr:
        print(point)
        break

    banner('The HIT')
    ds.report_index_tree(
    )  # Take a look at how the dataset is indexed - using the hit [HierarchicalIndexTree]

    banner('Collateral Info')
    print(f'Dataset Name = {ds.name()}')
    print(f'Number of indexed files = {ds.num_indexed()}')
    print(f'Number of full shapes = {ds.num_full_shapes()}')
    print(f'Number of projections = {ds.num_projections()}')
    print(f'Required disk space in bytes = {ds.disk_space()}')
    # You can also request a summary printout with:
    ds.data_summary(with_tree=False)  # Don't print out the tree again

    # For models with a single set of faces (SMPL or SMLR for example) you can request the face set/number of vertices
    # directly:
    banner('Face Array')
    print(ds.faces())
    print(ds.num_faces())
    print(ds.num_verts())
    # You can also ask for the null-shape the dataset - with hi : [0,0...,0]
    print(ds.null_shape(n_channels=6))
    ds.plot_null_shape(strategy='spheres', with_vnormals=True)

    # Let's look at the various sampling methods available to us:
    print(ds.defined_methods())
    # We can ask for a sample of the data with this sampling method:
    banner('Data Sample')
    samp = ds.sample(num_samples=2,
                     transforms=[Center(keys=['gt'])],
                     n_channels=6,
                     method='full')
    print(samp)  # Dict with gt_hi & gt
    print(ds.num_datapoints_by_method('full'))  # 100

    samp = ds.sample(num_samples=2,
                     transforms=[Center(keys=['gt'])],
                     n_channels=6,
                     method='part')
    print(samp)  # Dict with gt_hi & gt & gt_mask & gt_mask
    print(ds.num_datapoints_by_method('part'))  # 1000

    samp = ds.sample(num_samples=2,
                     transforms=[Center(keys=['gt'])],
                     n_channels=6,
                     method='f2p')
    print(samp)  # Dict with gt_hi & gt & gt_mask & gt_mask & tp
    print(ds.num_datapoints_by_method(
        'f2p'))  # 10000 tuples of (gt,tp) where the subjects are the same

    # # You can also ask for a simple loader, given by the ids you'd like to see.
    # # Pass ids = None to index the entire dataset, form point_cloud = 0 to point_cloud = num_datapoints_by_method -1
    banner('Loaders')
    single_ldr = ds.loaders(s_nums=1000,
                            s_shuffle=True,
                            s_transform=[Center()],
                            n_channels=6,
                            method='f2p',
                            batch_size=3,
                            device='cpu-single')
    for d in single_ldr:
        print(d)
        break

    print(single_ldr.num_verts())
    # There are also operations defined on the loaders themselves. See utils.torch_data for details

    # To receive train/validation splits or train/validation/test splits use:
    my_loaders = ds.loaders(split=[0.8, 0.1, 0.1],
                            s_nums=[2000, 1000, 1000],
                            s_shuffle=[True] * 3,
                            s_transform=[Center()] * 3,
                            global_shuffle=True,
                            method='p2p',
                            s_dynamic=[True, False, False])
示例#10
0
 def wrapper(*args, **kwargs):
     banner(title(func.__name__))
     return func(*args, **kwargs)
示例#11
0
 def print_weights(self):
     banner('Weights')
     for i, weights in enumerate(list(self.parameters())):
         print(f'Layer {i} :: weight shape: {list(weights.size())}')
示例#12
0
 def test(self):
     banner('Testing Phase')
     self._trainer().test(self.nn, self.data.test_ldrs)  # Sets the trainer