Example #1
0
  def __init__(self, cfg):
    # making session
    self.cfg = cfg
    self.step = 0
    self.epsilon = 1.0
    self._use_labeler_as_reward = cfg.use_labeler_as_reward
    self._use_oracle_instruction = cfg.use_oracle_instruction
    self._use_synonym_for_rollout = cfg.use_synonym_for_rollout

    # Vocab loading
    vocab_path = get_vocab_path(cfg)
    self.vocab_list = wv.load_vocab_list(vocab_path)
    self.vocab_list = ['eos', 'sos', 'nothing'] + self.vocab_list[1:]
    v2i, i2v = wv.create_look_up_table(self.vocab_list)
    self.encode_fn = wv.encode_text_with_lookup_table(
        v2i, max_sequence_length=self.cfg.max_sequence_length)
    self.decode_fn = wv.decode_with_lookup_table(i2v)

    labeler_config = get_labeler_config(cfg, self.vocab_list)

    if self._use_labeler_as_reward or not self._use_oracle_instruction:
      self.labeler = Labeler(labeler_config=labeler_config)
      self.labeler.set_captioning_model(
          labeler_config, labeler_config['captioning_weight_path'])
      self.labeler.set_answering_model(labeler_config,
                                       labeler_config['answering_weight_path'])
Example #2
0
  def __init__(self, cfg):
    # making session
    self.cfg = cfg
    self.step = 0
    self.epsilon = 1.0

    # Vocab loading
    vocab_path = get_vocab_path(cfg)
    self.vocab_list = wv.load_vocab_list(vocab_path)
    v2i, i2v = wv.create_look_up_table(self.vocab_list)
    self.encode_fn = wv.encode_text_with_lookup_table(
        v2i, max_sequence_length=self.cfg.max_sequence_length)
    self.decode_fn = wv.decode_with_lookup_table(i2v)