コード例 #1
0
    def from_config(config, load_stored_params=True, model_param_file=None):
        """
        Load a network from a config file

        :param load_stored_params: whether or not to load stored params, if so there should be
            a "path_to_network" entry in the config
        :type load_stored_params: bool

        :param config: Dict specifying details of the network architecture

        e.g.
            path_to_network: /home/manuelli/code/dense_correspondence/recipes/trained_models/10_drill_long_3d
            parameter_file: dense_resnet_34_8s_03505.pth
            descriptor_dimensionality: 3
            image_width: 640
            image_height: 480

        :return: DenseCorrespondenceNetwork
        :rtype:
        """

        #        fcn = resnet_dilated.Resnet34_8s(num_classes=config['descriptor_dimension'])
        fcn = resnet_dilated.Resnet34_8s(
            num_classes=config['descriptor_dimension'])

        if 'normalize' in config:
            normalize = config['normalize']
        else:
            normalize = False

        dcn = DenseCorrespondenceNetwork(fcn,
                                         config['descriptor_dimension'],
                                         image_width=config['image_width'],
                                         image_height=config['image_height'],
                                         normalize=normalize)

        if load_stored_params:
            assert model_param_file is not None
            config[
                'model_param_file'] = model_param_file  # should be an absolute path
            try:
                dcn.load_state_dict(
                    torch.load(model_param_file, map_location='cpu'))
            except:
                logging.info(
                    "loading params with the new style failed, falling back to dcn.fcn.load_state_dict"
                )
                dcn.fcn.load_state_dict(torch.load(model_param_file))

        #dcn.cuda()
        dcn.train()
        dcn.config = config
        return dcn
コード例 #2
0
    def from_model_folder(model_folder, load_stored_params=True, model_param_file=None,
        iteration=None):
        """
        Loads a DenseCorrespondenceNetwork from a model folder
        :param model_folder: the path to the folder where the model is stored. This direction contains
        files like

            - 003500.pth
            - training.yaml

        :type model_folder:
        :return: a DenseCorrespondenceNetwork objecc t
        :rtype:
        """

        model_folder = utils.convert_to_absolute_path(model_folder)

        if model_param_file is None:
            model_param_file, _, _ = utils.get_model_param_file_from_directory(model_folder, iteration=iteration)

        model_param_file = utils.convert_to_absolute_path(model_param_file)

        training_config_filename = os.path.join(model_folder, "training.yaml")
        training_config = utils.getDictFromYamlFilename(training_config_filename)
        config = training_config["dense_correspondence_network"]
        config["path_to_network_params_folder"] = model_folder


        fcn = resnet_dilated.Resnet34_8s(num_classes=config['descriptor_dimension'])

        dcn = DenseCorrespondenceNetwork(fcn, config['descriptor_dimension'],
                                         image_width=config['image_width'],
                                         image_height=config['image_height'])


        # load the stored params
        if load_stored_params:
            # old syntax
            try:
                dcn.load_state_dict(torch.load(model_param_file))
            except:
                logging.info("loading params with the new style failed, falling back to dcn.fcn.load_state_dict")
                dcn.fcn.load_state_dict(torch.load(model_param_file))

            # this is the new format
            #

        dcn.cuda()
        dcn.train()
        dcn.config = config

        return dcn
コード例 #3
0
    def from_config(config, load_stored_params=True):
        """
        Load a network from a config file

        :param load_stored_params: whether or not to load stored params, if so there should be
            a "path_to_network" entry in the config
        :type load_stored_params: bool

        :param config: Dict specifying details of the network architecture

        e.g.
            path_to_network: /home/manuelli/code/dense_correspondence/recipes/trained_models/10_drill_long_3d
            parameter_file: dense_resnet_34_8s_03505.pth
            descriptor_dimensionality: 3
            image_width: 640
            image_height: 480

        :return: DenseCorrespondenceNetwork
        :rtype:
        """

        fcn = resnet_dilated.Resnet34_8s(num_classes=config['descriptor_dimension'])

        if load_stored_params:
            path_to_network_params = utils.convert_to_absolute_path(config['path_to_network_params'])
            config['path_to_network_params_folder'] = os.path.dirname(config['path_to_network_params'])
            fcn.load_state_dict(torch.load(path_to_network_params))



        dcn = DenseCorrespondenceNetwork(fcn, config['descriptor_dimension'],
                                          image_width=config['image_width'],
                                          image_height=config['image_height'])

        dcn.cuda()
        dcn.train()
        dcn.config = config
        return dcn
コード例 #4
0
    # Initialization
    model_path = pwd + '/our_data/resnet_34_8s_68.pth'
    source_video_path = pwd + '/our_data/ariel.mp4'
    target_video_path = pwd + '/our_data/ariel_mask.avi'  # OpenCV must have avi as output. https://github.com/ContinuumIO/anaconda-issues/issues/223#issuecomment-285523938
    skip_frames = 0

    # Create and load weights to pre-trained Resnet-34 model:
    from pathlib import Path
    import urllib.request

    my_file = Path(model_path)
    if not my_file.is_file():
        # If weights file doesn't exist, download it from dropbox
        print("Weights file doesn't exist. Downloading...")
        url = "https://www.dropbox.com/s/91wcu6bpqezu4br/resnet_34_8s_68.pth?dl=1"
        u = urllib.request.urlopen(url)
        data = u.read()
        u.close()

        with open(model_path, "wb") as f:
            f.write(data)
    ##
    fcn = resnet_dilated.Resnet34_8s(num_classes=21)
    fcn.load_state_dict(
        torch.load(model_path, map_location=lambda storage, loc: storage))

    create_masked_video(fcn,
                        source_video_path,
                        target_video_path,
                        skipframes=skip_frames)