Ejemplo n.º 1
0
    def __init__(self, training_set, test_set, validation_set = None, name = None):

        sets = [training_set, test_set] + [validation_set] if validation_set is not None else []
        assert all_equal([[x.shape[1:] for x in s.inputs] for s in sets])
        assert all_equal([[x.shape[1:] for x in s.targets] for s in sets])
        self.training_set = training_set
        self.test_set = test_set
        self._validation_set = validation_set
        self._name = name
        self._n_categories = None
Ejemplo n.º 2
0
    def __init__(self, training_set, test_set, validation_set=None, name=None):

        sets = [training_set, test_set
                ] + [validation_set] if validation_set is not None else []
        assert all_equal(*[[x.shape[1:] for x in s.inputs] for s in sets])
        assert all_equal(*[[x.shape[1:] for x in s.targets] for s in sets])
        self.training_set = training_set
        self.test_set = test_set
        self._validation_set = validation_set
        self._name = name
        self._n_categories = None
Ejemplo n.º 3
0
    def compare(self, *args):
        parser = argparse.ArgumentParser()
        parser.add_argument('user_range', action='store', help='A selection of experiment records to compare.  Examples: "3" or "3-5", or "3,4,5"')
        parser.add_argument('-l', '--last', default=False, action = "store_true", help="Use this flag if you want to select Experiments instead of Experiment Records, and just show the last completed.")
        parser.add_argument('-r', '--results', default=False, action = "store_true", help="Only compare records with results.")
        parser.add_argument('-o', '--original', default=False, action = "store_true", help="Use original compare funcion")

        args = parser.parse_args(args)

        user_range = args.user_range if not args.last else args.user_range + '@result@last'
        records = select_experiment_records(user_range, self.exp_record_dict, flat=True)
        if args.results:
            records = [rec for rec in records if rec.has_result()]
        if len(records)==0:
            raise RecordSelectionError('No records were selected with "{}"'.format(args.user_range))

        if args.original:
            func = compare_experiment_records
        else:
            compare_funcs = [rec.get_experiment().compare for rec in records]
            assert all_equal(compare_funcs), "Your records have different comparison functions - {} - so you can't compare them".format(set(compare_funcs))
            func = compare_funcs[0]

        # The following could be used to launch comparisons in a  new process.  We don't do this now because
        # comparison function often use matplotlib, and matplotlib's Tkinter backend hangs when trying to create
        # a new figure in a new thread.
        # thread = Process(target = partial(func, records))
        # thread.start()
        # thread.join()

        func(records)
        _warn_with_prompt(use_prompt=False)
Ejemplo n.º 4
0
    def get_record_table(records = None, headers = ('#', 'Identifier', 'Start Time', 'Duration', 'Status', 'Valid', 'Notes', 'Result'), raise_display_errors = False, result_truncation=100):

        d = {
            '#': lambda: i,
            'Identifier': lambda: experiment_record.get_id(),
            'Start Time': lambda: experiment_record.info.get_field_text(ExpInfoFields.TIMESTAMP, replacement_if_none='?'),
            'Duration': lambda: experiment_record.info.get_field_text(ExpInfoFields.RUNTIME, replacement_if_none='?'),
            'Status': lambda: experiment_record.info.get_field_text(ExpInfoFields.STATUS, replacement_if_none='?'),
            'Args': lambda: experiment_record.info.get_field_text(ExpInfoFields.ARGS, replacement_if_none='?'),
            'Valid': lambda: get_record_invalid_arg_string(experiment_record, note_version='short'),
            'Notes': lambda: experiment_record.info.get_field_text(ExpInfoFields.NOTES, replacement_if_none='?'),
            'Result': lambda: get_oneline_result_string(experiment_record, truncate_to=128)
            # experiment_record.get_experiment().get_oneline_result_string(truncate_to=result_truncation) if is_experiment_loadable(experiment_record.get_experiment_id()) else '<Experiment not loaded>'
            }

        def get_col_info(headers):
            info = []
            for h in headers:
                try:
                    info.append(d[h]())
                except:
                    info.append('<Error displaying info>')
                    if raise_display_errors:
                        raise
            return info

        rows = []
        for i, experiment_record in enumerate(records):
            rows.append(get_col_info(headers))
        assert all_equal([len(headers)] + [len(row) for row in rows]), 'Header length: {}, Row Lengths: \n {}'.format(len(headers), [len(row) for row in rows])
        return tabulate(rows, headers=headers)
