Пример #1
0
def WkwDataSetConstructor():
    """ Construsts a WkwData[set] from fixed parameters. These parameters can also be explored for 
        further testing"""    
    # Get data source from example json
    json_dir = gpath.get_data_dir()
    datasources_json_path = os.path.join(json_dir, 'datasource_20X_980_980_1000bboxes.json')
    data_sources = WkwData.datasources_from_json(datasources_json_path)
    # Only pick the first two bboxes for faster epoch
    data_sources = data_sources[0:2]
    data_split = DataSplit(train=0.70, validation=0.00, test=0.30)
    # input, output shape
    input_shape = (28, 28, 1)
    output_shape = (28, 28, 1)
    # flags for memory and storage caching
    cache_RAM = True
    cache_HDD = True
    # HDD cache directory
    connDataDir = '/conndata/alik/genEM3_runs/VAE/'
    cache_root = os.path.join(connDataDir, '.cache/')
    dataset = WkwData(
        input_shape=input_shape,
        target_shape=output_shape,
        data_sources=data_sources,
        data_split=data_split,
        normalize=False,
        transforms=ToZeroOneRange(minimum=0, maximum=255),
        cache_RAM=cache_RAM,
        cache_HDD=cache_HDD,
        cache_HDD_root=cache_root
    )
    return dataset
Пример #2
0
cache_root = os.path.join(run_root, '.cache/')
datasources_json_path = os.path.join(run_root, 'datasources.json')
data_strata = {'training': [1, 2], 'validate': [3], 'test': []}
input_shape = (250, 250, 5)
output_shape = (125, 125, 3)

# Run
data_sources = WkwData.datasources_from_json(datasources_json_path)

# No Caching
dataset = WkwData(
    data_sources=data_sources,
    data_strata=data_strata,
    input_shape=input_shape,
    target_shape=output_shape,
    cache_root=None,
    cache_wipe=True,
    cache_size=1024,  #MiB
    cache_dim=2,
    cache_range=8)

t0 = time.time()
for sample_idx in range(8):
    print(sample_idx)
    data = dataset.get_ordered_sample(sample_idx)
    plt.imshow(data[0][0, :, :, 0].data.numpy())
t1 = time.time()
print('No caching: {} seconds'.format(t1 - t0))

