Beispiel #1
0
class Policy(Module):
    """This class represents the NPI policy containing the environment encoder, the key-value and program embedding
    matrices, the NPI core lstm and the value networks for each task.

    Args:
        encoder (:obj:`{HanoiEnvEncoder, ListEnvEncoder, RecursiveListEnvEncoder, PyramidsEnvEncoder}`):
        hidden_size (int): Dimensionality of the LSTM hidden state
        num_programs (int): Overall number of programs and size actor's output softmax vector
        num_non_primary_programs (int): Number of non-zero level programs, also number of rows in embedding matrix
        embedding_dim (int): Dimensionality of the programs' embedding vectors
        encoding_dim (int): Dimensionality of the environment observation's encoding
        indices_non_primary_programs (list): Non zero level programs' indices
        learning_rate (float, optional): Defaults to 10^-3.
    """
    def __init__(self,
                 encoder,
                 hidden_size,
                 num_programs,
                 num_non_primary_programs,
                 embedding_dim,
                 encoding_dim,
                 indices_non_primary_programs,
                 learning_rate=1e-3,
                 temperature=0.1):

        super(Policy, self).__init__()

        self._uniform_init = (-0.1, 0.1)

        self._hidden_size = hidden_size
        self.num_programs = num_programs
        self.num_non_primary_programs = num_non_primary_programs

        self.embedding_dim = embedding_dim
        self.encoding_dim = encoding_dim

        # Initialize networks
        self.Mprog = Embedding(num_non_primary_programs, embedding_dim)
        self.encoder = encoder

        self.lstm = LSTMCell(self.encoding_dim + self.embedding_dim,
                             self._hidden_size)
        self.critic = CriticNet(self._hidden_size)
        self.actor = ContinuousActorNet(self._hidden_size, self.num_programs)

        self.temperature = temperature

        self.init_networks()
        self.init_optimizer(lr=learning_rate)

        # Compute relative indices of non primary programs (to deal with task indices)
        self.relative_indices = dict(
            (prog_idx, relat_idx)
            for relat_idx, prog_idx in enumerate(indices_non_primary_programs))

    def init_networks(self):

        for p in self.encoder.parameters():
            uniform_(p, self._uniform_init[0], self._uniform_init[1])

        for p in self.lstm.parameters():
            uniform_(p, self._uniform_init[0], self._uniform_init[1])

        for p in self.critic.parameters():
            uniform_(p, self._uniform_init[0], self._uniform_init[1])

        for p in self.actor.parameters():
            uniform_(p, self._uniform_init[0], self._uniform_init[1])

    def init_optimizer(self, lr):
        '''Initialize the optimizer.

        Args:
            lr (float): learning rate
        '''
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)

    def _one_hot_encode(self, digits, basis=6):
        """One hot encode a digit with basis. The digit may be None,
        the encoding associated to None is a vector full of zeros.

        Args:
          digits: batch (list) of digits
          basis:  (Default value = 6)

        Returns:
          a numpy array representing the 10-hot-encoding of the digit

        """
        encoding = torch.zeros(len(digits), basis)
        digits_filtered = list(filter(lambda x: x is not None, digits))

        if len(digits_filtered) != 0:
            tmp = [[
                idx for idx, digit in enumerate(digits) if digit is not None
            ], digits_filtered]
            encoding[tmp] = 1.0
        return encoding

    def predict_on_batch(self, e_t, i_t, h_t, c_t):
        """Run one NPI inference.

        Args:
          e_t: batch of environment observation
          i_t: batch of calling program
          h_t: batch of lstm hidden state
          c_t: batch of lstm cell state

        Returns:
          probabilities over programs, value, new hidden state, new cell state

        """
        batch_size = len(i_t)
        s_t = self.encoder(e_t.view(batch_size, -1))
        relative_prog_indices = [self.relative_indices[idx] for idx in i_t]
        p_t = self.Mprog(torch.LongTensor(relative_prog_indices)).view(
            batch_size, -1)

        new_h, new_c = self.lstm(torch.cat([s_t, p_t], -1), (h_t, c_t))

        actor_out = self.actor(new_h)
        critic_out = self.critic(new_h)
        return actor_out, critic_out, new_h, new_c

    def train_on_batch(self, batch):
        """perform optimization step.

        Args:
          batch (tuple): tuple of batches of environment observations, calling programs, lstm's hidden and cell states

        Returns:
          policy loss, value loss, total loss combining policy and value losses
        """
        e_t = torch.FloatTensor(np.stack(batch[0]))
        i_t = batch[1]
        lstm_states = batch[2]
        h_t, c_t = zip(*lstm_states)
        h_t, c_t = torch.squeeze(torch.stack(list(h_t))), torch.squeeze(
            torch.stack(list(c_t)))

        policy_labels = torch.squeeze(torch.stack(batch[3]))
        value_labels = torch.stack(batch[4]).view(-1, 1)

        self.optimizer.zero_grad()
        policy_predictions, value_predictions, _, _ = self.predict_on_batch(
            e_t, i_t, h_t, c_t)

        # policy_loss = -torch.mean(policy_labels * torch.log(policy_predictions), dim=-1).mean()

        beta = Beta(policy_predictions[0], policy_predictions[1])
        policy_action = beta.sample()
        prob_action = beta.log_prob(policy_action)

        log_mcts = self.temperature * torch.log(policy_labels)
        with torch.no_grad():
            modified_kl = prob_action - log_mcts

        policy_loss = -modified_kl * (torch.log(modified_kl) + prob_action)
        entropy_loss = self.entropy_lambda * beta.entropy()

        policy_network_loss = policy_loss + entropy_loss
        value_network_loss = torch.pow(value_predictions - value_labels,
                                       2).mean()

        total_loss = (policy_network_loss + value_network_loss) / 2
        total_loss.backward()
        self.optimizer.step()

        return policy_network_loss, value_network_loss, total_loss

    def forward_once(self, e_t, i_t, h, c):
        """Run one NPI inference using predict.

        Args:
          e_t: current environment observation
          i_t: current program calling
          h: previous lstm hidden state
          c: previous lstm cell state

        Returns:
          probabilities over programs, value, new hidden state, new cell state, a program index sampled according to
          the probabilities over programs)

        """
        e_t = torch.FloatTensor(e_t)
        e_t, h, c = e_t.view(1, -1), h.view(1, -1), c.view(1, -1)
        with torch.no_grad():
            e_t = e_t.to(device)
            actor_out, critic_out, new_h, new_c = self.predict_on_batch(
                e_t, [i_t], h, c)
        return actor_out, critic_out, new_h, new_c

    def init_tensors(self):
        """Creates tensors representing the internal states of the lstm filled with zeros.
        
        Returns:
            instantiated hidden and cell states
        """
        h = torch.zeros(1, self._hidden_size)
        c = torch.zeros(1, self._hidden_size)
        h, c = h.to(device), c.to(device)
        return h, c
