def define_by_run_flatten(container, key_scope="", tensor_tuple_list=None, scope_separator_at_start=True): """ Flattens a native python dict/tuple into a flat dict with auto-key generation. Run-time equivalent to build-time flatten operation. Args: container (Union[dict,tuple]): Container to flatten. key_scope (str): The recursive scope for auto-key generation. tensor_tuple_list (list): The list of tuples (key, value) to be converted into the final results. scope_separator_at_start (bool): If to prepend a scope separator before the first key in a recursive structure. Default false. Returns: Dict: Flattened container. """ ret = False # Are we in the non-recursive (first) call? if tensor_tuple_list is None: tensor_tuple_list = [] if not isinstance(container, (dict, tuple)): return DataOpDict([("", container)]) ret = True if isinstance(container, dict): if scope_separator_at_start: key_scope += FLATTEN_SCOPE_PREFIX else: key_scope = "" for key in sorted(container.keys()): # Make sure we have no double slashes from flattening an already FlattenedDataOp. scope = (key_scope[:-1] if len(key) == 0 or key[0] == "/" else key_scope) + key define_by_run_flatten(container[key], key_scope=scope, tensor_tuple_list=tensor_tuple_list, scope_separator_at_start=True) elif isinstance(container, tuple): if scope_separator_at_start: key_scope += FLATTEN_SCOPE_PREFIX + FLAT_TUPLE_OPEN else: key_scope += "" + FLAT_TUPLE_OPEN for i, c in enumerate(container): define_by_run_flatten(c, key_scope=key_scope + str(i) + FLAT_TUPLE_CLOSE, tensor_tuple_list=tensor_tuple_list, scope_separator_at_start=True) else: assert not isinstance(container, (dict, tuple)) tensor_tuple_list.append((key_scope, container)) # Non recursive (first) call -> Return the final dict. if ret: return DataOpDict(tensor_tuple_list)
def _graph_fn_merge(self, *inputs): """ Merges the inputs into a single DataOpDict OR DataOpTuple with the flat keys given in `self.dict_keys`. Args: *inputs (FlattenedDataOp): The input items to be merged into a ContainerDataOp. Returns: ContainerDataOp: The DataOpDict or DataOpTuple as a merger of all *inputs. """ if self.is_tuple is True: ret = [] for op in inputs: # Merge single items inside a DataOpTuple into resulting tuple. if self.merge_tuples_into_one and isinstance(op, DataOpTuple): ret.extend(list(op)) # Strict by-input merging. else: ret.append(op) return DataOpTuple(ret) else: ret = DataOpDict() for i, op in enumerate(inputs): if get_backend() == "pytorch" and self.execution_mode == "define_by_run": ret[FLATTEN_SCOPE_PREFIX + self.dict_keys[i]] = op else: ret[self.dict_keys[i]] = op return ret
def _graph_fn_merge(self, *inputs): """ Merges the inputs into a single FlattenedDataOp with the flat keys given in `self.input_names`. Args: *inputs (FlattenedDataOp): The input items to be merged into a FlattenedDataOp. Returns: FlattenedDataOp: The FlattenedDataOp as a merger of all api_methods. """ ret = DataOpDict() for i, op in enumerate(inputs): ret[self.input_names[i]] = op return ret
def _average_grads_and_vars(self, variables_by_component, grads_and_vars_all_gpus_by_component): """ Utility to average gradients (per var) across towers. Args: variables_by_component (DataOpDict[Dict[str,DataOp]]): Dict of Dict of variables. grads_and_vars_all_gpus_by_component (DataOpDict[??]]): Dict of grads_and_vars lists. Returns: DataOpDict[str,list]: DataOpDict with keys=component keys, values=list of grads_and_vars tuples averaged across our GPUs. """ if get_backend() == "tf": ret = dict() for component_key in variables_by_component.keys(): gpu_grad_averages = [] for i, grads_and_vars in enumerate( zip(* grads_and_vars_all_gpus_by_component[component_key] )): gpu_grads = [] for grad, var in grads_and_vars: if grad is not None: # Add batch dimension. batch_grad = tf.expand_dims(input=grad, axis=0) # Add along axis for that gpu. gpu_grads.append(batch_grad) if not gpu_grads: continue aggregate_grads = tf.concat(axis=0, values=gpu_grads) mean_grad = tf.reduce_mean(input_tensor=aggregate_grads, axis=0) # Need the actual main policy vars, as these are the ones that should be updated. # TODO: This is a hack and needs to be changed, but it works for now to look up main policy variables. main_variable_key = re.sub( r'{}/tower-0/'.format(self.global_scope), "", grads_and_vars[0][1].op.name) main_variable_key = re.sub(r'/', "-", main_variable_key) var = variables_by_component[component_key][ main_variable_key] gpu_grad_averages.append((mean_grad, var)) ret[component_key] = DataOpTuple(gpu_grad_averages) return DataOpDict(ret)
def _graph_fn_merge(self, *inputs): """ Merges the inputs into a single DataOpDict OR DataOpTuple with the flat keys given in `self.dict_keys`. Args: *inputs (FlattenedDataOp): The input items to be merged into a ContainerDataOp. Returns: ContainerDataOp: The DataOpDict or DataOpTuple as a merger of all *inputs. """ if self.num_items is not None: ret = list() for op in inputs: ret.append(op) return DataOpTuple(ret) else: ret = DataOpDict() for i, op in enumerate(inputs): ret[self.dict_keys[i]] = op return ret
def _graph_fn_merge(self, *inputs): """ Merges the inputs into a single DataOpDict OR DataOpTuple with the flat keys given in `self.dict_keys`. Args: *inputs (FlattenedDataOp): The input items to be merged into a ContainerDataOp. Returns: ContainerDataOp: The DataOpDict or DataOpTuple as a merger of all *inputs. """ if self.num_items is not None: ret = [] for op in inputs: ret.append(op) return DataOpTuple(ret) else: ret = DataOpDict() for i, op in enumerate(inputs): if get_backend() == "pytorch" and self.execution_mode == "define_by_run": ret[FLATTEN_SCOPE_PREFIX + self.dict_keys[i]] = op else: ret[self.dict_keys[i]] = op return ret
def dtype(self): return DataOpDict([(key, subspace.dtype) for key, subspace in self.items()])
def _graph_fn_step(self): if get_backend() == "tf": def scan_func(accum, time_delta): # Not needed: preprocessed-previous-states (tuple!) # `state` is a tuple as well. See comment in ctor for why tf cannot use ContainerSpaces here. internal_states = None state = accum[1] if self.has_rnn: internal_states = accum[-1] state = tuple(tf.convert_to_tensor(value=s) for s in state) flat_state = OrderedDict() for i, flat_key in enumerate( self.state_space_actor_flattened.keys()): # Add a simple (size 1) batch rank to the state so it'll pass through the NN. # - Also have to add a time-rank for RNN processing. expanded = state[i] for _ in range(1 if self.has_rnn is False else 2): expanded = tf.expand_dims(input=expanded, axis=0) # Make None so it'll be recognized as batch-rank by the auto-Space detector. flat_state[flat_key] = tf.placeholder_with_default( input=expanded, shape=(None, ) + ((None, ) if self.has_rnn is True else ()) + self.state_space_actor_list[i].shape) # Recreate state as the original Space to pass it into the actor-component. state = unflatten_op(flat_state) # Get action and preprocessed state (as batch-size 1). out = (self.actor_component.get_preprocessed_state_and_action if self.add_action_probs is False else self.actor_component. get_preprocessed_state_action_and_action_probs)( state, # Add simple batch rank to internal_states. None if internal_states is None else DataOpTuple( internal_states), # <- None for non-RNN systems time_step=self.time_step + time_delta, return_ops=True) # Get output depending on whether it contains internal_states or not. a = out["action"] action_probs = out.get("action_probs") current_internal_states = out.get("last_internal_states") # Strip the batch (and maybe time) ranks again from the action in case the Env doesn't like it. a_no_extra_ranks = a[0, 0] if self.has_rnn is True else a[0] # Step through the Env and collect next state (tuple!), reward and terminal as single values # (not batched). out = self.environment_server.step_for_env_stepper( a_no_extra_ranks) s_, r, t_ = out[:-2], out[-2], out[-1] r = tf.cast(r, dtype="float32") # Add a and/or r to next_state? if self.add_previous_action_to_state is True: assert isinstance( s_, tuple ), "ERROR: Cannot add previous action to non tuple!" s_ = s_ + (a_no_extra_ranks, ) if self.add_previous_reward_to_state is True: assert isinstance( s_, tuple ), "ERROR: Cannot add previous reward to non tuple!" s_ = s_ + (r, ) # Note: s_ is packed as tuple. ret = [t_, s_] + \ ([a_no_extra_ranks] if self.add_action else []) + \ ([r] if self.add_reward else []) + \ ([(action_probs[0][0] if self.has_rnn is True else action_probs[0])] if self.add_action_probs is True else []) + \ ([tuple(current_internal_states)] if self.has_rnn is True else []) return tuple(ret) # Initialize the tf.scan run. initializer = [ self.current_terminal.read_value( ), # whether the current state is terminal # current (raw) state (flattened components if ContainerSpace). tuple( map(lambda x: x.read_value(), self.current_state.values())) ] # Append actions and rewards if needed. if self.add_action: initializer.append(self.current_action.read_value()) if self.add_reward: initializer.append(self.current_reward.read_value()) # Append action probs if needed. if self.add_action_probs is True: initializer.append(self.current_action_probs.read_value()) # Append internal states if needed. if self.current_internal_states is not None: initializer.append( tuple( tf.placeholder_with_default( internal_s.read_value(), shape=(None, ) + tuple(internal_s.shape.as_list()[1:])) for internal_s in self.current_internal_states.values())) # Scan over n time-steps (tf.range produces the time_delta with respect to the current time_step). # NOTE: Changed parallel to 1, to resolve parallel issues. step_results = list( tf.scan(fn=scan_func, elems=tf.range(self.num_steps, dtype="int32"), initializer=tuple(initializer), back_prop=False)) # Store the time-step increment, return so far, current terminal and current state. assigns = [ tf.assign_add(self.time_step, self.num_steps), self.assign_variable(self.current_terminal, step_results[0][-1]) ] # Concatenate first and rest. full_results = [] for first_values, rest_values in zip(initializer, step_results): full_results.append( nest.map_structure( lambda first, rest: tf.concat([[first], rest], axis=0), first_values, rest_values)) # Re-build DataOpDicts from preprocessed-states and states (from tuple right now). rebuild_s = DataOpDict() for flat_key, var_ref, s_comp in zip( self.state_space_actor_flattened.keys(), self.current_state.values(), full_results[1]): assigns.append(self.assign_variable( var_ref, s_comp[-1])) # -1: current state (last observed) rebuild_s[flat_key] = s_comp rebuild_s = unflatten_op(rebuild_s) full_results[1] = rebuild_s # Remove batch rank from internal states again. if self.current_internal_states is not None: # TODO: What if internal states is not the last item in the list anymore due to some change. slot = -1 # if self.add_action_probs is True else 2 # TODO: What if internal states is a dict? Right now assume some tuple. internal_states_wo_batch = list() for i in range(len(full_results[slot])): # 1=batch axis (which is 1); 0=time axis. internal_states_wo_batch.append( tf.squeeze(full_results[-1][i], axis=1)) full_results[slot] = DataOpTuple(internal_states_wo_batch) with tf.control_dependencies(control_inputs=assigns): # Let the auto-infer system know, what time rank we have. full_results = DataOpTuple(full_results) for o in flatten_op(full_results).values(): o._time_rank = 0 # which position in the shape is the time-rank? step_op = tf.no_op() return step_op, full_results
def _graph_fn_merge_actions(self, a, b, c): return DataOpDict(a=a, b=b, c=c)
def define_by_run_unflatten(result_dict): """ Takes a dict with auto-generated keys and returns the corresponding unflattened dict. If the only key in the input dict is "", it returns the value under that key. Args: result_dict (dict): The item to be unflattened (re-nested). Returns: Dict: The unflattened (re-nested) item. """ # Special case: Dict with only 1 value (key="") if len(result_dict) == 1 and "" in result_dict: return result_dict[""] # Normal case: OrderedDict that came from a ContainerItem. base_structure = None op_names = sorted(result_dict.keys()) for op_name in op_names: op_val = result_dict[op_name] parent_structure = None parent_key = None current_structure = None op_type = None if op_name.startswith(FLATTEN_SCOPE_PREFIX): op_name = op_name[1:] # N.b. removed this because we do not prepend / any more before first key. op_key_list = op_name.split(FLATTEN_SCOPE_PREFIX) # skip 1st char (/) for sub_key in op_key_list: mo = re.match(r'^{}(\d+){}$'.format(FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE), sub_key) if mo: op_type = list idx = int(mo.group(1)) else: op_type = OrderedDict idx = sub_key if current_structure is None: if base_structure is None: base_structure = [None] if op_type == list else DataOpDict() current_structure = base_structure elif parent_key is not None: if (isinstance(parent_structure, list) and (parent_structure[parent_key] is None)) or \ (isinstance(parent_structure, DataOpDict) and parent_key not in parent_structure): current_structure = [None] if op_type == list else DataOpDict() parent_structure[parent_key] = current_structure else: current_structure = parent_structure[parent_key] if op_type == list and len(current_structure) == idx: current_structure.append(None) parent_structure = current_structure parent_key = idx if isinstance(parent_structure, list) and len(parent_structure) == parent_key: parent_structure.append(None) if op_type == list and len(current_structure) == parent_key: current_structure.append(None) current_structure[parent_key] = op_val # Deep conversion from list to tuple. # TODO necessary in define by run? return deep_tuple(base_structure)
def _graph_fn_update_from_external_batch(root, preprocessed_states, actions, rewards, terminals, sequence_indices, apply_postprocessing): """ Calls iterative optimization by repeatedly sub-sampling. """ multi_gpu_sync_optimizer = root.sub_components.get( "multi-gpu-synchronizer") # Return values. loss, loss_per_item, vf_loss, vf_loss_per_item = None, None, None, None policy = root.get_sub_component_by_name(agent.policy.scope) value_function = root.get_sub_component_by_name( agent.value_function.scope) optimizer = root.get_sub_component_by_name(agent.optimizer.scope) loss_function = root.get_sub_component_by_name( agent.loss_function.scope) value_function_optimizer = root.get_sub_component_by_name( agent.value_function_optimizer.scope) vars_merger = root.get_sub_component_by_name( agent.vars_merger.scope) gae_function = root.get_sub_component_by_name( agent.gae_function.scope) prev_log_probs = policy.get_log_likelihood( preprocessed_states, actions)["log_likelihood"] if get_backend() == "tf": # Log probs before update. prev_log_probs = tf.stop_gradient(prev_log_probs) batch_size = tf.shape(preprocessed_states)[0] prior_baseline_values = tf.stop_gradient( value_function.value_output(preprocessed_states)) # Advantages are based on prior baseline values. advantages = tf.cond( pred=apply_postprocessing, true_fn=lambda: gae_function.calc_gae_values( prior_baseline_values, rewards, terminals, sequence_indices), false_fn=lambda: rewards) if self.standardize_advantages: mean, std = tf.nn.moments(x=advantages, axes=[0]) advantages = (advantages - mean) / std def opt_body(index_, loss_, loss_per_item_, vf_loss_, vf_loss_per_item_): start = tf.random_uniform(shape=(), minval=0, maxval=batch_size - 1, dtype=tf.int32) indices = tf.range( start=start, limit=start + agent.sample_size) % batch_size sample_states = tf.gather(params=preprocessed_states, indices=indices) if isinstance(actions, ContainerDataOp): sample_actions = FlattenedDataOp() for name, action in flatten_op(actions).items(): sample_actions[name] = tf.gather(params=action, indices=indices) sample_actions = unflatten_op(sample_actions) else: sample_actions = tf.gather(params=actions, indices=indices) sample_prior_log_probs = tf.gather(params=prev_log_probs, indices=indices) sample_rewards = tf.gather(params=rewards, indices=indices) sample_terminals = tf.gather(params=terminals, indices=indices) sample_sequence_indices = tf.gather( params=sequence_indices, indices=indices) sample_advantages = tf.gather(params=advantages, indices=indices) sample_advantages.set_shape((self.sample_size, )) sample_baseline_values = value_function.value_output( sample_states) sample_prior_baseline_values = tf.gather( params=prior_baseline_values, indices=indices) # If we are a multi-GPU root: # Simply feeds everything into the multi-GPU sync optimizer's method and return. if multi_gpu_sync_optimizer is not None: main_policy_vars = agent.policy.variables() main_vf_vars = agent.value_function.variables() all_vars = agent.vars_merger.merge( main_policy_vars, main_vf_vars) # grads_and_vars, loss, loss_per_item, vf_loss, vf_loss_per_item = \ out = multi_gpu_sync_optimizer.calculate_update_from_external_batch( all_vars, sample_states, sample_actions, sample_rewards, sample_terminals, sample_sequence_indices, apply_postprocessing=apply_postprocessing) avg_grads_and_vars_policy, avg_grads_and_vars_vf = agent.vars_splitter.call( out["avg_grads_and_vars_by_component"]) policy_step_op = agent.optimizer.apply_gradients( avg_grads_and_vars_policy) vf_step_op = agent.value_function_optimizer.apply_gradients( avg_grads_and_vars_vf) step_op = root._graph_fn_group(policy_step_op, vf_step_op) step_and_sync_op = multi_gpu_sync_optimizer.sync_variables_to_towers( step_op, all_vars) loss_vf, loss_per_item_vf = out[ "additional_return_0"], out["additional_return_1"] # Have to set all shapes here due to strict loop-var shape requirements. out["loss"].set_shape(()) loss_vf.set_shape(()) loss_per_item_vf.set_shape((agent.sample_size, )) out["loss_per_item"].set_shape((agent.sample_size, )) with tf.control_dependencies([step_and_sync_op]): if index_ == 0: # Increase the global training step counter. out["loss"] = root._graph_fn_training_step( out["loss"]) return index_ + 1, out["loss"], out[ "loss_per_item"], loss_vf, loss_per_item_vf policy_probs = policy.get_log_likelihood( sample_states, sample_actions)["log_likelihood"] baseline_values = value_function.value_output( tf.stop_gradient(sample_states)) sample_rewards = tf.cond( pred=apply_postprocessing, true_fn=lambda: gae_function.calc_gae_values( baseline_values, sample_rewards, sample_terminals, sample_sequence_indices), false_fn=lambda: sample_rewards) sample_rewards.set_shape((agent.sample_size, )) entropy = policy.get_entropy(sample_states)["entropy"] loss, loss_per_item, vf_loss, vf_loss_per_item = \ loss_function.loss( policy_probs, sample_prior_log_probs, sample_baseline_values, sample_prior_baseline_values, sample_advantages, entropy ) if hasattr(root, "is_multi_gpu_tower" ) and root.is_multi_gpu_tower is True: policy_grads_and_vars = optimizer.calculate_gradients( policy.variables(), loss) vf_grads_and_vars = value_function_optimizer.calculate_gradients( value_function.variables(), vf_loss) grads_and_vars_by_component = vars_merger.merge( policy_grads_and_vars, vf_grads_and_vars) return grads_and_vars_by_component, loss, loss_per_item, vf_loss, vf_loss_per_item else: step_op, loss, loss_per_item = optimizer.step( policy.variables(), loss, loss_per_item) loss.set_shape(()) loss_per_item.set_shape((agent.sample_size, )) vf_step_op, vf_loss, vf_loss_per_item = value_function_optimizer.step( value_function.variables(), vf_loss, vf_loss_per_item) vf_loss.set_shape(()) vf_loss_per_item.set_shape((agent.sample_size, )) with tf.control_dependencies([step_op, vf_step_op]): return index_ + 1, loss, loss_per_item, vf_loss, vf_loss_per_item def cond(index_, loss_, loss_per_item_, v_loss_, v_loss_per_item_): return index_ < agent.iterations init_loop_vars = [ 0, tf.zeros(shape=(), dtype=tf.float32), tf.zeros(shape=(agent.sample_size, )), tf.zeros(shape=(), dtype=tf.float32), tf.zeros(shape=(agent.sample_size, )) ] if hasattr(root, "is_multi_gpu_tower" ) and root.is_multi_gpu_tower is True: return opt_body(*init_loop_vars) else: index, loss, loss_per_item, vf_loss, vf_loss_per_item = tf.while_loop( cond=cond, body=opt_body, loop_vars=init_loop_vars, parallel_iterations=1) # Increase the global training step counter. loss = root._graph_fn_training_step(loss) return loss, loss_per_item, vf_loss, vf_loss_per_item elif get_backend() == "pytorch": if isinstance(prev_log_probs, dict): for name in actions.keys(): prev_log_probs[name] = prev_log_probs[name].detach() else: prev_log_probs = prev_log_probs.detach() batch_size = preprocessed_states.shape[0] sample_size = min(batch_size, agent.sample_size) prior_baseline_values = value_function.value_output( preprocessed_states).detach() if apply_postprocessing: advantages = gae_function.calc_gae_values( prior_baseline_values, rewards, terminals, sequence_indices) else: advantages = rewards if self.standardize_advantages: advantages = (advantages - torch.mean(advantages) ) / torch.std(advantages) for _ in range(agent.iterations): start = int(torch.rand(1) * (batch_size - 1)) indices = torch.arange(start=start, end=start + sample_size, dtype=torch.long) % batch_size sample_states = torch.index_select(preprocessed_states, 0, indices) if isinstance(actions, dict): sample_actions = DataOpDict() sample_prior_log_probs = DataOpDict() for name, action in define_by_run_flatten( actions, scope_separator_at_start=False).items(): sample_actions[name] = torch.index_select( action, 0, indices) sample_prior_log_probs[name] = torch.index_select( prev_log_probs[name], 0, indices) else: sample_actions = torch.index_select( actions, 0, indices) sample_prior_log_probs = torch.index_select( prev_log_probs, 0, indices) sample_advantages = torch.index_select( advantages, 0, indices) sample_prior_baseline_values = torch.index_select( prior_baseline_values, 0, indices) policy_probs = policy.get_log_likelihood( sample_states, sample_actions)["log_likelihood"] sample_baseline_values = value_function.value_output( sample_states) entropy = policy.get_entropy(sample_states)["entropy"] loss, loss_per_item, vf_loss, vf_loss_per_item = loss_function.loss( policy_probs, sample_prior_log_probs, sample_baseline_values, sample_prior_baseline_values, sample_advantages, entropy) # Do not need step op. _, loss, loss_per_item = optimizer.step( policy.variables(), loss, loss_per_item) _, vf_loss, vf_loss_per_item = \ value_function_optimizer.step(value_function.variables(), vf_loss, vf_loss_per_item) return loss, loss_per_item, vf_loss, vf_loss_per_item
def opt_body(index_, loss_, loss_per_item_, vf_loss_, vf_loss_per_item_): start = tf.random_uniform(shape=(), minval=0, maxval=batch_size - 1, dtype=tf.int32) indices = tf.range( start=start, limit=start + agent.sample_size) % batch_size sample_states = tf.gather(params=preprocessed_states, indices=indices) if isinstance(actions, dict): sample_actions = DataOpDict() sample_prior_log_probs = DataOpDict() for name, action in flatten_op( actions, scope_separator_at_start=False).items(): sample_actions[name] = tf.gather(params=action, indices=indices) sample_prior_log_probs[name] = tf.gather( params=prev_log_probs[name], indices=indices) else: sample_actions = tf.gather(params=actions, indices=indices) sample_prior_log_probs = tf.gather( params=prev_log_probs, indices=indices) sample_rewards = tf.gather(params=rewards, indices=indices) sample_terminals = tf.gather(params=terminals, indices=indices) sample_sequence_indices = tf.gather( params=sequence_indices, indices=indices) # If we are a multi-GPU root: # Simply feeds everything into the multi-GPU sync optimizer's method and return. if multi_gpu_sync_optimizer is not None: main_policy_vars = agent.policy.variables() main_vf_vars = agent.value_function.variables() all_vars = agent.vars_merger.merge( main_policy_vars, main_vf_vars) # grads_and_vars, loss, loss_per_item, vf_loss, vf_loss_per_item = \ out = multi_gpu_sync_optimizer.calculate_update_from_external_batch( all_vars, sample_states, sample_actions, sample_rewards, sample_terminals, sample_sequence_indices, apply_postprocessing=apply_postprocessing) avg_grads_and_vars_policy, avg_grads_and_vars_vf = agent.vars_splitter.split( out["avg_grads_and_vars_by_component"]) policy_step_op = agent.optimizer.apply_gradients( avg_grads_and_vars_policy) vf_step_op = agent.value_function_optimizer.apply_gradients( avg_grads_and_vars_vf) step_op = root._graph_fn_group(policy_step_op, vf_step_op) step_and_sync_op = multi_gpu_sync_optimizer.sync_variables_to_towers( step_op, all_vars) loss_vf, loss_per_item_vf = out[ "additional_return_0"], out["additional_return_1"] # Have to set all shapes here due to strict loop-var shape requirements. out["loss"].set_shape(()) loss_vf.set_shape(()) loss_per_item_vf.set_shape((agent.sample_size, )) out["loss_per_item"].set_shape((agent.sample_size, )) with tf.control_dependencies([step_and_sync_op]): if index_ == 0: # Increase the global training step counter. out["loss"] = root._graph_fn_training_step( out["loss"]) return index_ + 1, out["loss"], out[ "loss_per_item"], loss_vf, loss_per_item_vf policy_probs = policy.get_action_log_probs( sample_states, sample_actions) baseline_values = value_function.value_output( tf.stop_gradient(sample_states)) sample_rewards = tf.cond( pred=apply_postprocessing, true_fn=lambda: gae_function.calc_gae_values( baseline_values, sample_rewards, sample_terminals, sample_sequence_indices), false_fn=lambda: sample_rewards) sample_rewards.set_shape((agent.sample_size, )) entropy = policy.get_entropy(sample_states)["entropy"] loss, loss_per_item, vf_loss, vf_loss_per_item = \ loss_function.loss( policy_probs["action_log_probs"], sample_prior_log_probs, baseline_values, sample_rewards, entropy ) if hasattr(root, "is_multi_gpu_tower" ) and root.is_multi_gpu_tower is True: policy_grads_and_vars = optimizer.calculate_gradients( policy.variables(), loss) vf_grads_and_vars = value_function_optimizer.calculate_gradients( value_function.variables(), vf_loss) grads_and_vars_by_component = vars_merger.merge( policy_grads_and_vars, vf_grads_and_vars) return grads_and_vars_by_component, loss, loss_per_item, vf_loss, vf_loss_per_item else: step_op, loss, loss_per_item = optimizer.step( policy.variables(), loss, loss_per_item) loss.set_shape(()) loss_per_item.set_shape((agent.sample_size, )) vf_step_op, vf_loss, vf_loss_per_item = value_function_optimizer.step( value_function.variables(), vf_loss, vf_loss_per_item) vf_loss.set_shape(()) vf_loss_per_item.set_shape((agent.sample_size, )) with tf.control_dependencies([step_op, vf_step_op]): return index_ + 1, loss, loss_per_item, vf_loss, vf_loss_per_item