def load_model(model_path): n_segments = 25 n_affinity_maps = 21 bn = model_path.parent.name preeval_func = get_preeval_func(bn) if 'vgg19' in bn: backbone = PretrainedBackBone('vgg19') elif 'resnet50' in bn: backbone = PretrainedBackBone('resnet50') else: backbone = None model = CPM_PAF(n_segments=n_segments, n_affinity_maps=n_affinity_maps, same_output_size=True, backbone=backbone, is_PAF=False) state = torch.load(model_path, map_location='cpu') model.load_state_dict(state['state_dict']) model.eval() return model, preeval_func
model_path = Path.home( ) / 'workspace/WormData/worm-poses/results/' / bn / 'checkpoint.pth.tar' n_segments = 25 n_affinity_maps = 20 cuda_id = 0 if torch.cuda.is_available(): print("THIS IS CUDA!!!!") dev_str = "cuda:" + str(cuda_id) else: dev_str = 'cpu' device = torch.device(dev_str) model = CPM_PAF(n_segments=n_segments, n_affinity_maps=n_affinity_maps, same_output_size=True) state = torch.load(model_path, map_location='cpu') #%% model.load_state_dict(state['state_dict']) model.eval() model = model.to(device) #%% root_dir = Path( '/Users/avelinojaver/OneDrive - Nexus365/worms/Bertie_movies') #fnames = list(root_dir.glob('*.hdf5')) fnames = list(root_dir.glob('CX11314_Ch1_04072017_103259.hdf5')) #fnames = ['/Users/avelinojaver/Downloads/recording61.2r_X1.hdf5']
def train_PAF(data_type='v1', model_name='PAF+CPM', loss_type='mse', cuda_id=0, log_dir_root=log_dir_root_dflt, batch_size=16, num_workers=1, roi_size=96, lr=1e-4, weight_decay=0.0, is_fixed_width=False, **argkws): log_dir = log_dir_root / data_type dflts = data_types_dflts[data_type] root_dir = dflts['root_dir'] flow_args = dflts['flow_args'] flow_args[ 'width2sigma'] = -1 if loss_type == 'maxlikelihood' else flow_args[ 'width2sigma'] is_PAF = ('PAF' in model_name) or (model_name == 'openpose') train_data = read_data_files(root_dir=root_dir, set2read='train') val_data = read_data_files(root_dir=root_dir, set2read='validation') flow_train = SkelMapsRandomFlow(data=train_data, roi_size=roi_size, epoch_size=23040, return_affinity_maps=is_PAF, is_fixed_width=is_fixed_width, **flow_args) flow_val = SkelMapsSimpleFlow(data=val_data, roi_size=roi_size, return_raw_skels=True, width2sigma=flow_args['width2sigma'], return_affinity_maps=is_PAF, is_fixed_width=is_fixed_width) if 'vgg19' in model_name: backbone = PretrainedBackBone('vgg19', pretrained=False) elif 'resnet50' in model_name: backbone = PretrainedBackBone('resnet50', pretrained=False) else: backbone = None if 'CPM' in model_name: model = CPM_PAF(n_segments=flow_train.n_skel_maps_out, n_affinity_maps=flow_train.n_affinity_maps_out, same_output_size=True, backbone=backbone, is_PAF=is_PAF) elif model_name == 'openpose': model = OpenPoseCPM( n_segments=flow_train.n_skel_maps_out, n_affinity_maps=flow_train.n_affinity_maps_out, ) else: raise ValueError(f'Not implemented {model_name}') if model_name == 'openpose': criterion_func = OpenPoseCPMLoss elif is_PAF: criterion_func = CPM_PAF_Loss else: criterion_func = CPM_Loss if loss_type == 'mse': criterion = criterion_func() elif loss_type == 'maxlikelihood': criterion = criterion_func(is_maxlikelihood=True) else: raise ValueError(f'Not implemented {model_name}') preeval_func = get_preeval_func(loss_type) device = get_device(cuda_id) model_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adam(model_params, lr=lr, weight_decay=weight_decay) now = datetime.datetime.now() date_str = now.strftime('%Y%m%d_%H%M%S') str_is_fixed = '-fixW' if is_fixed_width else '' basename = f'{data_type}{str_is_fixed}_{model_name}_{loss_type}_{date_str}_adam_lr{lr}_wd{weight_decay}_batch{batch_size}' train_skeleton_maps(basename, model, device, flow_train, flow_val, criterion, optimizer, log_dir=log_dir, batch_size=batch_size, num_workers=num_workers, preeval_func=preeval_func, **argkws)