Beispiel #1
0
def initHook(obj, *file_list, **kwargs):
    """
    Define the type and size of input feature and label
    """

    obj.feat_size = 1024
    obj.audio_size = 128
    obj.label_size = 4716
    obj.slots = [
        DenseSlot(obj.feat_size),
        DenseSlot(obj.audio_size),
        SparseNonValueSlot(obj.label_size)
    ]
Beispiel #2
0
def main():
    conf = parse_config("./mnist_model/trainer_config.conf.norm", "")
    print conf.data_config.load_data_args
    network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
    assert isinstance(network, swig_paddle.GradientMachine)  # For code hint.
    network.loadParameters("./mnist_model/")
    converter = DataProviderWrapperConverter(False, [DenseSlot(784)])
    inArg = converter(TEST_DATA)
    print network.forwardTest(inArg)
Beispiel #3
0
    def __init__(self, train_conf, model_dir=None,
                 resize_dim=256, crop_dim=224,
                 mean_file=None,
                 output_layer=None,
                 oversample=False, is_color=True):
        """
        train_conf: network configure.
        model_dir: string, directory of model.
        resize_dim: int, resized image size.
        crop_dim: int, crop size.
        mean_file: string, image mean file.
        oversample: bool, oversample means multiple crops, namely five
                    patches (the four corner patches and the center
                    patch) as well as their horizontal reflections,
                    ten crops in all.
        """
        self.train_conf = train_conf
        self.model_dir = model_dir
        if model_dir is None:
            self.model_dir = os.path.dirname(train_conf)

        self.resize_dim = resize_dim
        self.crop_dims = [crop_dim, crop_dim]
        self.oversample = oversample
        self.is_color = is_color

        self.output_layer = output_layer
        if self.output_layer:
            assert isinstance(self.output_layer, basestring)
            self.output_layer = self.output_layer.split(",")

        self.transformer = image_util.ImageTransformer(is_color = is_color)
        self.transformer.set_transpose((2,0,1))
        self.transformer.set_channel_swap((2,1,0))

        self.mean_file = mean_file
        if self.mean_file is not None:
            mean = np.load(self.mean_file)['data_mean']
            mean = mean.reshape(3, self.crop_dims[0], self.crop_dims[1])
            self.transformer.set_mean(mean) # mean pixel
        else:
            # if you use three mean value, set like:
            # this three mean value is calculated from ImageNet.
            self.transformer.set_mean(np.array([103.939,116.779,123.68]))

        conf_args = "is_test=1,use_gpu=1,is_predict=1"
        conf = parse_config(train_conf, conf_args)
        swig_paddle.initPaddle("--use_gpu=1")
        self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
        assert isinstance(self.network, swig_paddle.GradientMachine)
        self.network.loadParameters(self.model_dir)

        data_size = 3 * self.crop_dims[0] * self.crop_dims[1]
        slots = [DenseSlot(data_size)]
        is_sequence = False
        self.converter = util.DataProviderWrapperConverter(is_sequence, slots)
Beispiel #4
0
def initHook(obj, *file_list, **kwargs):
    """
    Define the type and size of input feature and label
    """

    config_map = {}
    configs = [args.split(':') for args in kwargs['load_data_args'].split(';')]
    for confs in configs:
        config_map[confs[0]] = confs[1]
    logger.info(config_map)

    obj.ftr_dim = int(config_map.get('ftr_dim'))
    obj.label_size = int(config_map.get('label_size'))
    obj.slots = [DenseSlot(obj.ftr_dim), SparseNonValueSlot(obj.label_size)]
