def batch_dot(x, y, axes=None, name=None): X = x.tensor Y = y.tensor if isinstance(axes, six.integer_types): axes = (axes, axes) if axes is None: axes = (X.shape.ndims - 1, Y.shape.ndims - 2) PLAIDML_BATCHDOT_TF_BEHAVIOR = os.getenv('PLAIDML_BATCHDOT_TF_BEHAVIOR') if PLAIDML_BATCHDOT_TF_BEHAVIOR: _report_unimplemented('batch_dot') else: # replicate theano/documentation-specified behavior first_dim = edsl.TensorDim() first_idx = edsl.TensorIndex() batch_dim = edsl.TensorDim() batch_idx = edsl.TensorIndex() xdims = edsl.TensorDims(X.shape.ndims) xdims[0] = first_dim xdims[axes[0]] = batch_dim xidxs = edsl.TensorIndexes(X.shape.ndims) xidxs[0] = first_idx xidxs[axes[0]] = batch_idx ydims = edsl.TensorDims(Y.shape.ndims) ydims[0] = first_dim ydims[axes[1]] = batch_dim yidxs = edsl.TensorIndexes(Y.shape.ndims) yidxs[0] = first_idx yidxs[axes[1]] = batch_idx odims = [xdims[N] for N in range(len(xdims)) if N != axes[0] ] + [ydims[N] for N in range(1, len(ydims)) if N != axes[1]] oidxs = [xidxs[N] for N in range(len(xidxs)) if N != axes[0] ] + [yidxs[N] for N in range(1, len(yidxs)) if N != axes[1]] X.bind_dims(*xdims) Y.bind_dims(*ydims) O = edsl.TensorOutput(*odims) O[oidxs] += X[xidxs] * Y[yidxs] if len(odims) == 1: O = plaidml_op.expand_dims(O, 1) return _KerasNode('batch_dot', tensor=O)
def expand_dims(x, axis=-1, name=None): logger.debug('expand_dims(x: {}, axis: {}, name={})'.format(x, axis, name)) return _KerasNode('expand_dims', name=name, tensor=plaidml_op.expand_dims(x.tensor, axis))
def expand_dims(x, axis=-1, name=None): return _KerasNode('expand_dims', name=name, tensor=plaidml_op.expand_dims(x.tensor, axis))