Ejemplo n.º 1
0
def make_skel_from_json(json_path: str):
    """
    Creates a skeleton object from the binary targets of the data sources in json
    Args:
        json_path: the path of the data source json file
    Returns:
        skel: the skeleton object
    """
    data_sources_dict = WkwData.convert_ds_to_dict(
        WkwData.read_short_ds_json(json_path=json_path))
    # Init with empty skeleton
    empty_skel_name = os.path.join(get_data_dir(), 'NML', 'empty_skel.nml')
    skel = wkskel.Skeleton(nml_path=empty_skel_name)

    # Loop over each bbox
    keys = list(data_sources_dict.keys())
    num_nodes_perTree = 5
    for idx, key in tqdm(enumerate(keys),
                         desc='Making bbox nml',
                         total=len(keys)):
        # Get minimum and maximum node id
        min_id = (num_nodes_perTree * idx) + 1
        max_id = num_nodes_perTree * (idx + 1)
        # Encode the target in the tree name
        cur_target = data_sources_dict[key]['target_class']
        cur_name = f'{key}, Debris: {cur_target[0]}, Myelin: {cur_target[1]}'
        # add current tree
        add_bbox_tree(skel=skel,
                      bbox=data_sources_dict[key]['input_bbox'],
                      tree_name=cur_name,
                      node_id_min_max=[min_id, max_id])
    return skel
Ejemplo n.º 2
0
def WkwDataSetConstructor():
    """ Construsts a WkwData[set] from fixed parameters. These parameters can also be explored for 
        further testing"""    
    # Get data source from example json
    json_dir = gpath.get_data_dir()
    datasources_json_path = os.path.join(json_dir, 'datasource_20X_980_980_1000bboxes.json')
    data_sources = WkwData.datasources_from_json(datasources_json_path)
    # Only pick the first two bboxes for faster epoch
    data_sources = data_sources[0:2]
    data_split = DataSplit(train=0.70, validation=0.00, test=0.30)
    # input, output shape
    input_shape = (28, 28, 1)
    output_shape = (28, 28, 1)
    # flags for memory and storage caching
    cache_RAM = True
    cache_HDD = True
    # HDD cache directory
    connDataDir = '/conndata/alik/genEM3_runs/VAE/'
    cache_root = os.path.join(connDataDir, '.cache/')
    dataset = WkwData(
        input_shape=input_shape,
        target_shape=output_shape,
        data_sources=data_sources,
        data_split=data_split,
        normalize=False,
        transforms=ToZeroOneRange(minimum=0, maximum=255),
        cache_RAM=cache_RAM,
        cache_HDD=cache_HDD,
        cache_HDD_root=cache_root
    )
    return dataset
Ejemplo n.º 3
0
    def wkw_create_write(data, wkw_root, wkw_bbox, compress=False):

        if compress:
            wkw_block_type= 2
        else:
            wkw_block_type = 1

        if not os.path.exists(wkw_root):
            os.makedirs(wkw_root)

        if not os.path.exists(os.path.join(wkw_root, 'header.wkw')):
            WkwData.wkw_create(wkw_root, data.dtype, wkw_block_type)

        WkwData.wkw_write(wkw_root, wkw_bbox, data)
Ejemplo n.º 4
0
    def __init__(self,
                 wkw_dataset: WkwData,
                 subset_indices: List[np.int64],
                 fraction_debris: float,
                 artefact_dim: int,
                 verbose: bool = False):
        self.wkw_dataset = wkw_dataset
        self.subset_indices = subset_indices
        self.frac_clean_debris = np.asarray(
            [1 - fraction_debris, fraction_debris])
        self.artefact_dim = artefact_dim
        # Get the target (debris vs. clean) for each sample
        total_sample_range = iter(subset_indices)
        self.index_set = set(subset_indices)
        # check uniqueness of indices
        assert len(self.index_set) == len(self.subset_indices)
        self.target_class = np.asarray([
            wkw_dataset.get_target_from_sample_idx(sample_idx)
            for sample_idx in total_sample_range
        ],
                                       dtype=np.int64)
        self.artefact_targets = self.target_class[:, artefact_dim]
        if verbose:
            self.report_original_numbers()

        # Use the inverse of the number of samples as weight to create balance
        self.class_sample_count = np.array([
            len(np.where(self.artefact_targets == t)[0])
            for t in np.unique(self.artefact_targets)
        ])

        # Subset dataset
        self.sub_dataset = Subset(wkw_dataset, subset_indices)