Beispiel #5
0
    def __init__(self,
                 train_conf,
                 use_gpu=True,
                 model_dir=None,
                 resize_dim=None,
                 crop_dim=None,
                 mean_file=None,
                 oversample=False,
                 is_color=True):
        """
        train_conf: network configure.
        model_dir: string, directory of model.
        resize_dim: int, resized image size.
        crop_dim: int, crop size.
        mean_file: string, image mean file.
        oversample: bool, oversample means multiple crops, namely five
                    patches (the four corner patches and the center
                    patch) as well as their horizontal reflections,
                    ten crops in all.
        """
        self.train_conf = train_conf
        self.model_dir = model_dir
        if model_dir is None:
            self.model_dir = os.path.dirname(train_conf)

        self.resize_dim = resize_dim
        self.crop_dims = [crop_dim, crop_dim]
        self.oversample = oversample
        self.is_color = is_color

        self.transformer = image_util.ImageTransformer(is_color=is_color)
        self.transformer.set_transpose((2, 0, 1))

        self.mean_file = mean_file
        mean = np.load(self.mean_file)['data_mean']
        mean = mean.reshape(3, self.crop_dims[0], self.crop_dims[1])
        self.transformer.set_mean(mean)  # mean pixel
        gpu = 1 if use_gpu else 0
        conf_args = "is_test=1,use_gpu=%d,is_predict=1" % (gpu)
        conf = parse_config(train_conf, conf_args)
        swig_paddle.initPaddle("--use_gpu=%d" % (gpu))
        self.network = swig_paddle.GradientMachine.createFromConfigProto(
            conf.model_config)
        assert isinstance(self.network, swig_paddle.GradientMachine)
        self.network.loadParameters(self.model_dir)

        data_size = 3 * self.crop_dims[0] * self.crop_dims[1]
        slots = [DenseSlot(data_size)]
        self.converter = util.DataProviderWrapperConverter(False, slots)
def initHook(obj, *file_list, **kwargs):
    """
    Description: Init with a list of data file
    file_list is the name list of input files.
    kwargs['load_data_args'] is the value of 'load_data_args'
    which can be set in config.
    kwargs['load_data_args'] is organized as follows:
        'dictionary path'(str)
        'image feature list file'(str)
        'img_feat_dim'(int)
        'average norm factor for image feature'(float)
    Each args is seperated by a space.
    """
    str_conf_args = kwargs['load_data_args'].strip().split()
    dict_file = str_conf_args[0]
    img_feat_list = str_conf_args[1]
    img_feat_dim = int(str_conf_args[2])
    feat_avg_norm_factor = float(str_conf_args[3])

    LOG.info('Dictionary path: %s', dict_file)
    LOG.info('Image feature list: %s', img_feat_list)
    LOG.info('Image dimension: %d', img_feat_dim)
    LOG.info('Image feature norm factor: %.4f', feat_avg_norm_factor)

    if os.path.isfile(dict_file):
        word_dict = cPickle.load(io.open(dict_file, 'rb'))
        if word_dict.get('#OOV#', -1) == -1:
            word_dict['#OOV#'] = 0
        if word_dict.get('$$S$$', -1) == -1:
            word_dict['$$S$$'] = len(word_dict)
        if word_dict.get('$$E$$', -1) == -1:
            word_dict['$$E$$'] = len(word_dict)
        LOG.info('Dictionary loaded with %d words', len(word_dict))
    else:
        LOG.fatal('Dictionary file %s does not exist!', dict_file)
        sys.exit(1)

    if len(file_list) == 0:
        LOG.fatal('No annotation file!')
        sys.exit(1)
    else:
        LOG.info('There are %d annotation files', len(file_list))
    file_name = file_list[0].strip()
    if os.path.isfile(file_name):
        LOG.debug('Annotation file name: %s', file_name)
    else:
        LOG.fatal('Annotation file %s missing!', file_name)
        sys.exit(1)

    if os.path.exists(img_feat_list):
        img_feat_list = io.open(img_feat_list, 'rb').readlines()
        if len(img_feat_list) == 0:
            LOG.fatal('No image feature file!')
            sys.exit(1)
        else:
            LOG.info('There are %d feature files', len(img_feat_list))
    else:
        LOG.fatal('Image feature list %s does not exist!', img_feat_list)
        sys.exit(1)
    file_name = img_feat_list[0].strip()
    if os.path.isfile(file_name):
        LOG.debug('Image feature file name: %s', file_name)
    else:
        LOG.fatal('Image feature file %s missing!', file_name)
        sys.exit(1)

    obj.file_list = list(file_list)
    obj.word_dict = word_dict
    obj.features = load_image_feature(img_feat_list)
    obj.img_feat_dim = img_feat_dim
    obj.feat_avg_norm_factor = feat_avg_norm_factor
    LOG.info('DataProvider Initialization finished')
    obj.slots = [
        IndexSlot(1),
        DenseSlot(img_feat_dim),
        IndexSlot(len(word_dict))
    ]