Пример #1
0
def test_composition():
    pc = PointCloud(np.array([[10, 10, 10], [20, 20, 20]]))
    hc = HybridCloud(np.array([[10, 10, 10], [20, 20, 20]]),
                     np.array([[0, 1]]),
                     vertices=np.array([[10, 10, 10], [20, 20, 20]]))

    transform = clouds.Compose([
        clouds.Normalization(10),
        clouds.RandomRotate((60, 60)),
        clouds.Center()
    ])
    transform(pc)
    transform(hc)

    assert np.all(
        np.round(np.mean(pc.vertices, axis=0)) == np.array([0, 0, 0]))
    assert np.all(
        np.round(np.mean(hc.vertices, axis=0)) == np.array([0, 0, 0]))

    dummy = np.array([[10, 10, 10], [20, 20, 20]]) / 10
    angle_range = (60, 60)
    angles = np.random.uniform(angle_range[0], angle_range[1], (1, 3))[0]
    rot = Rot.from_euler('xyz', angles, degrees=True)
    dummy = rot.apply(dummy)
    centroid = np.mean(dummy, axis=0)
    dummy = dummy - centroid

    assert np.all(pc.vertices == dummy)
    assert np.all(hc.vertices == dummy)
    assert np.all(hc.vertices == dummy)
Пример #2
0
def apply_chunkhandler_ssd():
    data = SuperSegmentationDataset(
        working_dir="/wholebrain/songbird/j0126/areaxfs_v6/")
    ssd_include = [491527, 1090051]
    chunk_size = 4000
    features = {'sv': 1, 'mi': 2, 'vc': 3, 'syn_ssv': 4}
    transform = clouds.Compose([clouds.Center()])

    ch = ChunkHandler(data=data,
                      sample_num=4000,
                      density_mode=False,
                      specific=False,
                      ctx_size=chunk_size,
                      obj_feats=features,
                      splitting_redundancy=1,
                      sampling=True,
                      transform=transform,
                      ssd_include=ssd_include,
                      ssd_labels='axoness',
                      label_mappings=[(3, 2), (4, 3), (5, 1), (6, 1)])

    save_path = os.path.expanduser('~/thesis/current_work/chunkhandler_tests/')
    ix = 0
    while ix < 500:
        sample1 = ch[ix]
        sample2 = ch[ix + 1]
        ix += 2
        sample = [sample1, sample2]
        with open(f'{save_path}{ix}.pkl', 'wb') as f:
            pickle.dump(sample, f)
        f.close()
    ch.terminate()
Пример #3
0
def apply_chunkhandler(save_path: str):
    path = os.path.expanduser('~/thesis/gt/20_09_27/voxeled/test/')
    chunk_size = 12000
    features = {
        'hc': np.array([1, 0, 0, 0]),
        'mi': np.array([0, 1, 0, 0]),
        'vc': np.array([0, 0, 1, 0]),
        'sy': np.array([0, 0, 0, 1])
    }
    identity = clouds.Compose([clouds.Center()])
    ch = ChunkHandler(path,
                      sample_num=10000,
                      density_mode=False,
                      specific=False,
                      ctx_size=chunk_size,
                      obj_feats=features,
                      transform=identity,
                      splitting_redundancy=1,
                      sampling=True,
                      split_on_demand=False,
                      label_remove=[-2])
    info = ch.get_set_info()
    print(info['node_labels'])
    print(info['labels'])
    import ipdb
    ipdb.set_trace()
Пример #4
0
def compare_chunks():
    """ Create chunks with different ChunkHandlers and compare the results. """
    path = os.path.expanduser('~/thesis/current_work/augmentation_tests/')
    features = {
        'hc': np.array([1, 0, 0, 0]),
        'mi': np.array([0, 1, 0, 0]),
        'vc': np.array([0, 0, 1, 0]),
        'sy': np.array([0, 0, 0, 1])
    }
    transforms1 = clouds.Compose([
        clouds.Center(),
        clouds.RandomScale(distr_scale=0.6, distr='uniform')
    ])
    ch1 = ChunkHandler(path,
                       sample_num=4000,
                       density_mode=False,
                       specific=True,
                       ctx_size=4000,
                       obj_feats=features,
                       transform=transforms1)
    transforms2 = clouds.Compose([clouds.Center()])
    ch2 = ChunkHandler(path,
                       sample_num=4000,
                       density_mode=False,
                       specific=True,
                       ctx_size=4000,
                       obj_feats=features,
                       transform=transforms2)
    save_path = path + 'scale/'
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    for item in ch1.obj_names:
        for i in range(10):
            sample1, _ = ch1[(item, i)]
            sample2, _ = ch2[(item, i)]
            samples = [sample1, sample2]
            # meshes = [clouds.merge_clouds([sample1, meshes[0]]), clouds.merge_clouds([sample2, meshes[0]])]
            with open(f'{save_path}{item}_{i}.pkl', 'wb') as f:
                pickle.dump(samples, f)
            f.close()
