Exemplo n.º 1
0
def build_fully_conv(obs_spec,
                     act_spec,
                     data_format='channels_first',
                     broadcast_non_spatial=False,
                     fc_dim=256):
    print('obs_spec', obs_spec)
    print('act_spec', act_spec)

    screen, screen_input = spatial_block('screen', obs_spec.spaces[0],
                                         conv_cfg(data_format, 'relu'))
    minimap, minimap_input = spatial_block('minimap', obs_spec.spaces[1],
                                           conv_cfg(data_format, 'relu'))

    non_spatial_inputs = [Input(s.shape) for s in obs_spec.spaces[2:]]

    if broadcast_non_spatial:
        non_spatial, spatial_dim = non_spatial_inputs[1], obs_spec.spaces[
            0].shape[1]
        non_spatial = Log()(non_spatial)
        broadcasted_non_spatial = Broadcast2D(spatial_dim)(non_spatial)
        state = Concatenate(axis=1, name="state_block")(
            [screen, minimap, broadcasted_non_spatial])
    else:
        state = Concatenate(axis=1, name="state_block")([screen, minimap])

    fc = Flatten(name="state_flat")(state)
    fc = Dense(fc_dim, **dense_cfg('relu'))(fc)

    value = Dense(1, name="value_out", **dense_cfg(scale=0.1))(fc)
    value = Squeeze(axis=-1)(value)

    logits = []
    for space in act_spec:
        if space.is_spatial():
            logits.append(
                Conv2D(1, 1, **conv_cfg(data_format, scale=0.1))(state))
            logits[-1] = Flatten()(logits[-1])
        else:
            logits.append(Dense(space.size(), **dense_cfg(scale=0.1))(fc))

    mask_actions = Lambda(lambda x: tf.where(non_spatial_inputs[0] > 0, x,
                                             -1000 * tf.ones_like(x)),
                          name="mask_unavailable_action_ids")
    logits[0] = mask_actions(logits[0])

    return Model(inputs=[screen_input, minimap_input] + non_spatial_inputs,
                 outputs=logits + [value])
Exemplo n.º 2
0
    def __init__(self, act_spec, logits):
        """
        policy_logits: masked, (len(act_spec), )
        relational_spatial: _deconv4x output, (32, 32, channel_3)
        :param act_spec:
        :param logits: [policy_logits, relational_spatial]
        """
        # super().__init__(act_spec, logits)

        shared_features, policy_logits, relational_spatial = logits
        policy_dist = self.make_dist(act_spec.spaces[0], policy_logits)
        policy_sample = policy_dist.sample()

        # action_embed: [30, 16]
        action_embed = Embedding(int(policy_logits.shape[-1]),
                                 16)(policy_sample)

        # action_logits: [30, 16+16, 32, 32]
        action_logits = Concatenate(axis=1)([
            relational_spatial,
            Broadcast2D(size=relational_spatial.shape[-1])(action_embed)
        ])

        # conditioned_shared_feature: [30, 576+16=592]
        conditioned_shared_feature = Concatenate(axis=-1)(
            [shared_features, action_embed])

        self.logits = [
            policy_logits,
        ]
        data_format = 'channels_first'  # TODO: refactor me
        for space in list(act_spec.spaces)[1:]:  # [0]: function_id
            if space.is_spatial():
                conv_layer = Conv2D(1, 1, **conv_cfg(data_format, scale=0.1))
                spatial_param_logit = conv_layer(action_logits)
                self.logits.append(spatial_param_logit)
                self.logits[-1] = Flatten()(self.logits[-1])
            else:
                fc_layer = Dense(space.size(), **dense_cfg(scale=0.1))
                self.logits.append(fc_layer(conditioned_shared_feature))

        self.dists = [
            policy_dist,
        ] + [
            self.make_dist(s, l)
            for s, l in zip(act_spec.spaces[1:], self.logits[1:])
        ]
        # self.dists = [self.make_dist(s, l) for s, l in zip(act_spec.spaces, self.logits)]
        self.entropy = sum([dist.entropy() for dist in self.dists])
        self.sample = [
            policy_sample,
        ] + [dist.sample() for dist in self.dists[1:]]
        args_mask = tf.constant(act_spec.spaces[0].args_mask,
                                dtype=tf.float32)  # (23, 11)
        self.inputs = [
            tf.placeholder(s.dtype, [None, *s.shape]) for s in list(act_spec)
        ]
        act_args_mask = tf.gather(args_mask,
                                  self.inputs[0])  # masked action_id
        act_args_mask = tf.transpose(act_args_mask, [1, 0])

        self.logli = self.dists[0].log_prob(self.inputs[0])
        for i in range(1, len(self.dists)):
            self.logli += act_args_mask[i - 1] * self.dists[i].log_prob(
                self.inputs[i])
