Esempio n. 1
0
    def _forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        if not self.use_softpool:
            branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        else:
            if self.pad:
                branch_pool = F.pad(x, (1, 1, 1, 1), 'constant', 0)
            else:
                branch_pool = F.pad(x, (0, 0, 0, 0), 'constant', 0)
            branch_pool = soft_pool2d(branch_pool, kernel_size=3, stride=1)

        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return outputs
Esempio n. 2
0
    def _forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        if not self.use_softpool:
            branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        else:
            if self.pad:
                branch_pool = F.pad(x, (1, 1, 1, 1), 'constant', 0)
            else:
                branch_pool = F.pad(x, (0, 0, 0, 0), 'constant', 0)
            branch_pool = soft_pool2d(branch_pool, kernel_size=3, stride=1)

        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return outputs
Esempio n. 3
0
    def _forward(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        if not self.use_softpool:
            branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        else:
            branch_pool = soft_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return outputs
Esempio n. 4
0
 def forward(self, x):
     # N x 768 x 17 x 17
     if not self.use_softpool:
         x = F.avg_pool2d(x, kernel_size=5, stride=3)
     else:
         x = soft_pool2d(x, kernel_size=5, stride=3)
     # N x 768 x 5 x 5
     x = self.conv0(x)
     # N x 128 x 5 x 5
     x = self.conv1(x)
     # N x 768 x 1 x 1
     # Adaptive average pooling
     x = F.adaptive_avg_pool2d(x, (1, 1))
     # N x 768 x 1 x 1
     x = torch.flatten(x, 1)
     # N x 768
     x = self.fc(x)
     # N x 1000
     return x
Esempio n. 5
0
except Exception as e:
    print('\033[91m' + '> FAILED' + '\033[0m')
    print(e)

print('\033[93m' + '> Checking 1D CPU-GPU output similarities ...' + '\033[0m')
try:
    check_close_enough(pl_1d_cpu.data, pl_1d_gpu.data)
    print('\033[92m' + '> PASSED' + '\033[0m' + '\n')
except Exception as e:
    print('\033[91m' + '> FAILED' + '\033[0m')
    print(e, '\n')

################## 2D FORWARD ##################
print('\033[93m' + '> Checking 2D CPU ...' + '\033[0m')
try:
    pl_2d_cpu = soft_pool2d(x_2d)
    print('\033[92m' + '> PASSED' + '\033[0m')
except Exception as e:
    print('\033[91m' + '> FAILED' + '\033[0m')
    print(e)

print('\033[93m' + '> Checking 2D GPU ...' + '\033[0m')
try:
    pl_2d_gpu = soft_pool2d(x_2d.cuda())
    print('\033[92m' + '> PASSED' + '\033[0m')
except Exception as e:
    print('\033[91m' + '> FAILED' + '\033[0m')
    print(e)

print('\033[93m' + '> Checking 2D CPU-GPU output similarities ...' + '\033[0m')
try: