예제 #1
0
파일: agent.py 프로젝트: leb2/starcraft-rl
    def step(self, states, masks, memory):
        """
        :param states: List of states of length batch size. In this case, state is a dict with keys:
            "unit_embeddings": numpy array with shape [num_units, embedding_size]
            "state": numpy array with shape [*state_shape]
        :param masks: numpy array of shape [batch_size, num_actions]
        :param memory: numpy of shape [2, batch_size, memory_size] or None for the first step
        """
        if memory is None:
            memory = np.zeros((2, len(states), self.rnn_size))

        feed_dict = {
            **self.get_feed_dict(states, masks),
            self.memory_input: memory
        }
        results = self.session.run(
            [self.next_lstm_state, self.nonspacial_probs, self.unit_selection_probs,
             *self.spacial_probs_x, *self.spacial_probs_y], feed_dict)
        next_lstm_state, nonspacial_probs, selection_probs = results[:3]
        spacial_probs = results[3:]

        spacial_probs_x = spacial_probs[:self.num_screen_dims]
        spacial_probs_y = spacial_probs[self.num_screen_dims:]

        unit_coords = util.pad_stack([state['unit_coords'][:, :2] for state in states], pad_axis=0, stack_axis=0)
        sampled_action = self.sample_action_index_with_units(nonspacial_probs, spacial_probs_x, spacial_probs_y, selection_probs, unit_coords)

        log_prob = self.log_prob_numpy(sampled_action, nonspacial_probs, np.stack([spacial_probs_x, spacial_probs_y]), selection_probs)
        # TODO: check for unit embeddings
        return sampled_action, next_lstm_state, log_prob
예제 #2
0
    def step(self, states, masks, memories):
        """
        :param states: List of states of length batch size. In this case, state is a dict with keys:
            "unit_embeddings": numpy array with shape [num_units, embedding_size]
            "state": numpy array with shape [*state_shape]
        :param masks: numpy array of shape [batch_size, num_actions]
        :param memories: Array of memory dictionaries
        :return used_states is the state that was actually used, including modifications for memory
        """
        if memories is None:
            memories = []
            for _ in range(len(states)):
                memories.append({
                    'next_lstm_state':
                    np.zeros((2, len(states), self.rnn_size)),
                    'prev_action':
                    np.zeros(self.interface.num_actions)
                })
        for state, memory in zip(states, memories):
            state['prev_action'] = memory['prev_action']

        if use_lstm:
            feed_dict = {
                **self.get_feed_dict(states, masks), self.memory_input:
                memory['next_lstm_state']
            }
        else:
            feed_dict = {
                **self.get_feed_dict(states, masks),
            }
        results = self.session.run([
            self.next_lstm_state, self.nonspacial_probs,
            self.unit_selection_probs, *self.spacial_probs_x,
            *self.spacial_probs_y
        ], feed_dict)
        next_lstm_state, nonspacial_probs, selection_probs = results[:3]
        spacial_probs = results[3:]

        spacial_probs_x = spacial_probs[:self.num_screen_dims]
        spacial_probs_y = spacial_probs[self.num_screen_dims:]

        unit_coords = util.pad_stack(
            [state['unit_coords'][:, :2] for state in states],
            pad_axis=0,
            stack_axis=0)
        sampled_action = self.sample_action_index_with_units(
            nonspacial_probs, spacial_probs_x, spacial_probs_y,
            selection_probs, unit_coords)
        next_memories = []
        # TODO, to reimplement memory, you will probably have to np stack since memory is now a list of dictionaries
        # with a separate state for each item in the batch.
        for i in range(len(states)):
            next_memory = {
                'next_lstm_state': next_lstm_state,
                'prev_action': np.zeros(self.interface.num_actions)
            }
            next_memory['prev_action'][sampled_action[i][0]] = 1
            next_memories.append(next_memory)
        return states, sampled_action, next_memories
예제 #3
0
    def get_feed_dict(self, states, masks, actions=None, bootstrap_state=None):
        screens = np.stack([state['screen'] for state in states], axis=0)

        feed_dict = {
            self.mask_input: np.array(masks),
        }
        all_states = states if bootstrap_state is None else [
            *states, bootstrap_state
        ]
        unit_embeddings = util.pad_stack(
            [state['unit_embeddings'] for state in all_states],
            pad_axis=0,
            stack_axis=0)

        other_features = np.stack(
            [state['other_features'] for state in all_states], axis=0)
        # for state in all_states:
        # print("This is the state")
        # print(state)
        # print("This is the prev_action")
        # print(state['prev_action'])
        # if len(masks) == 1:
        #     print("this is the mask: ")
        #     print(masks)
        # print(masks)
        # print()
        prev_action_onehot = np.stack(
            [state['prev_action'] for state in all_states], axis=0)

        all_other_features = np.concatenate(
            [other_features, prev_action_onehot], axis=-1)
        feed_dict[self.other_features_input] = all_other_features
        feed_dict[self.unit_embeddings_input] = unit_embeddings

        if bootstrap_state is not None:
            bootstrap_screen = np.expand_dims(bootstrap_state['screen'],
                                              axis=0)
            feed_dict[self.state_input] = np.concatenate(
                [screens, bootstrap_screen], axis=0)
        else:
            feed_dict[self.state_input] = screens

        if actions is not None:
            nonspacial, spacials, selection_coords, selection_indices = zip(
                *actions)
            spacials = [(13, 27) if spacial is None else spacial
                        for spacial in spacials]
            selections = [
                -1 if selection is None else selection
                for selection in selection_indices
            ]
            feed_dict[self.action_input] = np.array(nonspacial)
            feed_dict[self.spacial_input] = np.array(spacials)
            feed_dict[self.unit_selection_input] = np.array(selections)
        return feed_dict
예제 #4
0
    def get_feed_dict(self, states, masks, actions=None, bootstrap_state=None):
        screens = np.stack([state['screen'] for state in states], axis=0)
        feed_dict = {
            self.state_input: np.array(states),
            self.mask_input: np.array(masks),
        }
        all_states = states if bootstrap_state is None else [
            *states, bootstrap_state
        ]
        unit_embeddings = util.pad_stack(
            [state['unit_embeddings'] for state in all_states],
            pad_axis=0,
            stack_axis=0)
        feed_dict[self.unit_embeddings_input] = unit_embeddings

        if bootstrap_state is not None:
            bootstrap_screen = np.expand_dims(bootstrap_state['screen'],
                                              axis=0)
            feed_dict[self.state_input] = np.concatenate(
                [screens, bootstrap_screen], axis=0)
        else:
            feed_dict[self.state_input] = screens

        if actions is not None:
            nonspacial, spacials, selection_coords, selection_indices = zip(
                *actions)
            spacials = [(13, 27) if spacial is None else spacial
                        for spacial in spacials]
            selections = [
                -1 if selection is None else selection
                for selection in selection_indices
            ]
            feed_dict[self.action_input] = np.array(nonspacial)
            feed_dict[self.spacial_input] = np.array(spacials)
            feed_dict[self.unit_selection_input] = np.array(selections)
        return feed_dict