Ejemplo n.º 1
0
def build_cnn_nature(obs_spec,
                     act_spec,
                     data_format='channels_first',
                     value_separate=False):
    conv_cfg = dict(padding='same', data_format=data_format, activation='relu')
    conv_spec = [(32, 8, 4), (64, 4, 2), (64, 3, 1)]

    inputs = [Input(s.shape, name="input_" + s.name) for s in obs_spec]
    inputs_concat = Concatenate()(inputs) if len(inputs) > 1 else inputs[0]

    # expected NxCxHxW, but got NxHxWxC
    if data_format == 'channels_first' and inputs_concat.shape[1] > 3:
        inputs_concat = Transpose([0, 3, 1, 2])(inputs_concat)

    inputs_scaled = Rescale(1. / 255)(inputs_concat)

    x = build_cnn(inputs_scaled,
                  conv_spec,
                  conv_cfg,
                  dense=512,
                  prefix='policy_')
    outputs = [Dense(s.size(), name="logits_" + s.name)(x) for s in act_spec]

    if value_separate:
        x = build_cnn(inputs_scaled,
                      conv_spec,
                      conv_cfg,
                      dense=512,
                      prefix='value_')

    value = Dense(1, name="value_out")(x)
    value = Squeeze(axis=-1)(value)
    outputs.append(value)

    return Model(inputs=inputs, outputs=outputs)
Ejemplo n.º 2
0
def spatial_block(name, space, cfg, batch_size):
    # TODO: tile spatial features with binary masks
    inpt = Input(space.shape, name=name + '_input', batch_size=batch_size)
    # tf.split(x, num_splits, axis=axis) -> List
    block = Split(space.shape[0], axis=1)(inpt)

    # Input Preprocessing
    for i, (name,
            dim) in enumerate(zip(space.spatial_feats, space.spatial_dims)):
        if dim > 1:
            # categorical features are embedded into a continuous 10-dimensional space.
            block[i] = Squeeze(axis=1)(block[i])
            block[i] = Embedding(input_dim=dim, output_dim=10)(block[i])
            block[i] = Transpose([0, 3, 1,
                                  2])(block[i])  # [N, H, W, C] -> [N, C, H, W]
        else:
            # numerical features are re-scaled with a logarithmic transformation
            block[i] = Log()(block[i])

    # State Encoding
    # TODO: determine channel_1
    channel_1 = len(block)
    block = Concatenate(axis=1)(
        block)  # concatenate along channel dimension -> [N, C, H, W]
    # TODO: adapt to different resolutions (64x64)
    block = _residual_block(block, filters=channel_1,
                            cfg=cfg)  # 32x32 -> 16x16
    block = _residual_block(block, filters=channel_1, cfg=cfg)  # 16x16 -> 8x8

    # block: [8, 8, channel_1], inpt: [32, 32, x]
    return block, inpt
Ejemplo n.º 3
0
def build_mlp(
        obs_spec: Spec,
        act_spec: Spec,
        layer_sizes=(64, 64),
        activation='relu',
        initializer='glorot_uniform',
        value_separate=False,
        obs_shift=False,
        obs_scale=False) -> tf.keras.Model:
    """
    Factory method for a simple fully connected neural network model used in e.g. MuJuCo environment

    If value separate is set to true then a separate path is added for value fn, otherwise branches out of last layer
    If obs shift is set to true then observations are normalized to mean zero with running mean estimate
    If obs scale is set to true then observations are standardized to std.dev one with running std.dev estimate
    """
    inputs = inputs_ = [Input(s.shape, name="input_" + s.name) for s in obs_spec]
    if obs_shift or obs_scale:
        inputs_ = [RunningStatsNorm(obs_shift, obs_scale, name="norm_" + s.name)(x) for s, x in zip(obs_spec, inputs_)]
    inputs_concat = Concatenate()(inputs_) if len(inputs_) > 1 else inputs_[0]

    x = build_fc(inputs_concat, layer_sizes, activation, initializer)
    outputs = [build_logits(space, x, initializer) for space in act_spec]

    if value_separate:
        x = build_fc(inputs_concat, layer_sizes, activation, initializer, 'value_')

    value = Dense(1, name="value_out", kernel_initializer=initializer)(x)
    value = Squeeze(axis=-1)(value)
    outputs.append(value)

    return tf.keras.Model(inputs=inputs, outputs=outputs)
