Example #1
0
    def from_config(cls, config: Dict[str, Any]) -> "FullyConnectedHead":
        """Instantiates a FullyConnectedHead from a configuration.

        Args:
            config: A configuration for a FullyConnectedHead.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A FullyConnectedHead instance.
        """
        num_classes = config.get("num_classes", None)
        in_plane = config["in_plane"]
        silu = None if get_torch_version() < [1, 7] else nn.SiLU()
        activation = {"relu": nn.ReLU(RELU_IN_PLACE), "silu": silu}[
            config.get("activation", "relu")
        ]
        if activation is None:
            raise RuntimeError("SiLU activation is only supported since PyTorch 1.7")
        return cls(
            config["unique_id"],
            num_classes,
            in_plane,
            conv_planes=config.get("conv_planes", None),
            activation=activation,
            zero_init_bias=config.get("zero_init_bias", False),
            normalize_inputs=config.get("normalize_inputs", None),
        )
    def _test_quantize_model(self, model_config):
        if get_torch_version() >= [1, 11]:
            import torch.ao.quantization as tq
            from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
        else:
            import torch.quantization as tq
            from torch.quantization.quantize_fx import convert_fx, prepare_fx

        # quantize model
        model = build_model(model_config)
        model.eval()

        input = torch.ones([1, 3, 32, 32])

        heads = model.get_heads()
        # since prepare changes the code of ClassyBlock we need to clear head first
        # and reattach it later to avoid caching
        model.clear_heads()

        prepare_custom_config_dict = {}
        head_path_from_blocks = [
            _find_block_full_path(model.features, block_name)
            for block_name in heads.keys()
        ]
        # we need to keep the modules used in head standalone since
        # it will be accessed with path name directly in execution
        prepare_custom_config_dict["standalone_module_name"] = [(
            head,
            {
                "": tq.default_qconfig
            },
            {
                "input_quantized_idxs": [0],
                "output_quantized_idxs": []
            },
            None,
        ) for head in head_path_from_blocks]
        model.initial_block = prepare_fx(model.initial_block,
                                         {"": tq.default_qconfig})
        model.features = prepare_fx(
            model.features,
            {"": tq.default_qconfig},
            prepare_custom_config_dict,
        )
        model.set_heads(heads)

        # calibration
        model(input)

        heads = model.get_heads()
        model.clear_heads()
        model.initial_block = convert_fx(model.initial_block)
        model.features = convert_fx(model.features)
        model.set_heads(heads)

        output = model(input)
        self.assertEqual(output.size(), (1, 1000))
Example #3
0
    def test_quantize_model(self, config):
        """
        Test that the model builds using a config using either model_params or
        model_name and calls fx graph mode quantization apis
        """
        if get_torch_version() < [1, 8]:
            self.skipTest(
                "FX Graph Modee Quantization is only availablee from 1.8")
        if get_torch_version() >= [1, 11]:
            import torch.ao.quantization as tq
            from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
        else:
            import torch.quantization as tq
            from torch.quantization.quantize_fx import convert_fx, prepare_fx

        model = build_model(config)
        assert isinstance(model, RegNet)
        model.eval()
        model.stem = prepare_fx(model.stem, {"": tq.default_qconfig})
        model.stem = convert_fx(model.stem)
Example #4
0
 def test_build_model(self, config):
     """
     Test that the model builds using a config using either model_params or
     model_name.
     """
     if get_torch_version() < [1, 7] and (
         "regnet_z" in config["name"] or config.get("activation", "relu") == "silu"
     ):
         self.skipTest("SiLU activation is only supported since PyTorch 1.7")
     model = build_model(config)
     assert isinstance(model, RegNet)
