def tf_observe_timestep(self, states, internals, actions, terminal, reward): # Store timestep in memory stored = self.memory.store(states=states, internals=internals, actions=actions, terminal=terminal, reward=reward) # Periodic optimization with tf.control_dependencies(control_inputs=(stored, )): unit = self.update_mode['unit'] batch_size = self.update_mode['batch_size'] frequency = self.update_mode.get('frequency', batch_size) if unit == 'timesteps': # Timestep-based batch optimize = tf.logical_and(x=tf.equal(x=(self.timestep % frequency), y=0), y=tf.greater_equal(x=self.timestep, y=batch_size)) batch = self.memory.retrieve_timesteps(n=batch_size) elif unit == 'episodes': # Episode-based batch optimize = tf.logical_and( x=tf.equal(x=(self.episode % frequency), y=0), y=tf.logical_and( # Only update once per episode increment. x=tf.greater(x=tf.count_nonzero(input_tensor=terminal), y=0), y=tf.greater_equal(x=self.episode, y=batch_size))) batch = self.memory.retrieve_episodes(n=batch_size) elif unit == 'sequences': # Timestep-sequence-based batch sequence_length = self.update_mode.get('length', 8) optimize = tf.logical_and( x=tf.equal(x=(self.timestep % frequency), y=0), y=tf.greater_equal(x=self.timestep, y=(batch_size + sequence_length - 1))) batch = self.memory.retrieve_sequences( n=batch_size, sequence_length=sequence_length) else: raise TensorForceError("Invalid update unit: {}.".format(unit)) # Do not calculate gradients for memory-internal operations. batch = util.map_tensors( fn=(lambda tensor: tf.stop_gradient(input=tensor)), tensors=batch) optimization = tf.cond( pred=optimize, true_fn=(lambda: self.fn_optimization(**batch)), false_fn=tf.no_op) return optimization
def tf_step(self, time, variables, arguments, **kwargs): """ Creates the TensorFlow operations for performing an optimization step. Args: time: Time tensor. variables: List of variables to optimize. arguments: Dict of arguments for callables, like fn_loss. **kwargs: Additional arguments passed on to the internal optimizer. Returns: List of delta tensors corresponding to the updates for each optimized variable. """ # Get some (batched) argument to determine batch size. arguments_iter = iter(arguments.values()) some_argument = next(arguments_iter) try: while not isinstance(some_argument, tf.Tensor) or util.rank(some_argument) == 0: if isinstance(some_argument, dict): if some_argument: arguments_iter = iter(some_argument.values()) some_argument = next(arguments_iter) elif isinstance(some_argument, list): if some_argument: arguments_iter = iter(some_argument) some_argument = next(arguments_iter) elif some_argument is None or util.rank(some_argument) == 0: # Non-batched argument some_argument = next(arguments_iter) else: raise TensorForceError("Invalid argument type.") except StopIteration: raise TensorForceError("Invalid argument type.") batch_size = tf.shape(input=some_argument)[0] num_samples = tf.cast( x=(self.fraction * tf.cast(x=batch_size, dtype=util.tf_dtype('float'))), dtype=util.tf_dtype('int')) num_samples = tf.maximum(x=num_samples, y=1) indices = tf.random_uniform(shape=(num_samples, ), maxval=batch_size, dtype=tf.int32) subsampled_arguments = util.map_tensors(fn=( lambda arg: arg if util.rank(arg) == 0 else tf.gather(params=arg, indices=indices) ), tensors=arguments) return self.optimizer.step(time=time, variables=variables, arguments=subsampled_arguments, **kwargs)
def true_fn(): if unit == 'timesteps': # Timestep-based batch batch = self.memory.retrieve_timesteps(n=batch_size) elif unit == 'episodes': # Episode-based batch batch = self.memory.retrieve_episodes(n=batch_size) elif unit == 'sequences': # Timestep-sequence-based batch batch = self.memory.retrieve_sequences(n=batch_size, sequence_length=sequence_length) # Do not calculate gradients for memory-internal operations. batch = util.map_tensors( fn=(lambda tensor: tf.stop_gradient(input=tensor)), tensors=batch ) optimize = self.fn_optimization(**batch) with tf.control_dependencies(control_inputs=(optimize,)): return tf.logical_and(x=True, y=True)