Пример #5
0
def produce_chunks(chunk_size: int, sample_num: int):
    """ Create and analyse all resulting chunks of an dataset. """
    features = {'hc': np.array([1])}
    center = clouds.Compose([clouds.Identity()])
    path = os.path.expanduser('~/working_dir/gt/cmn/dnh/voxeled/')
    save_path = f'{path}analysis/'
    ch = ChunkHandler(path,
                      sample_num=sample_num,
                      density_mode=False,
                      ctx_size=chunk_size,
                      obj_feats=features,
                      transform=center,
                      splitting_redundancy=5,
                      label_mappings=[(2, 0), (5, 1), (6, 2)],
                      label_remove=[2],
                      sampling=True,
                      verbose=True,
                      specific=True,
                      hybrid_mode=True)
    vert_nums = []
    counter = 0
    chunk_num = 0
    for item in ch.obj_names:
        chunk_num += ch.get_obj_length(item)
        for i in range(ch.get_obj_length(item)):
            sample, idcs, vert_num = ch[(item, i)]
            if vert_num < 10000:
                if not os.path.exists(save_path + f'examples/{item}/'):
                    os.makedirs(save_path + f'examples/{item}/')
                with open(f'{save_path}examples/{item}/{i}.pkl', 'wb') as f:
                    pickle.dump([sample, sample], f)
            vert_nums.append(vert_num)
            if vert_num < ch.sample_num:
                counter += 1
    vert_nums = np.array(vert_nums)
    analysis = f"Min: {vert_nums.min()}\nMax: {vert_nums.max()}\nMean: {vert_nums.mean()}\nChunks with less points than requested: {counter}/{chunk_num}"
    print(analysis)
    with open(f'{save_path}{chunk_size}_vertnums.pkl', 'wb') as f:
        pickle.dump(vert_nums, f)
    with open(f'{save_path}{chunk_size}_{sample_num}.txt', 'w') as f:
        f.write(analysis)
    f.close()
    return counter / chunk_num
Пример #6
0
def apply_torchhandler():
    path = os.path.expanduser('~/thesis/gt/cmn/dnh/test/')
    chunk_size = 5000
    features = {'hc': np.array([1])}
    identity = clouds.Compose([clouds.Center()])
    th = TorchHandler(path,
                      sample_num=5000,
                      density_mode=False,
                      tech_density=100,
                      bio_density=100,
                      specific=False,
                      ctx_size=chunk_size,
                      obj_feats=features,
                      transform=identity,
                      splitting_redundancy=1,
                      sampling=True,
                      split_on_demand=True,
                      nclasses=4,
                      feat_dim=1,
                      hybrid_mode=True,
                      exclude_borders=True)
    for ix in range(len(th)):
        sample = th[ix]
