def all_trees_rec_bb(taxa, cm, og): '''Return all the trees that can be made out of the set of taxa taxa, recursively. The recursive implementation was necessary because otherwise taxon was not scoped correctly.''' if len(taxa) == 1: return [tree.tree(id=taxa.pop())] else: taxon = taxa.pop() trees = (trees_adding_bb(t, taxon, cm, og) for t in all_trees_rec_bb(taxa, cm, og)) return tree.flatten(trees)
def add_init_obs( self, episode_id: EpisodeID, agent_index: int, env_id: EnvID, t: int, init_obs: TensorType, ) -> None: """Adds an initial observation (after reset) to the Agent's trajectory. Args: episode_id (EpisodeID): Unique ID for the episode we are adding the initial observation for. agent_index (int): Unique int index (starting from 0) for the agent within its episode. Not to be confused with AGENT_ID (Any). env_id (EnvID): The environment index (in a vectorized setup). t (int): The time step (episode length - 1). The initial obs has ts=-1(!), then an action/reward/next-obs at t=0, etc.. init_obs (TensorType): The initial observation tensor (after `env.reset()`). """ # Store episode ID + unroll ID, which will be constant throughout this # AgentCollector's lifecycle. self.episode_id = episode_id if self.unroll_id is None: self.unroll_id = _AgentCollector._next_unroll_id _AgentCollector._next_unroll_id += 1 if SampleBatch.OBS not in self.buffers: self._build_buffers( single_row={ SampleBatch.OBS: init_obs, SampleBatch.AGENT_INDEX: agent_index, SampleBatch.ENV_ID: env_id, SampleBatch.T: t, SampleBatch.EPS_ID: self.episode_id, SampleBatch.UNROLL_ID: self.unroll_id, } ) # Append data to existing buffers. flattened = tree.flatten(init_obs) for i, sub_obs in enumerate(flattened): self.buffers[SampleBatch.OBS][i].append(sub_obs) self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index) self.buffers[SampleBatch.ENV_ID][0].append(env_id) self.buffers[SampleBatch.T][0].append(t) self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id) self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
def append_sequence(self, sequence: Any): """Appends sequence of data to the internal buffer. Each element in `sequence` must have the same leading dimension [T]. A call to `append_sequence` is equivalent to splitting `sequence` along its first dimension and calling `append` once for each slice. For example: ```python with client.writer(max_sequence_length=2) as writer: sequence = np.array([[1, 2, 3], [4, 5, 6]]) # Insert two timesteps. writer.append_sequence([sequence]) # Create an item that references the step [4, 5, 6]. writer.create_item('my_table', num_timesteps=1, priority=1.0) # Create an item that references the steps [1, 2, 3] and [4, 5, 6]. writer.create_item('my_table', num_timesteps=2, priority=1.0) ``` Is equivalent to: ```python with client.writer(max_sequence_length=2) as writer: # Insert two timesteps. writer.append([np.array([1, 2, 3])]) writer.append([np.array([4, 5, 6])]) # Create an item that references the step [4, 5, 6]. writer.create_item('my_table', num_timesteps=1, priority=1.0) # Create an item that references the steps [1, 2, 3] and [4, 5, 6]. writer.create_item('my_table', num_timesteps=2, priority=1.0) ``` Args: sequence: Batched (possibly nested) structure to make available for items to reference. """ self._writer.AppendSequence(tree.flatten(sequence))
def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: orig_obs = restore_original_dimensions( input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="tf" ) # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): if i in self.cnns: cnn_out, _ = self.cnns[i](SampleBatch({SampleBatch.OBS: component})) outs.append(cnn_out) elif i in self.one_hot: if "int" in component.dtype.name: one_hot_in = { SampleBatch.OBS: one_hot( component, self.flattened_input_space[i] ) } else: one_hot_in = {SampleBatch.OBS: component} one_hot_out, _ = self.one_hot[i](SampleBatch(one_hot_in)) outs.append(one_hot_out) else: nn_out, _ = self.flatten[i]( SampleBatch( { SampleBatch.OBS: tf.cast( tf.reshape(component, [-1, self.flatten_dims[i]]), tf.float32, ) } ) ) outs.append(nn_out) # Concat all outputs and the non-image inputs. out = tf.concat(outs, axis=1) # Push through (optional) FC-stack (this may be an empty stack). out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out})) # No logits/value branches. if not self.logits_and_value_model: return out, [] # Logits- and value branches. logits, values = self.logits_and_value_model(out) self._value_out = tf.reshape(values, [-1]) return logits, []
def explicit_ntk( fwd_fn: Callable, params: hk.Params, sigma: hk.Params, diag=False, ) -> jnp.ndarray: """Calculate J * diag(sigma) * J^T, where J is Jacobian of model with respect to model parameters using explicit implementation and einsum Args: fwd_fn: a function that only takes in parameters and returns model output of shape (batch_dim, output_dim). params: the model parameters. sigma: it has the same structure and array shapes as the parameters of model. diag: if True, only calculating the diagonal of NTK. Returns: jnp.ndarray, array of shape (batch_dim, output_dim) if diag==True else (batch_dim, output_dim, batch_dim, output_dim) """ jacobian = jax.jacobian(fwd_fn)(params) def _get_diag_ntk(jac, sigma): # jac has shape (batch_dim, output_dim, params_dims...) # jac_2d has shape (batch_dim * output_dim, nb_params) batch_dim, output_dim = jac.shape[:2] jac_2d = jnp.reshape(jac, (batch_dim * output_dim, -1)) # sigma_flatten has shape (nb_params,) and will be broadcasted to the same # shape as jac_2d sigma_flatten = jnp.reshape(sigma, (-1, )) # jac_sigma_product has the same shape as jac_2d jac_sigma_product = jnp.multiply(jac_2d, sigma_flatten) # diag_ntk has shape (batch_dim * output_dim,) if diag: ntk = jnp.einsum("ij,ji->i", jac_sigma_product, jac_2d.T) ntk = jnp.reshape(ntk, (batch_dim, output_dim)) # ntk has shape (batch_dim, output_dim) else: ntk = jnp.matmul(jac_sigma_product, jac_2d.T) ntk = jnp.reshape(ntk, (batch_dim, output_dim, batch_dim, output_dim)) # ntk has shape (batch_dim, output_dim, batch_dim, output_dim) return ntk diag_ntk = tree.map_structure(_get_diag_ntk, jacobian, sigma) diag_ntk_sum_array = jnp.stack(tree.flatten(diag_ntk), axis=0).sum(axis=0) return diag_ntk_sum_array
def add_path(self, path): path = path.copy() path_length = tree.flatten(path)[0].shape[0] path.update({ 'episode_index_forwards': np.arange( path_length, dtype=self.fields['episode_index_forwards'].dtype )[..., np.newaxis], 'episode_index_backwards': np.arange( path_length, dtype=self.fields['episode_index_backwards'].dtype )[::-1, np.newaxis], }) return self.add_samples(path)
def test_bias_parameter_shape(self): params = self.network.init(self.init_rng_key, _sample_input(self.input_shape)) self.assertLen(tree.flatten(params), 2) def check_params(path, param): if path[-1] == 'b': self.assertNotEqual(self.output_shape, param.shape) chex.assert_shape(param, (1, )) elif path[-1] == 'w': chex.assert_shape(param, self.weights_shape) else: self.fail('Unexpected parameter %s.' % path) tree.map_structure_with_path(check_params, params)
def actions_np(self, observations): if self._deterministic: return self.deterministic_actions_model.predict(observations) elif self._smoothing_alpha == 0: return self.actions_model.predict(observations) else: alpha, beta = self._smoothing_alpha, self._smoothing_beta raw_latents = self.latents_model.predict(observations) self._smoothing_x = (alpha * self._smoothing_x + (1.0 - alpha) * raw_latents) latents = beta * self._smoothing_x inputs = tree.flatten((observations, latents)) # TODO(hartikainen/tf2): This can be removed once we use .numpy() return self.actions_model_for_fixed_latents.predict(inputs)
def actions(self, observations): """Compute actions for given observations.""" observations = self._filter_observations(observations) first_observation = tree.flatten(observations)[0] first_input_rank = tf.size(tree.flatten(self._input_shapes)[0]) batch_shape = tf.shape(first_observation)[:-first_input_rank] shifts, scales = self.shift_and_scale_model(observations) if self._deterministic: actions = self._action_post_processor(shifts) else: actions = self.action_distribution.sample(batch_shape, bijector_kwargs={ 'scale': { 'scale': scales }, 'shift': { 'shift': shifts } }) return actions
def test_ema_on_changing_data(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2)(x) init_fn, apply_fn = base.transform(g, state=True) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # ema_params should be different from changed params! tree.assert_same_structure(changed_params, ema_params) for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def compute_actions_from_input_dict( self, input_dict: Dict[str, TensorType], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List[Episode]] = None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: if not self.config.get( "eager_tracing") and not tf1.executing_eagerly(): tf1.enable_eager_execution() self._is_training = False explore = explore if explore is not None else self.explore timestep = timestep if timestep is not None else self.global_timestep if isinstance(timestep, tf.Tensor): timestep = int(timestep.numpy()) # Pass lazy (eager) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) input_dict.set_training(False) # Pack internal state inputs into (separate) list. state_batches = [ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] ] self._state_in = state_batches self._is_recurrent = state_batches != [] # Call the exploration before_compute_actions hook. self.exploration.before_compute_actions(timestep=timestep, explore=explore, tf_sess=self.get_session()) ret = self._compute_actions_helper( input_dict, state_batches, # TODO: Passing episodes into a traced method does not work. None if self.config["eager_tracing"] else episodes, explore, timestep, ) # Update our global timestep by the batch size. self.global_timestep.assign_add( tree.flatten(ret[0])[0].shape.as_list()[0]) return convert_to_numpy(ret)
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> None: """Adds the given dictionary (row) of values to the Agent's trajectory. Args: values (Dict[str, TensorType]): Data dict (interpreted as a single row) to be added to buffer. Must contain keys: SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS. """ if self.unroll_id is None: self.unroll_id = _AgentCollector._next_unroll_id _AgentCollector._next_unroll_id += 1 # Next obs -> obs. assert SampleBatch.OBS not in values values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS] del values[SampleBatch.NEXT_OBS] # Make sure EPS_ID/UNROLL_ID stay the same for this agent. if SampleBatch.EPS_ID in values: assert values[SampleBatch.EPS_ID] == self.episode_id del values[SampleBatch.EPS_ID] self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id) if SampleBatch.UNROLL_ID in values: assert values[SampleBatch.UNROLL_ID] == self.unroll_id del values[SampleBatch.UNROLL_ID] self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id) for k, v in values.items(): if k not in self.buffers: self._build_buffers(single_row=values) # Do not flatten infos, state_out_ and (if configured) actions. # Infos/state-outs may be structs that change from timestep to # timestep. if ( k == SampleBatch.INFOS or k.startswith("state_out_") or ( k == SampleBatch.ACTIONS and not self.policy.config.get("_disable_action_flattening") ) ): self.buffers[k][0].append(v) # Flatten all other columns. else: flattened = tree.flatten(v) for i, sub_list in enumerate(self.buffers[k]): sub_list.append(flattened[i]) self.agent_steps += 1
async def materialize_value(value: Any) -> Any: """Returns a structure of materialized values. Args: value: A materialized value, a value reference, or structure materialized values and value references to materialize. """ async def _materialize(value: Any) -> Any: if isinstance(value, MaterializableValueReference): return await value.get_value() else: return value flattened = tree.flatten(value) flattened = await asyncio.gather(*[_materialize(v) for v in flattened]) return tree.unflatten_as(value, flattened)
def actions(self, observations): if 0 < self._smoothing_alpha: raise NotImplementedError( "TODO(hartikainen): Smoothing alpha temporarily dropped on tf2" " migration. Should add it back. See:" " https://github.com/rail-berkeley/softlearning/blob/46374df0294b9b5f6dbe65b9471ec491a82b6944/softlearning/policies/base_policy.py#L80") observations = self._filter_observations(observations) batch_shape = tf.shape(tree.flatten(observations)[0])[:-1] actions = self.action_distribution.sample( batch_shape, bijector_kwargs={ self.flow_model.name: {'observations': observations} }) return actions
def batch_concat(inputs: types.NestedTensor) -> tf.Tensor: """Concatenate a collection of Tensors while preserving the batch dimension. This takes a potentially nested collection of tensors, flattens everything but the batch (first) dimension, and concatenates along the resulting data (second) dimension. Args: inputs: a tensor or nested collection of tensors. Returns: A concatenated tensor which maintains the batch dimension but concatenates all other data along the flattened second dimension. """ flat_leaves = tree.map_structure(snt.Flatten(), inputs) return tf.concat(tree.flatten(flat_leaves), axis=-1)
def _tree_merge_into(source, target): """Update `target` with content of substructure `source`.""" path_to_index = { path: i for i, (path, _) in enumerate(tree.flatten_with_path(target)) } flat_target = tree.flatten(target) for path, leaf in tree.flatten_with_path(source): if path not in path_to_index: raise ValueError( f'Cannot expand {source} into {target} as it is not a sub structure.' ) flat_target[path_to_index[path]] = leaf return tree.unflatten_as(target, flat_target)
def preprocess_value(params, value, weight): vectors = tree.map_structure(lambda v: tf.reshape(v, [-1]), value) norm = tf.norm(tf.concat(tree.flatten(vectors), axis=0), ord=norm_order) quantile_record = quantile_query.preprocess_record(params, norm) threshold = params.current_estimate * multiplier too_large = (norm > threshold) adj_weight = tf.cond(too_large, lambda: tf.constant(0.0), lambda: weight) weighted_value = tree.map_structure( lambda v: tf.math.multiply_no_nan(v, adj_weight), value) too_large = tf.cast(too_large, tf.int32) return weighted_value, adj_weight, quantile_record, too_large
def random_crop_and_resize(images, target_size, min_crop_fraction=0.5, crop_size=None, methods=tf.image.ResizeMethod.BILINEAR): """All tensors are cropped to the same size and resized to target size.""" if not isinstance(target_size, (tuple, list)): target_size = [target_size, target_size] aspect_ratio = target_size[0] / target_size[1] batch_size = tf.shape(tree.flatten(images)[0])[0] bboxes = get_random_bounding_box(hw=tf.ones((batch_size, 2)), aspect_ratio=aspect_ratio, min_crop_fraction=min_crop_fraction, new_height=crop_size, dtype=tf.float32) return crop_and_resize(images, bboxes, target_size, methods)
def __init__(self, environment: dm_env.Environment, random_prob: float = 0.): super().__init__(environment) self._random_prob = random_prob if not 0 <= random_prob <= 1: raise ValueError( f'random_prob ({random_prob}) must be within [0, 1]') self._action_spec = self._environment.action_spec() if not all( map(lambda spec: isinstance(spec, specs.DiscreteArray), tree.flatten(self._action_spec))): raise ValueError( 'RandomActionWrapper requires the action_spec to be a ' 'DiscreteArray or a nested DiscreteArray.')
def __init__( self, actor_id, environment_module, environment_fn_name, environment_kwargs, network_module, network_fn_name, network_kwargs, adder_module, adder_fn_name, adder_kwargs, replay_server_address, variable_server_name, variable_server_address, counter: counting.Counter = None, logger: loggers.Logger = None, ): # Counter and Logger self._actor_id = actor_id self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( f'actor_{actor_id}') # Create the environment self._environment = getattr(environment_module, environment_fn_name)(**environment_kwargs) env_spec = acme.make_environment_spec(self._environment) # Create actor's network self._network = getattr(network_module, network_fn_name)(**network_kwargs) tf2_utils.create_variables(self._network, [env_spec.observations]) self._variables = tree.flatten(self._network.variables) self._policy = tf.function(self._network) # The adder is used to insert observations into replay. self._adder = getattr(adder_module, adder_fn_name)( reverb.Client(replay_server_address), **adder_kwargs) variable_client = reverb.TFClient(variable_server_address) self._variable_dataset = variable_client.dataset( table=variable_server_name, dtypes=[tf.float32 for _ in self._variables], shapes=[v.shape for v in self._variables])
def add_path(self, path): # path: dict 6(1000) path = path.copy() # flatten 遇到dict只消除到 value 一级 path_length = tree.flatten(path)[0].shape[0] path.update({ 'episode_index_forwards': np.arange(path_length, dtype=self.fields['episode_index_forwards'].dtype)[ ..., np.newaxis], #[0...999] 'episode_index_backwards': np.arange(path_length, dtype=self.fields['episode_index_backwards'].dtype) [::-1, np.newaxis], ##[999...0] }) # path: dict(7(1000)) return self.add_samples(path)
def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="torch") # Push observations through the different components # (CNNs, one-hot + FC, etc..). outs = [] for i, component in enumerate(tree.flatten(orig_obs)): if i in self.cnns: cnn_out, _ = self.cnns[i](SampleBatch( {SampleBatch.OBS: component})) outs.append(cnn_out) elif i in self.one_hot: if component.dtype in [torch.int32, torch.int64, torch.uint8]: one_hot_in = { SampleBatch.OBS: one_hot(component, self.flattened_input_space[i]) } else: one_hot_in = {SampleBatch.OBS: component} one_hot_out, _ = self.one_hot[i](SampleBatch(one_hot_in)) outs.append(one_hot_out) else: nn_out, _ = self.flatten[i](SampleBatch({ SampleBatch.OBS: torch.reshape(component, [-1, self.flatten_dims[i]]) })) outs.append(nn_out) # Concat all outputs and the non-image inputs. out = torch.cat(outs, dim=1) # Push through (optional) FC-stack (this may be an empty stack). out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out})) # No logits/value branches. if self.logits_layer is None: return out, [] # Logits- and value branches. logits, values = self.logits_layer(out), self.value_layer(out) self._value_out = torch.reshape(values, [-1]) return logits, []
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: """Builds the buffers for sample collection, given an example data row. Args: single_row (Dict[str, TensorType]): A single row (keys=column names) of data to base the buffers on. """ for col, data in single_row.items(): if col in self.buffers: continue shift = self.shift_before - ( 1 if col in [ SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.ENV_ID, SampleBatch.T, SampleBatch.UNROLL_ID, ] else 0 ) # Store all data as flattened lists, except INFOS and state-out # lists. These are monolithic items (infos is a dict that # should not be further split, same for state-out items, which # could be custom dicts as well). if ( col == SampleBatch.INFOS or col.startswith("state_out_") or ( col == SampleBatch.ACTIONS and not self.policy.config.get("_disable_action_flattening") ) ): self.buffers[col] = [[data for _ in range(shift)]] else: self.buffers[col] = [ [v for _ in range(shift)] for v in tree.flatten(data) ] # Store an example data struct so we know, how to unflatten # each data col. self.buffer_structs[col] = data
def flatten_data(data: Dict[str, TensorStructType]): assert (type(data) == dict ), "Single agent data must be of type Dict[str, TensorStructType]" flattened = {} for k, v in data.items(): if k in [SampleBatch.INFOS, SampleBatch.ACTIONS ] or k.startswith("state_out_"): # Do not flatten infos, actions, and state_out_ columns. flattened[k] = v continue if v is None: # Keep the same column shape. flattened[k] = None continue flattened[k] = np.array(tree.flatten(v)) return flattened
def testEntropyGradients(self, is_multi_actions): if is_multi_actions: loss = self.multi_op.extra.entropy_loss policy_logits_nest = self.multi_policy_logits else: loss = self.op.extra.entropy_loss policy_logits_nest = self.policy_logits grad_policy_list = [ tf.gradients(loss, policy_logits)[0] * self.num_actions for policy_logits in nest.flatten(policy_logits_nest)] for grad_policy in grad_policy_list: self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3])) self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None]) self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs), self.invalid_grad_outputs)
def testFlattenAndUnflatten_withDicts(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. named_tuple = collections.namedtuple("A", ("b", "c")) mess = [ "z", named_tuple(3, 4), { "c": [ 1, collections.OrderedDict([ ("b", 3), ("a", 2), ]), ], "b": 5 }, 17 ] flattened = tree.flatten(mess) self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) structure_of_mess = [ 14, named_tuple("a", True), { "c": [ 0, collections.OrderedDict([ ("b", 9), ("a", 8), ]), ], "b": 3 }, "hi everybody", ] self.assertEqual(mess, tree.unflatten_as(structure_of_mess, flattened)) # Check also that the OrderedDict was created, with the correct key order. unflattened_ordered_dict = tree.unflatten_as(structure_of_mess, flattened)[2]["c"][1] self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
def __init__(self, client: core.VariableSource, variables: Mapping[str, Sequence[tf.Variable]], update_period: int = 1): self._keys = list(variables.keys()) self._variables = tree.flatten(list(variables.values())) self._call_counter = 0 self._update_period = update_period self._client = client self._request = lambda: client.get_variables(self._keys) # Create a single background thread to fetch variables without necessarily # blocking the actor. self._executor = futures.ThreadPoolExecutor(max_workers=1) self._async_request = lambda: self._executor.submit(self._request) # Initialize this client's future to None to indicate to the `update()` # method that there is no pending/running request. self._future: Optional[futures.Future] = None
def compute_actions( self, obs_batch: Union[List[TensorType], TensorType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Union[List[TensorType], TensorType] = None, prev_reward_batch: Union[List[TensorType], TensorType] = None, info_batch: Optional[Dict[str, list]] = None, episodes: Optional[List["Episode"]] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs, ): explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep builder = TFRunBuilder(self.get_session(), "compute_actions") input_dict = {SampleBatch.OBS: obs_batch, "is_training": False} if state_batches: for i, s in enumerate(state_batches): input_dict[f"state_in_{i}"] = s if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch to_fetch = self._build_compute_actions( builder, input_dict=input_dict, explore=explore, timestep=timestep ) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += ( len(obs_batch) if isinstance(obs_batch, list) else tree.flatten(obs_batch)[0].shape[0] ) return fetched
def flatten_data(data: Dict[str, TensorStructType]): assert isinstance( data, dict), "Single agent data must be of type Dict[str, TensorStructType]" flattened = {} for k, v in data.items(): if k in [SampleBatch.INFOS, SampleBatch.ACTIONS ] or k.startswith("state_out_"): # Do not flatten infos, actions, and state_out_ columns. flattened[k] = v continue if v is None: # Keep the same column shape. flattened[k] = None continue flattened[k] = np.array(tree.flatten(v)) flattened = SampleBatch(flattened, is_training=False) return AgentConnectorsOutput(data, flattened)
async def release(self, value: Any, type_signature: computation_types.Type, key: int) -> None: # pytype: disable=signature-mismatch """Releases `value` from a federated program. Args: value: A materialized value, a value reference, or a structure of materialized values and value references representing the value to release. type_signature: The `tff.Type` of `value`. key: An integer used to reference the released `value`. """ del type_signature # Unused. py_typecheck.check_type(key, int) path = self._get_path_for_key(key) materialized_value = await value_reference.materialize_value(value) flattened_value = tree.flatten(materialized_value) await file_utils.write_saved_model(flattened_value, path, overwrite=True)
def _build_signature_def(self): """Build signature def map for tensorflow SavedModelBuilder. """ # build input signatures input_signature = self._extra_input_signature_def() input_signature["observations"] = \ tf.saved_model.utils.build_tensor_info(self._obs_input) if self._seq_lens is not None: input_signature["seq_lens"] = \ tf.saved_model.utils.build_tensor_info(self._seq_lens) if self._prev_action_input is not None: input_signature["prev_action"] = \ tf.saved_model.utils.build_tensor_info(self._prev_action_input) if self._prev_reward_input is not None: input_signature["prev_reward"] = \ tf.saved_model.utils.build_tensor_info(self._prev_reward_input) input_signature["is_training"] = \ tf.saved_model.utils.build_tensor_info(self._is_training) for state_input in self._state_inputs: input_signature[state_input.name] = \ tf.saved_model.utils.build_tensor_info(state_input) # build output signatures output_signature = self._extra_output_signature_def() for i, a in enumerate(tree.flatten(self._sampled_action)): output_signature["actions_{}".format(i)] = \ tf.saved_model.utils.build_tensor_info(a) for state_output in self._state_outputs: output_signature[state_output.name] = \ tf.saved_model.utils.build_tensor_info(state_output) signature_def = ( tf.saved_model.signature_def_utils.build_signature_def( input_signature, output_signature, tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) signature_def_key = (tf.saved_model.signature_constants. DEFAULT_SERVING_SIGNATURE_DEF_KEY) signature_def_map = {signature_def_key: signature_def} return signature_def_map