コード例 #1
0
def assert_dir_exact_files(test: unittest.TestCase, directory: str,
                           contents: dict) -> None:
    """
    Similar to assert_dir_files, yet this one will check for an exact match of files
    Args:
        test (unittest.TestCase): The testcase to use
        directory (str): given directory to check
        contents (dict): mapping 'filename -> content'

    Returns:
        flag indicating all files are exactly as expected
    """
    test.assertSetEqual(set(contents.keys()), set(os.listdir(directory)),
                        'For: ' + directory)
    assert_dir_files(test, directory, contents)
コード例 #2
0
ファイル: training.py プロジェクト: Planet-AI-GmbH/tfaip
def test_tensorboard_content(test: unittest.TestCase, output_dir: str, logs: Dict[str, Any], trainer):
    tb_data_handler = TensorBoardDataHandler(trainer.scenario.keras_train_model)

    from tensorflow.python.summary.summary_iterator import summary_iterator

    all_event_files = glob.glob(os.path.join(output_dir, "**", "events.out.*"), recursive=True)
    logs_to_find = logs.copy()

    def rename(s: str) -> str:
        return s.replace("@", "_")

    logs_to_find = {rename(k): v for k, v in logs_to_find.items()}
    for event_file in all_event_files:
        log_type = os.path.split(os.path.relpath(event_file, output_dir))[0]
        test.assertTrue(log_type in {"train", "validation"} or log_type.startswith("lav_"))

        def add_prefix(k: str):
            if log_type == "train":
                return k
            if log_type == "validation":
                return "val_" + k
            return log_type + "_" + k

        additional_outputs_per_event = set(tb_data_handler.tensorboard_handlers.keys())
        if log_type in "train":
            additional_outputs_per_event.add("lr")

        # Check that (at least for step 0) all metrics/losses are written to the tensorboard log
        for e in summary_iterator(event_file):
            if e.step != 0:
                continue
            if len(e.summary.value) != 1:
                continue
            value = e.summary.value[0]
            if value.tag.startswith("epoch_") == 0:
                continue
            tag = value.tag[6:]  # no epoch_
            if tag in additional_outputs_per_event:
                additional_outputs_per_event.remove(tag)
                continue
            tag = add_prefix(tag)
            del logs_to_find[tag]

        test.assertSetEqual(set(), additional_outputs_per_event)

    test.assertDictEqual({}, logs_to_find)
コード例 #3
0
def validate_attributed_ast(test: unittest.TestCase,
                            source_unit: ast.SourceUnit):
    grammar = get_solidity_grammar_instance()

    stmts = []
    ids_to_vars = {}

    def regsiter_variables(a, _, __):
        if isinstance(a, Statement):
            stmts.append(a)

        if isinstance(a, VariableDeclaration):
            test.assertNotIn(a.id, ids_to_vars)
            ids_to_vars[a.id] = a.name

    grammar.traverse(source_unit, regsiter_variables)

    # Assert unambiguous line<=>stmt correspondence
    test.assertEqual(len(stmts), len({s.src_line for s in stmts}))

    stmts_by_line = {s.src_line: s for s in stmts}

    for i, line in enumerate(source_unit.source.split("\n")):
        if "//" not in line:
            continue

        _, line, *_ = line.split("//")

        expected_scope_vars = line.split(",")
        expected_scope_vars = {v.strip() for v in expected_scope_vars}

        stmt = stmts_by_line[i + 1]
        actual_scope_vars = stmt
        actual_scope_vars = {
            ids_to_vars[q]
            for q in actual_scope_vars.scope_post
        }

        # print(stmt, actual_scope_vars, expected_scope_vars)

        test.assertSetEqual(expected_scope_vars, actual_scope_vars,
                            f"Line: {i}")
コード例 #4
0
ファイル: tests.py プロジェクト: JonLevin25/coding-problems
def validate_tree_node(tester: unittest.TestCase, node: WordTreeNode,
                       expected_words: Set[str], expected_children: Set[str]):

    tester.assertSetEqual(node.words, expected_words)
    tester.assertSetEqual(set(node.children.keys()), expected_children)