def __init__(self, ac_space, policy_network, value_network=None, estimate_q=False): """ Parameters: ---------- ac_space action space policy_network keras network for policy value_network keras network for value estimate_q q value or v value """ self.policy_network = policy_network self.value_network = value_network or policy_network self.estimate_q = estimate_q self.initial_state = None # Based on the action space, will select what probability distribution type self.pdtype = make_pdtype(self.policy_network.output_shape, ac_space, init_scale=0.01) if estimate_q: self.value_fc = fc_build(self.value_network.output_shape, 'q', ac_space.n) else: self.value_fc = fc_build(self.value_network.output_shape, 'vf', 1) # # to get just dense size and avoid batch size # print(f'self.value_network.output_shape for agent_0 {self.value_network.output_shape[-1]}') # value_model_inputes = tf.keras.layers.Input(self.value_network.output_shape[-1]) # agent 0 output # # if estimate_q: # # value_fc = fc(scope='q', nh=ac_space.n)(policy_network.output) # value_fc = tf.keras.layers.Dense(units=ac_space.n, kernel_initializer=ortho_init(init_scale), # bias_initializer=tf.keras.initializers.Constant(init_bias), # name=f'q')(value_model_inputes) # else: # # value_fc = fc(scope='vf', nh=1)(policy_network.output) # value_fc = tf.keras.layers.Dense(units=1, kernel_initializer=ortho_init(init_scale), # bias_initializer=tf.keras.initializers.Constant(init_bias), # name=f'vf')(value_model_inputes) # self.value_model = tf.keras.Model(inputs=value_model_inputes, outputs=value_fc, name='Value_Network') # self.value_model.summary() # tf.keras.utils.plot_model(self.value_model, to_file='./value_model.png') self.value_network.summary() self.policy_network.summary() tf.keras.utils.plot_model(self.policy_network, to_file='./policy_network.png')
def _build_q_head(self): input_shape = self.value_network.output_shape name = 'q' critics_fc = [] for a in self.agent_ids: name += '_agent_' + str(a) critics_fc.append(fc_build(input_shape, name, self.num_actions)) return critics_fc
def _matching_fc(tensor_shape, name, size, init_scale, init_bias): if tensor_shape[-1] == size: return lambda x: x else: return fc_build(tensor_shape, name, size, init_scale=init_scale, init_bias=init_bias)