# With Caching (cache empty)
dataset = WkwData(
Пример #3
0
cache_root = os.path.join(run_root, '.cache/')
batch_size = 256
num_workers = 8

data_sources = WkwData.datasources_from_json(datasources_json_path)

transforms = transforms.Compose([
    transforms.RandomFlip(p=0.5, flip_plane=(1, 2)),
    transforms.RandomFlip(p=0.5, flip_plane=(2, 1)),
    transforms.RandomRotation90(p=1.0, mult_90=[0, 1, 2, 3], rot_plane=(1, 2))
])

dataset = WkwData(input_shape=input_shape,
                  target_shape=output_shape,
                  data_sources=data_sources,
                  data_split=data_split,
                  transforms=transforms,
                  cache_RAM=cache_RAM,
                  cache_HDD=cache_HDD,
                  cache_HDD_root=cache_HDD_root)
# Create the weighted samplers which create imbalance given the factor
imbalance_factor = 20
data_loaders = subsetWeightedSampler.get_data_loaders(
    dataset,
    imbalance_factor=imbalance_factor,
    batch_size=batch_size,
    num_workers=num_workers)

input_size = 140
output_size = input_size
valid_size = 2
kernel_size = 3
Пример #4
0
data_strata = {'training': [1, 2], 'validate': [3], 'test': []}
input_shape = (302, 302, 1)
output_shape = (302, 302, 1)
norm_mean = 148.0
norm_std = 36.0

# Run
data_sources = WkwData.datasources_from_json(datasources_json_path)

# With Caching (cache filled)
dataset = WkwData(
    data_sources=data_sources,
    data_strata=data_strata,
    input_shape=input_shape,
    target_shape=output_shape,
    norm_mean=norm_mean,
    norm_std=norm_std,
    cache_RAM=True,
    cache_HDD=True,
    cache_HDD_root=cache_root,
)

dataloader = DataLoader(dataset, batch_size=24, shuffle=False, num_workers=0)

input_size = 302
output_size = input_size
valid_size = 17
kernel_size = 3
stride = 1
n_fmaps = 8
n_latent = 5000
Пример #5
0
from genEM3.training.metrics import Metrics
from genEM3.util.path import get_runs_dir

path_in = os.path.join(get_runs_dir(),
                       'inference/ae_classify_11_parallel/test_center_filt')
cache_HDD_root = os.path.join(path_in, '.cache/')
path_datasources = os.path.join(path_in, 'datasources.json')
path_nml_in = os.path.join(path_in, 'bbox_annotated.nml')
input_shape = (140, 140, 1)
target_shape = (1, 1, 1)
stride = (35, 35, 1)

datasources = WkwData.datasources_from_json(path_datasources)
dataset = WkwData(input_shape=input_shape,
                  target_shape=target_shape,
                  data_sources=datasources,
                  stride=stride,
                  cache_HDD=False,
                  cache_RAM=True)

skel = Skeleton(path_nml_in)

pred_df = pd.DataFrame(columns=[
    'tree_idx', 'tree_id', 'x', 'y', 'z', 'xi', 'yi', 'class', 'explicit',
    'cluster_id', 'prob'
])
group_ids = np.array(skel.group_ids)
input_path = datasources[0].input_path
input_bbox = datasources[0].input_bbox
structure = np.ones((3, 3), dtype=np.int)
cluster_id = 0
for plane_group in skel.groups:
Пример #6
0
import os
import time
import torch
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from genEM3.data.wkwdata import WkwData
from genEM3.model.autoencoder2d import AE, Encoder_4_sampling_bn, Decoder_4_sampling_bn
from genEM3.training.autoencoder import Trainer

# Parameters
run_root = os.path.dirname(os.path.abspath(__file__))
datasources_json_path = os.path.join(run_root, 'datasources.json')
input_shape = (302, 302, 1)
output_shape = (302, 302, 1)
data_sources = WkwData.datasources_from_json(datasources_json_path)

# With Caching (cache filled)
dataset = WkwData(input_shape=input_shape,
                  target_shape=output_shape,
                  data_sources=data_sources)

stats = dataset.get_datasource_stats(1)
print(stats)
Пример #7
0
data_strata = {'training': [1, 2], 'validate': [3], 'test': []}
input_shape = (302, 302, 1)
output_shape = (302, 302, 1)
norm_mean = 148.0
norm_std = 36.0

# Run
data_sources = WkwData.datasources_from_json(datasources_json_path)

# With Caching (cache filled)
dataset = WkwData(
    data_sources=data_sources,
    data_strata=data_strata,
    input_shape=input_shape,
    target_shape=output_shape,
    norm_mean=norm_mean,
    norm_std=norm_std,
    cache_root=cache_root,
    cache_size=10240,  # MiB
    cache_dim=2,
    cache_range=8)

dataloader = DataLoader(dataset, batch_size=24, shuffle=False, num_workers=16)

input_size = 302
output_size = input_size
valid_size = 17
kernel_size = 3
stride = 1
n_fmaps = 8
n_latent = 5000
Пример #8
0
n_fmaps = 16
n_latent = 2048
input_size = 140
output_size = input_size
model = AE_Encoder_Classifier(
    Encoder_4_sampling_bn_1px_deep_convonly_skip(input_size,
                                                 kernel_size,
                                                 stride,
                                                 n_latent=n_latent),
    Classifier(n_latent=n_latent))

datasources = WkwData.datasources_from_json(datasources_json_path)
dataset = WkwData(input_shape=input_shape,
                  target_shape=output_shape,
                  data_sources=datasources,
                  stride=(70, 70, 1),
                  cache_HDD=True,
                  cache_RAM=True,
                  cache_HDD_root=cache_HDD_root)

prediction_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                batch_size=batch_size,
                                                num_workers=num_workers)

