예제 #1
0
    def __init__(self, model_proto, is_training=False):
        """Initializes the model.

    Args:
      model_proto: an instance of text_classification_model_pb2.TextClassificationModel
      is_training: if True, training graph will be built.
    """
        super(Model, self).__init__(model_proto, is_training)

        if not isinstance(
                model_proto,
                text_classification_model_pb2.TextClassificationModel):
            raise ValueError(
                'The model_proto has to be an instance of TextClassificationModel.'
            )

        options = model_proto

        self._open_vocabulary_list = model_utils.read_vocabulary(
            options.open_vocabulary_file)
        with open(options.open_vocabulary_glove_file, 'rb') as fid:
            self._open_vocabulary_initial_embedding = np.load(fid)

        self._vocabulary_list = model_utils.read_vocabulary(
            options.vocabulary_file)

        self._num_classes = len(self._vocabulary_list)
예제 #2
0
    def __init__(self, model_proto, is_training=False):
        """Initializes the model.

    Args:
      model_proto: an instance of nod_model_pb2.NODModel
      is_training: if True, training graph will be built.
    """
        super(Model, self).__init__(model_proto, is_training)

        if not isinstance(model_proto, nod_model_pb2.NODModel):
            raise ValueError(
                'The model_proto has to be an instance of NODModel.')

        options = model_proto

        self._vocabulary_list = model_utils.read_vocabulary(
            options.vocabulary_file)

        self._num_classes = len(self._vocabulary_list)

        self._feature_extractor = build_faster_rcnn_feature_extractor(
            options.feature_extractor, is_training,
            options.inplace_batchnorm_update)

        self._midn_post_process_fn = function_builder.build_post_processor(
            options.midn_post_process)

        self._oicr_post_process_fn = function_builder.build_post_processor(
            options.oicr_post_process)
예제 #3
0
    def __init__(self, model_proto, is_training=False):
        """Initializes the model.

    Args:
      model_proto: an instance of cam_model_pb2.CAMModel
      is_training: if True, training graph will be built.
    """
        super(Model, self).__init__(model_proto, is_training)

        if not isinstance(model_proto, cam_model_pb2.CAMModel):
            raise ValueError(
                'The model_proto has to be an instance of CAMModel.')

        self._vocabulary_list = model_utils.read_vocabulary(
            model_proto.vocabulary_file)
        tf.logging.info('Load %i classes: %s', len(self._vocabulary_list),
                        ','.join(self._vocabulary_list))

        self._input_scales = [1.0]
        if len(model_proto.input_image_scale) > 0:
            self._input_scales = [
                scale for scale in model_proto.input_image_scale
            ]

        self._cnn_feature_names = [model_proto.cnn_output_name]
        if len(model_proto.cnn_feature_name) > 0:
            self._cnn_feature_names = [
                name for name in model_proto.cnn_feature_name
            ]

        self._anchors = init_grid_anchors.initialize_grid_anchors(
            stride_ratio=0.2)
예제 #4
0
  def __init__(self, model_proto, is_training=False):
    """Initializes the model.

    Args:
      model_proto: an instance of nod5_model_pb2.NOD5Model
      is_training: if True, training graph will be built.
    """
    super(Model, self).__init__(model_proto, is_training)

    if not isinstance(model_proto, nod5_model_pb2.NOD5Model):
      raise ValueError('The model_proto has to be an instance of NOD5Model.')

    options = model_proto

    self._open_vocabulary_list = model_utils.read_vocabulary(
        options.open_vocabulary_file)
    with open(options.open_vocabulary_glove_file, 'rb') as fid:
      self._open_vocabulary_initial_embedding = np.load(fid)

    self._vocabulary_list = model_utils.read_vocabulary(options.vocabulary_file)

    self._num_classes = len(self._vocabulary_list)

    self._feature_extractor = build_faster_rcnn_feature_extractor(
        options.feature_extractor, is_training,
        options.inplace_batchnorm_update)

    self._pcl_preprocess_fn = function_builder.build_post_processor(
        options.pcl_preprocess)

    self._midn_post_process_fn = function_builder.build_post_processor(
        options.midn_post_process)

    self._oicr_post_process_fn = function_builder.build_post_processor(
        options.oicr_post_process)

    self._text_encoding_fn = sequence_encoding.get_encode_fn(
        options.text_encoding)
