def fit_gaussian(embeddings, damping=1e-7, full_covariance=False): """Fits a unimodal Gaussian distribution to `embeddings`. Args: embeddings: A [batch_size, embedding_dim] tf.Tensor of embeddings. damping: The scale of the covariance damping coefficient. full_covariance: Whether to use a full or diagonal covariance. Returns: Parameter estimates (means and log variances) for a Gaussian model. """ if full_covariance: num, dim = tf.split(tf.shape(input=embeddings), num_or_size_splits=2) num, dim = tf.squeeze(num), tf.squeeze(dim) sample_mean = tf.reduce_mean(input_tensor=embeddings, axis=0) centered_embeddings = embeddings - sample_mean sample_covariance = tf.einsum('ij,ik->kj', centered_embeddings, centered_embeddings) # Outer product. sample_covariance += damping * tf.eye(dim) # Positive definiteness. sample_covariance /= tf.cast(num, dtype=tf.float32) # Scale by N. return sample_mean, sample_covariance else: sample_mean, sample_variances = tf.nn.moments(x=embeddings) log_variances = tf.math.log(sample_variances + damping * tf.ones_like(sample_variances)) return sample_mean, log_variances
def safety_critic_loss(tf_agent, safety_critic, time_steps, actions, next_time_steps, safety_rewards, weights=None): """Returns a critic loss with safety.""" next_actions, next_log_pis = tf_agent._actions_and_log_probs( # pylint: disable=protected-access next_time_steps) del next_log_pis target_input = (next_time_steps.observation[0], next_actions[0]) target_q_values, unused_network_state1 = safety_critic( target_input, next_time_steps.step_type[0]) target_q_values = tf.nn.sigmoid(target_q_values) safety_rewards = tf.to_float(safety_rewards) td_targets = tf.stop_gradient(safety_rewards + (1 - safety_rewards) * next_time_steps.discount * target_q_values) td_targets = tf.squeeze(td_targets) pred_input = (time_steps.observation[0], actions[0]) pred_td_targets, unused_network_state1 = safety_critic( pred_input, time_steps.step_type[0]) loss = tf.losses.sigmoid_cross_entropy(td_targets, pred_td_targets) if weights is not None: loss *= tf.to_float(tf.squeeze(weights)) # Take the mean across the batch. loss = tf.reduce_mean(input_tensor=loss) return loss
def _network_adapter(self, states, scope): self._validate_states(states) with tf.compat.v1.name_scope('network'): q_value_list = [] for slate in self._all_possible_slates: user = tf.squeeze(states[:, 0, :, :], axis=2) docs = [] for i in slate: docs.append(tf.squeeze(states[:, i + 1, :, :], axis=2)) q_value_list.append(self.network(user, tf.concat(docs, axis=1), scope)) q_values = tf.concat(q_value_list, axis=1) return dqn_agent.DQNNetworkType(q_values)
def _compute_responsibilities(examples_, class_idx): train_predictions_ = tf.squeeze(self.head_fn( embeddings=examples_, components=True, class_idx=[class_idx]), axis=1) return tf.nn.softmax(train_predictions_, axis=-1)
def dqn_template(state, num_actions, layer_size=512, num_layers=1): r"""Builds a DQN Network mapping states to Q-values. Args: state: A `tf.placeholder` for the RL state. num_actions: int, number of actions that the RL agent can take. layer_size: int, number of hidden units per layer. num_layers: int, Number of hidden layers. Returns: net: A `tf.Graphdef` for DQN: `\theta : \mathcal{X}\rightarrow\mathbb{R}^{|\mathcal{A}|}` """ weights_initializer = slim.variance_scaling_initializer(factor=1.0 / np.sqrt(3.0), mode='FAN_IN', uniform=True) net = tf.cast(state, tf.float32) net = tf.squeeze(net, axis=2) for _ in range(num_layers): net = slim.fully_connected(net, layer_size, activation_fn=tf.nn.relu) net = slim.fully_connected(net, num_actions, activation_fn=None, weights_initializer=weights_initializer) return net
def pick_labeled_image(mesh_inputs, view_image_inputs, view_indices_2d_inputs, view_name): """Pick the image with most number of labeled points projecting to it.""" if view_name not in view_image_inputs: return if view_name not in view_indices_2d_inputs: return if standard_fields.InputDataFields.point_loss_weights not in mesh_inputs: raise ValueError('The key `weights` is missing from mesh_inputs.') height = tf.shape(view_image_inputs[view_name])[1] width = tf.shape(view_image_inputs[view_name])[2] valid_points_y = tf.logical_and( tf.greater_equal(view_indices_2d_inputs[view_name][:, :, 0], 0), tf.less(view_indices_2d_inputs[view_name][:, :, 0], height)) valid_points_x = tf.logical_and( tf.greater_equal(view_indices_2d_inputs[view_name][:, :, 1], 0), tf.less(view_indices_2d_inputs[view_name][:, :, 1], width)) valid_points = tf.logical_and(valid_points_y, valid_points_x) image_total_weights = tf.reduce_sum( tf.cast(valid_points, dtype=tf.float32) * tf.squeeze( mesh_inputs[standard_fields.InputDataFields.point_loss_weights], axis=1), axis=1) image_total_weights = tf.cond( tf.equal(tf.reduce_sum(image_total_weights), 0), lambda: tf.reduce_sum(tf.cast(valid_points, dtype=tf.float32), axis=1), lambda: image_total_weights) best_image = tf.math.argmax(image_total_weights) view_image_inputs[view_name] = view_image_inputs[view_name][ best_image:best_image + 1, :, :, :] view_indices_2d_inputs[view_name] = view_indices_2d_inputs[view_name][ best_image:best_image + 1, :, :]
def _mine(self, x_in, y_in): """Mutual Infomation Neural Estimator. Implement mutual information neural estimator from Belghazi et al "Mutual Information Neural Estimation" http://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf 'DV': sup_T E_P(T) - log E_Q(exp(T)) where P is the joint distribution of X and Y, and Q is the product marginal distribution of P. DV is a lower bound for KLD(P||Q)=MI(X, Y). """ y_in_tran = transpose2(y_in, 1, 0) y_shuffle_tran = math_ops.shuffle(y_in_tran) y_shuffle = transpose2(y_shuffle_tran, 1, 0) # propagate the forward pass T_xy, _ = self._network([x_in, y_in]) T_x_y, _ = self._network([x_in, y_shuffle]) # compute the negative loss (maximize loss == minimize -loss) mean_exp_T_x_y = tf.reduce_mean(tf.math.exp(T_x_y), axis=1) loss = tf.reduce_mean(T_xy, axis=1) - tf.math.log(mean_exp_T_x_y) loss = tf.squeeze(loss, axis=-1) # Mutual Information return loss
def rainbow_template(state, num_actions, num_atoms=51, layer_size=512, num_layers=2): # FIXME: Aron 3/14/19: changed from 1 to 2 r"""Builds a Rainbow Network mapping states to value distributions. Args: state: A `tf.placeholder` for the RL state. num_actions: int, number of actions that the RL agent can take. num_atoms: int, number of atoms to approximate the distribution with. layer_size: int, number of hidden units per layer. num_layers: int, number of hidden layers. Returns: net: A `tf.Graphdef` for Rainbow: `\theta : \mathcal{X}\rightarrow\mathbb{R}^{|\mathcal{A}| \times N}`, where `N` is num_atoms. """ weights_initializer = slim.variance_scaling_initializer( factor=1.0 / np.sqrt(3.0), mode='FAN_IN', uniform=True) net = tf.cast(state, tf.float32) net = tf.squeeze(net, axis=2) for _ in range(num_layers): net = slim.fully_connected(net, layer_size, activation_fn=tf.nn.relu) net = slim.fully_connected(net, num_actions * num_atoms, activation_fn=None, weights_initializer=weights_initializer) net = tf.reshape(net, [-1, num_actions, num_atoms]) return net
def compute_pointcloud_weights_based_on_voxel_density(points, grid_cell_size): """Computes pointcloud weights based on voxel density. Args: points: A tf.float32 tensor of size [num_points, 3]. grid_cell_size: The size of the grid cells in x, y, z dimensions in the voxel grid. It should be either a tf.float32 tensor, a numpy array or a list of size [3]. Returns: A tf.float32 tensor of size [num_points, 1] containing weights that are inverse proportional to the denisty of the points in voxels. """ num_points = tf.shape(points)[0] features = tf.ones([num_points, 1], dtype=tf.float32) voxel_features, _, segment_ids, _ = ( pointcloud_to_sparse_voxel_grid_unbatched( points=points, features=features, grid_cell_size=grid_cell_size, segment_func=tf.math.unsorted_segment_sum)) num_voxels = tf.shape(voxel_features)[0] point_features = sparse_voxel_grid_to_pointcloud( voxel_features=tf.expand_dims(voxel_features, axis=0), segment_ids=tf.expand_dims(segment_ids, axis=0), num_valid_voxels=tf.expand_dims(num_voxels, axis=0), num_valid_points=tf.expand_dims(num_points, axis=0)) inverse_point_densities = 1.0 / tf.squeeze(point_features, axis=0) total_inverse_density = tf.reduce_sum(inverse_point_densities) return (inverse_point_densities * tf.cast(num_points, dtype=tf.float32) / total_inverse_density)
def quantile_loss(y, y_hat, k=4): k = np.linspace(0., 1., k) loss = 0. y = tf.squeeze(y, axis=2) for idx, q in enumerate(k): error = tf.subtract(y, y_hat[:, :, idx]) loss += tf.reduce_mean(tf.maximum(q * error, (q - 1) / error), axis=-1) return tf.reduce_mean(loss)
def _network_adapter(self, states, scope): self._validate_states(states) with tf.name_scope('network'): # Since we decompose the slate optimization into an item-level # optimization problem, the observation space is the user state # observation plus all documents' observations. In the Dopamine DQN agent # implementation, there is one head for each possible action value, which # is designed for computing the argmax operation in the action space. # In our implementation, we generate one output for each document. q_value_list = [] for i in range(self._num_candidates): user = tf.squeeze(states[:, 0, :, :], axis=2) doc = tf.squeeze(states[:, i + 1, :, :], axis=2) q_value_list.append(self.network(user, doc, scope)) q_values = tf.concat(q_value_list, axis=1) return dqn_agent.DQNNetworkType(q_values)
def plot_map(img, act_map, pred_prob, pred_label, true_label, name): if pred_label: pred_label = 'RDR' else: pred_label = 'NRDR' if true_label: true_label = 'RDR' else: true_label = 'NRDR' img = tf.squeeze(img, axis=0) act_map = tf.squeeze(act_map, axis=0) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) axes[0].imshow(img) axes[1].imshow(img) i = axes[1].imshow(act_map, cmap="jet", alpha=0.5) fig.colorbar(i) plt.suptitle("File {} Pr(class={})= {:5.2f} Ground truth {}".format( name.numpy().decode('utf-8'), pred_label, pred_prob[0], true_label)) plt.savefig('{}_detection.png'.format(name.numpy().decode('utf-8'))) plt.show()
def historgram_loss(y, y_hat, k=100., sigma=1 / 2): raise NotImplementedError() ps = 0. w = 1 / k y = tf.squeeze(y, axis=2) # y_hat = tf.layers.flatten(y_hat) k = np.linspace(0., 1., k) s = (tf.erf((1. - y) / (tf.sqrt(2.) * sigma)) - tf.erf((0. - y) / (tf.sqrt(2.) * sigma))) for idx, j in enumerate(k): u = tf.erf(((j + w - y) / (tf.sqrt(2.) * sigma))) l = tf.erf(((j - y) / (tf.sqrt(2.) * sigma))) p = (u - l) / (2 * s + 1e-6) f_x = tf.log(y_hat[:, :, idx]) ps += p * tf.where(tf.is_nan(f_x), tf.zeros_like(f_x), f_x) return tf.reduce_mean(-ps)
def tf_random_choice(inputs, n_samples): """ With replacement. Params: inputs (Tensor): Shape [n_states, n_features] n_samples (int): The number of random samples to take. Returns: sampled_inputs (Tensor): Shape [n_samples, n_features] """ # (1, n_states) since multinomial requires 2D logits. uniform_log_prob = tf.expand_dims(tf.zeros(tf.shape(inputs)[0]), 0) ind = tf.multinomial(uniform_log_prob, n_samples) ind = tf.squeeze(ind, 0, name="random_choice_ind") # (n_samples,) return tf.gather(inputs, ind, name="random_choice")
def apply_ff( inputs: tf.Tensor, hid_sizes: Iterable[int], name: Optional[str] = None, ) -> tf.Tensor: """Applies a feed forward network on the inputs.""" xavier = tf.contrib.layers.xavier_initializer x = inputs for i, size in enumerate(hid_sizes): x = tf.layers.dense(x, size, activation='relu', kernel_initializer=xavier(), name="dense" + str(i)) x = tf.layers.dense(x, 1, kernel_initializer=xavier(), name="dense_final") return tf.squeeze(x, axis=1, name=name)
def preprocess_spatial_observation(input_obs, spec, categorical_embedding_dims=16, non_categorical_scaling='log'): with tf.name_scope('preprocess_spatial_obs'): features = Lambda(lambda x: tf.split(x, x.get_shape()[1], axis=1))(input_obs) for f in spec.features: if f.is_categorical: features[f.index] = Lambda(lambda x: tf.squeeze(x, axis=1))(features[f.index]) features[f.index] = Embedding(f.scale, categorical_embedding_dims)(features[f.index]) features[f.index] = Permute((3, 1, 2))(features[f.index]) else: if non_categorical_scaling == 'log': features[f.index] = Lambda(lambda x: tf.log(x + 1e-10))(features[f.index]) elif non_categorical_scaling == 'normalize': features[f.index] = Lambda(lambda x: x / f.scale)(features[f.index]) return features
def compute_logits(self, support_embeddings, query_embeddings, onehot_support_labels): """Computes the class logits for the episode. Args: support_embeddings: A Tensor of size [num_support_images, embedding dim]. query_embeddings: A Tensor of size [num_query_images, embedding dim]. onehot_support_labels: A Tensor of size [batch size, way]. Returns: The query set logits as a [num_query_images, way] matrix. Raises: ValueError: Distance must be one of l2 or cosine. """ if self.knn_in_fc: # Recompute the support and query embeddings that were originally computed # in self.forward_pass() to be the fc layer activations. support_embeddings = self.forward_pass_fc(support_embeddings) query_embeddings = self.forward_pass_fc(query_embeddings) # ------------------------ K-NN look up ------------------------------- # For each testing example in an episode, we use its embedding # vector to look for the closest neighbor in all the training examples' # embeddings from the same episode and then assign the training example's # class label to the testing example as the predicted class label for it. if self.distance == 'l2': # [1, num_support, embed_dims] support_embeddings = tf.expand_dims(support_embeddings, axis=0) # [num_query, 1, embed_dims] query_embeddings = tf.expand_dims(query_embeddings, axis=1) # [num_query, num_support] distance = tf.norm(query_embeddings - support_embeddings, axis=2) elif self.distance == 'cosine': support_embeddings = tf.nn.l2_normalize(support_embeddings, axis=1) query_embeddings = tf.nn.l2_normalize(query_embeddings, axis=1) distance = -1 * tf.matmul( query_embeddings, support_embeddings, transpose_b=True) else: raise ValueError('Distance must be one of l2 or cosine.') # [num_query] _, indices = tf.nn.top_k(-distance, k=1) indices = tf.squeeze(indices, axis=1) # [num_query, num_classes] query_logits = tf.gather(onehot_support_labels, indices) return query_logits
def _build_train_op(self): """Builds a training op. Returns: An op performing one step of training from replay data. """ # click_indicator: [B, S] # q_values: [B, A] # actions: [B, S] # slate_q_values: [B, S] # replay_click_q: [B] click_indicator = self._replay.rewards[:, :, self._click_response_index] slate_q_values = tf.compat.v1.batch_gather( self._replay_net_outputs.q_values, tf.cast(self._replay.actions, dtype=tf.int32)) # Only get the Q from the clicked document. replay_click_q = tf.reduce_sum(input_tensor=slate_q_values * click_indicator, axis=1, name='replay_click_q') target = tf.stop_gradient(self._build_target_q_op()) clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1) clicked_indices = tf.squeeze(tf.compat.v1.where(tf.equal(clicked, 1)), axis=1) # clicked_indices is a vector and tf.gather selects the batch dimension. q_clicked = tf.gather(replay_click_q, clicked_indices) target_clicked = tf.gather(target, clicked_indices) def get_train_op(): loss = tf.reduce_mean(input_tensor=tf.square(q_clicked - target_clicked)) if self.summary_writer is not None: with tf.compat.v1.variable_scope('Losses'): tf.compat.v1.summary.scalar('Loss', loss) return loss loss = tf.cond(pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0), true_fn=get_train_op, false_fn=lambda: tf.constant(0.), name='') return self.optimizer.minimize(loss)
def rollout(self, time_step: ActionTimeStep, state=None): observation = self._encode(time_step) value, value_state = self._value_network( observation, step_type=time_step.step_type, network_state=state.value_state) # ValueRnnNetwork will add a time dim to value # See value_rnn_network.py L153 if isinstance(self._value_network, ValueRnnNetwork): value = tf.squeeze(value, axis=1) action_distribution, actor_state = self._actor_network( observation, step_type=time_step.step_type, network_state=state.actor_state) info = ActorCriticInfo(value=value, icm_reward=(), icm_info=(), entropy_target_info=()) if self._icm is not None: icm_step = self._icm.train_step( (observation, time_step.prev_action), state=state.icm_state) info = info._replace(icm_reward=icm_step.outputs, icm_info=icm_step.info) icm_state = icm_step.state else: icm_state = () if self._entropy_target_algorithm: et_step = self._entropy_target_algorithm.train_step( action_distribution) info = info._replace(entropy_target_info=et_step.info) state = ActorCriticState(actor_state=actor_state, value_state=value_state, icm_state=icm_state) return PolicyStep(action=action_distribution, state=state, info=info)
def one_hot_encode(x, scale): x = tf.squeeze(x, axis=1) x = tf.cast(x, tf.int32) return tf.one_hot(x, scale, axis=1)
def value_output(state, activation='linear'): x = Dense( 1, activation=activation, kernel_initializer=tf.keras.initializers.Orthogonal(gain=0.1))(state) return tf.squeeze(x)
def value_output(state, activation='linear'): out = Dense(1, activation=activation)(state) return tf.squeeze(out)
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)
def on_predict_batch_end(self, batch, logs=None): """Write mesh summaries of semantics groundtruth and prediction point clouds at the end of each validation batch.""" inputs = logs['inputs'] outputs = logs['outputs'] if self._metric: for metric in self._metric: metric.update_state(inputs=inputs, outputs=outputs) if batch <= self.num_qualitative_examples: # point cloud visualization vertices = tf.reshape( inputs[standard_fields.InputDataFields.point_positions], [-1, 3]) num_valid_points = tf.squeeze( inputs[standard_fields.InputDataFields.num_valid_points]) logits = outputs[ standard_fields.DetectionResultFields.object_semantic_points] num_classes = logits.get_shape().as_list()[-1] logits = tf.reshape(logits, [-1, num_classes]) gt_semantic_class = tf.reshape( inputs[standard_fields.InputDataFields.object_class_points], [-1]) vertices = vertices[:num_valid_points, :] logits = logits[:num_valid_points, :] gt_semantic_class = gt_semantic_class[:num_valid_points] max_num_points = tf.math.minimum(self.max_num_points_qualitative, num_valid_points) sample_indices = tf.random.shuffle( tf.range(num_valid_points))[:max_num_points] vertices = tf.gather(vertices, sample_indices) logits = tf.gather(logits, sample_indices) gt_semantic_class = tf.gather(gt_semantic_class, sample_indices) semantic_class = tf.math.argmax(logits, axis=1) pred_colors = tf.gather(self._pascal_color_map, semantic_class, axis=0) gt_colors = tf.gather(self._pascal_color_map, gt_semantic_class, axis=0) if standard_fields.InputDataFields.point_colors in inputs: point_colors = (tf.reshape( inputs[standard_fields.InputDataFields.point_colors], [-1, 3]) + 1.0) * 255.0 / 2.0 point_colors = point_colors[:num_valid_points, :] point_colors = tf.gather(point_colors, sample_indices) point_colors = tf.math.minimum(point_colors, 255.0) point_colors = tf.math.maximum(point_colors, 0.0) point_colors = tf.cast(point_colors, dtype=tf.uint8) else: point_colors = tf.ones_like(vertices, dtype=tf.uint8) * 128 # add points and colors for predicted objects if standard_fields.DetectionResultFields.objects_length in outputs: box_corners = box_utils.get_box_corners_3d( boxes_length=outputs[ standard_fields.DetectionResultFields.objects_length], boxes_height=outputs[ standard_fields.DetectionResultFields.objects_height], boxes_width=outputs[ standard_fields.DetectionResultFields.objects_width], boxes_rotation_matrix=outputs[ standard_fields.DetectionResultFields.objects_rotation_matrix], boxes_center=outputs[ standard_fields.DetectionResultFields.objects_center]) box_points = box_utils.get_box_as_dotted_lines(box_corners) objects_class = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_class], [-1]) box_colors = tf.gather(self._pascal_color_map, objects_class, axis=0) box_colors = tf.repeat( box_colors[:, tf.newaxis, :], box_points.shape[1], axis=1) box_points = tf.reshape(box_points, [-1, 3]) box_colors = tf.reshape(box_colors, [-1, 3]) pred_vertices = tf.concat([vertices, box_points], axis=0) pred_colors = tf.concat([pred_colors, box_colors], axis=0) else: pred_vertices = vertices # add points and colors for gt objects if standard_fields.InputDataFields.objects_length in inputs: box_corners = box_utils.get_box_corners_3d( boxes_length=tf.reshape( inputs[standard_fields.InputDataFields.objects_length], [-1, 1]), boxes_height=tf.reshape( inputs[standard_fields.InputDataFields.objects_height], [-1, 1]), boxes_width=tf.reshape( inputs[standard_fields.InputDataFields.objects_width], [-1, 1]), boxes_rotation_matrix=tf.reshape( inputs[standard_fields.InputDataFields.objects_rotation_matrix], [-1, 3, 3]), boxes_center=tf.reshape( inputs[standard_fields.InputDataFields.objects_center], [-1, 3])) box_points = box_utils.get_box_as_dotted_lines(box_corners) objects_class = tf.reshape( inputs[standard_fields.InputDataFields.objects_class], [-1]) box_colors = tf.gather(self._pascal_color_map, objects_class, axis=0) box_colors = tf.repeat( box_colors[:, tf.newaxis, :], box_points.shape[1], axis=1) box_points = tf.reshape(box_points, [-1, 3]) box_colors = tf.reshape(box_colors, [-1, 3]) gt_vertices = tf.concat([vertices, box_points], axis=0) gt_colors = tf.concat([gt_colors, box_colors], axis=0) else: gt_vertices = vertices if batch == 1: logging.info('writing point cloud(shape %s) to summery.', gt_vertices.shape) if standard_fields.InputDataFields.camera_image_name in inputs: camera_image_name = str(inputs[ standard_fields.InputDataFields.camera_image_name].numpy()[0]) else: camera_image_name = str(batch) logging.info(camera_image_name) with self._val_mesh_writer.as_default(): mesh_summary.mesh( name=(self.split + '_points/' + camera_image_name), vertices=tf.expand_dims(vertices, axis=0), faces=None, colors=tf.expand_dims(point_colors, axis=0), config_dict=self._mesh_config_dict, step=self._val_step, ) mesh_summary.mesh( name=(self.split + '_predictions/' + camera_image_name), vertices=tf.expand_dims(pred_vertices, axis=0), faces=None, colors=tf.expand_dims(pred_colors, axis=0), config_dict=self._mesh_config_dict, step=self._val_step, ) mesh_summary.mesh( name=(self.split + '_ground_truth/' + camera_image_name), vertices=tf.expand_dims(gt_vertices, axis=0), faces=None, colors=tf.expand_dims(gt_colors, axis=0), config_dict=self._mesh_config_dict, step=self._val_step, ) if batch == self.num_qualitative_examples: self._val_mesh_writer.flush()
def fit_gaussian_mixture(embeddings, responsibilities, damping=1e-7, full_covariance=False): """Fits a unimodal Gaussian distribution `embeddings`. Args: embeddings: A [batch_size, embedding_dim] tf.Tensor of embeddings. responsibilities: The per-component responsibilities. damping: The scale of the covariance damping coefficient. full_covariance: Whether to use a full or diagonal covariance. Returns: Parameter estimates for a Gaussian mixture model. """ num, dim = tf.split(tf.shape(input=embeddings), num_or_size_splits=2) num, dim = tf.squeeze(num), tf.squeeze(dim) num_classes = responsibilities.shape[1] mixing_proportion = tf.einsum('jk->k', responsibilities) mixing_proportion /= tf.cast(num, dtype=tf.float32) mixing_logits = tf.math.log(mixing_proportion) sample_mean = tf.einsum('ij,ik->jk', responsibilities, embeddings) sample_mean /= tf.reduce_sum( input_tensor=responsibilities, axis=0)[:, tf.newaxis] centered_embeddings = ( embeddings[:, tf.newaxis, :] - sample_mean[tf.newaxis, :, :]) if full_covariance: sample_covariance = tf.einsum('ijk,ijl->ijkl', centered_embeddings, centered_embeddings) # Outer product. sample_covariance += damping * tf.eye(dim) # Positive definiteness. weighted_covariance = tf.einsum('ij,ijkl->jkl', responsibilities, sample_covariance) weighted_covariance /= tf.reduce_sum( input_tensor=responsibilities, axis=0)[:, tf.newaxis, tf.newaxis] return ( _split_and_squeeze(sample_mean, num_splits=num_classes), _split_and_squeeze(weighted_covariance, num_splits=num_classes), [mixing_logits], ) else: avg_x_squared = ( tf.matmul(responsibilities, embeddings**2, transpose_a=True) / tf.reduce_sum(input_tensor=responsibilities, axis=0)[:, tf.newaxis]) avg_means_squared = sample_mean**2 avg_x_means = ( sample_mean * tf.matmul(responsibilities, embeddings, transpose_a=True) / tf.reduce_sum(input_tensor=responsibilities, axis=0)[:, tf.newaxis]) sample_variances = ( avg_x_squared - 2 * avg_x_means + avg_means_squared + damping * tf.ones(dim)) log_variances = tf.math.log(sample_variances) return ( _split_and_squeeze(sample_mean, num_splits=num_classes), _split_and_squeeze(log_variances, num_splits=num_classes), [mixing_logits], )
def _split_and_squeeze(tensor, num_splits, axis=0): return [ tf.squeeze(t) for t in tf.split(tensor, axis=axis, num_or_size_splits=num_splits) ]
def _split_mode_params(params): return [ tf.squeeze(p) for p in tf.split(params, axis=0, num_or_size_splits=self.num_modes) ]
def train(hparams, num_epoch, tuning): log_dir = './results/' test_batch_size = 8 # Load dataset training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'], file_name='train_tf_record', split=True) test_set = make_dataset(BATCH_SIZE=test_batch_size, file_name='test_tf_record', split=False) class_names = ['NRDR', 'RDR'] # Model model = ResNet() # set optimizer optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR']) # set metrics train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() valid_accuracy = tf.keras.metrics.Accuracy() valid_con_mat = ConfusionMatrix(num_class=2) test_accuracy = tf.keras.metrics.Accuracy() test_con_mat = ConfusionMatrix(num_class=2) # Save Checkpoint if not tuning: ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5) # Set up summary writers current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tb_log_dir = log_dir + current_time + '/train' summary_writer = tf.summary.create_file_writer(tb_log_dir) # Restore Checkpoint if not tuning: ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: logging.info('Restored from {}'.format(manager.latest_checkpoint)) else: logging.info('Initializing from scratch.') @tf.function def train_step(train_img, train_label): # Optimize the model loss_value, grads = grad(model, train_img, train_label) optimizer.apply_gradients(zip(grads, model.trainable_variables)) train_pred, _ = model(train_img) train_label = tf.expand_dims(train_label, axis=1) train_accuracy.update_state(train_label, train_pred) for epoch in range(num_epoch): begin = time() # Training loop for train_img, train_label, train_name in training_set: train_img = data_augmentation(train_img) train_step(train_img, train_label) with summary_writer.as_default(): tf.summary.scalar('Train Accuracy', train_accuracy.result(), step=epoch) for valid_img, valid_label, _ in valid_set: valid_img = tf.cast(valid_img, tf.float32) valid_img = valid_img / 255.0 valid_pred, _ = model(valid_img, training=False) valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64) valid_con_mat.update_state(valid_label, valid_pred) valid_accuracy.update_state(valid_label, valid_pred) # Log the confusion matrix as an image summary cm_valid = valid_con_mat.result() figure = plot_confusion_matrix(cm_valid, class_names=class_names) cm_valid_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Valid Accuracy', valid_accuracy.result(), step=epoch) tf.summary.image('Valid ConfusionMatrix', cm_valid_image, step=epoch) end = time() logging.info( "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s" .format(epoch + 1, train_accuracy.result(), valid_accuracy.result(), (end - begin))) train_accuracy.reset_states() valid_accuracy.reset_states() valid_con_mat.reset_states() if not tuning: if int(ckpt.step) % 5 == 0: save_path = manager.save() logging.info('Saved checkpoint for epoch {}: {}'.format( int(ckpt.step), save_path)) ckpt.step.assign_add(1) for test_img, test_label, _ in test_set: test_img = tf.cast(test_img, tf.float32) test_img = test_img / 255.0 test_pred, _ = model(test_img, training=False) test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64) test_accuracy.update_state(test_label, test_pred) test_con_mat.update_state(test_label, test_pred) cm_test = test_con_mat.result() # Log the confusion matrix as an image summary figure = plot_confusion_matrix(cm_test, class_names=class_names) cm_test_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch) tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch) logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format( test_accuracy.result())) # Visualization if not tuning: for vis_img, vis_label, vis_name in test_set: vis_label = vis_label[0] vis_name = vis_name[0] vis_img = tf.cast(vis_img[0], tf.float32) vis_img = tf.expand_dims(vis_img, axis=0) vis_img = vis_img / 255.0 with tf.GradientTape() as tape: vis_pred, conv_output = model(vis_img, training=False) pred_label = tf.argmax(vis_pred, axis=-1) vis_pred = tf.reduce_max(vis_pred, axis=-1) grad_1 = tape.gradient(vis_pred, conv_output) weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1] act_map0 = tf.nn.relu( tf.reduce_sum(weight * conv_output, axis=-1)) act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0, axis=-1), (256, 256), antialias=True), axis=-1) plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label, vis_name) break return test_accuracy.result()