Ejemplo n.º 5
0
def update_data_source_targets(
        dataset: WkwData, target_index_tuple_list: Sequence[Tuple[int,
                                                                  float]]):
    """
    Create an updated list of datasources from a wkwdataset and a list of sample index, target_class pair
    """
    list_source_idx = [
        dataset.get_source_idx_from_sample_idx(sample_idx)
        for (sample_idx, _) in target_index_tuple_list
    ]
    source_list = []
    for cur_target_tuple in target_index_tuple_list:
        cur_target = cur_target_tuple[1]
        sample_idx = cur_target_tuple[0]
        s_index = list_source_idx[sample_idx]
        s = dataset.data_sources[s_index]
        source_list.append(
            DataSource(id=s.id,
                       input_path=s.input_path,
                       input_bbox=s.input_bbox,
                       input_mean=s.input_mean,
                       input_std=s.input_std,
                       target_path=s.target_path,
                       target_bbox=s.target_bbox,
                       target_class=cur_target,
                       target_binary=s.target_binary))
    return source_list
Ejemplo n.º 6
0
def readWkwFromCenter(wkwdir, coordinates, dimensions):
    """ Returns a collection of images given their coordinate and dimensions (numpy arrays)"""
    # Get the bounding boxes from coordinates and dimensions for the cropping
    bboxes = bboxesFromArray(coordinates, dimensions)
    # read the wkwdata into a numpy array
    readWk = lambda bbox: WkwData.wkw_read(wkwdir, bbox)
    images = np.apply_along_axis(readWk, 1, bboxes).squeeze(4).astype('double')
    return images
Ejemplo n.º 7
0
def merge_json_from_data_dir(fnames: Sequence[str], output_fname: str):
    """
    Function concatenates the data directory to the list of file names and concatenats the related jsons
    """
    # Test concatenating jsons
    full_fnames = []
    for fname in fnames:
        full_fname = os.path.join(get_data_dir(), fname)
        full_fnames.append(full_fname)

    # Concatenate the test and training data sets
    full_output_name = os.path.join(get_data_dir(), output_fname)
    all_ds = WkwData.concat_datasources(json_paths_in=full_fnames,
                                        json_path_out=full_output_name)
    return all_ds
Ejemplo n.º 8
0
def display_example(index: int,
                    dataset: WkwData,
                    margin: int = 35,
                    roi_size: int = 140):
    """
    Display an image with a central rectangle for the roi
    """
    _, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(dataset.get_ordered_sample(index)['input'].squeeze(),
              cmap='gray')
    rectangle = plt.Rectangle((margin, margin),
                              roi_size,
                              roi_size,
                              fill=False,
                              ec="red")
    ax.add_patch(rectangle)
    ax.axis('off')
    plt.show()
Ejemplo n.º 9
0
    def filter_sparse_cube_3d(self, wkw_bbox, filter_kernel, compress_output=False):

        data = WkwData.wkw_read(self.input_wkw_root, wkw_bbox).squeeze(axis=0)
        pred_inds_sparse = np.where(~np.isnan(data))
        pred_inds_dense = [stats.rankdata(pis, method='dense') - 1 for pis in pred_inds_sparse]
        data_dense = np.zeros((max(pred_inds_dense[0]+1), max(pred_inds_dense[1])+1, max(pred_inds_dense[2])+1),
                              dtype=np.float32)

        for i, (xd, yd, zd) in enumerate(zip(*pred_inds_dense)):
            xs, ys, zs = [pis[i] for pis in pred_inds_sparse]
            data_dense[xd, yd, zd] = data[xs, ys, zs]

        data_dense_conv = ndimage.filters.convolve(data_dense, weights=filter_kernel)
        data_dense_conv = data_dense_conv/data_dense_conv.max()
        data_dense_conv[data_dense_conv < 0] = 0

        for i, (xs, ys, zs) in enumerate(zip(*pred_inds_sparse)):
            xd, yd, zd = [pid[i] for pid in pred_inds_dense]
            data[xs, ys, zs] = data_dense_conv[xd, yd, zd]

        data = np.expand_dims(data, axis=0)
        self.wkw_create_write(data=data, wkw_root=self.output_wkw_root, wkw_bbox=wkw_bbox, compress=compress_output)
