def run_with_tool(tool, dists, batch_sizes): column_widths, field_format, template = None, None, None if tool == 'timeit': profile_cols = 2 * len(batch_sizes) column_widths = [14] * (profile_cols + 1) field_format = [None] + ['{:.6f}'] * profile_cols template = 'column' elif tool == 'cprofile': column_widths = [14, 80] template = 'row' with profile_print(column_widths, field_format, template) as out: column_headers = [] for size in batch_sizes: column_headers += [ 'SAMPLE (N=' + str(size) + ')', 'LOG_PROB (N=' + str(size) + ')' ] out.header(['DISTRIBUTION'] + column_headers) for dist_name in dists: Dist, params = DISTRIBUTIONS[dist_name] result_row = [dist_name] dist = Dist(**params) for size in batch_sizes: sample_result, sample_prof = sample(dist, batch_size=size) _, logpdf_prof = log_prob(dist, sample_result) result_row += [sample_prof, logpdf_prof] out.push(result_row)
def run_with_tool(tool, dists, batch_sizes): column_widths, field_format, template = None, None, None if tool == 'timeit': profile_cols = 2 * len(batch_sizes) column_widths = [14] * (profile_cols + 1) field_format = [None] + ['{:.6f}'] * profile_cols template = 'column' elif tool == 'cprofile': column_widths = [14, 80] template = 'row' with profile_print(column_widths, field_format, template) as out: column_headers = [] for size in batch_sizes: column_headers += ['SAMPLE (N=' + str(size) + ')', 'LOG_PROB (N=' + str(size) + ')'] out.header(['DISTRIBUTION'] + column_headers) for dist_name in dists: Dist, params = DISTRIBUTIONS[dist_name] result_row = [dist_name] dist = Dist(**params) for size in batch_sizes: sample_result, sample_prof = sample(dist, batch_size=size) _, logpdf_prof = log_prob(dist, sample_result) result_row += [sample_prof, logpdf_prof] out.push(result_row)