Beispiel #1
0
    def load_from_metrics(cls,
                          weights_path,
                          tags_csv,
                          on_gpu,
                          map_location=None):
        """
        Primary way of loading model from csv weights path
        :param weights_path:
        :param tags_csv:
        :param on_gpu:
        :param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
        :return:
        """
        hparams = load_hparams_from_tags_csv(tags_csv)
        hparams.__setattr__('on_gpu', on_gpu)

        if on_gpu:
            if map_location is not None:
                checkpoint = torch.load(weights_path,
                                        map_location=map_location)
            else:
                checkpoint = torch.load(weights_path)
        else:
            checkpoint = torch.load(weights_path,
                                    map_location=lambda storage, loc: storage)

        model = cls(hparams)

        # allow model to load
        model.load_model_specific(checkpoint)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        return model
    def load_from_metrics(cls,
                          weights_path,
                          tags_csv,
                          on_gpu,
                          map_location=None):
        """
        Primary way of loading model from csv weights path
        :param weights_path:
        :param tags_csv:
        :param on_gpu:
        :param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
        :return:
        """
        hparams = load_hparams_from_tags_csv(tags_csv)
        hparams.__setattr__('on_gpu', on_gpu)

        if on_gpu:
            if map_location is not None:
                checkpoint = torch.load(weights_path,
                                        map_location=map_location)
            else:
                checkpoint = torch.load(weights_path)
        else:
            checkpoint = torch.load(weights_path,
                                    map_location=lambda storage, loc: storage)

        # load the state_dict on the model automatically
        model = cls(hparams)
        model.load_state_dict(checkpoint['state_dict'])

        # give model a chance to load something
        model.on_load_checkpoint(checkpoint)

        return model
Beispiel #3
0
def test_loading_meta_tags():
    hparams = get_hparams()

    # save tags
    exp = get_exp(False)
    exp.tag({'some_str': 'a_str', 'an_int': 1, 'a_float': 2.0})
    exp.argparse(hparams)
    exp.save()

    # load tags
    tags_path = exp.get_data_path(exp.name, exp.version) + '/meta_tags.csv'
    tags = model_saving.load_hparams_from_tags_csv(tags_path)

    assert tags.batch_size == 32 and tags.hidden_dim == 1000

    clear_save_dir()