コード例 #1
0
ファイル: test_flop_count.py プロジェクト: wenh06/fvcore
    def test_customized_ops(self) -> None:
        """
        Test the use of customized operation handles. The first test checks the
        case when a new handle for a new operation is passed as an argument.
        The second case checks when a new handle for a default operation is
        passed. The new handle should overwrite the default handle.
        """

        # New handle for a new operation.
        def dummy_sigmoid_flop_jit(
                inputs: typing.List[Any],
                outputs: typing.List[Any]) -> typing.Counter[str]:
            """
            A dummy handle function for sigmoid. Note the handle here does
            not compute actual flop count. This is used for test only.
            """
            flop_dict = Counter()
            flop_dict["sigmoid"] = 10000
            return flop_dict

        batch_size = 10
        input_dim = 5
        output_dim = 4
        customNet = CustomNet(input_dim, output_dim)
        custom_ops: Dict[str, Handle] = {
            "aten::sigmoid": dummy_sigmoid_flop_jit
        }
        x = torch.rand(batch_size, input_dim)
        flop_dict1, _ = flop_count(customNet, (x, ), supported_ops=custom_ops)
        flop_sigmoid = 10000 / 1e9
        self.assertEqual(
            flop_dict1["sigmoid"],
            flop_sigmoid,
            "Customized operation handle failed to pass the flop count test.",
        )

        # New handle that overwrites a default handle addmm. So now the new
        # handle counts flops for the fully connected layer.
        def addmm_dummy_flop_jit(
                inputs: typing.List[object],
                outputs: typing.List[object]) -> typing.Counter[str]:
            """
            A dummy handle function for fully connected layers. This overwrites
            the default handle. Note the handle here does not compute actual
            flop count. This is used for test only.
            """
            flop_dict = Counter()
            flop_dict[self.lin_op] = 400000
            return flop_dict

        custom_ops2: Dict[str, Handle] = {
            "aten::{}".format(self.lin_op): addmm_dummy_flop_jit
        }
        flop_dict2, _ = flop_count(customNet, (x, ), supported_ops=custom_ops2)
        flop = 400000 / 1e9
        self.assertEqual(
            flop_dict2[self.lin_op],
            flop,
            "Customized operation handle failed to pass the flop count test.",
        )
コード例 #2
0
ファイル: test_flop_count.py プロジェクト: zbwxp/fvcore
    def test_whitelist(self) -> None:
        """
        Test the use of a whitelist. The first test only considers flop count
        for convolution layers. The linear layers of the ThreeNet are ignored
        and should not be counted towards total flops. The second test considers
        the case when a whitelist contains out of dictionary vocabulary.
        """
        batch_size = 9
        input_dim = 2
        conv_dim = 5
        spatial_dim = 10
        linear_dim = 3
        x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
        threeNet = ThreeNet(input_dim, conv_dim, linear_dim)
        flop = (batch_size * input_dim * conv_dim * spatial_dim * spatial_dim /
                1e9)
        white_list = ["aten::_convolution"]
        flop_dict = flop_count(threeNet, (x, ), white_list)
        gt_dict = defaultdict(float)
        gt_dict["conv"] = flop
        self.assertDictEqual(flop_dict, gt_dict,
                             "Whitelist failed to pass the flop count test.")

        # The whitelist contains out of dictionary vocabulary.
        white_list = ["division"]
        with self.assertRaises(Exception):
            flop_count(threeNet, (x, ), white_list)
