def test_serialize_to_file(self, log_config_parameter_to_mlflow_fn):
        def _log_config_parameter_to_mlflow(param_name, param_value):
            pass

        log_config_parameter_to_mlflow_fn.side_effect = (
            _log_config_parameter_to_mlflow)

        single_config_names = ["model", "dataset"]
        serializer = project_serializer.MlflowConfigSerializer(
            save_dir="", single_config_names=single_config_names)
        serializer.serialize_to_file(self.configs_to_log)
        config_serialized_must_flatten = nest_utils.flatten_nested_struct(
            self.config_serialized_must, separator="/")

        for each_param_name, each_param_value in (
                config_serialized_must_flatten.items()):
            log_config_parameter_to_mlflow_fn.assert_has_calls(
                [mock_call(each_param_name, each_param_value)])

        log_config_parameter_to_mlflow_fn.assert_has_calls(
            [mock_call("CLUSTER_SPEC", {})])
        len_of_calls_must = len(config_serialized_must_flatten) + 1
        self.assertEqual(len_of_calls_must,
                         log_config_parameter_to_mlflow_fn.call_count)

        self.assertDictEqual(self.configs_to_log_copy, self.configs_to_log)
示例#2
0
    def test_call_np(self):
        reader = DataReaderDummyNP(name="reader",
                                   file_list_keys_mapping={
                                       "data1_fl": "data1",
                                       "data2_fl": "data2"
                                   }).build()
        processor1 = _DataProcessorNP(inbound_nodes=["reader"],
                                      name="processor1").build()
        processor2 = _DataProcessorNP2(inbound_nodes=["reader"],
                                       name="processor2").build()
        data_pipe = DataPipe(readers=[reader],
                             processors=[processor1, processor2]).build()

        reader.read = MagicMock(wraps=reader.read)
        processor1.process = MagicMock(wraps=processor1.process)
        processor2.process = MagicMock(wraps=processor2.process)

        data_pipe.build_dna()
        data_pipe.save_target = self.get_temp_dir()
        for each_sample_file_list in self.file_list_with_floats:
            result = data_pipe(**each_sample_file_list)
            result_must = {
                "reader": {
                    "data1": float(each_sample_file_list["data1_fl"]),
                    "data2": float(each_sample_file_list["data2_fl"])
                },
                "processor1": {
                    "data1_p": float(each_sample_file_list["data1_fl"]) + 20
                },
                "processor2": {
                    "data2_p": float(each_sample_file_list["data2_fl"]) + 200
                }
            }
            self.assertAllClose(result_must, result)

        reader_call_data_must = [
            mock_call(
                **{
                    k[:5]: v
                    for k, v in each_sample_file_list.items()
                    if k in ["data1_fl", "data2_fl"]
                }) for each_sample_file_list in self.file_list_with_floats
        ]
        processor1_call_data_must = [
            mock_call(data1=float(each_sample_file_list["data1_fl"]))
            for each_sample_file_list in self.file_list_with_floats
        ]
        processor2_call_data_must = [
            mock_call(data2=float(each_sample_file_list["data2_fl"]))
            for each_sample_file_list in self.file_list_with_floats
        ]
        saver_save_call_data_must = [
            mock_call(data1_p=float(each_sample_file_list["data1_fl"]) + 20,
                      data2_p=float(each_sample_file_list["data2_fl"]) + 200)
            for each_sample_file_list in self.file_list_with_floats
        ]

        reader.read.assert_has_calls(reader_call_data_must)
        processor1.process.assert_has_calls(processor1_call_data_must)
        processor2.process.assert_has_calls(processor2_call_data_must)
示例#3
0
async def test_switches(hass):
    """Test that switches are loaded properly."""
    device = await setup_device(hass)
    device.api.vapix.ports = {'0': Mock(), '1': Mock()}
    device.api.vapix.ports['0'].name = 'Doorbell'
    device.api.vapix.ports['1'].name = ''

    for event in EVENTS:
        device.api.stream.event.manage_event(event)
    await hass.async_block_till_done()

    assert len(hass.states.async_all()) == 3

    relay_0 = hass.states.get('switch.model_0_doorbell')
    assert relay_0.state == 'off'
    assert relay_0.name == 'model 0 Doorbell'

    relay_1 = hass.states.get('switch.model_0_relay_1')
    assert relay_1.state == 'on'
    assert relay_1.name == 'model 0 Relay 1'

    device.api.vapix.ports['0'].action = Mock()

    await hass.services.async_call('switch',
                                   'turn_on',
                                   {'entity_id': 'switch.model_0_doorbell'},
                                   blocking=True)

    await hass.services.async_call('switch',
                                   'turn_off',
                                   {'entity_id': 'switch.model_0_doorbell'},
                                   blocking=True)

    assert device.api.vapix.ports['0'].action.call_args_list == \
        [mock_call('/'), mock_call('\\')]
示例#4
0
    def test_retrieve_client(self):
        self.patch_obj(target=settings.Settings,
                       attribute="MONGO_URL",
                       new=settings.Settings.MONGO_TEST_URL)

        result = self.db_handler.retrieve_client()

        self.debug_mock.assert_has_calls(calls=[
            mock_call(msg="Initialization of MongoDB"),
            mock_call(msg="Creating MongoDB client")
        ])
        self.assertIsInstance(result, motor.motor_asyncio.AsyncIOMotorClient)
