def build_graph(self):
        """
        Builds computational graph for policy
        """
        with tf.variable_scope(self.name):
            # build the actual policy network
            rnn_outs = create_rnn(name='mean_network',
                                  cell_type=self._cell_type,
                                  output_dim=self.action_dim,
                                  hidden_sizes=self.hidden_sizes,
                                  hidden_nonlinearity=self.hidden_nonlinearity,
                                  output_nonlinearity=self.output_nonlinearity,
                                  input_dim=(None, None, self.obs_dim,),
                                  )

            self.obs_var, self.hidden_var, self.mean_var, self.next_hidden_var, self.cell = rnn_outs

            with tf.variable_scope("log_std_network"):
                log_std_var = tf.get_variable(name='log_std_var',
                                              shape=(1, self.action_dim,),
                                              dtype=tf.float32,
                                              initializer=tf.constant_initializer(self.init_log_std),
                                              trainable=self.learn_std
                                              )

                self.log_std_var = tf.maximum(log_std_var, self.min_log_std, name='log_std')

            # symbolically define sampled action and distribution
            self._dist = DiagonalGaussian(self.action_dim)

            # save the policy's trainable variables in dicts
            current_scope = tf.get_default_graph().get_name_scope()
            trainable_policy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=current_scope)
            self.policy_params = OrderedDict([(remove_scope_from_name(var.name, current_scope), var) for var in trainable_policy_vars])
Exemple #2
0
    def distribution_info_sym(self, obs_var, params=None):
        """
        Return the symbolic distribution information about the actions.

        Args:
            obs_var (placeholder) : symbolic variable for observations
            params (dict) : a dictionary of placeholders or vars with the parameters of the MLP

        Returns:
            (dict) : a dictionary of tf placeholders for the policy output distribution
        """
        assert params is None
        with tf.variable_scope(self.name, reuse=True):
            rnn_outs = create_rnn(
                name="probs_network",
                output_dim=self.action_dim,
                hidden_sizes=self.hidden_sizes,
                hidden_nonlinearity=self.hidden_nonlinearity,
                output_nonlinearity=tf.nn.softmax,
                input_var=obs_var,
                cell_type=self._cell_type,
            )
            obs_var, hidden_var, probs_var, next_hidden_var, cell = rnn_outs

        return dict(probs=probs_var), hidden_var, next_hidden_var
    def build_graph(self):
        """
        Builds computational graph for policy
        """
        with tf.variable_scope(self.name):
            # build the actual policy network
            rnn_outs = create_rnn(
                name='probs_network',
                cell_type=self._cell_type,
                output_dim=self.action_dim,
                hidden_sizes=self.hidden_sizes,
                hidden_nonlinearity=self.hidden_nonlinearity,
                output_nonlinearity=tf.nn.softmax,
                input_dim=(
                    None,
                    None,
                    self.obs_dim,
                ),
            )

            self.obs_var, self.hidden_var, self.probs_var, self.next_hidden_var, self.cell = rnn_outs

            # symbolically define sampled action and distribution
            self._dist = Discrete(self.action_dim)

            # save the policy's trainable variables in dicts
            current_scope = tf.get_default_graph().get_name_scope()
            trainable_policy_vars = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=current_scope)
            self.policy_params = OrderedDict([
                (remove_scope_from_name(var.name, current_scope), var)
                for var in trainable_policy_vars
            ])