Ejemplo n.º 10
0
def patch_source_list_from_dataset(dataset: WkwData,
                                   margin: int = 35,
                                   roi_size: int = 140):
    """
    Return two data_sources from the image patches contained in a dataset. 
    One data source has a larger bbox for annotations
    """
    corner_xy_index = [0, 1]
    length_xy_index = [3, 4]
    large_bboxes_idx = []
    bboxes_idx = []
    for idx in range(len(dataset)):
        (source_idx, original_cur_bbox) = dataset.get_bbox_for_sample_idx(idx)
        bboxes_idx.append((source_idx, original_cur_bbox))
        cur_bbox = np.asarray(original_cur_bbox)
        cur_bbox[corner_xy_index] = cur_bbox[corner_xy_index] - margin
        cur_bbox[length_xy_index] = cur_bbox[length_xy_index] + margin * 2
        # large bbox append
        large_bboxes_idx.append((source_idx, cur_bbox.tolist()))

    assert len(large_bboxes_idx) == len(dataset) == len(bboxes_idx)
    large_source_list = update_data_source_bbox(dataset, large_bboxes_idx)
    patch_source_list = update_data_source_bbox(dataset, bboxes_idx)
    return {'original': patch_source_list, 'large': large_source_list}
Ejemplo n.º 11
0
skeletons = [Skeleton(skel_dir) for skel_dir in skel_dirs]
print(f'Time to read skeleton: {time.time() - start}')
# Read the coordinates and target class of all three skeletons into the volume data frame
volume_df = get_volume_df(skeletons=skeletons)
# Get the ingredients for making the datasources
bboxes = bboxesFromArray(volume_df[['x', 'y', 'z']].values)
input_dir = '/tmpscratch/webknossos/Connectomics_Department/2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8_artifact_pred/color/1'
target_class = volume_df['class'].values.astype(np.float)
target_binary = 1
target_dir = input_dir
input_mean = 148.0
input_std = 36.0
# Create a list of data sources
source_list = []
for i, cur_bbox in enumerate(bboxes):
    cur_target = target_class[i]
    source_list.append(
        DataSource(id=str(i),
                   input_path=input_dir,
                   input_bbox=cur_bbox.tolist(),
                   input_mean=input_mean,
                   input_std=input_std,
                   target_path=target_dir,
                   target_bbox=cur_bbox.tolist(),
                   target_class=cur_target,
                   target_binary=target_binary))
# Json name
json_name = os.path.join(get_data_dir(), 'test_data_three_bboxes.json')
# Write to json file
WkwData.datasources_to_json(source_list, json_name)
Ejemplo n.º 12
0
num_workers = 0

kernel_size = 3
stride = 1
n_fmaps = 16
n_latent = 2048
input_size = 140
output_size = input_size
model = AE_Encoder_Classifier(
    Encoder_4_sampling_bn_1px_deep_convonly_skip(input_size,
                                                 kernel_size,
                                                 stride,
                                                 n_latent=n_latent),
    Classifier(n_latent=n_latent))

datasources = WkwData.datasources_from_json(datasources_json_path)
dataset = WkwData(input_shape=input_shape,
                  target_shape=output_shape,
                  data_sources=datasources,
                  stride=(70, 70, 1),
                  cache_HDD=True,
                  cache_RAM=True,
                  cache_HDD_root=cache_HDD_root)

prediction_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                batch_size=batch_size,
                                                num_workers=num_workers)

checkpoint = torch.load(state_dict_path,
                        map_location=lambda storage, loc: storage)
state_dict = checkpoint['model_state_dict']
Ejemplo n.º 13
0
from genEM3.util.path import get_data_dir
from genEM3.data.wkwdata import WkwData
import os

# Test concatenating jsons
test_json_path = os.path.join(get_data_dir(), 'test_data_three_bboxes.json')
train_json_path = os.path.join(
    get_data_dir(), 'debris_clean_added_bboxes2_wiggle_datasource.json')
# Concatenate the test and training data sets
output_name = os.path.join
all_ds = WkwData.concat_datasources([train_json_path, test_json_path],
                                    os.path.join(get_data_dir(),
                                                 'train_test_combined.json'))
assert len(all_ds) == len(WkwData.datasources_from_json(test_json_path)) + len(
    WkwData.datasources_from_json(train_json_path))
Ejemplo n.º 14
0
              [27500, 22000, 3889, 560, 560, 5],
              [27500, 22000, 3902, 560, 560, 20],
              [27500, 22000, 3930, 560, 560, 19],
              [27500, 22000, 3969, 560, 560, 16],
              [27500, 22000, 4021, 560, 560, 9],
              [27500, 22000, 4065, 560, 560, 12],
              [27500, 22000, 4163, 560, 560, 9],
              [27500, 22000, 4255, 560, 560, 11]]
num_samples = sum([bbox[5] for bbox in bboxes_add]) * 560 * 560 / 140 / 140
target_binary_add = 1
target_class_add = 0.0
input_mean_add = 148.0
input_std_add = 36.0
path_add = "/tmpscratch/webknossos/Connectomics_Department/2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8/color/1"