Exemplo n.º 3
0
def build_relational(obs_spec, act_spec, data_format='channels_first', broadcast_non_spatial=False, fc_dim=256):
    # https://github.com/deepmind/pysc2/blob/master/docs/environment.md#last-actions
    # obs_spec: screen, minimap, player (11,), last_actions (n,)
    # At each time step agents are presented with 4 sources of information:
    # minimap, screen, player, and previous-action.

    # TODO: set spatial_dim <- 64
    screen, screen_input = spatial_block('screen', obs_spec.spaces[0], conv_cfg(data_format, 'relu'))
    minimap, minimap_input = spatial_block('minimap', obs_spec.spaces[1], conv_cfg(data_format, 'relu'))

    # TODO: obs_spec[2:] <- ['available_actions', 'player', 'previous_action']
    # 'non-spatial': ['available_actions', 'player']
    available_actions = Input(obs_spec.spaces[2].shape)
    non_spatial_inputs_list = [Input(s.shape) for s in obs_spec.spaces[3:]]
    non_spatial_inputs = Concatenate(axis=1, name='non_spatial_inputs')(non_spatial_inputs_list)
    input_2d = _mlp2(Flatten()(non_spatial_inputs), units=[128, 64], cfg=dense_cfg('relu'))  # [64, ]

    # broadcast_non_spatial = False
    if broadcast_non_spatial:
        non_spatial, spatial_dim = non_spatial_inputs[1], obs_spec.spaces[0].shape[1]
        non_spatial = Log()(non_spatial)
        broadcasted_non_spatial = Broadcast2D(spatial_dim)(non_spatial)
        input_3d = Concatenate(axis=1, name="state_block")([screen, minimap, broadcasted_non_spatial])
    else:
        # concat along channel dim: [N, C, H, W]
        # input_3d: [8, 8, channel_1+channel_1]
        input_3d = Concatenate(axis=1, name="state_block")([screen, minimap])

    # TODO: state -> Conv2DLSTM -> output_3d: [8, 8, channel_2]
    # TODO: treat channel_x as parameters or read from configuration files
    # channel_2 = 96
    channel_2 = 96
    output_3d = ConvLSTM2D(filters=channel_2, kernel_size=3, **conv2dlstm_cfg())(input_3d)
    relational_spatial = _resnet12(output_3d, filters=[64, 48, 32, 32], cfg=conv_cfg(data_format, 'relu'))  # [8, 8, 32]
    channel_3 = 16
    relational_spatial = _deconv4x(relational_spatial, filters=[channel_3, channel_3],
                                   kernel_sizes=[4, 4], cfg=deconv_cfg(data_format, 'relu'))  # [32, 32, channel_3]

    # relational_spatial = Concatenate()(Broadcast2D())  # [32, 32, channel_3+embed_size]
    # TODO: tile action embedding, 1x1x1 conv
    # embed_sz = 16
    # [32, 32, channel_3+embed_sz]

    # TODO: check scale factor
    relational_nonspatial = _mlp2(Flatten()(output_3d), units=[512, 512], cfg=dense_cfg('relu'))  # [512, ]

    shared_features = Concatenate(axis=1, name='shared_features')([relational_nonspatial, input_2d])  # [512+64, ]
    value = _mlp2(shared_features, units=[256, 1], cfg=dense_cfg('relu', scale=0.1))  # TODO: check scale_factor
    policy_logits = _mlp2(shared_features, units=[256, len(act_spec)], cfg=dense_cfg('relu', scale=0.1))

    # fc = Flatten(name="state_flat")(state)
    # fc = Dense(fc_dim, **dense_cfg('relu'))(fc)
    #
    # value = Dense(1, name="value_out", **dense_cfg(scale=0.1))(fc)
    # value = Squeeze(axis=-1)(value)

    # TODO: calculate policy logits
    # logits = []
    # sensible action set for all minigames
    # action_ids = [0, 1, 2, 3, 4, 6, 7, 12, 13, 42, 44, 50, 91, 183, 234, 309, 331, 332, 333, 334, 451, 452, 490]

    # env.obs_spec()
    # Out[5]:
    # Spec: Observation
    # Space(screen
    # {player_relative, selected, visibility_map, unit_hit_points_ratio, unit_density}, (5, 16, 16), numpy.int32)
    # Space(minimap
    # {player_relative, selected, visibility_map, camera}, (4, 16, 16), numpy.int32)
    # Space(available_actions, (23,), numpy.int32)
    # Space(player, (11,), numpy.int32)

    # env.act_spec()
    # Out[6]:
    # Spec: Action
    # Space(function_id, (), cat: 23, numpy.int32)
    # Space(screen, (), cat: (16, 16), numpy.int32)
    # Space(minimap, (), cat: (16, 16), numpy.int32)
    # Space(screen2, (), cat: (16, 16), numpy.int32)
    # Space(queued, (), cat: 2, numpy.int32)
    # Space(control_group_act, (), cat: 5, numpy.int32)
    # Space(control_group_id, (), cat: 10, numpy.int32)
    # Space(select_add, (), cat: 2, numpy.int32)
    # Space(select_point_act, (), cat: 4, numpy.int32)
    # Space(select_unit_act, (), cat: 4, numpy.int32)
    # Space(select_worker, (), cat: 4, numpy.int32)
    # Space(build_queue_id, (), cat: 10, numpy.int32)

    # TODO: check 1x1x1 conv & scale_factor
    # conv_layer = Conv2D(filters=1, kernel_size=1, **conv_cfg(data_format, scale=0.1))
    # spatial_logits_part = Flatten()(conv_layer(relational_spatial))
    # for space in act_spec:
    #     if space.is_spatial():
    #         logits.append(spatial_logits_part)
    #     else:
    #         logits.append(Dense(space.size(), **dense_cfg(scale=0.1))(fc))

    # non_spatial_inputs[0]: available_actions
    mask_actions = Lambda(
        lambda x: tf.where(available_actions > 0, x, -1000 * tf.ones_like(x)),
        name="mask_unavailable_action_ids"
    )
    # function_id, shape = (23, ) (#action_ids)
    # logits[0] = mask_actions(logits[0])
    policy_logits = mask_actions(policy_logits)

    # TODO: action embedding, matmul -> transformed_logits

    return Model(
        inputs=[screen_input, minimap_input] + [available_actions] + non_spatial_inputs_list,
        outputs=[policy_logits, relational_spatial, value]
    )