def build_q_perfect(self, cfg, goal_embedding):
    """Build the Q-function for perfect action space.

    Args:
      cfg: configuration of the experiments
      goal_embedding: embedding tensor of the instructions

    Returns:
      q_out: output q values
      input_select: object with maximum q value
      action_select: actions with maximum q value
    """
    per_input_ac_dim = cfg.per_input_ac_dim
    des_len, inner_len = cfg.descriptor_length, cfg.inner_product_length

    num_object = tf.shape(self.inputs)[1]
    tp_concat = vector_tensor_product(self.inputs, self.inputs)
    conv_layer_cfg = [[des_len * 8, 1, 1], [des_len * 4, 1, 1], [des_len, 1, 1]]
    # [B, ?, ?, des_len]
    tp_concat = stack_conv_layer(tp_concat, conv_layer_cfg)

    # similarity with goal
    goal_key = stack_dense_layer(
        goal_embedding, [inner_len*2, inner_len])  # [B, d_inner]
    goal_key = tf.expand_dims(goal_key, 1)  # [B, 1, d_inner]
    # [B, ?, ?, d_inner]
    obs_query = tf.layers.conv2d(tp_concat, inner_len, 1, padding='same')
    # [B, ?*?, d_inner]
    obs_query = tf.reshape(obs_query, [-1, num_object**2, inner_len])
    obs_query_t = tf.transpose(obs_query, perm=(0, 2, 1))  # [B, d_inner, ?*?]
    inner = tf.matmul(goal_key, obs_query_t)  # [B, 1, ?*?]
    weight = tf.nn.softmax(inner, axis=-1)  # [B, 1, ?*?]
    prod = tf.matmul(
        weight,
        tf.reshape(tp_concat, [-1, num_object**2, des_len]))  # [B, 1, des_len]

    goal_embedding_ = tf.expand_dims(goal_embedding, 1)  # [B, 1, dg]
    # [B, ?, dg]
    goal_embedding_ = tf.tile(goal_embedding_, multiples=[1, num_object, 1])
    # [B, ?, des_len]
    pair_wise_summary = tf.tile(prod, multiples=[1, num_object, 1])
    # [B, ?, des_len+di+dg]
    augemented_inputs = tf.concat(
        [self.inputs, pair_wise_summary, goal_embedding_], axis=-1)
    # [B, ?, 1, des_len+di+dg]
    augemented_inputs = tf.expand_dims(augemented_inputs, axis=2)
    conv_layer_cfg = [
        [per_input_ac_dim*64, 1, 1],
        [per_input_ac_dim*64, 1, 1],
        [per_input_ac_dim, 1, 1]
    ]
    # [B, ?, per_input_ac_dim]
    q_out = tf.squeeze(
        stack_conv_layer(augemented_inputs, conv_layer_cfg), axis=2)
    input_max_q = tf.reduce_max(q_out, axis=2)
    input_select = tf.argmax(input_max_q, axis=1)
    action_max_q = tf.reduce_max(q_out, axis=1)
    action_select = tf.argmax(action_max_q, axis=1)

    return q_out, input_select, action_select
    def build_q_discrete(self, cfg, goal_embedding):
        """Returns the q function for discrete action space.

    Args:
      cfg: configuration of the experiments
      goal_embedding: embedding tensor of the instructions

    Returns:
      output q values of all actions
    """
        ac_dim = cfg.ac_dim
        des_len, inner_len = cfg.descriptor_length, cfg.inner_product_length

        num_object = tf.shape(self.inputs)[1]
        tp_concat = vector_tensor_product(self.inputs, self.inputs)
        conv_layer_cfg = [[des_len * 8, 1, 1], [des_len * 4, 1, 1],
                          [des_len, 1, 1]]
        # [B, ?, ?, des_len]
        tp_concat = stack_conv_layer(tp_concat, conv_layer_cfg)

        # similarity with goal
        goal_key = stack_dense_layer(
            goal_embedding, [inner_len * 2, inner_len])  # [B, d_inner]
        goal_key = tf.expand_dims(goal_key, 1)  # [B, 1, d_inner]
        # [B, ?, ?, d_inner]
        obs_query = tf.layers.conv2d(tp_concat, inner_len, 1, padding='same')
        # [B, ?*?, d_inner]
        obs_query = tf.reshape(obs_query, [-1, num_object**2, inner_len])
        obs_query_t = tf.transpose(obs_query,
                                   perm=(0, 2, 1))  # [B, d_inner, ?*?]
        inner = tf.matmul(goal_key, obs_query_t)  # [B, 1, ?*?]
        weight = tf.nn.softmax(inner, axis=-1)  # [B, 1, ?*?]
        prod = tf.matmul(weight,
                         tf.reshape(
                             tp_concat,
                             [-1, num_object**2, des_len]))  # [B, 1, des_len]
        goal_embedding_ = tf.expand_dims(goal_embedding, 1)  # [B, 1, dg]
        # [B, ?, dg]
        goal_embedding_ = tf.tile(goal_embedding_,
                                  multiples=[1, num_object, 1])
        # [B, ?, des_len]
        pair_wise_summary = tf.tile(prod, multiples=[1, num_object, 1])
        # [B, ?, des_len+di+dg]
        augemented_inputs = tf.concat(
            [self.inputs, pair_wise_summary, goal_embedding_], axis=-1)
        # [B, ?, 1, des_len+di+dg]
        augemented_inputs = tf.expand_dims(augemented_inputs, axis=2)
        cfg = [[ac_dim // 8, 1, 1], [ac_dim // 8, 1, 1]]
        heads = []
        for _ in range(8):
            # [B, ?, 1, ac_dim//8]
            head_out = stack_conv_layer(augemented_inputs, cfg)
            weights = tf.layers.conv2d(head_out, 1, 1)  # [B, ?, 1, 1]
            softmax_weights = tf.nn.softmax(weights, axis=1)  # [B, ?, 1, 1]
            heads.append(tf.reduce_sum(softmax_weights * head_out,
                                       axis=(1, 2)))
        # heads = 8 X [B, ac_dim//8]
        out = tf.concat(heads, axis=1)  # [B, ac_dim]
        return tf.layers.dense(out, ac_dim)
示例#3
0
    def make_policy():
      """Returns one copy of the model."""
      artifact = {}
      if cfg.intruction_repr == 'language':
        trainable_encoder = cfg.trainable_encoder
        print('The encoder is trainable: {}'.format(trainable_encoder))
        embedding = tf.get_variable(
            name='word_embedding',
            shape=(cfg.vocab_size, cfg.embedding_size),
            dtype=tf.float32,
            trainable=trainable_encoder)
        _, goal_embedding = encoder(
            self.word_inputs,
            embedding,
            cfg.encoder_n_unit,
            trainable=trainable_encoder)
        artifact['embedding'] = embedding
      elif cfg.intruction_repr == 'one_hot':
        print('Goal input for one-hot max len {}'.format(
            cfg.max_sequence_length))
        one_hot_goal = tf.one_hot(self.word_inputs, cfg.max_sequence_length)
        one_hot_goal.set_shape([None, cfg.max_sequence_length])
        layer_cfg = [cfg.max_sequence_length // 8, cfg.encoder_n_unit]
        goal_embedding = stack_dense_layer(one_hot_goal, layer_cfg)
      else:
        raise ValueError('Unrecognized instruction type: {}'.format(
            cfg.instruction_repr))
      artifact['goal_embedding'] = goal_embedding
      all_q = self.build_q_factor_discrete(cfg, goal_embedding)

      predict_action = tf.argmax(all_q, axis=-1)
      action = tf.placeholder(shape=None, dtype=tf.int32)
      action_onehot = tf.one_hot(
          action, cfg.ac_dim[0], dtype=tf.float32)
      q = tf.reduce_sum(
          tf.multiply(all_q, action_onehot), axis=1)
      artifact.update(
          {
              'all_q': all_q,
              'predict_action': predict_action,
              'action_ph': action,
              'action_onehot': action_onehot,
              'q': q
          }
      )
      return artifact
        def make_policy():
            """Build one copy of the model."""
            artifact = {}
            if cfg.intruction_repr == 'language':
                trainable_encoder = cfg.trainable_encoder
                print('The encoder is trainable: {}'.format(trainable_encoder))
                embedding = tf.get_variable(name='word_embedding',
                                            shape=(cfg.vocab_size,
                                                   cfg.embedding_size),
                                            dtype=tf.float32,
                                            trainable=trainable_encoder)
                _, goal_embedding = encoder(self.word_inputs,
                                            embedding,
                                            cfg.encoder_n_unit,
                                            trainable=trainable_encoder)
                artifact['embedding'] = embedding
            elif cfg.intruction_repr == 'one_hot':
                print('Goal input for one-hot max len {}'.format(
                    cfg.max_sequence_length))
                one_hot_goal = tf.one_hot(self.word_inputs,
                                          cfg.max_sequence_length)
                one_hot_goal.set_shape([None, cfg.max_sequence_length])
                layer_cfg = [cfg.max_sequence_length // 8, cfg.encoder_n_unit]
                goal_embedding = stack_dense_layer(one_hot_goal, layer_cfg)
            else:
                raise ValueError('Unrecognized instruction type: {}'.format(
                    cfg.instruction_repr))
            artifact['goal_embedding'] = goal_embedding

            if cfg.action_type == 'perfect':
                print('using perfect action Q function...')
                all_q, predict_object, predict_object_action = self.build_q_perfect(
                    cfg, goal_embedding)
                predict_action = tf.stack(
                    [predict_object, predict_object_action], axis=1)
                action = tf.placeholder(shape=(None, 2), dtype=tf.int32)
                stacked_indices = tf.concat([
                    tf.expand_dims(tf.range(0,
                                            tf.shape(action)[0]), axis=1),
                    action
                ],
                                            axis=1)
                q = tf.gather_nd(all_q, stacked_indices)
                artifact.update({
                    'all_q': all_q,
                    'predict_object': predict_object,
                    'predict_object_action': predict_object_action,
                    'predict_action': predict_action,
                    'action_ph': action,
                    'q': q,
                })
            elif cfg.action_type == 'discrete':
                print('using discrete action Q function...')
                ac_dim = cfg.per_input_ac_dim[0]
                all_q = self.build_q_discrete(goal_embedding, ac_dim)
                predict_action = tf.argmax(all_q, axis=-1)
                action = tf.placeholder(shape=None, dtype=tf.int32)
                action_onehot = tf.one_hot(action, ac_dim, dtype=tf.float32)
                q = tf.reduce_sum(tf.multiply(all_q, action_onehot), axis=1)
                artifact.update({
                    'all_q': all_q,
                    'predict_action': predict_action,
                    'action_ph': action,
                    'action_onehot': action_onehot,
                    'q': q,
                })
            else:
                raise ValueError('Unrecognized action type: {}'.format(
                    cfg.action_type))
            return artifact