class BatchApply(Component): """ Takes an input with batch and time ranks, then folds the time rank into the batch rank, calls a certain API of some arbitrary child component, and unfolds the time rank again. """ def __init__(self, sub_component, api_method_name, scope="batch-apply", **kwargs): """ Args: sub_component (Component): The sub-Component to apply the batch to. api_method_name (str): The name of the API-method to call on the sub-component. """ super(BatchApply, self).__init__(scope=scope, **kwargs) self.sub_component = sub_component self.api_method_name = api_method_name # Create the necessary reshape components. self.folder = ReShape(fold_time_rank=True, scope="folder") self.unfolder = ReShape(unfold_time_rank=True, scope="unfolder") self.add_components(self.sub_component, self.folder, self.unfolder) @rlgraph_api def call(self, input_): folded = self.folder.call(input_) applied = getattr(self.sub_component, self.api_method_name)(folded) unfolded = self.unfolder.call(applied, input_before_time_rank_folding=input_) return unfolded
class IMPALANetwork(NeuralNetwork): """ The base class for both "large and small architecture" versions of the networks used in [1]. [1] IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures - Espeholt, Soyer, Munos et al. - 2018 (https://arxiv.org/abs/1802.01561) """ def __init__(self, worker_sample_size=100, scope="impala-network", **kwargs): """ Args: worker_sample_size (int): How many time-steps an IMPALA actor will have performed in one rollout. """ super(IMPALANetwork, self).__init__(scope=scope, **kwargs) self.worker_sample_size = worker_sample_size # Create all needed sub-components. # ContainerSplitter for the Env signal (dict of 4 keys: for env image, env text, previous action and reward). self.splitter = ContainerSplitter("RGB_INTERLEAVED", "INSTR", "previous_action", "previous_reward", scope="input-splitter") # Fold the time rank into the batch rank. self.time_rank_fold_before_lstm = ReShape(fold_time_rank=True, scope="time-rank-fold-before-lstm") self.time_rank_unfold_before_lstm = ReShape(unfold_time_rank=True, time_major=True, scope="time-rank-unfold-before-lstm") # The Image Processing Stack (left side of "Large Architecture" Figure 3 in [1]). # Conv2D column + ReLU + fc(256) + ReLU. self.image_processing_stack = self.build_image_processing_stack() # The text processing pipeline: Takes a batch of string tensors as input, creates a hash-bucket thereof, # and passes the output of the hash bucket through an embedding-lookup(20) layer. The output of the embedding # lookup is then passed through an LSTM(64). self.text_processing_stack = self.build_text_processing_stack() #self.debug_slicer = Slice(scope="internal-states-slicer", squeeze=True) # The concatenation layer (concatenates outputs from image/text processing stacks, previous action/reward). self.concat_layer = ConcatLayer() # The main LSTM (going into the ActionAdapter (next in the Policy Component that uses this NN Component)). # Use time-major as it's faster (say tf docs). self.main_lstm = LSTMLayer(units=256, scope="lstm-256", time_major=True, static_loop=self.worker_sample_size) # Add all sub-components to this one. self.add_components( self.splitter, self.image_processing_stack, self.text_processing_stack, self.concat_layer, self.main_lstm, self.time_rank_fold_before_lstm, self.time_rank_unfold_before_lstm, #self.debug_slicer ) @staticmethod def build_image_processing_stack(): """ Builds the image processing pipeline for IMPALA and returns it. """ raise NotImplementedError @staticmethod def build_text_processing_stack(): """ Helper function to build the text processing pipeline for both the large and small architectures, consisting of: - ReShape preprocessor to fold the incoming time rank into the batch rank. - StringToHashBucket Layer taking a batch of sentences and converting them to an indices-table of dimensions: cols=length of longest sentences in input rows=number of items in the batch The cols dimension could be interpreted as the time rank into a consecutive LSTM. The StringToHashBucket Component returns the sequence length of each batch item for exactly that purpose. - Embedding Lookup Layer of embedding size 20 and number of rows == num_hash_buckets (see previous layer). - LSTM processing the batched sequences of words coming from the embedding layer as batches of rows. """ num_hash_buckets = 1000 # Create a hash bucket from the sentences and use that bucket to do an embedding lookup (instead of # a vocabulary). string_to_hash_bucket = StringToHashBucket(num_hash_buckets=num_hash_buckets) embedding = EmbeddingLookup(embed_dim=20, vocab_size=num_hash_buckets, pad_empty=True) # The time rank for the LSTM is now the sequence of words in a sentence, NOT the original env time rank. # We will only use the last output of the LSTM-64 for further processing as that is the output after having # seen all words in the sentence. # The original env stepping time rank is currently folded into the batch rank and must be unfolded again before # passing it into the main LSTM. lstm64 = LSTMLayer(units=64, scope="lstm-64", time_major=False) tuple_splitter = ContainerSplitter(tuple_length=2, scope="tuple-splitter") def custom_call(self, inputs): hash_bucket, lengths = self.sub_components["string-to-hash-bucket"].call(inputs) embedding_output = self.sub_components["embedding-lookup"].call(hash_bucket) # Return only the last output (sentence of words, where we are not interested in intermediate results # where the LSTM has not seen the entire sentence yet). # Last output is the final internal h-state (slot 1 in the returned LSTM tuple; slot 0 is final c-state). lstm_output = self.sub_components["lstm-64"].call(embedding_output, sequence_length=lengths) lstm_final_internals = lstm_output["last_internal_states"] # Need to split once more because the LSTM state is always a tuple of final c- and h-states. _, lstm_final_h_state = self.sub_components["tuple-splitter"].call(lstm_final_internals) return lstm_final_h_state text_processing_stack = Stack( string_to_hash_bucket, embedding, lstm64, tuple_splitter, api_methods={("call", custom_call)}, scope="text-stack" ) return text_processing_stack @rlgraph_api def call(self, input_dict, internal_states=None): # Split the input dict coming directly from the Env. _, _, _, orig_previous_reward = self.splitter.call(input_dict) folded_input = self.time_rank_fold_before_lstm.call(input_dict) image, text, previous_action, previous_reward = self.splitter.call(folded_input) # Get the left-stack (image) and right-stack (text) output (see [1] for details). text_processing_output = self.text_processing_stack.call(text) image_processing_output = self.image_processing_stack.call(image) # Concat everything together. concatenated_data = self.concat_layer.call( image_processing_output, text_processing_output, previous_action, previous_reward ) unfolded_concatenated_data = self.time_rank_unfold_before_lstm.call(concatenated_data, orig_previous_reward) # Feed concat'd input into main LSTM(256). lstm_output = self.main_lstm.call(unfolded_concatenated_data, internal_states) return lstm_output