def test_enqueue_dequeue(self): """ Simply tests insert op without checking internal logic. """ fifo_queue = FIFOQueue(capacity=self.capacity, record_space=self.record_space) test = ComponentTest(component=fifo_queue, input_spaces=self.input_spaces) first_record = self.record_space.sample(size=1) test.test(("insert_records", first_record), expected_outputs=None) test.test("get_size", expected_outputs=1) further_records = self.record_space.sample(size=5) test.test(("insert_records", further_records), expected_outputs=None) test.test("get_size", expected_outputs=6) expected = dict() for (k1, v1), (k2, v2) in zip( flatten_op(first_record).items(), flatten_op(further_records).items()): expected[k1] = np.concatenate((v1, v2[:4])) expected = unflatten_op(expected) test.test(("get_records", 5), expected_outputs=expected) test.test("get_size", expected_outputs=1)
def _graph_fn_split_batch(self, *inputs): """ Splits all DataOps in *inputs along their batch dimension into n equally sized shards. The number of shards is determined by `self.num_shards` (int) and the size of each shard depends on the incoming batch size with possibly a few superfluous items in the batch being discarded (effective batch size = num_shards x shard_size). Args: *input (FlattenedDataOp): Input tensors which must all have the same batch dimension. Returns: tuple: # Each shard consisting of: A DataOpTuple with len = number of input args. # - Each item in the DataOpTuple is a FlattenedDataOp with (flat) key (describing the input-piece # (e.g. "/states1")) and values being the (now sharded) batch data for that input piece. # e.g. return (for 2 shards): # tuple(DataOpTuple(input1_flatdict, input2_flatdict, input3_flatdict, input4_flatdict), DataOpTuple([same])) List of FlattenedDataOps () containing DataOpTuples containing the input shards. """ if get_backend() == "tf": #batch_size = tf.shape(next(iter(inputs[0].values())))[0] #shard_size = tf.cast(batch_size / self.num_shards, dtype=tf.int32) # Must be evenly divisible so we slice out an evenly divisible tensor. # E.g. 203 items in batch with 4 shards -> Only 4 x 50 = 200 are usable. usable_size = self.shard_size * self.num_shards # List (one item for each input arg). Each item in the list looks like: # A FlattenedDataOp with (flat) keys (describing the input-piece (e.g. "/states1")) and values being # lists of len n for the n shards' data. inputs_flattened_and_split = list() for input_arg_data in inputs: shard_dict = FlattenedDataOp() for flat_key, data in input_arg_data.items(): usable_input_tensor = data[:usable_size] shard_dict[flat_key] = tf.split( value=usable_input_tensor, num_or_size_splits=self.num_shards) inputs_flattened_and_split.append(shard_dict) # Flip the list to generate a new list where each item represents one shard. shard_list = list() for shard_idx in range(self.num_shards): # To be converted into FlattenedDataOps over the input-arg-pieces once complete. input_arg_list = list() for input_elem in range(len(inputs)): sharded_data_dict = FlattenedDataOp() for flat_key, shards in inputs_flattened_and_split[ input_elem].items(): sharded_data_dict[flat_key] = shards[shard_idx] input_arg_list.append(unflatten_op(sharded_data_dict)) # Must store everything as FlattenedDataOp otherwise the re-nesting will not work. shard_list.append(DataOpTuple(input_arg_list)) # Return n values (n = number of batch shards). return tuple(shard_list)
def unflatten_output_ops(*ops): """ Re-creates the originally nested input structure (as DataOpDict/DataOpTuple) of the given op-record column. Process all FlattenedDataOp with auto-generated keys, and leave the others untouched. Args: ops (DataOp): The ops that need to be unflattened (only process the FlattenedDataOp amongst these and ignore all others). Returns: Tuple[DataOp]: A tuple containing the ops as they came in, except that all FlattenedDataOp have been un-flattened (re-nested) into their original structures. """ # The returned sequence of output ops. ret = [] for i, op in enumerate(ops): # A FlattenedDataOp: Try to re-nest it. if isinstance(op, FlattenedDataOp): ret.append(unflatten_op(op)) # All others are left as-is. else: ret.append(op) # Always return a tuple for indexing into the return values. return tuple(ret)
def test_capacity(self): """ Tests if insert correctly blocks when capacity is reached. """ fifo_queue = FIFOQueue(capacity=self.capacity, record_space=self.record_space) test = ComponentTest(component=fifo_queue, input_spaces=self.input_spaces) def run(expected_): # Wait n seconds. time.sleep(2) # Pull something out of the queue again to continue. test.test(("get_records", 2), expected_outputs=expected_) # Insert one more element than capacity records = self.record_space.sample(size=self.capacity + 1) expected = dict() for key, value in flatten_op(records).items(): expected[key] = value[:2] expected = unflatten_op(expected) # Start thread to save this one from getting stuck due to capacity overflow. thread = threading.Thread(target=run, args=(expected, )) thread.start() print("Going over capacity: blocking ...") test.test(("insert_records", records), expected_outputs=None) print("Dequeued some items in another thread. Unblocked.") thread.join()
def get_preprocessed_space(self, space): # Translate to corresponding FloatBoxes. ret = dict() for key, value in space.flatten().items(): ret[key] = FloatBox(shape=value.shape, add_batch_rank=value.has_batch_rank, add_time_rank=value.has_time_rank) return unflatten_op(ret)
def get_preprocessed_space(self, space): ret = {} for key, value in space.flatten().items(): shape = list(value.shape) if self.keep_rank is True: shape[-1] = 1 else: shape.pop(-1) ret[key] = value.__class__(shape=tuple(shape), add_batch_rank=value.has_batch_rank) return unflatten_op(ret)
def _graph_fn_get_action_components(self, logits, parameters, deterministic): ret = {} # TODO Clean up the checks in here wrt define-by-run processing. for flat_key, action_space_component in self.action_space.flatten( ).items(): # Skip our distribution, iff discrete action-space and deterministic acting (greedy). # In that case, one does not need to create a distribution in the graph each act (only to get the argmax # over the logits, which is the same as the argmax over the probabilities (or log-probabilities)). if isinstance(action_space_component, IntBox) and \ (deterministic is True or (isinstance(deterministic, np.ndarray) and deterministic)): if flat_key == "": return self._graph_fn_get_deterministic_action_wo_distribution( logits) else: ret[flat_key] = self._graph_fn_get_deterministic_action_wo_distribution( logits.flat_key_lookup(flat_key)) elif isinstance(action_space_component, BoolBox) and \ (deterministic is True or (isinstance(deterministic, np.ndarray) and deterministic)): if get_backend() == "tf": if flat_key == "": return tf.greater(logits, 0.5) else: ret[flat_key] = tf.greater( logits.flat_key_lookup(flat_key), 0.5) elif get_backend() == "pytorch": if flat_key == "": return torch.gt(logits, 0.5) else: ret[flat_key] = torch.gt( logits.flat_key_lookup(flat_key), 0.5) else: if flat_key == "": # Still wrapped as FlattenedDataOp. if isinstance(parameters, FlattenedDataOp): return self.distributions[flat_key].draw( parameters[flat_key], deterministic) else: return self.distributions[flat_key].draw( parameters, deterministic) if isinstance(parameters, ContainerDataOp) and not \ (isinstance(parameters, DataOpDict) and flat_key in parameters): ret[flat_key] = self.distributions[flat_key].draw( parameters.flat_key_lookup(flat_key), deterministic) else: ret[flat_key] = self.distributions[flat_key].draw( parameters[flat_key], deterministic) if get_backend() == "tf": return unflatten_op(ret) elif get_backend() == "pytorch": return define_by_run_unflatten(ret)
def get_preprocessed_space(self, space): ret = dict() for key, single_space in space.flatten().items(): class_ = type(single_space) # We flip batch and time ranks. time_major = not single_space.time_major ret[key] = class_(shape=single_space.shape, add_batch_rank=single_space.has_batch_rank, add_time_rank=single_space.has_time_rank, time_major=time_major) self.output_time_majors[key] = time_major ret = unflatten_op(ret) return ret
def get_preprocessed_space(self, space): ret = {} for key, single_space in space.flatten().items(): class_ = type(single_space) # Determine the actual shape (not batch/time ranks). if self.flatten is True: if type(single_space ) == IntBox and self.flatten_categories is not False: assert self.flatten_categories is not None,\ "ERROR: `flatten_categories` must not be None if `flatten` is True and input is IntBox!" new_shape = (self.get_num_categories(key, single_space), ) class_ = FloatBox else: new_shape = (single_space.flat_dim, ) else: new_shape = self.new_shape[key] if isinstance( self.new_shape, dict) else self.new_shape # Check the batch/time rank options. if self.fold_time_rank is True: sanity_check_space(single_space, must_have_batch_rank=True, must_have_time_rank=True) ret[key] = class_(shape=single_space.shape if new_shape is None else new_shape, add_batch_rank=True, add_time_rank=False) # Time rank should be unfolded from batch rank with the given dimension. elif self.unfold_time_rank is True: sanity_check_space(single_space, must_have_batch_rank=True, must_have_time_rank=False) ret[key] = class_(shape=single_space.shape if new_shape is None else new_shape, add_batch_rank=True, add_time_rank=True, time_major=self.time_major if self.time_major is not None else False) # Only change the actual shape (leave batch/time ranks as is). else: time_major = single_space.time_major ret[key] = class_(shape=single_space.shape if new_shape is None else new_shape, add_batch_rank=single_space.has_batch_rank, add_time_rank=single_space.has_time_rank, time_major=time_major) ret = unflatten_op(ret) return ret
def get_preprocessed_space(self, space): ret = dict() for key, value in space.flatten().items(): # Do some sanity checking. rank = value.rank assert rank == 2 or rank == 3, \ "ERROR: Given image's rank (which is {}{}, not counting batch rank) must be either 2 or 3!".\ format(rank, ("" if key == "" else " for key '{}'".format(key))) # Determine the output shape. shape = list(value.shape) shape[0] = self.width shape[1] = self.height ret[key] = value.__class__(shape=tuple(shape), add_batch_rank=value.has_batch_rank) return unflatten_op(ret)
def get_preprocessed_space(self, space): ret = dict() for key, single_space in space.flatten().items(): class_ = type(single_space) # Determine the actual shape (not batch/time ranks). if self.flatten is True: if self.flatten_categories is not False and type(single_space) == IntBox: if self.flatten_categories is True: num_categories = single_space.flat_dim_with_categories else: num_categories = self.flatten_categories new_shape = (num_categories,) else: new_shape = (single_space.flat_dim,) if self.flatten_categories is not False and type(single_space) == IntBox: class_ = FloatBox else: new_shape = self.new_shape[key] if isinstance(self.new_shape, dict) else self.new_shape # Check the batch/time rank options. if self.fold_time_rank is True: sanity_check_space(single_space, must_have_batch_rank=True, must_have_time_rank=True) ret[key] = class_( shape=single_space.shape if new_shape is None else new_shape, add_batch_rank=True, add_time_rank=False ) # Time rank should be unfolded from batch rank with the given dimension. elif self.unfold_time_rank is True: sanity_check_space(single_space, must_have_batch_rank=True, must_have_time_rank=False) ret[key] = class_( shape=single_space.shape if new_shape is None else new_shape, add_batch_rank=True, add_time_rank=True, time_major=self.time_major if self.time_major is not None else False ) # Only change the actual shape (leave batch/time ranks as is). else: # Do we flip batch and time ranks? time_major = single_space.time_major if self.flip_batch_and_time_rank is False else \ not single_space.time_major ret[key] = class_(shape=single_space.shape if new_shape is None else new_shape, add_batch_rank=single_space.has_batch_rank, add_time_rank=single_space.has_time_rank, time_major=time_major) ret = unflatten_op(ret) return ret
def get_preprocessed_space(self, space): ret = dict() for key, value in space.flatten().items(): shape = list(value.shape) if self.add_rank: shape.append(self.sequence_length) else: shape[-1] *= self.sequence_length # TODO move to transpose component. # Transpose. if self.in_data_format == "channels_last" and self.out_data_format == "channels_first": shape.reverse() ret[key] = value.__class__(shape=tuple(shape), add_batch_rank=value.has_batch_rank) else: ret[key] = value.__class__(shape=tuple(shape), add_batch_rank=value.has_batch_rank) return unflatten_op(ret)
def get_preprocessed_space(self, space): """ Returns the Space obtained after pushing the input through all layers of this Stack. Args: space (Dict): The incoming Space object. Returns: Space: The Space after preprocessing. """ assert isinstance(space, ContainerSpace) dict_spec = dict() for flat_key, sub_space in space.flatten().items(): if flat_key in self.flattened_preprocessors: dict_spec[flat_key] = self.flattened_preprocessors[flat_key].get_preprocessed_space(sub_space) else: dict_spec[flat_key] = sub_space dict_spec = unflatten_op(dict_spec) return Dict(dict_spec)
def _graph_fn_unstage(self): """ Unstages (and unflattens) all staged data. Returns: Tuple[DataOp]: All previously staged ops. """ unstaged_data = self.area.get() unflattened_data = list() idx = 0 # Unflatten all data and return. for flat_key_list in self.flat_keys: flat_dict = FlattenedDataOp({ flat_key: item for flat_key, item in zip( flat_key_list, unstaged_data[idx:idx + len(flat_key_list)]) }) unflattened_data.append(unflatten_op(flat_dict)) idx += len(flat_key_list) return tuple(unflattened_data)
def get_preprocessed_space(self, space): ## Test sending np samples to get number of return values and output spaces without having to call ## the tf graph_fn. #backend = self.backend #self.backend = "python" #sample = space.sample(size=1) #out = self._graph_fn_call(sample) #new_space = get_space_from_op(out) #self.backend = backend #return new_space ret = dict() for key, value in space.flatten().items(): # Do some sanity checking. rank = value.rank if get_backend() == "tf": assert rank == 2 or rank == 3, \ "ERROR: Given image's rank (which is {}{}, not counting batch rank) must be either 2 or 3!".\ format(rank, ("" if key == "" else " for key '{}'".format(key))) # Determine the output shape. shape = list(value.shape) shape[0] = self.width shape[1] = self.height elif get_backend() == "pytorch": shape = list(value.shape) # Determine the output shape. if rank == 3: shape[0] = self.width shape[1] = self.height elif rank == 4: # TODO PyTorch shape inference issue. shape[1] = self.width shape[2] = self.height ret[key] = value.__class__(shape=tuple(shape), add_batch_rank=value.has_batch_rank) return unflatten_op(ret)
def scan_func(accum, time_delta): # Not needed: preprocessed-previous-states (tuple!) # `state` is a tuple as well. See comment in ctor for why tf cannot use ContainerSpaces here. internal_states = None state = accum[1] if self.has_rnn: internal_states = accum[-1] state = tuple(tf.convert_to_tensor(value=s) for s in state) flat_state = OrderedDict() for i, flat_key in enumerate( self.state_space_actor_flattened.keys()): # Add a simple (size 1) batch rank to the state so it'll pass through the NN. # - Also have to add a time-rank for RNN processing. expanded = state[i] for _ in range(1 if self.has_rnn is False else 2): expanded = tf.expand_dims(input=expanded, axis=0) # Make None so it'll be recognized as batch-rank by the auto-Space detector. flat_state[flat_key] = tf.placeholder_with_default( input=expanded, shape=(None, ) + ((None, ) if self.has_rnn is True else ()) + self.state_space_actor_list[i].shape) # Recreate state as the original Space to pass it into the actor-component. state = unflatten_op(flat_state) # Get action and preprocessed state (as batch-size 1). out = (self.actor_component.get_preprocessed_state_and_action if self.add_action_probs is False else self.actor_component. get_preprocessed_state_action_and_action_probs)( state, # Add simple batch rank to internal_states. None if internal_states is None else DataOpTuple( internal_states), # <- None for non-RNN systems time_step=self.time_step + time_delta, return_ops=True) # Get output depending on whether it contains internal_states or not. a = out["action"] action_probs = out.get("action_probs") current_internal_states = out.get("last_internal_states") # Strip the batch (and maybe time) ranks again from the action in case the Env doesn't like it. a_no_extra_ranks = a[0, 0] if self.has_rnn is True else a[0] # Step through the Env and collect next state (tuple!), reward and terminal as single values # (not batched). out = self.environment_server.step_for_env_stepper( a_no_extra_ranks) s_, r, t_ = out[:-2], out[-2], out[-1] r = tf.cast(r, dtype="float32") # Add a and/or r to next_state? if self.add_previous_action_to_state is True: assert isinstance( s_, tuple ), "ERROR: Cannot add previous action to non tuple!" s_ = s_ + (a_no_extra_ranks, ) if self.add_previous_reward_to_state is True: assert isinstance( s_, tuple ), "ERROR: Cannot add previous reward to non tuple!" s_ = s_ + (r, ) # Note: s_ is packed as tuple. ret = [t_, s_] + \ ([a_no_extra_ranks] if self.add_action else []) + \ ([r] if self.add_reward else []) + \ ([(action_probs[0][0] if self.has_rnn is True else action_probs[0])] if self.add_action_probs is True else []) + \ ([tuple(current_internal_states)] if self.has_rnn is True else []) return tuple(ret)
def _graph_fn_step(self): if get_backend() == "tf": def scan_func(accum, time_delta): # Not needed: preprocessed-previous-states (tuple!) # `state` is a tuple as well. See comment in ctor for why tf cannot use ContainerSpaces here. internal_states = None state = accum[1] if self.has_rnn: internal_states = accum[-1] state = tuple(tf.convert_to_tensor(value=s) for s in state) flat_state = OrderedDict() for i, flat_key in enumerate( self.state_space_actor_flattened.keys()): # Add a simple (size 1) batch rank to the state so it'll pass through the NN. # - Also have to add a time-rank for RNN processing. expanded = state[i] for _ in range(1 if self.has_rnn is False else 2): expanded = tf.expand_dims(input=expanded, axis=0) # Make None so it'll be recognized as batch-rank by the auto-Space detector. flat_state[flat_key] = tf.placeholder_with_default( input=expanded, shape=(None, ) + ((None, ) if self.has_rnn is True else ()) + self.state_space_actor_list[i].shape) # Recreate state as the original Space to pass it into the actor-component. state = unflatten_op(flat_state) # Get action and preprocessed state (as batch-size 1). out = (self.actor_component.get_preprocessed_state_and_action if self.add_action_probs is False else self.actor_component. get_preprocessed_state_action_and_action_probs)( state, # Add simple batch rank to internal_states. None if internal_states is None else DataOpTuple( internal_states), # <- None for non-RNN systems time_step=self.time_step + time_delta, return_ops=True) # Get output depending on whether it contains internal_states or not. a = out["action"] action_probs = out.get("action_probs") current_internal_states = out.get("last_internal_states") # Strip the batch (and maybe time) ranks again from the action in case the Env doesn't like it. a_no_extra_ranks = a[0, 0] if self.has_rnn is True else a[0] # Step through the Env and collect next state (tuple!), reward and terminal as single values # (not batched). out = self.environment_server.step_for_env_stepper( a_no_extra_ranks) s_, r, t_ = out[:-2], out[-2], out[-1] r = tf.cast(r, dtype="float32") # Add a and/or r to next_state? if self.add_previous_action_to_state is True: assert isinstance( s_, tuple ), "ERROR: Cannot add previous action to non tuple!" s_ = s_ + (a_no_extra_ranks, ) if self.add_previous_reward_to_state is True: assert isinstance( s_, tuple ), "ERROR: Cannot add previous reward to non tuple!" s_ = s_ + (r, ) # Note: s_ is packed as tuple. ret = [t_, s_] + \ ([a_no_extra_ranks] if self.add_action else []) + \ ([r] if self.add_reward else []) + \ ([(action_probs[0][0] if self.has_rnn is True else action_probs[0])] if self.add_action_probs is True else []) + \ ([tuple(current_internal_states)] if self.has_rnn is True else []) return tuple(ret) # Initialize the tf.scan run. initializer = [ self.current_terminal.read_value( ), # whether the current state is terminal # current (raw) state (flattened components if ContainerSpace). tuple( map(lambda x: x.read_value(), self.current_state.values())) ] # Append actions and rewards if needed. if self.add_action: initializer.append(self.current_action.read_value()) if self.add_reward: initializer.append(self.current_reward.read_value()) # Append action probs if needed. if self.add_action_probs is True: initializer.append(self.current_action_probs.read_value()) # Append internal states if needed. if self.current_internal_states is not None: initializer.append( tuple( tf.placeholder_with_default( internal_s.read_value(), shape=(None, ) + tuple(internal_s.shape.as_list()[1:])) for internal_s in self.current_internal_states.values())) # Scan over n time-steps (tf.range produces the time_delta with respect to the current time_step). # NOTE: Changed parallel to 1, to resolve parallel issues. step_results = list( tf.scan(fn=scan_func, elems=tf.range(self.num_steps, dtype="int32"), initializer=tuple(initializer), back_prop=False)) # Store the time-step increment, return so far, current terminal and current state. assigns = [ tf.assign_add(self.time_step, self.num_steps), self.assign_variable(self.current_terminal, step_results[0][-1]) ] # Concatenate first and rest. full_results = [] for first_values, rest_values in zip(initializer, step_results): full_results.append( nest.map_structure( lambda first, rest: tf.concat([[first], rest], axis=0), first_values, rest_values)) # Re-build DataOpDicts from preprocessed-states and states (from tuple right now). rebuild_s = DataOpDict() for flat_key, var_ref, s_comp in zip( self.state_space_actor_flattened.keys(), self.current_state.values(), full_results[1]): assigns.append(self.assign_variable( var_ref, s_comp[-1])) # -1: current state (last observed) rebuild_s[flat_key] = s_comp rebuild_s = unflatten_op(rebuild_s) full_results[1] = rebuild_s # Remove batch rank from internal states again. if self.current_internal_states is not None: # TODO: What if internal states is not the last item in the list anymore due to some change. slot = -1 # if self.add_action_probs is True else 2 # TODO: What if internal states is a dict? Right now assume some tuple. internal_states_wo_batch = list() for i in range(len(full_results[slot])): # 1=batch axis (which is 1); 0=time axis. internal_states_wo_batch.append( tf.squeeze(full_results[-1][i], axis=1)) full_results[slot] = DataOpTuple(internal_states_wo_batch) with tf.control_dependencies(control_inputs=assigns): # Let the auto-infer system know, what time rank we have. full_results = DataOpTuple(full_results) for o in flatten_op(full_results).values(): o._time_rank = 0 # which position in the shape is the time-rank? step_op = tf.no_op() return step_op, full_results
def map(self, mapping): flattened_self = self.flatten(mapping=mapping) return Tuple( unflatten_op(flattened_self), add_batch_rank=self.has_batch_rank, add_time_rank=self.has_time_rank, time_major=self.time_major )
def opt_body(index_, loss_, loss_per_item_, vf_loss_, vf_loss_per_item_): start = tf.random_uniform(shape=(), minval=0, maxval=batch_size - 1, dtype=tf.int32) indices = tf.range( start=start, limit=start + agent.sample_size) % batch_size sample_states = tf.gather(params=preprocessed_states, indices=indices) if isinstance(actions, ContainerDataOp): sample_actions = FlattenedDataOp() for name, action in flatten_op(actions).items(): sample_actions[name] = tf.gather(params=action, indices=indices) sample_actions = unflatten_op(sample_actions) else: sample_actions = tf.gather(params=actions, indices=indices) sample_prior_log_probs = tf.gather(params=prev_log_probs, indices=indices) sample_rewards = tf.gather(params=rewards, indices=indices) sample_terminals = tf.gather(params=terminals, indices=indices) sample_sequence_indices = tf.gather( params=sequence_indices, indices=indices) sample_advantages = tf.gather(params=advantages, indices=indices) sample_advantages.set_shape((self.sample_size, )) sample_baseline_values = value_function.value_output( sample_states) sample_prior_baseline_values = tf.gather( params=prior_baseline_values, indices=indices) # If we are a multi-GPU root: # Simply feeds everything into the multi-GPU sync optimizer's method and return. if multi_gpu_sync_optimizer is not None: main_policy_vars = agent.policy.variables() main_vf_vars = agent.value_function.variables() all_vars = agent.vars_merger.merge( main_policy_vars, main_vf_vars) # grads_and_vars, loss, loss_per_item, vf_loss, vf_loss_per_item = \ out = multi_gpu_sync_optimizer.calculate_update_from_external_batch( all_vars, sample_states, sample_actions, sample_rewards, sample_terminals, sample_sequence_indices, apply_postprocessing=apply_postprocessing) avg_grads_and_vars_policy, avg_grads_and_vars_vf = agent.vars_splitter.call( out["avg_grads_and_vars_by_component"]) policy_step_op = agent.optimizer.apply_gradients( avg_grads_and_vars_policy) vf_step_op = agent.value_function_optimizer.apply_gradients( avg_grads_and_vars_vf) step_op = root._graph_fn_group(policy_step_op, vf_step_op) step_and_sync_op = multi_gpu_sync_optimizer.sync_variables_to_towers( step_op, all_vars) loss_vf, loss_per_item_vf = out[ "additional_return_0"], out["additional_return_1"] # Have to set all shapes here due to strict loop-var shape requirements. out["loss"].set_shape(()) loss_vf.set_shape(()) loss_per_item_vf.set_shape((agent.sample_size, )) out["loss_per_item"].set_shape((agent.sample_size, )) with tf.control_dependencies([step_and_sync_op]): if index_ == 0: # Increase the global training step counter. out["loss"] = root._graph_fn_training_step( out["loss"]) return index_ + 1, out["loss"], out[ "loss_per_item"], loss_vf, loss_per_item_vf policy_probs = policy.get_log_likelihood( sample_states, sample_actions)["log_likelihood"] baseline_values = value_function.value_output( tf.stop_gradient(sample_states)) sample_rewards = tf.cond( pred=apply_postprocessing, true_fn=lambda: gae_function.calc_gae_values( baseline_values, sample_rewards, sample_terminals, sample_sequence_indices), false_fn=lambda: sample_rewards) sample_rewards.set_shape((agent.sample_size, )) entropy = policy.get_entropy(sample_states)["entropy"] loss, loss_per_item, vf_loss, vf_loss_per_item = \ loss_function.loss( policy_probs, sample_prior_log_probs, sample_baseline_values, sample_prior_baseline_values, sample_advantages, entropy ) if hasattr(root, "is_multi_gpu_tower" ) and root.is_multi_gpu_tower is True: policy_grads_and_vars = optimizer.calculate_gradients( policy.variables(), loss) vf_grads_and_vars = value_function_optimizer.calculate_gradients( value_function.variables(), vf_loss) grads_and_vars_by_component = vars_merger.merge( policy_grads_and_vars, vf_grads_and_vars) return grads_and_vars_by_component, loss, loss_per_item, vf_loss, vf_loss_per_item else: step_op, loss, loss_per_item = optimizer.step( policy.variables(), loss, loss_per_item) loss.set_shape(()) loss_per_item.set_shape((agent.sample_size, )) vf_step_op, vf_loss, vf_loss_per_item = value_function_optimizer.step( value_function.variables(), vf_loss, vf_loss_per_item) vf_loss.set_shape(()) vf_loss_per_item.set_shape((agent.sample_size, )) with tf.control_dependencies([step_op, vf_step_op]): return index_ + 1, loss, loss_per_item, vf_loss, vf_loss_per_item