Beispiel #1
0
def gconv2d_util(h_input, h_output, in_channels, out_channels, ksize):
    """
    Convenience function for setting up static data required for the G-Conv.
     This function returns:
      1) an array of indices used in the filter transformation step of gconv2d
      2) shape information required by gconv2d
      5) the shape of the filter tensor to be allocated and passed to gconv2d

    :param h_input: one of ('Z2', 'C4', 'D4'). Use 'Z2' for the first layer. Use 'C4' or 'D4' for later layers.
    :param h_output: one of ('C4', 'D4'). What kind of transformations to use (rotations or roto-reflections).
      The choice of h_output of one layer should equal h_input of the next layer.
    :param in_channels: the number of input channels. Note: this refers to the number of (3D) channels on the group.
    The number of 2D channels will be 1, 4, or 8 times larger, depending the value of h_input.
    :param out_channels: the number of output channels. Note: this refers to the number of (3D) channels on the group.
    The number of 2D channels will be 1, 4, or 8 times larger, depending on the value of h_output.
    :param ksize: the spatial size of the filter kernels (typically 3, 5, or 7).
    :return: gconv_indices
    """

    if h_input == 'Z2' and h_output == 'C4':
        gconv_indices = flatten_indices(make_c4_z2_indices(ksize=ksize))
        nti = 1
        nto = 4
    elif h_input == 'C4' and h_output == 'C4':
        gconv_indices = flatten_indices(make_c4_p4_indices(ksize=ksize))
        nti = 4
        nto = 4
    elif h_input == 'Z2' and h_output == 'D4':
        gconv_indices = flatten_indices(make_d4_z2_indices(ksize=ksize))
        nti = 1
        nto = 8
    elif h_input == 'D4' and h_output == 'D4':
        gconv_indices = flatten_indices(make_d4_p4m_indices(ksize=ksize))
        nti = 8
        nto = 8
    else:
        raise ValueError('Unknown (h_input, h_output) pair:' +
                         str((h_input, h_output)))

    w_shape = (ksize, ksize, in_channels * nti, out_channels)
    gconv_shape_info = (out_channels, nto, in_channels, nti, ksize)
    return gconv_indices, gconv_shape_info, w_shape
Beispiel #2
0
def gconv2d_util(h_input, h_output, in_channels, out_channels, ksize):
    """
    Convenience function for setting up static data required for the G-Conv.
     This function returns:
      1) an array of indices used in the filter transformation step of gconv2d
      2) shape information required by gconv2d
      5) the shape of the filter tensor to be allocated and passed to gconv2d

    :param h_input: one of ('Z2', 'C4', 'D4'). Use 'Z2' for the first layer. Use 'C4' or 'D4' for later layers.
    :param h_output: one of ('C4', 'D4'). What kind of transformations to use (rotations or roto-reflections).
      The choice of h_output of one layer should equal h_input of the next layer.
    :param in_channels: the number of input channels. Note: this refers to the number of (3D) channels on the group.
    The number of 2D channels will be 1, 4, or 8 times larger, depending the value of h_input.
    :param out_channels: the number of output channels. Note: this refers to the number of (3D) channels on the group.
    The number of 2D channels will be 1, 4, or 8 times larger, depending on the value of h_output.
    :param ksize: the spatial size of the filter kernels (typically 3, 5, or 7).
    :return: gconv_indices
    """

    if h_input == 'Z2' and h_output == 'C4':
        gconv_indices = flatten_indices(make_c4_z2_indices(ksize=ksize))
        nti = 1
        nto = 4
    elif h_input == 'C4' and h_output == 'C4':
        gconv_indices = flatten_indices(make_c4_p4_indices(ksize=ksize))
        nti = 4
        nto = 4
    elif h_input == 'Z2' and h_output == 'D4':
        gconv_indices = flatten_indices(make_d4_z2_indices(ksize=ksize))
        nti = 1
        nto = 8
    elif h_input == 'D4' and h_output == 'D4':
        gconv_indices = flatten_indices(make_d4_p4m_indices(ksize=ksize))
        nti = 8
        nto = 8
    else:
        raise ValueError('Unknown (h_input, h_output) pair:' + str((h_input, h_output)))

    w_shape = (ksize, ksize, in_channels * nti, out_channels)
    gconv_shape_info = (out_channels, nto, in_channels, nti, ksize)
    return gconv_indices, gconv_shape_info, w_shape
Beispiel #3
0
def tf_trans_filter2(w, inds):

    flat_inds = flatten_indices(inds)
    no, ni, nti, n, _ = w.shape
    shape_info = (no, inds.shape[0], ni, nti, n)

    w = w.reshape(no, ni * nti, n, n)

    wt = tf.constant(w)
    rwt = transform_filter_2d_nchw(wt, flat_inds, shape_info)

    sess = tf.Session()
    rwt = sess.run(rwt)
    sess.close()

    nto = inds.shape[0]
    rwt = rwt.reshape(no, nto, ni, nti, n, n)
    return rwt
Beispiel #4
0
def tf_trans_filter(w, inds):

    flat_inds = flatten_indices(inds)
    no, ni, nti, n, _ = w.shape
    shape_info = (no, inds.shape[0], ni, nti, n)

    w = w.transpose((3, 4, 2, 1, 0)).reshape((n, n, nti * ni, no))

    wt = tf.constant(w)
    rwt = transform_filter_2d_nhwc(wt, flat_inds, shape_info)

    sess = tf.Session()
    rwt = sess.run(rwt)
    sess.close()

    nto = inds.shape[0]
    rwt = rwt.transpose(3, 2, 0, 1).reshape(no, nto, ni, nti, n, n)
    return rwt
Beispiel #5
0
 def make_transformation_indices(self, ksize):
     return flatten_indices(make_c4_p4_indices(ksize=ksize))