Пример #1
0
    n_classes = 9
    data_init_kwargs = {"raw_only": False, "nb_views": 2, 'train_fraction': 0.95,
                        'nb_views_renderinglocations': 4, #'view_key': "4_large_fov",
                        "reduce_context": 0, "reduce_context_fact": 1, 'ctgt_key': "ctgt_v2", 'random_seed': 0,
                        "binary_views": False, "n_classes": n_classes, 'class_weights': [1] * n_classes}

    if args.resume is not None:  # Load pretrained network
        print('Resuming model from {}.'.format(s.path.expanduser(args.resume)))
        try:  # Assume it's a state_dict for the model
            model.load_state_dict(torch.load(os.path.expanduser(args.resume)))
        except _pickle.UnpicklingError as exc:
            # Assume it's a complete saved ScriptModule
            model = torch.jit.load(os.path.expanduser(args.resume), map_location=device)

    # Specify data set
    transform = transforms.Compose([RandomFlip(ndim_spatial=2), ])
    train_dataset = CelltypeViewsE3(train=True, transform=transform, **data_init_kwargs)
    valid_dataset = CelltypeViewsE3(train=False, transform=transform, **data_init_kwargs)

    # Set up optimization
    optimizer = optim.SGD(
        model.parameters(),
        weight_decay=0.5e-4,
        lr=lr,
        # amsgrad=True
    )
    # lr_sched = optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)
    lr_sched = SGDR(optimizer, 20000, 3)
    schedulers = {'lr': lr_sched}
    # All these metrics assume a binary classification problem. If you have
    #  non-binary targets, remember to adapt the metrics!
Пример #2
0
    lr = 0.0048
    lr_stepsize = 500
    lr_dec = 0.995
    batch_size = 5

    model = get_model()
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        batch_size = batch_size * torch.cuda.device_count()
        # dim = 0 [20, xxx] -> [10, ...], [10, ...] on 2 GPUs
        model = nn.DataParallel(model)
    model.to(device)

    # Specify data set
    transform = transforms.Compose([
        RandomFlip(ndim_spatial=2),
    ])
    train_dataset = MultiviewData(train=True,
                                  transform=transform,
                                  base_dir=global_params.gt_path_axonseg)
    valid_dataset = MultiviewData(train=False,
                                  transform=transform,
                                  base_dir=global_params.gt_path_axonseg)

    # Set up optimization
    optimizer = optim.Adam(model.parameters(),
                           weight_decay=0.5e-4,
                           lr=lr,
                           amsgrad=True)
    lr_sched = optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)