Ejemplo n.º 1
0
def gconv2d_util(h_input, h_output, in_channels, out_channels, ksize):
    """
    Convenience function for setting up static data required for the G-Conv.
     This function returns:
      1) an array of indices used in the filter transformation step of gconv2d
      2) shape information required by gconv2d
      5) the shape of the filter tensor to be allocated and passed to gconv2d
    :param h_input: one of ('Z2', 'C4', 'D4'). Use 'Z2' for the first layer. Use 'C4' or 'D4' for later layers.
    :param h_output: one of ('Z2', 'C4', 'D4'). What kind of transformations to use (rotations or roto-reflections).
      The choice of h_output of one layer should equal h_input of the next layer.
    :param in_channels: the number of input channels. Note: this refers to the number of (3D) channels on the group.
    The number of 2D channels will be 1, 4, or 8 times larger, depending the value of h_input.
    :param out_channels: the number of output channels. Note: this refers to the number of (3D) channels on the group.
    The number of 2D channels will be 1, 4, or 8 times larger, depending on the value of h_output.
    :param ksize: the spatial size of the filter kernels (typically 3, 5, or 7).
    :return: gconv_indices
    """

    if h_input == 'Z2' and h_output == 'C4':
        gconv_indices = flatten_indices(make_c4_z2_indices(ksize=ksize))
        nti = 1
        nto = 4
    elif h_input == 'C4' and h_output == 'C4':
        gconv_indices = flatten_indices(make_c4_p4_indices(ksize=ksize))
        nti = 4
        nto = 4
    elif h_input == 'Z2' and h_output == 'D4':
        gconv_indices = flatten_indices(make_d4_z2_indices(ksize=ksize))
        nti = 1
        nto = 8
    elif h_input == 'D4' and h_output == 'D4':
        gconv_indices = flatten_indices(make_d4_p4m_indices(ksize=ksize))
        nti = 8
        nto = 8
    elif h_input == 'D4' and h_output == 'Z2':
        gconv_indices = flatten_indices(make_d4_z2_indices(ksize=ksize))
        nti = 8
        nto = 1
    elif h_input == 'C4' and h_output == 'Z2':
        gconv_indices = flatten_indices(make_c4_z2_indices(ksize=ksize))
        nti = 4
        nto = 1
    else:
        raise ValueError('Unknown (h_input, h_output) pair:' +
                         str((h_input, h_output)))

    if h_output == 'Z2':
        w_shape = (ksize, ksize, in_channels, out_channels)
    else:
        w_shape = (ksize, ksize, in_channels * nti, out_channels)

    gconv_shape_info = (out_channels, nto, in_channels, nti, ksize)
    return gconv_indices, gconv_shape_info, w_shape
Ejemplo n.º 2
0
def check_d4_z2():
    inds = make_d4_z2_indices(ksize=3)
    w = np.random.randn(6, 7, 1, 3, 3)

    rt = tf_trans_filter(w, inds)
    rp = pytorch_trans_filter(w, inds)

    diff = np.abs(rt - rp).sum()
    print('>>>>> DIFFERENCE:', diff)
    assert diff == 0
Ejemplo n.º 3
0
def gconv2d_util(h_input, h_output, in_channels, out_channels, ksize):
    """
    Convenience function for setting up static data required for the G-Conv.
     This function returns:
      1) an array of indices used in the filter transformation step of gconv2d
      2) shape information required by gconv2d
      5) the shape of the filter tensor to be allocated and passed to gconv2d

    :param h_input: one of ('Z2', 'C4', 'D4'). Use 'Z2' for the first layer. Use 'C4' or 'D4' for later layers.
    :param h_output: one of ('C4', 'D4'). What kind of transformations to use (rotations or roto-reflections).
      The choice of h_output of one layer should equal h_input of the next layer.
    :param in_channels: the number of input channels. Note: this refers to the number of (3D) channels on the group.
    The number of 2D channels will be 1, 4, or 8 times larger, depending the value of h_input.
    :param out_channels: the number of output channels. Note: this refers to the number of (3D) channels on the group.
    The number of 2D channels will be 1, 4, or 8 times larger, depending on the value of h_output.
    :param ksize: the spatial size of the filter kernels (typically 3, 5, or 7).
    :return: gconv_indices
    """

    if h_input == 'Z2' and h_output == 'C4':
        gconv_indices = flatten_indices(make_c4_z2_indices(ksize=ksize))
        nti = 1
        nto = 4
    elif h_input == 'C4' and h_output == 'C4':
        gconv_indices = flatten_indices(make_c4_p4_indices(ksize=ksize))
        nti = 4
        nto = 4
    elif h_input == 'Z2' and h_output == 'D4':
        gconv_indices = flatten_indices(make_d4_z2_indices(ksize=ksize))
        nti = 1
        nto = 8
    elif h_input == 'D4' and h_output == 'D4':
        gconv_indices = flatten_indices(make_d4_p4m_indices(ksize=ksize))
        nti = 8
        nto = 8
    else:
        raise ValueError('Unknown (h_input, h_output) pair:' + str((h_input, h_output)))

    w_shape = (ksize, ksize, in_channels * nti, out_channels)
    gconv_shape_info = (out_channels, nto, in_channels, nti, ksize)
    return gconv_indices, gconv_shape_info, w_shape
Ejemplo n.º 4
0
 def make_transformation_indices(self, ksize):
     return make_d4_z2_indices(ksize=ksize)
Ejemplo n.º 5
0
def check_transform_d4_z2_grad(dtype='float64', toll=1e-10):
    inds = make_d4_z2_indices(ksize=5)
    w = cp.random.randn(3, 2, 1, 5, 5)
    check_transform_grad(inds, w, TransformGFilter, dtype, toll)
Ejemplo n.º 6
0
def check_transform_d4_z2_grad(dtype='float64', toll=1e-10):
    inds = make_d4_z2_indices(ksize=5)
    w = cp.random.randn(3, 2, 1, 5, 5)
    check_transform_grad(inds, w, TransformGFilter, dtype, toll)
Ejemplo n.º 7
0
 def transformation_indices(self):
     return idx.make_d4_z2_indices(ksize=self.kernel_size)