コード例 #1
0
    def __init__(self):
        melt.PredictorBase.__init__(self)
        ShowAndTell.__init__(self, is_training=False, is_predict=True)

        if FLAGS.pre_calc_image_feature:
            self.image_feature_len = IMAGE_FEATURE_LEN
            self.image_feature_feed = tf.placeholder(
                tf.float32, [None, self.image_feature_len],
                name='image_feature')
        else:
            self.image_feature_feed = tf.placeholder(tf.string, [
                None,
            ],
                                                     name='image_feature')

        tf.add_to_collection('feed', self.image_feature_feed)
        tf.add_to_collection('lfeed', self.image_feature_feed)

        self.text_feed = tf.placeholder(tf.int64, [None, TEXT_MAX_WORDS],
                                        name='text')
        tf.add_to_collection('rfeed', self.text_feed)

        self.beam_text = None
        self.beam_text_score = None

        self.image_model = None
コード例 #2
0
  def __init__(self):
    #super(ShowAndTellPredictor, self).__init__()
    melt.PredictorBase.__init__(self)
    ShowAndTell.__init__(self, is_training=False, is_predict=True)

    self.text_list = []
    self.image_feature_place = tf.placeholder(tf.float32, [None, IMAGE_FEATURE_LEN], name='image_feature')
    self.text_place = tf.placeholder(tf.int64, [None, TEXT_MAX_WORDS], name='text')
コード例 #3
0
    def __init__(self):
        melt.PredictorBase.__init__(self)
        ShowAndTell.__init__(self, is_training=False, is_predict=True)

        if FLAGS.pre_calc_image_feature:
            self.image_feature_len = FLAGS.image_feature_len or IMAGE_FEATURE_LEN
            #TODO for rl, need use feed dict, so predict will introduce ... need to feed, how to use with_default?
            #self.image_feature_feed = tf.placeholder(tf.float32, [None, self.image_feature_len], name='image_feature')
            self.image_feature_feed = tf.placeholder_with_default(
                [[0.] * self.image_feature_len],
                [None, self.image_feature_len],
                name='image_feature')
        else:
            #self.image_feature_feed =  tf.placeholder(tf.string, [None,], name='image_feature')
            # TODO HACK for nasnet... need this due to using average decay
            if os.path.exists('./test.jpg'):
                test_image = melt.read_image('./test.jpg')
            elif os.path.exists('/tmp/test.jpg'):
                test_image = melt.read_image('/tmp/test.jpg')
            else:
                test_image = None

            if test_image is not None:
                self.image_feature_feed = tf.placeholder_with_default(
                    tf.constant([test_image]), [
                        None,
                    ], name='image_feature')
            else:
                assert not FLAGS.image_model_name.startswith(
                    'nasnet'
                ), 'HACK for nasnet you need one test.jpg in current path or /tmp/ path'
                self.image_feature_feed = tf.placeholder(tf.string, [
                    None,
                ],
                                                         name='image_feature')

        tf.add_to_collection('feed', self.image_feature_feed)
        tf.add_to_collection('lfeed', self.image_feature_feed)

        self.text_feed = tf.placeholder(tf.int64, [None, TEXT_MAX_WORDS],
                                        name='text')
        tf.add_to_collection('rfeed', self.text_feed)

        self.text = None
        self.text_score = None
        self.beam_text = None
        self.beam_text_score = None

        self.image_model = None

        self.logprobs_history = False
        self.alignment_history = False

        self.feed_dict = {}
コード例 #4
0
ファイル: algos_factory.py プロジェクト: Hibbert-pku/hasky
def _gen_builder(algo, is_predict=True):
  """
  Args:
  is_predict: set to False if only train, no need for predict/eval
  """
  if is_predict:
    if algo == Algos.bow:
      return BowPredictor()
    elif algo == Algos.show_and_tell:
      return ShowAndTellPredictor()
    elif algo == Algos.rnn:
      return RnnPredictor()
    elif algo == Algos.seq2seq:
      return Seq2seqPredictor()
    else:
      raise ValueError('Unsupported algo %s'%algo) 
  else:
    if algo == Algos.bow:
      return Bow()
    elif algo == Algos.show_and_tell:
      return ShowAndTell()
    elif algo == Algos.rnn:
      return Rnn()
    elif algo == Algos.seq2seq:
      return Seq2seq()
    else:
      raise ValueError('Unsupported algo %s'%algo) 
コード例 #5
0
    def __init__(self):
        #super(ShowAndTellPredictor, self).__init__()
        melt.PredictorBase.__init__(self)
        ShowAndTell.__init__(self, is_training=False, is_predict=True)

        if FLAGS.pre_calc_image_feature:
            self.image_feature_len = IMAGE_FEATURE_LEN
            self.image_feature_feed = tf.placeholder(
                tf.float32, [None, self.image_feature_len],
                name='image_feature')
        else:
            self.image_feature_len = 2048
            self.image_feature_feed = tf.placeholder(tf.string, [
                None,
            ],
                                                     name='image_feature')

        self.text_feed = tf.placeholder(tf.int64, [None, TEXT_MAX_WORDS],
                                        name='text')
コード例 #6
0
def _gen_trainer(algo):
    if algo == Algos.bow:
        return Bow()
    elif algo == Algos.show_and_tell:
        return ShowAndTell()
    elif algo == Algos.rnn:
        return Rnn()
    elif algo == Algos.pooling:
        return Pooling()
    elif algo == Algos.seq2seq:
        return Seq2seq()
    elif algo == Algos.imtxt2txt:
        return Imtxt2txt()
    else:
        raise ValueError('Unsupported algo %s' % algo)
コード例 #7
0
ファイル: algos_factory.py プロジェクト: tangqiqi123/hasky
def _gen_trainer(algo):
    if algo == Algos.bow:
        return DiscriminantTrainer('bow')
    elif algo == Algos.rnn:
        return DiscriminantTrainer('rnn')
    elif algo == Algos.cnn:
        return DiscriminantTrainer('cnn')
    elif algo == Algos.seq2seq:
        return Seq2seq()
    elif algo == Algos.show_and_tell:
        return ShowAndTell()
    elif algo == Algos.imtxt2txt:
        return Imtxt2txt()
    elif algo == Algos.dual_bow:
        return DualTextsim('bow')
    elif algo == Algos.dual_rnn:
        return DualTextsim('rnn')
    elif algo == Algos.dual_cnn:
        return DualTextsim('cnn')
    else:
        raise ValueError('Unsupported algo %s' % algo)