Exemplo n.º 1
0
    def forward(self, pos):
        if self.initial_density_map is None:
            self.compute_initial_density_map(pos)
            # plot(0, self.initial_density_map.clone().div(self.bin_size_x*self.bin_size_y).cpu().numpy(), self.padding, 'summary/initial_potential_map')
            logger.info("fixed density map: average %g, max %g, bin area %g" %
                        (self.initial_density_map.mean(),
                         self.initial_density_map.max(),
                         self.bin_size_x * self.bin_size_y))

            # expk
            M = self.num_bins_x
            N = self.num_bins_y
            self.exact_expkM = precompute_expk(M,
                                               dtype=pos.dtype,
                                               device=pos.device)
            self.exact_expkN = precompute_expk(N,
                                               dtype=pos.dtype,
                                               device=pos.device)

            # init dct2, idct2, idct_idxst, idxst_idct with expkM and expkN
            self.dct2 = dct.DCT2(self.exact_expkM, self.exact_expkN)
            if not self.fast_mode:
                self.idct2 = dct.IDCT2(self.exact_expkM, self.exact_expkN)
            self.idct_idxst = dct.IDCT_IDXST(self.exact_expkM,
                                             self.exact_expkN)
            self.idxst_idct = dct.IDXST_IDCT(self.exact_expkM,
                                             self.exact_expkN)

            # wu and wv
            wu = torch.arange(M, dtype=pos.dtype, device=pos.device).mul(
                2 * np.pi / M).view([M, 1])
            # scale wv because the aspect ratio of a bin may not be 1
            wv = torch.arange(N, dtype=pos.dtype,
                              device=pos.device).mul(2 * np.pi / N).view(
                                  [1,
                                   N]).mul_(self.bin_size_x / self.bin_size_y)
            wu2_plus_wv2 = wu.pow(2) + wv.pow(2)
            wu2_plus_wv2[0,
                         0] = 1.0  # avoid zero-division, it will be zeroed out
            self.inv_wu2_plus_wv2 = 1.0 / wu2_plus_wv2
            self.inv_wu2_plus_wv2[0, 0] = 0.0
            self.wu_by_wu2_plus_wv2_half = wu.mul(self.inv_wu2_plus_wv2).mul_(
                1. / 2)
            self.wv_by_wu2_plus_wv2_half = wv.mul(self.inv_wu2_plus_wv2).mul_(
                1. / 2)

        return ElectricPotentialFunction.apply(
            pos, self.node_size_x_clamped, self.node_size_y_clamped,
            self.offset_x, self.offset_y, self.ratio, self.bin_center_x,
            self.bin_center_y, self.initial_density_map, self.buf,
            self.target_density, self.xl, self.yl, self.xh, self.yh,
            self.bin_size_x, self.bin_size_y, self.num_movable_nodes,
            self.num_filler_nodes, self.padding, self.padding_mask,
            self.num_bins_x, self.num_bins_y, self.num_movable_impacted_bins_x,
            self.num_movable_impacted_bins_y, self.num_filler_impacted_bins_x,
            self.num_filler_impacted_bins_y, self.deterministic_flag,
            self.sorted_node_map, self.exact_expkM, self.exact_expkN,
            self.inv_wu2_plus_wv2, self.wu_by_wu2_plus_wv2_half,
            self.wv_by_wu2_plus_wv2_half, self.dct2, self.idct2,
            self.idct_idxst, self.idxst_idct, self.fast_mode, self.num_threads)