示例#5
0
    def test_delete_student_state(self):
        self.submit_question_answer('p1', {'2_1': 'choice_choice_2'})

        with patch('lms.djangoapps.instructor.enrollment.tracker') as enrollment_tracker:
            with patch('lms.djangoapps.grades.events.tracker') as events_tracker:
                reset_student_attempts(
                    self.course.id, self.student, self.problem.location, self.instructor, delete_module=True,
                )
        course = self.store.get_course(self.course.id, depth=0)

        event_transaction_id = enrollment_tracker.method_calls[0][1][1]['event_transaction_id']
        enrollment_tracker.emit.assert_called_with(
            events.STATE_DELETED_EVENT_TYPE,
            {
                'user_id': str(self.student.id),
                'course_id': str(self.course.id),
                'problem_id': str(self.problem.location),
                'instructor_id': str(self.instructor.id),
                'event_transaction_id': event_transaction_id,
                'event_transaction_type': events.STATE_DELETED_EVENT_TYPE,
            }
        )
        events_tracker.emit.assert_has_calls(
            [
                mock_call(
                    events.COURSE_GRADE_CALCULATED,
                    {
                        'percent_grade': 0.0,
                        'grading_policy_hash': 'ChVp0lHGQGCevD0t4njna/C44zQ=',
                        'user_id': str(self.student.id),
                        'letter_grade': '',
                        'event_transaction_id': event_transaction_id,
                        'event_transaction_type': events.STATE_DELETED_EVENT_TYPE,
                        'course_id': str(self.course.id),
                        'course_edited_timestamp': str(course.subtree_edited_on),
                        'course_version': str(course.course_version),
                    }
                ),
                mock_call(
                    events.COURSE_GRADE_NOW_FAILED_EVENT_TYPE,
                    {
                        'user_id': str(self.student.id),
                        'event_transaction_id': event_transaction_id,
                        'event_transaction_type': events.STATE_DELETED_EVENT_TYPE,
                        'course_id': str(self.course.id),
                    }
                ),
            ],
            any_order=True,
        )
示例#6
0
    def test_description_changed(self):
        self.event.new_description.has_writable_server.return_value = False
        self.event.new_description.has_readable_server.return_value = False

        self.logger.description_changed(event=self.event)

        self.debug_mock.assert_has_calls(calls=[
            mock_call(
                f"Topology description updated for topology id {self.event.topology_id}"
            ),
            mock_call(
                f"Topology {self.event.topology_id} changed type from "
                f"{self.event.previous_description.topology_type_name} to "
                f"{self.event.new_description.topology_type_name}"),
            mock_call("No writable servers available."),
            mock_call("No readable servers available."),
        ])
示例#7
0
    def test_add_optim_configs_to_handler(self):

        def get_optimization_configs_with_variables_config2():
            return [('config1', 'vars1'),
                    ('config2', 'vars2')]

        def get_optimization_configs_with_variables_config3():
            return [('config3', 'vars3')]

        def _add_config_with_variables(config_with_vars, name):
            return None

        model_handler = self._get_model_handler([])
        model_handler.build()
        optimization_handler = model_handler.optimization_handler
        plugin1 = ModelPlugin(name='plugin1')
        plugin2 = ModelPlugin(name='plugin2')
        plugin3 = ModelPlugin(name='plugin3')
        plugin1.get_optimization_configs_with_variables = MagicMock(
            return_value=None)
        plugin2.get_optimization_configs_with_variables = MagicMock(
            side_effect=get_optimization_configs_with_variables_config2)
        plugin3.get_optimization_configs_with_variables = MagicMock(
            side_effect=get_optimization_configs_with_variables_config3)
        optimization_handler.initialize_for_session = MagicMock(
            return_value=None)
        optimization_handler.add_config_with_variables = (
            MagicMock(side_effect=_add_config_with_variables))
        self.model.plugins = {'plugin1': plugin1,
                              'plugin2': plugin2,
                              'plugin3': plugin3}
        model_handler.add_optim_configs_to_handler()
        pl1 = plugin1
        pl2 = plugin2
        pl3 = plugin3
        pl1.get_optimization_configs_with_variables.assert_called_once_with()
        pl2.get_optimization_configs_with_variables.assert_called_once_with()
        pl3.get_optimization_configs_with_variables.assert_called_once_with()
        optimization_handler.initialize_for_session.assert_called_once_with()
        add_config_calls_must = [
            mock_call(('config1', 'vars1'), name="plugin2"),
            mock_call(('config2', 'vars2'), name="plugin2"),
            mock_call(('config3', 'vars3'), name="plugin3")]
        optimization_handler.add_config_with_variables.assert_has_calls(
            add_config_calls_must)
