def test_conv_flatten_deriv(n4_hw12_c3_5x5): """ Test deriv of conv followed by flatten """ cf = ConvParams(**n4_hw12_c3_5x5) axes_rsck = ng.make_axes([cf.ax_f[2], cf.ax_f[3], cf.ax_f[0], cf.ax_f[-1]]) axes_rsck_prime = ng.make_axes([ng.make_axis(name=ax.name + 'p', length=ax.length) for ax in axes_rsck]) axes_nmpqk = ng.make_axes([cf.ax_o[-1], cf.ax_o[1], cf.ax_o[2], cf.ax_o[3], cf.ax_o[0]]) # broadcast input / filter axes input_var = ng.variable(cf.ax_i).named('input') input_val = np.ones(input_var.axes.lengths) filter_rsck_prime = ng.variable(axes_rsck_prime).named('filter') filter_var = filter_rsck_prime filter_rsck = ng.cast_axes(filter_rsck_prime, axes_rsck).named('frsck') filter_trsck = ng.expand_dims(filter_rsck, cf.ax_f[1], 0).named('ftrsck') filter_ctrsk = ng.axes_with_order(filter_trsck, axes=cf.ax_f).named('ctrsk') # convolution output_kmpqn = ng.convolution(cf.conv_params, input_var, filter_ctrsk, axes=cf.ax_o) output_nmpqk = ng.axes_with_order(output_kmpqn, axes=axes_nmpqk) # slice away the oD out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)] output_npqk = ng.tensor_slice(output_nmpqk, out_slicing) output = ng.flatten_at(output_npqk, idx=1) # cost and grad cost = ng.sum(output, out_axes=()) filter_val = np.ones(filter_var.axes.lengths) with ExecutorFactory() as factory: conv_comp = factory.executor(output, filter_var, input_var) grad_filter_num_comp = factory.numeric_derivative(cost, filter_var, 1.0, input_var) grad_filter_sym_comp = factory.derivative(cost, filter_var, input_var) grad_input_num_comp = factory.numeric_derivative(cost, input_var, 1.0, filter_var) grad_input_sym_comp = factory.derivative(cost, input_var, filter_var) conv_val = conv_comp(filter_val, input_val) conv_val_num = np.empty_like(conv_val) conv_val_num.fill(np.prod(cf.ax_f.lengths[:-1])) ng.testing.assert_allclose(conv_val, conv_val_num) grad_filter_num_val = grad_filter_num_comp(filter_val, input_val) grad_filter_sym_val = grad_filter_sym_comp(filter_val, input_val) ng.testing.assert_allclose(grad_filter_num_val, grad_filter_sym_val) grad_input_num_val = grad_input_num_comp(input_val, filter_val) grad_input_sym_val = grad_input_sym_comp(input_val, filter_val) ng.testing.assert_allclose(grad_input_num_val, grad_input_sym_val)
def test_conv_flatten_deriv(transformer_factory): """ Test deriv of conv followed by flatten """ # set shape C, D, H, W, N = (3, 1, 28, 28, 8) C, T, R, S, K = (3, 1, 5, 5, 32) # i, f, o axes ax_i = ng.make_axes([ax.C, ax.D, ax.H, ax.W, ax.N]) ax_f = ng.make_axes([ax.C, ax.T, ax.R, ax.S, ax.K]) ax_o = ng.make_axes([ ng.make_axis(32, roles=[ar.Channel]), ng.make_axis(1, roles=[ar.Depth]), ng.make_axis(24, roles=[ar.Height]), ng.make_axis(24, roles=[ar.Width]), ax.N ]) ax_i.set_shape((C, D, H, W, N)) ax_f.set_shape((C, T, R, S, K)) params = dict(pad_d=0, pad_h=0, pad_w=0, str_d=1, str_h=1, str_w=1) axes_rsck = ng.make_axes([ax.R, ax.S, ax.C, ax.K]) axes_rsck_prime = ng.make_axes( [ng.make_axis(l) for l in axes_rsck.lengths]) # broadcast input / filter axes image = ng.constant(np.ones(ax_i.lengths), ax_i) filter = ng.variable(axes_rsck_prime, initial_value=np.ones((R, S, C, K))) filter_casted = ng.cast_axes(filter, axes_rsck) filter_casted = ng.expand_dims(filter_casted, ax.T, 0) filter_casted = ng.axes_with_order(filter_casted, axes=ax_f) # convolution output = ng.convolution(params, image, filter_casted, axes=ax_o) oC, oD, oH, oW, oN = output.axes output = ng.axes_with_order(output, axes=ng.make_axes([oN, oD, oH, oW, oC])) # slice away the oD out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)] conv = ng.Slice(output, out_slicing) flatten = ng.flatten_at(conv, idx=1) # cost and grad cost = ng.sum(flatten, reduction_axes=flatten.axes) grad = ng.deriv(cost, filter) # compute conv_grad_comp = executor([conv, grad]) conv_val, grad_val = conv_grad_comp() assert np.allclose(conv_val, np.zeros_like(conv_val) + 75.) assert np.allclose(grad_val, np.zeros_like(grad_val) + 4608.)
def reorder_spatial_axes(tensor): """ Assumes we are getting a C, H, N, or C, H, W, N, or C, D, H, W, N """ spatial_axes = tensor.axes.spatial_axes() batch_axes = tensor.axes.batch_axes() if len(spatial_axes) == 0 or len(spatial_axes) > 3: raise ValueError( 'spatial ops can only operate on tensors with 1, 2, or 3 spatial axes.' 'Found {}'.format(spatial_axes)) if not batch_axes: raise ValueError('spatial ops require a batch axis') if not tensor.axes.channel_axis(): c = ng.make_axis(length=1, name='C') tensor = ng.expand_dims(tensor, c, 0) channel_axes = ng.make_axes(tensor.axes.channel_axis()) if len(spatial_axes) == 1: w = ng.make_axis(length=1, name=_WIDTH) tensor = ng.expand_dims(tensor, w, 0) spatial_axes = spatial_axes + w if len(spatial_axes) == 2: d = ng.make_axis(length=1, name=_DEPTH) tensor = ng.expand_dims(tensor, d, 0) spatial_axes = ng.make_axes([d]) + spatial_axes new_axes = channel_axes + spatial_axes + batch_axes return ng.axes_with_order(tensor, new_axes)
def train_outputs(self, in_obj): """ Arguments: in_obj (Tensor): object that provides the lookup indices """ in_obj.axes.find_by_short_name('time')[0].add_role(ar.time) in_obj.axes.find_by_short_name('time')[0].is_recurrent = True in_obj = ng.axes_with_role_order(in_obj, self.role_order) in_obj = ng.flatten(in_obj) in_axes = in_obj.axes self.lut_v_axis = ng.make_axis(self.vocab_size).named('V') self.lut_f_axis = ng.make_axis(self.embed_dim).named('F') self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis]) self.lut_o_axes = in_axes + ng.make_axes([self.lut_f_axis]) self.o_axes = ng.make_axes([self.lut_f_axis]) + in_axes[0].axes self.W = ng.variable(axes=self.w_axes, initial_value=self.lut_init( self.w_axes, self.lut_v_axis, self.pad_idx)).named('W') lut_result = ng.lookuptable(self.W, in_obj, self.lut_o_axes, update=self.update, pad_idx=self.pad_idx) return ng.axes_with_order(ng.unflatten(lut_result), self.o_axes)
def __call__(self, in_obj, **kwargs): """ Arguments: in_obj (Tensor): object that provides the lookup indices """ LABELS = {"weight": "weight", "bias": "bias"} in_obj = ng.axes_with_order( in_obj, ng.make_axes( [in_obj.axes.recurrent_axis(), in_obj.axes.batch_axis()])) in_obj = ng.flatten(in_obj) in_axes = in_obj.axes # label lut_v_axis as shadow axis for initializers ... once #1158 is # in, shadow axis will do more than just determine fan in/out for # initializers. self.lut_v_axis = ng.make_axis(self.vocab_size).named('V') self.axes_map = shadow_axes_map([self.lut_v_axis]) self.lut_v_axis = list(self.axes_map.values())[0] self.lut_f_axis = ng.make_axis(self.embed_dim).named('F') self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis]) self.lut_o_axes = in_axes | ng.make_axes([self.lut_f_axis]) self.o_axes = ng.make_axes([self.lut_f_axis]) | in_axes[0].axes if not self.initialized: self.W = ng.variable( axes=self.w_axes, initial_value=self.lut_init(self.w_axes, self.lut_v_axis, self.pad_idx), metadata={ "label": LABELS["weight"] }, ).named('LutW') lut_result = ng.lookuptable(self.W, in_obj, self.lut_o_axes, update=self.update, pad_idx=self.pad_idx) return ng.axes_with_order( ng.map_roles(ng.unflatten(lut_result), self.axes_map), self.o_axes)
def run_inference(self, out_axes, init_states, **kwargs): if self.celltype == 'LSTM': init_states = [(state, ng.constant(0., state.axes)) for state in init_states] one_time_axis = ng.make_axis(1, name="REC") time_axis = out_axes.recurrent_axis() batch_axis = out_axes.batch_axis() feature_axis = (out_axes - [time_axis, batch_axis])[0] outputs = [ng.constant(0., [batch_axis, one_time_axis, feature_axis])] hidden_states = init_states for timestep in range(time_axis.length): in_obj = outputs[-1] # Compute the next hidden/cell states for the recurrent layers next_hidden_states = [] for i, l in enumerate(self.layers[:-1]): if i < len(hidden_states): init_state = hidden_states[i] else: init_state = None if self.celltype == 'LSTM': h, c = l(in_obj, init_state=init_state, return_cell_state=True) in_obj = h h = ng.slice_along_axis(h, one_time_axis, 0) c = ng.slice_along_axis(c, one_time_axis, 0) next_hidden_states.append((h, c)) else: h = l(in_obj, init_state=init_state) in_obj = h h = ng.slice_along_axis(h, one_time_axis, 0) next_hidden_states.append((h, c)) hidden_states = next_hidden_states # Compute the output of the affine layer in_obj = self.layers[-1](in_obj) outputs.append(in_obj) # Get rid of the initial 0 input outputs = outputs[1:] outputs = [ ng.slice_along_axis(output, one_time_axis, 0) for output in outputs ] outputs = ng.stack(outputs, time_axis) outputs = ng.axes_with_order(outputs, out_axes) return outputs
def test_dimshuffle_bprop(self, x, A, B): """ dimshuffle a 2d array and make sure bprop works """ # randomly initialize x_value = rng.uniform(-1, 1, x.axes) check_derivative( ng.axes_with_order(x, [B, A]), x, 0.001, x_value, atol=1e-3, rtol=1e-3 )
def NHWC2NCHW(self, c2_op, inputs): """ Returns data in NHWC format. """ assert 1 == len(inputs) X = inputs[0] order = X.order if hasattr(X, 'order') else 'NHWC' if 'NHWC' != order: raise ValueError("NHWC2NCHW accepts only NHWC input format.") Y = ng.axes_with_order( X, axes=ng.make_axes([X.axes[0], X.axes[3], X.axes[1], X.axes[2]])) Y.order = 'NCHW' return Y
def __call__(self, *args, **kwargs): output = super(Deepspeech, self).__call__(*args, **kwargs) # prepare activations/gradients for warp-ctc # TODO: This should be handled in a graph pass if self.to_ctc is True: warp_axes = ng.make_axes([output.axes.recurrent_axis(), output.axes.batch_axis()]) warp_axes = warp_axes | output.axes.feature_axes() output = ng.axes_with_order(output, warp_axes) output = ng.ContiguousOp(output) return output
def SparseSoftmaxCrossEntropyWithLogits(self, tf_node, inputs): """ Computes softmax cross entropy. The inputs `logits` are unscaled log probabilities, and each row of `labels[i]` must be a valid distribution. Reference: https://goo.gl/z5T2my Arguments: tf_node: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: logits, labels, name """ # logits: (N1, Y1), labels: (N2,) logits, labels = inputs # check input dimension try: assert len(logits.axes) == 2 assert len(labels.axes) == 1 assert logits.axes[0].length == labels.axes[0].length except: raise NotImplementedError("logits' shape must be (Y, N), " "labels' shape must be (N,), " "other shapes not supported yet.") # get axis axis_y = logits.axes[1] # labels_one_hot: (Y2, N2) labels_one_hot = ng.one_hot(labels, axis=axis_y) # predicts: (N1, Y1) predicts = ng.softmax(logits, normalization_axes=axis_y) # dim-shuffle / cast to (Y1, N1) predicts_axes = ng.make_axes( [axis for axis in reversed(predicts.axes)]) predicts = ng.axes_with_order(predicts, axes=predicts_axes) labels_one_hot = ng.cast_axes(labels_one_hot, predicts_axes) # cross_entropy: (N1,) cross_entropy = ng.cross_entropy_multi(predicts, labels_one_hot, out_axes=(logits.axes[0], )) return cross_entropy
def test_dimshuffle_fprop(self, x, A, B): """ dimshuffle a 2d array and make sure fprop works """ # compute convolution with graph output = ng.axes_with_order(x, [B, A]) assert output.axes == ng.make_axes([B, A]) # randomly initialize x_value = rng.uniform(-1, 1, x.axes) with executor(output, x) as ex: result = ex(x_value) ng.testing.assert_allclose(result, x_value.T)
def test_dimshuffle_bprop(transformer_factory): """ dimshuffle a 2d array and make sure bprop works """ A = ng.make_axis(2) B = ng.make_axis(3) x = ng.placeholder(ng.make_axes([A, B])) # randomly initialize x_value = rng.uniform(-1, 1, x.axes) check_derivative( ng.axes_with_order(x, [B, A]), x, 0.001, x_value, atol=1e-3, rtol=1e-3 )
def test_idempotent_axes_c(): """ Test test axes transformations with autodiff, case c, with broadcast, slice, cast and dim-shuffle """ with ExecutorFactory() as ex: axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)]) result_axes = [ng.make_axis(length=axis.length) for axis in axes] # variable w = ng.variable(axes, initial_value=np.ones((3, 1))) # broadcast l / r, introducing dummy length 1 axes l = ng.broadcast(w, axes) r = ng.broadcast(w, axes) # slice axes_slice = [slice(None, None, None), slice(None, None, None)] l_sliced = ng.tensor_slice(l, axes_slice) r_sliced = ng.tensor_slice(r, axes_slice) # cast r r_sliced_casted = ng.cast_axes(r_sliced, axes) # perform add result = ng.add(l_sliced, r_sliced_casted) # cast / dimshuffle result = ng.cast_axes(result, result_axes) result = ng.axes_with_order(result, result_axes) # cost and grad cost = ng.sum(result, reduction_axes=result.axes) grad = ng.deriv(cost, w) grad_comp = ex.executor(grad) cost_comp = ex.executor(cost) cost_comp_ng = cost_comp() grad_comp_ng = grad_comp() grad_comp_np = np.ones((3, 1)) * 2. assert cost_comp_ng == 6.0 assert np.array_equal(grad_comp_ng, grad_comp_np)
def test_shuffled_deriv(transformer_factory): # This gets the axes of a delta in a generate_add_delta in a different order than the # value being updated C = ng.make_axis(length=3) T = ng.make_axis(length=1) R = ng.make_axis(length=5) S = ng.make_axis(length=5) axes = [R, S, C] v = ng.variable([ng.make_axis(_.length) for _ in axes]) rsc = ng.cast_axes(v, axes) trsc = ng.expand_dims(rsc, T, 0) ctrs = ng.axes_with_order(trsc, axes=[C, T, R, S]) cost = ng.sum(ctrs, out_axes=None) grad = ng.deriv(cost, v) with ExecutorFactory() as ex: d_fun = ex.executor(grad) d_fun()
def reorder_pos_axes(x, prefix=POS_AXIS_PREFIX): """ Reorder x's axes to descending positional axes. E.g. x's axes: [POS_1, POS_2, POS_0] => [POS_2, POS_1, POS_0] Args: x: ngrpah op Returns: x reordered to descending positional axes. """ # get axes names axes_names = [axis.name for axis in x.axes] num_axes = len(axes_names) # check axes names are valid for name in axes_names: if name[:len(prefix)] != prefix: raise ValueError("axis {} is not a valid positional axes, " "to be valid, must have prefix {}".format( name, prefix)) axes_positions = [int(name[len(prefix):]) for name in axes_names] if sorted(axes_positions) != list(range(num_axes)): raise ValueError("axes positions {} must be continuous integers " "starting from 0") # special case, x is already in a good order if (axes_positions == reversed(list(range(num_axes)))): return x # get a position -> length map map_pos_length = dict() for pos, length in zip(axes_positions, x.axes.lengths): map_pos_length[pos] = length # get shape after reordering new_shapes = [ map_pos_length[pos] for pos in reversed(list(range(num_axes))) ] return ng.axes_with_order(x, axes=make_pos_axes(new_shapes))
def test_shuffled_deriv(): # This gets the axes of a delta in a generate_add_delta in a different order than the # value being updated ax = ng.make_name_scope("ax") ax.C = ng.make_axis(3) ax.T = ng.make_axis(1) ax.R = ng.make_axis(5) ax.S = ng.make_axis(5) axes = [ax.R, ax.S, ax.C] v = ng.variable([ng.make_axis(_.length) for _ in axes]) rsc = ng.cast_axes(v, axes) trsc = ng.expand_dims(rsc, ax.T, 0) ctrs = ng.axes_with_order(trsc, axes=[ax.C, ax.T, ax.R, ax.S]) cost = ng.sum(ctrs, out_axes=None) grad = ng.deriv(cost, v) ex = ExecutorFactory() d_fun = ex.executor(grad) d_fun()
def test_dimshuffle_fprop(transformer_factory): """ dimshuffle a 2d array and make sure fprop works """ A = ng.make_axis(2) B = ng.make_axis(3) x = ng.placeholder(ng.make_axes([A, B])) # compute convolution with graph output = ng.axes_with_order(x, [B, A]) assert output.axes == ng.make_axes([B, A]) # randomly initialize x_value = rng.uniform(-1, 1, x.axes) result = executor(output, x)(x_value) np.testing.assert_allclose(result, x_value.T)
def sparse_softmax_cross_entropy_with_logits(labels=None, logits=None, name=None): """ Computes softmax cross entropy. The inputs `logits` are unscaled log probabilities, and each row of `labels[i]` must be a valid distribution. Args: labels: of axis (N,) for (POS_0,) logits: of axis (N, Y) for (POS_1, POS_0) name: name of the ngraph op """ # Check input dimension # ( N, Y), ( N) # logits: (pos_1, pos_0), labels: (pos_0) try: assert len(logits.axes) == 2 assert len(labels.axes) == 1 assert logits.axes[0].length == labels.axes[0].length except: raise NotImplementedError("logits' shape must be (N, Y), " "labels' shape must be (N,), " "other shapes not supported yet.") # get axis axis_n, axis_y = logits.axes # convert labels to one-hot labels labels = ng.cast_axes(labels, ng.make_axes(axis_n)) labels = ng.one_hot(labels, axis=axis_y) labels = ng.axes_with_order(labels, axes=logits.axes) # predicts: (N, Y) predicts = ng.softmax(logits, normalization_axes=axis_y) # cross_entropy: (N) res = ng.cross_entropy_multi(predicts, labels, out_axes=(axis_n, )) return cast_to_pos_axes(res).named(name)
def test_fail_on_missing_and_extra_axis(self, x, A, C): with pytest.raises(ValueError): ng.axes_with_order(x, [A, C])
def MaxPool(self, tf_node, inputs): """ Performs the max pooling on the input. Arguments: tf_node: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: input TODO: assume default tensorflow layout NHWC, RSCK, need to support NCHW as well need to clean up / merge with conv2d Axes: Tensorflow Ngraph in (N, H, W, C) (C, D, H, W, N) out (N, P, Q, K) (K, M, P, Q, N) Notes on output shape: https://www.tensorflow.org/api_docs/python/nn.html#convolution """ image = inputs[0] # TODO: currently NHWC only assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC" # new axes C, D, H, W, K, M, P, Q = [ng.make_axis() for _ in range(8)] N = ng.make_axis(name='N') D.length, M.length = 1, 1 # only supports 2D conv for now # tf's input axes ax_i_tf = ng.make_axes([N, H, W, C]) ax_i_tf.set_shape(image.axes.lengths) # ksize params tf_ksize = [int(s) for s in list(tf_node.attr['ksize'].list.i)] if len(tf_ksize) != 4: raise ValueError("Length of ksize my be 4.") if tf_ksize[0] != 1: raise NotImplementedError('Ksize on batch axis (N) must be 1.') if tf_ksize[3] != 1: raise NotImplementedError('Ksize on channel axis (C) must be 1.' 'Cross map pooling to be implemented.') R_length, S_length = tf_ksize[1:3] T_length = J_length = 1 # strides params tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)] if len(tf_strides) != 4: raise ValueError("Length of strides my be 4.") if tf_strides[0] != 1: raise NotImplementedError('Strides on batch axis (N) must be 1.') if tf_strides[3] != 1: raise NotImplementedError('Strides on channel axis (C) must be 1.') str_h, str_w = tf_strides[1], tf_strides[2] # padding params padding = tf_node.attr['padding'].s.decode("ascii") pad_t, pad_b, pad_l, pad_r = common_conv2d_pool_padding( image.axes.lengths, (R_length, S_length, C.length, C.length), tf_strides, padding) if pad_t != pad_b or pad_l != pad_r: raise NotImplementedError("Requires symmetric padding in ngraph:" "pad_t(%s) == pad_b(%s) and" "pad_l(%s) == pad_r(%s)" % (pad_t, pad_b, pad_l, pad_r)) # pooling params params = dict(op='max', pad_d=0, pad_h=pad_t, pad_w=pad_l, pad_c=0, str_d=1, str_h=str_h, str_w=str_w, str_c=1, J=J_length, T=T_length, R=R_length, S=S_length) # tf's output axes ax_o_tf = ng.make_axes([N, P, Q, K]) ax_o_tf.set_shape(common_conv2d_pool_output_shape(image.axes.lengths, (R_length, S_length, C.length, C.length), tf_strides, padding)) # ngraph's i, f, o axes ax_i = ng.make_axes([C, D, H, W, N]) ax_o = ng.make_axes([K, M, P, Q, N]) # image NHWC -> CDHWN image = ng.cast_axes(image, ng.make_axes([N, H, W, C])) image = ng.expand_dims(image, D, 1) # NHWC -> NDHWC image = ng.axes_with_order(image, ax_i) # NDHWC -> CDHWN # pooling output = ng.pooling(params, image, axes=ax_o) # output KMPQN -> NPQK # KMPQN -> NMPQK output = ng.axes_with_order(output, ng.make_axes( [N, M, P, Q, K])) # NMPQK -> NPQK output = ng.tensor_slice(output, [slice(None), 0, slice(None), slice(None), slice(None)]) return output
def Conv2D(self, tf_node, inputs): """ Computes a 2-D convolution given 4D input and filter tensors. Arguments: tf_node: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: input, filter TODO: assume default tensorflow layout NHWC, RSCK, need to support NCHW as well need to clean up / merge with maxpool Axes: Tensorflow Ngraph in (N, H, W, C) (C, D, H, W, N) filter (R, S, C, K) (C, T, R, S, K) out (N, P, Q, K) (K, M, P, Q, N) Notes on output shape: https://www.tensorflow.org/api_docs/python/nn.html#convolution """ image, weight = inputs # TODO: currently NHWC only if tf_node.attr['data_format'].s.decode("ascii") != "NHWC": raise NotImplementedError("Only supports NHWC import for now.") # check in_C == f_C if image.axes.lengths[3] != weight.axes.lengths[2]: raise ValueError("Image's C dimension (%s) must be equal to " "filter's C dimension (%s)." % (image.axes.lengths[3], weight.axes.lengths[2])) # strides params tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)] if len(tf_strides) != 4: raise ValueError("Length of strides my be 4.") if tf_strides[0] != 1: raise NotImplementedError('Strides on batch axis (N) must be 1.') if tf_strides[3] != 1: raise NotImplementedError('Strides on channel axis (C) must be 1.') str_h, str_w = tf_strides[1], tf_strides[2] # padding params padding = tf_node.attr['padding'].s.decode("ascii") pad_t, pad_b, pad_l, pad_r = common_conv2d_pool_padding( image.axes.lengths, weight.axes.lengths, tf_strides, padding) if pad_t != pad_b or pad_l != pad_r: raise NotImplementedError("Requires symmetric padding in ngraph:" "pad_t(%s) == pad_b(%s) and" "pad_l(%s) == pad_r(%s)" % (pad_t, pad_b, pad_l, pad_r)) # conv params params = dict(pad_d=0, pad_h=pad_t, pad_w=pad_l, str_d=1, str_h=str_h, str_w=str_w, dil_d=1, dil_h=1, dil_w=1) # new axes C, D, H, W, T, R, S, K, M, P, Q = [ng.make_axis() for _ in range(11)] N = ng.make_axis(name='N') D.length, T.length, M.length = 1, 1, 1 # only supports 2D conv for now # tf's i, f, o axes ax_i_tf = ng.make_axes([N, H, W, C]) ax_f_tf = ng.make_axes([R, S, C, K]) ax_o_tf = ng.make_axes([N, P, Q, K]) ax_i_tf.set_shape(image.axes.lengths) ax_f_tf.set_shape(weight.axes.lengths) ax_o_tf.set_shape(common_conv2d_pool_output_shape(image.axes.lengths, weight.axes.lengths, tf_strides, padding)) # ngraph's i, f, o axes ax_i = ng.make_axes([C, D, H, W, N]) ax_f = ng.make_axes([C, T, R, S, K]) ax_o = ng.make_axes([K, M, P, Q, N]) # image NHWC -> CDHWN image = ng.cast_axes(image, ng.make_axes([N, H, W, C])) image = ng.expand_dims(image, D, 1) # NHWC -> NDHWC image = ng.axes_with_order(image, ax_i) # NDHWC -> CDHWN # weights RSCK -> CTRSK weight = ng.cast_axes(weight, ng.make_axes([R, S, C, K])) weight = ng.expand_dims(weight, T, 0) # RSCK -> TRSCK weight = ng.axes_with_order(weight, ax_f) # TRSCK -> CTRSK # convolution output = ng.convolution(params, image, weight, axes=ax_o) # output KMPQN -> NPQK # KMPQN -> NMPQK output = ng.axes_with_order(output, ng.make_axes([N, M, P, Q, K])) # NMPQK -> NPQK output = ng.tensor_slice(output, [slice(None), 0, slice(None), slice(None), slice(None)]) return output
def __call__(self, in_obj, channel_axes="C", spatial_axes=("D", "H", "W"), **kwargs): """ Arguments: in_obj (Op): Input op channel_axes (str): name of the expected channel axis type - defaults to "C" spatial_axes (tuple): names of expected depth, height and width axis types - defaults to "D", "H", and "W" """ if isinstance(spatial_axes, dict): spatial_axes = tuple( spatial_axes.get(name, name) for name in ("D", "H", "W")) elif isinstance(spatial_axes, tuple): if len(spatial_axes) < 3: raise ValueError( "spatial_axes must have length 3 (e.g. ('D', 'H', 'W'))") spatial_axes = tuple( name if name else default for name, default in zip(spatial_axes, ("D", "H", "W"))) orig_axes = in_obj.axes in_obj = reorder_spatial_axes(in_obj, channel_axes, spatial_axes) channel_axes = in_obj.axes.get_by_names(channel_axes) spatial_axes = in_obj.axes.get_by_names(*spatial_axes) filter_axes = self._filter_axes(channel_axes, spatial_axes) # mark 'K' as a shadow axis for the initializers. axes_map = shadow_axes_map(filter_axes.find_by_name('K')) filter_axes = ng.make_axes([ axis if axis.name != 'K' else list(axes_map.keys())[0] for axis in filter_axes ]) if not self.initialized: if not self.weight_norm: self.W = ng.variable(axes=filter_axes, initial_value=self.init, metadata={ "label": LABELS["weight"] }).named("W") else: self.v = ng.variable(axes=filter_axes, initial_value=self.init, metadata={ "label": LABELS["weight"] }).named("v") out_axes = ng.make_axes( [filter_axes.get_by_names("K__NG_SHADOW")]) v_norm = ng.mean(ng.square(self.v), out_axes=out_axes) self.g = ng.variable(axes=out_axes, initial_value=self.init, metadata={ "label": LABELS["weight"] }).named("g") self.W = self.g * self.v * ng.reciprocal( ng.sqrt(v_norm + 1e-3)) else: if filter_axes != self.W.axes: raise ValueError( ("{layer_name} layer has already been initialized with an " "input object which has resulted in filter axes: " "{existing_filter_axes}. This new input object has axes: " "{input_axes}, which implies the need for filter axes: " "{new_filter_axes} which are different than the existing " "filter axes.").format( layer_name=self.name, existing_filter_axes=self.W.axes, input_axes=in_obj.axes, new_filter_axes=filter_axes, )) output = ng.map_roles( self._conv_op(in_obj, channel_axes, spatial_axes), axes_map) # Reorder the output to match the input order output_axis_order = ng.make_axes( [output.axes.find_by_name(ax.name)[0] for ax in orig_axes]) # Remove introduced axes. If their length is > 1, then perhaps they should be kept slices = [ 0 if (ax not in orig_axes) and ax.length == 1 else slice(None) for ax in output.axes ] output = ng.tensor_slice(output, slices) # New axes with length > 1 may have been introduced. Add them to the end. output_axis_order = output_axis_order | output.axes return ng.axes_with_order(output, output_axis_order)
def __call__(self, H_pr, h_ip, states, output=None, reset_cells=True, input_data=None): """ Arguments: ---------- H_pr : Encoding for question h_ip: Sliced input of paragraph encoding for a particular time step states: State of the LSTM cell output: previous hidden state input_data: the ArrayIterator object for training data (contains information of length of each sentence) """ # get recurrent axis for question rec_axis_pr = H_pr.axes.recurrent_axis() const_one = ng.constant(const=1, axes=[self.dummy_axis]) # if first word in a paragraph is encountered, assign the previous LSTM # hidden state as zeros if output is None: h_r_old = ng.constant(axes=[self.F, self.N], const=0) else: h_r_old = ng.cast_axes(output, [self.F, self.N]) # Compute attention vector sum_1 = ng.dot(self.W_q, H_pr) sum_1 = ng.cast_axes(sum_1, [self.hidden_rows, self.hidden_cols_ques, self.N]) int_sum1 = ng.dot(self.W_p, h_ip) int_sum2 = ng.dot(self.W_r, h_r_old) int_sum = int_sum1 + int_sum2 + self.b_p int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1) # making for the attention vector req_mask = ng.axes_with_order( ng.cast_axes(ng.dot(self.e_q2, input_data['question_len']), [self.hidden_rows, self.N, self.hidden_cols_ques]), [self.hidden_rows, self.hidden_cols_ques, self.N]) req_mask_2 = ng.axes_with_order( ng.cast_axes(ng.dot(const_one, input_data['question_len']), [self.N, self.hidden_cols_ques]), [self.hidden_cols_ques, self.N]) G_i_int = sum_1 + ng.multiply( req_mask, ng.axes_with_order( ng.dot(int_sum, self.e_q), [self.hidden_rows, self.hidden_cols_ques, self.N])) G_i = ng.tanh(G_i_int) # Attention Vector at_sum1 = ng.dot(self.w_lr, G_i) at = ng.softmax(at_sum1 + ng.log(req_mask_2)) at_repeated = ng.cast_axes( ng.dot(self.e_q2, ng.ExpandDims(at, self.dummy_axis, 0)), [self.F, rec_axis_pr, self.N]) # Stack the 2 vectors as per the equation in the paper z1 = h_ip z2 = ng.sum(ng.multiply(H_pr, at_repeated), rec_axis_pr) # represents the inp to lstm_cell # ng.concat_along_axis([z1,z2],self.F) inputs_lstm = ng.dot(self.ZX, z1) + ng.dot(self.ZY, z2) # LSTM cell computations (from LSTM brach in ngraph) if self.out_axes is None: self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis() if states is None: states = self.initialize_states(inputs_lstm.axes.batch_axis(), reset_cells=reset_cells) assert self.out_axes == states['h'].axes for gate in self._gate_names: transform = self.gate_transform[gate] gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate]( states['h']) self.gate_output[gate] = ng.cast_role(transform(gate_input), self.out_axes) states['c'] = (states['c'] * self.gate_output['f'] + self.gate_output['i'] * self.gate_output['g']) states['h'] = self.gate_output['o'] * self.activation(states['c']) states['h'] = ng.cast_role(states['h'], self.out_axes) # return unrolled output and state of LSTM cell return ng.cast_axes(states['h'], axes=[self.F, self.N]), states
def MaxPool(self, tf_node, inputs): """ Performs the max pooling on the input. Arguments: tf_node: NodeDef object, the tensorflow node tso convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: input TODO: assume default tensorflow layout NHWC, RSCK, need to support NCHW as well need to clean up / merge with conv2d Notes on output shape: https://www.tensorflow.org/api_docs/python/nn.html#convolution """ image = inputs[0] # TODO: currently NHWC only assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC" # set axes shape ax_N = ng.make_axis(batch=True) ax_C = ng.make_axis(roles=[ar.Channel]) ax_D = ng.make_axis(roles=[ar.Depth]) ax_H = ng.make_axis(roles=[ar.Height]) ax_W = ng.make_axis(roles=[ar.Width]) ng.make_axes([ax_N, ax_H, ax_W, ax_C]).set_shape(image.axes.lengths) ax_D.length = 1 # ksize params tf_ksize = [int(s) for s in list(tf_node.attr['ksize'].list.i)] if len(tf_ksize) != 4: raise ValueError("Length of ksize my be 4.") if tf_ksize[0] != 1: raise NotImplementedError('Ksize on batch axis (N) must be 1.') if tf_ksize[3] != 1: raise NotImplementedError('Ksize on channel axis (C) must be 1.' 'Cross map pooling to be implemented.') R, S = tf_ksize[1:3] T = J = 1 # strides params tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)] if len(tf_strides) != 4: raise ValueError("Length of strides my be 4.") if tf_strides[0] != 1: raise NotImplementedError('Strides on batch axis (N) must be 1.') if tf_strides[3] != 1: raise NotImplementedError('Strides on channel axis (C) must be 1.') str_h, str_w = tf_strides[1], tf_strides[2] # padding params padding = tf_node.attr['padding'].s.decode("ascii") pad_t, pad_b, pad_l, pad_r = tf_conv2d_pool_padding( image.axes.lengths, (R, S, ax_C.length, ax_C.length), tf_strides, padding) if pad_t != pad_b or pad_l != pad_r: raise NotImplementedError("Requires symmetric padding in ngraph:" "pad_t(%s) == pad_b(%s) and" "pad_l(%s) == pad_r(%s)" % (pad_t, pad_b, pad_l, pad_r)) # pooling params params = dict(op='max', pad_d=0, pad_h=pad_t, pad_w=pad_l, pad_c=0, str_d=1, str_h=str_h, str_w=str_w, str_c=1, J=J, T=T, R=R, S=S) # i, f, o axes ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N]) ax_o = ng.make_axes([ spatial_axis(ax_i, J, params['pad_c'], params['str_c'], ar.Channel), spatial_axis(ax_i, T, params['pad_d'], params['str_d'], ar.Depth), spatial_axis(ax_i, R, params['pad_h'], params['str_h'], ar.Height), spatial_axis(ax_i, S, params['pad_w'], params['str_w'], ar.Width), ax_N ]) # broadcast input / filter axes image = ng.cast_axes(image, ng.make_axes([ax_N, ax_H, ax_W, ax_C])) image = ng.expand_dims(image, ax_D, 1) # NHWC -> NDHWC image = ng.axes_with_order(image, axes=ax_i) # NDHWC -> CDHWN # pooling output = ng.pooling(params, image, axes=ax_o) # cast back to NHWC oC, oD, oH, oW, oN = output.axes output = ng.broadcast(output, ng.make_axes([oN, oD, oH, oW, oC])) # slice away the oD out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)] output = ng.Slice(output, out_slicing) return output
def Conv2D(self, tf_node, inputs): """ Computes a 2-D convolution given 4D input and filter tensors. Arguments: tf_node: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: input, filter TODO: assume default tensorflow layout NHWC, RSCK, need to support NCHW as well need to clean up / merge with maxpool Notes on output shape: https://www.tensorflow.org/api_docs/python/nn.html#convolution """ image, weight = inputs # TODO: currently NHWC only assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC" # set axes shape ax_N = ng.make_axis(batch=True) ax_C = ng.make_axis(roles=[ar.Channel]) ax_D = ng.make_axis(roles=[ar.Depth]) ax_H = ng.make_axis(roles=[ar.Height]) ax_W = ng.make_axis(roles=[ar.Width]) ax_T = ng.make_axis(roles=[ar.Depth]) ax_R = ng.make_axis(roles=[ar.Height]) ax_S = ng.make_axis(roles=[ar.Width]) ax_K = ng.make_axis(roles=[ar.Channelout]) ng.make_axes([ax_N, ax_H, ax_W, ax_C]).set_shape(image.axes.lengths) ng.make_axes([ax_R, ax_S, ax_C, ax_K]).set_shape(weight.axes.lengths) ax_D.length = 1 ax_T.length = 1 # strides params tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)] if len(tf_strides) != 4: raise ValueError("Length of strides my be 4.") if tf_strides[0] != 1: raise NotImplementedError('Strides on batch axis (N) must be 1.') if tf_strides[3] != 1: raise NotImplementedError('Strides on channel axis (C) must be 1.') str_h, str_w = tf_strides[1], tf_strides[2] # padding params padding = tf_node.attr['padding'].s.decode("ascii") pad_t, pad_b, pad_l, pad_r = tf_conv2d_pool_padding( image.axes.lengths, weight.axes.lengths, tf_strides, padding) if pad_t != pad_b or pad_l != pad_r: raise NotImplementedError("Requires symmetric padding in ngraph:" "pad_t(%s) == pad_b(%s) and" "pad_l(%s) == pad_r(%s)" % (pad_t, pad_b, pad_l, pad_r)) # conv params params = dict(pad_d=0, pad_h=pad_t, pad_w=pad_l, str_d=1, str_h=str_h, str_w=str_w) # i, f, o axes ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N]) ax_f = ng.make_axes([ax_C, ax_T, ax_R, ax_S, ax_K]) ax_o = ng.make_axes([ ng.make_axis(ax_K.length, name='C', roles=[ar.Channel]), spatial_axis(ax_i, ax_f, params['pad_d'], params['str_d'], ar.Depth), spatial_axis(ax_i, ax_f, params['pad_h'], params['str_h'], ar.Height), spatial_axis(ax_i, ax_f, params['pad_w'], params['str_w'], ar.Width), ax_N ]) # broadcast input / filter axes image = ng.cast_axes(image, ng.make_axes([ax_N, ax_H, ax_W, ax_C])) image = ng.expand_dims(image, ax_D, 1) # NHWC -> NDHWC image = ng.axes_with_order(image, axes=ax_i) # NDHWC -> CDHWN weight = ng.cast_axes(weight, ng.make_axes([ax_R, ax_S, ax_C, ax_K])) weight = ng.expand_dims(weight, ax_T, 0) # RSCK -> TRSCK weight = ng.axes_with_order(weight, axes=ax_f) # TRSCK -> CTRSK # convolution output = ng.convolution(params, image, weight, axes=ax_o) # cast back to NHWC oC, oD, oH, oW, oN = output.axes output = ng.broadcast(output, ng.make_axes([oN, oD, oH, oW, oC])) # slice away the oD out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)] output = ng.Slice(output, out_slicing) return output
def test_fail_on_axis_reuse(self, x, A, B): with pytest.raises(ValueError): ng.axes_with_order(x, [A, B, B])
# Set up drop out layer dropout_val = ng.slice_along_axis(inputs['dropout_val'], N, 0) dropout_1 = Dropout_Modified(keep=dropout_val) dropout_2 = Dropout_Modified(keep=dropout_val) drop_pointer = ng.maximum(dropout_val, ng.constant(const=0.8, axes=[])) dropout_3 = Dropout_Modified(keep=drop_pointer) dropout_4 = Dropout_Modified(keep=drop_pointer) # Constants required for masking const_LSTM = ng.constant(axes=[F, dummy_axis], const=1) const_loss = ng.constant(axes=[ax.Y, dummy_axis], const=1) const_LSTM_embed = ng.constant(axes=[F_embed, dummy_axis], const=1) # Create masks reorder_para_mask = ng.axes_with_order( inputs['para_len'], axes=[ dummy_axis, inputs['para_len'].axes[2], N]) reorder_ques_mask = ng.axes_with_order( inputs['question_len'], axes=[ dummy_axis, inputs['question_len'].axes[2], N]) # Masks for question and para after encoding layer mask_para = ng.dot(const_LSTM, reorder_para_mask) mask_question = ng.dot(const_LSTM, ng.cast_axes(reorder_ques_mask, [dummy_axis, REC, N])) # Masks for question and para after embedding/LookupTable layer mask_para_embed = ng.dot(const_LSTM_embed, reorder_para_mask) mask_question_embed = ng.dot( const_LSTM_embed, ng.cast_axes(
def test_fail_on_missing(self, x, B): with pytest.raises(ValueError): ng.axes_with_order(x, [B, B])
def __call__(self, H_concat, states=None, output=None, reset_cells=True, input_data=None): """ Arguments: ---------- H_concat: Concatenated forward and reverse unrolled outputs of the `MatchLSTMCell_withAttention` cell states: previous LSTM state output: hidden state from previous timestep reset_cells: argument to reset a cell input_data: the ArrayIterator object for training data (contains information of length of each sentence) """ rec_axis_pr = H_concat.axes.recurrent_axis() const_one = ng.constant(const=1, axes=[self.dummy_axis]) b_k_lists = [] # rec_axis_hy=H_hy.axes.recurrent_axis() for i in range(0, 2): if output is None: h_k_old = ng.constant(axes=[self.F, self.N], const=0) else: h_k_old = ng.cast_axes(output, [self.F, self.N]) sum_1 = ng.dot( self.V_answer, ng.cast_axes(H_concat, [self.lstm_feature_new, rec_axis_pr, self.N])) sum_1 = ng.cast_axes( sum_1, [self.hidden_rows, self.hidden_cols_para, self.N]) int_sum2 = ng.dot(self.W_a, h_k_old) int_sum = int_sum2 # +self.b_a int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1) # Following notations from the paper # Compute Attention Vector F_i_int = sum_1 + ng.axes_with_order( ng.dot(int_sum, self.e_q), [self.hidden_rows, self.hidden_cols_para, self.N]) F_i = ng.tanh(F_i_int) # Attention Vector b_k_sum1 = ng.dot(self.v_lr, F_i) # This masking with -inf for length of para>max_para ensures that # when we do softmax over these values we get a 0 mask_loss_new = ng.log(ng.dot(const_one, input_data['para_len'])) mask_loss_new = ng.axes_with_order( ng.cast_axes(mask_loss_new, [self.N, self.hidden_cols_para]), [self.hidden_cols_para, self.N]) # Add mask to the required logits b_k = ng.softmax(b_k_sum1 + mask_loss_new) b_k_req = ng.softmax(b_k_sum1 + mask_loss_new) b_k_repeated = ng.cast_axes( ng.dot(self.e_q2, ng.ExpandDims(b_k, self.dummy_axis, 0)), [H_concat.axes[0], rec_axis_pr, self.N]) inputs_lstm = ng.sum(ng.multiply(H_concat, b_k_repeated), rec_axis_pr) # LSTM Cell calculations if self.out_axes is None: self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis( ) if states is None: states = self.initialize_states(inputs_lstm.axes.batch_axis(), reset_cells=reset_cells) assert self.out_axes == states['h'].axes for gate in self._gate_names: transform = self.gate_transform[gate] gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate]( states['h']) self.gate_output[gate] = ng.cast_role(transform(gate_input), self.out_axes) states['c'] = (states['c'] * self.gate_output['f'] + self.gate_output['i'] * self.gate_output['g']) states['h'] = self.gate_output['o'] * self.activation(states['c']) states['h'] = ng.cast_role(states['h'], self.out_axes) output = states['h'] # append required outputs b_k_lists.append(b_k_req) return b_k_lists
def Conv(self, c2_op, inputs): """ Computes a 2-D convolution given 4D input and filter tensors. Arguments: c2_op: NodeDef object, the caffe2 node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the caffe2 node. Inputs to c2_op: input, wegiths, filter Supports caffe2's layout NHWC and NCHW as well. """ X, W, bias = inputs order = [val.s for val in c2_op.arg if val.name == "order"] if 1 != len(order): raise ValueError("Multiple order values in convolution") order = order[0] if order not in ("NHWC", "NCHW"): raise NotImplementedError("Unsupported order in convolution: {}", order) # set input axes shape ax_N = ng.make_axis(name='N') ax_C = ng.make_axis() ax_D = ng.make_axis(length=1) ax_H = ng.make_axis() ax_W = ng.make_axis() # set kernel axes shape ax_kernel_D = ng.make_axis(length=1) ax_kernel_H = ng.make_axis() ax_kernel_W = ng.make_axis() ax_kernel_ofm = ng.make_axis() # create placeholders for output axes oC = ng.make_axis(name='C') oD = ng.make_axis(name='D', length=1) oH = ng.make_axis(name='H') oW = ng.make_axis(name='W') axes_order = { 'NCHW': { 'X': [ax_N, ax_C, ax_H, ax_W], 'W': [ax_kernel_ofm, ax_C, ax_kernel_H, ax_kernel_W] }, 'NHWC': { 'X': [ax_N, ax_H, ax_W, ax_C], 'W': [ax_kernel_ofm, ax_kernel_H, ax_kernel_W, ax_C] }, } ng.make_axes(axes_order[order]['X']).set_shape(X.axes.lengths) ng.make_axes(axes_order[order]['W']).set_shape(W.axes.lengths) if 1 != len(bias.axes): raise ValueError("Bias's must be 1D.") if ax_kernel_ofm.length != bias.axes.lengths[0]: raise ValueError( "Bias's length must equal to number of output feature maps.") # strides params stride_size = [int(val.i) for val in c2_op.arg if val.name == "stride"] if len(stride_size) != 1: raise ValueError("Stride size must be scalar value") str_h = str_w = stride_size[0] # padding params pad_t, pad_b, pad_l, pad_r = \ _c2_padding(c2_op, in_NHWC=[ax_N.length, ax_H.length, ax_W.length, ax_C.length], kernel_HWIO=[ax_kernel_H.length, ax_kernel_W.length, ax_C.length, ax_kernel_ofm.length], stride_NHWC=[1, str_h, str_w, 1]) if pad_t != pad_b or pad_l != pad_r: raise NotImplementedError("Requires symmetric padding in ngraph:" "pad_t(%s) == pad_b(%s) and" "pad_l(%s) == pad_r(%s)" % (pad_t, pad_b, pad_l, pad_r)) # conv params params = dict(pad_d=0, pad_h=pad_t, pad_w=pad_l, str_d=1, str_h=str_h, str_w=str_w, dil_d=1, dil_h=1, dil_w=1) # input, weight, output axes internal_ax_dict = { 'X': ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N]), 'W': ng.make_axes( [ax_C, ax_kernel_D, ax_kernel_H, ax_kernel_W, ax_kernel_ofm]) } oC.length = ax_kernel_ofm.length oH.length = output_dim(ax_H.length, ax_kernel_H.length, params['pad_h'], params['str_h']) oW.length = output_dim(ax_W.length, ax_kernel_W.length, params['pad_w'], params['str_w']) internal_ax_dict['Y'] = ng.make_axes([oC, oD, oH, oW, ax_N]) # broadcast input / filter axes # flow for NHWC order: | flow for NCHW order: # input: | input: # expand dims: NHWC -> NDHWC | expand dims: NCHW -> NDCHW # reorder: NDHWC -> CDHWN | reorder: NDCHW -> CDHWN # weights: | weights: # expand dims: (ofm)HWC -> D(ofm)HWC | expand dims: (ofm)CHWC -> D(ofm)CHW # reorder: D(ofm)HWC -> CDHW(ofm) | reorder: D(ofm)CHW -> CDHW(ofm) X = ng.cast_axes(X, ng.make_axes(axes_order[order]['X'])) X = ng.expand_dims(X, ax_D, 1) X = ng.axes_with_order(X, axes=internal_ax_dict['X']) W = ng.cast_axes(W, ng.make_axes(axes_order[order]['W'])) W = ng.expand_dims(W, ax_kernel_D, 0) W = ng.axes_with_order(W, axes=internal_ax_dict['W']) # convolution Y = ng.convolution(params, X, W, axes=internal_ax_dict['Y']) # cast back to proper format Y = ng.broadcast(Y, ng.make_axes([ax_N, oD, oH, oW, oC])) if "NHWC" == order \ else ng.broadcast(Y, ng.make_axes([ax_N, oD, oC, oH, oW])) # NCHW # slice away the oD out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)] Y = ng.tensor_slice(Y, out_slicing) def _conv_bias_add(c2_op, inputs): X, bias = inputs bias = ng.cast_axes(bias, axes=ng.make_axes( [X.axes[1 if 'NCHW' == order else 3]])) Y = ng.Add(X, bias) return Y return _conv_bias_add(c2_op, [Y, bias])
def Pool(self, c2_op, inputs): """ Performs max or average pooling on the input. Arguments: c2_op: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the c2_op node. Inputs to c2_op: input """ supported_pooling = {'MaxPool': 'max', 'AveragePool': 'avg'} image = inputs[0] # TODO: we assume NCHW, make some assert here? # set input axes shape ax_N = ng.make_axis(name='N') ax_C = ng.make_axis() ax_D = ng.make_axis(length=1) ax_H = ng.make_axis() ax_W = ng.make_axis() ng.make_axes([ax_N, ax_C, ax_H, ax_W]).set_shape(image.axes.lengths) # create placeholders for output axes oC = ng.make_axis(name='C') oD = ng.make_axis(length=1, name='D') oH = ng.make_axis(name='H') oW = ng.make_axis(name='W') # spatial kernel size kernel_size = [int(val.i) for val in c2_op.arg if val.name == "kernel"] if len(kernel_size) != 1: raise ValueError("Kernel size must be scalar value") # kernel is square kernel_h = kernel_w = kernel_size[0] kernel_d = kernel_c = 1 # strides params stride_size = [int(val.i) for val in c2_op.arg if val.name == "stride"] if len(stride_size) != 1: raise ValueError("Stride size must be scalar value") stride_h = stride_w = stride_size[0] # padding params pad_t, pad_b, pad_l, pad_r = \ _c2_padding(c2_op, in_NHWC=[ax_N.length, ax_H.length, ax_W.length, ax_C.length], kernel_HWIO=[kernel_h, kernel_w, ax_C.length, ax_C.length], stride_NHWC=[1, stride_h, stride_w, 1]) if pad_t != pad_b or pad_l != pad_r: raise NotImplementedError("Requires symmetric padding in ngraph:" "pad_t(%s) == pad_b(%s) and" "pad_l(%s) == pad_r(%s)" % (pad_t, pad_b, pad_l, pad_r)) # pooling params params = dict(op=supported_pooling[c2_op.type], pad_d=0, pad_h=pad_t, pad_w=pad_l, pad_c=0, str_d=1, str_h=stride_h, str_w=stride_w, str_c=1, J=kernel_c, T=kernel_d, R=kernel_h, S=kernel_w) # i, o axes oC.length = output_dim(ax_C.length, kernel_c, params['pad_c'], params['str_c']) oD.length = output_dim(ax_D.length, kernel_d, params['pad_d'], params['str_d']) oH.length = output_dim(ax_H.length, kernel_h, params['pad_h'], params['str_h']) oW.length = output_dim(ax_W.length, kernel_w, params['pad_w'], params['str_w']) ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N]) ax_o = ng.make_axes([oC, oD, oH, oW, ax_N]) # broadcast input / filter axes image = ng.cast_axes(image, ng.make_axes([ax_N, ax_C, ax_H, ax_W])) image = ng.expand_dims(image, ax_D, 1) # NCHW -> NDCHW image = ng.axes_with_order(image, axes=ax_i) # NDCHW -> CDHWN # pooling output = ng.pooling(params, image, axes=ax_o) # cast back to NDCHW output = ng.broadcast(output, ng.make_axes([ax_N, oD, oC, oH, oW])) # slice away the oD out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)] output = ng.tensor_slice(output, out_slicing) return output