コード例 #1
0
    def load_network(self, config_dict):

        out_file = self.out_file
        model_type_str = config_dict['model_type']
        pain_model = importlib.import_module('models.' + model_type_str)

        network_base = self.load_base_network()

        if 'network_params' in config_dict.keys():
            network_pain = pain_model.PainHead(
                base_network=network_base,
                output_types=config_dict['output_types'],
                **config_dict['network_params'])
        else:
            network_pain = pain_model.PainHead(
                base_network=network_base,
                output_types=config_dict['output_types'])

        pretrained_states = torch.load(out_file, map_location=device)

        utils_train.transfer_partial_weights(pretrained_states,
                                             network_pain,
                                             submodule=0)  #

        # todo transfer weights here
        print(network_pain.to_pain)
        # print (pretrained_states)
        # s = input()
        return network_pain
コード例 #2
0
def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        utils_train.transfer_partial_weights(
            model_zoo.load_url(model_urls['resnet152']), model)
        #model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model
コード例 #3
0
def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        utils_train.transfer_partial_weights(
            model_zoo.load_url(model_urls['resnet18']), model)
        #model.load_state_dict( model_zoo.load_url(model_urls['resnet18']))
    return model
コード例 #4
0
def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        print("resnet_low_level: Loading image net weights...")
        utils_train.transfer_partial_weights(
            model_zoo.load_url(model_urls['resnet101']), model)
        print("resnet_low_level: Done loading image net weights...")
        #model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model
コード例 #5
0
def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetTwoStream(Bottleneck, [3, 8, 36, 1], **kwargs)
    if pretrained:
        print("Loading image net weights...")
        utils_train.transfer_partial_weights(
            model_zoo.load_url(model_urls['resnet152']), model)
        print("Done image net weights...")
        #model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model
コード例 #6
0
    def load_network(self, config_dict):

        network_base = self.load_base_network()
        # fill it with saved params
        pretrained_network_path = config_dict['pretrained_network_path']            
        pretrained_states = torch.load(pretrained_network_path, map_location=device)
        utils_train.transfer_partial_weights(pretrained_states, network_base, submodule=0) # last argument is to remove "network.single" prefix in saved network
        print("Done loading weights from config_dict['pretrained_network_path']", pretrained_network_path)
        # s = input()

        # define the pain model with pretrained encoder
        model_type_str = config_dict['model_type']    
        pain_model = importlib.import_module('models.'+model_type_str)
        if 'network_params' in config_dict.keys():
            network_pain = pain_model.PainHead(base_network = network_base, output_types = config_dict['output_types'],**config_dict['network_params']) 
        else:
            network_pain = pain_model.PainHead(base_network = network_base, output_types = config_dict['output_types']) 
        print (network_pain.to_pain)
        # s = input()
        return network_pain
コード例 #7
0
    def load_network(self, config_dict):
        output_types = config_dict['output_types']

        use_billinear_upsampling = config_dict.get('upsampling_bilinear',
                                                   False)
        lower_billinear = 'upsampling_bilinear' in config_dict.keys(
        ) and config_dict['upsampling_bilinear'] == 'half'
        upper_billinear = 'upsampling_bilinear' in config_dict.keys(
        ) and config_dict['upsampling_bilinear'] == 'upper'

        from_latent_hidden_layers = config_dict.get(
            'from_latent_hidden_layers', 0)
        num_encoding_layers = config_dict.get('num_encoding_layers', 4)

        num_cameras = 4
        if config_dict['active_cameras']:  # for H36M it is set to False
            num_cameras = len(config_dict['active_cameras'])

        if lower_billinear:
            use_billinear_upsampling = False

        if 'model_type' not in config_dict:
            model_type_str = 'unet_encode3D_clean'
        else:
            model_type_str = config_dict['model_type']
        unet_encode3D = importlib.import_module('models.' + model_type_str)
        network_single = unet_encode3D.unet(
            dimension_bg=config_dict['latent_bg'],
            dimension_fg=config_dict['latent_fg'],
            dimension_3d=config_dict['latent_3d'],
            feature_scale=config_dict['feature_scale'],
            shuffle_fg=config_dict['shuffle_fg'],
            shuffle_3d=config_dict['shuffle_3d'],
            latent_dropout=config_dict['latent_dropout'],
            in_resolution=config_dict['inputDimension'],
            encoderType=config_dict['encoderType'],
            is_deconv=not use_billinear_upsampling,
            upper_billinear=upper_billinear,
            lower_billinear=lower_billinear,
            from_latent_hidden_layers=from_latent_hidden_layers,
            n_hidden_to3Dpose=config_dict['n_hidden_to3Dpose'],
            num_encoding_layers=num_encoding_layers,
            output_types=output_types,
            subbatch_size=config_dict['use_view_batches'],
            implicit_rotation=config_dict['implicit_rotation'],
            skip_background=config_dict['skip_background'],
            num_cameras=num_cameras,
        )

        if 'pretrained_network_path' in config_dict.keys():  # automatic
            if config_dict['pretrained_network_path'] == 'MPII2Dpose':
                pretrained_network_path = '/cvlabdata1/home/rhodin/code/humanposeannotation/output_save/CVPR18_H36M/TransferLearning2DNetwork/h36m_23d_crop_relative_s1_s5_aug_from2D_2017-08-22_15-52_3d_resnet/models/network_000000.pth'
                print("Loading weights from MPII2Dpose")
                pretrained_states = torch.load(pretrained_network_path,
                                               map_location=device)
                utils_train.transfer_partial_weights(
                    pretrained_states,
                    network_single,
                    submodule=0,
                    add_prefix='encoder.'
                )  # last argument is to remove "network.single" prefix in saved network
            else:
                print(
                    "Loading weights from config_dict['pretrained_network_path']"
                )
                pretrained_network_path = config_dict[
                    'pretrained_network_path']
                pretrained_states = torch.load(pretrained_network_path,
                                               map_location=device)
                utils_train.transfer_partial_weights(
                    pretrained_states, network_single, submodule=0
                )  # last argument is to remove "network.single" prefix in saved network
                print(
                    "Done loading weights from config_dict['pretrained_network_path']"
                )

        if 'pretrained_posenet_network_path' in config_dict.keys(
        ):  # automatic
            print(
                "Loading weights from config_dict['pretrained_posenet_network_path']"
            )
            pretrained_network_path = config_dict[
                'pretrained_posenet_network_path']
            pretrained_states = torch.load(pretrained_network_path,
                                           map_location=device)
            utils_train.transfer_partial_weights(
                pretrained_states, network_single.to_pose, submodule=0
            )  # last argument is to remove "network.single" prefix in saved network
            print(
                "Done loading weights from config_dict['pretrained_posenet_network_path']"
            )
        return network_single