def test_convert_deferred_batch_norm():
    bn = nn.BatchNorm2d(3, track_running_stats=False)
    bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS)
    assert type(bn) is nn.BatchNorm2d  # because of track_running_stats=False

    dbn = DeferredBatchNorm(3, chunks=CHUNKS)
    dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS)
    assert dbn is dbn_again

    dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1)
    assert dbn is not dbn_again  # because of different chunks
def test_optimize():
    bn = nn.BatchNorm2d(3)
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

    opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0)

    for i in range(5):
        input = torch.rand(16, 3, 224, 224)
        input = tilt_dist(input)

        # train
        y = bn(input)
        a = y.sum()
        a.backward()

        y = chunked_forward(dbn, input)
        b = y.sum()
        b.backward()

        opt.step()

        # eval
        bn.eval()
        dbn.eval()

        with torch.no_grad():
            assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10**i))
Exemple #3
0
    def __init__(self,
                 module: nn.Sequential,
                 balance: Iterable[int],
                 *,
                 devices: Optional[Devices] = None,
                 chunks: int = 1,
                 checkpoint: str = 'except_last',
                 deferred_batch_norm: bool = False):
        super().__init__()

        if chunks <= 0:
            raise ValueError('number of chunks must be positive integer')

        if checkpoint not in ['always', 'except_last', 'never']:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        self.chunks = chunks
        self.checkpoint = checkpoint

        if deferred_batch_norm:
            module = DeferredBatchNorm.convert_deferred_batch_norm(
                module, self.chunks)

        self.partitions, self.balance, self.devices = self.partition(
            module, balance, devices)
def test_input_requiring_grad():
    dbn = DeferredBatchNorm(3, chunks=CHUNKS)

    input = torch.rand(16, 3, 224, 224, requires_grad=True)
    input = tilt_dist(input)

    chunked_forward(dbn, input)

    assert not dbn.sum.requires_grad
    assert dbn.sum.grad_fn is None
def test_running_stats(momentum):
    bn = nn.BatchNorm2d(3, momentum=momentum)
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

    input = torch.rand(16, 3, 224, 224)
    input = tilt_dist(input)

    bn(input)
    chunked_forward(dbn, input)

    assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4)
    assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4)
def test_eval():
    bn = nn.BatchNorm2d(3)
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

    input = torch.rand(16, 3, 224, 224)
    input = tilt_dist(input)

    bn(input)
    chunked_forward(dbn, input)

    bn.eval()
    dbn.eval()

    assert torch.allclose(bn(input), dbn(input), atol=1e-4)
Exemple #7
0
    def __init__(
        self,
        module: nn.Sequential,
        balance: Optional[Iterable[int]] = None,
        *,
        devices: Optional[Devices] = None,
        chunks: int = chunks,
        checkpoint: str = checkpoint,
        deferred_batch_norm: bool = False,
    ) -> None:
        super().__init__()

        chunks = int(chunks)
        checkpoint = str(checkpoint)

        if balance is None:
            raise ValueError(recommend_auto_balance('balance is required'))
        if chunks <= 0:
            raise ValueError('number of chunks must be positive integer')
        if checkpoint not in ['always', 'except_last', 'never']:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        verify_module(module)

        # Verify if the underlying skippable modules satisfy integrity. The
        # integrity can be verified before forward() because it is static.
        verify_skippables(module)

        self.chunks = chunks
        self.checkpoint = checkpoint

        if deferred_batch_norm:
            module = DeferredBatchNorm.convert_deferred_batch_norm(
                module, chunks)

        if devices is None:
            devices = range(torch.cuda.device_count())
        devices = [torch.device(d) for d in devices]
        devices = cast(List[torch.device], devices)

        try:
            self.partitions, self.balance, self.devices = split_module(
                module, balance, devices)
        except BalanceError as exc:
            raise ValueError(recommend_auto_balance(str(exc)))

        self._copy_streams: List[List[AbstractStream]] = []
        self._skip_layout = inspect_skip_layout(self.partitions)
