def prob_collate_fn(outputs):
    outputs_collate = list(np.exp(outputs.detach().numpy())[:, 1, :, :])
    return outputs_collate


output_dtype = np.uint8
output_dtype_fn = lambda x: (logit(x) + 6) * 256 / 12
output_dtype_fni = lambda x: expit(x / 256 * 12 - 6)

datawriter_prob = DataWriter(dataloader=prediction_loader,
                             output_collate_fn=prob_collate_fn,
                             output_label='probs_ae_classify_03',
                             output_path=output_wkw_root,
                             output_dtype=output_dtype,
                             output_dtype_fn=output_dtype_fn)

datawriters = {'probs_wkw': datawriter_prob}

predictor = Predictor(dataloader=prediction_loader,
                      datawriters=datawriters,
                      model=model,
                      state_dict=state_dict,
                      device=device,
                      batch_size=batch_size,
                      input_shape=input_shape,
                      output_shape=output_shape,
                      interpolate='nearest')

predictor.predict()
Exemple #2
0
prediction_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                batch_size=batch_size,
                                                num_workers=num_workers)

checkpoint = torch.load(state_dict_path,
                        map_location=lambda storage, loc: storage)
state_dict = checkpoint['model_state_dict']
model.load_state_dict(state_dict)

output_prob_fn = lambda x: np.exp(x[:, 1, 0, 0])
# output_dtype = np.uint8
output_dtype = np.float32
# output_dtype_fn = lambda x: (logit(x) + 16) * 256 / 32
output_dtype_fn = lambda x: x
# output_dtype_fni = lambda x: expit(x / 256 * 32 - 16)
output_dtype_fni = lambda x: x

predictor = Predictor(model=model,
                      dataloader=prediction_loader,
                      output_prob_fn=output_prob_fn,
                      output_dtype_fn=output_dtype_fn,
                      output_dtype=output_dtype,
                      output_label=output_label,
                      output_wkw_root=output_wkw_root,
                      output_wkw_compress=False,
                      device=device,
                      interpolate=None)

predictor.predict()
print('done')
Exemple #3
0
output_size = input_size
model = AE(
    Encoder_4_sampling_bn_1px_deep(input_size, kernel_size, stride, n_fmaps,
                                   n_latent),
    Decoder_4_sampling_bn_1px_deep(output_size, kernel_size, stride, n_fmaps,
                                   n_latent))

datasources = WkwData.datasources_from_json(datasources_json_path)
dataset = WkwData(
    input_shape=input_shape,
    target_shape=output_shape,
    data_sources=datasources,
)

prediction_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                batch_size=batch_size,
                                                num_workers=num_workers)

checkpoint = torch.load(state_dict_path,
                        map_location=lambda storage, loc: storage)
state_dict = checkpoint['model_state_dict']
model.load_state_dict(state_dict)
predictor = Predictor(dataloader=prediction_loader,
                      model=model,
                      state_dict=state_dict,
                      device=device,
                      batch_size=batch_size,
                      input_shape=input_shape,
                      output_shape=output_shape)
predictor.predict()
Exemple #4
0
    Decoder_4_sampling_bn_1px_deep_convonly_skip(output_size, kernel_size,
                                                 stride, n_fmaps, n_latent))
# loading the model
checkpoint = torch.load(state_dict_path,
                        map_location=lambda storage, loc: storage)
state_dict = checkpoint['model_state_dict']
model.load_state_dict(state_dict)

# Create a dictionary to keep the hidden state of debris and clean images
TYPENAMES = ('clean', 'debris')
hidden_dict = {htype: [] for htype in TYPENAMES}
# predicting for clean data
predictor_clean = Predictor(dataloader=clean_loader,
                            model=model,
                            state_dict=state_dict,
                            device=device,
                            batch_size=batch_size,
                            input_shape=input_shape,
                            output_shape=output_shape)
hidden_dict[TYPENAMES[0]] = predictor_clean.encode()

# predicting for debris
predictor_debris = Predictor(dataloader=dataLoader_debris,
                             model=model,
                             state_dict=state_dict,
                             device=device,
                             batch_size=batch_size,
                             input_shape=input_shape,
                             output_shape=output_shape)
