コード例 #1
0
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================

import pytest
import unittest
import copy
import time

import torch
from aimet_torch.quantsim import QuantizationSimModel
from aimet_common.utils import AimetLogger
import aimet_torch.examples.mnist_torch_model as mnist_model

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Test)


def forward_pass(model, args):
    torch.manual_seed(1)
    device = next(model.parameters()).device

    rand_input = torch.randn((10, 1, 28, 28)).to(device)
    model(rand_input)


class QuantizerCpuGpu(unittest.TestCase):
    @pytest.mark.cuda
    def test_and_compare_quantizer_no_fine_tuning_CPU_and_GPU(self):

        torch.manual_seed(1)