data_sources = WkwData.datasources_from_json(datasources_json_path)
data_sources_max_id = max(
    [int(data_source.id) for data_source in data_sources])

data_sources_out = data_sources
for bbox_idx, bbox_add in enumerate(bboxes_add):
    data_source_out = DataSource(id=str(data_sources_max_id + bbox_idx + 1),
                                 input_path=path_add,
                                 input_bbox=bbox_add,
                                 input_mean=input_mean_add,
                                 input_std=input_std_add,
                                 target_path=path_add,
                                 target_bbox=bbox_add,
                                 target_class=target_class_add,
                                 target_binary=target_binary_add)
    data_sources_out.append(data_source_out)
Ejemplo n.º 15
0
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')

# Data settings
run_root = os.path.dirname(os.path.abspath(__file__))
input_shape = (140, 140, 1)
output_shape = (140, 140, 1)
data_split = DataSplit(train=0.70, validation=0.15, test=0.15)
cache_RAM = False
cache_HDD = False
batch_size = 1024
num_workers = 0

# Data sources
json_name = os.path.join(get_data_dir(), 'combined', 'combined_20K_patches.json')
data_sources = WkwData.read_short_ds_json(json_path=json_name)
transformations = WkwData.get_common_transforms()
# Data set
dataset = WkwData(
    input_shape=input_shape,
    target_shape=output_shape,
    data_sources=data_sources,
    data_split=data_split,
    transforms=transformations,
    cache_RAM=cache_RAM,
    cache_HDD=cache_HDD)
# Data loaders
data_loader_params = {'dataset': dataset, 'batch_size': batch_size,
                      'num_workers': num_workers, 'collate_fn': dataset.collate_fn}
data_loaders = data_loaders_split(params=data_loader_params)
# Model initialization
Ejemplo n.º 16
0
from scipy.ndimage.measurements import label
from wkskel import Skeleton, Parameters, Nodes
from genEM3.data.wkwdata import WkwData, DataSource
from genEM3.training.metrics import Metrics
from genEM3.util.path import get_runs_dir

path_in = os.path.join(get_runs_dir(),
                       'inference/ae_classify_11_parallel/test_center_filt')
cache_HDD_root = os.path.join(path_in, '.cache/')
path_datasources = os.path.join(path_in, 'datasources.json')
path_nml_in = os.path.join(path_in, 'bbox_annotated.nml')
input_shape = (140, 140, 1)
target_shape = (1, 1, 1)
stride = (35, 35, 1)

datasources = WkwData.datasources_from_json(path_datasources)
dataset = WkwData(input_shape=input_shape,
                  target_shape=target_shape,
                  data_sources=datasources,
                  stride=stride,
                  cache_HDD=False,
                  cache_RAM=True)

skel = Skeleton(path_nml_in)

pred_df = pd.DataFrame(columns=[
    'tree_idx', 'tree_id', 'x', 'y', 'z', 'xi', 'yi', 'class', 'explicit',
    'cluster_id', 'prob'
])
group_ids = np.array(skel.group_ids)
input_path = datasources[0].input_path
Ejemplo n.º 17
0
for idx, curBbox in enumerate(bboxes_debris):
    # convert bbox to normal python list and integer. numpy arrays are not serializable
    curBbox = [int(num) for num in curBbox]
    curSource = DataSource(id=str(idx),
                           input_path=getMag8DatasetDir(),
                           input_bbox=curBbox,
                           input_mean=148.0,
                           input_std=36.0,
                           target_path=getMag8DatasetDir(),
                           target_bbox=curBbox,
                           target_class=1.0,
                           target_binary=1)
    dataSources.append(curSource)
# Append clean locations
for idx, curBbox in enumerate(bboxes_clean):
    # The initial 600 Indices are taken by the debris locations
    idx = idx + numTrainingExamples
    curSource = DataSource(id=str(idx),
                           input_path=getMag8DatasetDir(),
                           input_bbox=curBbox,
                           input_mean=148.0,
                           input_std=36.0,
                           target_path=getMag8DatasetDir(),
                           target_bbox=curBbox,
                           target_class=0.0,
                           target_binary=1)
    dataSources.append(curSource)
# write to JSON file
jsonPath = os.path.join(get_data_dir(), 'debris_clean_datasource.json')
WkwData.datasources_to_json(dataSources, jsonPath)
sample_pos_x = np.random.randint(wkw_lims[0],
                                 wkw_lims[0] + wkw_lims[3] - sample_dims[0],
                                 num_samples)
sample_pos_y = np.random.randint(wkw_lims[1],
                                 wkw_lims[1] + wkw_lims[4] - sample_dims[1],
                                 num_samples)