示例#8
0
async def test_switches(hass):
    """Test that switches are loaded properly."""
    device = await setup_axis_integration(hass)

    device.api.vapix.ports = {"0": Mock(), "1": Mock()}
    device.api.vapix.ports["0"].name = "Doorbell"
    device.api.vapix.ports["1"].name = ""

    for event in EVENTS:
        device.api.stream.event.manage_event(event)
    await hass.async_block_till_done()

    assert len(hass.states.async_entity_ids("switch")) == 2

    relay_0 = hass.states.get(f"switch.{NAME}_doorbell")
    assert relay_0.state == "off"
    assert relay_0.name == f"{NAME} Doorbell"

    relay_1 = hass.states.get(f"switch.{NAME}_relay_1")
    assert relay_1.state == "on"
    assert relay_1.name == f"{NAME} Relay 1"

    device.api.vapix.ports["0"].action = Mock()

    await hass.services.async_call("switch",
                                   "turn_on",
                                   {"entity_id": f"switch.{NAME}_doorbell"},
                                   blocking=True)

    await hass.services.async_call("switch",
                                   "turn_off",
                                   {"entity_id": f"switch.{NAME}_doorbell"},
                                   blocking=True)

    assert device.api.vapix.ports["0"].action.call_args_list == [
        mock_call("/"),
        mock_call("\\"),
    ]
    def test_submit_answer(self, events_tracker):
        self.submit_question_answer('p1', {'2_1': 'choice_choice_2'})
        course = self.store.get_course(self.course.id, depth=0)

        event_transaction_id = events_tracker.emit.mock_calls[0][1][1][
            'event_transaction_id']
        events_tracker.emit.assert_has_calls(
            [
                mock_call(
                    events.PROBLEM_SUBMITTED_EVENT_TYPE,
                    {
                        'user_id': str(self.student.id),
                        'event_transaction_id': event_transaction_id,
                        'event_transaction_type':
                        events.PROBLEM_SUBMITTED_EVENT_TYPE,
                        'course_id': str(self.course.id),
                        'problem_id': str(self.problem.location),
                        'weighted_earned': 2.0,
                        'weighted_possible': 2.0,
                    },
                ),
                mock_call(
                    events.COURSE_GRADE_CALCULATED, {
                        'course_version': str(course.course_version),
                        'percent_grade': 0.02,
                        'grading_policy_hash': 'ChVp0lHGQGCevD0t4njna/C44zQ=',
                        'user_id': str(self.student.id),
                        'letter_grade': '',
                        'event_transaction_id': event_transaction_id,
                        'event_transaction_type':
                        events.PROBLEM_SUBMITTED_EVENT_TYPE,
                        'course_id': str(self.course.id),
                        'course_edited_timestamp': str(
                            course.subtree_edited_on),
                    }),
            ],
            any_order=True,
        )
示例#10
0
    def test_save(self, mlflow_log_metric_fn):
        def _log_metric(param_name, param_value):
            return

        mlflow_log_metric_fn.side_effect = _log_metric

        saver = MlflowKPILogger().build()
        saver.save("kpi1", self._kpi_values)

        calls_must = [
            mock_call("_".join(["kpi1", each_key]), each_value) for each_key,
            each_value in sorted(self._kpi_values_logged_must.items())
        ]
        mlflow_log_metric_fn.assert_has_calls(calls_must)
示例#11
0
    def test_angles_from_grid(self):
        # Check it will gets angles from 'u_cube', and pass any kwargs on to
        # the angles routine.
        u_cube = sample_2d_latlons(regional=True, transformed=True)
        u_cube = u_cube[:2, :3]
        u_cube.units = "ms-1"
        u_cube.rename("dx")
        u_cube.data[...] = 1.0
        v_cube = u_cube.copy()
        v_cube.name("dy")
        v_cube.data[...] = 0.0

        # Setup a fake angles result from the inner call to 'gridcell_angles'.
        angles_result_data = np.array([[0.0, 90.0, 180.0],
                                       [-180.0, -90.0, 270.0]])
        angles_result_cube = Cube(angles_result_data, units="degrees")
        angles_kwargs = {"this": 2}
        angles_call_patch = self.patch(
            "iris.analysis._grid_angles.gridcell_angles",
            Mock(return_value=angles_result_cube),
        )

        # Call the routine.
        result = rotate_grid_vectors(u_cube,
                                     v_cube,
                                     grid_angles_kwargs=angles_kwargs)

        self.assertEqual(angles_call_patch.call_args_list,
                         [mock_call(u_cube, this=2)])

        out_u, out_v = [cube.data for cube in result]
        # Records what results should be for the various n*90deg rotations.
        expect_u = np.array([[1.0, 0.0, -1.0], [-1.0, 0.0, 0.0]])
        expect_v = np.array([[0.0, 1.0, 0.0], [0.0, -1.0, -1.0]])
        # Check results are as expected.
        self.assertArrayAllClose(out_u, expect_u)
        self.assertArrayAllClose(out_v, expect_v)
