Пример #1
0
    def test_value_manipulation(self):
        val = np.random.random((4, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)

        # get_value
        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # set_value
        val = np.random.random((4, 2))
        KTH.set_value(xth, val)
        KTF.set_value(xtf, val)

        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # count_params
        assert KTH.count_params(xth) == KTF.count_params(xtf)

        # print_tensor
        check_single_tensor_operation('print_tensor', ())
        check_single_tensor_operation('print_tensor', (2,))
        check_single_tensor_operation('print_tensor', (4, 3))
        check_single_tensor_operation('print_tensor', (1, 2, 3))

        val = np.random.random((3, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)
        assert KTH.get_variable_shape(xth) == KTF.get_variable_shape(xtf)
Пример #2
0
    def test_value_manipulation(self):
        val = np.random.random((4, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)

        # get_value
        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # set_value
        val = np.random.random((4, 2))
        KTH.set_value(xth, val)
        KTF.set_value(xtf, val)

        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # count_params
        assert KTH.count_params(xth) == KTF.count_params(xtf)

        # print_tensor
        check_single_tensor_operation('print_tensor', ())
        check_single_tensor_operation('print_tensor', (2, ))
        check_single_tensor_operation('print_tensor', (4, 3))
        check_single_tensor_operation('print_tensor', (1, 2, 3))

        val = np.random.random((3, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)
        assert KTH.get_variable_shape(xth) == KTF.get_variable_shape(xtf)
Пример #3
0
    def reset_states(self):
        assert self.stateful, 'Layer must be stateful.'

        if hasattr(self, 'states'):
            K.set_value(self.states[0], np.zeros((self.batch_size, self.output_dim)))
            K.set_value(self.states[1], np.zeros((self.batch_size, self.output_dim)))
        else:
            self.states = [K.zeros((self.batch_size, self.output_dim)), 
                           K.zeros((self.batch_size, self.output_dim))]
Пример #4
0
    def test_value_manipulation(self):
        val = np.random.random((4, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)

        # get_value
        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # set_value
        val = np.random.random((4, 2))
        KTH.set_value(xth, val)
        KTF.set_value(xtf, val)

        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # count_params
        assert KTH.count_params(xth) == KTF.count_params(xtf)
Пример #5
0
    def test_value_manipulation(self):
        val = np.random.random((4, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)

        # get_value
        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # set_value
        val = np.random.random((4, 2))
        KTH.set_value(xth, val)
        KTF.set_value(xtf, val)

        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # count_params
        assert KTH.count_params(xth) == KTF.count_params(xtf)
Пример #6
0
def shuffle_states(graph_model, indices):
    for l in graph_model.layers:
        if getattr(l, 'stateful', False): 
            for s in l.states:
                K.set_value(s, s.get_value()[indices])
Пример #7
0
 def set_state(self, noise):
     K.set_value(self.states[0], noise)
Пример #8
0
def shuffle_states(graph_model, indices):
    for l in graph_model.layers:
        if getattr(l, 'stateful', False):
            for s in l.states:
                K.set_value(s, s.get_value()[indices])
Пример #9
0
 def set_state(self, noise):
     K.set_value(self.states[1], noise)