예제 #1
0
    def decompose_fc_layer(self, layer, k, d, tt_ranks, ins=None, outs=None):
        weights = np.array(layer.weight.data)
        if self.verbose:
            logger.info(f'before {weights.shape}')
            logger.info(f'input shape {layer.in_features}')

        if self.factorization == 'tt':
            weight_tensor = Tensor(weights, from_matrix=True, d=d)
            #[4, 32, 256, 1494, 4096, 512, 64, 16, 4]
            if ins is None:
                ins = weight_tensor.ns
                outs = weight_tensor.ms

            Gs = weight_tensor.tt_with_ranks(tt_ranks)

            if self.verbose:
                sum = 1
                for s in weight_tensor.T.shape:
                    sum *= s
                logger.info(f'original parameters: {sum}')
                logger.info(f'tt parameters: {len(Gs)}')

            np.save(f'data/tt_fc_{k}_alexnet_cores.npy', Gs)
            tt_layer = TTLayer(in_features=ins,
                               out_features=outs,
                               tt_ranks=tt_ranks)
            return [tt_layer]

        elif self.factorization == 'svd':
            U, S, Vt = randomized_svd(weights,
                                      n_components=2048,
                                      random_state=None)
            logger.info(U.shape, S.shape, Vt.shape)
            US = U @ np.diag(S)
            w_ap = US @ Vt
            logger.info(
                f'original parameters {weights.shape[0] * weights.shape[1]}')
            logger.info(
                f'new parameters {US.shape[0] * US.shape[1] + Vt.shape[0] * Vt.shape[1]}'
            )
            logger.info(
                f'error {np.linalg.norm(weights - w_ap) / np.linalg.norm(weights)}'
            )

        else:
            raise ValueError('Not supported decomposition for this layer ')
예제 #2
0
def test_tt_layer_time():
    weights = np.random.rand(4096, 4096).astype(np.float32)
    input = torch.Tensor((1, 4096))
    data = np.random.rand(1, 4096).astype(np.float32)
    input.data = torch.from_numpy(data)
    times_4 = []
    times_8 = []

    ds = [4, 6, 8, 10]
    rs1 = [[1, 8, 8, 8, 1], [1, 8, 8, 8, 8, 8, 1], [1, 8, 8, 8, 8, 8, 8, 8, 1],
           [1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1]]
    rs2 = [[1, 4, 4, 4, 1], [1, 4, 4, 4, 4, 4, 1], [1, 4, 4, 4, 4, 4, 4, 4, 1],
           [1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1]]
    rss = [rs2, rs1]
    times = [times_4, times_8]
    print('linear')
    layer = LinearLayer(4096, 4096)
    layer.weight.data = torch.from_numpy(weights)
    res = layer.forward(input)
    res = layer.forward(input)
    start = time.time()
    res = layer.forward(input)
    or_t = time.time() - start
    print('===========')
    rs3 = [or_t * 1000] * 4
    for j in range(2):
        rs = rss[j]
        for i in range(4):
            weight_tensor = Tensor(weights, from_matrix=True, d=ds[i])
            # [4, 32, 256, 1494, 4096, 512, 64, 16, 4]
            tt_ranks = rs[i]
            print('d=', ds[i])
            print('tt_ranks', tt_ranks)

            t_tensor = torch.Tensor((4096, 4096))
            t_tensor.data = torch.from_numpy(weights)

            Gs = weight_tensor.tt_with_ranks(tt_ranks)

            sum = 1
            for s in weight_tensor.T.shape:
                sum *= s
            logger.info(f'tt parameters: {len(Gs)}')

            np.save('../CNNs/data/tt_fc_4_alexnet_cores.npy', Gs)

            tt_layer = TTLayer(in_features=weight_tensor.ns,
                               out_features=weight_tensor.ms,
                               tt_ranks=tt_ranks)

            #start = time.time()
            tt_layer.forward(input)
            #logger.info(f'tt_layer {time.time() - start}')

            #start = time.time()
            tt_layer.forward(input)
            #logger.info(f'tt_layer {time.time() - start}')

            start = time.time()
            tt_layer.forward(input)
            t = time.time() - start
            times[j].append(t * 1000)
            #logger.info(f'tt_layer {time.time() - start}')
            print('===========')

    plt.plot(ds, times_4, sns.xkcd_rgb["amber"], label='ap', linewidth=3)
    for p in range(len(ds)):
        plt.plot([ds[p]], [times_4[p]], 'o', color=sns.xkcd_rgb["amber"])

    plt.plot(ds, times_8, sns.xkcd_rgb["dusty red"], label='ap', linewidth=3)
    for p in range(len(ds)):
        plt.plot([ds[p]], [times_8[p]], 'o', color=sns.xkcd_rgb["dusty red"])

    plt.plot(ds, rs3, sns.xkcd_rgb["medium green"], label='ap', linewidth=3)
    for p in range(len(ds)):
        plt.plot([ds[p]], [rs3[p]], 'o', color=sns.xkcd_rgb["medium green"])

    p1 = mpatches.Patch(color=sns.xkcd_rgb["amber"], label='rk = 4')
    p2 = mpatches.Patch(color=sns.xkcd_rgb["dusty red"], label='rk = 8')
    p3 = mpatches.Patch(color=sns.xkcd_rgb["medium green"],
                        label='Исходный слой')

    plt.xlabel('TT-ранг')
    plt.ylabel('Время работы, мс')

    plt.legend(handles=[p1, p2, p3])
    plt.show()