示例#12
0
    def test_rescoring_events(self):
        self.submit_question_answer('p1', {'2_1': 'choice_choice_3'})
        new_problem_xml = MultipleChoiceResponseXMLFactory().build_xml(
            question_text='The correct answer is Choice 3',
            choices=[False, False, False, True],
            choice_names=['choice_0', 'choice_1', 'choice_2', 'choice_3']
        )
        with self.store.branch_setting(ModuleStoreEnum.Branch.draft_preferred, self.course.id):
            self.problem.data = new_problem_xml
            self.store.update_item(self.problem, self.instructor.id)
        self.store.publish(self.problem.location, self.instructor.id)

        with patch('lms.djangoapps.grades.events.tracker') as events_tracker:
            submit_rescore_problem_for_student(
                request=get_mock_request(self.instructor),
                usage_key=self.problem.location,
                student=self.student,
                only_if_higher=False
            )
        course = self.store.get_course(self.course.id, depth=0)

        # make sure the tracker's context is updated with course info
        for args in events_tracker.get_tracker().context.call_args_list:
            assert args[0][1] == {
                'course_id': str(self.course.id),
                'enterprise_uuid': '',
                'org_id': str(self.course.org)
            }

        event_transaction_id = events_tracker.emit.mock_calls[0][1][1]['event_transaction_id']
        events_tracker.emit.assert_has_calls(
            [
                mock_call(
                    events.GRADES_RESCORE_EVENT_TYPE,
                    {
                        'course_id': str(self.course.id),
                        'user_id': str(self.student.id),
                        'problem_id': str(self.problem.location),
                        'new_weighted_earned': 2,
                        'new_weighted_possible': 2,
                        'only_if_higher': False,
                        'instructor_id': str(self.instructor.id),
                        'event_transaction_id': event_transaction_id,
                        'event_transaction_type': events.GRADES_RESCORE_EVENT_TYPE,
                    },
                ),
                mock_call(
                    events.COURSE_GRADE_CALCULATED,
                    {
                        'course_version': str(course.course_version),
                        'percent_grade': 0.02,
                        'grading_policy_hash': 'ChVp0lHGQGCevD0t4njna/C44zQ=',
                        'user_id': str(self.student.id),
                        'letter_grade': '',
                        'event_transaction_id': event_transaction_id,
                        'event_transaction_type': events.GRADES_RESCORE_EVENT_TYPE,
                        'course_id': str(self.course.id),
                        'course_edited_timestamp': str(course.subtree_edited_on),
                    },
                ),
            ],
            any_order=True,
        )
示例#13
0
    def test_evaluate_on_sample(self, add_prefix_to_saver_name):
        temp_dir = self.get_temp_dir()
        os.mkdir(os.path.join(temp_dir, "save"))
        os.mkdir(os.path.join(temp_dir, "cache"))

        saver = KPIJsonSaver().build()
        saver.add_prefix_to_name = add_prefix_to_saver_name
        cacher = KPIMD5Cacher().build()
        kpi_plugin = DummyTpFpTnFnKPIPlugin(cachers=[cacher],
                                            savers=[saver]).build()

        kpi_plugin.save_target = os.path.join(temp_dir, "save")
        kpi_plugin.cache_target = os.path.join(temp_dir, "cache")

        self.assertEqual(saver.save_target, os.path.join(temp_dir, "save"))
        self.assertEqual(cacher.cache_target, os.path.join(temp_dir, "cache"))
        kpi_plugin.process = MagicMock(wraps=kpi_plugin.process)
        saver.save = MagicMock(wraps=saver.save)
        cacher.cache = MagicMock(wraps=cacher.cache)
        cacher.restore = MagicMock(wraps=cacher.restore)
        cacher.calculate_hash = MagicMock(wraps=cacher.calculate_hash)

        for i_sample, each_sample in enumerate(self.data):
            kpi = kpi_plugin.evaluate_on_sample(**each_sample)
            self.assertAllClose(self.kpi_must[i_sample], kpi)

        plugin_process_calls_must = [
            mock_call(labels=each_sample["labels"],
                      predictions=each_sample["predictions"])
            for each_sample in self.data
        ]
        saver_save_calls_must = [
            mock_call(
                name=(each_sample["prefix"] + "-" + "DummyTpFpTnFnKPIPlugin" if
                      add_prefix_to_saver_name else "DummyTpFpTnFnKPIPlugin"),
                values=each_kpi)
            for each_sample, each_kpi in zip(self.data, self.kpi_must)
        ]
        cacher_cache_calls_must = [
            mock_call(each_kpi) for each_kpi in self.kpi_must
        ]
        cacher_calculate_hash_calls_must = [
            mock_call(
                cache_prefix=(each_sample["prefix"] + "-" +
                              "DummyTpFpTnFnKPIPlugin"),
                inputs={k: v
                        for k, v in each_sample.items() if k != "prefix"})
            for each_sample in self.data
        ]

        kpi_plugin.process.assert_has_calls(plugin_process_calls_must)
        saver.save.assert_has_calls(saver_save_calls_must)
        cacher.cache.assert_has_calls(cacher_cache_calls_must)
        cacher.calculate_hash.assert_has_calls(
            cacher_calculate_hash_calls_must)
        self.assertEqual(len(self.data), cacher.restore.call_count)

        # create new plugin and calculate KPI again but from cache
        saver2 = KPIJsonSaver().build()
        saver2.add_prefix_to_name = add_prefix_to_saver_name
        cacher2 = KPIMD5Cacher().build()
        kpi_plugin2 = DummyTpFpTnFnKPIPlugin(cachers=[cacher2],
                                             savers=[saver2]).build()
        kpi_plugin2.save_target = os.path.join(temp_dir, "save")
        kpi_plugin2.cache_target = os.path.join(temp_dir, "cache")
        kpi_plugin2.process = MagicMock(wraps=kpi_plugin2.process)
        saver2.save = MagicMock(wraps=saver2.save)
        cacher2.cache = MagicMock(wraps=cacher2.cache)
        cacher2.restore = MagicMock(wraps=cacher2.restore)

        for i_sample, each_sample in enumerate(self.data):
            kpi2 = kpi_plugin2.evaluate_on_sample(**each_sample)
            self.assertAllClose(self.kpi_must[i_sample], kpi2)
        kpi_plugin2.process.assert_not_called()
        saver2.save.assert_has_calls(saver_save_calls_must)
        cacher2.cache.assert_not_called()
        self.assertEqual(len(self.data), cacher2.restore.call_count)