checkpoint = torch.load(state_dict_path,
                        map_location=lambda storage, loc: storage)
state_dict = checkpoint['model_state_dict']
model.load_state_dict(state_dict)


def prob_collate_fn(outputs):
Пример #9
0
stride = 1
n_fmaps = 16
n_latent = 2048
input_size = 140
output_size = input_size
model = AE(
    Encoder_4_sampling_bn_1px_deep_convonly_skip(input_size, kernel_size,
                                                 stride, n_fmaps, n_latent),
    Decoder_4_sampling_bn_1px_deep_convonly_skip(output_size, kernel_size,
                                                 stride, n_fmaps, n_latent))

datasources = WkwData.datasources_from_json(datasources_json_path)
dataset = WkwData(input_shape=input_shape,
                  target_shape=output_shape,
                  data_sources=datasources,
                  data_split=data_split,
                  cache_HDD=True,
                  cache_RAM=True,
                  cache_HDD_root=cache_HDD_root)

train_sampler = SubsetRandomSampler(dataset.data_train_inds)
train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           sampler=train_sampler,
                                           collate_fn=dataset.collate_fn)

checkpoint = torch.load(state_dict_path,
                        map_location=lambda storage, loc: storage)
state_dict = checkpoint['model_state_dict']
model.load_state_dict(state_dict)
Пример #10
0
def main():
    parser = argparse.ArgumentParser(description='Convolutional VAE for 3D electron microscopy data')
    parser.add_argument('--result_dir', type=str, default='.log', metavar='DIR',
                        help='output directory')
    parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: None')

    # model options
    # Note(AK): with the AE models from genEM3, the 2048 latent size and 16 fmaps are fixed
    parser.add_argument('--latent_size', type=int, default=2048, metavar='N',
                        help='latent vector size of encoder')
    parser.add_argument('--max_weight_KLD', type=float, default=1.0, metavar='N',
                        help='Weight for the KLD part of loss')

    args = parser.parse_args()
    print('The command line argument:\n')
    print(args)

    # Make the directory for the result output
    if not os.path.isdir(args.result_dir):
        os.makedirs(args.result_dir)

    torch.manual_seed(args.seed)
    # Parameters
    warmup_kld = True
    connDataDir = '/conndata/alik/genEM3_runs/VAE/'
    json_dir = gpath.get_data_dir()
    datasources_json_path = os.path.join(json_dir, 'datasource_20X_980_980_1000bboxes.json')
    input_shape = (140, 140, 1)
    output_shape = (140, 140, 1)
    data_sources = WkwData.datasources_from_json(datasources_json_path)
    # # Only pick the first bboxes for faster epoch
    # data_sources = [data_sources[0]]
    data_split = DataSplit(train=0.80, validation=0.00, test=0.20)
    cache_RAM = True
    cache_HDD = True
    cache_root = os.path.join(connDataDir, '.cache/')
    gpath.mkdir(cache_root)

    # Set up summary writer for tensorboard
    constructedDirName = ''.join([f'weightedVAE_{args.max_weight_KLD}_warmup_{warmup_kld}_', gpath.gethostnameTimeString()])
    tensorBoardDir = os.path.join(connDataDir, constructedDirName)
    writer = SummaryWriter(log_dir=tensorBoardDir)
    launch_tb(logdir=tensorBoardDir, port='7900')
    # Set up data loaders
    num_workers = 8
    dataset = WkwData(
        input_shape=input_shape,
        target_shape=output_shape,
        data_sources=data_sources,
        data_split=data_split,
        normalize=False,
        transforms=ToStandardNormal(mean=148.0, std=36.0),
        cache_RAM=cache_RAM,
        cache_HDD=cache_HDD,
        cache_HDD_root=cache_root
    )
    # Data loaders for training and test
    train_sampler = SubsetRandomSampler(dataset.data_train_inds)
    train_loader = torch.utils.data.DataLoader(
        dataset=dataset, batch_size=args.batch_size, num_workers=num_workers, sampler=train_sampler,
        collate_fn=dataset.collate_fn)

    test_sampler = SubsetRandomSampler(dataset.data_test_inds)
    test_loader = torch.utils.data.DataLoader(
        dataset=dataset, batch_size=args.batch_size, num_workers=num_workers, sampler=test_sampler,
        collate_fn=dataset.collate_fn)
    # Model and optimizer definition
    input_size = 140
    output_size = 140
    kernel_size = 3
    stride = 1
    # initialize with the given value of KLD (maximum value in case of a warmup scenario)
    weight_KLD = args.max_weight_KLD
    model = ConvVAE(latent_size=args.latent_size,
                    input_size=input_size,
                    output_size=output_size,
                    kernel_size=kernel_size,
                    stride=stride,
                    weight_KLD=weight_KLD).to(device)
    # Add model to the tensorboard as graph
    add_graph(writer=writer, model=model, data_loader=train_loader, device=device)
    # print the details of the model
    print_model = True
    if print_model:
        model.summary(input_size=input_size, device=device.type)
    # set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    start_epoch = 0
    best_test_loss = np.finfo('f').max

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint %s' % args.resume)
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_test_loss = checkpoint['best_test_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint %s' % args.resume)
        else:
            print('=> no checkpoint found at %s' % args.resume)
    # Training loop
    for epoch in range(start_epoch, args.epochs):
        # warmup the kld error linearly
        if warmup_kld:
            model.weight_KLD.data = torch.Tensor([((epoch+1) / args.epochs) * args.max_weight_KLD]).to(device) 

        train_loss, train_lossDetailed = train(epoch, model, train_loader, optimizer, args,
                                               device=device)
        test_loss, test_lossDetailed = test(epoch, model, test_loader, writer, args,
                                            device=device)

        # logging, TODO: Use better tags for the logging
        cur_weight_KLD = model.weight_KLD.detach().item()
        writer.add_scalar('loss_train/weight_KLD', cur_weight_KLD, epoch)
        writer.add_scalar('loss_train/total', train_loss, epoch)
        writer.add_scalar('loss_test/total', test_loss, epoch)
        writer.add_scalars('loss_train', train_lossDetailed, global_step=epoch)
        writer.add_scalars('loss_test', test_lossDetailed, global_step=epoch)
        # add the histogram of weights and biases plus their gradients
        for name, param in model.named_parameters():
            writer.add_histogram(name, param.detach().cpu().data.numpy(), epoch)
            # weight_KLD is a parameter but does not have a gradient. It creates an error if one 
            # tries to plot the histogram of a None variable
            if param.grad is not None:
                writer.add_histogram(name+'_gradient', param.grad.cpu().numpy(), epoch)
        # plot mu and logvar
        for latent_prop in ['cur_mu', 'cur_logvar']:
            latent_val = getattr(model, latent_prop)
            writer.add_histogram(latent_prop, latent_val.cpu().numpy(), epoch)
        # flush them to the output
        writer.flush()
        print('Epoch [%d/%d] loss: %.3f val_loss: %.3f' % (epoch + 1, args.epochs, train_loss, test_loss))
        is_best = test_loss < best_test_loss
        best_test_loss = min(test_loss, best_test_loss)
        save_directory = os.path.join(tensorBoardDir, '.log')
        save_checkpoint({'epoch': epoch,
                         'best_test_loss': best_test_loss,
                         'state_dict': model.state_dict(),
                         'optimizer': optimizer.state_dict()},
                        is_best,
                        save_directory)

        with torch.no_grad():
            # Image 64 random sample from the prior latent space and decode
            sample = torch.randn(64, args.latent_size).to(device)
            sample = model.decode(sample).cpu()
            sample_uint8 = undo_normalize(sample, mean=148.0, std=36.0)
            img = make_grid(sample_uint8)
            writer.add_image('sampling', img, epoch)