Ejemplo n.º 5
0
    def compare(self, *args):
        parser = argparse.ArgumentParser()
        parser.add_argument('user_range', action='store', help='A selection of experiment records to compare.  Examples: "3" or "3-5", or "3,4,5"')
        parser.add_argument('-l', '--last', default=False, action = "store_true", help="Use this flag if you want to select Experiments instead of Experiment Records, and just show the last completed.")
        parser.add_argument('-r', '--results', default=False, action = "store_true", help="Only compare records with results.")
        parser.add_argument('-o', '--original', default=False, action = "store_true", help="Use original compare funcion")

        args = parser.parse_args(args)

        user_range = args.user_range if not args.last else args.user_range + '@result@last'
        records = select_experiment_records(user_range, self.exp_record_dict, flat=True)
        if args.results:
            records = [rec for rec in records if rec.has_result()]
        if len(records)==0:
            raise RecordSelectionError('No records were selected with "{}"'.format(args.user_range))

        if args.original:
            func = compare_experiment_records
        else:
            compare_funcs = [rec.get_experiment().compare for rec in records]
            assert all_equal(compare_funcs), "Your records have different comparison functions - {} - so you can't compare them".format(set(compare_funcs))
            func = compare_funcs[0]

        # The following could be used to launch comparisons in a  new process.  We don't do this now because
        # comparison function often use matplotlib, and matplotlib's Tkinter backend hangs when trying to create
        # a new figure in a new thread.
        # thread = Process(target = partial(func, records))
        # thread.start()
        # thread.join()

        func(records)
        _warn_with_prompt(use_prompt=False)
Ejemplo n.º 6
0
    def get_record_table(records = None, headers = ('#', 'Identifier', 'Start Time', 'Duration', 'Status', 'Valid', 'Notes', 'Result'), raise_display_errors = False, result_truncation=100):

        d = {
            '#': lambda: i,
            'Identifier': lambda: experiment_record.get_id(),
            'Start Time': lambda: experiment_record.info.get_field_text(ExpInfoFields.TIMESTAMP, replacement_if_none='?'),
            'Duration': lambda: experiment_record.info.get_field_text(ExpInfoFields.RUNTIME, replacement_if_none='?'),
            'Status': lambda: experiment_record.info.get_field_text(ExpInfoFields.STATUS, replacement_if_none='?'),
            'Args': lambda: experiment_record.info.get_field_text(ExpInfoFields.ARGS, replacement_if_none='?'),
            'Valid': lambda: get_record_invalid_arg_string(experiment_record, note_version='short'),
            'Notes': lambda: experiment_record.info.get_field_text(ExpInfoFields.NOTES, replacement_if_none='?'),
            'Result': lambda: get_oneline_result_string(experiment_record, truncate_to=128)
            # experiment_record.get_experiment().get_oneline_result_string(truncate_to=result_truncation) if is_experiment_loadable(experiment_record.get_experiment_id()) else '<Experiment not loaded>'
            }

        def get_col_info(headers):
            info = []
            for h in headers:
                try:
                    info.append(d[h]())
                except:
                    info.append('<Error displaying info>')
                    if raise_display_errors:
                        raise
            return info

        rows = []
        for i, experiment_record in enumerate(records):
            rows.append(get_col_info(headers))
        assert all_equal([len(headers)] + [len(row) for row in rows]), 'Header length: {}, Row Lengths: \n {}'.format(len(headers), [len(row) for row in rows])
        return tabulate(rows, headers=headers)
Ejemplo n.º 7
0
def compare_experiment_results(experiments, error_if_no_result = False):
    comp_functions = [ex.comparison_function for ex in experiments]
    assert all_equal(comp_functions), 'Experiments must have same comparison functions.'
    comp_function = comp_functions[0]
    assert comp_function is not None, 'Cannot compare results, because you have not specified any comparison function for this experiment.  Use @ExperimentFunction(comparison_function = my_func)'
    results = load_lastest_experiment_results(experiments, error_if_no_result=error_if_no_result)
    assert len(results), 'Experments {} had no saved results!'.format([e.get_id() for e in experiments])
    comp_function(results)