Exemple #8
0
    def __init__(
        self,
        module: nn.Sequential,
        balance: Optional[Iterable[int]] = None,
        *,
        devices: Optional[Devices] = None,
        chunks: int = 1,
        checkpoint: str = 'except_last',
        deferred_batch_norm: bool = False,
    ) -> None:
        super().__init__()

        if not isinstance(module, nn.Sequential):
            raise TypeError('non-sequential module cannot be partitioned')

        if balance is None:
            raise recommend_torchgpipe_balancing('balance is required')

        if chunks <= 0:
            raise ValueError('number of chunks must be positive integer')

        if checkpoint not in ['always', 'except_last', 'never']:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        self.chunks = chunks
        self.checkpoint = checkpoint

        if deferred_batch_norm:
            module = DeferredBatchNorm.convert_deferred_batch_norm(
                module, self.chunks)

        # Split the module into multiple partitions.
        balance = list(balance)

        if devices is None:
            devices = range(torch.cuda.device_count())
        devices = [torch.device(d) for d in devices]

        try:
            self.partitions, self.balance, self.devices = self._partition(
                module, balance, devices)
        except ValueError as exc:
            raise recommend_torchgpipe_balancing(str(exc))
Exemple #9
0
    def __init__(
        self,
        module: nn.Sequential,
        balance: Optional[Iterable[int]] = None,
        *,
        devices: Optional[Devices] = None,
        chunks: int = 1,
        checkpoint: str = 'except_last',
        deferred_batch_norm: bool = False,
    ) -> None:
        super().__init__()

        if balance is None:
            raise ValueError(recommend_auto_balance('balance is required'))
        if chunks <= 0:
            raise ValueError('number of chunks must be positive integer')
        if checkpoint not in ['always', 'except_last', 'never']:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        self.chunks = chunks
        self.checkpoint = checkpoint

        verify_module(module)

        if deferred_batch_norm:
            module = DeferredBatchNorm.convert_deferred_batch_norm(
                module, self.chunks)

        balance = list(balance)

        if devices is None:
            devices = range(torch.cuda.device_count())
        devices = [torch.device(d) for d in devices]

        try:
            self.partitions, self.balance, self.devices = split_module(
                module, balance, devices)
        except BalanceError as exc:
            raise ValueError(recommend_auto_balance(str(exc)))

        self._copy_streams: List[List[AbstractStream]] = []
def test_conv_bn():
    bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

    input = torch.rand(16, 3, 224, 224)
    input = tilt_dist(input)

    opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1)

    # 1st step
    a = bn(input)
    b = chunked_forward(dbn, input)

    # Outputs are different. (per-mini-batch vs. per-micro-batch)
    assert not torch.allclose(a, b)

    a.sum().backward()
    b.sum().backward()
    opt.step()
    opt.zero_grad()

    # Conv layers are also trained differently because of their different outputs.
    assert not torch.allclose(bn[0].weight, dbn[0].weight)

    # But BNs track identical running stats.
    assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
    assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e+3)

    # 2nd step
    a = bn(input)
    b = chunked_forward(dbn, input)
    a.sum().backward()
    b.sum().backward()

    # BNs can't track identical running stats due to the different conv layers.
    assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
    assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e+3)
def test_transparency(chunks, input_requires_grad):
    bn = nn.BatchNorm2d(3)
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks)

    input1 = torch.rand(16, 3, 224, 224)
    input1 = tilt_dist(input1)
    input2 = input1.clone()
    input1.requires_grad = input_requires_grad
    input2.requires_grad = input_requires_grad

    output1 = chunked_forward(bn, input1, chunks=chunks)
    output2 = chunked_forward(dbn, input2, chunks=chunks)

    assert torch.allclose(output1, output2, atol=1e-4)

    output1.mean().backward()
    output2.mean().backward()

    assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4)

    if input_requires_grad:
        assert input1.grad is not None
        assert input2.grad is not None
        assert torch.allclose(input1.grad, input2.grad, atol=1e-4)