def __init__(self, in_channels, num_joints, max_norm=None, loss_keypoint=None, train_cfg=None, test_cfg=None): super().__init__() self.in_channels = in_channels self.num_joints = num_joints self.max_norm = max_norm self.loss = build_loss(loss_keypoint) self.train_cfg = {} if train_cfg is None else train_cfg self.test_cfg = {} if test_cfg is None else test_cfg self.conv = build_conv_layer( dict(type='Conv1d'), in_channels, num_joints * 3, 1) if self.max_norm is not None: # Apply weight norm clip to conv layers weight_clip = WeightNormClipHook(self.max_norm) for module in self.modules(): if isinstance(module, nn.modules.conv._ConvNd): weight_clip.register(module)
def test_weight_norm_clip(): torch.manual_seed(0) module = torch.nn.Linear(2, 2, bias=False) module.weight.data.fill_(2) WeightNormClipHook(max_norm=1.0).register(module) x = torch.rand(1, 2).requires_grad_() _ = module(x) weight_norm = module.weight.norm().item() np.testing.assert_allclose(weight_norm, 1.0, rtol=1e-6)
def __init__(self, in_channels, stem_channels=1024, num_blocks=2, kernel_sizes=(3, 3, 3), dropout=0.25, causal=False, residual=True, use_stride_conv=False, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), max_norm=None): # Protect mutable default arguments conv_cfg = copy.deepcopy(conv_cfg) norm_cfg = copy.deepcopy(norm_cfg) super().__init__() self.in_channels = in_channels self.stem_channels = stem_channels self.num_blocks = num_blocks self.kernel_sizes = kernel_sizes self.dropout = dropout self.causal = causal self.residual = residual self.use_stride_conv = use_stride_conv self.max_norm = max_norm assert num_blocks == len(kernel_sizes) - 1 for ks in kernel_sizes: assert ks % 2 == 1, 'Only odd filter widths are supported.' self.expand_conv = ConvModule( in_channels, stem_channels, kernel_size=kernel_sizes[0], stride=kernel_sizes[0] if use_stride_conv else 1, bias='auto', conv_cfg=conv_cfg, norm_cfg=norm_cfg) dilation = kernel_sizes[0] self.tcn_blocks = nn.ModuleList() for i in range(1, num_blocks + 1): self.tcn_blocks.append( BasicTemporalBlock(in_channels=stem_channels, out_channels=stem_channels, mid_channels=stem_channels, kernel_size=kernel_sizes[i], dilation=dilation, dropout=dropout, causal=causal, residual=residual, use_stride_conv=use_stride_conv, conv_cfg=conv_cfg, norm_cfg=norm_cfg)) dilation *= kernel_sizes[i] if self.max_norm is not None: # Apply weight norm clip to conv layers weight_clip = WeightNormClipHook(self.max_norm) for module in self.modules(): if isinstance(module, nn.modules.conv._ConvNd): weight_clip.register(module) self.dropout = nn.Dropout(dropout) if dropout > 0 else None