def lstm_block(input_tensor, num_units, opts, name=""): # PopnnLSTM uses a direct poplibs implementation lstm_cell = rnn_ops.PopnnLSTM(num_units=num_units, dtype=input_tensor.dtype, name=name) # The input is [timesteps, batch_size, input_size] return lstm_cell(input_tensor, training=opts.train)
def lstm(opts, inputs): if opts.popnn: # PopnnLSTM uses a direct poplibs implementation lstm_cell = rnn_ops.PopnnLSTM(opts.hidden_size, dtype=inputs.dtype) # The input is [timesteps, batch_size, input_size], # where input_size is equal to hidden_size return lstm_cell(inputs, training=opts.train) else: # The input for LSTMCell is instead [batch_size, input_size], # where input_size is equal to hidden_size lstm_cell = tf.nn.rnn_cell.LSTMCell(opts.hidden_size) # Dynamic_rnn creates a loop that passes slices across timesteps to LSTMCell. # This is expanded in tensorflow creating a less optimal solution than PopnnLSTM. return tf.nn.dynamic_rnn(cell=lstm_cell, inputs=inputs, dtype=inputs.dtype, time_major=True)
def create_policy(*infeed_data): """Act according to current policy and generate action probability. """ dis_obs = list(infeed_data[:4]) cont_obs = list(infeed_data[4:8]) state_in = infeed_data[-1] # Look up embedding for all the discrete obs emb_lookup = [] with tf.variable_scope("popnn_lookup"): for index, obs in enumerate(dis_obs): emb_matrix = tf.get_variable( f'emb_matrix{index}', [DIS_OBS_CARDINALITY[index], DIS_OBS_EMB_SIZE[index]], DTYPE) emb_lookup.append( embedding_ops.embedding_lookup(emb_matrix, obs, name=f'emb_lookup{index}')) # Clip some continuous observations cont_obs[-1] = tf.clip_by_value(cont_obs[-1], -5.0, 5.0, name="clip") # Concat groups of observations obs_concat = [] for d_obs, c_obs in zip(emb_lookup, cont_obs): obs_concat.append(tf.concat([d_obs, c_obs], axis=3, name="concat_obs")) # Fully connected transformations num_output = 8 obs_concat[-1] = Dense(num_output, dtype=DTYPE)(obs_concat[-1]) # Reduce max obs_concat = [tf.reduce_max(obs, axis=2) for obs in obs_concat] # Final concat of all the observations lstm_input = tf.concat(obs_concat, axis=2, name="concat_all") # LSTM layer lstm_input = tf.transpose( lstm_input, perm=[1, 0, 2], name="pre_lstm_transpose") # PopnnLSTM uses time-major tensors lstm_cell = rnn_ops.PopnnLSTM(num_units=LSTM_HIDDEN_SIZE, dtype=DTYPE, partials_dtype=DTYPE, name="lstm") lstm_output, state_out = lstm_cell( lstm_input, training=True, initial_state=tf.nn.rnn_cell.LSTMStateTuple(state_in[:, 0], state_in[:, 1])) lstm_output = tf.transpose(lstm_output, perm=[1, 0, 2], name="post_lstm_transpose") logits = Dense(NUM_ACTIONS, name="logits", dtype=DTYPE)(lstm_output) log_prob = tf.nn.log_softmax(logits, name="prob") # make action selection op (outputs int actions, sampled from policy) actions = tf.random.categorical(logits=tf.reshape(logits, (-1, NUM_ACTIONS)), num_samples=1) actions = tf.reshape(actions, (args.batch_size, args.time_steps)) action_masks = tf.one_hot(actions, NUM_ACTIONS, dtype=DTYPE) action_prob = tf.reduce_sum(action_masks * log_prob, axis=-1) return action_prob