Пример #7
0
    def __init__(self,
                 data: Union[str, SuperSegmentationDataset],
                 sample_num: int,
                 density_mode: bool = True,
                 bio_density: float = None,
                 tech_density: int = None,
                 ctx_size: int = None,
                 transform: clouds.Compose = clouds.Compose(
                     [clouds.Identity()]),
                 specific: bool = False,
                 data_type: str = 'ce',
                 obj_feats: dict = None,
                 label_mappings: List[Tuple[int, int]] = None,
                 hybrid_mode: bool = False,
                 splitting_redundancy: int = 1,
                 label_remove: List[int] = None,
                 sampling: bool = True,
                 force_split: bool = False,
                 padding: int = None,
                 verbose: bool = False,
                 split_on_demand: bool = False,
                 split_jitter: int = 0,
                 epoch_size: int = None,
                 workers: int = 2,
                 voxel_sizes: Optional[dict] = None,
                 ssd_exclude: List[int] = None,
                 ssd_include: List[int] = None,
                 ssd_labels: str = None,
                 exclude_borders: int = 0,
                 rebalance: dict = None):
        """
        Args:
            data: Path to objects saved as pickle files. Existing chunking information would
                be available in the folder 'splitted' at this location.
            sample_num: Number of vertices which should be sampled from the surface of each chunk.
                Should be equal to the capacity of the given network architecture.
            tech_density: poisson sampling density with which data set was preprocessed in point/um²
            bio_density: chunk sampling density in point/um². This determines the size of the chunks.
                If previous chunking information should be used, this information must be available
                in the splitted/ folder with 'bio_density' as name.
            transform: Transformations which should be applied to the chunks before returning them
                (e.g. see :func:`morphx.processing.clouds.Compose`)
            specific: Flag for setting mode of requesting specific or rather randomly drawn chunks.
            data_type: Type of dataset, 'ce': CloudEnsembles, 'hc': HybridClouds
            obj_feats: Only used when inputs are CloudEnsembles. Dict with feature array (1, n) keyed by
                the name of the corresponding object in the CloudEnsemble. The HybridCloud gets addressed
                with 'hc'.
            label_mappings: list of labels which should get replaced by other labels. E.g. [(1, 2), (3, 2)]
                means that the labels 1 and 3 will get replaced by 3.
            splitting_redundancy: indicates how many times each skeleton node is included in different contexts.
            label_remove: List of labels indicating which nodes should be removed from the dataset. This is
                is independent from the label_mappings, as the label removal is done during splitting.
            sampling: Flag for random sampling from the extracted subsets.
            force_split: Split dataset again even if splitting information exists.
            padding: add padded points if a subset contains less points than there should be sampled.
            verbose: Return additional information about size of subsets.
            split_on_demand: Do not generate splitting information in advance, but rather generate chunks on the fly.
            split_jitter: Used only if split_on_demand = True. Adds jitter to the context size of the generated chunks.
            epoch_size: Parameter for epoch size that can be used when dataset size is unknown and epoch size should
                somehow be bounded.
            workers: Number of workers in case of ssd dataset.
            voxel_sizes: Voxelization options in case of ssd dataset use. Given as dict with voxel sizes keyed by
                cell part identifier (e.g. 'sv' or 'mi').
            exclude_borders: Offset radius (chunk_size - exclude_border) for excluding border regions of chunks from
                loss calculation.
            rebalance: dict for rebalancing of dataset if certain classes dominate. dict contains factor keyed by labels
                where the factor indicate how often the labels should get resampled. This was introduced for rebalancing
                the CMN ads dataset. Now this is outcommented and replaced by a hacky version for terminals.
        """
        if type(data) == SuperSegmentationDataset:
            self._data = data
        else:
            self._data = os.path.expanduser(data)
            if not os.path.exists(self._data):
                os.makedirs(self._data)

            # --- split cells into chunks and save this split information to file for later loading ---
            if not split_on_demand:
                if not os.path.exists(self._data + 'splitted/'):
                    os.makedirs(self._data + 'splitted/')
                self._splitfile = ''
                if density_mode:
                    if bio_density is None or tech_density is None:
                        raise ValueError(
                            "Density mode requires bio_density and tech_density"
                        )
                    self._splitfile = f'{self._data}splitted/d{bio_density}_p{sample_num}' \
                                      f'_r{splitting_redundancy}_lr{label_remove}.pkl'
                else:
                    if ctx_size is None:
                        raise ValueError("Context mode requires chunk_size.")
                    self._splitfile = f'{self._data}splitted/s{ctx_size}_r{splitting_redundancy}_lr{label_remove}.pkl'
                self._splitted_objs = None
                orig_splitfile = self._splitfile
                while os.path.exists(self._splitfile):
                    if not force_split:
                        # continue with existing split information
                        with open(self._splitfile, 'rb') as f:
                            self._splitted_objs = pickle.load(f)
                        f.close()
                        break
                    else:
                        # generate new split information without overriding the old
                        version = re.findall(r"v(\d+).", self._splitfile)
                        if len(version) == 0:
                            self._splitfile = self._splitfile[:-4] + '_v1.pkl'
                        else:
                            version = int(version[0])
                            self._splitfile = orig_splitfile[:-4] + f'_v{version + 1}.pkl'
                # actual splitting happens here
                splitting.split(data,
                                self._splitfile,
                                bio_density=bio_density,
                                capacity=sample_num,
                                tech_density=tech_density,
                                density_splitting=density_mode,
                                chunk_size=ctx_size,
                                splitted_hcs=self._splitted_objs,
                                redundancy=splitting_redundancy,
                                label_remove=label_remove,
                                split_jitter=split_jitter)
                with open(self._splitfile, 'rb') as f:
                    self._splitted_objs = pickle.load(f)
                f.close()

        self._voxel_sizes = dict(sv=80, mi=100, syn_ssv=100, vc=100)
        if voxel_sizes is not None:
            self._voxel_sizes = voxel_sizes
        self._sample_num = sample_num
        self._transform = transform
        self._specific = specific
        self._data_type = data_type
        self._obj_feats = obj_feats
        self._label_mappings = label_mappings
        self._hybrid_mode = hybrid_mode
        self._label_remove = label_remove
        self._sampling = sampling
        self._padding = padding
        self._verbose = verbose
        self._split_on_demand = split_on_demand
        self._bio_density = bio_density
        self._tech_density = tech_density
        self._density_mode = density_mode
        self._chunk_size = ctx_size
        self._splitting_redundancy = splitting_redundancy
        self._split_jitter = split_jitter
        self._epoch_size = epoch_size
        self._workers = workers
        self._ssd_labels = ssd_labels
        self._ssd_exclude = ssd_exclude
        self._rebalance = rebalance
        self._exclude_borders = exclude_borders
        if ssd_exclude is None:
            self._ssd_exclude = []
        self._ssd_include = ssd_include
        if self._ssd_labels is None and type(
                self._data) == SuperSegmentationDataset:
            raise ValueError(
                "ssd_labels must be specified when working with a SuperSegmentationDataset!"
            )
        self._obj_names = []
        self._objs = []
        self._chunk_list = []
        self._parts = {}

        if type(data) == SuperSegmentationDataset:
            self._load_func = self.get_item_ssd
        elif self._specific:
            self._load_func = self.get_item_specific
        else:
            self._load_func = self.get_item

        # --- dataloader for experiments when using CMN predictions as ground truth ---
        if type(self._data) == SuperSegmentationDataset:
            for key in self._obj_feats:
                self._parts[key] = [
                    self._voxel_sizes[key], self._obj_feats[key]
                ]
            # If ssd dataset is given, multiple workers are used for splitting the ssvs of the given dataset.
            self._obj_names = Queue()
            self._chunk_list = Queue(maxsize=10000)
            if self._ssd_include is None:
                sizes = [sso.size for sso in self._data.ssvs]
                idcs = np.argsort(sizes)
                self._ssd_include = np.array(self._data.ssv_ids)[idcs[-200:]]
            for ssv in self._ssd_include:
                if ssv not in self._ssd_exclude:
                    self._obj_names.put(ssv)
            self._splitters = [
                Process(target=worker_split,
                        args=(self._obj_names, self._chunk_list, self._data,
                              self._chunk_size,
                              self._chunk_size / self._splitting_redundancy,
                              self._parts, self._ssd_labels,
                              self._label_mappings, self._split_jitter))
                for ix in range(workers)
            ]
            for splitter in self._splitters:
                splitter.start()

        # --- dataloader for experiments with cells saved as pickle files ---
        else:
            files = glob.glob(data + '*.pkl')
            for file in files:
                slashs = [pos for pos, char in enumerate(file) if char == '/']
                name = file[slashs[-1] + 1:-4]
                self._obj_names.append(name)
                if not self._specific:
                    # load entire dataset into memory
                    obj = self._adapt_obj(
                        objects.load_obj(self._data_type, file))
                    self._objs.append(obj)
            if not self._specific:
                if split_on_demand:
                    # do not use split information from file but split cells on the fly
                    for ix, obj in enumerate(tqdm(self._objs)):
                        base_nodes = np.arange(len(obj.nodes)).reshape(
                            -1, 1)[obj.node_labels != -1]
                        base_nodes = np.random.choice(base_nodes,
                                                      int(len(base_nodes) / 3),
                                                      replace=True)
                        chunks = context_splitting_kdt(obj, base_nodes,
                                                       self._chunk_size)
                        for chunk in chunks:
                            self._chunk_list.append((ix, chunk))
                else:
                    # use split information from file
                    for item in self._splitted_objs:
                        if item in self._obj_names:
                            for idx in range(len(self._splitted_objs[item])):
                                self._chunk_list.append((item, idx))
                if self._rebalance is not None:
                    # rebalance occurence of chunks by using chunks which contain specific labels multiple times
                    print("Rebalancing...")
                    balance = {}
                    for key in self._rebalance:
                        balance[key] = 0
                    for ix in tqdm(range(len(self._chunk_list))):
                        item = self._chunk_list[ix]
                        obj = self._objs[self._obj_names.index(item[0])]
                        for key in self._rebalance:
                            if key in np.unique(obj.labels):
                                for i in range(self._rebalance[key]):
                                    self._chunk_list.append(item)
                                    balance[key] += 1
                    print("Done with rebalancing!")
                    print(balance)
                random.shuffle(self._chunk_list)

        self._curr_obj = None
        self._curr_name = None
        self._ix = 0
        self._size = len(self._chunk_list)