コード例 #3
0
    def test_batchnorm(self) -> None:
        """
        Test flop count for operation batchnorm. The test cases include
        BatchNorm1d, BatchNorm2d and BatchNorm3d.
        """
        # Test for BatchNorm1d.
        supported_ops: Dict[str, Handle] = {"aten::batch_norm": batchnorm_flop_jit}
        batch_size = 10
        input_dim = 10
        batch_1d = nn.BatchNorm1d(input_dim, affine=False)
        x = torch.randn(batch_size, input_dim)
        flop_dict, _ = flop_count(batch_1d, (x,), supported_ops)
        gt_flop = 4 * batch_size * input_dim / 1e9
        gt_dict = defaultdict(float)
        gt_dict["batchnorm"] = gt_flop
        self.assertDictEqual(
            flop_dict, gt_dict, "BatchNorm1d failed to pass the flop count test."
        )

        # Test for BatchNorm2d.
        batch_size = 10
        input_dim = 10
        spatial_dim_x = 5
        spatial_dim_y = 5
        batch_2d = nn.BatchNorm2d(input_dim, affine=False)
        x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y)
        flop_dict, _ = flop_count(batch_2d, (x,), supported_ops)
        gt_flop = 4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y / 1e9
        gt_dict = defaultdict(float)
        gt_dict["batchnorm"] = gt_flop
        self.assertDictEqual(
            flop_dict, gt_dict, "BatchNorm2d failed to pass the flop count test."
        )

        # Test for BatchNorm3d.
        batch_size = 10
        input_dim = 10
        spatial_dim_x = 5
        spatial_dim_y = 5
        spatial_dim_z = 5
        batch_3d = nn.BatchNorm3d(input_dim, affine=False)
        x = torch.randn(
            batch_size, input_dim, spatial_dim_x, spatial_dim_y, spatial_dim_z
        )
        flop_dict, _ = flop_count(batch_3d, (x,), supported_ops)
        gt_flop = (
            4
            * batch_size
            * input_dim
            * spatial_dim_x
            * spatial_dim_y
            * spatial_dim_z
            / 1e9
        )
        gt_dict = defaultdict(float)
        gt_dict["batchnorm"] = gt_flop
        self.assertDictEqual(
            flop_dict, gt_dict, "BatchNorm3d failed to pass the flop count test."
        )
コード例 #4
0
ファイル: test_flop_count.py プロジェクト: zbwxp/fvcore
 def test_threeNet(self) -> None:
     """
     Test a network with more than one layer. The network has a convolution
     layer followed by two fully connected layers.
     """
     batch_size = 4
     input_dim = 2
     conv_dim = 5
     spatial_dim = 10
     linear_dim = 3
     x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
     threeNet = ThreeNet(input_dim, conv_dim, linear_dim)
     flop1 = (batch_size * conv_dim * input_dim * spatial_dim *
              spatial_dim / 1e9)
     flop_linear1 = batch_size * conv_dim * linear_dim / 1e9
     flop_linear2 = batch_size * linear_dim * 1 / 1e9
     flop2 = flop_linear1 + flop_linear2
     flop_dict = flop_count(threeNet, (x, ))
     gt_dict = defaultdict(float)
     gt_dict["conv"] = flop1
     gt_dict["addmm"] = flop2
     self.assertDictEqual(
         flop_dict,
         gt_dict,
         "The three-layer network failed to pass the flop count test.",
     )
コード例 #5
0
def calculate_model_info(model: torch.nn.Module,
                         image_size: int,
                         color_channels: int = 3):
    """
    Log info, includes number of parameters, gpu usage and gflops.
    Args:
        :param model: model to log the info.
        :param color_channels: number of color channels in the data
        :param image_size: size of the image, it is important in flops counting

        :return pd.Series with logged values
    """
    example_batch_input = torch.rand(
        [1, color_channels, image_size, image_size]).to(model.device)
    flop_results = flop_count(model, (example_batch_input, ))

    gpu_mem = gpu_mem_usage()
    total_params = params_count(model)
    total_flops = sum(flop_results[0].values())

    log.info("Model:\n{}".format(model))
    log.info("Params: {:,}".format(total_params))
    log.info("Mem: {:.3f} MB".format(gpu_mem))
    log.info("Flops: {:,} G".format(total_flops))

    return dict(gpu_mem_usage_MB=gpu_mem,
                params_count=total_params,
                total_flops=total_flops)
