def test_cross_entropy_binary_unmatched_axes(input_tensor): """If y and t have different axes, an error should be thrown immediately""" y = input_tensor feature_axis, batch_axis = y.axes t = ng.placeholder([ng.make_axis(feature_axis.length), batch_axis]) with pytest.raises(ng.UnmatchedAxesError): ng.cross_entropy_binary_inner(y, t)
def test_cross_entropy_binary(transformer_factory): """TODO.""" N = ng.make_axis(name='N') W = ng.make_axis(name='W') delta = .001 W.length = 20 N.length = 128 axes = ng.make_axes([W, N]) p_u = ng.placeholder(axes) u = rng.uniform(-3.0, 3.0, p_u.axes) p_v = ng.placeholder(axes) v = rng.uniform(-3.0, 3.0, p_u.axes) y = ng.sigmoid(p_u) t = ng.softmax(p_v) val_u = ng.cross_entropy_binary_inner(y, t) ex = ExecutorFactory() dval_u_num_fun = ex.numeric_derivative(val_u, p_u, delta, p_v) dval_u_graph_fun = ex.derivative(val_u, p_u, p_v) dval_u_num = dval_u_num_fun(u, v) dval_u_graph = dval_u_graph_fun(u, v) np.testing.assert_allclose(dval_u_graph, dval_u_num, atol=1e-2, rtol=1e-2)
def fuse_ce_binary_inner_callback(self, op, label_map_op_list): """ Callback function that handles fusion for cross_entropy_binary_inner pattern """ for (label_map, op) in label_map_op_list: # Matched Sigmoid pattern, do the replacement here. y = label_map[self.ce_y_label] t = label_map[self.ce_t_label] cross_without_opt_op = ng.cross_entropy_binary_inner( y, t, enable_sig_opt=False, enable_diff_opt=True) self.replace_op(op, cross_without_opt_op)
def test_cross_entropy_binary_logistic_shortcut(input_tensor): """TODO.""" p_u = input_tensor p_v = ng.placeholder(p_u.axes) u = rng.uniform(-3.0, 3.0, p_u.axes) v = np_softmax(rng.uniform(-3.0, 3.0, p_u.axes), 0) cel = cross_entropy_binary_logistic(u, v) cel_shortcut = cross_entropy_binary_logistic_shortcut(u, v) ng.testing.assert_allclose(cel, cel_shortcut, rtol=1e-5) with executor(ng.cross_entropy_binary_inner(ng.sigmoid(p_u), p_v), p_u, p_v) as ex: cel_graph = ex(u, v) ng.testing.assert_allclose(cel, cel_graph, rtol=1e-5)
def test_cross_entropy_binary_logistic_shortcut(transformer_factory): """TODO.""" N = ng.make_axis(name='N') W = ng.make_axis(name='W') W.length = 20 N.length = 128 axes = ng.make_axes([W, N]) p_u = ng.placeholder(axes) u = rng.uniform(-3.0, 3.0, p_u.axes) p_v = ng.placeholder(axes) v = np_softmax(rng.uniform(-3.0, 3.0, p_u.axes), 0) cel = cross_entropy_binary_logistic(u, v) cel_shortcut = cross_entropy_binary_logistic_shortcut(u, v) np.testing.assert_allclose(cel, cel_shortcut, rtol=1e-5) cel_graph = executor(ng.cross_entropy_binary_inner(ng.sigmoid(p_u), p_v), p_u, p_v)(u, v) np.testing.assert_allclose(cel, cel_graph, rtol=1e-5)
def test_cross_entropy_binary(input_tensor): """TODO.""" p_u = input_tensor p_v = ng.placeholder(p_u.axes) u = rng.uniform(-3.0, 3.0, p_u.axes) v = rng.uniform(-3.0, 3.0, p_u.axes) delta = .001 y = ng.sigmoid(p_u) t = ng.softmax(p_v) val_u = ng.cross_entropy_binary_inner(y, t) with ExecutorFactory() as ex: dval_u_num_fun = ex.numeric_derivative(val_u, p_u, delta, p_v) dval_u_graph_fun = ex.derivative(val_u, p_u, p_v) dval_u_num = dval_u_num_fun(u, v) dval_u_graph = dval_u_graph_fun(u, v) ng.testing.assert_allclose(dval_u_graph, dval_u_num, atol=1e-2, rtol=1e-2)
def construct_ce_binary_inner_pattern(self): """ Generate graph op that represents a pattern for Cross Entropy Binary Inner operation ng.cross_entropy_binary_inner(y, t, enable_sig_opt=True, enable_diff_opt=True) Returns: Single pattern that matches "optimized" Cross Entropy Binary Inner """ self.ce_x_label = "X" self.ce_y_label = "Y" self.ce_t_label = "T" x = PatternLabelOp(self.ce_x_label, axes={ng.make_axis(name='N')}) y = PatternLabelOp(self.ce_y_label, axes={ng.make_axis(name='N')}) t = PatternLabelOp(self.ce_t_label, axes={ng.make_axis(name='N')}) y.deriv_handler = ng.SigmoidOp(x) cross_op = ng.cross_entropy_binary_inner(y, t, enable_sig_opt=True, enable_diff_opt=True) return cross_op