Пример #1
0
def inception_v3(pretrained=False, model_root=None, **kwargs):
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        model = Inception3(**kwargs)
        misc.load_state_dict(model, model_urls['inception_v3_google'],
                             model_root)
        return model

    return Inception3(**kwargs)
Пример #2
0
def resnet152(pretrained=False, model_root=None, **kwargs):
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        misc.load_state_dict(model, model_urls['resnet152'], model_root)
    return model
Пример #3
0
def resnet34(pretrained=False, model_root=None, **kwargs):
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        misc.load_state_dict(model, model_urls['resnet34'], model_root)
    return model
Пример #4
0
def sphere64a(pretrained=False, model_root=None, stage = 0):
    model=SphereNet(BasicBlock, [3,8,16,3], stage)
    if pretrained:
        misc.load_state_dict(model, model_root)
    return model
Пример #5
0
            # compute the transformation between the current source and nearest destination points
            T, R, t = best_fit_transform(vertices2, vertices1[indices])

            # update the current source
            vertices2 = (np.dot(R, vertices2.T)).T + t

            # check error
            mean_error = np.sqrt(np.mean(distances ** 2))
            if np.abs(prev_error - mean_error) < 0.00001:
                break

            prev_error = mean_error
            # print(mean_error)

        return mean_error


if __name__ == "__main__":
    print("loading model...")
    dict_file = "/home/jdq/model/dict.cl_382500.cl"
    model = net.sphere64a(pretrained=True, model_root=dict_file)
    model = model.cuda()
    print("Loading...\n")
    state = torch.load(dict_file)
    state_dict = state
    misc.load_state_dict(model, dict_file)
    print("evaluting...")
    e = EvalToolBox()
    e.get_micc_rmse(model)
    # evalMICC()