Example #5
0
    def test_fully_connected_head_normalize_inputs(self):
        batch_size = 2
        in_plane = 3
        image_size = 4
        head = FullyConnectedHead(
            "default_head",
            in_plane=in_plane,
            normalize_inputs="l2",
            num_classes=None,
        )
        input = torch.rand([batch_size, in_plane, image_size, image_size])
        output = head(input)
        self.assertEqual(output.shape, torch.Size([batch_size, in_plane]))
        for i in range(batch_size):
            output_i = output[i]
            self.assertAlmostEqual(output_i.norm().item(), 1, delta=1e-5)

        # test that the grads will be the same when using normalization as when
        # normalizing an input and passing it to the head without normalize_inputs.
        # use input with a norm > 1 and make image_size = 1 so that average
        # pooling is a no op
        image_size = 1
        input = 2 + torch.rand([batch_size, in_plane, image_size, image_size])
        norm_func = (
            torch.linalg.norm
            if get_torch_version() >= [1, 7]
            else partial(torch.norm, p=2)
        )
        norms = norm_func(input, dim=[1, 2, 3])
        normalized_input = torch.clone(input)
        for i in range(batch_size):
            normalized_input[i] /= norms[i]
        num_classes = 10
        head_norm = FullyConnectedHead(
            "default_head",
            in_plane=in_plane,
            normalize_inputs="l2",
            num_classes=num_classes,
        )
        head_no_norm = FullyConnectedHead(
            "default_head",
            in_plane=in_plane,
            num_classes=num_classes,
        )
        # share the weights between the heads
        head_norm.load_state_dict(copy.deepcopy(head_no_norm.state_dict()))

        # use the sum of the output as the loss and perform a backward
        head_no_norm(normalized_input).sum().backward()
        head_norm(input).sum().backward()

        for param_1, param_2 in zip(head_norm.parameters(), head_no_norm.parameters()):
            self.assertTorchAllClose(param_1, param_2)
            self.assertTorchAllClose(param_1.grad, param_2.grad)
Example #6
0
def lecun_normal_init(tensor, fan_in):
    if get_torch_version() >= "1.7":
        trunc_normal_ = nn.init.trunc_normal_
    else:

        def trunc_normal_(
            tensor: Tensor,
            mean: float = 0.0,
            std: float = 1.0,
            a: float = -2.0,
            b: float = 2.0,
        ) -> Tensor:
            # code copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
            # commit: e9b369c

            # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
            def norm_cdf(x):
                # Computes standard normal cumulative distribution function
                return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

            if (mean < a - 2 * std) or (mean > b + 2 * std):
                warnings.warn(
                    "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                    "The distribution of values may be incorrect.",
                    stacklevel=2,
                )

            with torch.no_grad():
                # Values are generated by using a truncated uniform distribution and
                # then using the inverse CDF for the normal distribution.
                # Get upper and lower cdf values
                l = norm_cdf((a - mean) / std)
                u = norm_cdf((b - mean) / std)

                # Uniformly fill tensor with values from [l, u], then translate to
                # [2l-1, 2u-1].
                tensor.uniform_(2 * l - 1, 2 * u - 1)

                # Use inverse cdf transform for normal distribution to get truncated
                # standard normal
                tensor.erfinv_()

                # Transform to proper mean, std
                tensor.mul_(std * math.sqrt(2.0))
                tensor.add_(mean)

                # Clamp to ensure it's in the proper range
                tensor.clamp_(min=a, max=b)
                return tensor

    trunc_normal_(tensor, std=math.sqrt(1 / fan_in))
 def test_get_torch_version(self, mock_torch: mock.MagicMock):
     mock_torch.__version__ = "1.7.2"
     self.assertEqual(get_torch_version(), [1, 7])
     self.assertLess(get_torch_version(), [1, 8])
     self.assertGreater(get_torch_version(), [1, 6])
     mock_torch.__version__ = "1.11.2a"
     self.assertEqual(get_torch_version(), [1, 11])
     self.assertLess(get_torch_version(), [1, 13])
     self.assertGreater(get_torch_version(), [1, 8])
Example #8
0
class TestRegNet(unittest.TestCase):
    def _compare_models(self, model_1, model_2, expect_same: bool):
        if expect_same:
            self.assertMultiLineEqual(repr(model_1), repr(model_2))
        else:
            self.assertNotEqual(repr(model_1), repr(model_2))

    def swap_relu_with_silu(self, module):
        for child_name, child in module.named_children():
            if isinstance(child, nn.ReLU):
                setattr(module, child_name, nn.SiLU())
            else:
                self.swap_relu_with_silu(child)

    def _check_no_module_cls_in_model(self, module_cls, model):
        for module in model.modules():
            self.assertNotIsInstance(module, module_cls)

    @unittest.skipIf(
        get_torch_version() < [1, 7],
        "SiLU activation is only supported since PyTorch 1.7",
    )
    def test_activation(self):
        config = REGNET_TEST_CONFIGS[0][0]
        model_default = build_model(config)
        config = copy.deepcopy(config)
        config["activation"] = "relu"
        model_relu = build_model(config)

        # both models should be the same
        self._compare_models(model_default, model_relu, expect_same=True)

        # we don't expect any silus in the model
        self._check_no_module_cls_in_model(nn.SiLU, model_relu)

        config["activation"] = "silu"
        model_silu = build_model(config)

        # the models should be different
        self._compare_models(model_silu, model_relu, expect_same=False)

        # swap out all relus with silus
        self.swap_relu_with_silu(model_relu)
        print(model_relu)
        # both models should be the same
        self._compare_models(model_relu, model_silu, expect_same=True)

        # we don't expect any relus in the model
        self._check_no_module_cls_in_model(nn.ReLU, model_relu)
