예제 #1
0
    def test_state_with_id_handler(self):
        num_cells = 20
        model_type = pt_gs_k.PTGSKModel
        model = self.build_model(model_type, pt_gs_k.PTGSKParameter, num_cells, 2)
        cids_unspecified = api.IntVector()
        cids_1 = api.IntVector([1])
        cids_2 = api.IntVector([2])

        model_state_12 = model.state.extract_state(cids_unspecified)  # this is how to get all states from model
        model_state_1 = model.state.extract_state(cids_1)  # this is how to get only specified states from model
        model_state_2 = model.state.extract_state(cids_2)
        self.assertEqual(len(model_state_1) + len(model_state_2), len(model_state_12))
        self.assertGreater(len(model_state_1), 0)
        self.assertGreater(len(model_state_2), 0)
        for i in range(len(model_state_1)):  # verify selective extract catchment 1
            self.assertEqual(model_state_1[i].id.cid, 1)
        for i in range(len(model_state_2)):  # verify selective extract catchment 2
            self.assertEqual(model_state_2[i].id.cid, 2)
        for i in range(len(model_state_12)):
            model_state_12[i].state.kirchner.q = 100 + i
        model.state.apply_state(model_state_12, cids_unspecified)  # this is how to put all states into  model
        ms_12 = model.state.extract_state(cids_unspecified)
        for i in range(len(ms_12)):
            self.assertAlmostEqual(ms_12[i].state.kirchner.q, 100 + i)
        for i in range(len(model_state_2)):
            model_state_2[i].state.kirchner.q = 200 + i
        unapplied = model.state.apply_state(model_state_2, cids_2)  # this is how to put a limited set of state into model
        self.assertEqual(len(unapplied), 0)
        ms_12 = model.state.extract_state(cids_unspecified)
        for i in range(len(ms_12)):
            if ms_12[i].id.cid == 1:
                self.assertAlmostEqual(ms_12[i].state.kirchner.q, 100 + i)

        ms_2 = model.state.extract_state(cids_2)
        for i in range(len(ms_2)):
            self.assertAlmostEqual(ms_2[i].state.kirchner.q, 200 + i)

        # serialization support, to and from bytes

        bytes = ms_2.serialize_to_bytes()  # first make some bytes out of the state
        with tempfile.TemporaryDirectory() as tmpdirname:
            file_path = str(path.join(tmpdirname, "pt_gs_k_state_test.bin"))
            api.byte_vector_to_file(file_path, bytes)  # stash it into a file
            bytes = api.byte_vector_from_file(file_path)  # get it back from the file and into ByteVector
        ms_2x = pt_gs_k.deserialize_from_bytes(bytes)  # then restore it from bytes to a StateWithIdVector

        self.assertIsNotNone(ms_2x)
        for i in range(len(ms_2x)):
            self.assertAlmostEqual(ms_2x[i].state.kirchner.q, 200 + i)
    def test_state_with_id_handler(self):
        num_cells = 20
        model_type = pt_gs_k.PTGSKModel
        model = self.build_model(model_type, pt_gs_k.PTGSKParameter, num_cells,
                                 2)
        cids_unspecified = api.IntVector()
        cids_1 = api.IntVector([1])
        cids_2 = api.IntVector([2])

        model_state_12 = model.state.extract_state(
            cids_unspecified)  # this is how to get all states from model
        model_state_1 = model.state.extract_state(
            cids_1)  # this is how to get only specified states from model
        model_state_2 = model.state.extract_state(cids_2)
        self.assertEqual(
            len(model_state_1) + len(model_state_2), len(model_state_12))
        # We need to store state into yaml text files, for this its nice to have it as a string:
        ms2 = pt_gs_k.PTGSKStateWithIdVector.deserialize_from_str(
            model_state_2.serialize_to_str())  # [0] verify state serialization
        self.assertEqual(len(ms2), len(model_state_2))
        for a, b in zip(ms2, model_state_2):
            self.assertEqual(a.id, b.id)
            self.assertAlmostEqual(a.state.kirchner.q, b.state.kirchner.q)

        self.assertGreater(len(model_state_1), 0)
        self.assertGreater(len(model_state_2), 0)
        for i in range(
                len(model_state_1)):  # verify selective extract catchment 1
            self.assertEqual(model_state_1[i].id.cid, 1)
        for i in range(
                len(model_state_2)):  # verify selective extract catchment 2
            self.assertEqual(model_state_2[i].id.cid, 2)
        for i in range(len(model_state_12)):
            model_state_12[i].state.kirchner.q = 100 + i
        model.state.apply_state(
            model_state_12,
            cids_unspecified)  # this is how to put all states into  model
        ms_12 = model.state.extract_state(cids_unspecified)
        for i in range(len(ms_12)):
            self.assertAlmostEqual(ms_12[i].state.kirchner.q, 100 + i)
        for i in range(len(model_state_2)):
            model_state_2[i].state.kirchner.q = 200 + i
        unapplied = model.state.apply_state(
            model_state_2,
            cids_2)  # this is how to put a limited set of state into model
        self.assertEqual(len(unapplied), 0)
        ms_12 = model.state.extract_state(cids_unspecified)
        for i in range(len(ms_12)):
            if ms_12[i].id.cid == 1:
                self.assertAlmostEqual(ms_12[i].state.kirchner.q, 100 + i)

        ms_2 = model.state.extract_state(cids_2)
        for i in range(len(ms_2)):
            self.assertAlmostEqual(ms_2[i].state.kirchner.q, 200 + i)

        # feature test: serialization support, to and from bytes
        #
        bytes = ms_2.serialize_to_bytes(
        )  # first make some bytes out of the state
        with tempfile.TemporaryDirectory() as tmpdirname:
            file_path = str(path.join(tmpdirname, "pt_gs_k_state_test.bin"))
            api.byte_vector_to_file(file_path, bytes)  # stash it into a file
            bytes = api.byte_vector_from_file(
                file_path)  # get it back from the file and into ByteVector
        ms_2x = pt_gs_k.deserialize_from_bytes(
            bytes)  # then restore it from bytes to a StateWithIdVector

        self.assertIsNotNone(ms_2x)
        for i in range(len(ms_2x)):
            self.assertAlmostEqual(ms_2x[i].state.kirchner.q, 200 + i)

        # feature test: given a state-with-id-vector, get the pure state-vector
        # suitable for rm.initial_state= <state_vector>
        # note however that this is 'unsafe', you need to ensure that size/ordering is ok
        # - that is the purpose of the cell-state-with-id
        #   better solution could be to use
        #     rm.state.apply( state_with_id) .. and maybe check the result, number of states== expected applied
        #     rm.initial_state=rm.current_state  .. a new property to ease typical tasks
        sv_2 = ms_2.state_vector
        self.assertEqual(len(sv_2), len(ms_2))
        for s, sid in zip(sv_2, ms_2):
            self.assertAlmostEqual(s.kirchner.q, sid.state.kirchner.q)
        # example apply, then initial state:
        model.state.apply_state(ms_2, cids_unspecified)
        model.initial_state = model.current_state