コード例 #1
0
lr_dec = 0.995
batch_size = 1

if args.resume is not None:  # Load pretrained network params
    model.load_state_dict(torch.load(os.path.expanduser(args.resume)))

dataset_mean = (143.97594, )
dataset_std = (44.264744, )

# Transformations to be applied to samples before feeding them to the network
common_transforms = [
    transforms.Normalize(mean=dataset_mean, std=dataset_std, inplace=True)
]
train_transform = transforms.Compose(common_transforms + [
    transforms.RandomCrop((128, 128)),  # Use smaller patches for training
    transforms.RandomFlip(),
    transforms.AdditiveGaussianNoise(prob=0.5, sigma=0.1)
])
valid_transform = transforms.Compose(common_transforms +
                                     [transforms.RandomCrop((144, 144))])
# Specify data set
train_dataset = SimpleNeuroData2d(train=True,
                                  transform=train_transform,
                                  out_channels=out_channels)
valid_dataset = SimpleNeuroData2d(train=False,
                                  transform=valid_transform,
                                  out_channels=out_channels)

# Set up optimization
optimizer = optim.Adam(model.parameters(),
                       weight_decay=0.5e-4,
コード例 #2
0
            if lr_sched_state_dict is None:
                logger.warning('lr_sched_state_dict not found.')
        elif isinstance(state, nn.Module):
            logger.warning(_warning_str)
            model = state
        else:
            raise ValueError(f'Can\'t load {pretrained}.')

# Transformations to be applied to samples before feeding them to the network
common_transforms = [
    transforms.SqueezeTarget(dim=0),  # Workaround for neuro_data_cdhw
    transforms.Normalize(mean=dataset_mean, std=dataset_std)
]
train_transform = transforms.Compose(common_transforms + [
    # transforms.RandomRotate2d(prob=0.9),
    # transforms.RandomGrayAugment(channels=[0], prob=0.3),
    # transforms.RandomGammaCorrection(gamma_std=0.25, gamma_min=0.25, prob=0.3),
    # transforms.AdditiveGaussianNoise(sigma=0.1, channels=[0], prob=0.3),
])
valid_transform = transforms.Compose(common_transforms + [])

# Specify data set
aniso_factor = 2  # Anisotropy in z dimension. E.g. 2 means half resolution in z dimension.
common_data_kwargs = {  # Common options for training and valid sets.
    'aniso_factor': aniso_factor,
    'patch_shape': (44, 88, 88),
    # 'offset': (8, 20, 20),
    # 'in_memory': True  # Uncomment to avoid disk I/O (if you have enough host memory for the data)
}
train_dataset = PatchCreator(input_sources=[
    input_h5data[i] for i in range(len(input_h5data)) if i not in valid_indices
],
コード例 #3
0
ファイル: simple2d.py プロジェクト: xiaoliang008/elektronn3
    batch_size = 1

    model = get_model()

    if args.resume is not None:  # Load pretrained network params
        model.load_state_dict(torch.load(os.path.expanduser(args.resume)))

    dataset_mean = (143.97594,)
    dataset_std = (44.264744,)

    # Transformations to be applied to samples before feeding them to the network
    common_transforms = [
        transforms.Normalize(mean=dataset_mean, std=dataset_std)
    ]
    train_transform = transforms.Compose(common_transforms + [
        transforms.RandomCrop((128, 128))  # Use smaller patches for training
    ])
    valid_transform = transforms.Compose(common_transforms + [])# Specify data set
    train_dataset = SimpleNeuroData2d(train=True, transform=train_transform,
                                      num_classes=2)
    valid_dataset = SimpleNeuroData2d(train=False, transform=valid_transform,
                                      num_classes=2)

    # Set up optimization
    optimizer = optim.Adam(
        model.parameters(),
        weight_decay=0.5e-4,
        lr=lr,
        amsgrad=True
    )
    lr_sched = optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)
コード例 #4
0
ファイル: cnn_celltype_e3.py プロジェクト: russell0/SyConn
    n_classes = 9
    data_init_kwargs = {"raw_only": False, "nb_views": 2, 'train_fraction': 0.95,
                        'nb_views_renderinglocations': 4, #'view_key': "4_large_fov",
                        "reduce_context": 0, "reduce_context_fact": 1, 'ctgt_key': "ctgt_v2", 'random_seed': 0,
                        "binary_views": False, "n_classes": n_classes, 'class_weights': [1] * n_classes}

    if args.resume is not None:  # Load pretrained network
        print('Resuming model from {}.'.format(s.path.expanduser(args.resume)))
        try:  # Assume it's a state_dict for the model
            model.load_state_dict(torch.load(os.path.expanduser(args.resume)))
        except _pickle.UnpicklingError as exc:
            # Assume it's a complete saved ScriptModule
            model = torch.jit.load(os.path.expanduser(args.resume), map_location=device)

    # Specify data set
    transform = transforms.Compose([RandomFlip(ndim_spatial=2), ])
    train_dataset = CelltypeViewsE3(train=True, transform=transform, **data_init_kwargs)
    valid_dataset = CelltypeViewsE3(train=False, transform=transform, **data_init_kwargs)

    # Set up optimization
    optimizer = optim.SGD(
        model.parameters(),
        weight_decay=0.5e-4,
        lr=lr,
        # amsgrad=True
    )
    # lr_sched = optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)
    lr_sched = SGDR(optimizer, 20000, 3)
    schedulers = {'lr': lr_sched}
    # All these metrics assume a binary classification problem. If you have
    #  non-binary targets, remember to adapt the metrics!
コード例 #5
0
if args.resume is not None:  # Load pretrained network
    try:  # Assume it's a state_dict for the model
        model.load_state_dict(torch.load(os.path.expanduser(args.resume)))
    except _pickle.UnpicklingError as exc:
        # Assume it's a complete saved ScriptModule
        model = torch.jit.load(os.path.expanduser(args.resume),
                               map_location=device)

drop_func = transforms.DropIfTooMuchBG(bg_id=3, threshold=0.9)
# Transformations to be applied to samples before feeding them to the network
common_transforms = [
    #transforms.Normalize(mean=dataset_mean, std=dataset_std),
]
train_transform = transforms.Compose(common_transforms + [
    transforms.RandomGrayAugment(channels=[0], prob=0.3),
    transforms.RandomGammaCorrection(gamma_std=0.25, gamma_min=0.25, prob=0.3),
    transforms.AdditiveGaussianNoise(sigma=0.05, channels=[0], prob=0.1),
    transforms.RandomBlurring({'probability': 0.1}), drop_func
])
valid_transform = transforms.Compose(common_transforms + [])

# Specify data set
aniso_factor = 2  # Anisotropy in z dimension. E.g. 2 means half resolution in z dimension.
common_data_kwargs = {  # Common options for training and valid sets.
    'aniso_factor': aniso_factor,
    'patch_shape': (48, 144, 144),
    'num_classes': 6,
    # 'offset': (20, 46, 46),
    'target_discrete_ix': [3, 4, 5]
}

type_args = list(range(len(input_h5data)))
コード例 #6
0
    if args.resume is not None:  # Load pretrained network params
        model.load_state_dict(torch.load(os.path.expanduser(args.resume)))

    # These statistics are computed from the training dataset.
    # Remember to re-compute and change them when switching the dataset.
    dataset_mean = (155.291411,)
    dataset_std = (42.599973,)

    # Transformations to be applied to samples before feeding them to the network
    common_transforms = [
        transforms.SqueezeTarget(dim=0),  # Workaround for neuro_data_cdhw
        transforms.Normalize(mean=dataset_mean, std=dataset_std)
    ]
    train_transform = transforms.Compose(common_transforms + [
        # transforms.RandomGrayAugment(channels=[0], prob=0.3),
        # transforms.AdditiveGaussianNoise(sigma=0.1, channels=[0], prob=0.3),
        # transforms.RandomBlurring({'probability': 0.5})
    ])
    valid_transform = transforms.Compose(common_transforms + [])

    # Specify data set
    common_data_kwargs = {  # Common options for training and valid sets.
        'aniso_factor': 2,
        'patch_shape': (48, 96, 96),
        'classes': [0, 1],
    }
    train_dataset = PatchCreator(
        input_h5data=input_h5data[:2],
        target_h5data=target_h5data[:2],
        train=True,
        epoch_size=args.epoch_size,
コード例 #7
0
                logger.warning('optimizer_state_dict not found.')
            if lr_sched_state_dict is None:
                logger.warning('lr_sched_state_dict not found.')
        elif isinstance(state, nn.Module):
            logger.warning(_warning_str)
            model = state
        else:
            raise ValueError(f'Can\'t load {pretrained}.')

# Transformations to be applied to samples before feeding them to the network
common_transforms = [
    transforms.SqueezeTarget(dim=0),
]
train_transform = transforms.Compose(common_transforms + [
    transforms.RandomFlip(ndim_spatial=3),
    transforms.RandomGrayAugment(channels=[0], prob=0.3),
    transforms.RandomGammaCorrection(gamma_std=0.25, gamma_min=0.25, prob=0.3),
    transforms.AdditiveGaussianNoise(sigma=0.1, channels=[0], prob=0.3),
])
valid_transform = transforms.Compose(common_transforms + [])

# Specify data set
aniso_factor = 2  # Anisotropy in z dimension. E.g. 2 means half resolution in z dimension.
common_data_kwargs = {  # Common options for training and valid sets.
    'aniso_factor': aniso_factor,
    'patch_shape': (48, 96, 96),
    # 'offset': (8, 20, 20),
    'num_classes': 2,
    # 'in_memory': True  # Uncomment to avoid disk I/O (if you have enough host memory for the data)
}
train_dataset = PatchCreator(input_h5data=[
    input_h5data[i] for i in range(len(input_h5data)) if i not in valid_indices