sample_pos_z = np.random.randint(wkw_lims[2],
                                 wkw_lims[2] + wkw_lims[5] - sample_dims[2],
                                 num_samples)

for id in range(num_samples):
    input_bbox = [
        int(sample_pos_x[id]),
        int(sample_pos_y[id]),
        int(sample_pos_z[id]), sample_dims[0], sample_dims[1], sample_dims[2]
    ]
    target_bbox = input_bbox
    datasource = DataSource(id=str(id + bboxes_positive.shape[0]),
                            input_path=wkw_path,
                            input_bbox=input_bbox,
                            input_mean=input_mean,
                            input_std=input_std,
                            target_path=target_path,
                            target_bbox=input_bbox,
                            target_class=0,
                            target_binary=1)
    datasources.append(datasource)

WkwData.datasources_to_json(datasources, json_path)
Ejemplo n.º 19
0
cache_HDD_root = os.path.join(run_root, '../../../data/.cache/')
datasources_json_path = os.path.join(
    run_root,
    '../../../data/debris_clean_added_bboxes2_wiggle_datasource.json')
state_dict_path = '/u/flod/code/genEM3/runs/training/ae_v05_skip/.log/epoch_60/model_state_dict'
input_shape = (140, 140, 1)
output_shape = (140, 140, 1)

data_split = DataSplit(train=0.85, validation=0.15, test=0.00)
cache_RAM = True
cache_HDD = True
cache_root = os.path.join(run_root, '.cache/')
batch_size = 256
num_workers = 8

data_sources = WkwData.datasources_from_json(datasources_json_path)

transforms = transforms.Compose([
    transforms.RandomFlip(p=0.5, flip_plane=(1, 2)),
    transforms.RandomFlip(p=0.5, flip_plane=(2, 1)),
    transforms.RandomRotation90(p=1.0, mult_90=[0, 1, 2, 3], rot_plane=(1, 2))
])

dataset = WkwData(input_shape=input_shape,
                  target_shape=output_shape,
                  data_sources=data_sources,
                  data_split=data_split,
                  transforms=transforms,
                  cache_RAM=cache_RAM,
                  cache_HDD=cache_HDD,
                  cache_HDD_root=cache_HDD_root)