hidden_dict[TYPENAMES[1]] = predictor_debris.encodeList()
# Concatenate individual batches into single torch tensors
Exemple #5
0
def predict_bbox_from_json(bbox_idx, verbose=True):

    if verbose:
        print('(' + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +
              ') Starting Parallel Prediction ... bbox: {}'.format(bbox_idx))

    run_root = os.path.dirname(os.path.abspath(__file__))
    cache_HDD_root = os.path.join(run_root, '.cache/')
    datasources_json_path = os.path.join(run_root,
                                         'datasources_predict_parallel.json')
    state_dict_path = os.path.join(
        run_root,
        '../../training/ae_classify_v09_3layer_unfreeze_latent_debris_clean_transform_add_clean2_wiggle/.log/run_w_pr/epoch_700/model_state_dict'
    )
    device = 'cpu'

    output_wkw_root = '/tmpscratch/webknossos/Connectomics_Department/2018-11-13_scMS109_1to7199_v01_l4_06_24_fixed_mag8_artifact_pred'
    output_label = 'probs_sparse'

    batch_size = 128
    input_shape = (140, 140, 1)
    output_shape = (1, 1, 1)
    num_workers = 12

    kernel_size = 3
    stride = 1
    n_fmaps = 16
    n_latent = 2048
    input_size = 140
    output_size = input_size
    model = AE_Encoder_Classifier(
        Encoder_4_sampling_bn_1px_deep_convonly_skip(input_size,
                                                     kernel_size,
                                                     stride,
                                                     n_latent=n_latent),
        Classifier3Layered(n_latent=n_latent))

    datasources = WkwData.datasources_bbox_from_json(
        datasources_json_path,
        bbox_ext=[1024, 1024, 1024],
        bbox_idx=bbox_idx,
        datasource_idx=0)
    dataset = WkwData(input_shape=input_shape,
                      target_shape=output_shape,
                      data_sources=datasources,
                      stride=(35, 35, 1),
                      cache_HDD=False,
                      cache_RAM=False,
                      cache_HDD_root=cache_HDD_root)

    prediction_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                    batch_size=batch_size,
                                                    num_workers=num_workers)

    checkpoint = torch.load(state_dict_path,
                            map_location=lambda storage, loc: storage)
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict)

    output_prob_fn = lambda x: np.exp(x[:, 1, 0, 0])
    # output_dtype = np.uint8
    output_dtype = np.float32
    # output_dtype_fn = lambda x: (logit(x) + 16) * 256 / 32
    output_dtype_fn = lambda x: x
    # output_dtype_fni = lambda x: expit(x / 256 * 32 - 16)
    output_dtype_fni = lambda x: x

    predictor = Predictor(model=model,
                          dataloader=prediction_loader,
                          output_prob_fn=output_prob_fn,
                          output_dtype_fn=output_dtype_fn,
                          output_dtype=output_dtype,
                          output_label=output_label,
                          output_wkw_root=output_wkw_root,
                          output_wkw_compress=True,
                          device=device,
                          interpolate=None)

    predictor.predict(verbose=verbose)
Exemple #6
0
output_prob_fn = lambda x: np.exp(x[:, 1])


def prob_collate_fn(outputs):
    outputs_collate = np.exp(outputs)[:, 1, 0, 0]
    return outputs_collate


output_dtype = np.uint8
output_dtype_fn = lambda x: (logit(x) + 16) * 256 / 32
output_dtype_fni = lambda x: expit(x / 256 * 32 - 16)

datawriter_prob = DataWriter(dataloader=prediction_loader,
                             output_label='probs_ae_classify_09',
                             output_path=output_wkw_root,
                             output_collate_fn=prob_collate_fn,
                             output_write_dtype=output_dtype,
                             output_write_dtype_fn=output_dtype_fn)

datawriters = {'probs_wkw': datawriter_prob}

predictor = Predictor(model=model,
                      dataloader=prediction_loader,
                      datawriters=datawriters,
                      output_prob_fn=output_prob_fn,
                      device=device,
                      interpolate='linear')

predictor.predict()
print('done')