예제 #1
0
def lstm_block(input_tensor, num_units, opts, name=""):
    # PopnnLSTM uses a direct poplibs implementation
    lstm_cell = rnn_ops.PopnnLSTM(num_units=num_units,
                                  dtype=input_tensor.dtype,
                                  name=name)
    # The input is [timesteps, batch_size, input_size]
    return lstm_cell(input_tensor, training=opts.train)
예제 #2
0
def lstm(opts, inputs):
    if opts.popnn:
        # PopnnLSTM uses a direct poplibs implementation
        lstm_cell = rnn_ops.PopnnLSTM(opts.hidden_size, dtype=inputs.dtype)
        # The input is [timesteps, batch_size, input_size],
        # where input_size is equal to hidden_size
        return lstm_cell(inputs, training=opts.train)
    else:
        # The input for LSTMCell is instead [batch_size, input_size],
        # where input_size is equal to hidden_size
        lstm_cell = tf.nn.rnn_cell.LSTMCell(opts.hidden_size)
        # Dynamic_rnn creates a loop that passes slices across timesteps to LSTMCell.
        # This is expanded in tensorflow creating a less optimal solution than PopnnLSTM.
        return tf.nn.dynamic_rnn(cell=lstm_cell,
                                 inputs=inputs,
                                 dtype=inputs.dtype,
                                 time_major=True)
예제 #3
0
def create_policy(*infeed_data):
    """Act according to current policy and generate action probability. """

    dis_obs = list(infeed_data[:4])
    cont_obs = list(infeed_data[4:8])
    state_in = infeed_data[-1]

    # Look up embedding for all the discrete obs
    emb_lookup = []
    with tf.variable_scope("popnn_lookup"):
        for index, obs in enumerate(dis_obs):
            emb_matrix = tf.get_variable(
                f'emb_matrix{index}',
                [DIS_OBS_CARDINALITY[index], DIS_OBS_EMB_SIZE[index]], DTYPE)
            emb_lookup.append(
                embedding_ops.embedding_lookup(emb_matrix,
                                               obs,
                                               name=f'emb_lookup{index}'))

    # Clip some continuous observations
    cont_obs[-1] = tf.clip_by_value(cont_obs[-1], -5.0, 5.0, name="clip")

    # Concat groups of observations
    obs_concat = []
    for d_obs, c_obs in zip(emb_lookup, cont_obs):
        obs_concat.append(tf.concat([d_obs, c_obs], axis=3, name="concat_obs"))

    # Fully connected transformations
    num_output = 8
    obs_concat[-1] = Dense(num_output, dtype=DTYPE)(obs_concat[-1])
    # Reduce max
    obs_concat = [tf.reduce_max(obs, axis=2) for obs in obs_concat]

    # Final concat of all the observations
    lstm_input = tf.concat(obs_concat, axis=2, name="concat_all")

    # LSTM layer
    lstm_input = tf.transpose(
        lstm_input, perm=[1, 0, 2],
        name="pre_lstm_transpose")  # PopnnLSTM uses time-major tensors
    lstm_cell = rnn_ops.PopnnLSTM(num_units=LSTM_HIDDEN_SIZE,
                                  dtype=DTYPE,
                                  partials_dtype=DTYPE,
                                  name="lstm")
    lstm_output, state_out = lstm_cell(
        lstm_input,
        training=True,
        initial_state=tf.nn.rnn_cell.LSTMStateTuple(state_in[:, 0],
                                                    state_in[:, 1]))
    lstm_output = tf.transpose(lstm_output,
                               perm=[1, 0, 2],
                               name="post_lstm_transpose")
    logits = Dense(NUM_ACTIONS, name="logits", dtype=DTYPE)(lstm_output)
    log_prob = tf.nn.log_softmax(logits, name="prob")

    # make action selection op (outputs int actions, sampled from policy)
    actions = tf.random.categorical(logits=tf.reshape(logits,
                                                      (-1, NUM_ACTIONS)),
                                    num_samples=1)
    actions = tf.reshape(actions, (args.batch_size, args.time_steps))

    action_masks = tf.one_hot(actions, NUM_ACTIONS, dtype=DTYPE)
    action_prob = tf.reduce_sum(action_masks * log_prob, axis=-1)

    return action_prob