Ejemplo n.º 8
0
def find_pareto_ixs(cost_arrays):
    """
    :param cost_arrays: A collection of nd-arrays representing a grid of costs for different indices.
    :return: A tuple of indices which can be used to index the pareto-efficient points.
    """
    assert all_equal([c.shape for c in cost_arrays])
    flat_ixs, = np.nonzero(is_pareto_efficient(np.reshape(cost_arrays, (len(cost_arrays), -1)).T), )
    ixs = np.unravel_index(flat_ixs, dims=cost_arrays[0].shape)
    return ixs
Ejemplo n.º 9
0
def find_pareto_ixs(cost_arrays):
    """
    :param cost_arrays: A collection of nd-arrays representing a grid of costs for different indices.
    :return: A tuple of indices which can be used to index the pareto-efficient points.
    """
    assert all_equal([c.shape for c in cost_arrays])
    flat_ixs, = np.nonzero(
        is_pareto_efficient_simple(
            np.reshape(cost_arrays, (len(cost_arrays), -1)).T), )
    ixs = np.unravel_index(flat_ixs, dims=cost_arrays[0].shape)
    return ixs
Ejemplo n.º 10
0
def filter_results(results, category_filters):

    assert all_equal(len(k) for k in results.keys())
    assert len(results.keys()[0]) == len(category_filters)
    filtered_results = OrderedDict(
        (tuple(cv for cv, cf in zip(category_values, category_filters)
               if isinstance(cf, list)), measures)
        for category_values, measures in zip(results)
        if all(cv in cf if isinstance(cf, list) else cv == cf
               for cv, cf in zip(category_values, category_filters)))
    return filtered_results
Ejemplo n.º 11
0
    def get_record_table(record_ids=None,
                         headers=('#', 'Identifier', 'Start Time', 'Duration',
                                  'Status', 'Notes', 'Result'),
                         raise_display_errors=False):

        d = {
            '#':
            lambda: i,
            'Identifier':
            lambda: record_id,
            'Start Time':
            lambda: experiment_record.info.get_field_text(
                ExpInfoFields.TIMESTAMP, replacement_if_none='?'),
            'Duration':
            lambda: experiment_record.info.get_field_text(
                ExpInfoFields.RUNTIME, replacement_if_none='?'),
            'Status':
            lambda: experiment_record.info.get_field_text(
                ExpInfoFields.STATUS, replacement_if_none='?'),
            'Args':
            lambda: experiment_record.info.get_field_text(
                ExpInfoFields.ARGS, replacement_if_none='?'),
            'Notes':
            lambda: experiment_record.info.get_field_text(
                ExpInfoFields.NOTES, replacement_if_none='?'),
            'Result':
            lambda: experiment_record.get_one_liner(),
        }

        def get_col_info(headers):
            info = []
            for h in headers:
                try:
                    info.append(d[h]())
                except:
                    info.append('<Error displaying info>')
                    if raise_display_errors:
                        raise
            return info

        rows = []
        for i, record_id in enumerate(record_ids):
            experiment_record = load_experiment_record(record_id)
            rows.append(get_col_info(headers))
        assert all_equal(
            [len(headers)] +
            [len(row)
             for row in rows]), 'Header length: {}, Row Lengths: \n {}'.format(
                 len(headers), [len(row) for row in rows])
        return tabulate(rows, headers=headers)
Ejemplo n.º 12
0
def build_table(lookup_fcn,
                row_categories,
                column_categories,
                clear_repeated_headers=True,
                prettify_labels=True,
                row_header_labels=None):
    """
    Build the rows of a table.  You can feed these rows into tabulate to generate pretty things.

    :param lookup_fcn: A function of the form:
        data = lookup_fcn(row_info, column_info)
        Where:
            row_info is a tuple of data identifying the row.
            col_info is a tuple of data identifying the column
    :param row_categories: A list<list<str>> of categories that will make up the rows
    :param column_categories: A list<list<str>> of catefories that will make up the columns
    :param clear_repeated_headers: True to not repeat row headers.
    :param row_header_labels: Labels for the row headers.
    :return: A list of rows.
    """
    # Now, build that table!
    if row_header_labels is not None:
        assert len(row_header_labels) == len(row_categories)
    rows = []
    column_headers = zip(*itertools.product(*column_categories))
    for i, c in enumerate(column_headers):
        row_header = row_header_labels if row_header_labels is not None and i == len(
            column_headers) - 1 else [' '] * len(row_header_labels)
        row = row_header + blank_out_repeats(
            c) if clear_repeated_headers else list(c)
        rows.append([prettify_label(el)
                     for el in row] if prettify_labels else row)
    last_row_data = [' '] * len(row_categories)
    for row_info in itertools.product(*row_categories):
        if blank_out_repeats:
            row_header, last_row_data = zip(
                *[(h, h) if lh != h else (' ', lh)
                  for h, lh in zip(row_info, last_row_data)])
        else:
            row_header = row_info
        if prettify_labels:
            row_header = [prettify_label(str(el)) for el in row_header]
        data = [
            lookup_fcn(row_info, column_info)
            for column_info in itertools.product(*column_categories)
        ]
        rows.append(list(row_header) + data)
    assert all_equal(len(r) for r in rows)
    return rows
