Ejemplo n.º 1
0
    def add_last_seq_in_trajectory(self, experience, new_seqs):
        """Add the last sequence in an episode's trajectory.

        Given a trajectory object, checks if the object is the last in the trajectory.
        Since the environment ends the episode when the score is non-increasing, it
        adds the associated maximum-valued sequence to the batch.

        If the episode is ending, it changes the "current sequence" of the environment
        to the next one in `last_batch`, so that when the environment resets, mutants
        are generated from that new sequence.
        """
        if experience.is_boundary():
            seq = one_hot_to_string(
                experience.observation["sequence"].numpy()[0], self.alphabet)
            new_seqs[seq] = experience.observation["fitness"].numpy().squeeze()

            top_fitness = max(new_seqs.values())
            top_sequences = [
                seq for seq, fitness in new_seqs.items()
                if fitness >= 0.9 * top_fitness
            ]
            if len(top_sequences) > 0:
                self.tf_env.pyenv.envs[0].seq = np.random.choice(top_sequences)
            else:
                self.tf_env.pyenv.envs[0].seq = np.random.choice(
                    [seq for seq, _ in new_seqs.items()])
Ejemplo n.º 2
0
    def _step(self, actions):
        """Progress the agent one step in the environment."""
        actions = actions.flatten()
        self.states[:, self.partial_seq_len, -1] = 0
        self.states[np.arange(self._batch_size), self.partial_seq_len,
                    actions] = 1
        self.partial_seq_len += 1

        # We have not generated the last residue in the sequence, so continue
        if self.partial_seq_len < self.seq_length - 1:
            return nest_utils.stack_nested_arrays(
                [ts.transition(seq_state, 0) for seq_state in self.states])

        # If sequence is of full length, score the sequence and end the episode
        # We need to take off the column in the matrix (-1) representing the mask token
        complete_sequences = [
            s_utils.one_hot_to_string(seq_state[:, :-1], self.alphabet)
            for seq_state in self.states
        ]
        if self.fitness_model_is_gt:
            fitnesses = self.landscape.get_fitness(complete_sequences)
        else:
            fitnesses = self.model.get_fitness(complete_sequences)
        self.all_seqs.update(zip(complete_sequences, fitnesses))

        # Reward = fitness - lambda * sequence density
        rewards = np.array([
            f - self.lam * self.sequence_density(seq)
            for seq, f in zip(complete_sequences, fitnesses)
        ])
        return nest_utils.stack_nested_arrays([
            ts.termination(seq_state, r)
            for seq_state, r in zip(self.states, rewards)
        ])
Ejemplo n.º 3
0
    def _soln_to_string(self, soln):
        x = soln.reshape((len(self.starting_sequence), len(self.alphabet)))

        one_hot = np.zeros(x.shape)
        one_hot[np.arange(len(one_hot)), np.argmax(x, axis=1)] = 1

        return s_utils.one_hot_to_string(one_hot, self.alphabet)
Ejemplo n.º 4
0
 def pick_action(self, all_measured_seqs):
     """Pick action."""
     state = self.state.copy()
     actions = self.sample_actions()
     actions_to_screen = []
     states_to_screen = []
     for i in range(self.model_queries_per_batch // self.sequences_batch_size):
         x = np.zeros((self.seq_len, len(self.alphabet)))
         for action in actions[i]:
             x[action] = 1
         actions_to_screen.append(x)
         state_to_screen = construct_mutant_from_sample(x, state)
         states_to_screen.append(one_hot_to_string(state_to_screen, self.alphabet))
     ensemble_preds = self.model.get_fitness(states_to_screen)
     method_pred = (
         [self.EI(vals) for vals in ensemble_preds]
         if self.method == "EI"
         else [self.UCB(vals) for vals in ensemble_preds]
     )
     action_ind = np.argmax(method_pred)
     uncertainty = np.std(method_pred[action_ind])
     action = actions_to_screen[action_ind]
     new_state_string = states_to_screen[action_ind]
     self.state = string_to_one_hot(new_state_string, self.alphabet)
     new_state = self.state
     reward = np.mean(ensemble_preds[action_ind])
     if new_state_string not in all_measured_seqs:
         self.best_fitness = max(self.best_fitness, reward)
         self.memory.store(state.ravel(), action.ravel(), reward, new_state.ravel())
     self.num_actions += 1
     return uncertainty, new_state_string, reward
Ejemplo n.º 5
0
    def _step(self, action):
        """Progress the agent one step in the environment.

        The agent moves until the reward is decreasing. The number of sequences that
        can be evaluated at each episode is capped to `self.max_num_steps`.
        """
        # if we've exceeded the maximum number of steps, terminate
        if self.num_steps >= self.max_num_steps:
            return ts.termination(self._state, 0)

        # `action` is an integer representing which residue to mutate to 1
        # along the flattened one-hot representation of the sequence
        pos = action // len(self.alphabet)
        res = action % len(self.alphabet)
        self.num_steps += 1

        # if we are trying to modify the sequence with a no-op, then stop
        if self._state["sequence"][pos, res] == 1:
            return ts.termination(self._state, 0)

        self._state["sequence"][pos] = 0
        self._state["sequence"][pos, res] = 1
        state_string = s_utils.one_hot_to_string(self._state["sequence"],
                                                 self.alphabet)

        if self.fitness_model_is_gt:
            self._state["fitness"] = self.landscape.get_fitness(
                [state_string]).astype(np.float32)
        else:
            self._state["fitness"] = self.model.get_fitness(
                [state_string]).astype(np.float32)
        self.all_seqs[state_string] = self._state["fitness"].item()

        reward = self._state["fitness"].item(
        ) - self.lam * self.sequence_density(state_string)

        # if we have seen the sequence this episode,
        # terminate episode and punish
        # (to prevent going in loops)
        if state_string in self.episode_seqs:
            return ts.termination(self._state, -1)
        self.episode_seqs.add(state_string)

        # if the reward is not increasing, then terminate
        if reward < self.previous_fitness:
            return ts.termination(self._state, reward=reward)

        self.previous_fitness = reward
        return ts.transition(self._state, reward=reward)
Ejemplo n.º 6
0
    def add_last_seq_in_trajectory(self, experience, new_seqs):
        """Add the last sequence in an episode's trajectory.

        Given a trajectory object, checks if the object is the last in the trajectory.
        Since the environment ends the episode when the score is non-increasing, it
        adds the associated maximum-valued sequence to the batch.

        If the episode is ending, it changes the "current sequence" of the environment
        to the next one in `last_batch`, so that when the environment resets, mutants
        are generated from that new sequence.
        """
        for is_bound, obs in zip(experience.is_boundary(), experience.observation):
            if is_bound:
                seq = s_utils.one_hot_to_string(obs.numpy()[:, :-1], self.alphabet)
                new_seqs[seq] = self.tf_env.get_cached_fitness(seq)
Ejemplo n.º 7
0
 def train_models(self):
     """Train the model."""
     if len(self.memory) >= self.sequences_batch_size:
         batch = self.memory.sample_batch()
     else:
         self.memory.batch_size = len(self.memory)
         batch = self.memory.sample_batch()
         self.memory.batch_size = self.sequences_batch_size
     states = batch["next_obs"]
     state_seqs = [
         one_hot_to_string(state.reshape((-1, len(self.alphabet))), self.alphabet)
         for state in states
     ]
     rewards = batch["rews"]
     self.model.train(state_seqs, rewards)
Ejemplo n.º 8
0
 def get_state_string(self):
     """Get sequence representing current state."""
     return s_utils.one_hot_to_string(self._state["sequence"],
                                      self.alphabet)