def construct_resnet_nostage(chkpt_path): # type: (str) -> (resnet_nostage.ResnetNoStage, resnet_nostage.ResnetNoStageConfig) """ Construct the resnet nostage, load weight and upload the network to cuda. Currently, the parameter are hardcoded. :param chkpt_path: The path to checkpoint :return: (ResnetNoStage, ResnetNoStageConfig) """ # The state dict of the network state_dict = torch.load(chkpt_path) n_keypoint = state_dict['head_net.features.9.weight'].shape[0] // 2 print('n_keypoint', n_keypoint) print(state_dict['head_net.features.9.weight'].shape[0]) assert n_keypoint * 2 == state_dict['head_net.features.9.weight'].shape[0] # Construct the network net_config = resnet_nostage.ResnetNoStageConfig() net_config.num_keypoints = n_keypoint net_config.image_channels = 4 net_config.depth_per_keypoint = 2 net_config.num_layers = 34 network = resnet_nostage.ResnetNoStage(net_config) # Load the network network.load_state_dict(state_dict) network.cuda() network.eval() return network, net_config
def construct_resnet_nostage_lstm(chkpt_path, use_cuda=False): """ Creates the keypoint detection network with LSTM and loads the weights from a saved network. :param chkpt_path: The path to checkpoint :param use_cuda: Whether or not to upload the network to cuda. :return: (ResnetNoStage, ResnetNoStageConfig) """ # The state dict of the network state_dict = torch.load(chkpt_path, map_location="cpu") n_keypoint = state_dict['head_net.features.9.weight'].shape[0] // 2 assert n_keypoint * 2 == state_dict['head_net.features.9.weight'].shape[0] # Construct the network net_config = resnet_nostage.ResnetNoStageConfig() net_config.num_keypoints = n_keypoint # net_config.num_keypoints = 2 net_config.image_channels = 4 net_config.depth_per_keypoint = 2 net_config.num_layers = 18 network = resnet_nostage.ResnetNoStageLSTM(net_config) # Load the network network.load_state_dict(state_dict) if use_cuda: network.cuda() network.eval() return network, net_config