def make_lstm(lstm_unit, nenvs, step_size, inpt, masks, rnn_state): with tf.variable_scope('rnn'): rnn_in = batch_to_seq(inpt, nenvs, step_size) masks = batch_to_seq(masks, nenvs, step_size) rnn_out, rnn_state = lstm(rnn_in, masks, rnn_state, lstm_unit, np.sqrt(2.0)) rnn_out = seq_to_batch(rnn_out, nenvs, step_size) return rnn_out, rnn_state
def _make_network(convs, fcs, use_lstm, padding, inpt, masks, rnn_state, num_actions, lstm_unit, nenvs, step_size, scope): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): out = inpt with tf.variable_scope('convnet'): for num_outputs, kernel_size, stride in convs: out = layers.convolution2d( out, num_outputs=num_outputs, kernel_size=kernel_size, stride=stride, padding=padding, activation_fn=tf.nn.relu, weights_initializer=tf.orthogonal_initializer( np.sqrt(2.0))) out = layers.flatten(out) with tf.variable_scope('hiddens'): for hidden in fcs: out = layers.fully_connected( out, hidden, activation_fn=tf.nn.relu, weights_initializer=tf.orthogonal_initializer( np.sqrt(2.0))) with tf.variable_scope('rnn'): rnn_in = batch_to_seq(out, nenvs, step_size) masks = batch_to_seq(masks, nenvs, step_size) rnn_out, rnn_state = lstm(rnn_in, masks, rnn_state, lstm_unit, np.sqrt(2.0)) rnn_out = seq_to_batch(rnn_out, nenvs, step_size) if use_lstm: out = rnn_out policy = layers.fully_connected( out, num_actions, activation_fn=tf.nn.softmax, weights_initializer=tf.orthogonal_initializer(0.1)) value = layers.fully_connected( out, 1, activation_fn=None, weights_initializer=tf.orthogonal_initializer(1.0)) return policy, value, rnn_state
def cnn_network(self, inpt, masks, rnn_state, num_actions, lstm_unit, nenvs, step_size, scope, is_train=False): with tf.variable_scope('cnn_conv_base_network', reuse=tf.AUTO_REUSE): conv_name = 'conv_base_params' if conv_name not in self.all_variables: cnn_base_params = create_cnn_network_params( context_dim=self.cpg_context_size, cnn_architecture=self.conv_architecture, padding=self.padding_type, initializer=tf.orthogonal_initializer(np.sqrt(2.0)), name=conv_name, cpg_network_shape=self.cpg_network_shape, dropout=self.dropout, use_batch_norm=self.use_batch_norm, batch_norm_momentum=self.batch_norm_momentum, batch_norm_train_stats=self.batch_norm_train_stats, input_channels=self.input_channels) self.all_variables.update(cnn_base_params) conv_output = cnn_network(inpt, self.all_variables[conv_name], gen_vector=self.context_vector, is_train=is_train) conv_flattened = layers.flatten(conv_output) with tf.variable_scope('cnn_fc_base_network', reuse=tf.AUTO_REUSE): fc_name = 'cnn_fc' if fc_name not in self.all_variables: fc_base_params = create_fc_network_params( input_dim=tf.shape(conv_flattened)[-1].eval(), context_dim=self.cpg_context_size, fc_architecture=self.fc_architecture, initializer=tf.orthogonal_initializer(np.sqrt(2.0)), name=fc_name, cpg_network_shape=self.cpg_network_shape, dropout=self.dropout, use_batch_norm=self.use_batch_norm, batch_norm_momentum=self.batch_norm_momentum, batch_norm_train_stats=self.batch_norm_train_stats) self.all_variables.update(fc_base_params) fc_output = fc_network(input=conv_flattened, fc_params=self.all_variables[fc_name], gen_vector=self.context_vector, is_train=is_train) with tf.variable_scope('cnn_rnn_base_network', reuse=tf.AUTO_REUSE): rnn_name = 'rnn' if rnn_name not in self.all_variables: rnn_base_params = create_lstm_network_params( input_dim=tf.shape(fc_output)[-1].eval(), context_dim=self.cpg_context_size, hidden_dim=lstm_unit, initializer=tf.orthogonal_initializer(np.sqrt(2.0)), name=rnn_name, cpg_network_shape=self.cpg_network_shape, dropout=self.dropout, use_batch_norm=self.use_batch_norm, batch_norm_momentum=self.batch_norm_momentum, batch_norm_train_stats=self.batch_norm_train_stats) self.all_variables.update(rnn_base_params) rnn_in = batch_to_seq(fc_output, nenvs, step_size) masks = batch_to_seq(masks, nenvs, step_size) rnn_output, rnn_state = lstm_network( inputs=rnn_in, keep_props=masks, state=rnn_state, lstm_params=self.all_variables[rnn_name], gen_vector=self.context_vector, is_train=is_train) rnn_output = seq_to_batch(rnn_output, nenvs, step_size) if self.use_lstm: output = rnn_output else: output = fc_output with tf.variable_scope('policy_network', reuse=tf.AUTO_REUSE): policy_name = 'policy' if policy_name not in self.all_variables: policy_params = create_fc_network_params( input_dim=tf.shape(output)[-1].eval(), context_dim=self.cpg_context_size, fc_architecture=[num_actions], initializer=tf.orthogonal_initializer(0.1), name=policy_name, cpg_network_shape=self.cpg_network_shape, dropout=self.dropout, use_batch_norm=self.use_batch_norm, batch_norm_momentum=self.batch_norm_momentum, batch_norm_train_stats=self.batch_norm_train_stats) self.all_variables.update(policy_params) policy = fc_network(input=output, fc_params=self.all_variables[policy_name], gen_vector=self.context_vector, is_train=is_train, activation=None) dist = tf.distributions.Categorical(probs=tf.nn.softmax(policy)) with tf.variable_scope('value_network', reuse=tf.AUTO_REUSE): value_name = 'value' if value_name not in self.all_variables: value_params = create_fc_network_params( input_dim=tf.shape(output)[-1].eval(), context_dim=self.cpg_context_size, fc_architecture=[1], initializer=tf.orthogonal_initializer(1.0), name=value_name, cpg_network_shape=self.cpg_network_shape, dropout=self.dropout, use_batch_norm=self.use_batch_norm, batch_norm_momentum=self.batch_norm_momentum, batch_norm_train_stats=self.batch_norm_train_stats) self.all_variables.update(value_params) value = fc_network(input=output, fc_params=self.all_variables[value_name], gen_vector=self.context_vector, is_train=is_train, activation=None) return dist, value, rnn_state