Exemplo n.º 1
0
def batch_dot(x, y, axes=None, name=None):
    X = x.tensor
    Y = y.tensor
    if isinstance(axes, six.integer_types):
        axes = (axes, axes)
    if axes is None:
        axes = (X.shape.ndims - 1, Y.shape.ndims - 2)
    PLAIDML_BATCHDOT_TF_BEHAVIOR = os.getenv('PLAIDML_BATCHDOT_TF_BEHAVIOR')
    if PLAIDML_BATCHDOT_TF_BEHAVIOR:
        _report_unimplemented('batch_dot')
    else:
        # replicate theano/documentation-specified behavior
        first_dim = edsl.TensorDim()
        first_idx = edsl.TensorIndex()
        batch_dim = edsl.TensorDim()
        batch_idx = edsl.TensorIndex()
        xdims = edsl.TensorDims(X.shape.ndims)
        xdims[0] = first_dim
        xdims[axes[0]] = batch_dim
        xidxs = edsl.TensorIndexes(X.shape.ndims)
        xidxs[0] = first_idx
        xidxs[axes[0]] = batch_idx
        ydims = edsl.TensorDims(Y.shape.ndims)
        ydims[0] = first_dim
        ydims[axes[1]] = batch_dim
        yidxs = edsl.TensorIndexes(Y.shape.ndims)
        yidxs[0] = first_idx
        yidxs[axes[1]] = batch_idx
        odims = [xdims[N] for N in range(len(xdims)) if N != axes[0]
                ] + [ydims[N] for N in range(1, len(ydims)) if N != axes[1]]
        oidxs = [xidxs[N] for N in range(len(xidxs)) if N != axes[0]
                ] + [yidxs[N] for N in range(1, len(yidxs)) if N != axes[1]]
        X.bind_dims(*xdims)
        Y.bind_dims(*ydims)
        O = edsl.TensorOutput(*odims)
        O[oidxs] += X[xidxs] * Y[yidxs]
    if len(odims) == 1:
        O = plaidml_op.expand_dims(O, 1)
    return _KerasNode('batch_dot', tensor=O)
Exemplo n.º 2
0
def expand_dims(x, axis=-1, name=None):
    logger.debug('expand_dims(x: {}, axis: {}, name={})'.format(x, axis, name))
    return _KerasNode('expand_dims', name=name, tensor=plaidml_op.expand_dims(x.tensor, axis))
Exemplo n.º 3
0
def expand_dims(x, axis=-1, name=None):
    return _KerasNode('expand_dims', name=name, tensor=plaidml_op.expand_dims(x.tensor, axis))