Пример #8
0
def compare_transforms(chunk_size: int, sample_num: int):
    """ Create and save all resulting chunks of an dataset with different transforms """
    # features = {'hc': np.array([1, 0, 0, 0]),
    #             'mi': np.array([0, 1, 0, 0]),
    #             'vc': np.array([0, 0, 1, 0]),
    #             'sy': np.array([0, 0, 0, 1])}
    features = {'hc': np.array([1])}
    identity = clouds.Compose([clouds.Identity()])
    center = clouds.Compose([clouds.Center()])
    path = os.path.expanduser('~/thesis/gt/cmn/dnh/voxeled/')
    save_path = f'{path}examples/'
    ch = ChunkHandler(path,
                      sample_num=sample_num,
                      density_mode=False,
                      tech_density=100,
                      bio_density=100,
                      specific=True,
                      ctx_size=chunk_size,
                      obj_feats=features,
                      transform=identity,
                      splitting_redundancy=2,
                      label_mappings=[(5, 3), (6, 4)],
                      label_remove=None,
                      sampling=True,
                      verbose=True)
    ch_transform = ChunkHandler(path,
                                sample_num=5000,
                                density_mode=False,
                                tech_density=100,
                                bio_density=100,
                                specific=True,
                                ctx_size=chunk_size,
                                obj_feats=features,
                                transform=center,
                                splitting_redundancy=2,
                                label_mappings=[(5, 3), (6, 4)],
                                label_remove=None,
                                sampling=True,
                                verbose=True)
    vert_nums = []
    counter = 0
    chunk_num = 0
    total = None
    for item in ch.obj_names:
        total_cell = None
        chunk_num += ch.get_obj_length(item)
        for i in range(ch.get_obj_length(item)):
            sample, idcs, vert_num = ch[(item, i)]
            sample_t, _, _ = ch_transform[(item, i)]
            vert_nums.append(vert_num)
            if not os.path.exists(save_path + f'{item}/'):
                os.makedirs(save_path + f'{item}/')
            if vert_num < ch.sample_num:
                counter += 1
            with open(f'{save_path}{item}/{i}.pkl', 'wb') as f:
                pickle.dump([sample, sample_t], f)
            if total_cell is None:
                total_cell = sample
            else:
                total_cell = clouds.merge_clouds([total_cell, sample])
        if total is None:
            total = total_cell
        else:
            total = clouds.merge_clouds([total, total_cell])
        with open(f'{save_path}{item}/total.pkl', 'wb') as f:
            pickle.dump(total_cell, f)
    with open(f'{save_path}total.pkl', 'wb') as f:
        pickle.dump(total, f)
    vert_nums = np.array(vert_nums)
    print(f"Min: {vert_nums.min()}")
    print(f"Max: {vert_nums.max()}")
    print(f"Mean: {vert_nums.mean()}")
    print(f"Chunks with less points than requested: {counter}/{chunk_num}")
    with open(f'{save_path}{chunk_size}_vertnums.pkl', 'wb') as f:
        pickle.dump(vert_nums, f)
    f.close()