示例#14
0
    def test_add_summaries(self, tf_global_norm, tf_norm,
                           tf_summary_scalar, add_histogram_summary,
                           add_summary_by_name,
                           with_losses, with_summaries, with_metrics,
                           with_grads_and_vars, with_reg_grads_and_vars):
        def _norm(inp):
            return "_".join(["norm", inp])

        def _global_norm(inputs):
            return "_".join(["norm", *inputs])

        tf_summary_scalar.return_value = None
        add_histogram_summary.return_value = None
        add_summary_by_name.return_value = None
        tf_norm.side_effect = _norm
        tf_global_norm.side_effect = _global_norm

        losses = None
        summary = None
        metrics = None
        grads_and_vars = None
        reg_grads_and_vars = None
        variables = [tf.Variable(0, name="var_{}".format(i)) for i in range(5)]
        if with_losses:
            losses = {'total_loss': 'total_loss_value',
                      'loss1': {'sub_loss11': 'loss11_value'}}
        if with_summaries:
            summary = {
                'summary1': {'subsummary11': "summary11_value",
                             'subsummary12': "summary12_value"},
                'summary2': {'subsummary21': "summary21_value"}}
        if with_metrics:
            metrics = {
                'metric1': {'accuracy11': "accuracy11_value",
                            'accuracy12': "accuracy12_value"},
                'metric2': {'accuracy21': "accuracy21_value"}}
        if with_grads_and_vars:
            grads_and_vars = [('grad_{}'.format(i), variables[i])
                              for i in range(5)]
        if with_reg_grads_and_vars:
            reg_grads_and_vars = [('reg_grad_{}'.format(i), variables[i])
                                  for i in range(3)]

        model_results = ModelResults(
            inputs_preprocessed="not_used",
            predictions_raw="not_used",
            predictions="not_used",
            losses=losses,
            summary=summary,
            metrics=metrics,
            grads_and_vars=grads_and_vars,
            regularization_grads_and_vars=reg_grads_and_vars)

        model_handler = self._get_model_handler([])
        model_handler.build()
        model_handler.optimization_handler._global_learning_rate = (
            "global_learning_rate")
        model_handler.add_summaries(
            model_results, mode=tf.estimator.ModeKeys.TRAIN)

        if with_grads_and_vars:
            for i, var in enumerate(variables):
                call_must = mock_call(
                    "gradient/{}".format(var.name).replace(':', '_'),
                    "grad_{}".format(i))
                add_histogram_summary.assert_has_calls([call_must])
        if with_reg_grads_and_vars:
            for i, var in enumerate(variables[:3]):
                call_must = mock_call(
                    "reg_gradient/{}".format(var.name).replace(':', '_'),
                    "reg_grad_{}".format(i))
                add_histogram_summary.assert_has_calls([call_must])

        if not with_grads_and_vars and not with_reg_grads_and_vars:
            add_histogram_summary.assert_not_called()

        if with_losses:
            tf_summary_scalar.assert_has_calls(
                [mock_call("total_loss", "total_loss_value", family="loss")])
            tf_summary_scalar.assert_has_calls(
                [mock_call("loss1//sub_loss11", "loss11_value", family="loss")])
        tf_summary_scalar.assert_has_calls(
            [mock_call("learning_rate", "global_learning_rate")])

        max_outputs_tb = model_handler.max_outputs_tb
        if with_metrics:
            add_summary_by_name.assert_has_calls(
                [mock_call("metric1//accuracy11", "accuracy11_value",
                           max_outputs_tb)])
            add_summary_by_name.assert_has_calls(
                [mock_call("metric1//accuracy12", "accuracy12_value",
                           max_outputs_tb)])
            add_summary_by_name.assert_has_calls(
                [mock_call("metric2//accuracy21", "accuracy21_value",
                           max_outputs_tb)])
        if with_summaries:
            add_summary_by_name.assert_has_calls(
                [mock_call("summary1//subsummary11", "summary11_value",
                           max_outputs_tb)])
            add_summary_by_name.assert_has_calls(
                [mock_call("summary1//subsummary12", "summary12_value",
                           max_outputs_tb)])
            add_summary_by_name.assert_has_calls(
                [mock_call("summary2//subsummary21", "summary21_value",
                           max_outputs_tb)])
        if not with_metrics and not with_summaries:
            add_summary_by_name.assert_not_called()
