def WkwDataSetConstructor(): """ Construsts a WkwData[set] from fixed parameters. These parameters can also be explored for further testing""" # Get data source from example json json_dir = gpath.get_data_dir() datasources_json_path = os.path.join(json_dir, 'datasource_20X_980_980_1000bboxes.json') data_sources = WkwData.datasources_from_json(datasources_json_path) # Only pick the first two bboxes for faster epoch data_sources = data_sources[0:2] data_split = DataSplit(train=0.70, validation=0.00, test=0.30) # input, output shape input_shape = (28, 28, 1) output_shape = (28, 28, 1) # flags for memory and storage caching cache_RAM = True cache_HDD = True # HDD cache directory connDataDir = '/conndata/alik/genEM3_runs/VAE/' cache_root = os.path.join(connDataDir, '.cache/') dataset = WkwData( input_shape=input_shape, target_shape=output_shape, data_sources=data_sources, data_split=data_split, normalize=False, transforms=ToZeroOneRange(minimum=0, maximum=255), cache_RAM=cache_RAM, cache_HDD=cache_HDD, cache_HDD_root=cache_root ) return dataset
cache_root = os.path.join(run_root, '.cache/') datasources_json_path = os.path.join(run_root, 'datasources.json') data_strata = {'training': [1, 2], 'validate': [3], 'test': []} input_shape = (250, 250, 5) output_shape = (125, 125, 3) # Run data_sources = WkwData.datasources_from_json(datasources_json_path) # No Caching dataset = WkwData( data_sources=data_sources, data_strata=data_strata, input_shape=input_shape, target_shape=output_shape, cache_root=None, cache_wipe=True, cache_size=1024, #MiB cache_dim=2, cache_range=8) t0 = time.time() for sample_idx in range(8): print(sample_idx) data = dataset.get_ordered_sample(sample_idx) plt.imshow(data[0][0, :, :, 0].data.numpy()) t1 = time.time() print('No caching: {} seconds'.format(t1 - t0)) # With Caching (cache empty) dataset = WkwData(
cache_root = os.path.join(run_root, '.cache/') batch_size = 256 num_workers = 8 data_sources = WkwData.datasources_from_json(datasources_json_path) transforms = transforms.Compose([ transforms.RandomFlip(p=0.5, flip_plane=(1, 2)), transforms.RandomFlip(p=0.5, flip_plane=(2, 1)), transforms.RandomRotation90(p=1.0, mult_90=[0, 1, 2, 3], rot_plane=(1, 2)) ]) dataset = WkwData(input_shape=input_shape, target_shape=output_shape, data_sources=data_sources, data_split=data_split, transforms=transforms, cache_RAM=cache_RAM, cache_HDD=cache_HDD, cache_HDD_root=cache_HDD_root) # Create the weighted samplers which create imbalance given the factor imbalance_factor = 20 data_loaders = subsetWeightedSampler.get_data_loaders( dataset, imbalance_factor=imbalance_factor, batch_size=batch_size, num_workers=num_workers) input_size = 140 output_size = input_size valid_size = 2 kernel_size = 3
data_strata = {'training': [1, 2], 'validate': [3], 'test': []} input_shape = (302, 302, 1) output_shape = (302, 302, 1) norm_mean = 148.0 norm_std = 36.0 # Run data_sources = WkwData.datasources_from_json(datasources_json_path) # With Caching (cache filled) dataset = WkwData( data_sources=data_sources, data_strata=data_strata, input_shape=input_shape, target_shape=output_shape, norm_mean=norm_mean, norm_std=norm_std, cache_RAM=True, cache_HDD=True, cache_HDD_root=cache_root, ) dataloader = DataLoader(dataset, batch_size=24, shuffle=False, num_workers=0) input_size = 302 output_size = input_size valid_size = 17 kernel_size = 3 stride = 1 n_fmaps = 8 n_latent = 5000
from genEM3.training.metrics import Metrics from genEM3.util.path import get_runs_dir path_in = os.path.join(get_runs_dir(), 'inference/ae_classify_11_parallel/test_center_filt') cache_HDD_root = os.path.join(path_in, '.cache/') path_datasources = os.path.join(path_in, 'datasources.json') path_nml_in = os.path.join(path_in, 'bbox_annotated.nml') input_shape = (140, 140, 1) target_shape = (1, 1, 1) stride = (35, 35, 1) datasources = WkwData.datasources_from_json(path_datasources) dataset = WkwData(input_shape=input_shape, target_shape=target_shape, data_sources=datasources, stride=stride, cache_HDD=False, cache_RAM=True) skel = Skeleton(path_nml_in) pred_df = pd.DataFrame(columns=[ 'tree_idx', 'tree_id', 'x', 'y', 'z', 'xi', 'yi', 'class', 'explicit', 'cluster_id', 'prob' ]) group_ids = np.array(skel.group_ids) input_path = datasources[0].input_path input_bbox = datasources[0].input_bbox structure = np.ones((3, 3), dtype=np.int) cluster_id = 0 for plane_group in skel.groups:
import os import time import torch import numpy as np from matplotlib import pyplot as plt from torch.utils.data import DataLoader from genEM3.data.wkwdata import WkwData from genEM3.model.autoencoder2d import AE, Encoder_4_sampling_bn, Decoder_4_sampling_bn from genEM3.training.autoencoder import Trainer # Parameters run_root = os.path.dirname(os.path.abspath(__file__)) datasources_json_path = os.path.join(run_root, 'datasources.json') input_shape = (302, 302, 1) output_shape = (302, 302, 1) data_sources = WkwData.datasources_from_json(datasources_json_path) # With Caching (cache filled) dataset = WkwData(input_shape=input_shape, target_shape=output_shape, data_sources=data_sources) stats = dataset.get_datasource_stats(1) print(stats)
data_strata = {'training': [1, 2], 'validate': [3], 'test': []} input_shape = (302, 302, 1) output_shape = (302, 302, 1) norm_mean = 148.0 norm_std = 36.0 # Run data_sources = WkwData.datasources_from_json(datasources_json_path) # With Caching (cache filled) dataset = WkwData( data_sources=data_sources, data_strata=data_strata, input_shape=input_shape, target_shape=output_shape, norm_mean=norm_mean, norm_std=norm_std, cache_root=cache_root, cache_size=10240, # MiB cache_dim=2, cache_range=8) dataloader = DataLoader(dataset, batch_size=24, shuffle=False, num_workers=16) input_size = 302 output_size = input_size valid_size = 17 kernel_size = 3 stride = 1 n_fmaps = 8 n_latent = 5000
n_fmaps = 16 n_latent = 2048 input_size = 140 output_size = input_size model = AE_Encoder_Classifier( Encoder_4_sampling_bn_1px_deep_convonly_skip(input_size, kernel_size, stride, n_latent=n_latent), Classifier(n_latent=n_latent)) datasources = WkwData.datasources_from_json(datasources_json_path) dataset = WkwData(input_shape=input_shape, target_shape=output_shape, data_sources=datasources, stride=(70, 70, 1), cache_HDD=True, cache_RAM=True, cache_HDD_root=cache_HDD_root) prediction_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers) checkpoint = torch.load(state_dict_path, map_location=lambda storage, loc: storage) state_dict = checkpoint['model_state_dict'] model.load_state_dict(state_dict) def prob_collate_fn(outputs):
stride = 1 n_fmaps = 16 n_latent = 2048 input_size = 140 output_size = input_size model = AE( Encoder_4_sampling_bn_1px_deep_convonly_skip(input_size, kernel_size, stride, n_fmaps, n_latent), Decoder_4_sampling_bn_1px_deep_convonly_skip(output_size, kernel_size, stride, n_fmaps, n_latent)) datasources = WkwData.datasources_from_json(datasources_json_path) dataset = WkwData(input_shape=input_shape, target_shape=output_shape, data_sources=datasources, data_split=data_split, cache_HDD=True, cache_RAM=True, cache_HDD_root=cache_HDD_root) train_sampler = SubsetRandomSampler(dataset.data_train_inds) train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, collate_fn=dataset.collate_fn) checkpoint = torch.load(state_dict_path, map_location=lambda storage, loc: storage) state_dict = checkpoint['model_state_dict'] model.load_state_dict(state_dict)
def main(): parser = argparse.ArgumentParser(description='Convolutional VAE for 3D electron microscopy data') parser.add_argument('--result_dir', type=str, default='.log', metavar='DIR', help='output directory') parser.add_argument('--batch_size', type=int, default=256, metavar='N', help='input batch size for training (default: 256)') parser.add_argument('--epochs', type=int, default=100, metavar='N', help='number of epochs to train (default: 100)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: None') # model options # Note(AK): with the AE models from genEM3, the 2048 latent size and 16 fmaps are fixed parser.add_argument('--latent_size', type=int, default=2048, metavar='N', help='latent vector size of encoder') parser.add_argument('--max_weight_KLD', type=float, default=1.0, metavar='N', help='Weight for the KLD part of loss') args = parser.parse_args() print('The command line argument:\n') print(args) # Make the directory for the result output if not os.path.isdir(args.result_dir): os.makedirs(args.result_dir) torch.manual_seed(args.seed) # Parameters warmup_kld = True connDataDir = '/conndata/alik/genEM3_runs/VAE/' json_dir = gpath.get_data_dir() datasources_json_path = os.path.join(json_dir, 'datasource_20X_980_980_1000bboxes.json') input_shape = (140, 140, 1) output_shape = (140, 140, 1) data_sources = WkwData.datasources_from_json(datasources_json_path) # # Only pick the first bboxes for faster epoch # data_sources = [data_sources[0]] data_split = DataSplit(train=0.80, validation=0.00, test=0.20) cache_RAM = True cache_HDD = True cache_root = os.path.join(connDataDir, '.cache/') gpath.mkdir(cache_root) # Set up summary writer for tensorboard constructedDirName = ''.join([f'weightedVAE_{args.max_weight_KLD}_warmup_{warmup_kld}_', gpath.gethostnameTimeString()]) tensorBoardDir = os.path.join(connDataDir, constructedDirName) writer = SummaryWriter(log_dir=tensorBoardDir) launch_tb(logdir=tensorBoardDir, port='7900') # Set up data loaders num_workers = 8 dataset = WkwData( input_shape=input_shape, target_shape=output_shape, data_sources=data_sources, data_split=data_split, normalize=False, transforms=ToStandardNormal(mean=148.0, std=36.0), cache_RAM=cache_RAM, cache_HDD=cache_HDD, cache_HDD_root=cache_root ) # Data loaders for training and test train_sampler = SubsetRandomSampler(dataset.data_train_inds) train_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=args.batch_size, num_workers=num_workers, sampler=train_sampler, collate_fn=dataset.collate_fn) test_sampler = SubsetRandomSampler(dataset.data_test_inds) test_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=args.batch_size, num_workers=num_workers, sampler=test_sampler, collate_fn=dataset.collate_fn) # Model and optimizer definition input_size = 140 output_size = 140 kernel_size = 3 stride = 1 # initialize with the given value of KLD (maximum value in case of a warmup scenario) weight_KLD = args.max_weight_KLD model = ConvVAE(latent_size=args.latent_size, input_size=input_size, output_size=output_size, kernel_size=kernel_size, stride=stride, weight_KLD=weight_KLD).to(device) # Add model to the tensorboard as graph add_graph(writer=writer, model=model, data_loader=train_loader, device=device) # print the details of the model print_model = True if print_model: model.summary(input_size=input_size, device=device.type) # set up optimizer optimizer = optim.Adam(model.parameters(), lr=1e-3) start_epoch = 0 best_test_loss = np.finfo('f').max # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint %s' % args.resume) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] + 1 best_test_loss = checkpoint['best_test_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint %s' % args.resume) else: print('=> no checkpoint found at %s' % args.resume) # Training loop for epoch in range(start_epoch, args.epochs): # warmup the kld error linearly if warmup_kld: model.weight_KLD.data = torch.Tensor([((epoch+1) / args.epochs) * args.max_weight_KLD]).to(device) train_loss, train_lossDetailed = train(epoch, model, train_loader, optimizer, args, device=device) test_loss, test_lossDetailed = test(epoch, model, test_loader, writer, args, device=device) # logging, TODO: Use better tags for the logging cur_weight_KLD = model.weight_KLD.detach().item() writer.add_scalar('loss_train/weight_KLD', cur_weight_KLD, epoch) writer.add_scalar('loss_train/total', train_loss, epoch) writer.add_scalar('loss_test/total', test_loss, epoch) writer.add_scalars('loss_train', train_lossDetailed, global_step=epoch) writer.add_scalars('loss_test', test_lossDetailed, global_step=epoch) # add the histogram of weights and biases plus their gradients for name, param in model.named_parameters(): writer.add_histogram(name, param.detach().cpu().data.numpy(), epoch) # weight_KLD is a parameter but does not have a gradient. It creates an error if one # tries to plot the histogram of a None variable if param.grad is not None: writer.add_histogram(name+'_gradient', param.grad.cpu().numpy(), epoch) # plot mu and logvar for latent_prop in ['cur_mu', 'cur_logvar']: latent_val = getattr(model, latent_prop) writer.add_histogram(latent_prop, latent_val.cpu().numpy(), epoch) # flush them to the output writer.flush() print('Epoch [%d/%d] loss: %.3f val_loss: %.3f' % (epoch + 1, args.epochs, train_loss, test_loss)) is_best = test_loss < best_test_loss best_test_loss = min(test_loss, best_test_loss) save_directory = os.path.join(tensorBoardDir, '.log') save_checkpoint({'epoch': epoch, 'best_test_loss': best_test_loss, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, is_best, save_directory) with torch.no_grad(): # Image 64 random sample from the prior latent space and decode sample = torch.randn(64, args.latent_size).to(device) sample = model.decode(sample).cpu() sample_uint8 = undo_normalize(sample, mean=148.0, std=36.0) img = make_grid(sample_uint8) writer.add_image('sampling', img, epoch)
plt.show() # Running model ae_v03 on the data run_root = os.path.dirname(os.path.abspath(__file__)) datasources_json_path = os.path.join(run_root, 'datasources_distributed.json') # setting for the clean data loader batch_size = 5 input_shape = (140, 140, 1) output_shape = (140, 140, 1) num_workers = 0 # construct clean data loader from json file datasources = WkwData.datasources_from_json(datasources_json_path) dataset = WkwData( input_shape=input_shape, target_shape=output_shape, data_sources=datasources, cache_HDD=False, cache_RAM=True, ) clean_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers) # settings for the model to be loaded # (Is there a way to save so that you do not need to specify model again?) state_dict_path = os.path.join(run_root, './.log/torch_model') device = 'cpu' kernel_size = 3 stride = 1 n_fmaps = 16 n_latent = 2048 input_size = 140
def predict_bbox_from_json(bbox_idx, verbose=True): if verbose: print('(' + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ') Starting Parallel Prediction ... bbox: {}'.format(bbox_idx)) run_root = os.path.dirname(os.path.abspath(__file__)) cache_HDD_root = os.path.join(run_root, '.cache/') datasources_json_path = os.path.join(run_root, 'datasources_predict_parallel.json') state_dict_path = os.path.join( run_root, '../../training/ae_classify_v09_3layer_unfreeze_latent_debris_clean_transform_add_clean2_wiggle/.log/run_w_pr/epoch_700/model_state_dict' ) device = 'cpu' output_wkw_root = '/tmpscratch/webknossos/Connectomics_Department/2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8_artifact_pred' output_label = 'probs_sparse' batch_size = 128 input_shape = (140, 140, 1) output_shape = (1, 1, 1) num_workers = 12 kernel_size = 3 stride = 1 n_fmaps = 16 n_latent = 2048 input_size = 140 output_size = input_size model = AE_Encoder_Classifier( Encoder_4_sampling_bn_1px_deep_convonly_skip(input_size, kernel_size, stride, n_latent=n_latent), Classifier3Layered(n_latent=n_latent)) datasources = WkwData.datasources_bbox_from_json( datasources_json_path, bbox_ext=[1024, 1024, 1024], bbox_idx=bbox_idx, datasource_idx=0) dataset = WkwData(input_shape=input_shape, target_shape=output_shape, data_sources=datasources, stride=(35, 35, 1), cache_HDD=False, cache_RAM=False, cache_HDD_root=cache_HDD_root) prediction_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers) checkpoint = torch.load(state_dict_path, map_location=lambda storage, loc: storage) state_dict = checkpoint['model_state_dict'] model.load_state_dict(state_dict) output_prob_fn = lambda x: np.exp(x[:, 1, 0, 0]) # output_dtype = np.uint8 output_dtype = np.float32 # output_dtype_fn = lambda x: (logit(x) + 16) * 256 / 32 output_dtype_fn = lambda x: x # output_dtype_fni = lambda x: expit(x / 256 * 32 - 16) output_dtype_fni = lambda x: x predictor = Predictor(model=model, dataloader=prediction_loader, output_prob_fn=output_prob_fn, output_dtype_fn=output_dtype_fn, output_dtype=output_dtype, output_label=output_label, output_wkw_root=output_wkw_root, output_wkw_compress=True, device=device, interpolate=None) predictor.predict(verbose=verbose)