示例#1
0
 def __init__(self, vocab_size, embed_size, dropout=0.1, B=1):
   """
     :param vocab_size: total vocab size
     :param embed_size: embedding size of token embedding
     :param dropout: dropout rate
   """
   super().__init__()
   Embedding = get_hfta_op_for(nn.Embedding, B)
   self.token = Embedding(vocab_size, embed_size)
   self.position = Embedding(512, embed_size)
   self.segment = Embedding(3, embed_size)
   self.dropout = get_hfta_op_for(nn.Dropout, B)(p=dropout)
   self.embed_size = embed_size
示例#2
0
def testcase(
    B=3,
    N=32,
    L=8,
    in_features=20,
    out_features=50,
    bias=True,
    device=torch.device('cpu'),
    dtype=torch.float,
):
  with torch.no_grad():
    x_array = [
        torch.rand(N, L, in_features, device=device, dtype=dtype)
        for _ in range(B)
    ]
    x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0)
    args = (in_features, out_features)
    kwargs = {'bias': bias, 'device': device, 'dtype': dtype}
    linear_array = [nn.Linear(*args, **kwargs) for _ in range(B)]
    linear_fused = get_hfta_op_for(nn.Linear, B=B)(*args, **kwargs)
    # Init weights and biases.
    for b in range(B):
      linear_fused.snatch_parameters(linear_array[b], b)
    y_array = [linear_array[b](x_array[b]) for b in range(B)]
    y_fused_actual = linear_fused(x_fused)
    y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0)
    try:
      assert_allclose(
          y_fused_actual.cpu().numpy(),
          y_fused_expect.cpu().numpy(),
          rtol=1e-4,
          population_threshold=1e-3,
      )
    except AssertionError as e:
      dump_error_msg(e)
示例#3
0
def testcase_AdaptiveAvgPool2d(
        B=3,
        N=32,
        C=16,
        HWin=28,
        output_size=(16, 16),
        device=torch.device('cpu'),
        dtype=torch.float,
):
    with torch.no_grad():
        x_array = [
            torch.rand(N, C, HWin, HWin, device=device, dtype=dtype)
            for _ in range(B)
        ]
        x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1)
        args = (output_size, )
        pool_array = [nn.AdaptiveAvgPool2d(*args) for _ in range(B)]
        pool_fused = get_hfta_op_for(nn.AdaptiveAvgPool2d, B=B)(*args)
        y_array = [pool_array[b](x_array[b]) for b in range(B)]
        y_fused_actual = pool_fused(x_fused)
        y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1)
        try:
            assert_allclose(
                y_fused_actual.cpu().numpy(),
                y_fused_expect.cpu().numpy(),
                rtol=1e-4,
            )
        except AssertionError as e:
            dump_error_msg(e)
示例#4
0
def testcase_MaxPool2d(
        B=3,
        N=32,
        C=16,
        kernel_size=2,
        HWin=28,
        stride=None,
        padding=0,
        dilation=1,
        return_indices=False,
        ceil_mode=False,
        device=torch.device('cpu'),
        dtype=torch.float,
):
    with torch.no_grad():
        x_array = [
            torch.rand(N, C, HWin, HWin, device=device, dtype=dtype)
            for _ in range(B)
        ]
        x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1)
        args = (kernel_size, )
        kwargs = {
            'stride': stride,
            'padding': padding,
            'dilation': dilation,
            'return_indices': return_indices,
            'ceil_mode': ceil_mode,
        }
        pool_array = [nn.MaxPool2d(*args, **kwargs) for _ in range(B)]
        pool_fused = get_hfta_op_for(nn.MaxPool2d, B=B)(*args, **kwargs)
        res_array = [pool_array[b](x_array[b]) for b in range(B)]
        res_fused_actual = pool_fused(x_fused)
        if return_indices:
            y_array, indices_array = tuple(zip(*res_array))
            y_fused_actual, indices_fused_actual = res_fused_actual
        else:
            y_array = res_array
            y_fused_actual = res_fused_actual
        y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1)
        try:
            assert_allclose(
                y_fused_actual.cpu().numpy(),
                y_fused_expect.cpu().numpy(),
                rtol=1e-4,
            )
        except AssertionError as e:
            dump_error_msg(e)
        if return_indices:
            indices_fused_expect = torch.cat(
                [indices.unsqueeze(1) for indices in indices_array],
                dim=1,
            )
            try:
                assert_allclose(
                    indices_fused_actual.cpu().numpy(),
                    indices_fused_expect.cpu().numpy(),
                    rtol=1e-4,
                )
            except AssertionError as e:
                dump_error_msg(e)
