def SparseSoftmaxCrossEntropyWithLogits(self, tf_node, inputs): """ Computes softmax cross entropy. The inputs `logits` are unscaled log probabilities, and each row of `labels[i]` must be a valid distribution. Reference: https://goo.gl/z5T2my Arguments: tf_node: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: logits, labels, name """ # logits: (N1, Y1), labels: (N2,) logits, labels = inputs # check input dimension try: assert len(logits.axes) == 2 assert len(labels.axes) == 1 assert logits.axes[0].length == labels.axes[0].length except: raise NotImplementedError("logits' shape must be (Y, N), " "labels' shape must be (N,), " "other shapes not supported yet.") # get axis axis_y = logits.axes[1] # labels_one_hot: (Y2, N2) labels_one_hot = ng.one_hot(labels, axis=axis_y) # predicts: (N1, Y1) predicts = ng.softmax(logits, normalization_axes=axis_y) # dim-shuffle / cast to (Y1, N1) predicts_axes = ng.make_axes( [axis for axis in reversed(predicts.axes)]) predicts = ng.Dimshuffle(predicts, axes=predicts_axes) labels_one_hot = ng.cast_axes(labels_one_hot, predicts_axes) # cross_entropy: (N1,) cross_entropy = ng.cross_entropy_multi(predicts, labels_one_hot, out_axes=(logits.axes[0], )) return cross_entropy
def test_dimshuffle_bprop(transformer_factory): """ dimshuffle a 2d array and make sure bprop works """ A = ng.make_axis(2) B = ng.make_axis(3) x = ng.placeholder(ng.make_axes([A, B])) # randomly initialize x_value = rng.uniform(-1, 1, x.axes) check_derivative(ng.Dimshuffle(x, axes=ng.make_axes([B, A])), x, 0.001, x_value, atol=1e-3, rtol=1e-3)
def test_idempotent_axes_c(): """ Test test axes transformations with autodiff, case c, with broadcast, slice, cast and dim-shuffle """ ex = ExecutorFactory() axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)]) result_axes = [ng.make_axis(length=axis.length) for axis in axes] # variable w = ng.variable(axes, initial_value=np.ones((3, 1))) l = w r = w # broadcast l / r, introducing dummy length 1 axes l = ng.broadcast(l, axes) r = ng.broadcast(r, axes) # slice axes_slice = [slice(None, None, None), slice(None, None, None)] l_sliced = ng.Slice(l, axes_slice) r_sliced = ng.Slice(r, axes_slice) # cast r r_sliced_casted = ng.cast_axes(r_sliced, axes) # perform add result = ng.add(l_sliced, r_sliced_casted) # cast / dimshuffle result = ng.cast_axes(result, result_axes) result = ng.Dimshuffle(result, axes=result_axes) # cost and grad cost = ng.sum(result, reduction_axes=result.axes) grad = ng.deriv(cost, w) grad_comp = ex.executor(grad) cost_comp = ex.executor(cost) assert cost_comp() == 6.0 assert np.array_equal(grad_comp(), np.ones((3, 1)) * 2.)
def test_dimshuffle_fprop(transformer_factory): """ dimshuffle a 2d array and make sure fprop works """ A = ng.make_axis(2) B = ng.make_axis(3) x = ng.placeholder(ng.make_axes([A, B])) # compute convolution with graph output = ng.Dimshuffle(x, axes=ng.make_axes([B, A])) assert output.axes == ng.make_axes([B, A]) # randomly initialize x_value = rng.uniform(-1, 1, x.axes) result = executor(output, x)(x_value) np.testing.assert_allclose(result, x_value.T)
def _element_wise_binary(self, ng_op, tf_node, inputs): """ Element-wise binary operation with broadcast. Args: ng_op: ngraph Op to be applied. tf_node: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. """ # get inputs left, right = inputs # check if shape compatibility left_shape = left.axes.lengths right_shape = right.axes.lengths assert is_compatible_numpy_shape(left_shape, right_shape) if left_shape and right_shape and left_shape != right_shape: """ Cast axes in numpy broadcast mapping rule 1. introduce dummy length 1 axes to match left / right length 2. keep maps for matching left / right / result axes 3. slice left / right to remove length 1 axes if not both of them are length 1 4. cast right to left by matching axes 5. perform binary op 6. cast and broadcast result """ left_dim = len(left.axes) right_dim = len(right.axes) # pad left and right axis to be the same length, align right result_dim = max(left_dim, right_dim) left_axes_pad = [ ng.make_axis(length=1) for _ in range(result_dim - left_dim) ] + list(left.axes) right_axes_pad = [ ng.make_axis(length=1) for _ in range(result_dim - right_dim) ] + list(right.axes) result_axes = [ ng.make_axis(length=max(l.length, r.length)) for l, r in zip(left_axes_pad, right_axes_pad) ] # broadcast left / right, introducing dummy length 1 axes left = ng.broadcast(left, left_axes_pad) right = ng.broadcast(right, right_axes_pad) # make two-way map of lr matching axes and map for result axes lr_axes_map = dict() result_axes_map = dict() for l, r, re in zip(left.axes, right.axes, result_axes): lr_axes_map[l] = r lr_axes_map[r] = l result_axes_map[l] = re result_axes_map[r] = re # get left / right slice left_slice = [] right_slice = [] for l, r in zip(left.axes, right.axes): if l.length == 1 and r.length != 1: left_slice.append(0) else: left_slice.append(slice(None)) if r.length == 1 and l.length != 1: right_slice.append(0) else: right_slice.append(slice(None)) # perform slicing left_sliced = ng.Slice(left, left_slice) right_sliced = ng.Slice(right, right_slice) # now cast the right_sliced to left_sliced from the axis map right_casted_axes = [] for r in right_sliced.axes: if r in lr_axes_map and lr_axes_map[r] in left_sliced.axes: right_casted_axes.append(lr_axes_map[r]) else: right_casted_axes.append(r) right_sliced_casted = ng.cast_axes(right_sliced, right_casted_axes) # perform binary op result_op = ng_op(left_sliced, right_sliced_casted) # cast result axis and broadcast to full result axes trimmed_result_axes = [ result_axes_map[re] for re in result_op.axes ] result_op = ng.cast_axes(result_op, trimmed_result_axes) result_op = ng.Dimshuffle(result_op, axes=result_axes) elif left_shape == right_shape: # cast right axes to be the same as left right = ng.cast_axes(right, left.axes) result_op = ng_op(left, right).named(tf_node.name) else: # no need for casting result_op = ng_op(left, right).named(tf_node.name) # return op return result_op
def MaxPool(self, tf_node, inputs): """ Performs the max pooling on the input. Arguments: tf_node: NodeDef object, the tensorflow node tso convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: input TODO: assume default tensorflow layout NHWC, RSCK, need to support NCHW as well need to clean up / merge with conv2d Notes on output shape: https://www.tensorflow.org/api_docs/python/nn.html#convolution """ image = inputs[0] # TODO: currently NHWC only assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC" # set axes shape ax_N = ng.make_axis(batch=True) ax_C = ng.make_axis(roles=[ar.Channel]) ax_D = ng.make_axis(roles=[ar.Depth]) ax_H = ng.make_axis(roles=[ar.Height]) ax_W = ng.make_axis(roles=[ar.Width]) ng.make_axes([ax_N, ax_H, ax_W, ax_C]).set_shape(image.axes.lengths) ax_D.length = 1 # ksize params tf_ksize = [int(s) for s in list(tf_node.attr['ksize'].list.i)] if len(tf_ksize) != 4: raise ValueError("Length of ksize my be 4.") if tf_ksize[0] != 1: raise NotImplementedError('Ksize on batch axis (N) must be 1.') if tf_ksize[3] != 1: raise NotImplementedError('Ksize on channel axis (C) must be 1.' 'Cross map pooling to be implemented.') R, S = tf_ksize[1:3] T = J = 1 # strides params tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)] if len(tf_strides) != 4: raise ValueError("Length of strides my be 4.") if tf_strides[0] != 1: raise NotImplementedError('Strides on batch axis (N) must be 1.') if tf_strides[3] != 1: raise NotImplementedError('Strides on channel axis (C) must be 1.') str_h, str_w = tf_strides[1], tf_strides[2] # padding params padding = tf_node.attr['padding'].s.decode("ascii") pad_t, pad_b, pad_l, pad_r = tf_conv2d_pool_padding( image.axes.lengths, (R, S, ax_C.length, ax_C.length), tf_strides, padding) if pad_t != pad_b or pad_l != pad_r: raise NotImplementedError("Requires symmetric padding in ngraph:" "pad_t(%s) == pad_b(%s) and" "pad_l(%s) == pad_r(%s)" % (pad_t, pad_b, pad_l, pad_r)) # pooling params params = dict(op='max', pad_d=0, pad_h=pad_t, pad_w=pad_l, pad_c=0, str_d=1, str_h=str_h, str_w=str_w, str_c=1, J=J, T=T, R=R, S=S) # i, f, o axes ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N]) ax_o = ng.make_axes([ spatial_axis(ax_i, J, params['pad_c'], params['str_c'], ar.Channel), spatial_axis(ax_i, T, params['pad_d'], params['str_d'], ar.Depth), spatial_axis(ax_i, R, params['pad_h'], params['str_h'], ar.Height), spatial_axis(ax_i, S, params['pad_w'], params['str_w'], ar.Width), ax_N ]) # broadcast input / filter axes image = ng.cast_axes(image, ng.make_axes([ax_N, ax_H, ax_W, ax_C])) image = ng.expand_dims(image, ax_D, 1) # NHWC -> NDHWC image = ng.Dimshuffle(image, axes=ax_i) # NDHWC -> CDHWN # pooling output = ng.pooling(params, image, axes=ax_o) # cast back to NHWC oC, oD, oH, oW, oN = output.axes output = ng.broadcast(output, ng.make_axes([oN, oD, oH, oW, oC])) # slice away the oD out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)] output = ng.Slice(output, out_slicing) return output
def Conv2D(self, tf_node, inputs): """ Computes a 2-D convolution given 4D input and filter tensors. Arguments: tf_node: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: input, filter TODO: assume default tensorflow layout NHWC, RSCK, need to support NCHW as well need to clean up / merge with maxpool Notes on output shape: https://www.tensorflow.org/api_docs/python/nn.html#convolution """ image, weight = inputs # TODO: currently NHWC only assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC" # set axes shape ax_N = ng.make_axis(batch=True) ax_C = ng.make_axis(roles=[ar.Channel]) ax_D = ng.make_axis(roles=[ar.Depth]) ax_H = ng.make_axis(roles=[ar.Height]) ax_W = ng.make_axis(roles=[ar.Width]) ax_T = ng.make_axis(roles=[ar.Depth]) ax_R = ng.make_axis(roles=[ar.Height]) ax_S = ng.make_axis(roles=[ar.Width]) ax_K = ng.make_axis(roles=[ar.Channelout]) ng.make_axes([ax_N, ax_H, ax_W, ax_C]).set_shape(image.axes.lengths) ng.make_axes([ax_R, ax_S, ax_C, ax_K]).set_shape(weight.axes.lengths) ax_D.length = 1 ax_T.length = 1 # strides params tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)] if len(tf_strides) != 4: raise ValueError("Length of strides my be 4.") if tf_strides[0] != 1: raise NotImplementedError('Strides on batch axis (N) must be 1.') if tf_strides[3] != 1: raise NotImplementedError('Strides on channel axis (C) must be 1.') str_h, str_w = tf_strides[1], tf_strides[2] # padding params padding = tf_node.attr['padding'].s.decode("ascii") pad_t, pad_b, pad_l, pad_r = tf_conv2d_pool_padding( image.axes.lengths, weight.axes.lengths, tf_strides, padding) if pad_t != pad_b or pad_l != pad_r: raise NotImplementedError("Requires symmetric padding in ngraph:" "pad_t(%s) == pad_b(%s) and" "pad_l(%s) == pad_r(%s)" % (pad_t, pad_b, pad_l, pad_r)) # conv params params = dict(pad_d=0, pad_h=pad_t, pad_w=pad_l, str_d=1, str_h=str_h, str_w=str_w) # i, f, o axes ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N]) ax_f = ng.make_axes([ax_C, ax_T, ax_R, ax_S, ax_K]) ax_o = ng.make_axes([ ng.make_axis(ax_K.length, name='C', roles=[ar.Channel]), spatial_axis(ax_i, ax_f, params['pad_d'], params['str_d'], ar.Depth), spatial_axis(ax_i, ax_f, params['pad_h'], params['str_h'], ar.Height), spatial_axis(ax_i, ax_f, params['pad_w'], params['str_w'], ar.Width), ax_N ]) # broadcast input / filter axes image = ng.cast_axes(image, ng.make_axes([ax_N, ax_H, ax_W, ax_C])) image = ng.expand_dims(image, ax_D, 1) # NHWC -> NDHWC image = ng.Dimshuffle(image, axes=ax_i) # NDHWC -> CDHWN weight = ng.cast_axes(weight, ng.make_axes([ax_R, ax_S, ax_C, ax_K])) weight = ng.expand_dims(weight, ax_T, 0) # RSCK -> TRSCK weight = ng.Dimshuffle(weight, axes=ax_f) # TRSCK -> CTRSK # convolution output = ng.convolution(params, image, weight, axes=ax_o) # cast back to NHWC oC, oD, oH, oW, oN = output.axes output = ng.broadcast(output, ng.make_axes([oN, oD, oH, oW, oC])) # slice away the oD out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)] output = ng.Slice(output, out_slicing) return output
def test_fail_on_missing_and_extra_axis(transformer_factory, x, A, C): with pytest.raises(ValueError): ng.Dimshuffle(x, axes=ng.make_axes([A, C]))
def test_fail_on_missing(transformer_factory, x, B): with pytest.raises(ValueError): ng.Dimshuffle(x, axes=ng.make_axes([B, B]))
def test_fail_on_axis_reuse(transformer_factory, x, A, B): with pytest.raises(ValueError): ng.Dimshuffle(x, axes=ng.make_axes([A, B, B]))