def _action(self, time_step, policy_state, seed=None): first_obs = tf.nest.flatten(time_step.observation)[0] batch_size = tf.compat.dimension_value( first_obs.shape[0]) or tf.shape(first_obs)[0] policy_choice = self._mixture_distribution.sample(batch_size) policy_steps = [ policy.action(time_step, policy_state) for policy in self._policies ] policy_actions = nest_utils.stack_nested_tensors( [step.action for step in policy_steps], axis=-1) policy_infos = nest_utils.stack_nested_tensors( [step.info for step in policy_steps], axis=-1) expanded_choice = tf.expand_dims(policy_choice, axis=-1) mixture_action = tf.nest.map_structure( lambda t: tf.gather(t, policy_choice, batch_dims=1), policy_actions) expanded_mixture_info = tf.nest.map_structure( lambda t: tf.gather(t, expanded_choice, batch_dims=1, axis=-1), policy_infos) mixture_info = tf.nest.map_structure(lambda t: tf.squeeze(t, axis=-1), expanded_mixture_info) return policy_step.PolicyStep(mixture_action, policy_state, { MIXTURE_AGENT_ID: policy_choice, SUBPOLICY_INFO: mixture_info })
def get_episode( self, batch_size: Optional[int] = None, truncate_episode_at: Optional[int] = None ) -> Tuple[EnvStep, np.ndarray]: if batch_size is None: episode = self._get_episode(truncate_episode_at) mask = np.ones((len(episode), )) return nest_utils.stack_nested_tensors(episode), mask if batch_size <= 0: raise ValueError('Invalid batch size %s.' % batch_size) episodes = [] episode_lengths = [] for _ in range(batch_size): next_episode = self._get_episode(truncate_episode_at) episodes.append(next_episode) episode_lengths.append(len(next_episode)) max_length = max(episode_lengths) for episode in episodes: episode.extend([episode[-1]] * (max_length - len(episode))) batched_episodes = nest_utils.stack_nested_tensors( [nest_utils.stack_nested_tensors(episode) for episode in episodes]) valid_steps = (tf.range(max_length)[None, :] < tf.convert_to_tensor(episode_lengths)[:, None]) return batched_episodes, valid_steps
def _action(self, time_step, policy_state, seed=None): tf.debugging.assert_greater_equal( self._mixture_distribution, 0.0, message='Negative probability in mixture distribution.') policy_sampler = tfd.Categorical(probs=self._mixture_distribution) first_obs = tf.nest.flatten(time_step.observation)[0] batch_size = tf.compat.dimension_value( first_obs.shape[0]) or tf.shape(first_obs)[0] policy_choice = policy_sampler.sample(batch_size) policy_actions = nest_utils.stack_nested_tensors([ policy.action(time_step, policy_state).action for policy in self._policies ], axis=-1) # TODO(b/147134243) Remove the expand_dims and squeeze once the fix in # b/143205052 is live. expanded_choice = tf.expand_dims(policy_choice, axis=-1) expanded_mixture_action = tf.nest.map_structure( lambda t: tf.gather(t, expanded_choice, batch_dims=1), policy_actions) mixture_action = tf.nest.map_structure( lambda t: tf.squeeze(t, axis=-1), expanded_mixture_action) return policy_step.PolicyStep(mixture_action, policy_state)
def sample(self, batch_size): dummy_action_step = policy_step.PolicyStep( action=tf.constant([tf.int32.min])) dummy_time_step = ts.TimeStep(step_type=tf.constant([tf.int32.min]), reward=(np.nan * tf.ones(1)), discount=(np.nan * tf.ones(1)), observation=None) trajs = [] for transition in random.sample(self.buffer, batch_size): traj1 = trajectory.from_transition(transition.time_step, transition.action_step, transition.next_time_step) traj2 = trajectory.from_transition(transition.next_time_step, dummy_action_step, dummy_time_step) trajs.append( nest_utils.unbatch_nested_tensors( nest_utils.stack_nested_tensors([traj1, traj2], axis=1))) return nest_utils.stack_nested_tensors(trajs)
def testStackNestedTensors(self): shape = [5, 8] batch_size = 3 batched_shape = [batch_size,] + shape specs = self.nest_spec(shape, include_sparse=False) unstacked_tensors = [self.zeros_from_spec(specs) for _ in range(batch_size)] stacked_tensor = nest_utils.stack_nested_tensors(unstacked_tensors) tf.nest.assert_same_structure(specs, stacked_tensor) assert_shapes = lambda tensor: self.assertEqual(tensor.shape, batched_shape) tf.nest.map_structure(assert_shapes, stacked_tensor)
def testStackNestedTensorsAxis1(self): shape = [5, 8] stack_dim = 3 stacked_shape = [5, 3, 8] specs = self.nest_spec(shape, include_sparse=False) unstacked_tensors = [self.zeros_from_spec(specs)] * stack_dim stacked_tensor = nest_utils.stack_nested_tensors(unstacked_tensors, axis=1) tf.nest.assert_same_structure(specs, stacked_tensor) assert_shapes = lambda tensor: self.assertEqual(tensor.shape, stacked_shape) tf.nest.map_structure(assert_shapes, stacked_tensor)
def setUp(self): super(BatchedPyMetricTest, self).setUp() # Order of args for trajectory methods: # observation, action, policy_info, reward, discount ts0 = nest_utils.stack_nested_tensors([ trajectory.boundary((), (), (), 0., 1.), trajectory.boundary((), (), (), 0., 1.) ]) ts1 = nest_utils.stack_nested_tensors([ trajectory.first((), (), (), 1., 1.), trajectory.first((), (), (), 2., 1.) ]) ts2 = nest_utils.stack_nested_tensors([ trajectory.last((), (), (), 3., 1.), trajectory.last((), (), (), 4., 1.) ]) ts3 = nest_utils.stack_nested_tensors([ trajectory.boundary((), (), (), 0., 1.), trajectory.boundary((), (), (), 0., 1.) ]) ts4 = nest_utils.stack_nested_tensors([ trajectory.first((), (), (), 5., 1.), trajectory.first((), (), (), 6., 1.) ]) ts5 = nest_utils.stack_nested_tensors([ trajectory.last((), (), (), 7., 1.), trajectory.last((), (), (), 8., 1.) ]) self._ts = [ts0, ts1, ts2, ts3, ts4, ts5]
def _create_trajectories(self): # Order of args for trajectory methods: # observation, action, policy_info, reward, discount ts0 = nest_utils.stack_nested_tensors([ trajectory.boundary((), (), (), 0., 1.), trajectory.boundary((), (), (), 0., 1.) ]) ts1 = nest_utils.stack_nested_tensors([ trajectory.first((), (), (), 1., 1.), trajectory.first((), (), (), 2., 1.) ]) ts2 = nest_utils.stack_nested_tensors([ trajectory.last((), (), (), 3., 1.), trajectory.last((), (), (), 4., 1.) ]) ts3 = nest_utils.stack_nested_tensors([ trajectory.boundary((), (), (), 0., 1.), trajectory.boundary((), (), (), 0., 1.) ]) ts4 = nest_utils.stack_nested_tensors([ trajectory.first((), (), (), 5., 1.), trajectory.first((), (), (), 6., 1.) ]) ts5 = nest_utils.stack_nested_tensors([ trajectory.last((), (), (), 7., 1.), trajectory.last((), (), (), 8., 1.) ]) return [ts0, ts1, ts2, ts3, ts4, ts5]
def day_preds(date_str,contract_str,c,model): s = StockEnvBasic(date_str=date_str,contract_str=contract_str,**c.stock_env) obs = [] times = range(c.stock_env['time_start_trade'],c.stock_env['time_end']) the_ts = s.reset() while not the_ts.is_last(): obs.append(the_ts.observation) action = 1 the_ts = s.step(action) obs = nest_utils.stack_nested_tensors(obs) preds = model(obs,is_training=False) return times,preds
def testNumActions(self, dtype): if not dtype.is_integer: self.skipTest('testNumActions only applies to integer dtypes') batch_size = 1000 # Create action spec, time_step and spec with max_num_arms = 4. action_spec = tensor_spec.BoundedTensorSpec((), dtype, 0, 3) time_step_spec, time_step_1 = self.create_time_step( use_per_arm_features=True, num_arms=2) _, time_step_2 = self.create_time_step(use_per_arm_features=True, num_arms=3) # First half of time_step batch will have num_action = 2 and second # half will have num_actions = 3. half_batch_size = int(batch_size / 2) time_step = nest_utils.stack_nested_tensors( [time_step_1] * half_batch_size + [time_step_2] * half_batch_size) # The features for the chosen arm is saved to policy_info. chosen_arm_features_info = ( policy_utilities.create_chosen_arm_features_info_spec( time_step_spec.observation)) info_spec = policy_utilities.PerArmPolicyInfo( chosen_arm_features=chosen_arm_features_info) policy = random_tf_policy.RandomTFPolicy(time_step_spec=time_step_spec, action_spec=action_spec, info_spec=info_spec, accepts_per_arm_features=True, emit_log_probability=True) action_step = policy.action(time_step) tf.nest.assert_same_structure(action_spec, action_step.action) # Sample from the policy 1000 times, and ensure that actions considered # invalid according to the mask are never chosen. step = self.evaluate(action_step) action_ = step.action self.assertTrue(np.all(action_ >= 0)) self.assertTrue(np.all(action_[:half_batch_size] < 2)) self.assertTrue(np.all(action_[half_batch_size:] < 3)) # With num_action valid actions, probabilities should be 1/num_actions. self.assertAllClose( step.info.log_probability[:half_batch_size], tf.constant(np.log(1. / 2), shape=[half_batch_size])) self.assertAllClose( step.info.log_probability[half_batch_size:], tf.constant(np.log(1. / 3), shape=[half_batch_size]))
def get_step(self, batch_size: Optional[int] = None, num_steps: Optional[int] = None) -> EnvStep: if batch_size is not None: raise ValueError( 'This dataset does not support batched step sampling.') if num_steps is None: return self._get_step() env_steps = [] for _ in range(num_steps): next_step = self._get_step() env_steps.append(next_step) return nest_utils.stack_nested_tensors(env_steps)
def _resample_action_fn(self, resample_input): ac_mean, scale, step_type, *flat_observation = resample_input # expects single ac, obs, step n, k = self._n, self._k # samples "best" safe action out of 50 # sampled_ac = actions.sample(n) observation = tf.nest.pack_sequence_as(self.time_step_spec.observation, flat_observation) obs = nest_utils.stack_nested_tensors([observation for _ in range(n)]) actions = self._actor_network.output_spec.build_distribution( loc=ac_mean, scale=scale) sampled_ac = actions.sample(n) ac_outer_rank = nest_utils.get_outer_rank(sampled_ac, self.action_spec) ac_batch_squash = utils.BatchSquash(ac_outer_rank) sampled_ac = tf.nest.map_structure(ac_batch_squash.flatten, sampled_ac) obs_outer_rank = nest_utils.get_outer_rank( obs, self.time_step_spec.observation) obs_batch_squash = utils.BatchSquash(obs_outer_rank) obs = tf.nest.map_structure(obs_batch_squash.flatten, obs) q_val, _ = self._safety_critic_network((obs, sampled_ac), step_type) fail_prob = tf.nn.sigmoid(q_val) safe_ac_mask = fail_prob < self._safety_threshold # pdb.set_trace() [_, safe_ac_mask, sampled_ac, fail_prob] = tf.while_loop( cond=resample_cond, body=self._loop_body_fn(ac_batch_squash, obs, step_type, ac_mean), loop_vars=[scale, safe_ac_mask, sampled_ac, fail_prob], maximum_iterations=k) sampled_ac = tf.nest.map_structure(ac_batch_squash.unflatten, sampled_ac) if self._resample_counter is not None: logging.debug('resampled %d times', self._resample_counter.result()) safe_ac_idx = tf.where(safe_ac_mask) fail_prob_safe = tf.gather(fail_prob, safe_ac_idx[:, 0]) safe_idx = self._get_safe_idx(safe_ac_mask, fail_prob, sampled_ac, safe_ac_idx, actions, fail_prob_safe) ac = sampled_ac[safe_idx] return ac
def _apply_actor_network(self, time_step, policy_state): has_batch_dim = time_step.step_type.shape.as_list()[0] > 1 observation = time_step.observation if self._observation_normalizer: observation = self._observation_normalizer.normalize(observation) actions, policy_state = self._actor_network(observation, time_step.step_type, policy_state) if has_batch_dim: return actions, policy_state # samples "best" safe action out of 50 sampled_ac = actions.sample(50) obs = nest_utils.stack_nested_tensors( [time_step.observation for _ in range(50)]) q_val, _ = self._safety_critic_network((obs, sampled_ac), time_step.step_type) fail_prob = tf.nn.sigmoid(q_val) safe_ac_mask = fail_prob < self._safety_threshold safe_ac_idx = tf.where(safe_ac_mask) resample_count = 0 while resample_count < 4 and not safe_ac_idx.shape.as_list()[0]: if self._resample_metric is not None: self._resample_metric() resample_count += 1 scale = actions.scale * 1.5 # increase variance by constant 1.5 actions = self._actor_network.output_spec.build_distribution( loc=actions.loc, scale=scale) sampled_ac = actions.sample(50) q_val, _ = self._safety_critic_network((obs, sampled_ac), time_step.step_type) fail_prob = tf.nn.sigmoid(q_val) safe_ac_idx = tf.where( tf.squeeze(fail_prob) < self._safety_threshold) if not safe_ac_idx.shape.as_list()[0]: # return safest action safe_ac_idx = tf.math.argmin(fail_prob) return sampled_ac[safe_ac_idx], policy_state actions = tf.squeeze(tf.gather(sampled_ac, safe_ac_idx)) fail_prob_safe = tf.gather(fail_prob, safe_ac_idx) safe_idx = tf.math.argmax(fail_prob_safe) return actions[safe_idx], policy_state
def create_batch(self, single_time_step, batch_size): batch_time_step = nest_utils.stack_nested_tensors( [single_time_step] * batch_size) return batch_time_step
def _apply_actor_network(self, time_step, policy_state): has_batch_dim = time_step.step_type.shape.as_list()[0] > 1 observation = time_step.observation if self._observation_normalizer: observation = self._observation_normalizer.normalize(observation) actions, policy_state = self._actor_network(observation, time_step.step_type, policy_state, training=self._training) if has_batch_dim: return actions, policy_state # samples "best" safe action out of 50 sampled_ac = actions.sample(50) obs = nest_utils.stack_nested_tensors( [time_step.observation for _ in range(50)]) obs_outer_rank = nest_utils.get_outer_rank( obs, self.time_step_spec.observation) ac_outer_rank = nest_utils.get_outer_rank(sampled_ac, self.action_spec) obs_batch_squash = utils.BatchSquash(obs_outer_rank) ac_batch_squash = utils.BatchSquash(ac_outer_rank) obs = tf.nest.map_structure(obs_batch_squash.flatten, obs) sampled_ac = tf.nest.map_structure(ac_batch_squash.flatten, sampled_ac) q_val, _ = self._safety_critic_network((obs, sampled_ac), time_step.step_type) fail_prob = tf.nn.sigmoid(q_val) safe_ac_mask = fail_prob < self._safety_threshold safe_ac_idx = tf.where(safe_ac_mask) resample_count = 0 start_time = time.time() while self._training and resample_count < 4 and not safe_ac_idx.shape.as_list( )[0]: if self._resample_counter is not None: self._resample_counter() resample_count += 1 if isinstance(actions, dist_utils.SquashToSpecNormal): scale = actions.input_distribution.scale * 1.5 # increase variance by constant 1.5 ac_mean = actions.mean() else: scale = actions.scale * 1.5 ac_mean = actions.mean() actions = self._actor_network.output_spec.build_distribution( loc=ac_mean, scale=scale) sampled_ac = actions.sample(50) sampled_ac = tf.nest.map_structure(ac_batch_squash.flatten, sampled_ac) q_val, _ = self._safety_critic_network((obs, sampled_ac), time_step.step_type) fail_prob = tf.nn.sigmoid(q_val) safe_ac_idx = tf.where(fail_prob < self._safety_threshold) # logging.debug('resampled {} times, {} seconds'.format(resample_count, time.time() - start_time)) sampled_ac = ac_batch_squash.unflatten(sampled_ac) if None in safe_ac_idx.shape.as_list() or not np.prod( safe_ac_idx.shape.as_list()): # return safest action safe_idx = tf.argmin(fail_prob) else: sampled_ac = tf.gather(sampled_ac, safe_ac_idx) fail_prob_safe = tf.gather(fail_prob, safe_ac_idx) if self._training: safe_idx = tf.argmax(fail_prob_safe)[ 0] # picks most unsafe action out of "safe" options else: safe_idx = tf.argmin(fail_prob_safe)[0] ac = sampled_ac[safe_idx] assert ac.shape.as_list( )[0] == 1, 'action shape is not correct: {}'.format(ac.shape.as_list()) return ac, policy_state