コード例 #1
0
    def build_q_discrete(self, cfg, name, embedding_length):
        """Build the q function for discrete action space."""
        ac_dim = cfg.ac_dim[0]
        des_len, inner_len = cfg.descriptor_length, cfg.inner_product_length

        inputs = tf.keras.layers.Input(shape=(cfg.obs_dim[0], cfg.obs_dim[1]))
        goal_embedding = tf.keras.layers.Input(shape=(embedding_length))

        tp_concat = vector_tensor_product(inputs, inputs)
        conv_layer_cfg = [[des_len * 8, 3, 1], [des_len * 4, 3, 1],
                          [des_len * 4, 1, 1]]
        # [B, ?, ?, des_len]
        tp_concat = stack_conv_layer(conv_layer_cfg)(tp_concat)
        summary = tf.reduce_mean(tp_concat, axis=(1, 2))

        goal_projection_layer = tf.keras.layers.Dense(des_len * 4,
                                                      activation='sigmoid')
        gating = goal_projection_layer(summary)

        gated_summary = summary * gating
        out_layer = tf.keras.Sequential(layers=[
            tf.keras.layers.Dense(100, activation='relu'),
            tf.keras.layers.Dense(ac_dim),
        ])
        out = out_layer(gated_summary)
        all_inputs = {'state_input': inputs, 'goal_embedding': goal_embedding}
        q_out_layer = tf.keras.Model(name=name, inputs=all_inputs, outputs=out)
        return q_out_layer
コード例 #2
0
    def _build_q_perfect(self, cfg, name):
        """Build the q function for perfect action space."""
        per_input_ac_dim = cfg.per_input_ac_dim
        des_len = cfg.descriptor_length

        inputs = tf.keras.layers.Input(shape=(cfg.obs_dim[0], cfg.obs_dim[1]))

        shape_layer = tf.keras.layers.Lambda(tf.shape)
        num_object = shape_layer(inputs)[1]
        tp_concat_orig = vector_tensor_product(inputs, inputs)
        conv_layer_cfg = [[des_len * 8, 1, 1], [des_len * 4, 1, 1],
                          [des_len, 1, 1]]
        # [B, ?, ?, des_len]
        tp_concat = stack_conv_layer(conv_layer_cfg)(tp_concat_orig)

        expand_dims_layer = tf.keras.layers.Lambda(
            lambda inputs: tf.expand_dims(inputs[0], axis=inputs[1]))

        # [B, ?, ?, 1]
        conv_layer_cfg = [[32, 3, 1], [16, 3, 1], [1, 1, 1]]
        obs_query = stack_conv_layer(conv_layer_cfg)(tp_concat)
        # [B, 1, ?*?]
        obs_query = tf.keras.layers.Reshape((1, -1))(obs_query)
        matmul_layer = tf.keras.layers.Lambda(
            lambda inputs: tf.matmul(inputs[0], inputs[1]))
        weight = tf.keras.layers.Softmax()(obs_query)  # [B, 1, ?*?]
        prod = matmul_layer((weight, tf.keras.layers.Reshape(
            (-1, des_len))(tp_concat)))  # [B, 1, des_len]

        tile_layer = tf.keras.layers.Lambda(
            lambda inputs: tf.tile(inputs[0], multiples=inputs[1]))
        # [B, ?, des_len]
        pair_wise_summary = tile_layer((prod, [1, num_object, 1]))
        # [B, ?, des_len+di]
        augemented_inputs = tf.keras.layers.Concatenate(axis=-1)(
            [inputs, pair_wise_summary])
        # [B, ?, 1, des_len+di+dg]
        augemented_inputs = expand_dims_layer((augemented_inputs, 2))
        conv_layer_cfg = [[per_input_ac_dim * 64, 1, 1],
                          [per_input_ac_dim * 64, 1, 1],
                          [per_input_ac_dim, 1, 1]]
        # [B, ?, per_input_ac_dim]
        q_out = tf.keras.layers.Reshape((-1, per_input_ac_dim))(
            stack_conv_layer(conv_layer_cfg)(augemented_inputs))
        q_out_layer = tf.keras.Model(name=name, inputs=inputs, outputs=q_out)
        return q_out_layer