示例#15
0
    def test_call(self, plugin_is_last_sample, sample_mask, is_last_iteration,
                  is_last_sample_must):
        temp_dir = self.get_temp_dir()
        os.mkdir(os.path.join(temp_dir, "save"))
        os.mkdir(os.path.join(temp_dir, "cache"))

        plugin_is_last_sample.side_effect = lambda x: x
        saver = KPIJsonSaver().build()
        cacher = KPIMD5Cacher().build()
        kpi_plugin = DummyTpFpTnFnKPIPlugin(cachers=[cacher],
                                            savers=[saver]).build()

        kpi_plugin.save_target = os.path.join(temp_dir, "save")
        kpi_plugin.cache_target = os.path.join(temp_dir, "cache")
        kpi_plugin.evaluate_on_sample = MagicMock(
            wraps=kpi_plugin.evaluate_on_sample)

        data_batch = nest_utils.combine_nested(self.data, combine_fun=np.stack)
        kpi_plugin.is_last_iteration = is_last_iteration

        kpi_must_list = []
        for i_sample, each_kpi_must in enumerate(self.kpi_must):
            if sample_mask is None or sample_mask[i_sample]:
                kpi_must_list.append(each_kpi_must)

        if kpi_must_list:
            kpi_must = nest_utils.combine_nested(kpi_must_list,
                                                 combine_fun=np.array)
        else:
            kpi_must = None

        is_last_sample_calls_must = [mock_call(i) for i in is_last_sample_must]
        if sample_mask is None:
            evaluate_on_sample_args_must = [
                mock_call(**each_sample_data) for each_sample_data in self.data
            ]
        else:
            evaluate_on_sample_args_must = [
                mock_call(**each_sample_data)
                for i, each_sample_data in enumerate(self.data)
                if sample_mask[i]
            ]

        kpi = kpi_plugin(sample_mask=sample_mask, **data_batch)

        plugin_is_last_sample.assert_has_calls(is_last_sample_calls_must)
        kpi_plugin.evaluate_on_sample.assert_has_calls(
            evaluate_on_sample_args_must)

        if kpi_must is None:
            self.assertIsNone(kpi)
            return

        if sample_mask is None:
            self.assertAllClose(kpi_must, kpi)
        else:
            for i in range(sum(sample_mask)):
                sample_kpi_must = {k: v[i] for k, v in kpi_must.items()}
                sample_kpi = {k: v[i] for k, v in kpi.items()}
                if sample_mask[i]:
                    self.assertAllClose(sample_kpi_must, sample_kpi)
                else:
                    self.assertAllEqual(sample_kpi_must, sample_kpi)
示例#16
0
    def test_evaluate_on_sample(self):
        temp_dir = self.get_temp_dir()
        os.mkdir(os.path.join(temp_dir, "save"))
        os.mkdir(os.path.join(temp_dir, "cache"))

        saver = KPIJsonSaver(add_prefix_to_name=True).build()
        cacher = KPIMD5Cacher().build()
        kpi_accumulator = DummyF1KPIAccumulator(cachers=[cacher],
                                                savers=[saver]).build()

        kpi_accumulator.save_target = os.path.join(temp_dir, "save")
        kpi_accumulator.cache_target = os.path.join(temp_dir, "cache")

        kpi_accumulator.process = MagicMock(wraps=kpi_accumulator.process)
        kpi_accumulator.buffer_processor.buffer.add = MagicMock(
            wraps=kpi_accumulator.buffer_processor.buffer.add)
        kpi_accumulator.clear_state = MagicMock(
            wraps=kpi_accumulator.clear_state)

        saver.save = MagicMock(wraps=saver.save)
        cacher.cache = MagicMock(wraps=cacher.cache)

        for i_sample, (each_evaluate_flag, each_inputs) in enumerate(
                zip(self.evaluate_flag, self.inputs)):
            is_last_sample = False
            if i_sample == len(self.inputs) - 1:
                is_last_sample = True
                kpi_accumulator.is_last_sample = True
            kpi = kpi_accumulator.evaluate_on_sample(
                evaluate=each_evaluate_flag, **each_inputs)
            if not each_evaluate_flag and not is_last_sample:
                self.assertIsNone(kpi)
            else:
                self.assertAllClose(self.kpis_must[i_sample], kpi)

        last_kpi_must = {k: v for k, v in self.kpis_must[-1].items()}
        self.assertAllClose(last_kpi_must, kpi_accumulator.last_kpi)

        kpi_process_calls_must = [
            mock_call(true_positives=[0, 1, 0, 1],
                      false_positives=[0, 0, 1, 0],
                      false_negatives=[1, 0, 0, 0]),
            mock_call(true_positives=[0, 1, 0],
                      false_positives=[0, 0, 0],
                      false_negatives=[0, 0, 0]),
        ]
        kpi_accumulate_calls_must = [
            mock_call(**each_sample) for each_sample in self.inputs
        ]
        saver_save_calls_must = [
            mock_call(name="sample1-sample4-DummyF1KPIAccumulator",
                      values=self._get_kpi_must(0, 4)),
            mock_call(name="sample5-sample7-DummyF1KPIAccumulator",
                      values=self._get_kpi_must(4, 7)),
        ]
        cacher_cache_calls_must = [
            mock_call(self._get_kpi_must(0, 4)),
            mock_call(self._get_kpi_must(4, 7)),
        ]

        kpi_accumulator.process.assert_has_calls(kpi_process_calls_must)
        kpi_accumulator.buffer_processor.buffer.add.assert_has_calls(
            kpi_accumulate_calls_must)
        saver.save.assert_has_calls(saver_save_calls_must)
        cacher.cache.assert_has_calls(cacher_cache_calls_must)

        self.assertTrue(kpi_accumulator.buffer.is_empty())