Example #9
0
    def create_stem(self, params: Union[RegNetParams, AnyNetParams]):
        # get the activation
        silu = None if get_torch_version() < [1, 7] else nn.SiLU()
        activation = {
            ActivationType.RELU: nn.ReLU(params.relu_in_place),
            ActivationType.SILU: silu,
        }[params.activation]

        # create stem
        stem = {
            StemType.RES_STEM_CIFAR: ResStemCifar,
            StemType.RES_STEM_IN: ResStemIN,
            StemType.SIMPLE_STEM_IN: SimpleStemIN,
        }[params.stem_type](3, params.stem_width, params.bn_epsilon,
                            params.bn_momentum, activation)
        init_weights(stem)
        return stem
Example #10
0
    def create_block(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        params: Union[RegNetParams, AnyNetParams],
        bottleneck_multiplier: float,
        group_width: int = 1,
    ):
        # get the block constructor function to use
        block_constructor = {
            BlockType.VANILLA_BLOCK: VanillaBlock,
            BlockType.RES_BASIC_BLOCK: ResBasicBlock,
            BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock,
            BlockType.RES_BOTTLENECK_LINEAR_BLOCK: ResBottleneckLinearBlock,
        }[params.block_type]

        # get the activation module
        silu = None if get_torch_version() < [1, 7] else nn.SiLU()
        activation = {
            ActivationType.RELU: nn.ReLU(params.relu_in_place),
            ActivationType.SILU: silu,
        }[params.activation]

        block = block_constructor(
            width_in,
            width_out,
            stride,
            params.bn_epsilon,
            params.bn_momentum,
            activation,
            group_width,
            bottleneck_multiplier,
            params.se_ratio,
        ).cuda()
        with set_torch_seed(self.seed):
            init_weights(block)
            self.seed += 1
        return block
class TestDensenet(unittest.TestCase):
    def _test_model(self, model_config):
        """This test will build Densenet models, run a forward pass and
        verify output shape, and then verify that get / set state
        works.

        I do this in one test so that we construct the model a minimum
        number of times.
        """
        model = build_model(model_config)

        # Verify forward pass works
        input = torch.ones([1, 3, 32, 32])
        output = model.forward(input)
        self.assertEqual(output.size(), (1, 1000))

        # Verify get_set_state
        new_model = build_model(model_config)
        state = model.get_classy_state()
        new_model.set_classy_state(state)
        new_state = new_model.get_classy_state()

        compare_model_state(self, state, new_state, check_heads=True)

    def _test_quantize_model(self, model_config):
        if get_torch_version() >= [1, 11]:
            import torch.ao.quantization as tq
            from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
        else:
            import torch.quantization as tq
            from torch.quantization.quantize_fx import convert_fx, prepare_fx

        # quantize model
        model = build_model(model_config)
        model.eval()

        input = torch.ones([1, 3, 32, 32])

        heads = model.get_heads()
        # since prepare changes the code of ClassyBlock we need to clear head first
        # and reattach it later to avoid caching
        model.clear_heads()

        prepare_custom_config_dict = {}
        head_path_from_blocks = [
            _find_block_full_path(model.features, block_name)
            for block_name in heads.keys()
        ]
        # we need to keep the modules used in head standalone since
        # it will be accessed with path name directly in execution
        prepare_custom_config_dict["standalone_module_name"] = [(
            head,
            {
                "": tq.default_qconfig
            },
            {
                "input_quantized_idxs": [0],
                "output_quantized_idxs": []
            },
            None,
        ) for head in head_path_from_blocks]
        model.initial_block = prepare_fx(model.initial_block,
                                         {"": tq.default_qconfig})
        model.features = prepare_fx(
            model.features,
            {"": tq.default_qconfig},
            prepare_custom_config_dict,
        )
        model.set_heads(heads)

        # calibration
        model(input)

        heads = model.get_heads()
        model.clear_heads()
        model.initial_block = convert_fx(model.initial_block)
        model.features = convert_fx(model.features)
        model.set_heads(heads)

        output = model(input)
        self.assertEqual(output.size(), (1, 1000))

    def test_small_densenet(self):
        self._test_model(MODELS["small_densenet"])

    @unittest.skipIf(
        get_torch_version() < [1, 8],
        "FX Graph Modee Quantization is only availablee from 1.8",
    )
    def test_quantized_small_densenet(self):
        self._test_quantize_model(MODELS["small_densenet"])