Ejemplo n.º 4
0
def build_fully_conv(obs_spec,
                     act_spec,
                     data_format='channels_first',
                     broadcast_non_spatial=False,
                     fc_dim=256):
    print('obs_spec', obs_spec)
    print('act_spec', act_spec)

    screen, screen_input = spatial_block('screen', obs_spec.spaces[0],
                                         conv_cfg(data_format, 'relu'))
    minimap, minimap_input = spatial_block('minimap', obs_spec.spaces[1],
                                           conv_cfg(data_format, 'relu'))

    non_spatial_inputs = [Input(s.shape) for s in obs_spec.spaces[2:]]

    if broadcast_non_spatial:
        non_spatial, spatial_dim = non_spatial_inputs[1], obs_spec.spaces[
            0].shape[1]
        non_spatial = Log()(non_spatial)
        broadcasted_non_spatial = Broadcast2D(spatial_dim)(non_spatial)
        state = Concatenate(axis=1, name="state_block")(
            [screen, minimap, broadcasted_non_spatial])
    else:
        state = Concatenate(axis=1, name="state_block")([screen, minimap])

    fc = Flatten(name="state_flat")(state)
    fc = Dense(fc_dim, **dense_cfg('relu'))(fc)

    value = Dense(1, name="value_out", **dense_cfg(scale=0.1))(fc)
    value = Squeeze(axis=-1)(value)

    logits = []
    for space in act_spec:
        if space.is_spatial():
            logits.append(
                Conv2D(1, 1, **conv_cfg(data_format, scale=0.1))(state))
            logits[-1] = Flatten()(logits[-1])
        else:
            logits.append(Dense(space.size(), **dense_cfg(scale=0.1))(fc))

    mask_actions = Lambda(lambda x: tf.where(non_spatial_inputs[0] > 0, x,
                                             -1000 * tf.ones_like(x)),
                          name="mask_unavailable_action_ids")
    logits[0] = mask_actions(logits[0])

    return Model(inputs=[screen_input, minimap_input] + non_spatial_inputs,
                 outputs=logits + [value])
Ejemplo n.º 5
0
def build_mlp(obs_spec, act_spec, layer_sizes=(64, 64), activation='relu', value_separate=False, normalize_obs=False):
    inputs = inputs_ = [Input(s.shape, name="input_" + s.name) for s in obs_spec]
    if normalize_obs:
        inputs_ = [RunningStatsNorm(name="norm_" + s.name)(x) for s, x in zip(obs_spec, inputs_)]
    inputs_concat = Concatenate()(inputs_) if len(inputs_) > 1 else inputs_[0]

    x = build_fc(inputs_concat, layer_sizes, activation)
    outputs = [build_logits(x, space) for space in act_spec]

    if value_separate:
        x = build_fc(inputs_concat, layer_sizes, activation, 'value_')

    value = Dense(1, name="value_out")(x)
    value = Squeeze(axis=-1)(value)
    outputs.append(value)

    return tf.keras.Model(inputs=inputs, outputs=outputs)
Ejemplo n.º 6
0
def build_mlp(obs_spec,
              act_spec,
              layer_sizes=(64, 64),
              activation='relu',
              value_separate=False):
    inputs = [Input(s.shape, name="input_" + s.name) for s in obs_spec]
    inputs_concat = Concatenate()(inputs) if len(inputs) > 1 else inputs[0]

    x = build_fc(inputs_concat, layer_sizes, activation)
    outputs = [Dense(s.size(), name="logits_" + s.name)(x) for s in act_spec]

    if value_separate:
        x = build_fc(inputs_concat, layer_sizes, activation, 'value_')

    value = Dense(1, name="value_out")(x)
    value = Squeeze(axis=-1)(value)
    outputs.append(value)

    return Model(inputs=inputs, outputs=outputs)
Ejemplo n.º 7
0
def spatial_block(name, space, cfg):
    inpt = Input(space.shape, name=name + '_input')
    block = Split(space.shape[0], axis=1)(inpt)

    for i, (name, dim) in enumerate(zip(space.spatial_feats, space.spatial_dims)):
        if dim > 1:
            embed_dim = int(max(1, round(np.log2(dim))))
            block[i] = Squeeze(axis=1)(block[i])
            block[i] = Embedding(input_dim=dim, output_dim=embed_dim)(block[i])
            # [N, H, W, C] -> [N, C, H, W]
            block[i] = Transpose([0, 3, 1, 2])(block[i])
        else:
            block[i] = Log()(block[i])

    block = Concatenate(axis=1)(block)
    block = Conv2D(16, 5, **cfg)(block)
    block = Conv2D(32, 3, **cfg)(block)

    return block, inpt