Exemplo n.º 2
0
def eval_idct2d(x, expk0, expk1, expkM, expkN, runs):
    y_N = discrete_spectral_transform.idct2_N(x, expk0=expk0, expk1=expk1)
    torch.cuda.synchronize()
    tt = time.time()
    for i in range(runs):
        y_N = discrete_spectral_transform.idct2_N(x, expk0=expk0, expk1=expk1)
    torch.cuda.synchronize()
    print("PyTorch idct2_N takes %.7f ms" % ((time.time()-tt)/runs*1000))

    idct2func = dct.IDCT2(expk0, expk1, algorithm='2N')
    y_N = idct2func.forward(x)
    torch.cuda.synchronize()
    tt = time.time()
    for i in range(runs):
        y_N = idct2func.forward(x)
    torch.cuda.synchronize()
    print("IDCT2_2N Function takes %.7f ms" % ((time.time()-tt)/runs*1000))

    idct2func = dct.IDCT2(expk0, expk1, algorithm='N')
    y_N = idct2func.forward(x)
    torch.cuda.synchronize()
    tt = time.time()
    # with torch.autograd.profiler.profile(use_cuda=True) as prof:
    for i in range(runs):
        y_N = idct2func.forward(x)/x.size(0)/x.size(1)/4
    torch.cuda.synchronize()
    # print(prof)
    print("IDCT2_N Function takes %.7f ms" % ((time.time()-tt)/runs*1000))

    dct2func = dct2_fft2.IDCT2(expkM, expkN)
    y = dct2func.forward(x)
    torch.cuda.synchronize()
    tt = time.time()
    for i in range(runs):
        y_test = dct2func.forward(x)
    torch.cuda.synchronize()
    print("IDCT2_FFT2 Function takes %.7f ms" % ((time.time()-tt)/runs*1000))

    print("")
Exemplo n.º 3
0
    def forward(self, pos, mode="density"):
        assert mode in {"density", "overflow"
                        }, "Only support density mode or overflow mode"
        if (self.region_id is not None):
            ### reconstruct pos, only extract cells in this electric field
            pos = pos[self.pos_mask]

        if self.initial_density_map is None:
            num_nodes = pos.size(0) // 2
            if (self.fence_regions is not None):
                if (self.placedb.num_terminals > 0):
                    ### merge fence region density and macro density together as initial density map
                    ### pay attention to the number of nodes, must use data from self
                    ### here pos is reconstructed pos !
                    self.initial_density_map = self.compute_fence_region_map(
                        self.fence_regions,
                        pos[self.num_movable_nodes:self.num_movable_nodes +
                            self.num_terminals],
                        pos[num_nodes + self.num_movable_nodes:num_nodes +
                            self.num_movable_nodes + self.num_terminals],
                        self.node_size_x[self.num_movable_nodes:self.
                                         num_movable_nodes +
                                         self.num_terminals],
                        self.node_size_y[self.num_movable_nodes:self.
                                         num_movable_nodes +
                                         self.num_terminals])
                else:
                    self.initial_density_map = self.compute_fence_region_map(
                        self.fence_regions)
            else:
                self.compute_initial_density_map(pos)
            ## sync the initial density map with
            # self.compute_initial_density_map(pos)
            # plot(0, self.initial_density_map.clone().div(self.bin_size_x*self.bin_size_y).cpu().numpy(), self.padding, 'summary/initial_potential_map')
            logger.info("fixed density map: average %g, max %g, bin area %g" %
                        (self.initial_density_map.mean(),
                         self.initial_density_map.max(),
                         self.bin_size_x * self.bin_size_y))

            # expk
            M = self.num_bins_x
            N = self.num_bins_y
            self.exact_expkM = precompute_expk(M,
                                               dtype=pos.dtype,
                                               device=pos.device)
            self.exact_expkN = precompute_expk(N,
                                               dtype=pos.dtype,
                                               device=pos.device)

            # init dct2, idct2, idct_idxst, idxst_idct with expkM and expkN
            self.dct2 = dct.DCT2(self.exact_expkM, self.exact_expkN)
            if not self.fast_mode:
                self.idct2 = dct.IDCT2(self.exact_expkM, self.exact_expkN)
            self.idct_idxst = dct.IDCT_IDXST(self.exact_expkM,
                                             self.exact_expkN)
            self.idxst_idct = dct.IDXST_IDCT(self.exact_expkM,
                                             self.exact_expkN)

            # wu and wv
            wu = torch.arange(M, dtype=pos.dtype, device=pos.device).mul(
                2 * np.pi / M).view([M, 1])
            # scale wv because the aspect ratio of a bin may not be 1
            wv = torch.arange(N, dtype=pos.dtype,
                              device=pos.device).mul(2 * np.pi / N).view(
                                  [1,
                                   N]).mul_(self.bin_size_x / self.bin_size_y)
            wu2_plus_wv2 = wu.pow(2) + wv.pow(2)
            wu2_plus_wv2[0,
                         0] = 1.0  # avoid zero-division, it will be zeroed out
            self.inv_wu2_plus_wv2 = 1.0 / wu2_plus_wv2
            self.inv_wu2_plus_wv2[0, 0] = 0.0
            self.wu_by_wu2_plus_wv2_half = wu.mul(self.inv_wu2_plus_wv2).mul_(
                1. / 2)
            self.wv_by_wu2_plus_wv2_half = wv.mul(self.inv_wu2_plus_wv2).mul_(
                1. / 2)

        if (mode == "density"):
            return ElectricPotentialFunction.apply(
                pos, self.node_size_x_clamped, self.node_size_y_clamped,
                self.offset_x, self.offset_y, self.ratio, self.bin_center_x,
                self.bin_center_y, self.initial_density_map,
                self.target_density, self.xl, self.yl, self.xh, self.yh,
                self.bin_size_x, self.bin_size_y, self.num_movable_nodes,
                self.num_filler_nodes, self.padding, self.padding_mask,
                self.num_bins_x, self.num_bins_y,
                self.num_movable_impacted_bins_x,
                self.num_movable_impacted_bins_y,
                self.num_filler_impacted_bins_x,
                self.num_filler_impacted_bins_y, self.deterministic_flag,
                self.sorted_node_map, self.exact_expkM, self.exact_expkN,
                self.inv_wu2_plus_wv2, self.wu_by_wu2_plus_wv2_half,
                self.wv_by_wu2_plus_wv2_half, self.dct2, self.idct2,
                self.idct_idxst, self.idxst_idct, self.fast_mode)
        elif (mode == "overflow"):
            ### num_filler_nodes is set 0
            density_map = ElectricDensityMapFunction.forward(
                pos, self.node_size_x_clamped, self.node_size_y_clamped,
                self.offset_x, self.offset_y, self.ratio, self.bin_center_x,
                self.bin_center_y, self.initial_density_map,
                self.target_density, self.xl, self.yl, self.xh, self.yh,
                self.bin_size_x, self.bin_size_y, self.num_movable_nodes, 0,
                self.padding, self.padding_mask, self.num_bins_x,
                self.num_bins_y, self.num_movable_impacted_bins_x,
                self.num_movable_impacted_bins_y,
                self.num_filler_impacted_bins_x,
                self.num_filler_impacted_bins_y, self.deterministic_flag,
                self.sorted_node_map)

            bin_area = self.bin_size_x * self.bin_size_y
            density_cost = (density_map -
                            self.target_density * bin_area).clamp_(
                                min=0.0).sum()

            return density_cost, density_map.max() / bin_area
