def __init__(self, env, observations, timer, params): self.regularizer = tf.contrib.layers.l2_regularizer(scale=1e-10) img_model_name = params.image_model_name fc_layers = params.fc_layers fc_size = params.fc_size lowdim_model_name = params.lowdim_model_name past_frames = params.stack_past_frames image_obs = has_image_observations(env.observation_space.spaces['obs']) num_actions = env.action_space.n if image_obs: # convolutions if img_model_name == 'convnet_simple': conv_filters = self._convnet_simple(observations, [(32, 3, 2)] * 4) else: raise Exception('Unknown model name') encoded_input = tf.contrib.layers.flatten(conv_filters) else: # low-dimensional input if lowdim_model_name == 'simple_fc': frames = tf.split(observations, past_frames, axis=1) fc_encoder = tf.make_template('fc_encoder', self._fc_frame_encoder, create_scope_now_=True) encoded_frames = [fc_encoder(frame) for frame in frames] encoded_input = tf.concat(encoded_frames, axis=1) else: raise Exception('Unknown lowdim model name') if params.ignore_timer: timer = tf.multiply(timer, 0.0) encoded_input_with_timer = tf.concat([encoded_input, tf.expand_dims(timer, 1)], axis=1) fc = encoded_input_with_timer for _ in range(fc_layers - 1): fc = dense(fc, fc_size, self.regularizer) # fully-connected layers to generate actions actions_fc = dense(fc, fc_size // 2, self.regularizer) self.actions = tf.contrib.layers.fully_connected(actions_fc, num_actions, activation_fn=None) self.best_action_deterministic = tf.argmax(self.actions, axis=1) self.actions_prob_distribution = CategoricalProbabilityDistribution(self.actions) self.act = self.actions_prob_distribution.sample() value_fc = dense(fc, fc_size // 2, self.regularizer) self.value = tf.squeeze(tf.contrib.layers.fully_connected(value_fc, 1, activation_fn=None), axis=[1]) if image_obs: # summaries with tf.variable_scope('conv1', reuse=True): weights = tf.get_variable('weights') with tf.name_scope('a2c_agent_summary_conv'): if weights.shape[2].value in [1, 3, 4]: tf.summary.image('conv1/kernels', put_kernels_on_grid(weights), max_outputs=1) log.info('Total parameters in the model: %d', count_total_parameters())
def _fc_frame_encoder(self, x): return dense(x, 128, self.regularizer)
def __init__(self, env, obs, next_obs, actions, past_frames, forward_fc): """ :param obs - placeholder for observations :param actions - placeholder for selected actions """ self.regularizer = tf.contrib.layers.l2_regularizer(scale=1e-10) image_obs = has_image_observations(env.observation_space.spaces['obs']) num_actions = env.action_space.n if image_obs: # convolutions conv_encoder = tf.make_template( 'conv_encoder', self._convnet_simple, create_scope_now_=True, convs=[(32, 3, 2)] * 4, ) encoded_obs = conv_encoder(obs=obs) encoded_obs = tf.contrib.layers.flatten(encoded_obs) encoded_next_obs = conv_encoder(obs=next_obs) self.encoded_next_obs = tf.contrib.layers.flatten(encoded_next_obs) else: # low-dimensional input lowdim_encoder = tf.make_template( 'lowdim_encoder', self._lowdim_encoder, create_scope_now_=True, past_frames=past_frames, ) encoded_obs = lowdim_encoder(obs=obs) self.encoded_next_obs = lowdim_encoder(obs=next_obs) self.feature_vector_size = encoded_obs.get_shape().as_list()[-1] log.info('Feature vector size in ICM: %d', self.feature_vector_size) actions_one_hot = tf.one_hot(actions, num_actions) # forward model forward_model_input = tf.concat( [encoded_obs, actions_one_hot], axis=1, ) forward_model_hidden = dense(forward_model_input, forward_fc, self.regularizer) forward_model_hidden = dense(forward_model_hidden, forward_fc, self.regularizer) forward_model_output = tf.contrib.layers.fully_connected( forward_model_hidden, self.feature_vector_size, activation_fn=None, ) self.predicted_obs = forward_model_output # inverse model inverse_model_input = tf.concat([encoded_obs, self.encoded_next_obs], axis=1) inverse_model_hidden = dense(inverse_model_input, 256, self.regularizer) inverse_model_output = tf.contrib.layers.fully_connected( inverse_model_hidden, num_actions, activation_fn=None, ) self.predicted_actions = inverse_model_output log.info('Total parameters in the model: %d', count_total_parameters())