Ejemplo n.º 13
0
def nested_map(func, *nested_objs, **kwargs):
    """
    An equivalent of pythons built-in map, but for nested objects.  This function crawls the object and applies func
    to the leaf nodes.

    :param func: A function of the form new_leaf_val = func(old_leaf_val)
    :param nested_obj: A nested object e.g. [1, 2, {'a': 3, 'b': (3, 4)}, 5]
    :param check_types: Assert that the new leaf types match the old leaf types (False by default)
    :param is_container_func: A callback which returns True if an object is to be considered a container and False otherwise
    :return: A nested objectect with the same structure, but func applied to every value.
    """
    is_container_func = kwargs['is_container_func'] if 'is_container_func' in kwargs else _is_primitive_container
    check_types = kwargs['check_types'] if 'check_types' in kwargs else False
    assert len(nested_objs)>0, 'nested_map requires at least 2 args'

    assert callable(func), 'func must be a function with one argument.'
    nested_types = [NestedType.from_data(nested_obj, is_container_func=is_container_func) for nested_obj in nested_objs]
    assert all_equal(nested_types), "The nested objects you provided had different data structures:\n{}".format('\n'.join(str(s) for s in nested_types))
    leaf_values = zip(*[nested_type.get_leaves(nested_obj, is_container_func=is_container_func, check_types=check_types) for nested_type, nested_obj in zip(nested_types, nested_objs)])
    new_leaf_values = [func(*v) for v in leaf_values]
    new_nested_obj = nested_types[0].expand_from_leaves(new_leaf_values, check_types=check_types, is_container_func=is_container_func)
    return new_nested_obj
Ejemplo n.º 14
0
def nested_map(func, *nested_objs, **kwargs):
    """
    An equivalent of pythons built-in map, but for nested objects.  This function crawls the object and applies func
    to the leaf nodes.

    :param func: A function of the form new_leaf_val = func(old_leaf_val)
    :param nested_obj: A nested object e.g. [1, 2, {'a': 3, 'b': (3, 4)}, 5]
    :param check_types: Assert that the new leaf types match the old leaf types (False by default)
    :param is_container_func: A callback which returns True if an object is to be considered a container and False otherwise
    :return: A nested objectect with the same structure, but func applied to every value.
    """
    is_container_func = kwargs[
        'is_container_func'] if 'is_container_func' in kwargs else _is_primitive_container
    check_types = kwargs['check_types'] if 'check_types' in kwargs else False
    assert len(nested_objs) > 0, 'nested_map requires at least 2 args'

    assert callable(func), 'func must be a function with one argument.'
    nested_types = [
        NestedType.from_data(nested_obj, is_container_func=is_container_func)
        for nested_obj in nested_objs
    ]
    assert all_equal(
        nested_types
    ), "The nested objects you provided had different data structures:\n{}".format(
        '\n'.join(str(s) for s in nested_types))
    leaf_values = zip(*[
        nested_type.get_leaves(nested_obj,
                               is_container_func=is_container_func,
                               check_types=check_types)
        for nested_type, nested_obj in zip(nested_types, nested_objs)
    ])
    new_leaf_values = [func(*v) for v in leaf_values]
    new_nested_obj = nested_types[0].expand_from_leaves(
        new_leaf_values,
        check_types=check_types,
        is_container_func=is_container_func)
    return new_nested_obj
