Esempio n. 1
0
def load_model(model_path):

    n_segments = 25
    n_affinity_maps = 21
    bn = model_path.parent.name

    preeval_func = get_preeval_func(bn)

    if 'vgg19' in bn:
        backbone = PretrainedBackBone('vgg19')
    elif 'resnet50' in bn:
        backbone = PretrainedBackBone('resnet50')
    else:
        backbone = None

    model = CPM_PAF(n_segments=n_segments,
                    n_affinity_maps=n_affinity_maps,
                    same_output_size=True,
                    backbone=backbone,
                    is_PAF=False)
    state = torch.load(model_path, map_location='cpu')
    model.load_state_dict(state['state_dict'])
    model.eval()

    return model, preeval_func
Esempio n. 2
0
    model_path = Path.home(
    ) / 'workspace/WormData/worm-poses/results/' / bn / 'checkpoint.pth.tar'

    n_segments = 25
    n_affinity_maps = 20

    cuda_id = 0
    if torch.cuda.is_available():
        print("THIS IS CUDA!!!!")
        dev_str = "cuda:" + str(cuda_id)
    else:
        dev_str = 'cpu'
    device = torch.device(dev_str)

    model = CPM_PAF(n_segments=n_segments,
                    n_affinity_maps=n_affinity_maps,
                    same_output_size=True)

    state = torch.load(model_path, map_location='cpu')
    #%%
    model.load_state_dict(state['state_dict'])
    model.eval()
    model = model.to(device)
    #%%
    root_dir = Path(
        '/Users/avelinojaver/OneDrive - Nexus365/worms/Bertie_movies')
    #fnames = list(root_dir.glob('*.hdf5'))

    fnames = list(root_dir.glob('CX11314_Ch1_04072017_103259.hdf5'))

    #fnames = ['/Users/avelinojaver/Downloads/recording61.2r_X1.hdf5']
Esempio n. 3
0
def train_PAF(data_type='v1',
              model_name='PAF+CPM',
              loss_type='mse',
              cuda_id=0,
              log_dir_root=log_dir_root_dflt,
              batch_size=16,
              num_workers=1,
              roi_size=96,
              lr=1e-4,
              weight_decay=0.0,
              is_fixed_width=False,
              **argkws):

    log_dir = log_dir_root / data_type

    dflts = data_types_dflts[data_type]
    root_dir = dflts['root_dir']
    flow_args = dflts['flow_args']

    flow_args[
        'width2sigma'] = -1 if loss_type == 'maxlikelihood' else flow_args[
            'width2sigma']
    is_PAF = ('PAF' in model_name) or (model_name == 'openpose')

    train_data = read_data_files(root_dir=root_dir, set2read='train')
    val_data = read_data_files(root_dir=root_dir, set2read='validation')
    flow_train = SkelMapsRandomFlow(data=train_data,
                                    roi_size=roi_size,
                                    epoch_size=23040,
                                    return_affinity_maps=is_PAF,
                                    is_fixed_width=is_fixed_width,
                                    **flow_args)

    flow_val = SkelMapsSimpleFlow(data=val_data,
                                  roi_size=roi_size,
                                  return_raw_skels=True,
                                  width2sigma=flow_args['width2sigma'],
                                  return_affinity_maps=is_PAF,
                                  is_fixed_width=is_fixed_width)

    if 'vgg19' in model_name:
        backbone = PretrainedBackBone('vgg19', pretrained=False)
    elif 'resnet50' in model_name:
        backbone = PretrainedBackBone('resnet50', pretrained=False)
    else:
        backbone = None

    if 'CPM' in model_name:
        model = CPM_PAF(n_segments=flow_train.n_skel_maps_out,
                        n_affinity_maps=flow_train.n_affinity_maps_out,
                        same_output_size=True,
                        backbone=backbone,
                        is_PAF=is_PAF)

    elif model_name == 'openpose':
        model = OpenPoseCPM(
            n_segments=flow_train.n_skel_maps_out,
            n_affinity_maps=flow_train.n_affinity_maps_out,
        )
    else:
        raise ValueError(f'Not implemented {model_name}')

    if model_name == 'openpose':
        criterion_func = OpenPoseCPMLoss
    elif is_PAF:
        criterion_func = CPM_PAF_Loss
    else:
        criterion_func = CPM_Loss

    if loss_type == 'mse':
        criterion = criterion_func()
    elif loss_type == 'maxlikelihood':
        criterion = criterion_func(is_maxlikelihood=True)
    else:
        raise ValueError(f'Not implemented {model_name}')

    preeval_func = get_preeval_func(loss_type)

    device = get_device(cuda_id)

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adam(model_params,
                                 lr=lr,
                                 weight_decay=weight_decay)

    now = datetime.datetime.now()
    date_str = now.strftime('%Y%m%d_%H%M%S')

    str_is_fixed = '-fixW' if is_fixed_width else ''
    basename = f'{data_type}{str_is_fixed}_{model_name}_{loss_type}_{date_str}_adam_lr{lr}_wd{weight_decay}_batch{batch_size}'

    train_skeleton_maps(basename,
                        model,
                        device,
                        flow_train,
                        flow_val,
                        criterion,
                        optimizer,
                        log_dir=log_dir,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        preeval_func=preeval_func,
                        **argkws)