Beispiel #1
0
batch_size = 1

if args.resume is not None:  # Load pretrained network params
    model.load_state_dict(torch.load(os.path.expanduser(args.resume)))

dataset_mean = (143.97594, )
dataset_std = (44.264744, )

# Transformations to be applied to samples before feeding them to the network
common_transforms = [
    transforms.Normalize(mean=dataset_mean, std=dataset_std, inplace=True)
]
train_transform = transforms.Compose(common_transforms + [
    transforms.RandomCrop((128, 128)),  # Use smaller patches for training
    transforms.RandomFlip(),
    transforms.AdditiveGaussianNoise(prob=0.5, sigma=0.1)
])
valid_transform = transforms.Compose(common_transforms +
                                     [transforms.RandomCrop((144, 144))])
# Specify data set
train_dataset = SimpleNeuroData2d(train=True,
                                  transform=train_transform,
                                  out_channels=out_channels)
valid_dataset = SimpleNeuroData2d(train=False,
                                  transform=valid_transform,
                                  out_channels=out_channels)

# Set up optimization
optimizer = optim.Adam(model.parameters(),
                       weight_decay=0.5e-4,
                       lr=lr,
Beispiel #2
0
    try:  # Assume it's a state_dict for the model
        model.load_state_dict(torch.load(os.path.expanduser(args.resume)))
    except _pickle.UnpicklingError as exc:
        # Assume it's a complete saved ScriptModule
        model = torch.jit.load(os.path.expanduser(args.resume),
                               map_location=device)

drop_func = transforms.DropIfTooMuchBG(bg_id=3, threshold=0.9)
# Transformations to be applied to samples before feeding them to the network
common_transforms = [
    #transforms.Normalize(mean=dataset_mean, std=dataset_std),
]
train_transform = transforms.Compose(common_transforms + [
    transforms.RandomGrayAugment(channels=[0], prob=0.3),
    transforms.RandomGammaCorrection(gamma_std=0.25, gamma_min=0.25, prob=0.3),
    transforms.AdditiveGaussianNoise(sigma=0.05, channels=[0], prob=0.1),
    transforms.RandomBlurring({'probability': 0.1}), drop_func
])
valid_transform = transforms.Compose(common_transforms + [])

# Specify data set
aniso_factor = 2  # Anisotropy in z dimension. E.g. 2 means half resolution in z dimension.
common_data_kwargs = {  # Common options for training and valid sets.
    'aniso_factor': aniso_factor,
    'patch_shape': (48, 144, 144),
    'num_classes': 6,
    # 'offset': (20, 46, 46),
    'target_discrete_ix': [3, 4, 5]
}

type_args = list(range(len(input_h5data)))