示例#5
0
    def _make_layer(self, block, planes, blocks, stride=1, B=1):
        downsample = None
        norm_layer = get_hfta_op_for(nn.BatchNorm2d, B)

        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride, B=B),
                norm_layer(planes * block.expansion,
                           track_running_stats=self.track_running_stats),
            )

        layers = []
        layers.append(
            block(self.inplanes,
                  planes,
                  stride,
                  downsample,
                  norm_layer,
                  track_running_stats=self.track_running_stats,
                  B=B))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      norm_layer=norm_layer,
                      track_running_stats=self.track_running_stats,
                      B=B))

        return nn.Sequential(*layers)
示例#6
0
def testcase(
        B=3,
        N=32,
        C=1024,
        HWin=16,
        p=0.4,
        device=torch.device('cpu'),
        dtype=torch.float,
):
    with torch.no_grad():
        x_array = [
            torch.ones(N, C, HWin, HWin, device=device, dtype=dtype)
            for _ in range(B)
        ]
        x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1)
        dropout2d_fused = get_hfta_op_for(torch.nn.Dropout2d, B=B)(p)
        y_fused = dropout2d_fused(x_fused)
        for b in range(B):
            y = y_fused[:, b, :, :, :]
            assert y.size(0) == N
            assert y.size(1) == C
            for n in range(N):
                zero_channels = 0
                for c in range(C):
                    s = y[n, c].sum()
                    # Each channel either has all zeros or no zeros.
                    try:
                        assert_allclose(s.cpu(), HWin**2 / (1 - p), rtol=1e-4)
                    except AssertionError as e:
                        assert_allclose(s.cpu(), 0, atol=1e-4)
                        # s must be zero at this point.
                        zero_channels += 1
                assert_allclose(zero_channels / C, p, rtol=2e-1)
示例#7
0
def conv1x1(in_planes, out_planes, stride=1, B=1):
    """1x1 convolution"""
    return get_hfta_op_for(nn.Conv2d, B)(
        in_planes,
        out_planes,
        kernel_size=1,
        stride=stride,
        bias=False,
    )
示例#8
0
 def __init__(
     self,
     B=0,
     partially_fused=False,
     device=torch.device('cpu'),
     dtype=torch.float,
 ):
   super(_TestNet, self).__init__()
   kwargs = {'device': device, 'dtype': dtype}
   self.conv1 = get_hfta_op_for(nn.Conv2d, B=B)(256, 128, 3, 3, **kwargs)
   if partially_fused:
     self.conv2 = [nn.Conv2d(128, 256, 5, 5, **kwargs) for _ in range(B)]
   else:
     self.conv2 = get_hfta_op_for(nn.Conv2d, B=B)(128, 256, 5, 5, **kwargs)
   if partially_fused:
     self.linear1 = [nn.Linear(500, 1000, **kwargs) for _ in range(B)]
   else:
     self.linear1 = get_hfta_op_for(nn.Linear, B=B)(500, 1000, **kwargs)
   self.linear2 = get_hfta_op_for(nn.Linear, B=B)(1000, 500, **kwargs)
   self.partially_fused = partially_fused
示例#9
0
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, B=1):
    """3x3 convolution with padding"""
    return get_hfta_op_for(nn.Conv2d, B)(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )
示例#10
0
def testcase(
        B=3,
        N=32,
        input_dim=(20, ),
        num_embeddings=200,
        embedding_dim=50,
        padding_idx=None,
        max_norm=None,
        norm_type=2.,
        scale_grad_by_freq=False,
        sparse=False,
        _weight=None,
        device=torch.device('cpu'),
        x_dtype=torch.int,
        param_dtype=torch.float,
):
    with torch.no_grad():
        x_array = [
            torch.randint(
                num_embeddings,
                [N] + list(input_dim),
                device=device,
                dtype=x_dtype,
            ) for _ in range(B)
        ]
        x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0)
        args = (num_embeddings, embedding_dim)
        kwargs = {
            'padding_idx': padding_idx,
            'max_norm': max_norm,
            'norm_type': norm_type,
            'scale_grad_by_freq': scale_grad_by_freq,
            'sparse': sparse,
            '_weight': _weight,
            'device': device,
            'dtype': param_dtype,
        }
        embedding_array = [nn.Embedding(*args, **kwargs) for _ in range(B)]
        embedding_fused = get_hfta_op_for(nn.Embedding, B=B)(*args, **kwargs)
        # Init weights and biases.
        for b in range(B):
            embedding_fused.snatch_parameters(embedding_array[b], b)
        y_array = [embedding_array[b](x_array[b]) for b in range(B)]
        y_fused_actual = embedding_fused(x_fused)
        y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0)
        try:
            assert_allclose(
                y_fused_actual.cpu().numpy(),
                y_fused_expect.cpu().numpy(),
                rtol=1e-4,
            )
        except AssertionError as e:
            dump_error_msg(e)
