def reset(self, batch_size: int = 1, **kwargs) -> StatesEnv: """ Reset the environment to the start of a new episode and return a new \ :class:`StatesEnv` instance describing the state of the :envs:`Environment`. Args: batch_size: Number of walkers that the returned state will have. **kwargs: Ignored. This environment resets without using any external data. Returns: :class:`StatesEnv` instance describing the state of the Environment. The first \ dimension of the data tensors (number of walkers) will be equal to \ batch_size. """ states = tensor.ones(tuple([batch_size]) + self.states_shape, dtype=dtype.int32) * -1 observs = tensor.zeros((batch_size, 3)) rewards = tensor.ones(batch_size, dtype=dtype.float32) * numpy.inf oobs = tensor.zeros(batch_size, dtype=dtype.bool) times = tensor.zeros(batch_size, dtype=dtype.int32) new_states = self.states_from_data( batch_size=batch_size, states=states, observs=observs, rewards=rewards, oobs=oobs, times=times, ) return new_states
def get_empty_registers(self, n_walkers: int = 1, batch_size: int = 1, as_tensor: bool = False) -> typing.Tensor: """Return an array of zeros representing a set of registers.""" registers = tensor.zeros((n_walkers, batch_size, self.n_registers), dtype=dtype.float32) if as_tensor: registers = tensor.to_torch(registers, use_grad=True) return registers
def forward_programs(env, programs, registers, grad: bool = False): with Backend.use_backend("torch"): programs, registers = tensor.to_backend(programs), tensor.to_backend(registers) predictions = tensor.zeros((registers.shape[0], registers.shape[1], env.output_dims)) for i, program in enumerate(programs): program_regs = tensor.unsqueeze(tensor.to_torch(registers[i])) outs = env.model.predict( program=tensor.to_torch(program), registers=program_regs, grad=grad ) predictions[i] = outs return tensor.to_backend(predictions)
def boundary_condition(env, rewards: typing.Tensor, times: typing.Tensor) -> typing.Tensor: """ Apply an arbitrary boundary conditions to discard ill-performing states. It discards states with losses greater than the mean loss of all walkers \ plus one standard deviation. """ mean, std = rewards.mean(), rewards.std() too_bad = rewards > (mean + std) too_long = times > env.max_len - 1 # if gym_env.max_len is not None else too_bad oobs = tensor.logical_or(too_bad, too_long) if not oobs.all(): return oobs return tensor.zeros(rewards.shape[0], dtype=dtype.bool)
def update_registers( env, registers: typing.Tensor, actions: typing.Tensor, times: typing.Tensor, grad: bool = False, to_backend: bool = False, ) -> typing.Tensor: """ Transition the states by updating the registers as described my the actions. Args: env: Environment used in the current program synthesis task. registers: Contain the writable registers for each walker. actions: Describes the register update process for each walker. grad: Compute the gradients of the model. Returns: Tensor containing the updated registers. """ # Assumes actions are integers actions = actions.flatten() with Backend.use_backend("torch"): new_registers = tensor.zeros(registers.shape, dtype=dtype.float32) for walker_i, (action, time) in enumerate(zip(actions, times.copy())): if time >= env.max_len: continue walker_regs = tensor.to_backend(registers[walker_i], use_grad=grad) # Calculate outputs and write them to the corresponding registers new_registers[walker_i, :] = env.repertoire.forward_one_action( walker_regs, action=action, grad=grad, index=int(time) ) if not grad and to_backend: new_registers = tensor.to_backend(new_registers) return new_registers