def add_last_seq_in_trajectory(self, experience, new_seqs): """Add the last sequence in an episode's trajectory. Given a trajectory object, checks if the object is the last in the trajectory. Since the environment ends the episode when the score is non-increasing, it adds the associated maximum-valued sequence to the batch. If the episode is ending, it changes the "current sequence" of the environment to the next one in `last_batch`, so that when the environment resets, mutants are generated from that new sequence. """ if experience.is_boundary(): seq = one_hot_to_string( experience.observation["sequence"].numpy()[0], self.alphabet) new_seqs[seq] = experience.observation["fitness"].numpy().squeeze() top_fitness = max(new_seqs.values()) top_sequences = [ seq for seq, fitness in new_seqs.items() if fitness >= 0.9 * top_fitness ] if len(top_sequences) > 0: self.tf_env.pyenv.envs[0].seq = np.random.choice(top_sequences) else: self.tf_env.pyenv.envs[0].seq = np.random.choice( [seq for seq, _ in new_seqs.items()])
def _step(self, actions): """Progress the agent one step in the environment.""" actions = actions.flatten() self.states[:, self.partial_seq_len, -1] = 0 self.states[np.arange(self._batch_size), self.partial_seq_len, actions] = 1 self.partial_seq_len += 1 # We have not generated the last residue in the sequence, so continue if self.partial_seq_len < self.seq_length - 1: return nest_utils.stack_nested_arrays( [ts.transition(seq_state, 0) for seq_state in self.states]) # If sequence is of full length, score the sequence and end the episode # We need to take off the column in the matrix (-1) representing the mask token complete_sequences = [ s_utils.one_hot_to_string(seq_state[:, :-1], self.alphabet) for seq_state in self.states ] if self.fitness_model_is_gt: fitnesses = self.landscape.get_fitness(complete_sequences) else: fitnesses = self.model.get_fitness(complete_sequences) self.all_seqs.update(zip(complete_sequences, fitnesses)) # Reward = fitness - lambda * sequence density rewards = np.array([ f - self.lam * self.sequence_density(seq) for seq, f in zip(complete_sequences, fitnesses) ]) return nest_utils.stack_nested_arrays([ ts.termination(seq_state, r) for seq_state, r in zip(self.states, rewards) ])
def _soln_to_string(self, soln): x = soln.reshape((len(self.starting_sequence), len(self.alphabet))) one_hot = np.zeros(x.shape) one_hot[np.arange(len(one_hot)), np.argmax(x, axis=1)] = 1 return s_utils.one_hot_to_string(one_hot, self.alphabet)
def pick_action(self, all_measured_seqs): """Pick action.""" state = self.state.copy() actions = self.sample_actions() actions_to_screen = [] states_to_screen = [] for i in range(self.model_queries_per_batch // self.sequences_batch_size): x = np.zeros((self.seq_len, len(self.alphabet))) for action in actions[i]: x[action] = 1 actions_to_screen.append(x) state_to_screen = construct_mutant_from_sample(x, state) states_to_screen.append(one_hot_to_string(state_to_screen, self.alphabet)) ensemble_preds = self.model.get_fitness(states_to_screen) method_pred = ( [self.EI(vals) for vals in ensemble_preds] if self.method == "EI" else [self.UCB(vals) for vals in ensemble_preds] ) action_ind = np.argmax(method_pred) uncertainty = np.std(method_pred[action_ind]) action = actions_to_screen[action_ind] new_state_string = states_to_screen[action_ind] self.state = string_to_one_hot(new_state_string, self.alphabet) new_state = self.state reward = np.mean(ensemble_preds[action_ind]) if new_state_string not in all_measured_seqs: self.best_fitness = max(self.best_fitness, reward) self.memory.store(state.ravel(), action.ravel(), reward, new_state.ravel()) self.num_actions += 1 return uncertainty, new_state_string, reward
def _step(self, action): """Progress the agent one step in the environment. The agent moves until the reward is decreasing. The number of sequences that can be evaluated at each episode is capped to `self.max_num_steps`. """ # if we've exceeded the maximum number of steps, terminate if self.num_steps >= self.max_num_steps: return ts.termination(self._state, 0) # `action` is an integer representing which residue to mutate to 1 # along the flattened one-hot representation of the sequence pos = action // len(self.alphabet) res = action % len(self.alphabet) self.num_steps += 1 # if we are trying to modify the sequence with a no-op, then stop if self._state["sequence"][pos, res] == 1: return ts.termination(self._state, 0) self._state["sequence"][pos] = 0 self._state["sequence"][pos, res] = 1 state_string = s_utils.one_hot_to_string(self._state["sequence"], self.alphabet) if self.fitness_model_is_gt: self._state["fitness"] = self.landscape.get_fitness( [state_string]).astype(np.float32) else: self._state["fitness"] = self.model.get_fitness( [state_string]).astype(np.float32) self.all_seqs[state_string] = self._state["fitness"].item() reward = self._state["fitness"].item( ) - self.lam * self.sequence_density(state_string) # if we have seen the sequence this episode, # terminate episode and punish # (to prevent going in loops) if state_string in self.episode_seqs: return ts.termination(self._state, -1) self.episode_seqs.add(state_string) # if the reward is not increasing, then terminate if reward < self.previous_fitness: return ts.termination(self._state, reward=reward) self.previous_fitness = reward return ts.transition(self._state, reward=reward)
def add_last_seq_in_trajectory(self, experience, new_seqs): """Add the last sequence in an episode's trajectory. Given a trajectory object, checks if the object is the last in the trajectory. Since the environment ends the episode when the score is non-increasing, it adds the associated maximum-valued sequence to the batch. If the episode is ending, it changes the "current sequence" of the environment to the next one in `last_batch`, so that when the environment resets, mutants are generated from that new sequence. """ for is_bound, obs in zip(experience.is_boundary(), experience.observation): if is_bound: seq = s_utils.one_hot_to_string(obs.numpy()[:, :-1], self.alphabet) new_seqs[seq] = self.tf_env.get_cached_fitness(seq)
def train_models(self): """Train the model.""" if len(self.memory) >= self.sequences_batch_size: batch = self.memory.sample_batch() else: self.memory.batch_size = len(self.memory) batch = self.memory.sample_batch() self.memory.batch_size = self.sequences_batch_size states = batch["next_obs"] state_seqs = [ one_hot_to_string(state.reshape((-1, len(self.alphabet))), self.alphabet) for state in states ] rewards = batch["rews"] self.model.train(state_seqs, rewards)
def get_state_string(self): """Get sequence representing current state.""" return s_utils.one_hot_to_string(self._state["sequence"], self.alphabet)