def __init__(self, vocab_size, embed_size, dropout=0.1, B=1): """ :param vocab_size: total vocab size :param embed_size: embedding size of token embedding :param dropout: dropout rate """ super().__init__() Embedding = get_hfta_op_for(nn.Embedding, B) self.token = Embedding(vocab_size, embed_size) self.position = Embedding(512, embed_size) self.segment = Embedding(3, embed_size) self.dropout = get_hfta_op_for(nn.Dropout, B)(p=dropout) self.embed_size = embed_size
def testcase( B=3, N=32, L=8, in_features=20, out_features=50, bias=True, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, L, in_features, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0) args = (in_features, out_features) kwargs = {'bias': bias, 'device': device, 'dtype': dtype} linear_array = [nn.Linear(*args, **kwargs) for _ in range(B)] linear_fused = get_hfta_op_for(nn.Linear, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): linear_fused.snatch_parameters(linear_array[b], b) y_array = [linear_array[b](x_array[b]) for b in range(B)] y_fused_actual = linear_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, population_threshold=1e-3, ) except AssertionError as e: dump_error_msg(e)
def testcase_AdaptiveAvgPool2d( B=3, N=32, C=16, HWin=28, output_size=(16, 16), device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, C, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) args = (output_size, ) pool_array = [nn.AdaptiveAvgPool2d(*args) for _ in range(B)] pool_fused = get_hfta_op_for(nn.AdaptiveAvgPool2d, B=B)(*args) y_array = [pool_array[b](x_array[b]) for b in range(B)] y_fused_actual = pool_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def testcase_MaxPool2d( B=3, N=32, C=16, kernel_size=2, HWin=28, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, C, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) args = (kernel_size, ) kwargs = { 'stride': stride, 'padding': padding, 'dilation': dilation, 'return_indices': return_indices, 'ceil_mode': ceil_mode, } pool_array = [nn.MaxPool2d(*args, **kwargs) for _ in range(B)] pool_fused = get_hfta_op_for(nn.MaxPool2d, B=B)(*args, **kwargs) res_array = [pool_array[b](x_array[b]) for b in range(B)] res_fused_actual = pool_fused(x_fused) if return_indices: y_array, indices_array = tuple(zip(*res_array)) y_fused_actual, indices_fused_actual = res_fused_actual else: y_array = res_array y_fused_actual = res_fused_actual y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e) if return_indices: indices_fused_expect = torch.cat( [indices.unsqueeze(1) for indices in indices_array], dim=1, ) try: assert_allclose( indices_fused_actual.cpu().numpy(), indices_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def _make_layer(self, block, planes, blocks, stride=1, B=1): downsample = None norm_layer = get_hfta_op_for(nn.BatchNorm2d, B) if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride, B=B), norm_layer(planes * block.expansion, track_running_stats=self.track_running_stats), ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, norm_layer, track_running_stats=self.track_running_stats, B=B)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block(self.inplanes, planes, norm_layer=norm_layer, track_running_stats=self.track_running_stats, B=B)) return nn.Sequential(*layers)
def testcase( B=3, N=32, C=1024, HWin=16, p=0.4, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.ones(N, C, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) dropout2d_fused = get_hfta_op_for(torch.nn.Dropout2d, B=B)(p) y_fused = dropout2d_fused(x_fused) for b in range(B): y = y_fused[:, b, :, :, :] assert y.size(0) == N assert y.size(1) == C for n in range(N): zero_channels = 0 for c in range(C): s = y[n, c].sum() # Each channel either has all zeros or no zeros. try: assert_allclose(s.cpu(), HWin**2 / (1 - p), rtol=1e-4) except AssertionError as e: assert_allclose(s.cpu(), 0, atol=1e-4) # s must be zero at this point. zero_channels += 1 assert_allclose(zero_channels / C, p, rtol=2e-1)
def conv1x1(in_planes, out_planes, stride=1, B=1): """1x1 convolution""" return get_hfta_op_for(nn.Conv2d, B)( in_planes, out_planes, kernel_size=1, stride=stride, bias=False, )
def __init__( self, B=0, partially_fused=False, device=torch.device('cpu'), dtype=torch.float, ): super(_TestNet, self).__init__() kwargs = {'device': device, 'dtype': dtype} self.conv1 = get_hfta_op_for(nn.Conv2d, B=B)(256, 128, 3, 3, **kwargs) if partially_fused: self.conv2 = [nn.Conv2d(128, 256, 5, 5, **kwargs) for _ in range(B)] else: self.conv2 = get_hfta_op_for(nn.Conv2d, B=B)(128, 256, 5, 5, **kwargs) if partially_fused: self.linear1 = [nn.Linear(500, 1000, **kwargs) for _ in range(B)] else: self.linear1 = get_hfta_op_for(nn.Linear, B=B)(500, 1000, **kwargs) self.linear2 = get_hfta_op_for(nn.Linear, B=B)(1000, 500, **kwargs) self.partially_fused = partially_fused
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, B=1): """3x3 convolution with padding""" return get_hfta_op_for(nn.Conv2d, B)( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, )
def testcase( B=3, N=32, input_dim=(20, ), num_embeddings=200, embedding_dim=50, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, _weight=None, device=torch.device('cpu'), x_dtype=torch.int, param_dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.randint( num_embeddings, [N] + list(input_dim), device=device, dtype=x_dtype, ) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0) args = (num_embeddings, embedding_dim) kwargs = { 'padding_idx': padding_idx, 'max_norm': max_norm, 'norm_type': norm_type, 'scale_grad_by_freq': scale_grad_by_freq, 'sparse': sparse, '_weight': _weight, 'device': device, 'dtype': param_dtype, } embedding_array = [nn.Embedding(*args, **kwargs) for _ in range(B)] embedding_fused = get_hfta_op_for(nn.Embedding, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): embedding_fused.snatch_parameters(embedding_array[b], b) y_array = [embedding_array[b](x_array[b]) for b in range(B)] y_fused_actual = embedding_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def __init__(self, d_model, dropout=0.1, max_len=5000, B=1): super(PositionalEncoding, self).__init__() self.dropout = get_hfta_op_for(nn.Dropout, B)(p=dropout) self.B = B pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe)
def __init__(self, block, layers, num_classes=10, zero_init_residual=False, track_running_stats=True, B=1): super(ResNet, self).__init__() self.B = B self.track_running_stats = track_running_stats norm_layer = get_hfta_op_for(nn.BatchNorm2d, B) self._conv_layer = get_hfta_op_for(nn.Conv2d, B).func if B > 0 else nn.Conv2d self._norm_layer = get_hfta_op_for(nn.BatchNorm2d, B).func if B > 0 else nn.BatchNorm2d self._linear_layer = get_hfta_op_for(nn.Linear, B).func if B > 0 else nn.Linear self.inplanes = 64 self.conv1 = get_hfta_op_for(nn.Conv2d, B=B)(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes, track_running_stats=track_running_stats) self.relu = nn.ReLU(inplace=True) self.maxpool = get_hfta_op_for(nn.MaxPool2d, B=B)(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], B=B) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, B=B) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, B=B) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, B=B) self.fc = get_hfta_op_for(nn.Linear, B)(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, self._conv_layer): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, self._norm_layer): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0)
def testcase_Conv2d( B=3, N=32, Cin=4, Cout=16, kernel_size=3, HWin=28, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, Cin, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) args = (Cin, Cout, kernel_size) kwargs = { 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups, 'bias': bias, 'padding_mode': padding_mode, 'device': device, 'dtype': dtype, } conv_array = [nn.Conv2d(*args, **kwargs) for _ in range(B)] conv_fused = get_hfta_op_for(nn.Conv2d, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): conv_fused.snatch_parameters(conv_array[b], b) y_array = [conv_array[b](x_array[b]) for b in range(B)] y_fused_actual = conv_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, population_threshold=1e-2, ) except AssertionError as e: dump_error_msg(e)
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, track_running_stats=True, B=1): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = get_hfta_op_for(nn.BatchNorm2d, B) # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride, B=B) self.bn1 = norm_layer(planes, track_running_stats=track_running_stats) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes, B=B) self.bn2 = norm_layer(planes, track_running_stats=track_running_stats) self.downsample = downsample self.stride = stride
def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout, B=1): """ :param hidden: hidden size of transformer :param attn_heads: head sizes of multi-head attention :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size :param dropout: dropout rate """ super().__init__() self.attention = MultiheadAttention(hidden, attn_heads, B=B) self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout, B=B) self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout, B=B) self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout, B=B) self.dropout = get_hfta_op_for(nn.Dropout, B)(p=dropout)
def testcase( B=5, N=64, normalized_shape=(200, ), elementwise_affine=True, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand([N] + list(normalized_shape), device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0) args = (normalized_shape, ) kwargs = { 'elementwise_affine': elementwise_affine, 'device': device, 'dtype': dtype, } layernorm_array = [nn.LayerNorm(*args, **kwargs) for _ in range(B)] layernorm_fused = get_hfta_op_for(nn.LayerNorm, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): layernorm_fused.snatch_parameters(layernorm_array[b], b) y_array = [layernorm_array[b](x_array[b]) for b in range(B)] y_fused_actual = layernorm_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1, B=1): """ :param vocab_size: vocab_size of total words :param hidden: BERT model hidden size :param n_layers: numbers of Transformer blocks(layers) :param attn_heads: number of attention heads :param dropout: dropout rate """ super().__init__() self.hidden = hidden self.n_layers = n_layers self.attn_heads = attn_heads self.B = B # paper noted they used 4*hidden_size for ff_network_hidden_size self.feed_forward_hidden = hidden * 4 # embedding for BERT, sum of positional, segment, token embeddings self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden, B=B) # multi-layers transformer blocks, deep network self.transformer_blocks = nn.ModuleList([ TransformerBlock(hidden, attn_heads, hidden * 4, dropout, B=B) for _ in range(n_layers) ]) self.output = get_hfta_op_for(nn.Linear, B)(hidden, vocab_size)
def convert_ops(B, *torch_op_classes): return (get_hfta_op_for(op_class, B=B) for op_class in torch_op_classes)
def testcase_2d( num_features=128, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, B=3, N=8, HWin=28, train_test_steps=10, training=True, device=torch.device('cpu'), dtype=torch.float, ): C = num_features with torch.no_grad(): args = (num_features, ) kwargs = { 'eps': eps, 'momentum': momentum, 'affine': affine, 'track_running_stats': track_running_stats, 'device': device, 'dtype': dtype, } batchNormal2d_array = [ nn.BatchNorm2d(*args, **kwargs) for _ in range(B) ] batchNormal2d_fused = get_hfta_op_for(nn.BatchNorm2d, B=B)(*args, **kwargs) if track_running_stats: rand_int = random.randint(0, 1024) for bn in batchNormal2d_array: nn.init.normal_(bn.running_mean) nn.init.normal_(bn.running_var) bn.num_batches_tracked.fill_(rand_int) # Init weights and biases. for b in range(B): batchNormal2d_fused.snatch_parameters(batchNormal2d_array[b], b) if training: [bn.train() for bn in batchNormal2d_array] batchNormal2d_fused.train() else: [bn.eval() for bn in batchNormal2d_array] batchNormal2d_fused.eval() # check whether fused outputs are same in several training steps for i in range(train_test_steps): x_array = [ torch.rand(N, C, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) y_array = [batchNormal2d_array[b](x_array[b]) for b in range(B)] y_fused_actual = batchNormal2d_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def _make_layer(self, block, serial_block, planes, blocks, run_in_serial, stride=1, B=1): downsample = None norm_layer = get_hfta_op_for(nn.BatchNorm2d, B) assert block.expansion == serial_block.expansion if stride != 1 or self.inplanes != planes * block.expansion: if self.B > 0 and run_in_serial[0]: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride, B=0), nn.BatchNorm2d(planes * block.expansion), ) else: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride, B=B), norm_layer(planes * block.expansion, track_running_stats=self.track_running_stats), ) layers = [] if self.B > 0 and run_in_serial[0]: current_block = serial_block( self.inplanes, planes, stride, downsample, nn.BatchNorm2d, B=B, track_running_stats=self.track_running_stats) self.unfused_layers.append(current_block) else: current_block = block(self.inplanes, planes, stride, downsample, norm_layer, B=B, track_running_stats=self.track_running_stats) layers.append(current_block) self.inplanes = planes * block.expansion for i in range(1, blocks): if self.B > 0 and run_in_serial[i]: current_block = serial_block( self.inplanes, planes, norm_layer=nn.BatchNorm2d, B=B, track_running_stats=self.track_running_stats) self.unfused_layers.append(current_block) else: current_block = block( self.inplanes, planes, norm_layer=norm_layer, B=B, track_running_stats=self.track_running_stats) layers.append(current_block) return nn.Sequential(*layers)
def __init__(self, config, block, serial_block, num_classes=10, zero_init_residual=False, track_running_stats=True, B=1): super(PartiallyFusedResNet, self).__init__() layers = config["layers"] run_in_serial = config["run_in_serial"] self.B = B self.track_running_stats = track_running_stats norm_layer = get_hfta_op_for(nn.BatchNorm2d, B) self._conv_layer = get_hfta_op_for(nn.Conv2d, B).func if B > 0 else nn.Conv2d self._norm_layer = get_hfta_op_for(nn.BatchNorm2d, B).func if B > 0 else nn.BatchNorm2d self._linear_layer = get_hfta_op_for(nn.Linear, B).func if B > 0 else nn.Linear self.inplanes = 64 if run_in_serial[4][1]: self.convBlock = SerialConvBlock(B, 3, self.inplanes) self.unfused_layers.append(self.convBlock) else: self.convBlock = nn.Sequential( get_hfta_op_for(nn.Conv2d, B=B)(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), norm_layer(self.inplanes, track_running_stats=track_running_stats), nn.ReLU(inplace=True), get_hfta_op_for(nn.MaxPool2d, B=B)(kernel_size=3, stride=2, padding=1)) self.layer1 = self._make_layer(block, serial_block, 64, layers[0], run_in_serial[0], B=B) self.layer2 = self._make_layer(block, serial_block, 128, layers[1], run_in_serial[1], stride=2, B=B) self.layer3 = self._make_layer(block, serial_block, 256, layers[2], run_in_serial[2], stride=2, B=B) self.layer4 = self._make_layer(block, serial_block, 512, layers[3], run_in_serial[3], stride=2, B=B) if run_in_serial[4][0]: self.fc = SerialLinear(512 * block.expansion, num_classes, B=B) self.unfused_layers.append(self.fc) else: self.fc = get_hfta_op_for(nn.Linear, B)(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, self._conv_layer): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, self._norm_layer): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0)
# custom weights initialization called on netG and netD # NOTE(wangsh46): This is okay for HFTA, because torch.nn.init.normal_ (or # torch.Tensor.normal_ to be specific) is element-wise; so is # torch.nn.init.zeros_ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: torch.nn.init.normal_(m.weight, 0.0, 0.02) elif classname.find('BatchNorm') != -1: torch.nn.init.normal_(m.weight, 1.0, 0.02) torch.nn.init.zeros_(m.bias) ConvTranspose2d = get_hfta_op_for(nn.ConvTranspose2d, B=B) BatchNorm2d = get_hfta_op_for(nn.BatchNorm2d, B=B) ReLU = get_hfta_op_for(nn.ReLU, B=B) Tanh = get_hfta_op_for(nn.Tanh, B=B) class Generator(nn.Module): def __init__(self, ngpu): super(Generator, self).__init__() self.ngpu = ngpu self.main = nn.Sequential( # input is Z, going into a convolution ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), BatchNorm2d(ngf * 8, track_running_stats=(args.device != 'xla')), ReLU(True), # state size. (ngf*8) x 4 x 4
def testcase_ConvTranspose2d( B=3, N=32, Cin=4, Cout=16, kernel_size=3, HWin=28, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', output_size=None, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, Cin, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) args = (Cin, Cout, kernel_size) # Handle output_padding if output_padding != 0: stride = output_padding + 1 dilation = output_padding + 1 # Handle output_size argument for the forward function if output_size: # The hardcoded input 57 and 58 are the possible size given stride == 2 stride = 2 output_size_arg = output_size if len(output_size) == 2 else ( output_size[0:1] + output_size[2:]) else: output_size_arg = None kwargs = { 'stride': stride, 'padding': padding, 'output_padding': output_padding, 'groups': groups, 'bias': bias, 'dilation': dilation, 'padding_mode': padding_mode, 'device': device, 'dtype': dtype, } conv_array = [nn.ConvTranspose2d(*args, **kwargs) for _ in range(B)] conv_fused = get_hfta_op_for(nn.ConvTranspose2d, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): conv_fused.snatch_parameters(conv_array[b], b) y_array = [ conv_array[b](x_array[b], output_size=output_size_arg) for b in range(B) ] y_fused_actual = conv_fused(x_fused, output_size=output_size) y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, population_threshold=1e-2, ) if output_size: assert ( y_fused_actual.shape == y_fused_expect.shape ), "The actual output size ({}) is different from the expected output size ({}).".format( y_fused_actual.shape, y_fused_expect.shape) except AssertionError as e: dump_error_msg(e)