Ejemplo n.º 15
0
    def compare(self, *args):
        parser = argparse.ArgumentParser()
        parser.add_argument(
            'user_range',
            action='store',
            help=
            'A selection of experiment records to compare.  Examples: "3" or "3-5", or "3,4,5"'
        )
        parser.add_argument(
            '-l',
            '--last',
            default=False,
            action="store_true",
            help=
            "Use this flag if you want to select Experiments instead of Experiment Records, and just show the last completed."
        )
        parser.add_argument('-r',
                            '--results',
                            default=False,
                            action="store_true",
                            help="Only compare records with results.")
        args = parser.parse_args(args)

        user_range = args.user_range if not args.last else args.user_range + '>finished>last'
        records = select_experiment_records(user_range,
                                            self.exp_record_dict,
                                            flat=True)
        if args.results:
            records = [rec for rec in records if rec.has_result()]
        compare_funcs = [rec.get_experiment().compare for rec in records]
        assert all_equal(
            compare_funcs
        ), "Your records have different comparison functions - {} - so you can't compare them".format(
            set(compare_funcs))
        func = compare_funcs[0]
        func(records)
        _warn_with_prompt(use_prompt=not self.close_after)
Ejemplo n.º 16
0
def test_all_equal():

    assert all_equal([2, 2, 2])
    assert not all_equal([2, 2, 3])
    assert all_equal([])
Ejemplo n.º 17
0
def build_table(lookup_fcn,
                row_categories,
                column_categories,
                clear_repeated_headers=True,
                prettify_labels=True,
                row_header_labels=None,
                remove_unchanging_cols=False):
    """
    Build the rows of a table.  You can feed these rows into tabulate to generate pretty things.

        Example (requires installing tabulate (pip install tabulate):
        For the table of total utility in prisoner's dillema (see https://en.wikipedia.org/wiki/Prisoner%27s_dilemma):

        def lookup_function(prisoner_a_choice, prisoner_b_choice):
            total_utility = \
                2 if prisoner_a_choice=='cooperate' and prisoner_b_choice=='cooperate' else \
                3 if prisoner_a_choice != prisoner_b_choice else \
                4 if prisoner_b_choice=='betray' and prisoner_a_choice=='betray' \
                else bad_value((prisoner_a_choice, prisoner_b_choice))
            return total_utility

        rows = build_table(lookup_function, row_categories=['cooperate', 'betray'], column_categories=['cooperate', 'betray'])
        print tabulate.tabulate(rows)

        ---------  ---------  ------
                   Cooperate  Betray
        Cooperate  2          3
        Betray     3          4
        ---------  ---------  ------

        See more examples in test_tables.

    :param lookup_fcn: A function of the form:
        data = lookup_fcn(row_info, column_info)
        Where:
            row_info is a tuple of data identifying the row.
            col_info is a tuple of data identifying the column
    :param row_categories: A list<list<str>> of categories that will make up the rows
    :param column_categories: A list<list<str>> of catefories that will make up the columns
    :param clear_repeated_headers: True to not repeat row headers.
    :param row_header_labels: Labels for the row headers.
    :param remove_unchanging_cols: Remove columns for which all d
    :return: A list of rows.
    """
    # Now, build that table!
    single_row_category = all(
        isinstance(c, basestring) for c in row_categories)
    single_column_category = all(
        isinstance(c, basestring) for c in column_categories)

    if single_row_category:
        row_categories = [row_categories]
    if single_column_category:
        column_categories = [column_categories]
    if row_header_labels is not None:
        assert len(row_header_labels) == len(row_categories)
    rows = []
    column_headers = zip(*itertools.product(*column_categories))
    for i, c in enumerate(column_headers):
        row_header = row_header_labels if row_header_labels is not None and i == len(
            column_headers) - 1 else [' '] * len(row_categories)
        row = row_header + (blank_out_repeats(c)
                            if clear_repeated_headers else list(c))
        rows.append([prettify_label(el)
                     for el in row] if prettify_labels else row)
    last_row_data = [' '] * len(row_categories)
    for row_info in itertools.product(*row_categories):
        if clear_repeated_headers:
            row_header, last_row_data = zip(
                *[(h, h) if lh != h else (' ', lh)
                  for h, lh in zip(row_info, last_row_data)])
        else:
            row_header = row_info
        if prettify_labels:
            row_header = [prettify_label(str(el)) for el in row_header]
        data = [
            lookup_fcn(
                row_info[0] if single_row_category else row_info,
                column_info[0] if single_column_category else column_info)
            for column_info in itertools.product(*column_categories)
        ]
        rows.append(list(row_header) + data)
    assert all_equal(
        (len(r) for r in rows)
    ), "All rows must have equal length.  They now have lengths: {}".format(
        [len(r) for r in rows])

    if remove_unchanging_cols:
        for col_ix in range(len(rows[0]))[::-1]:
            if all_equal([row[col_ix] for row in rows[len(column_headers):]]):
                for row in rows:
                    del row[col_ix]
    return rows
