def test_pick_scalar_condition_dynamic(self): pos = np.exp(np.random.randn(3, 2, 4)).astype(np.float32) neg = -np.exp(np.random.randn(3, 2, 4)).astype(np.float32) # TF dynamic cond dynamic_true = tf1.placeholder_with_default(True, shape=None) dynamic_false = tf1.placeholder_with_default(False, shape=None) pos_ = self.evaluate( distribution_util.pick_scalar_condition(dynamic_true, pos, neg)) neg_ = self.evaluate( distribution_util.pick_scalar_condition(dynamic_false, pos, neg)) self.assertAllEqual(pos_, pos) self.assertAllEqual(neg_, neg) # TF dynamic everything pos_dynamic = tf1.placeholder_with_default(pos, shape=None) neg_dynamic = tf1.placeholder_with_default(neg, shape=None) pos_ = self.evaluate( distribution_util.pick_scalar_condition(dynamic_true, pos_dynamic, neg_dynamic)) neg_ = self.evaluate( distribution_util.pick_scalar_condition(dynamic_false, pos_dynamic, neg_dynamic)) self.assertAllEqual(pos_, pos) self.assertAllEqual(neg_, neg)
def seasonal_transition_noise(t): noise_scale_tril = dist_util.pick_scalar_condition( is_last_day_of_season(t), drift_scale_tril, tf.zeros_like(drift_scale_tril)) return tfd.MultivariateNormalTriL(loc=tf.zeros( num_seasons - 1, dtype=drift_scale.dtype), scale_tril=noise_scale_tril)
def seasonal_transition_noise(t): noise_scale = dist_util.pick_scalar_condition( is_last_day_of_season(t), drift_scale_diag, tf.zeros_like(drift_scale_diag, dtype=dtype)) return tfd.MultivariateNormalDiag(loc=tf.zeros(num_seasons, dtype=dtype), scale_diag=noise_scale)
def test_pick_scalar_condition_static(self): pos = np.exp(np.random.randn(3, 2, 4)).astype(np.float32) neg = -np.exp(np.random.randn(3, 2, 4)).astype(np.float32) # Python static cond self.assertAllEqual( distribution_util.pick_scalar_condition(True, pos, neg), pos) self.assertAllEqual( distribution_util.pick_scalar_condition(False, pos, neg), neg) # TF static cond self.assertAllEqual(distribution_util.pick_scalar_condition( tf.constant(True), pos, neg), pos) self.assertAllEqual(distribution_util.pick_scalar_condition( tf.constant(False), pos, neg), neg)
def seasonal_transition_noise(t): noise_scale = dist_util.pick_scalar_condition( is_last_day_of_season(t), drift_scale_diag, tf.zeros_like(drift_scale_diag, dtype=dtype)) return tfd.MultivariateNormalDiag(loc=tf.zeros(num_seasons, dtype=dtype), scale_diag=noise_scale)
def seasonal_transition_matrix(t): return tf.linalg.LinearOperatorFullMatrix( matrix=dist_util.pick_scalar_condition( is_last_day_of_season(t), seasonal_permutation_matrix, identity_matrix))
def seasonal_transition_matrix(t): return tf.linalg.LinearOperatorFullMatrix( matrix=dist_util.pick_scalar_condition( is_last_day_of_season(t), tf.constant(seasonal_permutation_matrix, dtype=dtype), tf.eye(num_seasons, dtype=dtype)))
def seasonal_transition_matrix(t): return tf.linalg.LinearOperatorFullMatrix( matrix=dist_util.pick_scalar_condition( is_last_day_of_season(t), tf.constant(seasonal_permutation_matrix, dtype=dtype), tf.eye(num_seasons, dtype=dtype)))