def print_table(table, header_text, row_labels, col_labels, colwidth=10,
    latex=True):
    """Pretty-print a 2D array of data, optionally with row/col labels"""
    print('')

    if latex:
        num_cols = len(table[0])
        print("\\begin{center}")
        print("\\begin{tabular}{l" + "c" * num_cols + "}")
        print("\\toprule")
        print('\multicolumn{8}{c}{'+header_text+'} \\')
        print("\\midrule")
    else:
        print('--------', header_text)

    for row, label in zip(table, row_labels):
        row.insert(0, label)

    if latex:
        col_labels = ["\\textbf{" + str(col_label) + "}"
            for col_label in col_labels]
    table.insert(0, col_labels)

    for r, row in enumerate(table):
        misc.print_row(row, colwidth=colwidth, latex=latex)
        if latex and r == 0:
            print("\\midrule")
    if latex:
        print("\\bottomrule")
        print("\\end{tabular}")
        print("\\end{center}")
Exemple #2
0
def print_table(table, header_text, row_labels, col_labels, colwidth=10,
    latex=True):
    """Pretty-print a 2D array of data, optionally with row/col labels"""
    print("")

    if latex:
        num_cols = len(table[0])
        print("\\begin{center}")
        print("\\adjustbox{max width=\\textwidth}{%")
        print("\\begin{tabular}{l" + "c" * num_cols + "}")
        print("\\toprule")
    else:
        print("--------", header_text)

    for row, label in zip(table, row_labels):
        row.insert(0, label)

    if latex:
        col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}"
            for col_label in col_labels]
    table.insert(0, col_labels)

    for r, row in enumerate(table):
        misc.print_row(row, colwidth=colwidth, latex=latex)
        if latex and r == 0:
            print("\\midrule")
    if latex:
        print("\\bottomrule")
        print("\\end{tabular}}")
        print("\\end{center}")
Exemple #3
0
            results = {
                'step': step,
                'epoch': step / steps_per_epoch,
            }

            for key, val in checkpoint_vals.items():
                results[key] = np.mean(val)

            evals = zip(eval_loader_names, eval_loaders, eval_weights)
            for name, loader, weights in evals:
                acc = misc.accuracy(algorithm, loader, weights, device)
                results[name + '_acc'] = acc

            results_keys = sorted(results.keys())
            if results_keys != last_results_keys:
                misc.print_row(results_keys, colwidth=12)
                last_results_keys = results_keys
            misc.print_row([results[key] for key in results_keys], colwidth=12)

            results.update({'hparams': hparams, 'args': vars(args)})

            epochs_path = os.path.join(args.output_dir, 'results.jsonl')
            with open(epochs_path, 'a') as f:
                f.write(json.dumps(results, sort_keys=True) + "\n")

            algorithm_dict = algorithm.state_dict()
            start_step = step + 1
            checkpoint_vals = collections.defaultdict(lambda: [])

            records = []
            with open(epochs_path, 'r') as f:
Exemple #4
0
        if args.algorithm == 'MULDENS':
            eval_dict = {}
            for name, loader, weights in evals:
                eval_dict[name] = [loader, weights]
            models_selected = [
                1, 1, 1
            ]  #step_vals['models_selected'] # random stuff. we will not use it anyway

            correct_models_selected_for_each_domain = np.nan * np.ones(
                len(dataset))
            train_envs = [
                i for i in range(len(dataset)) if i not in args.test_envs
            ]
            for t, m in zip(train_envs, models_selected):
                correct_models_selected_for_each_domain[t] = m
            acc_flags = {'ensemble_for_obs': True, 'compute_test_beta': False}

            results_MULDENS = misc.MULDENS_accuracy(
                algorithm, eval_dict, args.test_envs,
                correct_models_selected_for_each_domain, device, acc_flags)
            results_MULDENS['step'] = step_number

            results_keys = sorted(
                [k for k in results_MULDENS.keys() if 'acc' in k])
            results_keys.append('step')

            if i == 0:
                misc.print_row(results_keys, colwidth=25)
            misc.print_row([results_MULDENS[key] for key in results_keys],
                           colwidth=25)