Example #12
0
    def __init__(
        self,
        input_filters: int,
        output_filters: int,
        expand_ratio: float,
        kernel_size: int,
        stride: int,
        se_ratio: float,
        id_skip: bool,
        use_se: bool,
        bn_momentum: float,
        bn_epsilon: float,
    ):
        assert se_ratio is None or (0 < se_ratio <= 1)
        super().__init__()
        self.bn_momentum = bn_momentum
        self.bn_epsilon = bn_epsilon
        self.has_se = use_se and se_ratio is not None
        self.se_avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.id_skip = id_skip
        self.expand_ratio = expand_ratio
        self.stride = stride
        self.input_filters = input_filters
        self.output_filters = output_filters

        self.relu_fn = swish if get_torch_version() < [1, 7] else nn.SiLU()

        # used to track the depth of the block
        self.depth = 0

        # Expansion phase
        expanded_filters = input_filters * expand_ratio
        if expand_ratio != 1:
            self.expand_conv = nn.Conv2d(
                in_channels=input_filters,
                out_channels=expanded_filters,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            )
            self.bn0 = nn.BatchNorm2d(
                num_features=expanded_filters,
                momentum=self.bn_momentum,
                eps=self.bn_epsilon,
            )
            self.depth += 1

        # Depthwise convolution phase
        self.depthwise_conv = nn.Conv2d(
            in_channels=expanded_filters,
            out_channels=expanded_filters,
            groups=expanded_filters,
            kernel_size=kernel_size,
            stride=stride,
            padding=get_same_padding_for_kernel_size(kernel_size),
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(
            num_features=expanded_filters,
            momentum=self.bn_momentum,
            eps=self.bn_epsilon,
        )
        self.depth += 1

        if self.has_se:
            # Squeeze and Excitation layer
            num_reduced_filters = max(1, int(input_filters * se_ratio))
            self.se_reduce = nn.Conv2d(
                in_channels=expanded_filters,
                out_channels=num_reduced_filters,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True,
            )
            self.se_expand = nn.Conv2d(
                in_channels=num_reduced_filters,
                out_channels=expanded_filters,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True,
            )
            self.depth += 2

        # Output phase
        self.project_conv = nn.Conv2d(
            in_channels=expanded_filters,
            out_channels=output_filters,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(
            num_features=output_filters, momentum=self.bn_momentum, eps=self.bn_epsilon
        )
        self.depth += 1
Example #13
0
    def __init__(self, params: AnyNetParams):
        super().__init__()

        silu = None if get_torch_version() < [1, 7] else nn.SiLU()
        activation = {
            ActivationType.RELU: nn.ReLU(params.relu_in_place),
            ActivationType.SILU: silu,
        }[params.activation]

        if activation is None:
            raise RuntimeError("SiLU activation is only supported since PyTorch 1.7")

        # Ad hoc stem
        self.stem = {
            StemType.RES_STEM_CIFAR: ResStemCifar,
            StemType.RES_STEM_IN: ResStemIN,
            StemType.SIMPLE_STEM_IN: SimpleStemIN,
        }[params.stem_type](
            3,
            params.stem_width,
            params.bn_epsilon,
            params.bn_momentum,
            activation,
        )

        # Instantiate all the AnyNet blocks in the trunk
        block_fun = {
            BlockType.VANILLA_BLOCK: VanillaBlock,
            BlockType.RES_BASIC_BLOCK: ResBasicBlock,
            BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock,
            BlockType.RES_BOTTLENECK_LINEAR_BLOCK: ResBottleneckLinearBlock,
        }[params.block_type]

        current_width = params.stem_width

        self.trunk_depth = 0

        blocks = []
        for i, (
            width_out,
            stride,
            depth,
            group_width,
            bottleneck_multiplier,
        ) in enumerate(params.get_expanded_params()):
            blocks.append(
                (
                    f"block{i+1}",
                    AnyStage(
                        current_width,
                        width_out,
                        stride,
                        depth,
                        block_fun,
                        activation,
                        group_width,
                        bottleneck_multiplier,
                        params,
                        stage_index=i + 1,
                    ),
                )
            )

            self.trunk_depth += blocks[-1][1].stage_depth

            current_width = width_out

        self.trunk_output = nn.Sequential(OrderedDict(blocks))

        # Init weights and good to go
        self.init_weights()
class TestResnext(unittest.TestCase):
    def _test_model(self, model_config):
        """This test will build ResNeXt-* models, run a forward pass and
        verify output shape, and then verify that get / set state
        works.

        I do this in one test so that we construct the model a minimum
        number of times.
        """
        model = build_model(model_config)

        # Verify forward pass works
        input = torch.ones([1, 3, 32, 32])
        output = model.forward(input)
        self.assertEqual(output.size(), (1, 1000))

        # Verify get_set_state
        new_model = build_model(model_config)
        state = model.get_classy_state()
        new_model.set_classy_state(state)
        new_state = new_model.get_classy_state()

        compare_model_state(self, state, new_state, check_heads=True)

    def _test_quantize_model(self, model_config):
        """This test will build ResNeXt-* models, quantize the model
        with fx graph mode quantization, run a forward pass and
        verify output shape, and then verify that get / set state
        works.
        """
        model = build_model(model_config)
        # Verify forward pass works
        input = torch.ones([1, 3, 32, 32])
        output = model.forward(input)
        self.assertEqual(output.size(), (1, 1000))

        model = _post_training_quantize(model, input)

        # Verify forward pass works
        input = torch.ones([1, 3, 32, 32])
        output = model.forward(input)
        self.assertEqual(output.size(), (1, 1000))

        # Verify get_set_state
        new_model = build_model(model_config)
        new_model = _post_training_quantize(new_model, input)
        state = model.get_classy_state()
        new_model.set_classy_state(state)
        # TODO: test get state for new_model and make sure
        # it is the same as state,
        # Currently allclose is not supported in quantized tensors
        # so we can't check this right now

    def test_build_preset_model(self):
        configs = [
            {
                "name": "resnet18",
                "use_se": True
            },
            {
                "name":
                "resnet50",
                "heads": [{
                    "name": "fully_connected",
                    "unique_id": "default_head",
                    "num_classes": 1000,
                    "fork_block": "block3-2",
                    "in_plane": 2048,
                }],
            },
            {
                "name":
                "resnext50_32x4d",
                "heads": [{
                    "name": "fully_connected",
                    "unique_id": "default_head",
                    "num_classes": 1000,
                    "fork_block": "block3-2",
                    "in_plane": 2048,
                }],
            },
        ]
        for config in configs:
            model = build_model(config)
            self.assertIsInstance(model, ResNeXt)

    def test_small_resnext(self):
        self._test_model(MODELS["small_resnext"])

    @unittest.skipIf(
        get_torch_version() < [1, 8],
        "FX Graph Modee Quantization is only availablee from 1.8",
    )
    def test_quantized_small_resnext(self):
        self._test_quantize_model(MODELS["small_resnext"])

    def test_small_resnet(self):
        self._test_model(MODELS["small_resnet"])

    @unittest.skipIf(
        get_torch_version() < [1, 8],
        "FX Graph Modee Quantization is only availablee from 1.8",
    )
    def test_quantized_small_resnet(self):
        self._test_quantize_model(MODELS["small_resnet"])

    def test_small_resnet_se(self):
        self._test_model(MODELS["small_resnet_se"])

    @unittest.skipIf(
        get_torch_version() < [1, 8],
        "FX Graph Modee Quantization is only availablee from 1.8",
    )
    def test_quantized_small_resnet_se(self):
        self._test_quantize_model(MODELS["small_resnet_se"])
 def use_optimization(self, task):
     # we can only use the optimization if we are on PyTorch >= 1.7 and the EMA state
     # is on the same device as the model
     return get_torch_version() >= [1, 7] and task.use_gpu == (self.device == "cuda")
Example #16
0
    def set_distributed_options(
        self,
        broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.
        BEFORE_EVAL,
        batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED,
        batch_norm_sync_group_size: int = 0,
        find_unused_parameters: bool = False,
        bucket_cap_mb: int = 25,
        fp16_grad_compress: bool = False,
    ):
        """Set distributed options.

        Args:
            broadcast_buffers_mode: Broadcast buffers mode. See
                :class:`BroadcastBuffersMode` for options.
            batch_norm_sync_mode: Batch normalization synchronization mode. See
                :class:`BatchNormSyncMode` for options.
            batch_norm_sync_group_size: Group size to use for synchronized batch norm.
                0 means that the stats are synchronized across all replicas. For
                efficient synchronization, set it to the number of GPUs in a node (
                usually 8).
            find_unused_parameters: See
                :class:`torch.nn.parallel.DistributedDataParallel` for information.
            bucket_cap_mb: See
                :class:`torch.nn.parallel.DistributedDataParallel` for information.
        Raises:
            RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex
                is not installed.
        """
        self.broadcast_buffers_mode = broadcast_buffers_mode

        if batch_norm_sync_group_size > 0:
            if not batch_norm_sync_mode == BatchNormSyncMode.APEX:
                # this should ideally work with PyTorch Sync BN as well, but it
                # fails while initializing DDP for some reason.
                raise ValueError(
                    "batch_norm_sync_group_size can be > 0 only when "
                    "Apex Synchronized Batch Normalization is being used.")
        self.batch_norm_sync_group_size = batch_norm_sync_group_size

        if batch_norm_sync_mode == BatchNormSyncMode.DISABLED:
            logging.info("Synchronized Batch Normalization is disabled")
        else:
            if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available:
                raise RuntimeError("apex is not installed")
            msg = f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}"
            if self.batch_norm_sync_group_size > 0:
                msg += f" and group size {batch_norm_sync_group_size}"
            logging.info(msg)
        self.batch_norm_sync_mode = batch_norm_sync_mode

        if find_unused_parameters:
            logging.info("Enabling find_unused_parameters in DDP")

        self.find_unused_parameters = find_unused_parameters
        self.ddp_bucket_cap_mb = bucket_cap_mb

        if fp16_grad_compress:
            if get_torch_version() < [1, 8]:
                raise RuntimeError(
                    "FP16 grad compression is only supported since PyTorch 1.8"
                )
            logging.info("Enabling FP16 grad compression")
        self.fp16_grad_compress = fp16_grad_compress

        return self
Example #17
0
class TestClassificationTask(unittest.TestCase):
    def _compare_model_state(self,
                             model_state_1,
                             model_state_2,
                             check_heads=True):
        compare_model_state(self, model_state_1, model_state_2, check_heads)

    def _compare_samples(self, sample_1, sample_2):
        compare_samples(self, sample_1, sample_2)

    def _compare_states(self, state_1, state_2, check_heads=True):
        compare_states(self, state_1, state_2)

    def setUp(self):
        # create a base directory to write checkpoints to
        self.base_dir = tempfile.mkdtemp()

    def tearDown(self):
        # delete all the temporary data created
        shutil.rmtree(self.base_dir)

    def test_build_task(self):
        config = get_test_task_config()
        task = build_task(config)
        self.assertTrue(isinstance(task, ClassificationTask))

    def test_hooks_config_builds_correctly(self):
        config = get_test_task_config()
        config["hooks"] = [{"name": "loss_lr_meter_logging"}]
        task = build_task(config)
        self.assertTrue(len(task.hooks) == 1)
        self.assertTrue(isinstance(task.hooks[0], LossLrMeterLoggingHook))

    def test_get_state(self):
        config = get_test_task_config()
        loss = build_loss(config["loss"])
        task = (
            ClassificationTask().set_num_epochs(1).set_loss(loss).set_model(
                build_model(config["model"])).set_optimizer(
                    build_optimizer(config["optimizer"])))
        for phase_type in ["train", "test"]:
            dataset = build_dataset(config["dataset"][phase_type])
            task.set_dataset(dataset, phase_type)

        task.prepare()

        task = build_task(config)
        task.prepare()

    def test_synchronize_losses_non_distributed(self):
        """
        Tests that synchronize losses has no side effects in a non-distributed setting.
        """
        test_config = get_fast_test_task_config()
        task = build_task(test_config)
        task.prepare()

        old_losses = copy.deepcopy(task.losses)
        task.synchronize_losses()
        self.assertEqual(old_losses, task.losses)

    def test_synchronize_losses_when_losses_empty(self):
        config = get_fast_test_task_config()
        task = build_task(config)
        task.prepare()

        task.set_use_gpu(torch.cuda.is_available())

        # Losses should be empty when creating task
        self.assertEqual(len(task.losses), 0)

        task.synchronize_losses()

    def test_checkpointing(self):
        """
        Tests checkpointing by running train_steps to make sure the train_steps
        run the same way after loading from a checkpoint.
        """
        config = get_fast_test_task_config()
        task = build_task(config).set_hooks([LossLrMeterLoggingHook()])
        task_2 = build_task(config).set_hooks([LossLrMeterLoggingHook()])

        task.set_use_gpu(torch.cuda.is_available())

        # only train 1 phase at a time
        trainer = LimitedPhaseTrainer(num_phases=1)

        while not task.done_training():
            # set task's state as task_2's checkpoint
            task_2._set_checkpoint_dict(
                get_checkpoint_dict(task, {}, deep_copy=True))

            # task 2 should have the same state before training
            self._compare_states(task.get_classy_state(),
                                 task_2.get_classy_state())

            # train for one phase
            trainer.train(task)
            trainer.train(task_2)

            # task 2 should have the same state after training
            self._compare_states(task.get_classy_state(),
                                 task_2.get_classy_state())

    def test_final_train_checkpoint(self):
        """Test that a train phase checkpoint with a where of 1.0 can be loaded"""

        config = get_fast_test_task_config()
        task = build_task(config).set_hooks(
            [CheckpointHook(self.base_dir, {}, phase_types=["train"])])
        task_2 = build_task(config)

        task.set_use_gpu(torch.cuda.is_available())

        trainer = LocalTrainer()
        trainer.train(task)

        self.assertAlmostEqual(task.where, 1.0, delta=1e-3)

        # set task_2's state as task's final train checkpoint
        task_2.set_checkpoint(self.base_dir)
        task_2.prepare()

        # we should be able to train the task
        trainer.train(task_2)

    def test_test_only_checkpointing(self):
        """
        Tests checkpointing by running train_steps to make sure the
        train_steps run the same way after loading from a training
        task checkpoint on a test_only task.
        """
        train_config = get_fast_test_task_config()
        train_config["num_epochs"] = 10
        test_config = get_fast_test_task_config()
        test_config["test_only"] = True
        train_task = build_task(train_config).set_hooks(
            [LossLrMeterLoggingHook()])
        test_only_task = build_task(test_config).set_hooks(
            [LossLrMeterLoggingHook()])

        # prepare the tasks for the right device
        train_task.prepare()

        # test in both train and test mode
        trainer = LocalTrainer()
        trainer.train(train_task)

        # set task's state as task_2's checkpoint
        test_only_task._set_checkpoint_dict(
            get_checkpoint_dict(train_task, {}, deep_copy=True))
        test_only_task.prepare()
        test_state = test_only_task.get_classy_state()

        # We expect the phase idx to be different for a test only task
        self.assertEqual(test_state["phase_idx"], -1)

        # We expect that test only state is test, no matter what train state is
        self.assertFalse(test_state["train"])

        # Num updates should be 0
        self.assertEqual(test_state["num_updates"], 0)

        # train_phase_idx should -1
        self.assertEqual(test_state["train_phase_idx"], -1)

        # Verify task will run
        trainer = LocalTrainer()
        trainer.train(test_only_task)

    def test_test_only_task(self):
        """
        Tests the task in test mode by running train_steps
        to make sure the train_steps run as expected on a
        test_only task
        """
        test_config = get_fast_test_task_config()
        test_config["test_only"] = True

        # delete train dataset
        del test_config["dataset"]["train"]

        test_only_task = build_task(test_config).set_hooks(
            [LossLrMeterLoggingHook()])

        test_only_task.prepare()
        test_state = test_only_task.get_classy_state()

        # We expect that test only state is test, no matter what train state is
        self.assertFalse(test_state["train"])

        # Num updates should be 0
        self.assertEqual(test_state["num_updates"], 0)

        # Verify task will run
        trainer = LocalTrainer()
        trainer.train(test_only_task)

    def test_train_only_task(self):
        """
        Tests that the task runs when only a train dataset is specified.
        """
        test_config = get_fast_test_task_config()

        # delete the test dataset from the config
        del test_config["dataset"]["test"]

        task = build_task(test_config).set_hooks([LossLrMeterLoggingHook()])
        task.prepare()

        # verify the the task can still be trained
        trainer = LocalTrainer()
        trainer.train(task)

    @unittest.skipUnless(torch.cuda.is_available(),
                         "This test needs a gpu to run")
    def test_checkpointing_different_device(self):
        config = get_fast_test_task_config()
        task = build_task(config)
        task_2 = build_task(config)

        for use_gpu in [True, False]:
            task.set_use_gpu(use_gpu)
            task.prepare()

            # set task's state as task_2's checkpoint
            task_2._set_checkpoint_dict(
                get_checkpoint_dict(task, {}, deep_copy=True))

            # we should be able to run the trainer using state from a different device
            trainer = LocalTrainer()
            task_2.set_use_gpu(not use_gpu)
            trainer.train(task_2)

    @unittest.skipUnless(is_distributed_training_run(),
                         "This test needs a distributed run")
    def test_get_classy_state_on_loss(self):
        config = get_fast_test_task_config()
        config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
        task = build_task(config)
        task.prepare()
        self.assertIn("alpha", task.get_classy_state()["loss"])

    def test_gradient_clipping(self):
        apex_available = True
        try:
            import apex  # noqa F401
        except ImportError:
            apex_available = False

        def train_with_clipped_gradients(amp_args=None):
            task = build_task(get_fast_test_task_config())
            task.set_num_epochs(1)
            task.set_model(SimpleModel())
            task.set_loss(SimpleLoss())
            task.set_meters([])
            task.set_use_gpu(torch.cuda.is_available())
            task.set_clip_grad_norm(0.5)
            task.set_amp_args(amp_args)

            task.set_optimizer(SGD(lr=1))

            trainer = LocalTrainer()
            trainer.train(task)

            return task.model.param.grad.norm()

        grad_norm = train_with_clipped_gradients(None)
        self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2)

        if apex_available and torch.cuda.is_available():
            grad_norm = train_with_clipped_gradients({"opt_level": "O2"})
            self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2)

    def test_clip_stateful_loss(self):
        config = get_fast_test_task_config()
        config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
        config["grad_norm_clip"] = grad_norm_clip = 1
        task = build_task(config)
        task.set_use_gpu(False)
        task.prepare()

        # set fake gradients with norm > grad_norm_clip
        for param in itertools.chain(task.base_model.parameters(),
                                     task.base_loss.parameters()):
            param.grad = 1.1 + torch.rand(param.shape)
            self.assertGreater(param.grad.norm(), grad_norm_clip)

        task._clip_gradients(grad_norm_clip)

        for param in itertools.chain(task.base_model.parameters(),
                                     task.base_loss.parameters()):
            self.assertLessEqual(param.grad.norm(), grad_norm_clip)

    # helper used by gradient accumulation tests
    def train_with_batch(self, simulated_bs, actual_bs, clip_grad_norm=None):
        config = copy.deepcopy(get_fast_test_task_config())
        config["dataset"]["train"]["num_samples"] = 12
        config["dataset"]["train"]["batchsize_per_replica"] = actual_bs
        del config["dataset"]["test"]

        task = build_task(config)
        task.set_num_epochs(1)
        task.set_model(SimpleModel())
        task.set_loss(SimpleLoss())
        task.set_meters([])
        task.set_use_gpu(torch.cuda.is_available())
        if simulated_bs is not None:
            task.set_simulated_global_batchsize(simulated_bs)
        if clip_grad_norm is not None:
            task.set_clip_grad_norm(clip_grad_norm)

        task.set_optimizer(SGD(lr=1))

        trainer = LocalTrainer()
        trainer.train(task)

        return task.model.param

    def test_gradient_accumulation(self):
        param_with_accumulation = self.train_with_batch(simulated_bs=4,
                                                        actual_bs=2)
        param = self.train_with_batch(simulated_bs=4, actual_bs=4)

        self.assertAlmostEqual(param_with_accumulation, param, delta=1e-5)

    def test_gradient_accumulation_and_clipping(self):
        param = self.train_with_batch(simulated_bs=6,
                                      actual_bs=2,
                                      clip_grad_norm=0.1)

        # param starts at 5, it has to decrease, LR = 1
        # clipping the grad to 0.1 means we drop 0.1 per update. num_samples =
        # 12 and the simulated batch size is 6, so we should do 2 updates: 5 ->
        # 4.9 -> 4.8
        self.assertAlmostEqual(param, 4.8, delta=1e-5)

    @unittest.skipIf(
        get_torch_version() < [1, 8],
        "FP16 Grad compression is only available from PyTorch 1.8",
    )
    def test_fp16_grad_compression(self):
        # there is no API defined to check that a DDP hook has been enabled, so we just
        # test that we set the right variables
        config = copy.deepcopy(get_fast_test_task_config())
        task = build_task(config)
        self.assertFalse(task.fp16_grad_compress)

        config.setdefault("distributed", {})
        config["distributed"]["fp16_grad_compress"] = True

        task = build_task(config)
        self.assertTrue(task.fp16_grad_compress)