Exemplo n.º 1
0
def get_model():
    # Initialize neural network model
    model = UNet(
        n_blocks=3,
        start_filts=32,
        planar_blocks=(1,),
        activation='relu',
        batch_norm=True
    ).to(device)
    return model
Exemplo n.º 2
0
def benchmark(float16, inp_shape):
    torch.cuda.empty_cache()
    dim = len(inp_shape) - 2
    device = torch.device('cuda')
    n = 10  # Number of measured repetitions
    dtype = torch.float16 if float16 else torch.float32
    print(f'dtype: {dtype}')
    print(f'dim: {dim}')
    print(f'inp_shape: {inp_shape}')
    _float16_str = 'float16' if float16 else 'float32'
    experiment_name = f'{dim}d_{_float16_str}'
    print(f'Experiment name: {experiment_name}')

    model = UNet(
        dim=dim,
        out_channels=2,
        n_blocks=4,
        start_filts=32,
        activation='relu',
        normalization='batch',
        # conv_mode='valid',
    ).to(device, dtype)
    if jit:
        model = torch.jit.script(model)

    x_warmup = torch.randn(*inp_shape, dtype=dtype)

    print(' == Warming up...')
    r = model(x_warmup.to(device)).cpu()
    torch.cuda.synchronize()
    del r

    torch.cuda.empty_cache()

    # Generate random inputs of same shape for measurements
    xm = [torch.randn(*inp_shape, dtype=dtype) for _ in range(n)]

    print(' == Start timing inference speed...')
    start_total = time.time()

    for i in range(n):
        startm = time.time()
        model(xm[i].to(device)).cpu()
        torch.cuda.synchronize()
        dt = time.time() - startm
        print(f'Inference run time (sec): {dt:.2f}')

    dt_total = time.time() - start_total
    dt_total_per_run = dt_total / n
    throughput = np.prod(inp_shape) / dt_total_per_run
    mvoxs = throughput / 1e6
    print(f'Average inference time ({n} runs) (sec): {dt_total_per_run:.2f}')
    print(f'Average MVox/s: {mvoxs:.2f}')
    print('\n\n')
Exemplo n.º 3
0
def get_model():
    # vgg_model = VGGNet(model='vgg13', requires_grad=True, in_channels=4)
    # model = FCNs(base_net=vgg_model, n_class=4)
    model = UNet(
        in_channels=4,
        out_channels=4,
        n_blocks=5,
        start_filts=32,
        up_mode='resize',
        merge_mode='concat',
        planar_blocks=(),
        activation='relu',
        batch_norm=True,
        dim=2,
    )
    return model
Exemplo n.º 4
0
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
logger.info(f'Running on device: {device}')

# You can store selected hyperparams in a dict for logging to tensorboard, e.g.
# hparams = {'n_blocks': 4, 'start_filts': 32, 'planar_blocks': (0,)}
hparams = {}

out_channels = 2
model = UNet(
    out_channels=out_channels,
    n_blocks=4,
    start_filts=32,
    planar_blocks=(0, ),
    activation='relu',
    normalization='batch',
    # conv_mode='valid',
    # full_norm=False,  # Uncomment to restore old sparse normalization scheme
    # up_mode='resizeconv_nearest',  # Enable to avoid checkerboard artifacts
).to(device)
# Example for a model-compatible input.
example_input = torch.ones(1, 1, 32, 64, 64)

enable_save_trace = False if args.jit == 'disabled' else True
if args.jit == 'onsave':
    # Make sure that tracing works
    tracedmodel = torch.jit.trace(model, example_input.to(device))