def compare_different_methods(cuda_flag, M=1024, N=1024, dtype=torch.float64):
    density_map = torch.empty(M, N, dtype=dtype).uniform_(0, 10.0)
    if cuda_flag:
        density_map = density_map.cuda()
    expkM = discrete_spectral_transform.get_expk(M, dtype, density_map.device)
    expkN = discrete_spectral_transform.get_expk(N, dtype, density_map.device)
    exact_expkM = discrete_spectral_transform.get_exact_expk(M, dtype, density_map.device)
    exact_expkN = discrete_spectral_transform.get_exact_expk(N, dtype, density_map.device)
    print("M = {}, N = {}".format(M, N))

    wu = torch.arange(M, dtype=density_map.dtype, device=density_map.device).mul(2 * np.pi / M).view([M, 1])
    wv = torch.arange(N, dtype=density_map.dtype, device=density_map.device).mul(2 * np.pi / N).view([1, N])
    wu2_plus_wv2 = wu.pow(2) + wv.pow(2)
    wu2_plus_wv2[0, 0] = 1.0  # avoid zero-division, it will be zeroed out

    inv_wu2_plus_wv2_2X = 2.0 / wu2_plus_wv2
    inv_wu2_plus_wv2_2X[0, 0] = 0.0
    wu_by_wu2_plus_wv2_2X = wu.mul(inv_wu2_plus_wv2_2X)
    wv_by_wu2_plus_wv2_2X = wv.mul(inv_wu2_plus_wv2_2X)

    # the first approach is used as the ground truth
    auv_golden = dct.dct2(density_map, expk0=expkM, expk1=expkN)
    auv = auv_golden.clone()
    auv[0, :].mul_(0.5)
    auv[:, 0].mul_(0.5)
    auv_by_wu2_plus_wv2_wu = auv.mul(wu_by_wu2_plus_wv2_2X)
    auv_by_wu2_plus_wv2_wv = auv.mul(wv_by_wu2_plus_wv2_2X)
    field_map_x_golden = dct.idsct2(auv_by_wu2_plus_wv2_wu, expkM, expkN)
    field_map_y_golden = dct.idcst2(auv_by_wu2_plus_wv2_wv, expkM, expkN)
    # compute potential phi
    # auv / (wu**2 + wv**2)
    auv_by_wu2_plus_wv2 = auv.mul(inv_wu2_plus_wv2_2X).mul_(2)
    #potential_map = discrete_spectral_transform.idcct2(auv_by_wu2_plus_wv2, expkM, expkN)
    potential_map_golden = dct.idcct2(auv_by_wu2_plus_wv2, expkM, expkN)
    # compute energy
    energy_golden = potential_map_golden.mul(density_map).sum()

    if density_map.is_cuda:
        torch.cuda.synchronize()

    # the second approach uses the idxst_idct and idct_idxst
    dct2 = dct2_fft2.DCT2(exact_expkM, exact_expkN)
    idct2 = dct2_fft2.IDCT2(exact_expkM, exact_expkN)
    idct_idxst = dct2_fft2.IDCT_IDXST(exact_expkM, exact_expkN)
    idxst_idct = dct2_fft2.IDXST_IDCT(exact_expkM, exact_expkN)

    inv_wu2_plus_wv2 = 1.0 / wu2_plus_wv2
    inv_wu2_plus_wv2[0, 0] = 0.0
    wu_by_wu2_plus_wv2_half = wu.mul(inv_wu2_plus_wv2).mul_(0.5)
    wv_by_wu2_plus_wv2_half = wv.mul(inv_wu2_plus_wv2).mul_(0.5)

    buv = dct2.forward(density_map)

    buv_by_wu2_plus_wv2_wu = buv.mul(wu_by_wu2_plus_wv2_half)
    buv_by_wu2_plus_wv2_wv = buv.mul(wv_by_wu2_plus_wv2_half)
    field_map_x = idxst_idct.forward(buv_by_wu2_plus_wv2_wu)
    field_map_y = idct_idxst.forward(buv_by_wu2_plus_wv2_wv)
    buv_by_wu2_plus_wv2 = buv.mul(inv_wu2_plus_wv2)
    potential_map = idct2.forward(buv_by_wu2_plus_wv2)
    energy = potential_map.mul(density_map).sum()

    if density_map.is_cuda:
        torch.cuda.synchronize()

    # compare results
    np.testing.assert_allclose(buv.data.cpu().numpy(), auv_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
    np.testing.assert_allclose(field_map_x.data.cpu().numpy(), field_map_x_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
    np.testing.assert_allclose(field_map_y.data.cpu().numpy(), field_map_y_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
    np.testing.assert_allclose(potential_map.data.cpu().numpy(), potential_map_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
    np.testing.assert_allclose(energy.data.cpu().numpy(), energy_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)

    # the third approach uses the dct.idxst_idct and dct.idxst_idct
    dct2 = dct.DCT2(expkM, expkN)
    idct2 = dct.IDCT2(expkM, expkN)
    idct_idxst = dct.IDCT_IDXST(expkM, expkN)
    idxst_idct = dct.IDXST_IDCT(expkM, expkN)

    cuv = dct2.forward(density_map)

    cuv_by_wu2_plus_wv2_wu = cuv.mul(wu_by_wu2_plus_wv2_half)
    cuv_by_wu2_plus_wv2_wv = cuv.mul(wv_by_wu2_plus_wv2_half)
    field_map_x = idxst_idct.forward(cuv_by_wu2_plus_wv2_wu)
    field_map_y = idct_idxst.forward(cuv_by_wu2_plus_wv2_wv)
    cuv_by_wu2_plus_wv2 = cuv.mul(inv_wu2_plus_wv2)
    potential_map = idct2.forward(cuv_by_wu2_plus_wv2)
    energy = potential_map.mul(density_map).sum()

    if density_map.is_cuda:
        torch.cuda.synchronize()

    # compare results
    np.testing.assert_allclose(cuv.data.cpu().numpy(), auv_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
    np.testing.assert_allclose(field_map_x.data.cpu().numpy(), field_map_x_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
    np.testing.assert_allclose(field_map_y.data.cpu().numpy(), field_map_y_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
    np.testing.assert_allclose(potential_map.data.cpu().numpy(), potential_map_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
    np.testing.assert_allclose(energy.data.cpu().numpy(), energy_golden.data.cpu().numpy(), rtol=1e-6, atol=1e-5)
Exemplo n.º 5
0
    def forward(self, pos):
        if self.initial_density_map is None:
            if self.num_terminals == 0:
                num_fixed_impacted_bins_x = 0
                num_fixed_impacted_bins_y = 0
            else:
                num_fixed_impacted_bins_x = int(
                    ((self.node_size_x[
                        self.num_movable_nodes:self.num_movable_nodes +
                        self.num_terminals].max() + self.bin_size_x) /
                     self.bin_size_x).ceil().clamp(max=self.num_bins_x))
                num_fixed_impacted_bins_y = int(
                    ((self.node_size_y[
                        self.num_movable_nodes:self.num_movable_nodes +
                        self.num_terminals].max() + self.bin_size_y) /
                     self.bin_size_y).ceil().clamp(max=self.num_bins_y))

            if pos.is_cuda:
                self.initial_density_map = electric_potential_cuda.fixed_density_map(
                    pos.view(pos.numel()), self.node_size_x, self.node_size_y,
                    self.bin_center_x, self.bin_center_y, self.xl, self.yl,
                    self.xh, self.yh, self.bin_size_x, self.bin_size_y,
                    self.num_movable_nodes, self.num_terminals,
                    self.num_bins_x, self.num_bins_y,
                    num_fixed_impacted_bins_x, num_fixed_impacted_bins_y)
            else:
                self.initial_density_map = electric_potential_cpp.fixed_density_map(
                    pos.view(pos.numel()), self.node_size_x, self.node_size_y,
                    self.bin_center_x, self.bin_center_y, self.xl, self.yl,
                    self.xh, self.yh, self.bin_size_x, self.bin_size_y,
                    self.num_movable_nodes, self.num_terminals,
                    self.num_bins_x, self.num_bins_y,
                    num_fixed_impacted_bins_x, num_fixed_impacted_bins_y,
                    self.num_threads)

            # plot(0, self.initial_density_map.clone().div(self.bin_size_x*self.bin_size_y).cpu().numpy(), self.padding, 'summary/initial_potential_map')

            # scale density of fixed macros
            self.initial_density_map.mul_(self.target_density)
            # expk
            M = self.num_bins_x
            N = self.num_bins_y
            self.exact_expkM = precompute_expk(M,
                                               dtype=pos.dtype,
                                               device=pos.device)
            self.exact_expkN = precompute_expk(N,
                                               dtype=pos.dtype,
                                               device=pos.device)

            # init dct2, idct2, idct_idxst, idxst_idct with expkM and expkN
            self.dct2 = dct.DCT2(self.exact_expkM, self.exact_expkN)
            if not self.fast_mode:
                self.idct2 = dct.IDCT2(self.exact_expkM, self.exact_expkN)
            self.idct_idxst = dct.IDCT_IDXST(self.exact_expkM,
                                             self.exact_expkN)
            self.idxst_idct = dct.IDXST_IDCT(self.exact_expkM,
                                             self.exact_expkN)

            # wu and wv
            wu = torch.arange(M, dtype=pos.dtype, device=pos.device).mul(
                2 * np.pi / M).view([M, 1])
            # scale wv because the aspect ratio of a bin may not be 1
            wv = torch.arange(N, dtype=pos.dtype,
                              device=pos.device).mul(2 * np.pi / N).view(
                                  [1,
                                   N]).mul_(self.bin_size_x / self.bin_size_y)
            wu2_plus_wv2 = wu.pow(2) + wv.pow(2)
            wu2_plus_wv2[0,
                         0] = 1.0  # avoid zero-division, it will be zeroed out
            self.inv_wu2_plus_wv2 = 1.0 / wu2_plus_wv2
            self.inv_wu2_plus_wv2[0, 0] = 0.0
            self.wu_by_wu2_plus_wv2_half = wu.mul(self.inv_wu2_plus_wv2).mul_(
                1. / 2)
            self.wv_by_wu2_plus_wv2_half = wv.mul(self.inv_wu2_plus_wv2).mul_(
                1. / 2)

        return ElectricPotentialFunction.apply(
            pos, self.node_size_x_clamped, self.node_size_y_clamped,
            self.offset_x, self.offset_y, self.ratio, self.bin_center_x,
            self.bin_center_y, self.initial_density_map, self.target_density,
            self.xl, self.yl, self.xh, self.yh, self.bin_size_x,
            self.bin_size_y, self.num_movable_nodes, self.num_filler_nodes,
            self.padding, self.padding_mask, self.num_bins_x, self.num_bins_y,
            self.num_movable_impacted_bins_x, self.num_movable_impacted_bins_y,
            self.num_filler_impacted_bins_x, self.num_filler_impacted_bins_y,
            self.sorted_node_map, self.exact_expkM, self.exact_expkN,
            self.inv_wu2_plus_wv2, self.wu_by_wu2_plus_wv2_half,
            self.wv_by_wu2_plus_wv2_half, self.dct2, self.idct2,
            self.idct_idxst, self.idxst_idct, self.fast_mode, self.num_threads)
Exemplo n.º 6
0
    def test_idct2Random(self):
        torch.manual_seed(10)
        M = 4
        N = 8
        x = torch.empty(M, N, dtype=torch.int32).random_(0, 10).double()
        print("2D x")
        print(x)

        expkM = discrete_spectral_transform.get_exact_expk(M, dtype=x.dtype, device=x.device)
        expkN = discrete_spectral_transform.get_exact_expk(N, dtype=x.dtype, device=x.device)

        y = discrete_spectral_transform.dct2_2N(x)

        golden_value = discrete_spectral_transform.idct2_2N(y).data.numpy()
        print("2D idct golden_value")
        print(golden_value)

        # test cpu using N-FFT
        # pdb.set_trace()
        custom = dct.IDCT2(algorithm='N')
        dct_value = custom.forward(y)
        print("2D idct_value")
        print(dct_value.data.numpy())

        np.testing.assert_allclose(dct_value.data.numpy(), golden_value, rtol=1e-6, atol=1e-5)

        # test cpu using 2N-FFT
        # pdb.set_trace()
        custom = dct.IDCT2(algorithm='2N')
        dct_value = custom.forward(y)
        print("2D idct_value")
        print(dct_value.data.numpy())

        np.testing.assert_allclose(dct_value.data.numpy(), golden_value, rtol=1e-6, atol=1e-5)

        # test cpu using dct_lee
        # pdb.set_trace()
        custom = dct_lee.IDCT2()
        dct_value = custom.forward(y)
        print("2D idct_value")
        print(dct_value.data.numpy())

        np.testing.assert_allclose(dct_value.data.numpy(), golden_value, rtol=1e-6, atol=1e-5)

        # test cpu using fft2
        custom = dct2_fft2.IDCT2(expkM, expkN)
        dct_value = custom.forward(y)
        print("2D idct_value cuda")
        print(dct_value.data.numpy())

        np.testing.assert_allclose(dct_value.data.numpy(), golden_value, rtol=1e-6, atol=1e-5)

        if torch.cuda.device_count():
            # test gpu
            custom = dct.IDCT2(algorithm='N')
            dct_value = custom.forward(y.cuda()).cpu()
            print("2D idct_value cuda")
            print(dct_value.data.numpy())

            np.testing.assert_allclose(dct_value.data.numpy(), golden_value, rtol=1e-6, atol=1e-5)

            # test gpu
            custom = dct.IDCT2(algorithm='2N')
            dct_value = custom.forward(y.cuda()).cpu()
            print("2D idct_value cuda")
            print(dct_value.data.numpy())

            np.testing.assert_allclose(dct_value.data.numpy(), golden_value, rtol=1e-6, atol=1e-5)

            # test gpu
            custom = dct_lee.IDCT2()
            dct_value = custom.forward(y.cuda()).cpu()
            print("2D idct_value cuda")
            print(dct_value.data.numpy())

            np.testing.assert_allclose(dct_value.data.numpy(), golden_value, rtol=1e-6, atol=1e-5)

            # test gpu using ifft2
            custom = dct2_fft2.IDCT2(expkM.cuda(), expkN.cuda())
            dct_value = custom.forward(y.cuda()).cpu()
            print("2D idct_value cuda")
            print(dct_value.data.numpy())

            np.testing.assert_allclose(dct_value.data.numpy(), golden_value, rtol=1e-6, atol=1e-5)