示例#17
0
    def test_init(self,
                  add_constructor_parameters_to_log_fn,
                  register_to_name_scope_fn,
                  use_child,
                  register_name=None,
                  register_name_scope=None,
                  log_name_scope=None,
                  exclude_args_from_log=None,
                  exclude_from_register=False,
                  exclude_from_log=False):

        ClsWithMeta = self._declare_cls(
            register_name_=register_name,
            register_name_scope_=register_name_scope,
            log_name_scope_=log_name_scope,
            exclude_from_register_=exclude_from_register,
            exclude_from_log_=exclude_from_log,
            exclude_args_from_log_=exclude_args_from_log)
        cls = ClsWithMeta
        cls_parent = None
        if use_child:
            cls_parent = cls

            class ChildClsWithMeta(ClsWithMeta):
                pass

            cls = ChildClsWithMeta

        if use_child:
            self.assertEqual(cls.__name__, cls._register_name)
        else:
            self.assertEqual(register_name, cls._register_name)
        if use_child:
            self.assertFalse(cls._exclude_from_register)
        else:
            self.assertEqual(exclude_from_register or False,
                             cls._exclude_from_register)

        if use_child:
            self.assertFalse(cls._exclude_from_log)
        else:
            if exclude_from_log is None:
                self.assertEqual(exclude_from_register or False,
                                 cls._exclude_from_log)
            else:
                self.assertEqual(exclude_from_log, cls._exclude_from_log)

        if register_name_scope is not None:
            self.assertEqual(register_name_scope, cls._register_name_scope)
        else:
            self.assertEqual(cls.__name__, cls._register_name_scope)

        if log_name_scope is not None:
            self.assertEqual(log_name_scope, cls._log_name_scope)
        else:
            if register_name_scope is not None:
                self.assertEqual(register_name_scope, cls._log_name_scope)
            else:
                self.assertEqual(cls.__name__, cls._log_name_scope)
        if exclude_args_from_log is None:
            self.assertEmpty(cls._exclude_args_from_log)
        else:
            self.assertListEqual(exclude_args_from_log,
                                 cls._exclude_args_from_log)

        register_to_name_scope_fn_num_calls = 0
        if use_child or not exclude_from_register:
            register_to_name_scope_fn.assert_has_calls([
                mock_call(cls._register_name_scope,
                          cls,
                          name=cls._register_name)
            ])
            register_to_name_scope_fn_num_calls += 1

        if use_child and not exclude_from_register:
            if use_child:
                cls_ = cls_parent
            else:
                cls_ = cls
            register_to_name_scope_fn.assert_has_calls([
                mock_call(cls_._register_name_scope,
                          cls_,
                          name=cls_._register_name)
            ])
            register_to_name_scope_fn_num_calls += 1

        self.assertEqual(register_to_name_scope_fn_num_calls,
                         register_to_name_scope_fn.call_count)
        add_constructor_parameters_to_log_fn.assert_not_called()
