Exemplo n.º 1
0
def load_model(path, device, base_kernel_width=11):
    from collections import OrderedDict
    log = sys.stderr

    # load the model
    if path == 'unet-3d':  # load the pretrained unet model
        name = 'unet-3d-v0.2.4.sav'
        print('# loading pretrained model:', name, file=log)
        model = UDenoiseNet3D(base_width=7)

        import pkg_resources
        pkg = __name__
        path = '../pretrained/denoise/' + name
        f = pkg_resources.resource_stream(pkg, path)
        state_dict = torch.load(f)  # load the parameters

        model.load_state_dict(state_dict)
    else:
        model = torch.load(path)
        if type(model) is OrderedDict:
            state = model
            model = UDenoiseNet3D(base_width=base_kernel_width)
            model.load_state_dict(state)
    model.eval()

    # set the device or devices
    d = device
    use_cuda = (d != -1) and torch.cuda.is_available()
    num_devices = 1
    if use_cuda:
        device_count = torch.cuda.device_count()
        try:
            if d >= 0:
                assert d < device_count
                torch.cuda.set_device(d)
                print('# using CUDA device:', d, file=log)
            elif d == -2:
                print('# using all available CUDA devices:',
                      device_count,
                      file=log)
                num_devices = device_count
                model = nn.DataParallel(model)
            else:
                raise ValueError
        except (AssertionError, ValueError):
            print('ERROR: Invalid device id or format', file=log)
            sys.exit(1)
        except Exception:
            print(
                'ERROR: Something went wrong with setting the compute device',
                file=log)
            sys.exit(2)

    if use_cuda:
        model.cuda()

    return model, num_devices
Exemplo n.º 2
0
def load_model(path, base_kernel_width=11):
    from collections import OrderedDict
    log = sys.stderr

    # load the model
    pretrained = False
    if path == 'unet-3d':  # load the pretrained unet model
        name = 'unet-3d-10a-v0.2.4.sav'
        model = UDenoiseNet3D(base_width=7)
        pretrained = True
    elif path == 'unet-3d-10a':
        name = 'unet-3d-10a-v0.2.4.sav'
        model = UDenoiseNet3D(base_width=7)
        pretrained = True
    elif path == 'unet-3d-20a':
        name = 'unet-3d-20a-v0.2.4.sav'
        model = UDenoiseNet3D(base_width=7)
        pretrained = True

    if pretrained:
        print('# loading pretrained model:', name, file=log)

        import pkg_resources
        pkg = __name__
        path = '../pretrained/denoise/' + name
        f = pkg_resources.resource_stream(pkg, path)
        state_dict = torch.load(f)  # load the parameters

        model.load_state_dict(state_dict)

    else:
        model = torch.load(path)
        if type(model) is OrderedDict:
            state = model
            model = UDenoiseNet3D(base_width=base_kernel_width)
            model.load_state_dict(state)
    model.eval()

    return model
