def build_fully_conv(obs_spec, act_spec, data_format='channels_first', broadcast_non_spatial=False, fc_dim=256): print('obs_spec', obs_spec) print('act_spec', act_spec) screen, screen_input = spatial_block('screen', obs_spec.spaces[0], conv_cfg(data_format, 'relu')) minimap, minimap_input = spatial_block('minimap', obs_spec.spaces[1], conv_cfg(data_format, 'relu')) non_spatial_inputs = [Input(s.shape) for s in obs_spec.spaces[2:]] if broadcast_non_spatial: non_spatial, spatial_dim = non_spatial_inputs[1], obs_spec.spaces[ 0].shape[1] non_spatial = Log()(non_spatial) broadcasted_non_spatial = Broadcast2D(spatial_dim)(non_spatial) state = Concatenate(axis=1, name="state_block")( [screen, minimap, broadcasted_non_spatial]) else: state = Concatenate(axis=1, name="state_block")([screen, minimap]) fc = Flatten(name="state_flat")(state) fc = Dense(fc_dim, **dense_cfg('relu'))(fc) value = Dense(1, name="value_out", **dense_cfg(scale=0.1))(fc) value = Squeeze(axis=-1)(value) logits = [] for space in act_spec: if space.is_spatial(): logits.append( Conv2D(1, 1, **conv_cfg(data_format, scale=0.1))(state)) logits[-1] = Flatten()(logits[-1]) else: logits.append(Dense(space.size(), **dense_cfg(scale=0.1))(fc)) mask_actions = Lambda(lambda x: tf.where(non_spatial_inputs[0] > 0, x, -1000 * tf.ones_like(x)), name="mask_unavailable_action_ids") logits[0] = mask_actions(logits[0]) return Model(inputs=[screen_input, minimap_input] + non_spatial_inputs, outputs=logits + [value])
def __init__(self, act_spec, logits): """ policy_logits: masked, (len(act_spec), ) relational_spatial: _deconv4x output, (32, 32, channel_3) :param act_spec: :param logits: [policy_logits, relational_spatial] """ # super().__init__(act_spec, logits) shared_features, policy_logits, relational_spatial = logits policy_dist = self.make_dist(act_spec.spaces[0], policy_logits) policy_sample = policy_dist.sample() # action_embed: [30, 16] action_embed = Embedding(int(policy_logits.shape[-1]), 16)(policy_sample) # action_logits: [30, 16+16, 32, 32] action_logits = Concatenate(axis=1)([ relational_spatial, Broadcast2D(size=relational_spatial.shape[-1])(action_embed) ]) # conditioned_shared_feature: [30, 576+16=592] conditioned_shared_feature = Concatenate(axis=-1)( [shared_features, action_embed]) self.logits = [ policy_logits, ] data_format = 'channels_first' # TODO: refactor me for space in list(act_spec.spaces)[1:]: # [0]: function_id if space.is_spatial(): conv_layer = Conv2D(1, 1, **conv_cfg(data_format, scale=0.1)) spatial_param_logit = conv_layer(action_logits) self.logits.append(spatial_param_logit) self.logits[-1] = Flatten()(self.logits[-1]) else: fc_layer = Dense(space.size(), **dense_cfg(scale=0.1)) self.logits.append(fc_layer(conditioned_shared_feature)) self.dists = [ policy_dist, ] + [ self.make_dist(s, l) for s, l in zip(act_spec.spaces[1:], self.logits[1:]) ] # self.dists = [self.make_dist(s, l) for s, l in zip(act_spec.spaces, self.logits)] self.entropy = sum([dist.entropy() for dist in self.dists]) self.sample = [ policy_sample, ] + [dist.sample() for dist in self.dists[1:]] args_mask = tf.constant(act_spec.spaces[0].args_mask, dtype=tf.float32) # (23, 11) self.inputs = [ tf.placeholder(s.dtype, [None, *s.shape]) for s in list(act_spec) ] act_args_mask = tf.gather(args_mask, self.inputs[0]) # masked action_id act_args_mask = tf.transpose(act_args_mask, [1, 0]) self.logli = self.dists[0].log_prob(self.inputs[0]) for i in range(1, len(self.dists)): self.logli += act_args_mask[i - 1] * self.dists[i].log_prob( self.inputs[i])
def build_relational(obs_spec, act_spec, data_format='channels_first', broadcast_non_spatial=False, fc_dim=256): # https://github.com/deepmind/pysc2/blob/master/docs/environment.md#last-actions # obs_spec: screen, minimap, player (11,), last_actions (n,) # At each time step agents are presented with 4 sources of information: # minimap, screen, player, and previous-action. # TODO: set spatial_dim <- 64 screen, screen_input = spatial_block('screen', obs_spec.spaces[0], conv_cfg(data_format, 'relu')) minimap, minimap_input = spatial_block('minimap', obs_spec.spaces[1], conv_cfg(data_format, 'relu')) # TODO: obs_spec[2:] <- ['available_actions', 'player', 'previous_action'] # 'non-spatial': ['available_actions', 'player'] available_actions = Input(obs_spec.spaces[2].shape) non_spatial_inputs_list = [Input(s.shape) for s in obs_spec.spaces[3:]] non_spatial_inputs = Concatenate(axis=1, name='non_spatial_inputs')(non_spatial_inputs_list) input_2d = _mlp2(Flatten()(non_spatial_inputs), units=[128, 64], cfg=dense_cfg('relu')) # [64, ] # broadcast_non_spatial = False if broadcast_non_spatial: non_spatial, spatial_dim = non_spatial_inputs[1], obs_spec.spaces[0].shape[1] non_spatial = Log()(non_spatial) broadcasted_non_spatial = Broadcast2D(spatial_dim)(non_spatial) input_3d = Concatenate(axis=1, name="state_block")([screen, minimap, broadcasted_non_spatial]) else: # concat along channel dim: [N, C, H, W] # input_3d: [8, 8, channel_1+channel_1] input_3d = Concatenate(axis=1, name="state_block")([screen, minimap]) # TODO: state -> Conv2DLSTM -> output_3d: [8, 8, channel_2] # TODO: treat channel_x as parameters or read from configuration files # channel_2 = 96 channel_2 = 96 output_3d = ConvLSTM2D(filters=channel_2, kernel_size=3, **conv2dlstm_cfg())(input_3d) relational_spatial = _resnet12(output_3d, filters=[64, 48, 32, 32], cfg=conv_cfg(data_format, 'relu')) # [8, 8, 32] channel_3 = 16 relational_spatial = _deconv4x(relational_spatial, filters=[channel_3, channel_3], kernel_sizes=[4, 4], cfg=deconv_cfg(data_format, 'relu')) # [32, 32, channel_3] # relational_spatial = Concatenate()(Broadcast2D()) # [32, 32, channel_3+embed_size] # TODO: tile action embedding, 1x1x1 conv # embed_sz = 16 # [32, 32, channel_3+embed_sz] # TODO: check scale factor relational_nonspatial = _mlp2(Flatten()(output_3d), units=[512, 512], cfg=dense_cfg('relu')) # [512, ] shared_features = Concatenate(axis=1, name='shared_features')([relational_nonspatial, input_2d]) # [512+64, ] value = _mlp2(shared_features, units=[256, 1], cfg=dense_cfg('relu', scale=0.1)) # TODO: check scale_factor policy_logits = _mlp2(shared_features, units=[256, len(act_spec)], cfg=dense_cfg('relu', scale=0.1)) # fc = Flatten(name="state_flat")(state) # fc = Dense(fc_dim, **dense_cfg('relu'))(fc) # # value = Dense(1, name="value_out", **dense_cfg(scale=0.1))(fc) # value = Squeeze(axis=-1)(value) # TODO: calculate policy logits # logits = [] # sensible action set for all minigames # action_ids = [0, 1, 2, 3, 4, 6, 7, 12, 13, 42, 44, 50, 91, 183, 234, 309, 331, 332, 333, 334, 451, 452, 490] # env.obs_spec() # Out[5]: # Spec: Observation # Space(screen # {player_relative, selected, visibility_map, unit_hit_points_ratio, unit_density}, (5, 16, 16), numpy.int32) # Space(minimap # {player_relative, selected, visibility_map, camera}, (4, 16, 16), numpy.int32) # Space(available_actions, (23,), numpy.int32) # Space(player, (11,), numpy.int32) # env.act_spec() # Out[6]: # Spec: Action # Space(function_id, (), cat: 23, numpy.int32) # Space(screen, (), cat: (16, 16), numpy.int32) # Space(minimap, (), cat: (16, 16), numpy.int32) # Space(screen2, (), cat: (16, 16), numpy.int32) # Space(queued, (), cat: 2, numpy.int32) # Space(control_group_act, (), cat: 5, numpy.int32) # Space(control_group_id, (), cat: 10, numpy.int32) # Space(select_add, (), cat: 2, numpy.int32) # Space(select_point_act, (), cat: 4, numpy.int32) # Space(select_unit_act, (), cat: 4, numpy.int32) # Space(select_worker, (), cat: 4, numpy.int32) # Space(build_queue_id, (), cat: 10, numpy.int32) # TODO: check 1x1x1 conv & scale_factor # conv_layer = Conv2D(filters=1, kernel_size=1, **conv_cfg(data_format, scale=0.1)) # spatial_logits_part = Flatten()(conv_layer(relational_spatial)) # for space in act_spec: # if space.is_spatial(): # logits.append(spatial_logits_part) # else: # logits.append(Dense(space.size(), **dense_cfg(scale=0.1))(fc)) # non_spatial_inputs[0]: available_actions mask_actions = Lambda( lambda x: tf.where(available_actions > 0, x, -1000 * tf.ones_like(x)), name="mask_unavailable_action_ids" ) # function_id, shape = (23, ) (#action_ids) # logits[0] = mask_actions(logits[0]) policy_logits = mask_actions(policy_logits) # TODO: action embedding, matmul -> transformed_logits return Model( inputs=[screen_input, minimap_input] + [available_actions] + non_spatial_inputs_list, outputs=[policy_logits, relational_spatial, value] )