示例#11
0
 def __init__(self, d_model, dropout=0.1, max_len=5000, B=1):
   super(PositionalEncoding, self).__init__()
   self.dropout = get_hfta_op_for(nn.Dropout, B)(p=dropout)
   self.B = B
   pe = torch.zeros(max_len, d_model)
   position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
   div_term = torch.exp(
       torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
   pe[:, 0::2] = torch.sin(position * div_term)
   pe[:, 1::2] = torch.cos(position * div_term)
   pe = pe.unsqueeze(0)
   self.register_buffer('pe', pe)
示例#12
0
    def __init__(self,
                 block,
                 layers,
                 num_classes=10,
                 zero_init_residual=False,
                 track_running_stats=True,
                 B=1):
        super(ResNet, self).__init__()
        self.B = B
        self.track_running_stats = track_running_stats
        norm_layer = get_hfta_op_for(nn.BatchNorm2d, B)
        self._conv_layer = get_hfta_op_for(nn.Conv2d,
                                           B).func if B > 0 else nn.Conv2d
        self._norm_layer = get_hfta_op_for(nn.BatchNorm2d,
                                           B).func if B > 0 else nn.BatchNorm2d
        self._linear_layer = get_hfta_op_for(nn.Linear,
                                             B).func if B > 0 else nn.Linear

        self.inplanes = 64

        self.conv1 = get_hfta_op_for(nn.Conv2d, B=B)(3,
                                                     self.inplanes,
                                                     kernel_size=7,
                                                     stride=2,
                                                     padding=3,
                                                     bias=False)
        self.bn1 = norm_layer(self.inplanes,
                              track_running_stats=track_running_stats)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = get_hfta_op_for(nn.MaxPool2d, B=B)(kernel_size=3,
                                                          stride=2,
                                                          padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], B=B)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, B=B)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, B=B)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, B=B)
        self.fc = get_hfta_op_for(nn.Linear, B)(512 * block.expansion,
                                                num_classes)
        for m in self.modules():
            if isinstance(m, self._conv_layer):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, self._norm_layer):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
示例#13
0
def testcase_Conv2d(
    B=3,
    N=32,
    Cin=4,
    Cout=16,
    kernel_size=3,
    HWin=28,
    stride=1,
    padding=0,
    dilation=1,
    groups=1,
    bias=True,
    padding_mode='zeros',
    device=torch.device('cpu'),
    dtype=torch.float,
):
  with torch.no_grad():
    x_array = [
        torch.rand(N, Cin, HWin, HWin, device=device, dtype=dtype)
        for _ in range(B)
    ]
    x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1)
    args = (Cin, Cout, kernel_size)
    kwargs = {
        'stride': stride,
        'padding': padding,
        'dilation': dilation,
        'groups': groups,
        'bias': bias,
        'padding_mode': padding_mode,
        'device': device,
        'dtype': dtype,
    }
    conv_array = [nn.Conv2d(*args, **kwargs) for _ in range(B)]
    conv_fused = get_hfta_op_for(nn.Conv2d, B=B)(*args, **kwargs)
    # Init weights and biases.
    for b in range(B):
      conv_fused.snatch_parameters(conv_array[b], b)
    y_array = [conv_array[b](x_array[b]) for b in range(B)]
    y_fused_actual = conv_fused(x_fused)
    y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1)
    try:
      assert_allclose(
          y_fused_actual.cpu().numpy(),
          y_fused_expect.cpu().numpy(),
          rtol=1e-4,
          population_threshold=1e-2,
      )
    except AssertionError as e:
      dump_error_msg(e)
