Esempio n. 1
0
  def _process_step_inputs(self, inputs, maxlen=None):
    """Turn a list of MemoryInputTuple into one MemoryInputTuple.

    Args:
      inputs: a list of MemoryInputTuple, like [MemTuple(1, 2, [1,2,3]),
        MemTuple(1, 2, [1,2,3])...].
      maxlen: Maximum length of a program.

    Returns:
      processed_inputs: a MemoryInputTuple like
        MemTuple(np.array([1, 1, ...]), np.array([2, 2, ...]),
                 np.array([[1, 2, 3, -1, ...], [1, 2, 3, -1,...]))).
    """
    read_ind = np.array([[x[0].read_ind for x in seq] for seq in inputs])
    write_ind = np.array([[x[0].write_ind for x in seq] for seq in inputs])
    valid_indices = np.array([[
        _pad_list(x[0].valid_indices, -1, self.max_n_valid_indices) for x in seq
    ] for seq in inputs])
    output_features = np.array(
        [[_pad_list(x[1], [0], self.max_n_valid_indices)
          for x in seq]
         for seq in inputs])

    read_ind_batch, sequence_length = data_utils.convert_seqs_to_batch(
        read_ind, maxlen)
    output_feature_batch, _ = data_utils.convert_seqs_to_batch(
        output_features, maxlen)
    write_ind_batch, _ = data_utils.convert_seqs_to_batch(write_ind, maxlen)
    valid_indices_batch, _ = data_utils.convert_seqs_to_batch(
        valid_indices, maxlen)
    processed_inputs = tf_utils.MemoryInputTuple(
        read_ind_batch, write_ind_batch, valid_indices_batch)
    return (processed_inputs, sequence_length), (output_feature_batch,
                                                 sequence_length)
 def reset(self):
     self.actions = []
     self.mapped_actions = []
     self.rewards = []
     self.done = False
     valid_actions = self.de_vocab.lookup(self.interpreter.valid_tokens())
     if self.use_cache:
         new_valid_actions = []
         for ma in valid_actions:
             partial_program = self.de_vocab.lookup(self.mapped_actions +
                                                    [ma],
                                                    reverse=True)
             if not self.cache.check(partial_program):
                 new_valid_actions.append(ma)
         valid_actions = new_valid_actions
     self.valid_actions = valid_actions
     self.start_ob = (tf_utils.MemoryInputTuple(self.de_vocab.decode_id, -1,
                                                valid_actions),
                      [self.id_feature_dict[a] for a in valid_actions])
     self.obs = [self.start_ob]
    def step(self, action, debug=False):
        self.actions.append(action)
        if debug:
            print('-' * 50)
            print(self.de_vocab.lookup(self.valid_actions, reverse=True))
            print('pick #{} valid action'.format(action))
            print('history:')
            print(self.de_vocab.lookup(self.mapped_actions, reverse=True))
            # print('env: {}, cache size: {}'.format(self.name, len(self.cache._set)))
            print('obs')
            pprint.pprint(self.obs)

        if action < len(self.valid_actions) and action >= 0:
            mapped_action = self.valid_actions[action]
        else:
            print('-' * 50)
            # print('env: {}, cache size: {}'.format(self.name, len(self.cache._set)))
            print('action out of range.')
            print('action:')
            print(action)
            print('valid actions:')
            print(self.de_vocab.lookup(self.valid_actions, reverse=True))
            print('pick #{} valid action'.format(action))
            print('history:')
            print(self.de_vocab.lookup(self.mapped_actions, reverse=True))
            print('obs')
            pprint.pprint(self.obs)
            print('-' * 50)
            mapped_action = self.valid_actions[action]

        self.mapped_actions.append(mapped_action)

        result = self.interpreter.read_token(
            self.de_vocab.lookup(mapped_action, reverse=True))

        self.done = self.interpreter.done
        # Only when the proram is finished and it doesn't have
        # extra work or we don't care, its result will be
        # scored, and the score will be used as reward.
        if self.done and not (self.punish_extra_work
                              and self.interpreter.has_extra_work()):
            reward = self.score_fn(self.interpreter.result, self.answer)
        else:
            reward = 0.0

        if self.done and self.interpreter.result == [
                computer_factory.ERROR_TK
        ]:
            self.error = True

        if result is None or self.done:
            new_var_id = -1
        else:
            new_var_id = self.de_vocab.lookup(
                self.interpreter.namespace.last_var)
        valid_tokens = self.interpreter.valid_tokens()
        valid_actions = self.de_vocab.lookup(valid_tokens)

        # For each action, check the cache for the program, if
        # already tried, then not valid anymore.
        if self.use_cache:
            new_valid_actions = []
            cached_actions = []
            partial_program = self.de_vocab.lookup(self.mapped_actions,
                                                   reverse=True)
            for ma in valid_actions:
                new_program = partial_program + [
                    self.de_vocab.lookup(ma, reverse=True)
                ]
                if not self.cache.check(new_program):
                    new_valid_actions.append(ma)
                else:
                    cached_actions.append(ma)
            valid_actions = new_valid_actions

        self.valid_actions = valid_actions
        self.rewards.append(reward)
        ob = (tf_utils.MemoryInputTuple(read_ind=mapped_action,
                                        write_ind=new_var_id,
                                        valid_indices=self.valid_actions),
              [self.id_feature_dict[a] for a in valid_actions])

        # If no valid actions are available, then stop.
        if not self.valid_actions:
            self.done = True
            self.error = True

        # If the program is not finished yet, collect the
        # observation.
        if not self.done:
            # Add the actions that are filtered by cache into the
            # training example because at test time, they will be
            # there (no cache is available).
            if self.use_cache:
                valid_actions = self.valid_actions + cached_actions
                true_ob = (tf_utils.MemoryInputTuple(
                    read_ind=mapped_action,
                    write_ind=new_var_id,
                    valid_indices=valid_actions),
                           [self.id_feature_dict[a] for a in valid_actions])
                self.obs.append(true_ob)
            else:
                self.obs.append(ob)
        elif self.use_cache:
            # If already finished, save it in the cache.
            self.cache.save(
                self.de_vocab.lookup(self.mapped_actions, reverse=True))

        return ob, reward, self.done, {}