Пример #9
0
 def __init__(self,
              data_path: str,
              sample_num: int,
              nclasses: int,
              feat_dim: int,
              density_mode: bool = True,
              bio_density: float = None,
              tech_density: int = None,
              ctx_size: int = None,
              transform: clouds.Compose = clouds.Compose(
                  [clouds.Identity()]),
              specific: bool = False,
              data_type: str = 'ce',
              obj_feats: dict = None,
              label_mappings: List[Tuple[int, int]] = None,
              hybrid_mode: bool = False,
              splitting_redundancy: int = 1,
              label_remove: List[int] = None,
              sampling: bool = True,
              force_split: bool = False,
              padding: int = None,
              split_on_demand: bool = False,
              split_jitter: int = 0,
              epoch_size: int = None,
              workers: int = 2,
              voxel_sizes: Optional[dict] = None,
              ssd_exclude: List[int] = None,
              ssd_include: List[int] = None,
              ssd_labels: str = None,
              exclude_borders: int = 0,
              rebalance: dict = None,
              extend_no_pred: List[int] = None):
     self._ch = ChunkHandler(data_path,
                             sample_num,
                             density_mode=density_mode,
                             bio_density=bio_density,
                             tech_density=tech_density,
                             ctx_size=ctx_size,
                             transform=transform,
                             specific=specific,
                             data_type=data_type,
                             obj_feats=obj_feats,
                             label_mappings=label_mappings,
                             hybrid_mode=hybrid_mode,
                             splitting_redundancy=splitting_redundancy,
                             label_remove=label_remove,
                             sampling=sampling,
                             force_split=force_split,
                             padding=padding,
                             split_on_demand=split_on_demand,
                             split_jitter=split_jitter,
                             epoch_size=epoch_size,
                             workers=workers,
                             voxel_sizes=voxel_sizes,
                             ssd_exclude=ssd_exclude,
                             ssd_include=ssd_include,
                             ssd_labels=ssd_labels,
                             rebalance=rebalance,
                             exclude_borders=exclude_borders)
     self._specific = specific
     self._nclasses = nclasses
     self._sample_num = sample_num
     self._feat_dim = feat_dim
     self._padding = padding
     self._extend_no_pred = extend_no_pred
Пример #10
0
def training_thread(acont: ArgsContainer):
    torch.cuda.empty_cache()
    lr = 1e-3
    lr_stepsize = 10000
    lr_dec = 0.995
    max_steps = int(acont.max_step_size / acont.batch_size)

    torch.manual_seed(acont.random_seed)
    np.random.seed(acont.random_seed)
    random.seed(acont.random_seed)

    if acont.use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    lcp_flag = False
    # load model
    if acont.architecture == 'lcp' or acont.model == 'ConvAdaptSeg':
        kwargs = {}
        if acont.model == 'ConvAdaptSeg':
            kwargs = dict(kernel_num=acont.pl, architecture=acont.architecture, activation=acont.act, norm=acont.norm_type)
        conv = dict(layer=acont.conv[0], kernel_separation=acont.conv[1])
        model = ConvAdaptSeg(acont.input_channels, acont.class_num, get_conv(conv), get_search(acont.search), **kwargs)
        lcp_flag = True
    elif acont.use_big:
        model = SegBig(acont.input_channels, acont.class_num, trs=acont.track_running_stats, dropout=acont.dropout,
                       use_bias=acont.use_bias, norm_type=acont.norm_type, use_norm=acont.use_norm,
                       kernel_size=acont.kernel_size, neighbor_nums=acont.neighbor_nums, reductions=acont.reductions,
                       first_layer=acont.first_layer, padding=acont.padding, nn_center=acont.nn_center,
                       centroids=acont.centroids, pl=acont.pl, normalize=acont.cp_norm)
    else:
        model = SegAdapt(acont.input_channels, acont.class_num, architecture=acont.architecture,
                         trs=acont.track_running_stats, dropout=acont.dropout, use_bias=acont.use_bias,
                         norm_type=acont.norm_type, kernel_size=acont.kernel_size, padding=acont.padding,
                         nn_center=acont.nn_center, centroids=acont.centroids, kernel_num=acont.pl,
                         normalize=acont.cp_norm, act=acont.act)

    batch_size = acont.batch_size

    train_transforms = clouds.Compose(acont.train_transforms)
    train_ds = TorchHandler(data_path=acont.train_path, sample_num=acont.sample_num, nclasses=acont.class_num,
                            feat_dim=acont.input_channels, density_mode=acont.density_mode,
                            ctx_size=acont.chunk_size, bio_density=acont.bio_density,
                            tech_density=acont.tech_density, transform=train_transforms,
                            obj_feats=acont.features, label_mappings=acont.label_mappings,
                            hybrid_mode=acont.hybrid_mode, splitting_redundancy=acont.splitting_redundancy,
                            label_remove=acont.label_remove, sampling=acont.sampling, padding=acont.padding,
                            split_on_demand=acont.split_on_demand, split_jitter=acont.split_jitter,
                            epoch_size=acont.epoch_size, workers=acont.workers, voxel_sizes=acont.voxel_sizes,
                            ssd_exclude=acont.ssd_exclude, ssd_include=acont.ssd_include,
                            ssd_labels=acont.ssd_labels, exclude_borders=acont.exclude_borders,
                            rebalance=acont.rebalance, extend_no_pred=acont.extend_no_pred)

    if acont.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif acont.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.5e-5)
    else:
        raise ValueError('Unknown optimizer')

    if acont.scheduler == 'steplr':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)
    elif acont.scheduler == 'cosannwarm':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5000, T_mult=2)
    else:
        raise ValueError('Unknown scheduler')

    # calculate class weights if necessary
    weights = None
    if acont.class_weights is not None:
        weights = torch.from_numpy(acont.class_weights).float()

    criterion = torch.nn.CrossEntropyLoss(weight=weights)
    if acont.use_cuda:
        criterion.cuda()

    if acont.use_val:
        val_path = acont.val_path
    else:
        val_path = None

    trainer = Trainer3d(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        train_dataset=train_ds,
        v_path=val_path,
        val_freq=acont.val_freq,
        val_red=acont.val_iter,
        channel_num=acont.input_channels,
        batchsize=batch_size,
        num_workers=4,
        save_root=acont.save_root,
        exp_name=acont.name,
        num_classes=acont.class_num,
        schedulers={"lr": scheduler},
        target_names=acont.target_names,
        stop_epoch=acont.stop_epoch,
        enable_tensorboard=False,
        lcp_flag=lcp_flag,
    )
    # Archiving training script, src folder, env info
    Backup(script_path=__file__, save_path=trainer.save_path).archive_backup()
    acont.save2pkl(trainer.save_path + '/argscont.pkl')
    with open(trainer.save_path + '/argscont.txt', 'w') as f:
        f.write(str(acont.attr_dict))
    f.close()

    trainer.run(max_steps)