Ejemplo n.º 20
0
def main():
    parser = argparse.ArgumentParser(description='Convolutional VAE for 3D electron microscopy data')
    parser.add_argument('--result_dir', type=str, default='.log', metavar='DIR',
                        help='output directory')
    parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: None')

    # model options
    # Note(AK): with the AE models from genEM3, the 2048 latent size and 16 fmaps are fixed
    parser.add_argument('--latent_size', type=int, default=2048, metavar='N',
                        help='latent vector size of encoder')
    parser.add_argument('--max_weight_KLD', type=float, default=1.0, metavar='N',
                        help='Weight for the KLD part of loss')

    args = parser.parse_args()
    print('The command line argument:\n')
    print(args)

    # Make the directory for the result output
    if not os.path.isdir(args.result_dir):
        os.makedirs(args.result_dir)

    torch.manual_seed(args.seed)
    # Parameters
    warmup_kld = True
    connDataDir = '/conndata/alik/genEM3_runs/VAE/'
    json_dir = gpath.get_data_dir()
    datasources_json_path = os.path.join(json_dir, 'datasource_20X_980_980_1000bboxes.json')
    input_shape = (140, 140, 1)
    output_shape = (140, 140, 1)
    data_sources = WkwData.datasources_from_json(datasources_json_path)
    # # Only pick the first bboxes for faster epoch
    # data_sources = [data_sources[0]]
    data_split = DataSplit(train=0.80, validation=0.00, test=0.20)
    cache_RAM = True
    cache_HDD = True
    cache_root = os.path.join(connDataDir, '.cache/')
    gpath.mkdir(cache_root)

    # Set up summary writer for tensorboard
    constructedDirName = ''.join([f'weightedVAE_{args.max_weight_KLD}_warmup_{warmup_kld}_', gpath.gethostnameTimeString()])
    tensorBoardDir = os.path.join(connDataDir, constructedDirName)
    writer = SummaryWriter(log_dir=tensorBoardDir)
    launch_tb(logdir=tensorBoardDir, port='7900')
    # Set up data loaders
    num_workers = 8
    dataset = WkwData(
        input_shape=input_shape,
        target_shape=output_shape,
        data_sources=data_sources,
        data_split=data_split,
        normalize=False,
        transforms=ToStandardNormal(mean=148.0, std=36.0),
        cache_RAM=cache_RAM,
        cache_HDD=cache_HDD,
        cache_HDD_root=cache_root
    )
    # Data loaders for training and test
    train_sampler = SubsetRandomSampler(dataset.data_train_inds)
    train_loader = torch.utils.data.DataLoader(
        dataset=dataset, batch_size=args.batch_size, num_workers=num_workers, sampler=train_sampler,
        collate_fn=dataset.collate_fn)

    test_sampler = SubsetRandomSampler(dataset.data_test_inds)
    test_loader = torch.utils.data.DataLoader(
        dataset=dataset, batch_size=args.batch_size, num_workers=num_workers, sampler=test_sampler,
        collate_fn=dataset.collate_fn)
    # Model and optimizer definition
    input_size = 140
    output_size = 140
    kernel_size = 3
    stride = 1
    # initialize with the given value of KLD (maximum value in case of a warmup scenario)
    weight_KLD = args.max_weight_KLD
    model = ConvVAE(latent_size=args.latent_size,
                    input_size=input_size,
                    output_size=output_size,
                    kernel_size=kernel_size,
                    stride=stride,
                    weight_KLD=weight_KLD).to(device)
    # Add model to the tensorboard as graph
    add_graph(writer=writer, model=model, data_loader=train_loader, device=device)
    # print the details of the model
    print_model = True
    if print_model:
        model.summary(input_size=input_size, device=device.type)
    # set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    start_epoch = 0
    best_test_loss = np.finfo('f').max

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint %s' % args.resume)
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_test_loss = checkpoint['best_test_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint %s' % args.resume)
        else:
            print('=> no checkpoint found at %s' % args.resume)
    # Training loop
    for epoch in range(start_epoch, args.epochs):
        # warmup the kld error linearly
        if warmup_kld:
            model.weight_KLD.data = torch.Tensor([((epoch+1) / args.epochs) * args.max_weight_KLD]).to(device) 

        train_loss, train_lossDetailed = train(epoch, model, train_loader, optimizer, args,
                                               device=device)
        test_loss, test_lossDetailed = test(epoch, model, test_loader, writer, args,
                                            device=device)

        # logging, TODO: Use better tags for the logging
        cur_weight_KLD = model.weight_KLD.detach().item()
        writer.add_scalar('loss_train/weight_KLD', cur_weight_KLD, epoch)
        writer.add_scalar('loss_train/total', train_loss, epoch)
        writer.add_scalar('loss_test/total', test_loss, epoch)
        writer.add_scalars('loss_train', train_lossDetailed, global_step=epoch)
        writer.add_scalars('loss_test', test_lossDetailed, global_step=epoch)
        # add the histogram of weights and biases plus their gradients
        for name, param in model.named_parameters():
            writer.add_histogram(name, param.detach().cpu().data.numpy(), epoch)
            # weight_KLD is a parameter but does not have a gradient. It creates an error if one 
            # tries to plot the histogram of a None variable
            if param.grad is not None:
                writer.add_histogram(name+'_gradient', param.grad.cpu().numpy(), epoch)
        # plot mu and logvar
        for latent_prop in ['cur_mu', 'cur_logvar']:
            latent_val = getattr(model, latent_prop)
            writer.add_histogram(latent_prop, latent_val.cpu().numpy(), epoch)
        # flush them to the output
        writer.flush()
        print('Epoch [%d/%d] loss: %.3f val_loss: %.3f' % (epoch + 1, args.epochs, train_loss, test_loss))
        is_best = test_loss < best_test_loss
        best_test_loss = min(test_loss, best_test_loss)
        save_directory = os.path.join(tensorBoardDir, '.log')
        save_checkpoint({'epoch': epoch,
                         'best_test_loss': best_test_loss,
                         'state_dict': model.state_dict(),
                         'optimizer': optimizer.state_dict()},
                        is_best,
                        save_directory)

        with torch.no_grad():
            # Image 64 random sample from the prior latent space and decode
            sample = torch.randn(64, args.latent_size).to(device)
            sample = model.decode(sample).cpu()
            sample_uint8 = undo_normalize(sample, mean=148.0, std=36.0)
            img = make_grid(sample_uint8)
            writer.add_image('sampling', img, epoch)
Ejemplo n.º 21
0
import os
import time
import pickle
import itertools
from collections import namedtuple
import numpy as np
import matplotlib.pyplot as plt

from genEM3.data.wkwdata import WkwData, DataSource
from genEM3.util.path import get_data_dir
import genEM3.data.annotation as annotation
# %% Prepare for annotation
# Loaded the json file for the dataset
json_dir = os.path.join(get_data_dir(),
                        'debris_clean_added_bboxes2_wiggle_datasource.json')