Exemple #4
0
    def build_graph(self):
        """
        Builds computational graph for policy
        """
        with tf.variable_scope(self.name):

            instr_embedding = self._get_instr_embedding(obs.instr)
            x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3)
            x = self.image_conv(x)
            for controler in self.controllers:
                x = controler(x, instr_embedding)
            x = F.relu(self.film_pool(x))
            x = x.reshape(x.shape[0], -1)

            hidden = (memory[:, :self.semi_memory_size],
                      memory[:, self.semi_memory_size:])
            hidden = self.memory_rnn(x, hidden)
            embedding = hidden[0]
            memory = torch.cat(hidden, dim=1)

            embedding = torch.cat((embedding, instr_embedding), dim=1)
            x = self.actor(embedding)
            dist = Categorical(logits=F.log_softmax(x, dim=1))

            # memory_rnn = tf.nn.rnn_cell.LSTMCell(self.memory_dim)  # TODO: set these

            rnn_outs = create_rnn(
                name='probs_network',
                cell_type='lstm',
                output_dim=self.action_dim,
                hidden_sizes=self.hidden_sizes,
                hidden_nonlinearity=self.hidden_nonlinearity,
                output_nonlinearity=tf.nn.softmax,
                input_dim=(
                    None,
                    None,
                    self.obs_dim,
                ),
            )

            # obs_var, hidden_var, probs_var, next_hidden_var, cell = create_rnn(name='probs_network2',
            #                       cell_type='lstm',
            #                       output_dim=self.action_dim,
            #                       hidden_sizes=self.hidden_sizes,
            #                       hidden_nonlinearity=self.hidden_nonlinearity,
            #                       output_nonlinearity=tf.nn.softmax,
            #                       input_dim=(None, None, self.obs_dim,),
            #                       )

            from tensorflow.keras import datasets, layers, models
            import matplotlib.pyplot as plt

            # x = "INPUT"
            # input_conv = nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
            #
            # model = models.Sequential()
            # model.add(layers.Conv2D(128, (2,2), padding='SAME', input_shape=(8, 8, 3)))
            # model.add(layers.BatchNormalization())
            # model.add(layers.ReLU())
            # model.add(layers.MaxPooling2D((2, 2), strides=2))
            # model.add(layers.Conv2D(128, (3, 3), padding='SAME'))
            # model.add(layers.BatchNormalization())
            # model.add(layers.ReLU())
            # model.add(layers.MaxPooling2D((2, 2), strides=2))
            # model.compile('rmsprop', 'mse')

            film_pool = layers.MaxPooling2D((2, 2), strides=2)
            word_embedding = layers.Embedding(
                obs_space["instr"], self.instr_dim)  # TODO: get this!
            gru_dim = self.instr_dim  # TODO: set this
            gru_dim //= 2
            instr_rnn = tf.keras.layers.GRUCell(self.instr_dim,
                                                gru_dim,
                                                batch_first=True,
                                                bidirectional=True)
            self.final_instr_dim = self.instr_dim

            memory_rnn = tf.keras.layers.LSTMCell(
                self.image_dim, self.memory_dim)  # TODO: set these

            # Resize image embedding
            embedding_size = self.semi_memory_size  # TODO: set this
            # if self.use_instr and not "filmcnn" in arch:
            #     self.embedding_size += self.final_instr_dim  # TODO: consider keepint this!

            num_module = 2
            self.controllers = []
            for ni in range(num_module):
                if ni < num_module - 1:
                    mod = ExpertControllerFiLM(
                        in_features=self.final_instr_dim,
                        out_features=128,
                        in_channels=128,
                        imm_channels=128)
                else:
                    mod = ExpertControllerFiLM(
                        in_features=self.final_instr_dim,
                        out_features=self.image_dim,
                        in_channels=128,
                        imm_channels=128)
                self.controllers.append(mod)
                self.add_module('FiLM_Controler_' + str(ni), mod)
            #
            # Define actor's model
            self.actor = nn.Sequential(nn.Linear(self.embedding_size, 64),
                                       nn.Tanh(),
                                       nn.Linear(64, action_space.n))

            # rnn_outs = create_rnn(name='probs_network',
            #                       cell_type=self._cell_type,
            #                       output_dim=self.action_dim,
            #                       hidden_sizes=self.hidden_sizes,
            #                       hidden_nonlinearity=self.hidden_nonlinearity,
            #                       output_nonlinearity=tf.nn.softmax,
            #                       input_dim=(None, None, self.obs_dim,),
            #                       )

            self.obs_var, self.hidden_var, self.probs_var, self.next_hidden_var, self.cell = rnn_outs
            self.probs_var = (self.probs_var + probs_var) / 2

            # symbolically define sampled action and distribution
            self._dist = Discrete(self.action_dim)

            # save the policy's trainable variables in dicts
            current_scope = tf.get_default_graph().get_name_scope()
            trainable_policy_vars = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=current_scope)
            self.policy_params = OrderedDict([
                (remove_scope_from_name(var.name, current_scope), var)
                for var in trainable_policy_vars
            ])