Пример #11
0
def generate_predictions_with_model(argscont: ArgsContainer,
                                    model_path: str,
                                    cell_path: str,
                                    out_path: str,
                                    prediction_redundancy: int = 1,
                                    batch_size: int = -1,
                                    chunk_redundancy: int = -1,
                                    force_split: bool = False,
                                    training_seed: bool = False,
                                    label_mappings: List[Tuple[int,
                                                               int]] = None,
                                    label_remove: List[int] = None,
                                    border_exclusion: int = 0,
                                    state_dict: str = None,
                                    model=None,
                                    **args):
    """
    Can be used to generate predictions for multiple files using a specific model (either passed as path to state_dict or as pre-loaded model).

    Args:
        argscont: argument container for current model.
        model_path: path to model state dict.
        cell_path: path to cells used for prediction.
        out_path: path to folder where predictions of this model should get saved.
        prediction_redundancy: number of times each cell should be processed (using the same chunks but different points due to random sampling).
        batch_size: batch size, if -1 this defaults to the batch size used during training.
        chunk_redundancy: number of times each cell should get splitted into a complete chunk set (including different chunks each time).
        force_split: split cells even if cached split information exists.
        training_seed: use random seed from training.
        label_mappings: List of tuples like (from, to) where 'from' is label which should get mapped to 'to'.
            Defaults to label_mappings from training or to val_label_mappings of ArgsContainer.
        label_remove: List of labels to remove from the cells. Defaults to label_remove from training or to val_label_remove of ArgsContainer.
        border_exclusion: nm distance which defines how much of the chunk borders should be excluded from predictions.
        state_dict: state dict holding model for prediction.
        model: loaded model to use for prediction.
    """
    if os.path.exists(out_path):
        print(f"{out_path} already exists. Skipping...")
        return

    if training_seed:
        torch.manual_seed(argscont.random_seed)
        np.random.seed(argscont.random_seed)
        random.seed(argscont.random_seed)

    if argscont.use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    lcp_flag = False
    if model is not None:
        model = model
        if isinstance(model, ConvAdaptSeg):
            lcp_flag = True
    else:
        # load model
        if argscont.architecture == 'lcp' or argscont.model == 'ConvAdaptSeg':
            kwargs = {}
            if argscont.model == 'ConvAdaptSeg':
                kwargs = dict(kernel_num=argscont.pl,
                              architecture=argscont.architecture,
                              activation=argscont.act,
                              norm=argscont.norm_type)
            conv = dict(layer=argscont.conv[0],
                        kernel_separation=argscont.conv[1])
            model = ConvAdaptSeg(argscont.input_channels, argscont.class_num,
                                 get_conv(conv), get_search(argscont.search),
                                 **kwargs)
            lcp_flag = True
        elif argscont.use_big:
            model = SegBig(argscont.input_channels,
                           argscont.class_num,
                           trs=argscont.track_running_stats,
                           dropout=argscont.dropout,
                           use_bias=argscont.use_bias,
                           norm_type=argscont.norm_type,
                           use_norm=argscont.use_norm,
                           kernel_size=argscont.kernel_size,
                           neighbor_nums=argscont.neighbor_nums,
                           reductions=argscont.reductions,
                           first_layer=argscont.first_layer,
                           padding=argscont.padding,
                           nn_center=argscont.nn_center,
                           centroids=argscont.centroids,
                           pl=argscont.pl,
                           normalize=argscont.cp_norm)
        else:
            model = SegAdapt(argscont.input_channels,
                             argscont.class_num,
                             architecture=argscont.architecture,
                             trs=argscont.track_running_stats,
                             dropout=argscont.dropout,
                             use_bias=argscont.use_bias,
                             norm_type=argscont.norm_type,
                             kernel_size=argscont.kernel_size,
                             padding=argscont.padding,
                             nn_center=argscont.nn_center,
                             centroids=argscont.centroids,
                             kernel_num=argscont.pl,
                             normalize=argscont.cp_norm,
                             act=argscont.act)
        try:
            full = torch.load(model_path + state_dict)
            model.load_state_dict(full)
        except RuntimeError:
            model.load_state_dict(full['model_state_dict'])
        model.to(device)
        model.eval()

    transforms = clouds.Compose(argscont.val_transforms)
    if chunk_redundancy == -1:
        chunk_redundancy = argscont.splitting_redundancy
    if batch_size == -1:
        batch_size = argscont.batch_size
    if label_remove is None:
        if argscont.val_label_remove is not None:
            label_remove = argscont.val_label_remove
        else:
            label_remove = argscont.label_remove
    if label_mappings is None:
        if argscont.val_label_mappings is not None:
            label_mappings = argscont.val_label_mappings
        else:
            label_mappings = argscont.label_mappings

    torch_handler = TorchHandler(cell_path,
                                 argscont.sample_num,
                                 argscont.class_num,
                                 density_mode=argscont.density_mode,
                                 bio_density=argscont.bio_density,
                                 tech_density=argscont.tech_density,
                                 transform=transforms,
                                 specific=True,
                                 obj_feats=argscont.features,
                                 ctx_size=argscont.chunk_size,
                                 label_mappings=label_mappings,
                                 hybrid_mode=argscont.hybrid_mode,
                                 feat_dim=argscont.input_channels,
                                 splitting_redundancy=chunk_redundancy,
                                 label_remove=label_remove,
                                 sampling=argscont.sampling,
                                 force_split=force_split,
                                 padding=argscont.padding,
                                 exclude_borders=border_exclusion)
    prediction_mapper = PredictionMapper(cell_path,
                                         out_path,
                                         torch_handler.splitfile,
                                         label_remove=label_remove,
                                         hybrid_mode=argscont.hybrid_mode)

    obj = None
    obj_names = torch_handler.obj_names.copy()
    for obj in torch_handler.obj_names:
        if os.path.exists(out_path + obj + '_preds.pkl'):
            print(obj + " has already been processed. Skipping...")
            obj_names.remove(obj)
            continue
        if torch_handler.get_obj_length(obj) == 0:
            print(obj + " has no chunks to process. Skipping...")
            obj_names.remove(obj)
            continue
        print(f"Processing {obj}")
        predict_cell(torch_handler,
                     obj,
                     batch_size,
                     argscont.sample_num,
                     prediction_redundancy,
                     device,
                     model,
                     prediction_mapper,
                     argscont.input_channels,
                     point_subsampling=argscont.sampling,
                     lcp_flag=lcp_flag)
    if obj is not None:
        prediction_mapper.save_prediction()
    else:
        return
    argscont.save2pkl(out_path + 'argscont.pkl')
    del model
    torch.cuda.empty_cache()