Ejemplo n.º 18
0
    def get_experiment_list_str(exp_record_dict,
                                just_last_record,
                                view_mode='full',
                                raise_display_errors=False,
                                truncate_result_to=100,
                                cache_result_string=True):

        headers = {
            'full': [
                'E#', 'R#', 'Name',
                'Last Run' if just_last_record else 'All Runs', 'Duration',
                'Status', 'Valid', 'Result'
            ],
            'results': ['E#', 'R#', 'Name', 'Result']
        }[view_mode]

        rows = []

        oneliner_func = memoize_to_disk_with_settings(suppress_info=True)(
            get_oneline_result_string
        ) if cache_result_string else get_oneline_result_string

        def get_field(header):
            try:
                return \
                    index if header=='#' else \
                    (str(i) if j==0 else '') if header == 'E#' else \
                    j if header == 'R#' else \
                    (name if j==0 else '') if header=='Name' else \
                    experiment_record.info.get_field_text(ExpInfoFields.TIMESTAMP) if header in ('Last Run', 'All Runs') else \
                    experiment_record.info.get_field_text(ExpInfoFields.RUNTIME) if header=='Duration' else \
                    experiment_record.info.get_field_text(ExpInfoFields.STATUS) if header=='Status' else \
                    get_record_invalid_arg_string(experiment_record) if header=='Valid' else \
                    oneliner_func(experiment_record.get_id(), truncate_to=truncate_result_to) if header=='Result' else \
                    '???'
            except:
                if raise_display_errors:
                    raise
                return '<Display Error>'

        for i, (exp_id, record_ids) in enumerate(exp_record_dict.iteritems()):
            if len(record_ids) == 0:
                if exp_id in exp_record_dict:
                    rows.append([
                        str(i), '', exp_id, '<No Records>', '-', '-', '-', '-'
                    ])
            else:
                for j, record_id in enumerate(record_ids):
                    index, name = ['{}.{}'.format(
                        i, j), exp_id] if j == 0 else [
                            '{}.{}'.format('`' * len(str(i)), j), exp_id
                        ]
                    try:
                        experiment_record = load_experiment_record(record_id)
                    except:
                        experiment_record = None
                    rows.append([get_field(h) for h in headers])
        assert all_equal([len(headers)] + [len(row) for row in rows]
                         ), 'Header length: {}, Row Lengths: \n  {}'.format(
                             len(headers),
                             '\n'.join([len(row) for row in rows]))
        table = tabulate(rows, headers=headers)
        return table
Ejemplo n.º 19
0
def test_all_equal():

    assert all_equal([2, 2, 2])
    assert not all_equal([2, 2, 3])
    assert all_equal([])
Ejemplo n.º 20
0
    def get_experiment_list_str(exp_record_dict,
                                just_last_record,
                                view_mode='full',
                                raise_display_errors=False,
                                exp_filter=None):

        headers = {
            'full': [
                'E#', 'R#', 'Name',
                'Last Run' if just_last_record else 'All Runs', 'Duration',
                'Status', 'Valid', 'Result'
            ],
            'results': ['E#', 'R#', 'Name', 'Result']
        }[view_mode]

        rows = []

        def get_field(header):
            try:
                return \
                    index if header=='#' else \
                    (str(i) if j==0 else '') if header == 'E#' else \
                    j if header == 'R#' else \
                    (name if j==0 else '') if header=='Name' else \
                    experiment_record.info.get_field_text(ExpInfoFields.TIMESTAMP) if header in ('Last Run', 'All Runs') else \
                    experiment_record.info.get_field_text(ExpInfoFields.RUNTIME) if header=='Duration' else \
                    experiment_record.info.get_field_text(ExpInfoFields.STATUS) if header=='Status' else \
                    experiment_record.get_invalid_arg_note() if header=='Valid' else \
                    experiment_record.get_one_liner() if header=='Result' else \
                    '???'
            except:
                if raise_display_errors:
                    raise
                return '<Error>'

        exps_to_show = set(
            exp_record_dict.keys()) if exp_filter is None else set(
                select_experiments(exp_filter, exp_record_dict))

        for i, (exp_id, record_ids) in enumerate(exp_record_dict.iteritems()):
            if len(record_ids) == 0:
                rows.append(
                    [str(i), '', exp_id, '<No Records>', '-', '-', '-', '-'])
            else:
                for j, record_id in enumerate(record_ids):
                    index, name = ['{}.{}'.format(
                        i, j), exp_id] if j == 0 else [
                            '{}.{}'.format('`' * len(str(i)), j), exp_id
                        ]
                    try:
                        experiment_record = load_experiment_record(record_id)
                    except:
                        experiment_record = None
                    if exp_id in exps_to_show:
                        rows.append([get_field(h) for h in headers])
        assert all_equal([len(headers)] + [len(row) for row in rows]
                         ), 'Header length: {}, Row Lengths: \n  {}'.format(
                             len(headers),
                             '\n'.join([len(row) for row in rows]))
        table = tabulate(rows, headers=headers)
        if exp_filter:
            table += '\n[Filtered with "{}" to show {}/{} experiments]'.format(
                exp_filter, len(exps_to_show), len(exp_record_dict))
        return table