コード例 #6
0
    def test_matmul_broadcast(self) -> None:
        """
        Test flop count for operation matmul.
        """
        m = 20
        n = 10
        p = 100
        mNet = MatmulNet()
        x = torch.randn(1, m, n)
        y = torch.randn(1, n, p)
        flop_dict, _ = flop_count(mNet, (x, y))
        gt_flop = m * n * p / 1e9
        gt_dict = defaultdict(float)
        gt_dict["matmul"] = gt_flop
        self.assertDictEqual(
            flop_dict, gt_dict, "Matmul operation failed to pass the flop count test."
        )

        x = torch.randn(2, 2, m, n)
        y = torch.randn(2, 2, n, p)
        flop_dict, _ = flop_count(mNet, (x, y))
        gt_flop = 4 * m * n * p / 1e9
        gt_dict = defaultdict(float)
        gt_dict["matmul"] = gt_flop
        self.assertDictEqual(
            flop_dict, gt_dict, "Matmul operation failed to pass the flop count test."
        )

        x = torch.randn(1, m, n)
        y = torch.randn(n, p)
        flop_dict, _ = flop_count(mNet, (x, y))
        gt_flop = m * n * p / 1e9
        gt_dict = defaultdict(float)
        gt_dict["matmul"] = gt_flop
        self.assertDictEqual(
            flop_dict, gt_dict, "Matmul operation failed to pass the flop count test."
        )

        x = torch.randn(2, m, n)
        y = torch.randn(n, p)
        flop_dict, _ = flop_count(mNet, (x, y))
        gt_flop = 2 * m * n * p / 1e9
        gt_dict = defaultdict(float)
        gt_dict["matmul"] = gt_flop
        self.assertDictEqual(
            flop_dict, gt_dict, "Matmul operation failed to pass the flop count test."
        )
コード例 #7
0
        def _test_conv(
            conv_dim: int,
            batch_size: int,
            input_dim: int,
            output_dim: int,
            spatial_dim: int,
            kernel_size: int,
            padding: int,
            stride: int,
            group_size: int,
        ) -> None:
            convNet = ConvNet(
                conv_dim,
                input_dim,
                output_dim,
                kernel_size,
                spatial_dim,
                stride,
                padding,
                group_size,
            )
            assert conv_dim in [
                1,
                2,
                3,
            ], "Convolution dimension needs to be 1, 2, or 3"
            if conv_dim == 1:
                x = torch.randn(batch_size, input_dim, spatial_dim)
            elif conv_dim == 2:
                x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
            else:
                x = torch.randn(
                    batch_size, input_dim, spatial_dim, spatial_dim, spatial_dim
                )

            flop_dict, _ = flop_count(convNet, (x,))
            spatial_out = (
                (spatial_dim + 2 * padding) - kernel_size
            ) // stride + 1
            gt_flop = (
                batch_size
                * input_dim
                * output_dim
                * (kernel_size ** conv_dim)
                * (spatial_out ** conv_dim)
                / group_size
                / 1e9
            )
            gt_dict = defaultdict(float)
            gt_dict["conv"] = gt_flop
            self.assertDictEqual(
                flop_dict,
                gt_dict,
                "Convolution layer failed to pass the flop count test.",
            )
コード例 #8
0
ファイル: test_flop_count.py プロジェクト: zbwxp/fvcore
    def test_einsum(self) -> None:
        """
        Test flop count for operation torch.einsum. The first case checkes
        torch.einsum with equation nct,ncp->ntp. The second case checkes
        torch.einsum with equation "ntg,ncg->nct".
        """
        equation = "nct,ncp->ntp"
        n = 1
        c = 5
        t = 2
        p = 12
        eNet = EinsumNet(equation)
        x = torch.randn(n, c, t)
        y = torch.randn(n, c, p)
        flop_dict = flop_count(eNet, (x, y))
        gt_flop = n * t * p * c / 1e9
        gt_dict = defaultdict(float)
        gt_dict["einsum"] = gt_flop
        self.assertDictEqual(
            flop_dict,
            gt_dict,
            "Einsum operation nct,ncp->ntp failed to pass the flop count test.",
        )

        equation = "ntg,ncg->nct"
        g = 6
        eNet = EinsumNet(equation)
        x = torch.randn(n, t, g)
        y = torch.randn(n, c, g)
        flop_dict = flop_count(eNet, (x, y))
        gt_flop = n * t * g * c / 1e9
        gt_dict = defaultdict(float)
        gt_dict["einsum"] = gt_flop
        self.assertDictEqual(
            flop_dict,
            gt_dict,
            "Einsum operation ntg,ncg->nct failed to pass the flop count test.",
        )