Beispiel #2
0
class Brain(Module):
    def __init__(
            self,
            action_space_size=18,
            embedding=100,
            hidden_size=100,
            uniform_init=(-0.1, 0.1),
            device=0,
    ):
        """

        :param nodes_per_cells:
        :param hidden_size:
        :param number_of_ops:
        :param uniform_init:
        :param device:
        """
        super(Brain, self).__init__()
        # Internal parameters
        self.device = device
        self.uniform_init = uniform_init

        self.hidden_size = hidden_size
        self.embedding = embedding
        self.action_space = action_space_size

        # Initialize network
        self.rnn = None
        self.actor = None
        self.middle_critic = None
        self.encoder = None
        self.critic = None
        self.padding = None
        self.init_network()
        self.cuda(self.device)

    def init_network(self):
        """Initialize network parameters. This is an actor-critic build on top of a RNN cell. The
        actor is a fully connected layer, and the critic consists of two fully connected layers"""
        self.rnn = LSTMCell(self.action_space, self.hidden_size)
        for p in self.rnn.parameters():
            uniform_(p, self.uniform_init[0], self.uniform_init[1])

        self.actor = Linear(self.hidden_size, self.action_space)
        for p in self.actor.parameters():
            uniform_(p, self.uniform_init[0], self.uniform_init[1])

        self.middle_critic = Linear(self.hidden_size, self.hidden_size // 2)
        for p in self.middle_critic.parameters():
            uniform_(p, self.uniform_init[0], self.uniform_init[1])

        self.critic = Linear(self.hidden_size // 2, 1)
        for p in self.critic.parameters():
            uniform_(p, self.uniform_init[0], self.uniform_init[1])

        self.encoder = resnet34(**{"num_classes": self.embedding})

        self.padding = ZeroPad2d((30, 20, 0, 0))

    def predict(self, oh_action, h, c):
        """
        Run the model for the given internal state and action.
        :param oh_action:
        :param h:
        :param c:
        :return:
        """
        h, c = self.rnn(oh_action, (h, c))
        actor_out = self.actor(h)
        critic_out = torch.nn.functional.relu(self.middle_critic(h))
        critic_out = self.critic(critic_out)
        return actor_out, critic_out, h, c

    def __forward_input(self, sampled, observations):
        """
        From a full state vector (40 dims) predict all the values and actions
         for that vector by building a full path of actions
        :param sampled: 40 dim vector representing the full network connections.
        :return: actor_outs, critic_outs, h, c
        """
        obs_tensor = []
        for i in range(len(observations)):
            obs = torch.from_numpy(observations[i]).cuda(self.device).type(
                torch.cuda.FloatTensor)
            obs = torch.unsqueeze(obs, 0)
            obs = torch.transpose(obs, 1, 3)
            obs = torch.transpose(obs, 3, 2)
            h = self.encoder(obs)
            obs_tensor.append(h)
        c = torch.zeros(1, self.hidden_size).cuda(self.device)
        oh_action = torch.zeros(1, self.action_space).cuda(self.device)
        actor_outs = []
        critic_outs = []
        h = obs_tensor[0]
        # TODO: stack mask to avoid rep
        embedding_tensor = []
        for i in range(len(sampled)):
            actor_out, critic_out, h, c = self.predict(oh_action, h, c)
            embedding_tensor.append(h)
            actor_outs.append(actor_out)
            critic_outs.append(critic_out)
            # One hot encode the next action to be taken
            action = sampled[i]
            oh_action = torch.zeros_like(oh_action)
            oh_action[0, action] = 1.0

        # Get last value pred from the critic
        h, c = self.rnn(oh_action, (h, c))
        critic_out = torch.nn.functional.relu(self.middle_critic(h))
        critic_out = self.critic(critic_out)
        critic_outs.append(critic_out)

        # Squeeze stuff
        actor_outs = torch.stack(tuple(actor_outs), dim=1)
        actor_outs = torch.squeeze(actor_outs)
        critic_outs = torch.stack(tuple(critic_outs), dim=1)
        critic_outs = torch.squeeze(critic_outs)
        obs_tensor = torch.cat(tuple(obs_tensor), dim=0)
        embedding_tensor = torch.cat(tuple(embedding_tensor), dim=0)
        return actor_outs, critic_outs, obs_tensor, embedding_tensor

    def forward_input(self, sampled, observations):
        """
        From a full state vector (40 dims) predict all the values and actions
         for that vector by building a full path of actions
        :param sampled: 40 dim vector representing the full network connections.
        :return: actor_outs, critic_outs, h, c
        """
        obs = torch.from_numpy(observations[0]).cuda(self.device).type(
            torch.cuda.FloatTensor)
        obs = torch.unsqueeze(obs, 0)
        obs = torch.transpose(obs, 1, 3)
        obs = torch.transpose(obs, 3, 2)
        # obs = self.padding(obs)
        h = self.encoder(obs)
        c = torch.zeros(1, self.hidden_size).cuda(self.device)
        oh_action = torch.zeros(1, self.action_space).cuda(self.device)
        actor_outs = []
        critic_outs = []
        # TODO: stack mask to avoid rep
        for i in range(len(sampled)):
            actor_out, critic_out, h, c = self.predict(oh_action, h, c)

            actor_outs.append(actor_out)
            critic_outs.append(critic_out)
            # One hot encode the next action to be taken
            action = sampled[i]
            oh_action = torch.zeros_like(oh_action)
            oh_action[0, action] = 1.0

        # Get last value pred from the critic
        h, c = self.rnn(oh_action, (h, c))
        critic_out = torch.nn.functional.relu(self.middle_critic(h))
        critic_out = self.critic(critic_out)
        critic_outs.append(critic_out)

        # Squeeze stuff
        actor_outs = torch.stack(tuple(actor_outs), dim=1)
        actor_outs = torch.squeeze(actor_outs)
        critic_outs = torch.stack(tuple(critic_outs), dim=1)
        critic_outs = torch.squeeze(critic_outs)
        return actor_outs, critic_outs

    def forward_once(self, oh_action, h, c):
        """
        Given an state, a one hot encoded actions and its corresponding index, predict the
         next value of the actor-critic.
        :param oh_action:
        :param h:
        :param c:
        :return:
        """
        with torch.no_grad():
            actor_out, critic_out, h, c = self.predict(
                oh_action.cuda(self.device), h, c)
            # Mask to build the true action distribution
            action_probs = softmax(actor_out, dim=-1)
            action_probs = action_probs / torch.sum(action_probs)
            # Sample and one hot encoding
            action = torch.multinomial(action_probs, 1)
            oh_action = torch.zeros_like(oh_action)
            oh_action[0, action[0, 0]] = 1.0
        action = torch.squeeze(action)
        return action_probs, critic_out, h, c, oh_action, int(action)

    def init_tensors(self, observation):
        init_action = torch.zeros(1, self.action_space).cuda(self.device)
        c = torch.zeros(1, self.hidden_size).cuda(self.device)
        with torch.no_grad():
            obs = torch.from_numpy(observation).cuda(self.device).type(
                torch.cuda.FloatTensor)
            obs = torch.unsqueeze(obs, 0)
            obs = torch.transpose(obs, 1, 3)
            obs = torch.transpose(obs, 3, 2)
            # obs = self.padding(obs)
            encoded = self.encoder(obs)
        return init_action, encoded, c

    def convert_to_hot(self, action):
        oh_action = torch.zeros(1, self.action_space).cuda(self.device)
        oh_action[0, action] = 1.0
        return oh_action

    def reset(self):
        return None
Beispiel #3
0
class Policy(Module):
    """
    This class represents an Actor-Critic network that will be used to evaluate and predict
    regex proposals. Both the Actor and the Critic heads are connected to an LSTM
    cell, that is run on a sequence vector that represents the regex proposal.

    Args:
        mask: Mask of available actions at each step of the unrolling process.
        hidden_size: Number of hidden neurons in the LSTM cell architectures.
        uniform_init: Range of initialization of weights in the network.

    Attributes:
        rnn: LTSM block where the regex proposal components are fed one by one as a
            sequence.
        actor: Output of the Actor critic containing the probability distribution of choosing
            one specific action over all the possible ones.
        critic: This linear layer is in charge of predicting a value function estimate for a
            target regex proposal, even if the sequence conforming it is not finished.

    """

    def __init__(
        self,
        n_action: int,
        max_length: int,
        regex_generator: object,
        hidden_size: int = 100,
        uniform_init: Tuple[float, float] = (-0.1, 0.1),
    ):

        super(Policy, self).__init__()
        # Internal parameters
        self._uniform_init = uniform_init
        self.generator = regex_generator
        self._hidden_size = hidden_size
        self._n_actions = n_action
        self._max_path_len = max_length
        # Initialize network
        self.rnn = None
        self.actor = None
        self.middle_critic = None
        self.critic = None
        self.init_network()

    @property
    def max_path_len(self) -> int:
        """This is the length of a complete regex proposal."""
        return self._max_path_len

    @property
    def n_actions(self) -> int:
        """Maximum number of different possible actions."""
        return self._n_actions

    def init_network(self):
        """Initialize network parameters. This is an actor-critic build on top of a RNN cell. The
        actor is a fully connected layer, and the critic consists of two fully connected layers"""
        self.rnn = LSTMCell(self.n_actions, self._hidden_size)
        for p in self.rnn.parameters():
            uniform_(p, self._uniform_init[0], self._uniform_init[1])

        self.actor = Linear(self._hidden_size, self.n_actions)
        for p in self.actor.parameters():
            uniform_(p, self._uniform_init[0], self._uniform_init[1])

        self.middle_critic = Linear(self._hidden_size, self._hidden_size // 2)
        for p in self.middle_critic.parameters():
            uniform_(p, self._uniform_init[0], self._uniform_init[1])

        self.critic = Linear(self._hidden_size // 2, 1)
        for p in self.critic.parameters():
            uniform_(p, self._uniform_init[0], self._uniform_init[1])

    def predict(
        self, oh_action: Vector, h: Vector, c: Vector
    ) -> Tuple[Vector, Vector, Vector, Vector]:
        """
        Run the model for the given internal state, which is represented by the last action that
        was taken, and the internal state of the LSTM cell.
        Args:
            oh_action: One hot encoded vector representing the last action that was taken.
            h: Previous output of the LSTM cell.
            c: Hidden state of the LSTM cell.
            mask: Mask of available actions at the current step of the unrolling process.

        Returns:
            tuple containing the output of the actor network, the output of the critic,
            the output of the LSTM cell, and the internal state of the LSTM.
        """
        h, c = self.rnn(oh_action, (h, c))
        actor_out = self.actor(h)
        actor_out = actor_out
        critic_out = torch.nn.functional.relu(self.middle_critic(h))
        critic_out = self.critic(critic_out)
        return actor_out, critic_out, h, c

    def evaluate_solution(self, arch_proposal: Vector) -> Tuple[Vector, Vector, Vector, Vector]:
        """
        Get the outputs that the model will give at every step when trying to build the
        proposed regex. For each one of the actions taken to build the regex
        proposal it will calculate the output of the actor and the critic.
        Args:
            arch_proposal: vector representing the full regex.

        Returns:
            actor_outs, critic_outs, h, c
        """
        oh_action = torch.zeros(1, self.n_actions)
        h = torch.zeros(1, self._hidden_size)
        c = torch.zeros(1, self._hidden_size)
        if torch.cuda.is_available():
            oh_action = oh_action.cuda()
            h = h.cuda()
            c = c.cuda()
        actor_outs = []
        critic_outs = []
        for ix, action in enumerate(arch_proposal):
            actor_out, critic_out, h, c = self.predict(oh_action, h, c)
            actor_out = torch.softmax(actor_out, dim=1)[0, action]  # Added
            actor_outs.append(actor_out)
            critic_outs.append(critic_out)
            # One hot encode the next action to be taken
            oh_action = torch.zeros_like(oh_action)
            oh_action[0, action] = 1.0

        # Get last value prediction from the critic
        h, c = self.rnn(oh_action, (h, c))
        critic_out = torch.nn.functional.relu(self.middle_critic(h))
        critic_out = self.critic(critic_out)
        critic_outs.append(critic_out)

        actor_outs = torch.stack(tuple(actor_outs), dim=0)
        actor_outs = torch.squeeze(actor_outs)
        critic_outs = torch.stack(tuple(critic_outs), dim=1)
        critic_outs = torch.squeeze(critic_outs)
        return actor_outs, critic_outs[1:], h, c

    def forward_once(self, oh_action, h, c, action_ix, regex_state):
        """
        Given a one hot encoded vetor representing the last action that was taken and the
        state of the recurrent cell, run the model to sample information for the next action to
        be chosen.
        Args:
            oh_action: One hot encoded vector representing the last action that was taken.
            h: Previous output of the LSTM cell.
            c: Hidden state of the LSTM cell.
            action_ix: Index of the current step in the whole unrolling process.

        Returns:
            action_probs, critic_out, h, c, oh_action, action_chosen
        """

        if not regex_state:
            mask, new_regex_state = self.generator.reset_with_state()
            reg = None
        else:
            action = int(torch.argmax(oh_action, dim=-1))
            mask, reg, new_regex_state = self.generator.step_with_state(action, regex_state)

        finish = True if mask is None else False

        with torch.no_grad():
            oh_action = oh_action if not torch.cuda.is_available() else oh_action.cuda()
            mask = torch.FloatTensor([mask]) if mask is not None else torch.ones_like(oh_action)
            mask = mask if not torch.cuda.is_available() else mask.cuda()
            actor_out, critic_out, h, c = self.predict(oh_action, h, c)
            action_probs = softmax(actor_out, dim=-1)
            action_probs = (action_probs * mask) / action_probs.sum()
            action = torch.multinomial(action_probs, 1)
            oh_action = torch.zeros_like(oh_action)
            oh_action[0, action[0, 0]] = 1.0
        action = torch.squeeze(action)

        if finish:
            action_probs = None

        return action_probs, critic_out, h, c, oh_action, int(action), new_regex_state

    def init_tensors(self) -> Tuple[Vector, Vector, Vector]:
        """
        Creates tensors representing the internal states of the controller filled with zeros
        Returns:
            init_action, h, c which are used as an starting point for the LSTM cell.
        """
        init_action = torch.zeros(1, self.n_actions)
        h = torch.zeros(1, self._hidden_size)
        c = torch.zeros(1, self._hidden_size)
        if torch.cuda.is_available():
            init_action, h, c = (init_action.cuda(), h.cuda(), c.cuda())

        return init_action, h, c

    def sample_regex(self, length):
        """
        Sample a regex just with policy lstm
        :param length: max length of Regex
        :return: actions choosen and regex string
        """

        oh_action, h, c = self.init_tensors()
        regex_state = None
        actions = []

        for i in range(length):
            action_probs, reg, h, c, oh_action, action, regex_state = self.forward_once(
                oh_action, h, c, i, regex_state
            )
            actions.append(action)
            if action_probs is None:
                break

        reg = self.generator.get_regex(regex_state)

        return actions, reg