Exemplo n.º 1
0
    def init_H_hat(self, V):

        if self.model.recycle_q:
            rval = self.model.prev_H
            if config.compute_test_value != 'off':
                if rval.get_value().shape[0] != V.tag.test_value.shape[0]:
                    raise Exception('E step given wrong test batch size', rval.get_value().shape, V.tag.test_value.shape)
        else:
            rval = T.alloc(1., V.shape[0], self.model.nhid)

            for rval_value, V_value in get_debug_values(rval, V):
                if rval_value.shape[0] != V_value.shape[0]:
                    debug_error_message("rval.shape = %s, V.shape = %s, element 0 should match but doesn't", str(rval_value.shape), str(V_value.shape))

        return rval
Exemplo n.º 2
0
def test_debug_error_message():
    """tests that debug_error_message raises an
    exception when it should."""

    prev_value = config.compute_test_value

    for mode in ['ignore', 'raise']:

        try:
            config.compute_test_value = mode

            try:
                op.debug_error_message('msg')
                raised = False
            except ValueError:
                raised = True
            assert raised
        finally:
            config.compute_test_value = prev_value
Exemplo n.º 3
0
def test_debug_error_message():
    # tests that debug_error_message raises an
    # exception when it should.

    prev_value = config.compute_test_value

    for mode in ["ignore", "raise"]:

        try:
            config.compute_test_value = mode

            try:
                op.debug_error_message("msg")
                raised = False
            except ValueError:
                raised = True
            assert raised
        finally:
            config.compute_test_value = prev_value