def _compute_nce_score(self, predictions, lookaheads): assert predictions.shape.ndims == lookaheads.shape.ndims shape = tf.shape(predictions) batch_size, max_num_nodes = shape[-3], shape[-2] unknown_prefix = shape[:-3] unknown_prefix_list = tf.unstack(unknown_prefix) # (..., B, N, dh) -> (..., B * N, dh) flat_shape = tf.stack( [*unknown_prefix_list, batch_size * max_num_nodes, -1]) # * (dh, ds) -> (..., B * N, ds) flat_predictions = tf.reshape(predictions, flat_shape) flat_lookaheads = tf.reshape(lookaheads, flat_shape) # (..., B * N, ds) * (..., [B * N, ds].T) -> (..., B * N, B * N) pairwise_log_bilinear_scores = tf.math.divide( tf.linalg.matmul(flat_predictions, tf.linalg.tensordot(flat_lookaheads, self._linear, axes=1), transpose_b=True), tf.math.sqrt(util.float(util.dim(predictions)))) # (..., B * N, B * N) -> (..., B * N) -> (..., B, N) batched_nce_scores = tf.math.subtract( tf.linalg.diag_part(pairwise_log_bilinear_scores), tf.math.reduce_logsumexp(pairwise_log_bilinear_scores, axis=-1)) return tf.reshape( batched_nce_scores, tf.stack([*unknown_prefix_list, batch_size, max_num_nodes]))
def _broadcast_and_concat(self, global_context, local_context): global_context = tf.expand_dims(global_context, axis=-2) global_context = tf.math.add( global_context, tf.zeros( tf.stack([ *tf.unstack(tf.shape(local_context)[:-1]), util.dim(global_context) ]))) return tf.concat([global_context, local_context], axis=-1)
def trainable(self): initializer = tf.initializers.glorot_uniform belief_state = tf.get_variable("initial_belief_state", [util.dim(self.belief_states)], trainable=True, initializer=initializer()) global_latent_history = tf.get_variable( "initial_global_history", [util.dim(self.latent_histories[0])], trainable=True, initializer=initializer()) local_latent_history = tf.get_variable( "initial_local_history", [util.dim(self.latent_histories[1])], trainable=True, initializer=initializer()) global_latent_state = tf.get_variable( "initial_global_state", [util.dim(self.latent_states[0])], trainable=True, initializer=initializer()) local_latent_state = tf.get_variable("initial_local_state", [util.dim(self.latent_states[1])], trainable=True, initializer=initializer()) z = tf.zeros_like belief_states = tf.math.add(z(self.belief_states), belief_state) latent_histories = ( tf.math.add(z(self.latent_histories[0]), global_latent_history), tf.math.add(z(self.latent_histories[1]), local_latent_history), ) latent_states = ( tf.math.add(z(self.latent_states[0]), global_latent_state), tf.math.add(z(self.latent_states[1]), local_latent_state), ) original_states = (self.belief_states, self.latent_histories, self.latent_states) states = util.select_nested( self.reset, (belief_states, latent_histories, latent_states), original_states) states = util.nested_set_shape_like(states, original_states) return PersistentStates(self.reset, *states)
def _concat_with_inputs(self, context, inputs): if inputs is None: return context if inputs.shape.ndims < context.shape.ndims: inputs = tf.math.add( inputs, tf.zeros( tf.stack([ *tf.unstack(tf.shape(context)[:-1]), util.dim(inputs) ]))) return tf.concat([context, inputs], axis=-1)
def score(self, graph, histories, states, observations, beliefs, lookaheads): ''' Args: histories: A 2-ary tuple. states: A 2-ary tuple. observations: A (F, B, N, dx) Tensor. Returns: scores: A (..., B, N) Tensor. ''' del graph, histories, beliefs, lookaheads global_states, local_states = states combined_states = _broadcast_and_concat(*states) assert observations.shape.ndims == 4 num_future_steps = self._num_future_steps num_remain_steps = tf.shape(observations)[0] dim_observ = util.dim(observations) mask = tf.cond( tf.math.greater_equal(num_remain_steps, num_future_steps), lambda: tf.ones(num_future_steps), lambda: tf.concat([ tf.ones(num_remain_steps), tf.zeros(num_future_steps - num_remain_steps) ], axis=0)) observations = tf.cond( tf.math.greater_equal(tf.shape(observations)[0], num_future_steps), lambda: observations[:num_future_steps], lambda: tf_pad_axis_to( observations, axis=0, size=num_future_steps)) # (F, B, N, dx) -> (B, N, F, dx) perm = [1, 2, 0, 3] transposed_observations = tf.transpose(observations, perm) cond_dist = self._mlp_diag_normal(combined_states).build() assert type(cond_dist) is tfd.MultivariateNormalDiag # (..., B, N, F * dx) locs, scales = cond_dist.mean(), cond_dist.stddev() # (..., B, N, F, dx) shape = tf.stack([ *tf.unstack(tf.shape(local_states)[:-1]), num_future_steps, dim_observ ]) locs, scales = tf.reshape(locs, shape), tf.reshape(scales, shape) cond_dist = tfd.MultivariateNormalDiag(loc=locs, scale_diag=scales) log_probs = cond_dist.log_prob(transposed_observations) # (B, N, F) return tf.math.reduce_sum(tf.math.multiply(log_probs, mask), axis=-1)
def propose_local(self, graph, external, local_priors, observations, conditions, context=None): ''' Args: graph: A RuntimeGraph object. histories: A (..., B, N, dh) Tensor. states: A (..., B, N, dz) Tensor. observations: A (B, N, dx) Tensor. beliefs: A (B, N, dh) Tensor. lookaheads: A (B, N, dh) Tensor. length: A scalar Tensor, the length of seqeuence. Returns: dist: Distribution of batch shape (..., B, N) & event shape (dz) ''' del context local_pre_flow_dist = local_priors.pre_flow_dist local_context = self.combine_local(inputs=external.local_inputs, priors=local_priors, conditions=conditions, observations=observations) broadcast_observations = tf.math.add( tf.zeros( tf.stack([ *tf.unstack(tf.shape(local_context)[:-1]), util.dim(observations) ])), observations) approx_dist = self._local_mlp_normal( tf.concat([local_context, broadcast_observations], axis=-1)) if self.use_skip_conn: assert type(approx_dist) is util.PartialLocScaleDist approx_dist.loc = self._local_skip_conn_layer_norm( tf.math.add(approx_dist.loc, local_pre_flow_dist.mean())) if self.use_gated_adder: assert type(approx_dist) is util.PartialLocScaleDist approx_dist.loc = tf.math.add( approx_dist.loc, self._local_gated_unit(local_context)) if type(approx_dist) is util.PartialLocScaleDist: approx_dist = approx_dist.build() approx_dist = tfd.Independent(distribution=approx_dist, reinterpreted_batch_ndims=1, name="indep_" + approx_dist.name) if self.flow_num_layers > 0: approx_dist = perm_equiv_flow_wrapper( components=self._local_flow_components, graph=graph, const_num_nodes=self.model.const_num_nodes, global_context=local_priors.global_context, local_context=local_context, base_dist=approx_dist, skip_conn=self.flow_skip_conn) assert approx_dist.reparameterization_type == \ tfd.FULLY_REPARAMETERIZED return approx_dist
def resampling(resampler, histories, particles, log_weights, num_samples, mask=None): ''' Args: histories: A 2-ary tuple: - global_histories: A (S, B, dH) Tensor. - local_histories: A (S, B, N, dH) Tensor. particles: A 2-ary tuple: - global_states: A (S, B, dz) Tensor. - local_states: A (S, B, N, dz) Tensor. log_weights: A (S, B) Tensor. num_samples: A scalar. Returns: resampled_histories, resampled_particles, resampled_log_weights ''' global_histories, local_histories = histories global_states, local_states = particles assert global_states.shape.ndims == 3 assert local_states.shape.ndims == 4 shape = tf.shape(local_states) prev_num_samples, batch_size, num_nodes = shape[0], shape[1], shape[2] dim_global_state = util.dim(global_states) dim_local_state = util.dim(local_states) dim_global_history = util.dim(global_histories) dim_local_history = util.dim(local_histories) local_flat_shape = tf.stack([prev_num_samples, batch_size, -1]) flat_local_states = tf.reshape(local_states, local_flat_shape) flat_local_histories = tf.reshape(local_histories, local_flat_shape) concated = tf.concat([ tf.expand_dims(log_weights, axis=-1), global_histories, global_states, flat_local_histories, flat_local_states ], axis=-1) resampled = resampler(concated, log_weights, num_samples) split_sizes = tf.stack([ 1, dim_global_history, dim_global_state, num_nodes * dim_local_history, num_nodes * dim_local_state ]) resampled_log_weights, \ resampled_global_histories, resampled_global_states, \ resampled_flat_local_histories, resampled_flat_local_states \ = tf.split(resampled, split_sizes, axis=-1) resampled_local_histories = tf.reshape( resampled_flat_local_histories, [num_samples, batch_size, num_nodes, dim_local_history]) resampled_local_states = tf.reshape( resampled_flat_local_states, [num_samples, batch_size, num_nodes, dim_local_state]) resampled_histories = (resampled_global_histories, resampled_local_histories) resampled_particles = (resampled_global_states, resampled_local_states) return (resampled_histories, resampled_particles, tf.squeeze(resampled_log_weights, axis=[-1]))