コード例 #9
0
def get_flop_stats(model, cfg, is_train):
    """
    Compute the gflops for the current model given the config.
    Args:
        model (model): model to compute the flop counts.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        is_train (bool): if True, compute flops for training. Otherwise,
            compute flops for testing.

    Returns:
        float: the total number of gflops of the given model.
    """
    rgb_dimension = 3
    if is_train:
        input_tensors = torch.rand(
            rgb_dimension,
            cfg.DATA.NUM_FRAMES,
            cfg.DATA.TRAIN_CROP_SIZE,
            cfg.DATA.TRAIN_CROP_SIZE,
        )
    else:
        input_tensors = torch.rand(
            rgb_dimension,
            cfg.DATA.NUM_FRAMES,
            cfg.DATA.TEST_CROP_SIZE,
            cfg.DATA.TEST_CROP_SIZE,
        )

    flop_inputs = pack_pathway_output(cfg, input_tensors)
    for i in range(len(flop_inputs)):
        flop_inputs[i] = flop_inputs[i].unsqueeze(0).cuda(non_blocking=True)

    # If detection is enabled, count flops for one proposal.
    if not cfg.MODEL.LSTM:
        if cfg.DETECTION.ENABLE:
            bbox = torch.tensor([[0, 0, 1.0, 0, 1.0]])
            bbox = bbox.cuda()
            inputs = (flop_inputs, bbox)
        else:
            inputs = (flop_inputs, )
    else:
        label_history = torch.zeros(
            [1, 10, cfg.MODEL.NUM_CLASSES[0] + cfg.MODEL.NUM_CLASSES[1]])
        label_history = label_history.cuda()
        inputs = ([flop_inputs, label_history], )

    gflop_dict, _ = flop_count(model, inputs)
    gflops = sum(gflop_dict.values())
    return gflops
コード例 #10
0
def get_flop_stats(model, cfg, is_train):
    """
    Compute the gflops for the current model_utils given the config.
    Args:
        model (model_utils): model_utils to compute the flop counts.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        is_train (bool): if True, compute flops for training. Otherwise,
            compute flops for testing.

    Returns:
        float: the total number of gflops of the given model_utils.
    """
    rgb_dimension = 3
    if is_train:
        input_tensors = torch.rand(
            rgb_dimension,
            cfg.DATA.NUM_FRAMES,
            cfg.DATA.TRAIN_CROP_SIZE,
            cfg.DATA.TRAIN_CROP_SIZE,
        )
    else:
        input_tensors = torch.rand(
            rgb_dimension,
            cfg.DATA.NUM_FRAMES,
            cfg.DATA.TEST_CROP_SIZE,
            cfg.DATA.TEST_CROP_SIZE,
        )
    whitelist_ops = [
        "aten::addmm",
        "aten::_convolution",
        "aten::einsum",
        "aten::matmul",
    ]
    flop_inputs = pack_pathway_output(cfg, input_tensors)
    for i in range(len(flop_inputs)):
        flop_inputs[i] = flop_inputs[i].unsqueeze(0).cuda(non_blocking=True)

    # If detection is enabled, count flops for one proposal.
    if cfg.DETECTION.ENABLE:
        bbox = torch.tensor([[0, 0, 1.0, 0, 1.0]])
        bbox = bbox.cuda()
        inputs = (flop_inputs, bbox)
    else:
        inputs = (flop_inputs, )

    gflop_dict = flop_count(model, inputs, whitelist_ops)
    gflops = sum(gflop_dict.values())
    return gflops
コード例 #11
0
ファイル: test_flop_count.py プロジェクト: zbwxp/fvcore
 def test_nn(self) -> None:
     """
     Test a model which is a pre-defined nn.module without defining a new
     customized network.
     """
     batch_size = 5
     input_dim = 8
     output_dim = 4
     x = torch.randn(batch_size, input_dim)
     flop_dict = flop_count(nn.Linear(input_dim, output_dim), (x, ))
     gt_flop = batch_size * input_dim * output_dim / 1e9
     gt_dict = defaultdict(float)
     gt_dict["addmm"] = gt_flop
     self.assertDictEqual(flop_dict, gt_dict,
                          "nn.Linear failed to pass the flop count test.")
コード例 #12
0
 def test_skip_ops(self) -> None:
     """
     Test the return of skipped operations.
     """
     batch_size = 10
     input_dim = 5
     output_dim = 4
     customNet = CustomNet(input_dim, output_dim)
     x = torch.rand(batch_size, input_dim)
     _, skip_dict = flop_count(customNet, (x,))
     gt_dict = Counter()
     gt_dict["aten::sigmoid"] = 1
     self.assertDictEqual(
         skip_dict, gt_dict, "Skipped operations failed to pass the flop count test."
     )