Пример #11
0
        plt.show()

# Running model ae_v03 on the data
run_root = os.path.dirname(os.path.abspath(__file__))
datasources_json_path = os.path.join(run_root, 'datasources_distributed.json')
# setting for the clean data loader
batch_size = 5
input_shape = (140, 140, 1)
output_shape = (140, 140, 1)
num_workers = 0
# construct clean data loader from json file
datasources = WkwData.datasources_from_json(datasources_json_path)
dataset = WkwData(
    input_shape=input_shape,
    target_shape=output_shape,
    data_sources=datasources,
    cache_HDD=False,
    cache_RAM=True,
)
clean_loader = torch.utils.data.DataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           num_workers=num_workers)
# settings for the model to be loaded
# (Is there a way to save so that you do not need to specify model again?)
state_dict_path = os.path.join(run_root, './.log/torch_model')
device = 'cpu'
kernel_size = 3
stride = 1
n_fmaps = 16
n_latent = 2048
input_size = 140
Пример #12
0
def predict_bbox_from_json(bbox_idx, verbose=True):

    if verbose:
        print('(' + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +
              ') Starting Parallel Prediction ... bbox: {}'.format(bbox_idx))

    run_root = os.path.dirname(os.path.abspath(__file__))
    cache_HDD_root = os.path.join(run_root, '.cache/')
    datasources_json_path = os.path.join(run_root,
                                         'datasources_predict_parallel.json')
    state_dict_path = os.path.join(
        run_root,
        '../../training/ae_classify_v09_3layer_unfreeze_latent_debris_clean_transform_add_clean2_wiggle/.log/run_w_pr/epoch_700/model_state_dict'
    )
    device = 'cpu'

    output_wkw_root = '/tmpscratch/webknossos/Connectomics_Department/2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8_artifact_pred'
    output_label = 'probs_sparse'

    batch_size = 128
    input_shape = (140, 140, 1)
    output_shape = (1, 1, 1)
    num_workers = 12

    kernel_size = 3
    stride = 1
    n_fmaps = 16
    n_latent = 2048
    input_size = 140
    output_size = input_size
    model = AE_Encoder_Classifier(
        Encoder_4_sampling_bn_1px_deep_convonly_skip(input_size,
                                                     kernel_size,
                                                     stride,
                                                     n_latent=n_latent),
        Classifier3Layered(n_latent=n_latent))

    datasources = WkwData.datasources_bbox_from_json(
        datasources_json_path,
        bbox_ext=[1024, 1024, 1024],
        bbox_idx=bbox_idx,
        datasource_idx=0)
    dataset = WkwData(input_shape=input_shape,
                      target_shape=output_shape,
                      data_sources=datasources,
                      stride=(35, 35, 1),
                      cache_HDD=False,
                      cache_RAM=False,
                      cache_HDD_root=cache_HDD_root)

    prediction_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                    batch_size=batch_size,
                                                    num_workers=num_workers)

    checkpoint = torch.load(state_dict_path,
                            map_location=lambda storage, loc: storage)
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict)

    output_prob_fn = lambda x: np.exp(x[:, 1, 0, 0])
    # output_dtype = np.uint8
    output_dtype = np.float32
    # output_dtype_fn = lambda x: (logit(x) + 16) * 256 / 32
    output_dtype_fn = lambda x: x
    # output_dtype_fni = lambda x: expit(x / 256 * 32 - 16)
    output_dtype_fni = lambda x: x

    predictor = Predictor(model=model,
                          dataloader=prediction_loader,
                          output_prob_fn=output_prob_fn,
                          output_dtype_fn=output_dtype_fn,
                          output_dtype=output_dtype,
                          output_label=output_label,
                          output_wkw_root=output_wkw_root,
                          output_wkw_compress=True,
                          device=device,
                          interpolate=None)

    predictor.predict(verbose=verbose)