예제 #5
0
    def __init__(self, model_proto, is_training=False):
        """Initializes the model.

    Args:
      model_proto: an instance of oicr_model_pb2.OICRModel
      is_training: if True, training graph will be built.
    """
        super(Model, self).__init__(model_proto, is_training)

        if not isinstance(model_proto, oicr_model_pb2.OICRModel):
            raise ValueError(
                'The model_proto has to be an instance of OICRModel.')

        self._vocabulary_list = model_utils.read_vocabulary(
            model_proto.vocabulary_file)
        self._num_classes = len(self._vocabulary_list)
예제 #6
0
  def __init__(self, model_proto, is_training=False):
    """Initializes the model.

    Args:
      model_proto: an instance of visual_w2v_model_pb2.VisualW2vModel
      is_training: if True, training graph will be built.
    """
    super(Model, self).__init__(model_proto, is_training)

    if not isinstance(model_proto, visual_w2v_model_pb2.VisualW2vModel):
      raise ValueError(
          'The model_proto has to be an instance of VisualW2vModel.')

    options = model_proto

    self._open_vocabulary_list = model_utils.read_vocabulary(
        options.open_vocabulary_file)
    with open(options.open_vocabulary_glove_file, 'rb') as fid:
      self._open_vocabulary_initial_embedding = np.load(fid)
예제 #7
0
    def __init__(self, model_proto, is_training=False):
        """Initializes the model.

    Args:
      model_proto: an instance of stacked_attn_model_pb2.StackedAttnModel
      is_training: if True, training graph will be built.
    """
        super(Model, self).__init__(model_proto, is_training)

        if not isinstance(model_proto,
                          stacked_attn_model_pb2.StackedAttnModel):
            raise ValueError(
                'The model_proto has to be an instance of StackedAttnModel.')

        options = model_proto

        self._vocabulary_list = model_utils.read_vocabulary(
            options.vocabulary_file)

        self._num_classes = len(self._vocabulary_list)

        self._feature_extractor = build_faster_rcnn_feature_extractor(
            options.feature_extractor, is_training,
            options.inplace_batchnorm_update)
예제 #8
0
    def __init__(self, model_proto, is_training=False):
        """Initializes the model.

    Args:
      model_proto: an instance of reasoning_model_pb2.AdViSEGCN
      is_training: if True, training graph will be built.
    """
        super(Model, self).__init__(model_proto, is_training)

        if not isinstance(model_proto, reasoning_model_pb2.ReasoningModel):
            raise ValueError(
                'The model_proto has to be an instance of AdViSEGCN.')

        options = model_proto

        # Read vocabulary.

        def filter_fn(word_with_freq, min_freq):
            return [word for word, freq in word_with_freq if freq >= min_freq]

        stmt_vocab_with_freq = model_utils.read_vocabulary_with_frequency(
            options.stmt_vocab_list_path)
        stmt_vocab_list = filter_fn(stmt_vocab_with_freq, 5)

        slgn_vocab_with_freq = model_utils.read_vocabulary_with_frequency(
            options.slgn_vocab_list_path)
        slgn_vocab_list = filter_fn(slgn_vocab_with_freq, 20)

        slgn_dbpedia_vocab_with_freq = model_utils.read_vocabulary_with_frequency(
            options.slgn_kb_vocab_list_path)
        slgn_dbpedia_vocab_list = filter_fn(slgn_dbpedia_vocab_with_freq, 20)

        ads_labels = model_utils.read_vocabulary(options.ads_vocab_list_path)

        vocab_list = sorted(
            set(stmt_vocab_list + slgn_vocab_list + slgn_dbpedia_vocab_list +
                ads_labels))
        tf.logging.info('Vocab, len=%i', len(vocab_list))

        # Read glove data.

        word2vec_dict, embedding_dims = {}, options.embedding_dims

        if options.glove_path:
            (word2vec_dict,
             embedding_dims) = model_utils.load_glove_data(options.glove_path)

        oov_word, glove_word, glove_vec = [], [], []

        for word in vocab_list:
            if not word in word2vec_dict:
                oov_word.append(word)
            else:
                glove_word.append(word)
                glove_vec.append(word2vec_dict[word])

        self._embedding_dims = embedding_dims
        self._shared_vocab = glove_word + oov_word + ['out-of-vocabulary']
        if len(glove_vec) > 0:
            self._glove_vec = np.stack(glove_vec, 0)
        else:
            self._glove_vec = np.zeros((0, embedding_dims))

        tf.logging.info('Vocab, glove=%i, all=%i', len(glove_word),
                        len(self._shared_vocab))

        self._shared_dims = options.shared_dims

        # Text encoder.

        self._text_encoder = sequence_encoder.build(options.text_encoder,
                                                    is_training=is_training)

        # Graph creator.

        self._graph_creator = graph_creator.build_graph_creator(
            options.graph_creator, is_training)