Ejemplo n.º 8
0
def spatial_block(name, space, cfg):
    inpt = Input(space.shape, name=name + '_input')
    block = Split(space.shape[0], axis=1)(inpt)

    for i, (name,
            dim) in enumerate(zip(space.spatial_feats, space.spatial_dims)):
        if dim > 1:
            block[i] = Squeeze(axis=1)(block[i])
            # Embedding dim 10 as per https://arxiv.org/pdf/1806.01830.pdf
            block[i] = Embedding(input_dim=dim, output_dim=10)(block[i])
            # [N, H, W, C] -> [N, C, H, W]
            block[i] = Transpose([0, 3, 1, 2])(block[i])
        else:
            block[i] = Log()(block[i])

    block = Concatenate(axis=1)(block)
    block = Conv2D(16, 5, **cfg)(block)
    block = Conv2D(32, 3, **cfg)(block)

    return block, inpt
Ejemplo n.º 9
0
def build_relational(obs_spec,
                     act_spec,
                     data_format='channels_first',
                     broadcast_non_spatial=False):
    # https://github.com/deepmind/pysc2/blob/master/docs/environment.md#last-actions
    # obs_spec: screen, minimap, player (11,), last_actions (n,)
    # At each time step agents are presented with 4 sources of information:
    # minimap, screen, player, and previous-action.
    assert broadcast_non_spatial is False, 'broadcast_non_spatial should be false for relational agents'

    batch_size = None
    channel_3 = 16
    channel_2 = 96

    # TODO: set spatial_dim <- 64
    screen, screen_input = spatial_block('screen',
                                         obs_spec.spaces[0],
                                         conv_cfg(data_format, 'relu'),
                                         batch_size=batch_size)
    minimap, minimap_input = spatial_block('minimap',
                                           obs_spec.spaces[1],
                                           conv_cfg(data_format, 'relu'),
                                           batch_size=batch_size)

    # TODO: obs_spec[2:] <- ['available_actions', 'player', 'last_actions']
    non_spatial_inputs_list = [
        Input(s.shape, batch_size=batch_size) for s in obs_spec.spaces[2:]
    ]
    available_actions = non_spatial_inputs_list[0]
    non_spatial_inputs = Concatenate(axis=1, name='non_spatial_inputs')(
        non_spatial_inputs_list[1:])

    # input_2d: [30, 64], input_3d: [30, 9, 8, 8]
    input_2d = _mlp2(Flatten()(non_spatial_inputs),
                     units=[128, 64],
                     cfg=dense_cfg('relu'))
    input_3d = Concatenate(axis=1, name="state_block")([screen, minimap])

    # TODO: treat channel_x as parameters or read from configuration files

    class ExpandDims(Lambda):
        def __init__(self, axis):
            Lambda.__init__(self, lambda x: tf.expand_dims(x, axis))

    # input_3d = ExpandDims(axis=1)(input_3d)

    # # output_3d: [30, 96, 8, 8]
    # # TODO: unroll length
    # output_3d = ConvLSTM2D(
    #     filters=channel_2,
    #     kernel_size=3,
    #     stateful=True,
    #     **conv2dlstm_cfg()
    # )(input_3d)

    output_3d = Conv2D(32, 3, **conv_cfg(data_format, 'relu'))(input_3d)
    output_3d = Conv2D(96, 3, **conv_cfg(data_format, 'relu'))(output_3d)

    # relational_spatial: [30, 32, 8, 8]
    relational_spatial = _resnet12(output_3d,
                                   filters=[64, 48, 32, 32],
                                   cfg=conv_cfg(data_format, 'relu'))
    # relational_spatial: [30, 16, 32, 32]
    relational_spatial = _deconv4x(relational_spatial,
                                   filters=[channel_3, channel_3],
                                   kernel_sizes=[4, 4],
                                   cfg=deconv_cfg(data_format, 'relu'))

    # TODO: check scale factor
    # relational_nonspatial: [30, 512]
    relational_nonspatial = _mlp2(Flatten()(output_3d),
                                  units=[512, 512],
                                  cfg=dense_cfg('relu'))

    # shared_features: [30, 512+64=576]
    shared_features = Concatenate(axis=1, name='shared_features')(
        [relational_nonspatial, input_2d])  # [512+64, ]

    # [30,]
    value = _mlp2(shared_features,
                  units=[256, 1],
                  cfg=dense_cfg('relu', scale=0.1))
    value = Squeeze(axis=-1)(value)

    # [30, #actions=549]
    policy_logits = _mlp2(shared_features,
                          units=[512, list(act_spec)[0].size()],
                          cfg=dense_cfg('relu', scale=0.1))

    mask_actions = Lambda(
        lambda x: tf.where(available_actions > 0, x, -1000 * tf.ones_like(x)),
        name="mask_unavailable_action_ids")
    policy_logits = mask_actions(policy_logits)

    # TODO: check
    return Model(
        inputs=[screen_input, minimap_input] + non_spatial_inputs_list,
        outputs=[shared_features, policy_logits, relational_spatial, value])