elif args.jit == 'train':
    if getattr(model, 'checkpointing', False):
        raise NotImplementedError(
Exemplo n.º 5
0
if deterministic:
    torch.backends.cudnn.deterministic = True
else:
    torch.backends.cudnn.benchmark = True  # Improves overall performance in *most* cases

if not args.disable_cuda and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Running on device: {device}')

out_channels = 2
model = UNet(out_channels=out_channels,
             n_blocks=4,
             start_filts=32,
             activation='relu',
             normalization='batch',
             dim=2).to(device)

# USER PATHS
save_root = os.path.expanduser('~/e3training/')

max_steps = args.max_steps
lr = 0.0004
lr_stepsize = 1000
lr_dec = 0.995
batch_size = 1

if args.resume is not None:  # Load pretrained network params
    model.load_state_dict(torch.load(os.path.expanduser(args.resume)))
Exemplo n.º 6
0
from elektronn3.training import SWA
from elektronn3.modules import DiceLoss, CombinedLoss
from elektronn3.models.unet import UNet


if not args.disable_cuda and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
logger.info(f'Running on device: {device}')

model = UNet(
    n_blocks=4,
    start_filts=32,
    planar_blocks=(0,),
    activation='relu',
    batch_norm=True,
    # conv_mode='valid',
    # up_mode='resizeconv_nearest',  # Enable to avoid checkerboard artifacts
    adaptive=True  # Experimental. Disable if results look weird.
).to(device)
# Example for a model-compatible input.
example_input = torch.ones(1, 1, 32, 64, 64)

enable_save_trace = False if args.jit == 'disabled' else True
if args.jit == 'onsave':
    # Make sure that tracing works
    tracedmodel = torch.jit.trace(model, example_input.to(device))
    model = tracedmodel
elif args.jit == 'train':
    if getattr(model, 'checkpointing', False):
        raise NotImplementedError(
Exemplo n.º 7
0
# Don't move this stuff, it needs to be run this early to work
import elektronn3
elektronn3.select_mpl_backend('Agg')
import numpy as np
from elektronn3.data import PatchCreator, transforms, utils, get_preview_batch
from elektronn3.training import Trainer, Backup, metrics, Padam, handlers
from elektronn3.models.unet import UNet
from elektronn3.modules.loss import DiceLoss, DiceLossFancy

torch.backends.cudnn.benchmark = True  # Improves overall performance in *most* cases

model = UNet(
    in_channels=1,
    out_channels=29,  # vec. field 0-2, syntype 3-6, celltype 7-17 and 18-28
    n_blocks=4,
    start_filts=28,
    planar_blocks=(0, ),
    activation='relu',
    batch_norm=True,
    adaptive=False  # Experimental. Disable if results look weird.
).to(device)

# Example for a model-compatible input.
example_input = torch.randn(1, 1, 48, 144, 144)

enable_save_trace = False if args.jit == 'disabled' else True
if args.jit == 'onsave':
    # Make sure that tracing works
    tracedmodel = torch.jit.trace(model, example_input.to(device))
elif args.jit == 'train':
    if getattr(model, 'checkpointing', False):
        raise NotImplementedError(
Exemplo n.º 8
0
#     n_blocks=5,
#     start_filts=64,
#     planar_blocks=(1,2),
#     activation='relu',
#     batch_norm=True,
#     valid_blocks = (0,1,2),
#     adaptive=False
# ).to(device)
# offset = model.offset

model = UNet(
    in_channels=1,
    out_channels=4,
    n_blocks=4,
    start_filts=28,
    planar_blocks=(0, ),
    activation='relu',
    batch_norm=False,
    # conv_mode='valid',
    #up_mode='resizeconv_nearest',  # Enable to avoid checkerboard artifacts
    adaptive=True  # Experimental. Disable if results look weird.
).to(device)

# Example for a model-compatible input.
example_input = torch.randn(1, 1, 40, 144, 144)

enable_save_trace = False if args.jit == 'disabled' else True
if args.jit == 'onsave':
    # Make sure that tracing works
    tracedmodel = torch.jit.trace(model, example_input.to(device))
elif args.jit == 'train':
    if getattr(model, 'checkpointing', False):
Exemplo n.º 9
0
if deterministic:
    torch.backends.cudnn.deterministic = True
else:
    torch.backends.cudnn.benchmark = True  # Improves overall performance in *most* cases

if not args.disable_cuda and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Running on device: {device}')

model = UNet(
    n_blocks=4,
    start_filts=32,
    activation='relu',
    batch_norm=True,
    dim=2
).to(device)
if not args.disable_trace:
    x = torch.randn(1, 1, 64, 64, device=device)
    model = torch.jit.trace(model, x)


# USER PATHS
save_root = os.path.expanduser('~/e3training/')

max_steps = args.max_steps
lr = 0.0004
lr_stepsize = 1000
lr_dec = 0.995