def init_H_hat(self, V): if self.model.recycle_q: rval = self.model.prev_H if config.compute_test_value != 'off': if rval.get_value().shape[0] != V.tag.test_value.shape[0]: raise Exception('E step given wrong test batch size', rval.get_value().shape, V.tag.test_value.shape) else: rval = T.alloc(1., V.shape[0], self.model.nhid) for rval_value, V_value in get_debug_values(rval, V): if rval_value.shape[0] != V_value.shape[0]: debug_error_message("rval.shape = %s, V.shape = %s, element 0 should match but doesn't", str(rval_value.shape), str(V_value.shape)) return rval
def test_debug_error_message(): """tests that debug_error_message raises an exception when it should.""" prev_value = config.compute_test_value for mode in ['ignore', 'raise']: try: config.compute_test_value = mode try: op.debug_error_message('msg') raised = False except ValueError: raised = True assert raised finally: config.compute_test_value = prev_value
def test_debug_error_message(): # tests that debug_error_message raises an # exception when it should. prev_value = config.compute_test_value for mode in ["ignore", "raise"]: try: config.compute_test_value = mode try: op.debug_error_message("msg") raised = False except ValueError: raised = True assert raised finally: config.compute_test_value = prev_value