def test_should_build_train_graph_with_defaults(
            self, parse_color_map_from_file_mock, read_examples_mock):

        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
        read_examples_mock.return_value = EXAMPLE_PROPS_1
        args = create_args(DEFAULT_ARGS)
        model = Model(args)
        model.build_train_graph(DATA_PATH, BATCH_SIZE)
    def test_should_build_train_graph_with_sample_class_weights(
            self,
            parse_color_map_from_file_mock,
            parse_json_file_mock,
            read_examples_mock):

        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
        parse_json_file_mock.return_value = SOME_CLASS_WEIGHTS
        read_examples_mock.return_value = EXAMPLE_PROPS_1
        args = create_args(
            DEFAULT_ARGS,
            base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
            color_map=COLOR_MAP_FILENAME,
            channels=SOME_LABELS,
            use_separate_channels=True,
            use_unknown_class=True
        )
        model = Model(args)
        tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE)
        n_output_channels = len(SOME_LABELS) + 1
        assert (
            tensors.separate_channel_annotation_tensor.shape.as_list() ==
            [BATCH_SIZE, model.image_height, model.image_width, n_output_channels]
        )
        assert tensors.pos_weight.shape.as_list(
        ) == [BATCH_SIZE, 1, 1, n_output_channels]
    def test_should_build_train_graph_with_class_weights(
            self, parse_color_map_from_file_mock, parse_json_file_mock,
            read_examples_mock):

        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
        parse_json_file_mock.return_value = SOME_CLASS_WEIGHTS
        read_examples_mock.return_value = EXAMPLE_PROPS_1
        args = create_args(DEFAULT_ARGS,
                           base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
                           color_map=COLOR_MAP_FILENAME,
                           class_weights=CLASS_WEIGHTS_FILENAME,
                           channels=['a', 'b'],
                           use_separate_channels=True,
                           use_unknown_class=True)
        model = Model(args)
        tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE)
        assert tensors.pos_weight is not None