示例#14
0
    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 norm_layer=None,
                 track_running_stats=True,
                 B=1):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = get_hfta_op_for(nn.BatchNorm2d, B)

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride, B=B)
        self.bn1 = norm_layer(planes, track_running_stats=track_running_stats)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, B=B)
        self.bn2 = norm_layer(planes, track_running_stats=track_running_stats)
        self.downsample = downsample
        self.stride = stride
示例#15
0
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout, B=1):
        """
      :param hidden: hidden size of transformer
      :param attn_heads: head sizes of multi-head attention
      :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
      :param dropout: dropout rate
    """

        super().__init__()
        self.attention = MultiheadAttention(hidden, attn_heads, B=B)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden,
                                                    d_ff=feed_forward_hidden,
                                                    dropout=dropout,
                                                    B=B)
        self.input_sublayer = SublayerConnection(size=hidden,
                                                 dropout=dropout,
                                                 B=B)
        self.output_sublayer = SublayerConnection(size=hidden,
                                                  dropout=dropout,
                                                  B=B)
        self.dropout = get_hfta_op_for(nn.Dropout, B)(p=dropout)
示例#16
0
def testcase(
        B=5,
        N=64,
        normalized_shape=(200, ),
        elementwise_affine=True,
        device=torch.device('cpu'),
        dtype=torch.float,
):
    with torch.no_grad():
        x_array = [
            torch.rand([N] + list(normalized_shape),
                       device=device,
                       dtype=dtype) for _ in range(B)
        ]
        x_fused = torch.cat([x.unsqueeze(0) for x in x_array], dim=0)
        args = (normalized_shape, )
        kwargs = {
            'elementwise_affine': elementwise_affine,
            'device': device,
            'dtype': dtype,
        }
        layernorm_array = [nn.LayerNorm(*args, **kwargs) for _ in range(B)]
        layernorm_fused = get_hfta_op_for(nn.LayerNorm, B=B)(*args, **kwargs)
        # Init weights and biases.
        for b in range(B):
            layernorm_fused.snatch_parameters(layernorm_array[b], b)
        y_array = [layernorm_array[b](x_array[b]) for b in range(B)]
        y_fused_actual = layernorm_fused(x_fused)
        y_fused_expect = torch.cat([y.unsqueeze(0) for y in y_array], dim=0)
        try:
            assert_allclose(
                y_fused_actual.cpu().numpy(),
                y_fused_expect.cpu().numpy(),
                rtol=1e-4,
            )
        except AssertionError as e:
            dump_error_msg(e)
示例#17
0
    def __init__(self,
                 vocab_size,
                 hidden=768,
                 n_layers=12,
                 attn_heads=12,
                 dropout=0.1,
                 B=1):
        """
      :param vocab_size: vocab_size of total words
      :param hidden: BERT model hidden size
      :param n_layers: numbers of Transformer blocks(layers)
      :param attn_heads: number of attention heads
      :param dropout: dropout rate
    """

        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads
        self.B = B

        # paper noted they used 4*hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = hidden * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size,
                                       embed_size=hidden,
                                       B=B)

        # multi-layers transformer blocks, deep network
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(hidden, attn_heads, hidden * 4, dropout, B=B)
            for _ in range(n_layers)
        ])

        self.output = get_hfta_op_for(nn.Linear, B)(hidden, vocab_size)
示例#18
0
def convert_ops(B, *torch_op_classes):
    return (get_hfta_op_for(op_class, B=B) for op_class in torch_op_classes)
示例#19
0
def testcase_2d(
        num_features=128,
        eps=1e-5,
        momentum=0.1,
        affine=True,
        track_running_stats=True,
        B=3,
        N=8,
        HWin=28,
        train_test_steps=10,
        training=True,
        device=torch.device('cpu'),
        dtype=torch.float,
):
    C = num_features
    with torch.no_grad():

        args = (num_features, )
        kwargs = {
            'eps': eps,
            'momentum': momentum,
            'affine': affine,
            'track_running_stats': track_running_stats,
            'device': device,
            'dtype': dtype,
        }
        batchNormal2d_array = [
            nn.BatchNorm2d(*args, **kwargs) for _ in range(B)
        ]
        batchNormal2d_fused = get_hfta_op_for(nn.BatchNorm2d, B=B)(*args,
                                                                   **kwargs)
        if track_running_stats:
            rand_int = random.randint(0, 1024)
            for bn in batchNormal2d_array:
                nn.init.normal_(bn.running_mean)
                nn.init.normal_(bn.running_var)
                bn.num_batches_tracked.fill_(rand_int)
        # Init weights and biases.
        for b in range(B):
            batchNormal2d_fused.snatch_parameters(batchNormal2d_array[b], b)

        if training:
            [bn.train() for bn in batchNormal2d_array]
            batchNormal2d_fused.train()
        else:
            [bn.eval() for bn in batchNormal2d_array]
            batchNormal2d_fused.eval()

        # check whether fused outputs are same in several training steps
        for i in range(train_test_steps):
            x_array = [
                torch.rand(N, C, HWin, HWin, device=device, dtype=dtype)
                for _ in range(B)
            ]
            x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1)

            y_array = [batchNormal2d_array[b](x_array[b]) for b in range(B)]
            y_fused_actual = batchNormal2d_fused(x_fused)
            y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array],
                                       dim=1)
            try:
                assert_allclose(
                    y_fused_actual.cpu().numpy(),
                    y_fused_expect.cpu().numpy(),
                    rtol=1e-4,
                )
            except AssertionError as e:
                dump_error_msg(e)
