def test_log_diffBase_reverse(): ad.set_mode('reverse') # ============================================================================= # define the input variable # ============================================================================= data = np.random.random(5) x = Variable(data) # ============================================================================= # define custom block # ============================================================================= base = np.random.randint(2, 5) + np.random.random() log_block = log(base=base) # ============================================================================= # compute output of custom block # ============================================================================= y_block = log_block(x) y_block.compute_gradients() # ============================================================================= # define expected output # ============================================================================= data_true = np.log(data) / np.log(base) gradient_true = np.diag(1 / (data * np.log(base))) # ============================================================================= # assert data pass # ============================================================================= assert np.equal(data_true, y_block.data).all( ), 'wrong log data pass. expected {}, given{}'.format( data_true, y_block.data) # ============================================================================= # assert gradient forward pass # ============================================================================= assert np.equal(gradient_true, y_block.gradient).all( ), 'wrong log gradient forward pass. expected {}, given{}'.format( gradient_true, y_block.gradient) ad.set_mode('forward')
global mode, c_graph mode = new_mode if new_mode == 'reverse': reset_graph() # ============================================================================= # shortcuts for better user interface # ============================================================================= sin_ = sin() cos_ = cos() tan_ = tan() exp_ = exp() log_ = log() sqrt_ = sqrt() sinh_ = sinh() cosh_ = cosh() tanh_ = tanh() arcsin_ = arcsin() arccos_ = arccos() arctan_ = arctan() add_ = add() subtract_ = subtract() multiply_ = multiply() divide_ = divide() power_ = power() sum_elts_ = sum_elts() # ================