Пример #12
0
def analyse_features(m_path: str, args_path: str, out_path: str, val_path: str, context_list: List[Tuple[str, int]],
                     label_mappings: List[Tuple[int, int]] = None, label_remove: List[int] = None,
                     splitting_redundancy: int = 1, test: bool = False):
    device = torch.device('cuda')
    m_path = os.path.expanduser(m_path)
    out_path = os.path.expanduser(out_path)
    args_path = os.path.expanduser(args_path)
    val_path = os.path.expanduser(val_path)

    # load model specifications
    argscont = ArgsContainer().load_from_pkl(args_path)

    lcp_flag = False
    # load model
    if argscont.architecture == 'lcp' or argscont.model == 'ConvAdaptSeg':
        kwargs = {}
        if argscont.model == 'ConvAdaptSeg':
            kwargs = dict(f_map_num=argscont.pl, architecture=argscont.architecture, act=argscont.act, norm=argscont.norm_type)
        conv = dict(layer=argscont.conv[0], kernel_separation=argscont.conv[1])
        model = get_network(argscont.model, argscont.input_channels, argscont.class_num, conv, argscont.search, **kwargs)
        lcp_flag = True
    elif argscont.use_big:
        model = SegBig(argscont.input_channels, argscont.class_num, trs=argscont.track_running_stats, dropout=0,
                       use_bias=argscont.use_bias, norm_type=argscont.norm_type, use_norm=argscont.use_norm,
                       kernel_size=argscont.kernel_size, neighbor_nums=argscont.neighbor_nums,
                       reductions=argscont.reductions, first_layer=argscont.first_layer,
                       padding=argscont.padding, nn_center=argscont.nn_center, centroids=argscont.centroids,
                       pl=argscont.pl, normalize=argscont.cp_norm)
    else:
        print("Adaptable model was found!")
        model = SegAdapt(argscont.input_channels, argscont.class_num, architecture=argscont.architecture,
                         trs=argscont.track_running_stats, dropout=argscont.dropout, use_bias=argscont.use_bias,
                         norm_type=argscont.norm_type, kernel_size=argscont.kernel_size, padding=argscont.padding,
                         nn_center=argscont.nn_center, centroids=argscont.centroids, kernel_num=argscont.pl,
                         normalize=argscont.cp_norm, act=argscont.act)
    try:
        full = torch.load(m_path)
        model.load_state_dict(full)
    except RuntimeError:
        model.load_state_dict(full['model_state_dict'])
    model.to(device)
    model.eval()

    pts = torch.rand(1, argscont.sample_num, 3, device=device)
    feats = torch.rand(1, argscont.sample_num, argscont.input_channels, device=device)
    contexts = []
    th = None

    if not test:
        # prepare data loader
        if label_mappings is None:
            label_mappings = argscont.label_mappings
        if label_remove is None:
            label_remove = argscont.label_remove
        transforms = clouds.Compose(argscont.val_transforms)
        th = TorchHandler(val_path, argscont.sample_num, argscont.class_num, density_mode=argscont.density_mode,
                          bio_density=argscont.bio_density, tech_density=argscont.tech_density, transform=transforms,
                          specific=True, obj_feats=argscont.features, ctx_size=argscont.chunk_size,
                          label_mappings=label_mappings, hybrid_mode=argscont.hybrid_mode,
                          feat_dim=argscont.input_channels, splitting_redundancy=splitting_redundancy,
                          label_remove=label_remove, sampling=argscont.sampling,
                          force_split=False, padding=argscont.padding, exclude_borders=0)
        for context in context_list:
            pts = torch.zeros((1, argscont.sample_num, 3))
            feats = torch.ones((1, argscont.sample_num, argscont.input_channels))
            sample = th[context]
            pts[0] = sample['pts']
            feats[0] = sample['features']
            o_mask = sample['o_mask'].numpy().astype(bool)
            l_mask = sample['l_mask'].numpy().astype(bool)
            target = sample['target'].numpy()
            target = target[l_mask].astype(int)
            contexts.append((feats, pts, o_mask, l_mask, target))
    else:
        contexts.append((feats, pts))

    for c_ix, context in enumerate(contexts):
        # set hooks

        if lcp_flag:
            layer_outs = SaveFeatures(list(model.children())[0][1:])
            act_outs = SaveFeatures([layer.activation for layer in list(model.children())[0][1:]])
        else:
            layer_outs = SaveFeatures(list(model.children())[1])
            act_outs = SaveFeatures([list(model.children())[0]])
        feats = context[0].to(device, non_blocking=True)
        pts = context[1].to(device, non_blocking=True)

        if lcp_flag:
            pts = pts.transpose(1, 2)
            feats = feats.transpose(1, 2)

        output = model(feats, pts).cpu().detach()

        if lcp_flag:
            output = output.transpose(1, 2).numpy()

        if not test:
            output = np.argmax(output[0][context[2]].reshape(-1, th.num_classes), axis=1)
            pts = context[1][0].numpy()
            identifier = f'{context_list[c_ix][0]}_{context_list[c_ix][1]}'
            target = PointCloud(pts, context[4])
            x_offset = (pts[:, 0].max() - pts[:, 0].min()) * 1.5 * 3
            pred = PointCloud(pts[context[3]], output)
            pred.move(np.array([x_offset / 2, 0, 0]))
            clouds.merge([target, pred]).save2pkl(out_path + identifier + '_0io_r_a.pkl')
        for ix, layer in enumerate(layer_outs.features):
            if len(layer) < 2:
                continue
            feats = layer[0].detach().cpu()[0]
            feats_act = act_outs.features[ix].detach().cpu()[0]
            pts = layer[1].detach().cpu()[0]
            if lcp_flag:
                feats = feats.transpose(0, 1).numpy()
                feats_act = feats_act.transpose(0, 1).numpy()
                pts = pts.transpose(0, 1).numpy()
            else:
                feats = feats.numpy()
                feats_act = feats_act.numpy()
                pts = pts.numpy()
            x_offset = (pts[:, 0].max() - pts[:, 0].min()) * 1.5 * 3
            x_offset_act = x_offset / 3
            y_size = (pts[:, 1].max() - pts[:, 1].min()) * 1.5
            y_offset = 0
            row_num = feats.shape[1] / 8
            total_pc = None
            total_pc_act = None
            for i in range(feats.shape[1]):
                if i % 8 == 0 and i != 0:
                    y_offset += y_size
                pc = PointCloud(vertices=pts, features=feats[:, i].reshape(-1, 1))
                pc_act = PointCloud(vertices=pts, features=feats_act[:, i].reshape(-1, 1))
                pc.move(np.array([(i % 8) * x_offset, y_offset, 0]))
                pc_act.move(np.array([(i % 8) * x_offset + x_offset / 2.8, y_offset, 0]))
                pc = clouds.merge_clouds([pc, pc_act])
                pc_act = PointCloud(vertices=pts, features=feats_act[:, i].reshape(-1, 1))
                pc_act.move(np.array([(i % 8) * x_offset_act, y_offset, 0]))
                if total_pc is None:
                    total_pc = pc
                    total_pc_act = pc_act
                else:
                    total_pc = clouds.merge_clouds([total_pc, pc])
                    total_pc_act = clouds.merge_clouds([total_pc_act, pc_act])
            total_pc.move(np.array([-4 * x_offset - x_offset / 2, -row_num / 2 * y_size - y_size / 2, 0]))
            total_pc_act.move(np.array([-4 * x_offset_act - x_offset_act / 2, -row_num / 2 * y_size - y_size / 2, 0]))
            total_pc.save2pkl(out_path + f'{context_list[c_ix][0]}_{context_list[c_ix][1]}_l{ix}_r.pkl')
            total_pc_act.save2pkl(out_path + f'{context_list[c_ix][0]}_{context_list[c_ix][1]}_l{ix}_a.pkl')