Exemple #1
0
def convert_sync_nd(module):
    """Traverse the input module and its child recursively
       and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
       to SynchronizedBatchNorm*N*d

    Args:
        module: the input module needs to be convert to SyncBN model

    Examples:
        >>> import torch.nn as nn
        >>> import torchvision
        >>> # m is a standard pytorch model
        >>> m = torchvision.models.resnet18(True)
        >>> m = nn.DataParallel(m)
        >>> # after convert, m is using SyncBN
        >>> m = convert_model(m)
    """
    if isinstance(module, torch.nn.DataParallel):
        mod = module.module
        mod = convert_sync_nd(mod)
        mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
        return mod

    mod = module
    for pth_module, sync_module in zip(
        [FastDeconv, FastDeconvTransposed, Delinear], [
            SynchronizedDeconv, SynchronizedDeconvTransposed,
            SynchronizedDelinear
        ]):
        if isinstance(module, pth_module):
            if isinstance(module, FastDeconv):
                mod = sync_module(module.in_channels, module.out_channels,
                                  module.kernel_size[0], module.stride[0],
                                  module.padding[0], module.dilation[0],
                                  module.groups, True, module.eps,
                                  module.n_iter, module.momentum, module.block,
                                  module.sampling_stride, module.freeze,
                                  module.freeze_iter)
            if isinstance(module, FastDeconvTransposed):
                mod = sync_module(module.in_channels, module.out_channels,
                                  module.kernel_size[0], module.stride[0],
                                  module.padding[0], module.output_padding[0],
                                  module.groups, True, module.dilation[0],
                                  module.eps, module.n_iter, module.momentum,
                                  module.block, module.sampling_stride)
            if isinstance(module, Delinear):
                mod = sync_module(module.in_features, module.out_features,
                                  True, module.eps, module.n_iter,
                                  module.momentum, module.block)
            mod.running_mean = module.running_mean
            mod.running_cov_isqrt = module.running_cov_isqrt
            mod.weight.data = module.weight.data.clone().detach()
            mod.bias.data = module.bias.data.clone().detach()

    for name, child in module.named_children():
        mod.add_module(name, convert_sync_nd(child))

    return mod
Exemple #2
0
def convert_model(module):
    """Traverse the input module and its child recursively
       and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
       to SynchronizedBatchNorm*N*d

    Args:
        module: the input module needs to be convert to SyncBN model

    Examples:
        >>> import torch.nn as nn
        >>> import torchvision
        >>> # m is a standard pytorch model
        >>> m = torchvision.models.resnet18(True)
        >>> m = nn.DataParallel(m)
        >>> # after convert, m is using SyncBN
        >>> m = convert_model(m)
    """
    if isinstance(module, torch.nn.DataParallel):
        mod = module.module
        mod = convert_model(mod)
        mod = DataParallelWithCallback(mod)
        return mod

    mod = module
    for pth_module, sync_module in zip(
        [
            torch.nn.modules.batchnorm.BatchNorm1d,
            torch.nn.modules.batchnorm.BatchNorm2d,
            torch.nn.modules.batchnorm.BatchNorm3d,
        ],
        [
            SynchronizedBatchNorm1d,
            SynchronizedBatchNorm2d,
            SynchronizedBatchNorm3d,
        ],
    ):
        if isinstance(module, pth_module):
            mod = sync_module(module.num_features, module.eps, module.momentum,
                              module.affine)
            mod.running_mean = module.running_mean
            mod.running_var = module.running_var
            if module.affine:
                mod.weight.data = module.weight.data.clone().detach()
                mod.bias.data = module.bias.data.clone().detach()

    for name, child in module.named_children():
        mod.add_module(name, convert_model(child))

    return mod