예제 #1
0
def common_layers():
    cur_layers = []
    if FLAGS.flatten_non_batch_time_dims:
        cur_layers = [
            layers.Div(divisor=255.0),
            layers.Flatten(num_axis_to_keep=2)
        ]
    body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()]
    return cur_layers + body
예제 #2
0
def common_layers():
    # TODO(afrozm): Refactor.
    if "NoFrameskip" in FLAGS.env_problem_name:
        return atari_layers()

    cur_layers = []
    if FLAGS.flatten_dims:
        cur_layers = [
            layers.Div(divisor=255.0),
            layers.Flatten(num_axis_to_keep=2)
        ]
    body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()]
    return cur_layers + body
예제 #3
0
def common_layers():
    # TODO(afrozm): Refactor.
    if "NoFrameskip" in FLAGS.env_problem_name:
        return atari_layers()

    return [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()]