示例#20
0
    def _make_layer(self,
                    block,
                    serial_block,
                    planes,
                    blocks,
                    run_in_serial,
                    stride=1,
                    B=1):
        downsample = None
        norm_layer = get_hfta_op_for(nn.BatchNorm2d, B)
        assert block.expansion == serial_block.expansion

        if stride != 1 or self.inplanes != planes * block.expansion:
            if self.B > 0 and run_in_serial[0]:
                downsample = nn.Sequential(
                    conv1x1(self.inplanes,
                            planes * block.expansion,
                            stride,
                            B=0),
                    nn.BatchNorm2d(planes * block.expansion),
                )
            else:
                downsample = nn.Sequential(
                    conv1x1(self.inplanes,
                            planes * block.expansion,
                            stride,
                            B=B),
                    norm_layer(planes * block.expansion,
                               track_running_stats=self.track_running_stats),
                )

        layers = []
        if self.B > 0 and run_in_serial[0]:
            current_block = serial_block(
                self.inplanes,
                planes,
                stride,
                downsample,
                nn.BatchNorm2d,
                B=B,
                track_running_stats=self.track_running_stats)
            self.unfused_layers.append(current_block)
        else:
            current_block = block(self.inplanes,
                                  planes,
                                  stride,
                                  downsample,
                                  norm_layer,
                                  B=B,
                                  track_running_stats=self.track_running_stats)
        layers.append(current_block)

        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            if self.B > 0 and run_in_serial[i]:
                current_block = serial_block(
                    self.inplanes,
                    planes,
                    norm_layer=nn.BatchNorm2d,
                    B=B,
                    track_running_stats=self.track_running_stats)
                self.unfused_layers.append(current_block)
            else:
                current_block = block(
                    self.inplanes,
                    planes,
                    norm_layer=norm_layer,
                    B=B,
                    track_running_stats=self.track_running_stats)
            layers.append(current_block)

        return nn.Sequential(*layers)
示例#21
0
    def __init__(self,
                 config,
                 block,
                 serial_block,
                 num_classes=10,
                 zero_init_residual=False,
                 track_running_stats=True,
                 B=1):
        super(PartiallyFusedResNet, self).__init__()
        layers = config["layers"]
        run_in_serial = config["run_in_serial"]
        self.B = B
        self.track_running_stats = track_running_stats
        norm_layer = get_hfta_op_for(nn.BatchNorm2d, B)
        self._conv_layer = get_hfta_op_for(nn.Conv2d,
                                           B).func if B > 0 else nn.Conv2d
        self._norm_layer = get_hfta_op_for(nn.BatchNorm2d,
                                           B).func if B > 0 else nn.BatchNorm2d
        self._linear_layer = get_hfta_op_for(nn.Linear,
                                             B).func if B > 0 else nn.Linear

        self.inplanes = 64
        if run_in_serial[4][1]:
            self.convBlock = SerialConvBlock(B, 3, self.inplanes)
            self.unfused_layers.append(self.convBlock)
        else:
            self.convBlock = nn.Sequential(
                get_hfta_op_for(nn.Conv2d, B=B)(3,
                                                self.inplanes,
                                                kernel_size=7,
                                                stride=2,
                                                padding=3,
                                                bias=False),
                norm_layer(self.inplanes,
                           track_running_stats=track_running_stats),
                nn.ReLU(inplace=True),
                get_hfta_op_for(nn.MaxPool2d, B=B)(kernel_size=3,
                                                   stride=2,
                                                   padding=1))

        self.layer1 = self._make_layer(block,
                                       serial_block,
                                       64,
                                       layers[0],
                                       run_in_serial[0],
                                       B=B)
        self.layer2 = self._make_layer(block,
                                       serial_block,
                                       128,
                                       layers[1],
                                       run_in_serial[1],
                                       stride=2,
                                       B=B)
        self.layer3 = self._make_layer(block,
                                       serial_block,
                                       256,
                                       layers[2],
                                       run_in_serial[2],
                                       stride=2,
                                       B=B)
        self.layer4 = self._make_layer(block,
                                       serial_block,
                                       512,
                                       layers[3],
                                       run_in_serial[3],
                                       stride=2,
                                       B=B)
        if run_in_serial[4][0]:
            self.fc = SerialLinear(512 * block.expansion, num_classes, B=B)
            self.unfused_layers.append(self.fc)
        else:
            self.fc = get_hfta_op_for(nn.Linear, B)(512 * block.expansion,
                                                    num_classes)

        for m in self.modules():
            if isinstance(m, self._conv_layer):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, self._norm_layer):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
