def _effective_sample_size_single_state(states, filter_beyond_lag, filter_threshold): """ESS computation for one single Tensor argument.""" with tf.compat.v1.name_scope( 'effective_sample_size_single_state', values=[states, filter_beyond_lag, filter_threshold]): states = tf.convert_to_tensor(value=states, name='states') dt = states.dtype # filter_beyond_lag == None ==> auto_corr is the full sequence. auto_corr = stats.auto_correlation(states, axis=0, max_lags=filter_beyond_lag) if filter_threshold is not None: filter_threshold = tf.convert_to_tensor(value=filter_threshold, dtype=dt, name='filter_threshold') # Get a binary mask to zero out values of auto_corr below the threshold. # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, # mask[i, ...] = 0, otherwise. # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] # Building step by step, # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. # Step 1: mask = [False, False, True, False] mask = auto_corr < filter_threshold # Step 2: mask = [0, 0, 1, 1] mask = tf.cast(mask, dtype=dt) # Step 3: mask = [0, 0, 1, 2] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) auto_corr *= mask # With R[k] := auto_corr[k, ...], # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} # where M is the filter_beyond_lag truncation point chosen above. # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total # ndims the same as auto_corr n = _axis_size(states, axis=0) k = tf.range(0., _axis_size(auto_corr, axis=0)) nk_factor = (n - k) / n if auto_corr.shape.ndims is not None: new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) else: new_shape = tf.concat( ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)), axis=0) nk_factor = tf.reshape(nk_factor, new_shape) return n / ( -1 + 2 * tf.reduce_sum(input_tensor=nk_factor * auto_corr, axis=0))
def _effective_sample_size_single_state(states, filter_beyond_lag, filter_threshold): """ESS computation for one single Tensor argument.""" with tf.name_scope( 'effective_sample_size_single_state', values=[states, filter_beyond_lag, filter_threshold]): states = tf.convert_to_tensor(states, name='states') dt = states.dtype # filter_beyond_lag == None ==> auto_corr is the full sequence. auto_corr = stats.auto_correlation( states, axis=0, max_lags=filter_beyond_lag) if filter_threshold is not None: filter_threshold = tf.convert_to_tensor( filter_threshold, dtype=dt, name='filter_threshold') # Get a binary mask to zero out values of auto_corr below the threshold. # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, # mask[i, ...] = 0, otherwise. # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] # Building step by step, # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. # Step 1: mask = [False, False, True, False] mask = auto_corr < filter_threshold # Step 2: mask = [0, 0, 1, 1] mask = tf.cast(mask, dtype=dt) # Step 3: mask = [0, 0, 1, 2] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) auto_corr *= mask # With R[k] := auto_corr[k, ...], # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} # where M is the filter_beyond_lag truncation point chosen above. # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total # ndims the same as auto_corr n = _axis_size(states, axis=0) k = tf.range(0., _axis_size(auto_corr, axis=0)) nk_factor = (n - k) / n if auto_corr.shape.ndims is not None: new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) else: new_shape = tf.concat( ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)), axis=0) nk_factor = tf.reshape(nk_factor, new_shape) return n / (-1 + 2 * tf.reduce_sum(nk_factor * auto_corr, axis=0))
def _effective_sample_size_single_state(states, filter_beyond_lag, filter_threshold, filter_beyond_positive_pairs, cross_chain_dims, validate_args): """ESS computation for one single Tensor argument.""" with tf.name_scope('effective_sample_size_single_state'): states = tf.convert_to_tensor(states, name='states') dt = states.dtype # filter_beyond_lag == None ==> auto_corr is the full sequence. auto_cov = stats.auto_correlation(states, axis=0, max_lags=filter_beyond_lag, normalize=False) n = _axis_size(states, axis=0) if cross_chain_dims is not None: num_chains = _axis_size(states, cross_chain_dims) num_chains_ = tf.get_static_value(num_chains) assertions = [] msg = ( 'When `cross_chain_dims` is not `None`, there must be > 1 chain ' 'in `states`.') if num_chains_ is not None: if num_chains_ < 2: raise ValueError(msg) elif validate_args: assertions.append( assert_util.assert_greater(num_chains, 1., message=msg)) with tf.control_dependencies(assertions): # We're computing the R[k] from equation 10 of Vehtari et al. # (2019): # # R[k] := 1 - (W - 1/C * Sum_{c=1}^C s_c**2 R[k, c]) / (var^+), # # where: # C := number of chains # N := length of chains # x_hat[c] := 1 / N Sum_{n=1}^N x[n, c], chain mean. # x_hat := 1 / C Sum_{c=1}^C x_hat[c], overall mean. # W := 1/C Sum_{c=1}^C s_c**2, within-chain variance. # B := N / (C - 1) Sum_{c=1}^C (x_hat[c] - x_hat)**2, between chain # variance. # s_c**2 := 1 / (N - 1) Sum_{n=1}^N (x[n, c] - x_hat[c])**2, chain # variance # R[k, m] := auto_corr[k, m, ...], auto-correlation indexed by chain. # var^+ := (N - 1) / N * W + B / N cross_chain_dims = prefer_static.non_negative_axis( cross_chain_dims, prefer_static.rank(states)) # B / N between_chain_variance_div_n = _reduce_variance( tf.reduce_mean(states, axis=0), biased=False, # This makes the denominator be C - 1. axis=cross_chain_dims - 1) # W * (N - 1) / N biased_within_chain_variance = tf.reduce_mean( auto_cov[0], cross_chain_dims - 1) # var^+ approx_variance = (biased_within_chain_variance + between_chain_variance_div_n) # 1/C * Sum_{c=1}^C s_c**2 R[k, c] mean_auto_cov = tf.reduce_mean(auto_cov, cross_chain_dims) auto_corr = 1. - (biased_within_chain_variance - mean_auto_cov) / approx_variance else: auto_corr = auto_cov / auto_cov[:1] num_chains = 1 # With R[k] := auto_corr[k, ...], # ESS = N / {1 + 2 * Sum_{k=1}^N R[k] * (N - k) / N} # = N / {-1 + 2 * Sum_{k=0}^N R[k] * (N - k) / N} (since R[0] = 1) # approx N / {-1 + 2 * Sum_{k=0}^M R[k] * (N - k) / N} # where M is the filter_beyond_lag truncation point chosen above. # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total # ndims the same as auto_corr k = tf.range(0., _axis_size(auto_corr, axis=0)) nk_factor = (n - k) / n if tensorshape_util.rank(auto_corr.shape) is not None: new_shape = [-1 ] + [1] * (tensorshape_util.rank(auto_corr.shape) - 1) else: new_shape = tf.concat( ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)), axis=0) nk_factor = tf.reshape(nk_factor, new_shape) weighted_auto_corr = nk_factor * auto_corr if filter_beyond_positive_pairs: def _sum_pairs(x): x_len = tf.shape(x)[0] # For odd sequences, we drop the final value. x = x[:x_len - x_len % 2] new_shape = tf.concat( [[x_len // 2, 2], tf.shape(x)[1:]], axis=0) return tf.reduce_sum(tf.reshape(x, new_shape), 1) # Pairwise sums are all positive for auto-correlation spectra derived from # reversible MCMC chains. # E.g. imagine the pairwise sums are [0.2, 0.1, -0.1, -0.2] # Step 1: mask = [False, False, True, True] mask = _sum_pairs(auto_corr) < 0. # Step 2: mask = [0, 0, 1, 1] mask = tf.cast(mask, dt) # Step 3: mask = [0, 0, 1, 2] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) # N.B. this reduces the length of weighted_auto_corr by a factor of 2. # It still works fine in the formula below. weighted_auto_corr = _sum_pairs(weighted_auto_corr) * mask elif filter_threshold is not None: filter_threshold = tf.convert_to_tensor(filter_threshold, dtype=dt, name='filter_threshold') # Get a binary mask to zero out values of auto_corr below the threshold. # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, # mask[i, ...] = 0, otherwise. # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] # Building step by step, # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. # Step 1: mask = [False, False, True, False] mask = auto_corr < filter_threshold # Step 2: mask = [0, 0, 1, 0] mask = tf.cast(mask, dtype=dt) # Step 3: mask = [0, 0, 1, 1] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) weighted_auto_corr *= mask return num_chains * n / (-1 + 2 * tf.reduce_sum(weighted_auto_corr, axis=0))
def _effective_sample_size_single_state(states, filter_beyond_lag, filter_threshold, filter_beyond_positive_pairs): """ESS computation for one single Tensor argument.""" with tf.name_scope('effective_sample_size_single_state'): states = tf.convert_to_tensor(states, name='states') dt = states.dtype # filter_beyond_lag == None ==> auto_corr is the full sequence. auto_corr = stats.auto_correlation(states, axis=0, max_lags=filter_beyond_lag) # With R[k] := auto_corr[k, ...], # ESS = N / {1 + 2 * Sum_{k=1}^N R[k] * (N - k) / N} # = N / {-1 + 2 * Sum_{k=0}^N R[k] * (N - k) / N} (since R[0] = 1) # approx N / {-1 + 2 * Sum_{k=0}^M R[k] * (N - k) / N} # where M is the filter_beyond_lag truncation point chosen above. # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total # ndims the same as auto_corr n = _axis_size(states, axis=0) k = tf.range(0., _axis_size(auto_corr, axis=0)) nk_factor = (n - k) / n if auto_corr.shape.ndims is not None: new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) else: new_shape = tf.concat( ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)), axis=0) nk_factor = tf.reshape(nk_factor, new_shape) weighted_auto_corr = nk_factor * auto_corr if filter_beyond_positive_pairs: def _sum_pairs(x): x_len = tf.shape(x)[0] # For odd sequences, we drop the final value. x = x[:x_len - x_len % 2] new_shape = tf.concat( [[x_len // 2, 2], tf.shape(x)[1:]], axis=0) return tf.reduce_sum(tf.reshape(x, new_shape), 1) # Pairwise sums are all positive for auto-correlation spectra derived from # reversible MCMC chains. # E.g. imagine the pairwise sums are [0.2, 0.1, -0.1, -0.2] # Step 1: mask = [False, False, True, True] mask = _sum_pairs(auto_corr) < 0. # Step 2: mask = [0, 0, 1, 1] mask = tf.cast(mask, dt) # Step 3: mask = [0, 0, 1, 2] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) # N.B. this reduces the length of weighted_auto_corr by a factor of 2. # It still works fine in the formula below. weighted_auto_corr = _sum_pairs(weighted_auto_corr) * mask elif filter_threshold is not None: filter_threshold = tf.convert_to_tensor(filter_threshold, dtype=dt, name='filter_threshold') # Get a binary mask to zero out values of auto_corr below the threshold. # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, # mask[i, ...] = 0, otherwise. # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] # Building step by step, # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. # Step 1: mask = [False, False, True, False] mask = auto_corr < filter_threshold # Step 2: mask = [0, 0, 1, 0] mask = tf.cast(mask, dtype=dt) # Step 3: mask = [0, 0, 1, 1] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) weighted_auto_corr *= mask return n / (-1 + 2 * tf.reduce_sum(weighted_auto_corr, axis=0))