示例#18
0
    def test_parse_tfrecord_example(self, tf_decode_raw,
                                    tf_parse_single_example):
        def _get_tfrecords_features():
            return {
                "data1": "feature_1",
                "data2": "feature_data2",
                "data3": ["feature_data3_0", "feature_data3_1"],
                "data4": {
                    "sub1": "feature_data4_sub1",
                    "sub2": "feature_data4_sub2"
                }
            }

        def _get_tfrecords_output_types():
            return {"data1": "string_value", "data4/sub1": "float_value"}

        def _parse_single_example(example, features):
            example_flat = nest_utils.flatten_nested_struct(example, "/")
            result = {
                k: "-".join([str(example_flat[k]), features[k]])
                for k in example_flat
            }
            return result

        def _postprocess_tfrecords(**data):
            data_flat = nest_utils.flatten_nested_struct(data)
            return nest_utils.unflatten_dict_to_nested(
                {k: v + "_pp"
                 for k, v in data_flat.items()})

        tf_decode_raw.side_effect = lambda x, y: x + "_raw"
        tf_parse_single_example.side_effect = _parse_single_example

        mixin = tf_data_utils.TfRecordsMixin()
        mixin.get_tfrecords_features = MagicMock(wraps=_get_tfrecords_features)
        mixin.get_tfrecords_output_types = MagicMock(
            wraps=_get_tfrecords_output_types)
        mixin.postprocess_tfrecords = MagicMock(wraps=_postprocess_tfrecords)
        mixin.decode_field = MagicMock(wraps=mixin.decode_field)
        result = mixin.parse_tfrecord_example(self.data)

        features = _get_tfrecords_features()
        features_flat = nest_utils.flatten_nested_struct(features, "/")
        data_flat = nest_utils.flatten_nested_struct(self.data, "/")
        output_types_flat = nest_utils.flatten_nested_struct(
            _get_tfrecords_output_types(), "/")
        result_must = nest_utils.unflatten_dict_to_nested(
            {
                k: "-".join([str(data_flat[k]), features_flat[k]]) +
                ("_raw_pp" if k in output_types_flat else "_pp")
                for k in features_flat
            }, "/")
        self.assertAllEqual(result_must, result)

        mixin.get_tfrecords_features.assert_called_once_with()
        mixin.get_tfrecords_output_types.assert_called_once_with()

        combine_fn_before_decode = lambda x: "-".join([str(x[0]), x[1]])
        decode_values = nest_utils.flatten_nested_struct(
            nest_utils.combine_nested([self.data, features],
                                      combine_fun=combine_fn_before_decode),
            "/")
        decode_field_calls = [
            mock_call(each_key, decode_values[each_key],
                      output_types_flat.get(each_key))
            for each_key in decode_values
        ]
        mixin.decode_field.assert_has_calls(decode_field_calls, any_order=True)

        data_to_postprocess_must = nest_utils.unflatten_dict_to_nested(
            {
                k: "-".join([str(data_flat[k]), features_flat[k]]) +
                ("_raw" if k in output_types_flat else "")
                for k in features_flat
            }, "/")
        mixin.postprocess_tfrecords.assert_called_once_with(
            **data_to_postprocess_must)
示例#19
0
    def test_create_configs_with_grads_and_vars(self, use_only_global_config,
                                                use_regularization):
        res_must = []

        def _filter_grads_and_vars_with_decouple_for_config(
                optim_config, vars_for_config, grads_and_vars,
                regularization_grads_and_vars):
            config_with_vars = (OptimizationHandler.
                                filter_grads_and_vars_with_decouple_for_config(
                                    optim_config, vars_for_config,
                                    grads_and_vars,
                                    regularization_grads_and_vars))
            res_must.extend(config_with_vars)
            return config_with_vars

        self.optimization.filter_grads_and_vars_with_decouple_for_config = (
            MagicMock(
                side_effect=_filter_grads_and_vars_with_decouple_for_config))

        (vars_tf, grads_and_vars_tf,
         reg_grads_and_vars_tf) = self._get_grads_and_vars_tf()
        vars_for_config1 = vars_tf[2:4]
        vars_for_config2 = vars_tf[4:]
        self.optimization.global_config = self.global_optim_config
        if not use_only_global_config:
            self.optimization.add_config_with_variables(
                (self.local_optim_config1, vars_for_config1))
            self.optimization.add_config_with_variables(
                (self.local_optim_config2, vars_for_config2))
        self.optimization.initialize_for_session()
        regularization_grads_and_vars = (use_regularization
                                         and reg_grads_and_vars_tf or None)
        optim_configs_with_variables = (
            self.optimization.create_configs_with_grads_and_vars(
                grads_and_vars=grads_and_vars_tf,
                regularization_grads_and_vars=regularization_grads_and_vars,
                all_trainable_variables=vars_tf))
        if use_only_global_config:
            vars_for_global_config = set(vars_tf)
        else:
            vars_for_global_config = set(vars_tf[:2])
        filter_grads_and_vars_with_decouple_for_config_call_args_list = [
            mock_call(self.optimization.global_config, vars_for_global_config,
                      grads_and_vars_tf, regularization_grads_and_vars)
        ]
        if not use_only_global_config:
            filter_grads_and_vars_with_decouple_for_config_call_args_list += [
                mock_call(self.optimization._local_configs_with_vars[0][0],
                          vars_for_config1, grads_and_vars_tf,
                          regularization_grads_and_vars),
                mock_call(self.optimization._local_configs_with_vars[1][0],
                          vars_for_config2, grads_and_vars_tf,
                          regularization_grads_and_vars)
            ]
        opt = self.optimization
        opt.filter_grads_and_vars_with_decouple_for_config.assert_has_calls(
            filter_grads_and_vars_with_decouple_for_config_call_args_list)

        if use_only_global_config:
            optim_configs_with_vars_len_must = 1
        else:
            optim_configs_with_vars_len_must = 3
        if use_regularization:
            optim_configs_with_vars_len_must += (
                self.global_optim_config.decouple_regularization is True)
            if not use_only_global_config:
                optim_configs_with_vars_len_must += (
                    self.local_optim_config1.decouple_regularization is True)
                optim_configs_with_vars_len_must += (
                    self.local_optim_config2.decouple_regularization is True)
        self.assertEqual(optim_configs_with_vars_len_must,
                         len(optim_configs_with_variables))
        self.assertEqual(res_must, optim_configs_with_variables)