コード例 #13
0
    def test_linear(self) -> None:
        """
        Test a network with a single fully connected layer.
        """
        batch_size = 5
        input_dim = 10
        output_dim = 20
        linearNet = LinearNet(input_dim, output_dim)
        x = torch.randn(batch_size, input_dim)
        flop_dict, _ = flop_count(linearNet, (x,))
        gt_flop = batch_size * input_dim * output_dim / 1e9
        gt_dict = defaultdict(float)
        gt_dict[self.lin_op] = gt_flop
        self.assertDictEqual(
            flop_dict,
            gt_dict,
            "Fully connected layer failed to pass the flop count test.",
        )

        # Test with #input_dims>2
        if self.lin_op != "linear":
            # Skip this test if nn.Linear doesn't use aten::linear
            # TODO: Stop skipping when multidimension aten::matmul
            # flop counting is implemented
            return
        extra_dim = 5
        x = torch.randn(batch_size, extra_dim, input_dim)
        flop_dict, _ = flop_count(linearNet, (x,))
        gt_flop = batch_size * input_dim * extra_dim * output_dim / 1e9
        gt_dict = defaultdict(float)
        gt_dict[self.lin_op] = gt_flop
        self.assertDictEqual(
            flop_dict,
            gt_dict,
            "Fully connected layer failed to pass the flop count test.",
        )
コード例 #14
0
def get_flop_stats(model):
    """
    Compute the gflops for the current model given the config.
    Args:
        model (model): model to compute the flop counts.

    Returns:
        float: the total number of gflops of the given model.
    """
    device = torch.cuda.current_device()

    flop_inputs = [torch.rand(3, 16, 224, 224).unsqueeze(0).to(device)]

    gflop_dict, _ = flop_count(model, (flop_inputs, ))
    gflops = sum(gflop_dict.values())
    return gflops
コード例 #15
0
ファイル: misc.py プロジェクト: BigFishMaster/SlowFast
def get_flop_stats(model, cfg, is_train):
    """
    Compute the gflops for the current model given the config.
    Args:
        model (model): model to compute the flop counts.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        is_train (bool): if True, compute flops for training. Otherwise,
            compute flops for testing.

    Returns:
        float: the total number of gflops of the given model.
    """
    inputs = _get_model_analysis_input(cfg, is_train)
    gflop_dict, _ = flop_count(model, inputs)
    gflops = sum(gflop_dict.values())
    return gflops
コード例 #16
0
ファイル: test_flop_count.py プロジェクト: zbwxp/fvcore
 def test_linear(self) -> None:
     """
     Test a network with a single fully connected layer.
     """
     batch_size = 5
     input_dim = 10
     output_dim = 20
     linearNet = LinearNet(input_dim, output_dim)
     x = torch.randn(batch_size, input_dim)
     flop_dict = flop_count(linearNet, (x, ))
     gt_flop = batch_size * input_dim * output_dim / 1e9
     gt_dict = defaultdict(float)
     gt_dict["addmm"] = gt_flop
     self.assertDictEqual(
         flop_dict,
         gt_dict,
         "Fully connected layer failed to pass the flop count test.",
     )
コード例 #17
0
 def test_bmm(self) -> None:
     """
     Test flop count for operation torch.bmm. The case checkes
     torch.bmm with equation nct,ntp->ncp.
     """
     n = 2
     c = 5
     t = 2
     p = 12
     eNet = BMMNet()
     x = torch.randn(n, c, t)
     y = torch.randn(n, t, p)
     flop_dict, _ = flop_count(eNet, (x, y))
     gt_flop = n * t * p * c / 1e9
     gt_dict = defaultdict(float)
     gt_dict["bmm"] = gt_flop
     self.assertDictEqual(
         flop_dict,
         gt_dict,
         "bmm operation nct,ncp->ntp failed to pass the flop count test.",
     )
