Пример #1
0
def gconv2d(input, filter, strides, padding, gconv_indices, gconv_shape_info,
            use_cudnn_on_gpu=None, data_format='NHWC', name=None):
    """
    Tensorflow implementation of the group convolution.
    This function has the same interface as the standard convolution nn.conv2d, except for two new parameters,
    gconv_indices and gconv_shape_info. These can be obtained from gconv2d_util(), and are described below

    :param input: a tensor with (batch, height, width, in channels) axes.
    :param filter: a tensor with (ksize, ksize, in channels * in transformations, out channels) axes.
      The shape for filter can be obtained from gconv2d_util().
    :param strides: A list of ints. 1-D of length 4. The stride of the sliding window for each dimension of input.
     Must be in the same order as the dimension specified with format.
    :param padding: A string from: "SAME", "VALID". The type of padding algorithm to use.
    :param gconv_indices: indices used in the filter transformation step of the G-Conv.
      Can be obtained from gconv2d_util() or using a command like flatten_indices(make_d4_p4m_indices(ksize=3)).
    :param gconv_shape_info: a tuple containing
     (num output channels, num output transformations, num input channels, num input transformations, kernel size)
     Can be obtained from gconv2d_util()
    :param use_cudnn_on_gpu: an optional bool. Defaults to True.
    :param data_format: the order of axes. Currently only NCHW is supported
    :param name: a name for the operation (optional)
    :return: tensor with (batch, out channels, height, width) axes.
    """

    if data_format != 'NHWC':
        raise NotImplemented('Currently only NHWC data_format is supported. Got:' + str(data_format))

    # Transform the filters
    transformed_filter = transform_filter_2d_nhwc(w=filter, flat_indices=gconv_indices, shape_info=gconv_shape_info)

    # Convolve input with transformed filters
    conv = tf.nn.conv2d(input=input, filter=transformed_filter, strides=strides, padding=padding,
                        use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format, name=name)

    return conv
Пример #2
0
def gconv2d(input, filter, strides, padding, gconv_indices, gconv_shape_info,
            use_cudnn_on_gpu=None, data_format='NHWC', name=None):
    """
    Tensorflow implementation of the group convolution.
    This function has the same interface as the standard convolution nn.conv2d, except for two new parameters,
    gconv_indices and gconv_shape_info. These can be obtained from gconv2d_util(), and are described below

    :param input: a tensor with (batch, height, width, in channels) axes.
    :param filter: a tensor with (ksize, ksize, in channels * in transformations, out channels) axes.
      The shape for filter can be obtained from gconv2d_util().
    :param strides: A list of ints. 1-D of length 4. The stride of the sliding window for each dimension of input.
     Must be in the same order as the dimension specified with format.
    :param padding: A string from: "SAME", "VALID". The type of padding algorithm to use.
    :param gconv_indices: indices used in the filter transformation step of the G-Conv.
      Can be obtained from gconv2d_util() or using a command like flatten_indices(make_d4_p4m_indices(ksize=3)).
    :param gconv_shape_info: a tuple containing
     (num output channels, num output transformations, num input channels, num input transformations, kernel size)
     Can be obtained from gconv2d_util()
    :param use_cudnn_on_gpu: an optional bool. Defaults to True.
    :param data_format: the order of axes. Currently only NCHW is supported
    :param name: a name for the operation (optional)
    :return: tensor with (batch, out channels, height, width) axes.
    """

    if data_format != 'NHWC':
        raise NotImplemented('Currently only NHWC data_format is supported. Got:' + str(data_format))

    # Transform the filters
    transformed_filter = transform_filter_2d_nhwc(w=filter, flat_indices=gconv_indices, shape_info=gconv_shape_info)

    # Convolve input with transformed filters
    conv = tf.nn.conv2d(input=input, filter=transformed_filter, strides=strides, padding=padding,
                        use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format, name=name)

    return conv
Пример #3
0
def tf_trans_filter(w, inds):

    flat_inds = flatten_indices(inds)
    no, ni, nti, n, _ = w.shape
    shape_info = (no, inds.shape[0], ni, nti, n)

    w = w.transpose((3, 4, 2, 1, 0)).reshape((n, n, nti * ni, no))

    wt = tf.constant(w)
    rwt = transform_filter_2d_nhwc(wt, flat_inds, shape_info)

    sess = tf.Session()
    rwt = sess.run(rwt)
    sess.close()

    nto = inds.shape[0]
    rwt = rwt.transpose(3, 2, 0, 1).reshape(no, nto, ni, nti, n, n)
    return rwt