示例#22
0

# custom weights initialization called on netG and netD
# NOTE(wangsh46): This is okay for HFTA, because torch.nn.init.normal_ (or
# torch.Tensor.normal_ to be specific) is element-wise; so is
# torch.nn.init.zeros_
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


ConvTranspose2d = get_hfta_op_for(nn.ConvTranspose2d, B=B)
BatchNorm2d = get_hfta_op_for(nn.BatchNorm2d, B=B)
ReLU = get_hfta_op_for(nn.ReLU, B=B)
Tanh = get_hfta_op_for(nn.Tanh, B=B)


class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            BatchNorm2d(ngf * 8, track_running_stats=(args.device != 'xla')),
            ReLU(True),
            # state size. (ngf*8) x 4 x 4
示例#23
0
def testcase_ConvTranspose2d(
    B=3,
    N=32,
    Cin=4,
    Cout=16,
    kernel_size=3,
    HWin=28,
    stride=1,
    padding=0,
    output_padding=0,
    groups=1,
    bias=True,
    dilation=1,
    padding_mode='zeros',
    output_size=None,
    device=torch.device('cpu'),
    dtype=torch.float,
):
  with torch.no_grad():
    x_array = [
        torch.rand(N, Cin, HWin, HWin, device=device, dtype=dtype)
        for _ in range(B)
    ]
    x_fused = torch.cat([x.unsqueeze(1) for x in x_array], dim=1)
    args = (Cin, Cout, kernel_size)

    # Handle output_padding
    if output_padding != 0:
      stride = output_padding + 1
      dilation = output_padding + 1

    # Handle output_size argument for the forward function
    if output_size:
      # The hardcoded input 57 and 58 are the possible size given stride == 2
      stride = 2
      output_size_arg = output_size if len(output_size) == 2 else (
          output_size[0:1] + output_size[2:])
    else:
      output_size_arg = None

    kwargs = {
        'stride': stride,
        'padding': padding,
        'output_padding': output_padding,
        'groups': groups,
        'bias': bias,
        'dilation': dilation,
        'padding_mode': padding_mode,
        'device': device,
        'dtype': dtype,
    }
    conv_array = [nn.ConvTranspose2d(*args, **kwargs) for _ in range(B)]
    conv_fused = get_hfta_op_for(nn.ConvTranspose2d, B=B)(*args, **kwargs)
    # Init weights and biases.
    for b in range(B):
      conv_fused.snatch_parameters(conv_array[b], b)

    y_array = [
        conv_array[b](x_array[b], output_size=output_size_arg) for b in range(B)
    ]
    y_fused_actual = conv_fused(x_fused, output_size=output_size)
    y_fused_expect = torch.cat([y.unsqueeze(1) for y in y_array], dim=1)
    try:
      assert_allclose(
          y_fused_actual.cpu().numpy(),
          y_fused_expect.cpu().numpy(),
          rtol=1e-4,
          population_threshold=1e-2,
      )
      if output_size:
        assert (
            y_fused_actual.shape == y_fused_expect.shape
        ), "The actual output size ({}) is different from the expected output size ({}).".format(
            y_fused_actual.shape, y_fused_expect.shape)
    except AssertionError as e:
      dump_error_msg(e)