コード例 #18
0
def get_flop_stats(model, cfg, is_train):
    """
    Compute the gflops for the current model given the config.
    Args:
        model (model): model to compute the flop counts.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        is_train (bool): if True, compute flops for training. Otherwise,
            compute flops for testing.

    Returns:
        float: the total number of gflops of the given model.
    """
    rgb_dimension = 1
    # if is_train:
    #     input_tensors = torch.rand(
    #         rgb_dimension,
    #         cfg.DATA.NUM_FRAMES,
    #         cfg.DATA.TRAIN_CROP_SIZE,
    #         cfg.DATA.TRAIN_CROP_SIZE,
    #     )
    # else:
    #     input_tensors = torch.rand(
    #         rgb_dimension,
    #         cfg.DATA.NUM_FRAMES,
    #         cfg.DATA.TEST_CROP_SIZE,
    #         cfg.DATA.TEST_CROP_SIZE,
    #     )
    input_tensors = torch.rand(1, 16, 192, 128)
    flop_inputs = input_tensors
    for i in range(len(flop_inputs)):
        flop_inputs[i] = flop_inputs[i].unsqueeze(0).cuda(non_blocking=True)

    # If detection is enabled, count flops for one proposal.

    inputs = (flop_inputs, )

    gflop_dict, _ = flop_count(model, inputs)
    gflops = sum(gflop_dict.values())
    return gflops
コード例 #19
0
        return x


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DF1SegX38(classes=19).to(device)
    summary(model, (3, 352, 480))
    x = torch.randn(2, 3, 512, 1024).to(device)

    from fvcore.nn.jit_handles import batchnorm_flop_jit
    from fvcore.nn.jit_handles import generic_activation_jit

    supported_ops = {
        "aten::batch_norm": batchnorm_flop_jit,
    }
    flop_dict, _ = flop_count(model, (x, ), supported_ops)

    flops_count, params_count = get_model_complexity_info(
        model, (3, 512, 1024), as_strings=False, print_per_layer_stat=True)
    input = x
    macs, params = profile(model, inputs=(input, ))
    print(flop_dict)
    print(flops_count, params_count)
    print(macs, params)
'''
/home/ethan/anaconda3/envs/py36_cuda101/bin/python /home/ethan/codes/Efficient-Segmentation-Networks/model/DFSegX16.py
/home/ethan/anaconda3/envs/py36_cuda101/lib/python3.6/site-packages/torch/nn/functional.py:2941: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/home/ethan/anaconda3/envs/py36_cuda101/lib/python3.6/site-packages/torch/nn/functional.py:3121: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))
----------------------------------------------------------------
コード例 #20
0
def main(args):
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    dataset, num_classes = get_dataset(
        args.dataset, "train" if args.dataset == 'coco' else 'trainval',
        get_transform(train=True), args.data_path)
    dataset_test, _ = get_dataset(args.dataset,
                                  "val" if args.dataset == 'coco' else 'test',
                                  get_transform(train=False), args.data_path)

    print("Creating data loaderssss")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(
            dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                  args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                            args.batch_size,
                                                            drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=1,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=utils.collate_fn)

    print("Creating model")
    # model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
    #                                                          pretrained=args.pretrained)

    model = get_model(num_classes=num_classes)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=args.lr_step_size,
                                                   gamma=args.lr_gamma)
    # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    if args.test_only:
        voc_evaluate(model, data_loader_test, device=device)
        return
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("model params", pytorch_total_params)
    temp = torch.randn(1, 3, 500, 353, device='cuda')
    model.eval()
    macs, params = flop_count(model, (temp, ))
    # macs, params = clever_format([macs, params], "%.3f")
    print("macs", macs.items())
    # print("macs", params.items())
    print("Start training")
    start_time = time.time()
    best_map = 0
    for epoch in range(args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader, device, epoch,
                        args.print_freq)
        lr_scheduler.step()
        if args.output_dir and epoch % 9 == 0:
            utils.save_on_master(
                {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'args': args
                }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

        # evaluate after every epoch
        if 'coco' in args.dataset:
            coco_evaluate(model, data_loader_test, device=device)
        elif 'voc' in args.dataset:
            map = voc_evaluate(model, data_loader_test, device=device)
            if map > best_map:
                best_map = map
            print("Best Mean AP")
            print(best_map)
        else:
            print(
                f'No evaluation method available for the dataset {args.dataset}'
            )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    if args.output_dir:
        utils.save_on_master(
            {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'args': args
            }, os.path.join(args.output_dir, 'model_{}.pth'.format("final")))