def testcase_MaxPool2d( B=3, N=32, C=16, kernel_size=2, HWin=28, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, C, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) args = (kernel_size, ) kwargs = { 'stride': stride, 'padding': padding, 'dilation': dilation, 'return_indices': return_indices, 'ceil_mode': ceil_mode, } pool_array = [nn.MaxPool2d(*args, **kwargs) for _ in range(B)] pool_fused = get_hfta_op_for(nn.MaxPool2d, B=B)(*args, **kwargs) res_array = [pool_array[b](x_array[b]) for b in range(B)] res_fused_actual = pool_fused(x_fused) if return_indices: y_array, indices_array = tuple(zip(*res_array)) y_fused_actual, indices_fused_actual = res_fused_actual else: y_array = res_array y_fused_actual = res_fused_actual y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e) if return_indices: indices_fused_expect = torch.cat( [indices.unsqueeze(1) for indices in indices_array], dim=1, ) try: assert_allclose( indices_fused_actual.cpu().numpy(), indices_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def testcase( B=3, N=32, L=8, in_features=20, out_features=50, bias=True, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, L, in_features, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0) args = (in_features, out_features) kwargs = {'bias': bias, 'device': device, 'dtype': dtype} linear_array = [nn.Linear(*args, **kwargs) for _ in range(B)] linear_fused = get_hfta_op_for(nn.Linear, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): linear_fused.snatch_parameters(linear_array[b], b) y_array = [linear_array[b](x_array[b]) for b in range(B)] y_fused_actual = linear_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, population_threshold=1e-3, ) except AssertionError as e: dump_error_msg(e)
def testcase_AdaptiveAvgPool2d( B=3, N=32, C=16, HWin=28, output_size=(16, 16), device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, C, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) args = (output_size, ) pool_array = [nn.AdaptiveAvgPool2d(*args) for _ in range(B)] pool_fused = get_hfta_op_for(nn.AdaptiveAvgPool2d, B=B)(*args) y_array = [pool_array[b](x_array[b]) for b in range(B)] y_fused_actual = pool_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def testcase( B=3, N=32, input_dim=(20, ), num_embeddings=200, embedding_dim=50, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, _weight=None, device=torch.device('cpu'), x_dtype=torch.int, param_dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.randint( num_embeddings, [N] + list(input_dim), device=device, dtype=x_dtype, ) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0) args = (num_embeddings, embedding_dim) kwargs = { 'padding_idx': padding_idx, 'max_norm': max_norm, 'norm_type': norm_type, 'scale_grad_by_freq': scale_grad_by_freq, 'sparse': sparse, '_weight': _weight, 'device': device, 'dtype': param_dtype, } embedding_array = [nn.Embedding(*args, **kwargs) for _ in range(B)] embedding_fused = get_hfta_op_for(nn.Embedding, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): embedding_fused.snatch_parameters(embedding_array[b], b) y_array = [embedding_array[b](x_array[b]) for b in range(B)] y_fused_actual = embedding_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def testcase_Conv2d( B=3, N=32, Cin=4, Cout=16, kernel_size=3, HWin=28, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, Cin, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) args = (Cin, Cout, kernel_size) kwargs = { 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups, 'bias': bias, 'padding_mode': padding_mode, 'device': device, 'dtype': dtype, } conv_array = [nn.Conv2d(*args, **kwargs) for _ in range(B)] conv_fused = get_hfta_op_for(nn.Conv2d, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): conv_fused.snatch_parameters(conv_array[b], b) y_array = [conv_array[b](x_array[b]) for b in range(B)] y_fused_actual = conv_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, population_threshold=1e-2, ) except AssertionError as e: dump_error_msg(e)
def _assert_params_conv2d(fused_op, op, b, fused=True): try: if fused: assert_allclose( fused_op.weight.data[b].cpu().numpy(), op.weight.data.cpu().numpy(), rtol=1e-4, population_threshold=1e-2, ) if fused_op.bias is not None: assert_allclose( fused_op.bias.data[b].cpu().numpy(), op.bias.data.cpu().numpy(), rtol=1e-4, population_threshold=1e-2, ) else: _assert_params_unfused(fused_op, op, b) except AssertionError as e: dump_error_msg(e)
def testcase( B=5, N=64, normalized_shape=(200, ), elementwise_affine=True, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand([N] + list(normalized_shape), device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0) args = (normalized_shape, ) kwargs = { 'elementwise_affine': elementwise_affine, 'device': device, 'dtype': dtype, } layernorm_array = [nn.LayerNorm(*args, **kwargs) for _ in range(B)] layernorm_fused = get_hfta_op_for(nn.LayerNorm, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): layernorm_fused.snatch_parameters(layernorm_array[b], b) y_array = [layernorm_array[b](x_array[b]) for b in range(B)] y_fused_actual = layernorm_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def testcase_2d( num_features=128, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, B=3, N=8, HWin=28, train_test_steps=10, training=True, device=torch.device('cpu'), dtype=torch.float, ): C = num_features with torch.no_grad(): args = (num_features, ) kwargs = { 'eps': eps, 'momentum': momentum, 'affine': affine, 'track_running_stats': track_running_stats, 'device': device, 'dtype': dtype, } batchNormal2d_array = [ nn.BatchNorm2d(*args, **kwargs) for _ in range(B) ] batchNormal2d_fused = get_hfta_op_for(nn.BatchNorm2d, B=B)(*args, **kwargs) if track_running_stats: rand_int = random.randint(0, 1024) for bn in batchNormal2d_array: nn.init.normal_(bn.running_mean) nn.init.normal_(bn.running_var) bn.num_batches_tracked.fill_(rand_int) # Init weights and biases. for b in range(B): batchNormal2d_fused.snatch_parameters(batchNormal2d_array[b], b) if training: [bn.train() for bn in batchNormal2d_array] batchNormal2d_fused.train() else: [bn.eval() for bn in batchNormal2d_array] batchNormal2d_fused.eval() # check whether fused outputs are same in several training steps for i in range(train_test_steps): x_array = [ torch.rand(N, C, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) y_array = [batchNormal2d_array[b](x_array[b]) for b in range(B)] y_fused_actual = batchNormal2d_fused(x_fused) y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, ) except AssertionError as e: dump_error_msg(e)
def testcase_ConvTranspose2d( B=3, N=32, Cin=4, Cout=16, kernel_size=3, HWin=28, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', output_size=None, device=torch.device('cpu'), dtype=torch.float, ): with torch.no_grad(): x_array = [ torch.rand(N, Cin, HWin, HWin, device=device, dtype=dtype) for _ in range(B) ] x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1) args = (Cin, Cout, kernel_size) # Handle output_padding if output_padding != 0: stride = output_padding + 1 dilation = output_padding + 1 # Handle output_size argument for the forward function if output_size: # The hardcoded input 57 and 58 are the possible size given stride == 2 stride = 2 output_size_arg = output_size if len(output_size) == 2 else ( output_size[0:1] + output_size[2:]) else: output_size_arg = None kwargs = { 'stride': stride, 'padding': padding, 'output_padding': output_padding, 'groups': groups, 'bias': bias, 'dilation': dilation, 'padding_mode': padding_mode, 'device': device, 'dtype': dtype, } conv_array = [nn.ConvTranspose2d(*args, **kwargs) for _ in range(B)] conv_fused = get_hfta_op_for(nn.ConvTranspose2d, B=B)(*args, **kwargs) # Init weights and biases. for b in range(B): conv_fused.snatch_parameters(conv_array[b], b) y_array = [ conv_array[b](x_array[b], output_size=output_size_arg) for b in range(B) ] y_fused_actual = conv_fused(x_fused, output_size=output_size) y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1) try: assert_allclose( y_fused_actual.cpu().numpy(), y_fused_expect.cpu().numpy(), rtol=1e-4, population_threshold=1e-2, ) if output_size: assert ( y_fused_actual.shape == y_fused_expect.shape ), "The actual output size ({}) is different from the expected output size ({}).".format( y_fused_actual.shape, y_fused_expect.shape) except AssertionError as e: dump_error_msg(e)