Exemplo n.º 3
0
def train_model(even_path, odd_path, save_prefix, save_interval, device
               , cost_func='L2'
               , weight_decay=0
               , learning_rate=0.001
               , optim='adagrad'
               , momentum=0.8
               , minibatch_size=10
               , num_epochs=500
               , N_train=1000
               , N_test=200
               , tilesize=96
               , num_workers=1
               ):
    output = sys.stdout
    log = sys.stderr

    if save_prefix is not None:
        save_dir = os.path.dirname(save_prefix)
        if not os.path.exists(save_dir):
            print('# creating save directory:', save_dir, file=log)
            os.makedir(save_dir)

    start_time = time.time()
    now = datetime.datetime.now()
    print('# starting time: {:02d}/{:02d}/{:04d} {:02d}h:{:02d}m:{:02d}s'.format(now.month,now.day,now.year,now.hour,now.minute,now.second), file=log)

    # initialize the model
    print('# initializing model...', file=log)
    model = UDenoiseNet3D()
    
    # set the device or devices
    d = device
    use_cuda = (d != -1) and torch.cuda.is_available()
    num_devices = 1
    if use_cuda:
        device_count = torch.cuda.device_count()
        try:
            if d >= 0:
                assert d < device_count
                torch.cuda.set_device(d)
                print('# using CUDA device:', d, file=log)
            elif d == -2:
                print('# using all available CUDA devices:', device_count, file=log)
                model = nn.DataParallel(model)
                num_devices = device_count
            else:
                raise ValueError
        except (AssertionError, ValueError):
            print('ERROR: Invalid device id or format', file=log)
            sys.exit(1)
        except Exception:
            print('ERROR: Something went wrong with setting the compute device', file=log)
            sys.exit(2)

    if use_cuda:
        model.cuda()
    
    if cost_func == 'L2':
        cost_func = nn.MSELoss()
    elif cost_func == 'L1':
        cost_func = nn.L1Loss()
    else:
        cost_func = nn.MSELoss()

    wd = weight_decay
    params = [{'params': model.parameters(), 'weight_decay': wd}]
    lr = learning_rate
    if optim == 'sgd':
        optim = torch.optim.SGD(params, lr=lr, momentum=momentum)
    elif optim == 'rmsprop':
        optim = torch.optim.RMSprop(params, lr=lr)
    elif optim == 'adam':
        optim = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.999), eps=1e-8, amsgrad=True)
    elif optim == 'adagrad':
        optim = torch.optim.Adagrad(params, lr=lr)
    else:
        raise Exception('Unrecognized optim: ' + optim)
        
    # Load the data
    print('# loading data...', file=log)
    if not (os.path.isdir(even_path) or os.path.isfile(even_path)):
        print('ERROR: Cannot find file or directory:', even_path, file=log)
        sys.exit(3)
    if not (os.path.isdir(odd_path) or os.path.isfile(odd_path)):
        print('ERROR: Cannot find directory:', odd_path, file=log)
        sys.exit(3)
    
    if tilesize < 1:
        print('ERROR: tilesize must be >0', file=log)
        sys.exit(4)
    if tilesize < 10:
        print('WARNING: small tilesize is not recommended', file=log)
    data = TrainingDataset3D(even_path, odd_path, tilesize, N_train, N_test)
    
    N_train = len(data)
    data.set_mode('test')
    N_test = len(data)
    data.set_mode('train')
    num_workers = min(num_workers, mp.cpu_count())

    iterator = torch.utils.data.DataLoader(data,batch_size=minibatch_size,num_workers=num_workers,shuffle=False)
    
    ## Begin model training
    print('# training model...', file=log)
    print('\t'.join(['Epoch', 'Split', 'Error']), file=output)

    for epoch in range(num_epochs):
        data.set_mode('train')
        epoch_loss_accum = train_epoch(iterator,
                                       model,
                                       cost_func,
                                       optim,
                                       epoch=epoch,
                                       num_epochs=num_epochs,
                                       N=N_train,
                                       use_cuda=use_cuda)

        line = '\t'.join([str(epoch+1), 'train', str(epoch_loss_accum)])
        print(line, file=output)
        
        # evaluate on the test set
        data.set_mode('test')
        epoch_loss_accum = eval_model(iterator,
                                   model,
                                   cost_func,
                                   epoch=epoch,
                                   num_epochs=num_epochs,
                                   N=N_test,
                                   use_cuda=use_cuda)
    
        line = '\t'.join([str(epoch+1), 'test', str(epoch_loss_accum)])
        print(line, file=output)

        ## save the models
        if save_prefix is not None and (epoch+1)%save_interval == 0:
            model.eval().cpu()
            save_model(model,epoch+1,save_prefix)
            if use_cuda:
                model.cuda()

    print('# training completed!', file=log)

    end_time = time.time()
    now = datetime.datetime.now()
    print("# ending time: {:02d}/{:02d}/{:04d} {:02d}h:{:02d}m:{:02d}s".format(now.month,now.day,now.year,now.hour,now.minute,now.second), file=log)
    print("# total time:", time.strftime("%Hh:%Mm:%Ss", time.gmtime(end_time - start_time)), file=log)

    return model, num_devices