config = WkwData.config_wkwdata(json_dir)
dataset = WkwData.init_from_config(config)

# Get a set of data sources with the normal bounding boxes to create a patch wise detaset and a larger bounding box for annotation
margin = 35
roi_size = 140
source_dict = annotation.patch_source_list_from_dataset(dataset=dataset,
                                                        margin=margin,
                                                        roi_size=roi_size)
dataset_dict = dict.fromkeys(source_dict)

for key in source_dict:
    cur_source = source_dict[key]
    cur_patch_shape = tuple(cur_source[0].input_bbox[3:6])
    cur_config = WkwData.config_wkwdata(datasources_json_path=None,
                                        input_shape=cur_patch_shape,
Ejemplo n.º 22
0
import os
from genEM3.data.wkwdata import WkwData
from genEM3.util.path import get_data_dir
# Read the data
json_name = os.path.join(get_data_dir(), 'combined', 'combined_20K_patches.json')
data_sources = WkwData.read_short_ds_json(json_path=json_name)
# Read an old json for comparison
old_json_name = os.path.join(get_data_dir(), 'dense_3X_10_10_2_um/original_merged_double_binary_v01.json')
old_example = WkwData.datasources_from_json(old_json_name)
# Write a copy [with some modifications]
ouput_name = os.path.join(get_data_dir(), 'combined', 'copyTest_20K_patches.json')
WkwData.write_short_ds_json(datasources=data_sources, json_path=ouput_name, convert_to_short=True)
Ejemplo n.º 23
0
import os

from genEM3.data.wkwdata import WkwData
from genEM3.util.path import get_data_dir

# Read Json file
json_names = ['dense_3X_10_10_2_um/original_merged_double_binary_v01.json', 
             '10x_test_bboxes/10X_9_9_1_um_double_binary_v01.json']
ds_names = [os.path.join(get_data_dir(), j_name) for j_name in json_names]
data_sources = []
dataset_path = '/tmpscratch/webknossos/Connectomics_Department/2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8_artifact_pred/color/1'
for ds in ds_names:
    cur_ds = WkwData.datasources_from_json(json_path=ds)
    cur_ds_dict = WkwData.convert_ds_to_dict(cur_ds)
    # all pathes use the artifact_pred dataset
    for s in cur_ds_dict:
        cur_source = cur_ds_dict[s]
        cur_source['input_path'] = dataset_path
        cur_source['target_path'] = dataset_path
        cur_ds_dict[s] = cur_source
    # Write out the jsons
    cur_ds_corrected_list = WkwData.convert_ds_to_list(datasources_dict=cur_ds_dict)
    WkwData.datasources_to_json(datasources=cur_ds_corrected_list, json_path=ds)
Ejemplo n.º 24
0
wkw_root = '/tmpscratch/webknossos/Connectomics_Department/' \
                  '2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8/color/1'

cache_root = os.path.join(run_root, '.cache/')
# path for the datasource JSON
datasources_json_path = os.path.join(json_root, 'datasources.json')
assert os.path.exists(datasources_json_path)
# other parameterss
data_strata = {'training': [1, 2], 'validate': [3], 'test': []}
input_shape = (302, 302, 1)
output_shape = (302, 302, 1)
norm_mean = 148.0
norm_std = 36.0

# Run
data_sources = WkwData.datasources_from_json(datasources_json_path)

# With Caching (cache filled)
dataset = WkwData(
    data_sources=data_sources,
    data_strata=data_strata,
    input_shape=input_shape,
    target_shape=output_shape,
    norm_mean=norm_mean,
    norm_std=norm_std,
    cache_root=cache_root,
    cache_size=10240,  # MiB
    cache_dim=2,
    cache_range=8)

dataloader = DataLoader(dataset, batch_size=24, shuffle=False, num_workers=16)
Ejemplo n.º 25
0
import os
import time
import torch
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from genEM3.data.wkwdata import WkwData
from genEM3.model.autoencoder2d import AE, Encoder_4_sampling_bn, Decoder_4_sampling_bn
from genEM3.training.autoencoder import Trainer

# Parameters
run_root = os.path.dirname(os.path.abspath(__file__))
datasources_json_path = os.path.join(run_root, 'datasources.json')
input_shape = (302, 302, 1)
output_shape = (302, 302, 1)
data_sources = WkwData.datasources_from_json(datasources_json_path)

# With Caching (cache filled)
dataset = WkwData(input_shape=input_shape,
                  target_shape=output_shape,
                  data_sources=data_sources)

stats = dataset.get_datasource_stats(1)
print(stats)
Ejemplo n.º 26
0
        plt.show()

    for i, example in enumerate(dataLoader_debris):
        plt.imshow(np.squeeze(example[0].numpy()), cmap='gray')
        plt.show()

# Running model ae_v03 on the data
run_root = os.path.dirname(os.path.abspath(__file__))
datasources_json_path = os.path.join(run_root, 'datasources_distributed.json')
# setting for the clean data loader
batch_size = 5
input_shape = (140, 140, 1)
output_shape = (140, 140, 1)
num_workers = 0
# construct clean data loader from json file
datasources = WkwData.datasources_from_json(datasources_json_path)
dataset = WkwData(
    input_shape=input_shape,
    target_shape=output_shape,
    data_sources=datasources,
    cache_HDD=False,
    cache_RAM=True,
)
clean_loader = torch.utils.data.DataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           num_workers=num_workers)
# settings for the model to be loaded
# (Is there a way to save so that you do not need to specify model again?)
state_dict_path = os.path.join(run_root, './.log/torch_model')
device = 'cpu'
kernel_size = 3
Ejemplo n.º 27
0
# Parameters
run_root = os.path.dirname(os.path.abspath(__file__))

