def step(self, states, masks, memory): """ :param states: List of states of length batch size. In this case, state is a dict with keys: "unit_embeddings": numpy array with shape [num_units, embedding_size] "state": numpy array with shape [*state_shape] :param masks: numpy array of shape [batch_size, num_actions] :param memory: numpy of shape [2, batch_size, memory_size] or None for the first step """ if memory is None: memory = np.zeros((2, len(states), self.rnn_size)) feed_dict = { **self.get_feed_dict(states, masks), self.memory_input: memory } results = self.session.run( [self.next_lstm_state, self.nonspacial_probs, self.unit_selection_probs, *self.spacial_probs_x, *self.spacial_probs_y], feed_dict) next_lstm_state, nonspacial_probs, selection_probs = results[:3] spacial_probs = results[3:] spacial_probs_x = spacial_probs[:self.num_screen_dims] spacial_probs_y = spacial_probs[self.num_screen_dims:] unit_coords = util.pad_stack([state['unit_coords'][:, :2] for state in states], pad_axis=0, stack_axis=0) sampled_action = self.sample_action_index_with_units(nonspacial_probs, spacial_probs_x, spacial_probs_y, selection_probs, unit_coords) log_prob = self.log_prob_numpy(sampled_action, nonspacial_probs, np.stack([spacial_probs_x, spacial_probs_y]), selection_probs) # TODO: check for unit embeddings return sampled_action, next_lstm_state, log_prob
def step(self, states, masks, memories): """ :param states: List of states of length batch size. In this case, state is a dict with keys: "unit_embeddings": numpy array with shape [num_units, embedding_size] "state": numpy array with shape [*state_shape] :param masks: numpy array of shape [batch_size, num_actions] :param memories: Array of memory dictionaries :return used_states is the state that was actually used, including modifications for memory """ if memories is None: memories = [] for _ in range(len(states)): memories.append({ 'next_lstm_state': np.zeros((2, len(states), self.rnn_size)), 'prev_action': np.zeros(self.interface.num_actions) }) for state, memory in zip(states, memories): state['prev_action'] = memory['prev_action'] if use_lstm: feed_dict = { **self.get_feed_dict(states, masks), self.memory_input: memory['next_lstm_state'] } else: feed_dict = { **self.get_feed_dict(states, masks), } results = self.session.run([ self.next_lstm_state, self.nonspacial_probs, self.unit_selection_probs, *self.spacial_probs_x, *self.spacial_probs_y ], feed_dict) next_lstm_state, nonspacial_probs, selection_probs = results[:3] spacial_probs = results[3:] spacial_probs_x = spacial_probs[:self.num_screen_dims] spacial_probs_y = spacial_probs[self.num_screen_dims:] unit_coords = util.pad_stack( [state['unit_coords'][:, :2] for state in states], pad_axis=0, stack_axis=0) sampled_action = self.sample_action_index_with_units( nonspacial_probs, spacial_probs_x, spacial_probs_y, selection_probs, unit_coords) next_memories = [] # TODO, to reimplement memory, you will probably have to np stack since memory is now a list of dictionaries # with a separate state for each item in the batch. for i in range(len(states)): next_memory = { 'next_lstm_state': next_lstm_state, 'prev_action': np.zeros(self.interface.num_actions) } next_memory['prev_action'][sampled_action[i][0]] = 1 next_memories.append(next_memory) return states, sampled_action, next_memories
def get_feed_dict(self, states, masks, actions=None, bootstrap_state=None): screens = np.stack([state['screen'] for state in states], axis=0) feed_dict = { self.mask_input: np.array(masks), } all_states = states if bootstrap_state is None else [ *states, bootstrap_state ] unit_embeddings = util.pad_stack( [state['unit_embeddings'] for state in all_states], pad_axis=0, stack_axis=0) other_features = np.stack( [state['other_features'] for state in all_states], axis=0) # for state in all_states: # print("This is the state") # print(state) # print("This is the prev_action") # print(state['prev_action']) # if len(masks) == 1: # print("this is the mask: ") # print(masks) # print(masks) # print() prev_action_onehot = np.stack( [state['prev_action'] for state in all_states], axis=0) all_other_features = np.concatenate( [other_features, prev_action_onehot], axis=-1) feed_dict[self.other_features_input] = all_other_features feed_dict[self.unit_embeddings_input] = unit_embeddings if bootstrap_state is not None: bootstrap_screen = np.expand_dims(bootstrap_state['screen'], axis=0) feed_dict[self.state_input] = np.concatenate( [screens, bootstrap_screen], axis=0) else: feed_dict[self.state_input] = screens if actions is not None: nonspacial, spacials, selection_coords, selection_indices = zip( *actions) spacials = [(13, 27) if spacial is None else spacial for spacial in spacials] selections = [ -1 if selection is None else selection for selection in selection_indices ] feed_dict[self.action_input] = np.array(nonspacial) feed_dict[self.spacial_input] = np.array(spacials) feed_dict[self.unit_selection_input] = np.array(selections) return feed_dict
def get_feed_dict(self, states, masks, actions=None, bootstrap_state=None): screens = np.stack([state['screen'] for state in states], axis=0) feed_dict = { self.state_input: np.array(states), self.mask_input: np.array(masks), } all_states = states if bootstrap_state is None else [ *states, bootstrap_state ] unit_embeddings = util.pad_stack( [state['unit_embeddings'] for state in all_states], pad_axis=0, stack_axis=0) feed_dict[self.unit_embeddings_input] = unit_embeddings if bootstrap_state is not None: bootstrap_screen = np.expand_dims(bootstrap_state['screen'], axis=0) feed_dict[self.state_input] = np.concatenate( [screens, bootstrap_screen], axis=0) else: feed_dict[self.state_input] = screens if actions is not None: nonspacial, spacials, selection_coords, selection_indices = zip( *actions) spacials = [(13, 27) if spacial is None else spacial for spacial in spacials] selections = [ -1 if selection is None else selection for selection in selection_indices ] feed_dict[self.action_input] = np.array(nonspacial) feed_dict[self.spacial_input] = np.array(spacials) feed_dict[self.unit_selection_input] = np.array(selections) return feed_dict