def build_q_discrete(self, cfg, name, embedding_length): """Build the q function for discrete action space.""" ac_dim = cfg.ac_dim[0] des_len, inner_len = cfg.descriptor_length, cfg.inner_product_length inputs = tf.keras.layers.Input(shape=(cfg.obs_dim[0], cfg.obs_dim[1])) goal_embedding = tf.keras.layers.Input(shape=(embedding_length)) tp_concat = vector_tensor_product(inputs, inputs) conv_layer_cfg = [[des_len * 8, 3, 1], [des_len * 4, 3, 1], [des_len * 4, 1, 1]] # [B, ?, ?, des_len] tp_concat = stack_conv_layer(conv_layer_cfg)(tp_concat) summary = tf.reduce_mean(tp_concat, axis=(1, 2)) goal_projection_layer = tf.keras.layers.Dense(des_len * 4, activation='sigmoid') gating = goal_projection_layer(summary) gated_summary = summary * gating out_layer = tf.keras.Sequential(layers=[ tf.keras.layers.Dense(100, activation='relu'), tf.keras.layers.Dense(ac_dim), ]) out = out_layer(gated_summary) all_inputs = {'state_input': inputs, 'goal_embedding': goal_embedding} q_out_layer = tf.keras.Model(name=name, inputs=all_inputs, outputs=out) return q_out_layer
def _build_q_perfect(self, cfg, name): """Build the q function for perfect action space.""" per_input_ac_dim = cfg.per_input_ac_dim des_len = cfg.descriptor_length inputs = tf.keras.layers.Input(shape=(cfg.obs_dim[0], cfg.obs_dim[1])) shape_layer = tf.keras.layers.Lambda(tf.shape) num_object = shape_layer(inputs)[1] tp_concat_orig = vector_tensor_product(inputs, inputs) conv_layer_cfg = [[des_len * 8, 1, 1], [des_len * 4, 1, 1], [des_len, 1, 1]] # [B, ?, ?, des_len] tp_concat = stack_conv_layer(conv_layer_cfg)(tp_concat_orig) expand_dims_layer = tf.keras.layers.Lambda( lambda inputs: tf.expand_dims(inputs[0], axis=inputs[1])) # [B, ?, ?, 1] conv_layer_cfg = [[32, 3, 1], [16, 3, 1], [1, 1, 1]] obs_query = stack_conv_layer(conv_layer_cfg)(tp_concat) # [B, 1, ?*?] obs_query = tf.keras.layers.Reshape((1, -1))(obs_query) matmul_layer = tf.keras.layers.Lambda( lambda inputs: tf.matmul(inputs[0], inputs[1])) weight = tf.keras.layers.Softmax()(obs_query) # [B, 1, ?*?] prod = matmul_layer((weight, tf.keras.layers.Reshape( (-1, des_len))(tp_concat))) # [B, 1, des_len] tile_layer = tf.keras.layers.Lambda( lambda inputs: tf.tile(inputs[0], multiples=inputs[1])) # [B, ?, des_len] pair_wise_summary = tile_layer((prod, [1, num_object, 1])) # [B, ?, des_len+di] augemented_inputs = tf.keras.layers.Concatenate(axis=-1)( [inputs, pair_wise_summary]) # [B, ?, 1, des_len+di+dg] augemented_inputs = expand_dims_layer((augemented_inputs, 2)) conv_layer_cfg = [[per_input_ac_dim * 64, 1, 1], [per_input_ac_dim * 64, 1, 1], [per_input_ac_dim, 1, 1]] # [B, ?, per_input_ac_dim] q_out = tf.keras.layers.Reshape((-1, per_input_ac_dim))( stack_conv_layer(conv_layer_cfg)(augemented_inputs)) q_out_layer = tf.keras.Model(name=name, inputs=inputs, outputs=q_out) return q_out_layer
def _build_q_discrete(self, cfg, name): """Build the q function for discrete action space.""" ac_dim = cfg.ac_dim inputs = tf.keras.layers.Input(shape=(cfg.obs_dim[0], cfg.obs_dim[1])) # [B, ?, 1, des_len+di+dg] augemented_inputs = tf.expand_dims(inputs, axis=2) cfg = [[ac_dim // 8, 1, 1], [ac_dim // 8, 1, 1]] heads = [] for _ in range(8): # [B, ?, 1, ac_dim//8] head_out = stack_conv_layer(cfg)(augemented_inputs) weights = tf.keras.layers.Conv2D(1, 1, 1)(head_out) # [B, ?, 1, 1] softmax_weights = tf.nn.softmax(weights, axis=1) # [B, ?, 1, 1] heads.append(tf.reduce_sum(softmax_weights * head_out, axis=(1, 2))) # heads = 8 X [B, ac_dim//8] out = tf.concat(heads, axis=1) # [B, ac_dim] out = tf.keras.layers.Dense(ac_dim)(out) q_out_layer = tf.keras.Model(name=name, inputs=inputs, outputs=out) return q_out_layer
def build_q_perfect(self, cfg, name, embedding_length): """Build the q function for perfect action space.""" per_input_ac_dim = cfg.per_input_ac_dim des_len, inner_len = cfg.descriptor_length, cfg.inner_product_length inputs = tf.keras.layers.Input(shape=(cfg.obs_dim[0], cfg.obs_dim[1])) goal_embedding = tf.keras.layers.Input(shape=(embedding_length)) shape_layer = tf.keras.layers.Lambda(tf.shape) num_object = tf.cast(shape_layer(inputs)[1], tf.int32) tp_concat = vector_tensor_product(inputs, inputs) conv_layer_cfg = [[des_len * 8, 1, 1], [des_len * 4, 1, 1], [des_len, 1, 1]] # [B, ?, ?, des_len] tp_concat = stack_conv_layer(conv_layer_cfg)(tp_concat) # similarity with goal goal_key = stack_dense_layer([inner_len * 2, inner_len ])(goal_embedding) # [B, d_inner] expand_dims_layer = tf.keras.layers.Lambda( lambda inputs: tf.expand_dims(inputs[0], axis=inputs[1])) goal_key = expand_dims_layer((goal_key, 1)) # [B, 1, d_inner] # [B, ?, ?, d_inner] obs_query_layer = tf.keras.layers.Conv2D(inner_len, 1, 1, padding='same') obs_query = obs_query_layer(tp_concat) # [B, ?*?, d_inner] obs_query = tf.keras.layers.Reshape((-1, inner_len))(obs_query) obs_query_t = tf.keras.layers.Permute( (2, 1))(obs_query) # [B,d_inner,?*?] matmul_layer = tf.keras.layers.Lambda( lambda inputs: tf.matmul(inputs[0], inputs[1])) inner = matmul_layer((goal_key, obs_query_t)) # [B, 1, ?*?] weight = tf.keras.layers.Softmax()(inner) # [B, 1, ?*?] prod = matmul_layer((weight, tf.keras.layers.Reshape( (-1, des_len))(tp_concat))) # [B, 1, des_len] goal_embedding_ = expand_dims_layer((goal_embedding, 1)) # [B, 1, dg] tile_layer = tf.keras.layers.Lambda( lambda inputs: tf.tile(inputs[0], multiples=inputs[1])) # [B, ?, dg] goal_embedding_ = tile_layer((goal_embedding_, [1, num_object, 1])) # [B, ?, des_len] pair_wise_summary = tile_layer((prod, [1, num_object, 1])) # [B, ?, des_len+di+dg] augemented_inputs = tf.keras.layers.Concatenate(axis=-1)( [inputs, pair_wise_summary, goal_embedding_]) # [B, ?, 1, des_len+di+dg] augemented_inputs = expand_dims_layer((augemented_inputs, 2)) conv_layer_cfg = [[per_input_ac_dim * 64, 1, 1], [per_input_ac_dim * 64, 1, 1], [per_input_ac_dim, 1, 1]] # [B, ?, per_input_ac_dim] q_out = tf.keras.layers.Reshape((-1, per_input_ac_dim))( stack_conv_layer(conv_layer_cfg)(augemented_inputs)) all_inputs = {'state_input': inputs, 'goal_embedding': goal_embedding} q_out_layer = tf.keras.Model(name=name, inputs=all_inputs, outputs=q_out) return q_out_layer