wkw_root = '/tmpscratch/webknossos/Connectomics_Department/' \
                  '2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8/color/1'

cache_root = os.path.join(run_root, '.cache/')
datasources_json_path = os.path.join(run_root, 'datasources.json')
data_strata = {'training': [1, 2], 'validate': [3], 'test': []}
input_shape = (302, 302, 1)
output_shape = (302, 302, 1)
norm_mean = 148.0
norm_std = 36.0

# Run
data_sources = WkwData.datasources_from_json(datasources_json_path)

# With Caching (cache filled)
dataset = WkwData(
    data_sources=data_sources,
    data_strata=data_strata,
    input_shape=input_shape,
    target_shape=output_shape,
    norm_mean=norm_mean,
    norm_std=norm_std,
    cache_RAM=True,
    cache_HDD=True,
    cache_HDD_root=cache_root,
)

dataloader = DataLoader(dataset, batch_size=24, shuffle=False, num_workers=0)
Ejemplo n.º 28
0
import os

from genEM3.data.wkwdata import WkwData
from genEM3.util.path import get_data_dir

# Read Json file
json_names = [
    'dense_3X_10_10_2_um/original_merged_double_binary_v01.json',
    '10x_test_bboxes/10X_9_9_1_um_double_binary_v01.json'
]
ds_names = [os.path.join(get_data_dir(), j_name) for j_name in json_names]
data_sources = WkwData.concat_datasources(ds_names)
# Get the short version of the data sources
output_name = os.path.join(get_data_dir(), 'combined',
                           'combined_20K_patches.json')
short_ds = WkwData.convert_to_short_ds(data_sources=data_sources)
# Write combined data source json file
WkwData.write_short_ds_json(datasources=short_ds, json_path=output_name)
Ejemplo n.º 29
0
from genEM3.data.wkwdata import WkwData, DataSource

# Parameters
run_root = os.path.dirname(os.path.abspath(__file__))

wkw_root = '/gaba/tmpscratch/webknossos/Connectomics_Department/' \
                  '2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8/color/1'

cache_root = os.path.join(run_root, '.cache/')
datasources_json_path = os.path.join(run_root, 'datasources.json')
data_strata = {'training': [1, 2], 'validate': [3], 'test': []}
input_shape = (250, 250, 5)
output_shape = (125, 125, 3)

# Run
data_sources = WkwData.datasources_from_json(datasources_json_path)

# No Caching
dataset = WkwData(
    data_sources=data_sources,
    data_strata=data_strata,
    input_shape=input_shape,
    target_shape=output_shape,
    cache_root=None,
    cache_wipe=True,
    cache_size=1024,  #MiB
    cache_dim=2,
    cache_range=8)

t0 = time.time()
for sample_idx in range(8):
Ejemplo n.º 30
0
import os

from genEM3.data.wkwdata import WkwData
from genEM3.util.path import get_data_dir

# Read the two jsons
target_names = ['Debris', 'Myelin']
json_names = ['combined_20K_patches.json', 'combined_20K_patches_v2.json']
full_names = [
    os.path.join(get_data_dir(), 'combined', f_name) for f_name in json_names
]
ds_list = [WkwData.read_short_ds_json(name) for name in full_names]
ds_dict = [WkwData.convert_ds_to_dict(ds) for ds in ds_list]
# Get the difference between the two data sources from jsons
diff_sources = WkwData.compare_ds_targets(two_datasources=ds_dict,
                                          source_names=json_names,
                                          target_names=target_names)
print(diff_sources)