コード例 #3
0
  def _build_q_discrete(self, cfg, name):
    """Build the q function for discrete action space."""
    ac_dim = cfg.ac_dim

    inputs = tf.keras.layers.Input(shape=(cfg.obs_dim[0], cfg.obs_dim[1]))

    # [B, ?, 1, des_len+di+dg]
    augemented_inputs = tf.expand_dims(inputs, axis=2)
    cfg = [[ac_dim // 8, 1, 1], [ac_dim // 8, 1, 1]]
    heads = []
    for _ in range(8):
      # [B, ?, 1, ac_dim//8]
      head_out = stack_conv_layer(cfg)(augemented_inputs)
      weights = tf.keras.layers.Conv2D(1, 1, 1)(head_out)  # [B, ?, 1, 1]
      softmax_weights = tf.nn.softmax(weights, axis=1)  # [B, ?, 1, 1]
      heads.append(tf.reduce_sum(softmax_weights * head_out, axis=(1, 2)))
    # heads = 8 X [B, ac_dim//8]
    out = tf.concat(heads, axis=1)  # [B, ac_dim]
    out = tf.keras.layers.Dense(ac_dim)(out)
    q_out_layer = tf.keras.Model(name=name, inputs=inputs, outputs=out)
    return q_out_layer
コード例 #4
0
    def build_q_perfect(self, cfg, name, embedding_length):
        """Build the q function for perfect action space."""
        per_input_ac_dim = cfg.per_input_ac_dim
        des_len, inner_len = cfg.descriptor_length, cfg.inner_product_length

        inputs = tf.keras.layers.Input(shape=(cfg.obs_dim[0], cfg.obs_dim[1]))
        goal_embedding = tf.keras.layers.Input(shape=(embedding_length))

        shape_layer = tf.keras.layers.Lambda(tf.shape)
        num_object = tf.cast(shape_layer(inputs)[1], tf.int32)
        tp_concat = vector_tensor_product(inputs, inputs)
        conv_layer_cfg = [[des_len * 8, 1, 1], [des_len * 4, 1, 1],
                          [des_len, 1, 1]]
        # [B, ?, ?, des_len]
        tp_concat = stack_conv_layer(conv_layer_cfg)(tp_concat)

        # similarity with goal
        goal_key = stack_dense_layer([inner_len * 2, inner_len
                                      ])(goal_embedding)  # [B, d_inner]
        expand_dims_layer = tf.keras.layers.Lambda(
            lambda inputs: tf.expand_dims(inputs[0], axis=inputs[1]))
        goal_key = expand_dims_layer((goal_key, 1))  # [B, 1, d_inner]
        # [B, ?, ?, d_inner]
        obs_query_layer = tf.keras.layers.Conv2D(inner_len,
                                                 1,
                                                 1,
                                                 padding='same')
        obs_query = obs_query_layer(tp_concat)
        # [B, ?*?, d_inner]
        obs_query = tf.keras.layers.Reshape((-1, inner_len))(obs_query)
        obs_query_t = tf.keras.layers.Permute(
            (2, 1))(obs_query)  # [B,d_inner,?*?]
        matmul_layer = tf.keras.layers.Lambda(
            lambda inputs: tf.matmul(inputs[0], inputs[1]))
        inner = matmul_layer((goal_key, obs_query_t))  # [B, 1, ?*?]
        weight = tf.keras.layers.Softmax()(inner)  # [B, 1, ?*?]
        prod = matmul_layer((weight, tf.keras.layers.Reshape(
            (-1, des_len))(tp_concat)))  # [B, 1, des_len]

        goal_embedding_ = expand_dims_layer((goal_embedding, 1))  # [B, 1, dg]

        tile_layer = tf.keras.layers.Lambda(
            lambda inputs: tf.tile(inputs[0], multiples=inputs[1]))
        # [B, ?, dg]
        goal_embedding_ = tile_layer((goal_embedding_, [1, num_object, 1]))
        # [B, ?, des_len]
        pair_wise_summary = tile_layer((prod, [1, num_object, 1]))
        # [B, ?, des_len+di+dg]
        augemented_inputs = tf.keras.layers.Concatenate(axis=-1)(
            [inputs, pair_wise_summary, goal_embedding_])
        # [B, ?, 1, des_len+di+dg]
        augemented_inputs = expand_dims_layer((augemented_inputs, 2))
        conv_layer_cfg = [[per_input_ac_dim * 64, 1, 1],
                          [per_input_ac_dim * 64, 1, 1],
                          [per_input_ac_dim, 1, 1]]
        # [B, ?, per_input_ac_dim]
        q_out = tf.keras.layers.Reshape((-1, per_input_ac_dim))(
            stack_conv_layer(conv_layer_cfg)(augemented_inputs))
        all_inputs = {'state_input': inputs, 'goal_embedding': goal_embedding}
        q_out_layer = tf.keras.Model(name=name,
                                     inputs=all_inputs,
                                     outputs=q_out)

        return q_out_layer