def convert_sync_nd(module): """Traverse the input module and its child recursively and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d to SynchronizedBatchNorm*N*d Args: module: the input module needs to be convert to SyncBN model Examples: >>> import torch.nn as nn >>> import torchvision >>> # m is a standard pytorch model >>> m = torchvision.models.resnet18(True) >>> m = nn.DataParallel(m) >>> # after convert, m is using SyncBN >>> m = convert_model(m) """ if isinstance(module, torch.nn.DataParallel): mod = module.module mod = convert_sync_nd(mod) mod = DataParallelWithCallback(mod, device_ids=module.device_ids) return mod mod = module for pth_module, sync_module in zip( [FastDeconv, FastDeconvTransposed, Delinear], [ SynchronizedDeconv, SynchronizedDeconvTransposed, SynchronizedDelinear ]): if isinstance(module, pth_module): if isinstance(module, FastDeconv): mod = sync_module(module.in_channels, module.out_channels, module.kernel_size[0], module.stride[0], module.padding[0], module.dilation[0], module.groups, True, module.eps, module.n_iter, module.momentum, module.block, module.sampling_stride, module.freeze, module.freeze_iter) if isinstance(module, FastDeconvTransposed): mod = sync_module(module.in_channels, module.out_channels, module.kernel_size[0], module.stride[0], module.padding[0], module.output_padding[0], module.groups, True, module.dilation[0], module.eps, module.n_iter, module.momentum, module.block, module.sampling_stride) if isinstance(module, Delinear): mod = sync_module(module.in_features, module.out_features, True, module.eps, module.n_iter, module.momentum, module.block) mod.running_mean = module.running_mean mod.running_cov_isqrt = module.running_cov_isqrt mod.weight.data = module.weight.data.clone().detach() mod.bias.data = module.bias.data.clone().detach() for name, child in module.named_children(): mod.add_module(name, convert_sync_nd(child)) return mod
def convert_model(module): """Traverse the input module and its child recursively and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d to SynchronizedBatchNorm*N*d Args: module: the input module needs to be convert to SyncBN model Examples: >>> import torch.nn as nn >>> import torchvision >>> # m is a standard pytorch model >>> m = torchvision.models.resnet18(True) >>> m = nn.DataParallel(m) >>> # after convert, m is using SyncBN >>> m = convert_model(m) """ if isinstance(module, torch.nn.DataParallel): mod = module.module mod = convert_model(mod) mod = DataParallelWithCallback(mod) return mod mod = module for pth_module, sync_module in zip( [ torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d, ], [ SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d, ], ): if isinstance(module, pth_module): mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) mod.running_mean = module.running_mean mod.running_var = module.running_var if module.affine: mod.weight.data = module.weight.data.clone().detach() mod.bias.data = module.bias.data.clone().detach() for name, child in module.named_children(): mod.add_module(name, convert_model(child)) return mod