コード例 #1
0
def construct_resnet_nostage(chkpt_path):
    # type: (str) -> (resnet_nostage.ResnetNoStage, resnet_nostage.ResnetNoStageConfig)
    """
    Construct the resnet nostage, load weight
    and upload the network to cuda.
    Currently, the parameter are hardcoded.
    :param chkpt_path: The path to checkpoint
    :return: (ResnetNoStage, ResnetNoStageConfig)
    """
    # The state dict of the network
    state_dict = torch.load(chkpt_path)
    n_keypoint = state_dict['head_net.features.9.weight'].shape[0] // 2
    print('n_keypoint', n_keypoint)
    print(state_dict['head_net.features.9.weight'].shape[0])
    assert n_keypoint * 2 == state_dict['head_net.features.9.weight'].shape[0]

    # Construct the network
    net_config = resnet_nostage.ResnetNoStageConfig()
    net_config.num_keypoints = n_keypoint
    net_config.image_channels = 4
    net_config.depth_per_keypoint = 2
    net_config.num_layers = 34
    network = resnet_nostage.ResnetNoStage(net_config)

    # Load the network
    network.load_state_dict(state_dict)
    network.cuda()
    network.eval()
    return network, net_config
コード例 #2
0
def construct_resnet_nostage(chkpt_path, use_cuda=False):
    """
    Creates the network and loads the weights from a saved network.
    :param chkpt_path: The path to checkpoint
    :param use_cuda: Whether or not to upload the network to cuda.
    :return: (ResnetNoStage, ResnetNoStageConfig)
    """
    # The state dict of the network
    state_dict = torch.load(chkpt_path, map_location="cpu")
    n_keypoint = state_dict['head_net.features.9.weight'].shape[0] // 2
    assert n_keypoint * 2 == state_dict['head_net.features.9.weight'].shape[0]

    # Construct the network
    net_config = resnet_nostage.ResnetNoStageConfig()
    net_config.num_keypoints = n_keypoint
    # net_config.num_keypoints = 2
    net_config.image_channels = 4
    net_config.depth_per_keypoint = 2
    net_config.num_layers = 18
    network = resnet_nostage.ResnetNoStage(net_config)

    # Load the network
    network.load_state_dict(state_dict)
    if use_cuda:
        network.cuda()
    network.eval()
    return network, net_config