Ejemplo n.º 21
0
def build_table(lookup_fcn, row_categories, column_categories, clear_repeated_headers = True, prettify_labels = True,
            row_header_labels = None, remove_unchanging_cols = False):
    """
    Build the rows of a table.  You can feed these rows into tabulate to generate pretty things.

        Example (requires installing tabulate (pip install tabulate):
        For the table of total utility in prisoner's dillema (see https://en.wikipedia.org/wiki/Prisoner%27s_dilemma):

        def lookup_function(prisoner_a_choice, prisoner_b_choice):
            total_utility = \
                2 if prisoner_a_choice=='cooperate' and prisoner_b_choice=='cooperate' else \
                3 if prisoner_a_choice != prisoner_b_choice else \
                4 if prisoner_b_choice=='betray' and prisoner_a_choice=='betray' \
                else bad_value((prisoner_a_choice, prisoner_b_choice))
            return total_utility

        rows = build_table(lookup_function, row_categories=['cooperate', 'betray'], column_categories=['cooperate', 'betray'])
        print tabulate.tabulate(rows)

        ---------  ---------  ------
                   Cooperate  Betray
        Cooperate  2          3
        Betray     3          4
        ---------  ---------  ------

        See more examples in test_tables.

    :param lookup_fcn: A function of the form:
        data = lookup_fcn(row_info, column_info)
        Where:
            row_info is a tuple of data identifying the row.
            col_info is a tuple of data identifying the column
    :param row_categories: A list<list<str>> of categories that will make up the rows
    :param column_categories: A list<list<str>> of catefories that will make up the columns
    :param clear_repeated_headers: True to not repeat row headers.
    :param row_header_labels: Labels for the row headers.
    :param remove_unchanging_cols: Remove columns for which all d
    :return: A list of rows.
    """
    # Now, build that table!
    single_row_category = all(isinstance(c, string_types) for c in row_categories)
    single_column_category = all(isinstance(c, string_types) for c in column_categories)

    if single_row_category:
        row_categories = [row_categories]
    if single_column_category:
        column_categories = [column_categories]
    if row_header_labels is not None:
        assert len(row_header_labels) == len(row_categories)
    rows = []
    column_headers = list(zip(*itertools.product(*column_categories)))
    for i, c in enumerate(column_headers):
        row_header = row_header_labels if row_header_labels is not None and i==len(column_headers)-1 else [' ']*len(row_categories)
        row = row_header+(blank_out_repeats(c) if clear_repeated_headers else list(c))
        rows.append([prettify_label(el) for el in row] if prettify_labels else row)
    last_row_data = [' ']*len(row_categories)
    for row_info in itertools.product(*row_categories):
        if clear_repeated_headers:
            row_header, last_row_data = zip(*[(h, h) if lh!=h else (' ', lh) for h, lh in zip(row_info, last_row_data)])
        else:
            row_header = row_info
        if prettify_labels:
            row_header = [prettify_label(str(el)) for el in row_header]
        data = [lookup_fcn(row_info[0] if single_row_category else row_info, column_info[0] if single_column_category else column_info) for column_info in itertools.product(*column_categories)]
        rows.append(list(row_header) + data)
    assert all_equal((len(r) for r in rows)), "All rows must have equal length.  They now have lengths: {}".format([len(r) for r in rows])

    if remove_unchanging_cols:
        for col_ix in range(len(rows[0]))[::-1]:
            if all_equal([row[col_ix] for row in rows[len(column_headers):]]):
                for row in rows:
                    del row[col_ix]
    return rows