Esempio n. 1
0
class BatchApply(Component):
    """
    Takes an input with batch and time ranks, then folds the time rank into the batch rank,
    calls a certain API of some arbitrary child component, and unfolds the time rank again.
    """
    def __init__(self,
                 sub_component,
                 api_method_name,
                 scope="batch-apply",
                 **kwargs):
        """
        Args:
            sub_component (Component): The sub-Component to apply the batch to.
            api_method_name (str): The name of the API-method to call on the sub-component.
        """
        super(BatchApply, self).__init__(scope=scope, **kwargs)

        self.sub_component = sub_component
        self.api_method_name = api_method_name

        # Create the necessary reshape components.
        self.folder = ReShape(fold_time_rank=True, scope="folder")
        self.unfolder = ReShape(unfold_time_rank=True, scope="unfolder")

        self.add_components(self.sub_component, self.folder, self.unfolder)

    @rlgraph_api
    def call(self, input_):
        folded = self.folder.call(input_)
        applied = getattr(self.sub_component, self.api_method_name)(folded)
        unfolded = self.unfolder.call(applied,
                                      input_before_time_rank_folding=input_)
        return unfolded
Esempio n. 2
0
class IMPALANetwork(NeuralNetwork):
    """
    The base class for both "large and small architecture" versions of the networks used in [1].

    [1] IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures - Espeholt, Soyer,
        Munos et al. - 2018 (https://arxiv.org/abs/1802.01561)
    """
    def __init__(self, worker_sample_size=100, scope="impala-network", **kwargs):
        """
        Args:
            worker_sample_size (int): How many time-steps an IMPALA actor will have performed in one rollout.
        """
        super(IMPALANetwork, self).__init__(scope=scope, **kwargs)

        self.worker_sample_size = worker_sample_size

        # Create all needed sub-components.

        # ContainerSplitter for the Env signal (dict of 4 keys: for env image, env text, previous action and reward).
        self.splitter = ContainerSplitter("RGB_INTERLEAVED", "INSTR", "previous_action", "previous_reward",
                                          scope="input-splitter")

        # Fold the time rank into the batch rank.
        self.time_rank_fold_before_lstm = ReShape(fold_time_rank=True, scope="time-rank-fold-before-lstm")
        self.time_rank_unfold_before_lstm = ReShape(unfold_time_rank=True, time_major=True,
                                                    scope="time-rank-unfold-before-lstm")

        # The Image Processing Stack (left side of "Large Architecture" Figure 3 in [1]).
        # Conv2D column + ReLU + fc(256) + ReLU.
        self.image_processing_stack = self.build_image_processing_stack()

        # The text processing pipeline: Takes a batch of string tensors as input, creates a hash-bucket thereof,
        # and passes the output of the hash bucket through an embedding-lookup(20) layer. The output of the embedding
        # lookup is then passed through an LSTM(64).
        self.text_processing_stack = self.build_text_processing_stack()

        #self.debug_slicer = Slice(scope="internal-states-slicer", squeeze=True)

        # The concatenation layer (concatenates outputs from image/text processing stacks, previous action/reward).
        self.concat_layer = ConcatLayer()

        # The main LSTM (going into the ActionAdapter (next in the Policy Component that uses this NN Component)).
        # Use time-major as it's faster (say tf docs).
        self.main_lstm = LSTMLayer(units=256, scope="lstm-256", time_major=True, static_loop=self.worker_sample_size)

        # Add all sub-components to this one.
        self.add_components(
            self.splitter, self.image_processing_stack, self.text_processing_stack,
            self.concat_layer,
            self.main_lstm,
            self.time_rank_fold_before_lstm, self.time_rank_unfold_before_lstm,
            #self.debug_slicer
        )

    @staticmethod
    def build_image_processing_stack():
        """
        Builds the image processing pipeline for IMPALA and returns it.
        """
        raise NotImplementedError

    @staticmethod
    def build_text_processing_stack():
        """
        Helper function to build the text processing pipeline for both the large and small architectures, consisting of:
        - ReShape preprocessor to fold the incoming time rank into the batch rank.
        - StringToHashBucket Layer taking a batch of sentences and converting them to an indices-table of dimensions:
          cols=length of longest sentences in input
          rows=number of items in the batch
          The cols dimension could be interpreted as the time rank into a consecutive LSTM. The StringToHashBucket
          Component returns the sequence length of each batch item for exactly that purpose.
        - Embedding Lookup Layer of embedding size 20 and number of rows == num_hash_buckets (see previous layer).
        - LSTM processing the batched sequences of words coming from the embedding layer as batches of rows.
        """
        num_hash_buckets = 1000

        # Create a hash bucket from the sentences and use that bucket to do an embedding lookup (instead of
        # a vocabulary).
        string_to_hash_bucket = StringToHashBucket(num_hash_buckets=num_hash_buckets)
        embedding = EmbeddingLookup(embed_dim=20, vocab_size=num_hash_buckets, pad_empty=True)
        # The time rank for the LSTM is now the sequence of words in a sentence, NOT the original env time rank.
        # We will only use the last output of the LSTM-64 for further processing as that is the output after having
        # seen all words in the sentence.
        # The original env stepping time rank is currently folded into the batch rank and must be unfolded again before
        # passing it into the main LSTM.
        lstm64 = LSTMLayer(units=64, scope="lstm-64", time_major=False)

        tuple_splitter = ContainerSplitter(tuple_length=2, scope="tuple-splitter")

        def custom_call(self, inputs):
            hash_bucket, lengths = self.sub_components["string-to-hash-bucket"].call(inputs)

            embedding_output = self.sub_components["embedding-lookup"].call(hash_bucket)

            # Return only the last output (sentence of words, where we are not interested in intermediate results
            # where the LSTM has not seen the entire sentence yet).
            # Last output is the final internal h-state (slot 1 in the returned LSTM tuple; slot 0 is final c-state).
            lstm_output = self.sub_components["lstm-64"].call(embedding_output, sequence_length=lengths)
            lstm_final_internals = lstm_output["last_internal_states"]

            # Need to split once more because the LSTM state is always a tuple of final c- and h-states.
            _, lstm_final_h_state = self.sub_components["tuple-splitter"].call(lstm_final_internals)

            return lstm_final_h_state

        text_processing_stack = Stack(
            string_to_hash_bucket, embedding, lstm64, tuple_splitter,
            api_methods={("call", custom_call)}, scope="text-stack"
        )

        return text_processing_stack

    @rlgraph_api
    def call(self, input_dict, internal_states=None):
        # Split the input dict coming directly from the Env.
        _, _, _, orig_previous_reward = self.splitter.call(input_dict)

        folded_input = self.time_rank_fold_before_lstm.call(input_dict)
        image, text, previous_action, previous_reward = self.splitter.call(folded_input)

        # Get the left-stack (image) and right-stack (text) output (see [1] for details).
        text_processing_output = self.text_processing_stack.call(text)
        image_processing_output = self.image_processing_stack.call(image)

        # Concat everything together.
        concatenated_data = self.concat_layer.call(
            image_processing_output, text_processing_output, previous_action, previous_reward
        )

        unfolded_concatenated_data = self.time_rank_unfold_before_lstm.call(concatenated_data, orig_previous_reward)

        # Feed concat'd input into main LSTM(256).
        lstm_output = self.main_lstm.call(unfolded_concatenated_data, internal_states)

        return lstm_output