class DenseNet(nn.Module):
    def __init__(self):
        super(DenseNet, self).__init__()
        input_channels = 1
        conv_channels = 32
        down_structure = [2, 2, 2]
        output_channels = 4  #4#2
        act_fn = config["act_fn"]
        norm_fn = config["norm_fn"]
        self.features = nn.Sequential()
        self.features.add_module(
            "init_conv",
            nn.Conv3d(input_channels,
                      conv_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=True))
        self.features.add_module("init_norm", norm_fn(conv_channels))
        self.features.add_module("init_act", act_fn())
        self.dropblock = LinearScheduler(DropBlock3D(drop_prob=0.,
                                                     block_size=5),
                                         start_value=0.,
                                         stop_value=0.5,
                                         nr_steps=5e3)

        channels = conv_channels
        self.features.add_module('drop_block',
                                 DropBlock3D(drop_prob=0.1, block_size=5))

        for i, num_layers in enumerate(down_structure):
            for j in range(num_layers):
                conv_layer = ConvBlock(channels)
                self.features.add_module(
                    "block{}_layer{}".format(i + 1, j + 1), conv_layer)
                channels = conv_layer.out_channels

            # dowmsample
            trans_layer = TransmitBlock(
                channels, is_last_layer=(i == len(down_structure) - 1))
            self.features.add_module("transition{}".format(i + 1), trans_layer)
            channels = trans_layer.out_channels

        self.classifier = nn.Linear(channels, output_channels)

    def forward(self, x, **return_opts):
        self.dropblock.step()
        batch_size, _, z, h, w = x.size()

        features = self.dropblock(self.features(x))
        # print("features", features.size())
        pooled = F.adaptive_avg_pool3d(features, 1).view(batch_size, -1)
        # print("pooled", pooled.size())
        scores = self.classifier(pooled)
        # print("scored", scores.size())

        if len(return_opts) == 0:
            return scores
class ResNetCustom(ResNet):
    def __init__(self,
                 block,
                 layers,
                 num_classes=1000,
                 drop_prob=0.,
                 block_size=5):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=7,
                               stride=2,
                               padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.dropblock = LinearScheduler(DropBlock2D(drop_prob=drop_prob,
                                                     block_size=block_size),
                                         start_value=0.,
                                         stop_value=drop_prob,
                                         nr_steps=5e3)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        self.dropblock.step()  # increment number of iterations

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.dropblock(self.layer1(x))
        x = self.dropblock(self.layer2(x))
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)

        return x
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, drop_prob=0., block_size=5):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3, 64)
        self.bn1 = nn.BatchNorm2d(64)
        self.dropblock = LinearScheduler(
            DropBlock2D(drop_prob=drop_prob, block_size=block_size, att=True),
            start_value=0.,
            stop_value=drop_prob,
            nr_steps=5e4
        )
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        self.dropblock.step()
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropblock(self.layer1(out))
        out1 = out
        out = self.dropblock(self.layer2(out))
        out2 = out
        out = self.layer3(out)
        out3 = out
        out = self.layer4(out)
        out4 = out
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return [out, out1.cpu().detach().numpy(),
                out2.cpu().detach().numpy(),
                out3.cpu().detach().numpy(),
                out4.cpu().detach().numpy()]
Exemple #4
0
class ResNet(nn.Module):
    def __init__(self,
                 block,
                 num_blocks,
                 num_classes=10,
                 drop_prob=0.,
                 block_size=5,
                 nr_steps=5000):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3, 64)
        self.bn1 = nn.BatchNorm2d(64)
        self.dropblock = LinearScheduler(DropBlock2D(drop_prob=drop_prob,
                                                     block_size=block_size,
                                                     att=True),
                                         start_value=0.,
                                         stop_value=drop_prob,
                                         nr_steps=nr_steps)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        self.dropblock.step()
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropblock(self.layer1(out))
        out = self.dropblock(self.layer2(out))
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

    def get_mask(self, x, prob):
        # self.dropblock.step()
        self.dropblock.dropblock.drop_prob = prob
        out = F.relu(self.bn1(self.conv1(x)))
        out, blk1 = self.dropblock.get_mask(self.layer1(out))
        out, blk2 = self.dropblock.get_mask(self.layer2(out))
        # out = self.layer3(out)
        # out = self.layer4(out)
        # out = F.avg_pool2d(out, 4)
        # out = out.view(out.size(0), -1)
        # out = self.linear(out)
        return out, blk1, blk2
Exemple #5
0
class FluxNet(nn.Module):
    default_filters = (8, 16, 24, 32, 64)

    def __init__(self,
                 in_channel=1,
                 filters=default_filters,
                 aspp_dilation_ratio=1,
                 symmetric=True,
                 use_flux_head=True,
                 use_skeleton_head=False):
        super().__init__()

        # encoding path
        self.symmetric = symmetric
        if self.symmetric:
            self.layer1_E = nn.Sequential(
                conv3d_bn_relu(in_planes=in_channel,
                               out_planes=filters[0],
                               kernel_size=(5, 5, 5),
                               stride=1,
                               padding=(2, 2, 2)),
                conv3d_bn_relu(in_planes=filters[0],
                               out_planes=filters[0],
                               kernel_size=(3, 3, 3),
                               stride=1,
                               padding=(1, 1, 1)),
                residual_block_3d(filters[0], filters[0], projection=False))
        else:
            self.layer1_E = nn.Sequential(
                conv3d_bn_relu(in_planes=in_channel,
                               out_planes=filters[0],
                               kernel_size=(1, 5, 5),
                               stride=1,
                               padding=(0, 2, 2)),
                conv3d_bn_relu(in_planes=filters[0],
                               out_planes=filters[0],
                               kernel_size=(1, 3, 3),
                               stride=1,
                               padding=(0, 1, 1)),
                residual_block_2d(filters[0], filters[0], projection=False))

        self.layer2_E = nn.Sequential(
            conv3d_bn_relu(in_planes=filters[0],
                           out_planes=filters[1],
                           kernel_size=(3, 3, 3),
                           stride=1,
                           padding=(1, 1, 1)),
            residual_block_3d(filters[1], filters[1], projection=False))
        self.layer3_E = nn.Sequential(
            conv3d_bn_relu(in_planes=filters[1],
                           out_planes=filters[2],
                           kernel_size=(3, 3, 3),
                           stride=1,
                           padding=(1, 1, 1)),
            residual_block_3d(filters[2], filters[2], projection=False))
        self.layer4_E = nn.Sequential(
            conv3d_bn_relu(in_planes=filters[2],
                           out_planes=filters[3],
                           kernel_size=(3, 3, 3),
                           stride=1,
                           padding=(1, 1, 1)),
            residual_block_3d(filters[3], filters[3], projection=False))

        # center ASPP block
        self.center = ASPP(
            filters[3], filters[4],
            [[2, int(2 * aspp_dilation_ratio),
              int(2 * aspp_dilation_ratio)],
             [3, int(3 * aspp_dilation_ratio),
              int(3 * aspp_dilation_ratio)],
             [5, int(5 * aspp_dilation_ratio),
              int(5 * aspp_dilation_ratio)]])

        # decoding path
        self.layer1_D_flux = nn.Sequential(
            conv3d_bn_relu(in_planes=filters[1],
                           out_planes=filters[0],
                           kernel_size=(1, 1, 1),
                           stride=1,
                           padding=(0, 0, 0)),
            residual_block_3d(filters[0], filters[0], projection=False),
            conv3d_bn_non(in_planes=filters[0],
                          out_planes=3,
                          kernel_size=(3, 3, 3),
                          stride=1,
                          padding=(1, 1, 1)))

        self.layer1_D_skeleton = nn.Sequential(
            conv3d_bn_relu(in_planes=filters[1],
                           out_planes=filters[0],
                           kernel_size=(1, 1, 1),
                           stride=1,
                           padding=(0, 0, 0)),
            residual_block_3d(filters[0], filters[0], projection=False),
            conv3d_bn_non(in_planes=filters[0],
                          out_planes=1,
                          kernel_size=(3, 3, 3),
                          stride=1,
                          padding=(1, 1, 1)))

        self.layer2_D = nn.Sequential(
            conv3d_bn_relu(in_planes=filters[2],
                           out_planes=filters[1],
                           kernel_size=(1, 1, 1),
                           stride=1,
                           padding=(0, 0, 0)),
            residual_block_3d(filters[1], filters[1], projection=False))
        self.layer3_D = nn.Sequential(
            conv3d_bn_relu(in_planes=filters[3],
                           out_planes=filters[2],
                           kernel_size=(1, 1, 1),
                           stride=1,
                           padding=(0, 0, 0)),
            residual_block_3d(filters[2], filters[2], projection=False))
        self.layer4_D = nn.Sequential(
            conv3d_bn_relu(in_planes=filters[4],
                           out_planes=filters[3],
                           kernel_size=(1, 1, 1),
                           stride=1,
                           padding=(0, 0, 0)),
            residual_block_3d(filters[3], filters[3], projection=False))

        # downsample pooling
        self.down = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.down_aniso = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))

        # upsampling
        self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='nearest')
        self.up_aniso = nn.Upsample(scale_factor=(1, 2, 2), mode='nearest')

        self.dropblock: LinearScheduler = None

        self.use_flux_head = use_flux_head
        self.use_skeleton_head = use_skeleton_head

        #initialization
        ortho_init(self)

    def init_dropblock(self, start_value, stop_value, nr_steps, block_size):
        self.dropblock = LinearScheduler(DropBlock3D(drop_prob=stop_value,
                                                     block_size=block_size),
                                         start_value=start_value,
                                         stop_value=stop_value,
                                         nr_steps=nr_steps)

    def forward(self,
                x,
                call_dropblock_step=False,
                get_penultimate_layer=False):

        if self.dropblock and call_dropblock_step:
            self.dropblock.step()

        # encoding path
        z1 = self.layer1_E(x)
        z1 = self.dropblock(z1) if self.dropblock else z1
        x = self.down(z1) if self.symmetric else self.down_aniso(z1)

        z2 = self.layer2_E(x)
        z2 = self.dropblock(z2) if self.dropblock else z2
        x = self.down(z2)

        z3 = self.layer3_E(x)
        x = self.down(z3)

        z4 = self.layer4_E(x)

        #center ASPP block
        x = self.center(z4)

        # decoding path
        x = self.layer4_D(x)

        x = self.up(x)
        x = self.layer3_D(x)
        x = x + z3

        x = self.up(x)
        x = self.layer2_D(x)
        x = x + z2

        x = self.up(x) if self.symmetric else self.up_aniso(x)

        output = dict()
        if self.use_flux_head:
            flux = self.layer1_D_flux(x)
            # TODO share the penultimate layer code for skeleton head and flux head
            # for i, layer in enumerate(self.layer1_D_flux):
            #     x = layer(x)
            #     if get_penultimate_layer and i is 0:
            #         output['penultimate_layer'] = x
            flux = torch.tanh(flux)
            output['flux'] = flux
        if self.use_skeleton_head:
            skeleton = self.layer1_D_skeleton(x)
            skeleton = torch.sigmoid(skeleton)
            output['skeleton'] = skeleton

        if not output:
            raise ValueError("Neither flux or skeleton head was specified.")

        return output
class ResNetDropblock(nn.Module):
    # TNet with Non-Local-Block on the 4th stage (before the last conv of the 4th stage)
    def __init__(self, output_class=7, model_path=None, resnet_type=34, drop_block=False, drop_prob=0.5, drop_pos=None, layer_depth=4, drop_block_size=7):
        super(ResNetDropblock, self).__init__()

        assert resnet_type in [18, 34, 50]
        assert layer_depth in [1, 2, 3, 4]
        if resnet_type == 18:
            self.base = resnet18(pretrained=False, num_classes=1000)
            # state_dict = torch.load('resnet18.pth')
            if layer_depth == 4:
                last_fc_in_channel = 512 * 1
            elif layer_depth == 3:
                last_fc_in_channel = 256 * 1
            elif layer_depth == 2:
                last_fc_in_channel = 128 * 1
            else:  # elif layer_depth == 1:
                last_fc_in_channel = 64 * 1

        elif resnet_type == 34:
            self.base = resnet34(pretrained=False, num_classes=1000)
            # state_dict = torch.load('resnet34.pth')
            if layer_depth == 4:
                last_fc_in_channel = 512 * 1
            elif layer_depth == 3:
                last_fc_in_channel = 256 * 1
            elif layer_depth == 2:
                last_fc_in_channel = 128 * 1
            else:  # elif layer_depth == 1:
                last_fc_in_channel = 64 * 1
        else:  # elif resnet_type == 50:
            self.base = resnet50(pretrained=False, num_classes=1000)
            # state_dict = torch.load('resnet50.pth')
            if layer_depth == 4:
                last_fc_in_channel = 512 * 4
            elif layer_depth == 3:
                last_fc_in_channel = 256 * 4
            elif layer_depth == 2:
                last_fc_in_channel = 128 * 4
            else:  # elif layer_depth == 1:
                last_fc_in_channel = 64 * 4

        # self.base.load_state_dict(state_dict)

        # def weight_init(m):
        #     if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        #         nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
        #         if m.bias is not None:
        #             nn.init.zeros_(m.bias)
        #
        # self.base.layer3.apply(weight_init)
        # self.base.layer4.apply(weight_init)

        self.base.fc = nn.Linear(in_features=last_fc_in_channel, out_features=output_class, bias=True)

        if drop_block:
            self.dropblock = LinearScheduler(
                DropBlock2D(drop_prob=drop_prob, block_size=drop_block_size),
                start_value=0.,
                stop_value=drop_prob,
                nr_steps=300
            )
        else:
            self.dropblock = nn.Sequential()

        self.drop_pos = drop_pos
        self.layer_depth = layer_depth

        if model_path is not None:
            self.model_path = model_path
            assert '.pth' in self.model_path or '.pkl' in self.model_path
            self.init_weight()

    def init_weight(self):
        print('Loading weights into state dict...')
        self.load_state_dict(torch.load(self.model_path, map_location=lambda storage, loc: storage))
        print('Finished!')

    def forward(self, x):
        if type(self.dropblock) != nn.Sequential:
            self.dropblock.step()

        x = self.base.conv1(x)
        x = self.base.bn1(x)
        x = self.base.relu(x)
        x = self.base.maxpool(x)

        if self.drop_pos == 1 or self.drop_pos is None:
            x = self.dropblock(self.base.layer1(x))
        else:
            x = self.base.layer1(x)

        if self.layer_depth > 1:
            if self.drop_pos == 2 or self.drop_pos is None:
                x = self.dropblock(self.base.layer2(x))
            else:
                x = self.base.layer2(x)

            if self.layer_depth > 2:
                if self.drop_pos == 3 or self.drop_pos is None:
                    x = self.dropblock(self.base.layer3(x))
                else:
                    x = self.base.layer3(x)

                if self.layer_depth > 3:
                    if self.drop_pos == 4 or self.drop_pos is None:
                        x = self.dropblock(self.base.layer4(x))
                    else:
                        x = self.base.layer4(x)

        x = self.base.avgpool(x)
        x = torch.flatten(x, 1)
        out = self.base.fc(x)

        return out
Exemple #7
0
class ResNet(nn.Module):
    def __init__(self, cfg, block, layers):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=7,
                               stride=2,
                               padding=3,
                               bias=False)  #150*150
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.cfg = cfg
        #----------new structure called DropBlock 11sr April-------------------------
        if self.cfg.MODEL.BACKBONE.DROP_BLOCK:
            drop_prob = 0.5
            block_size = 3
            self.dropblock = LinearScheduler(DropBlock2D(
                drop_prob=drop_prob, block_size=block_size),
                                             start_value=0.,
                                             stop_value=drop_prob,
                                             nr_steps=5)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  #75*75
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  #38*38
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  #19*19
        #self.layer4 = self._make_layer(block, 512,layers[3] , stride=1)         #10*10
        self.ex_layer0 = self._make_layer(block, 512, 2, stride=2)  #10*10
        #self.maxpoo2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        #2. extra_layers (ReLU will be used in the foward function) 10thApril,Xiaoyu Zhu
        #         if cfg.MODEL.BACKBONE.DEPTH>34:
        #             self.ex_layer1 = nn.Sequential(BasicBlock_modified(2048,512))
        #             #self.ex_layer1 = self._make_extra_layers(2048,512,3,1) #5*5
        #         else:
        #             self.ex_layer1 = nn.Sequential(BasicBlock_modified(512,512))
        self.ex_layer1 = self._make_layer(block, 512, 1, stride=2)  #5*5
        #self.maxpoo3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.ex_layer2 = self._make_layer(block, 256, 1, stride=2)  #3*3
        #self.maxpoo4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.ex_layer3 = self._make_layer(block, 128, 1, stride=2)

        #         if cfg.MODEL.BACKBONE.DEPTH>34:
        #             self.ex_layer3 = self._make_extra_layers(256*4,128,[2,3],0)
        #         else:
        #             #self.ex_layer3 = self._make_extra_layers(256,128,[2,3],0)
        #             self.ex_layer3 = self._make_layer(block, 128, 1, stride=2)

        #BasicBlock_modified(inplanes=256, planes=128, kernel=[2,3],stride=2, padding=0)

        # kaiming weight normal after default initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    # construct layer/stage conv2_x,conv3_x,conv4_x,conv5_x
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        # when to need downsample
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,
                          planes * block.expansion,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        # inplanes expand for next block
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_extra_layers(self, input_channels, output_channels, k,
                           p):  #10thApril,Xiaoyu Zhu
        layers = []
        layers.append(
            torch.nn.Conv2d(in_channels=input_channels,
                            out_channels=output_channels,
                            kernel_size=k,
                            stride=1,
                            padding=1))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.BatchNorm2d(output_channels))
        #layers.append(nn.Dropout(0.5))
        layers.append(
            torch.nn.Conv2d(in_channels=output_channels,
                            out_channels=output_channels,
                            kernel_size=k,
                            stride=2,
                            padding=p))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.BatchNorm2d(output_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        if self.cfg.MODEL.BACKBONE.DROP_BLOCK:
            self.dropblock.step()  # increment number of iterations
        out_features = []
        #print('Input',x.shape) #Original 300*300; For rectange input size :320*240
        x = self.conv1(x)
        #print('self.conv1',x.shape) #150*150; 160*120
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        #print('self.maxpool',x.shape) #75*75; 80*60
        if self.cfg.MODEL.BACKBONE.DROP_BLOCK:
            x = self.dropblock(self.layer1(x))  #added 11st April
        else:
            x = self.layer1(x)
        #print('layer1',x.shape) #80*60
        #out_features.append(x) #Add new feature map 60*80
        if self.cfg.MODEL.BACKBONE.DROP_BLOCK:
            x = self.dropblock(self.layer2(x))
        else:
            x = self.layer2(x)
        #print('layer2',x.shape)  #38*38 output[0]; 30*40
        out_features.append(x)
        x = self.layer3(x)
        #print('layer3',x.shape) #19*19 output[1]; 15*20
        out_features.append(x)
        #x = self.layer4(x)
        x = self.ex_layer0(x)
        #x = self.maxpoo2(x)
        #print('layer4',x.shape) #10*10 output[2]; 8*10
        out_features.append(x)
        #For other output: 10thApril,Xiaoyu Zhu
        x = self.ex_layer1(x)
        #x = self.maxpoo3(x)
        #print('ex_layer1',x.shape) #5*5 output[3] ;4*5
        out_features.append(x)
        x = self.ex_layer2(x)
        #x = self.maxpoo4(x)
        #print('ex_layer2',x.shape) #3*2 output[4] ;2*3
        out_features.append(x)
        x = self.ex_layer3(x)
        #print('ex_layer3',x.shape) #1*1 output[5]
        out_features.append(x)
        #-----------------------------------Old Version-------------------
        #         #For other outputs:
        #         for i, f in enumerate(self.extra_layer): #i means the index and f means the function of the layer
        #             if i== len(self.extra_layer):
        #                 x= f(x)
        #             else:
        #                 x= torch.nn.functional.relu(f(x),inplace=True)
        #             if i % 2 == 1:
        #                 out_features.append(x)
        #-----------------------------------------------------------------
        return tuple(out_features)
class CNN(nn.Module):
    def __init__(self, grayscale=False):
        super(CNN, self).__init__()

        # hyperparameters
        # Activation = Mish()
        Activation = nn.ReLU()
        drop_prob = 0.2
        block_size = 5

        kernel_sizes = np.array([7, 3, 3, 3, 3])
        conv_strides = np.array([3, 1, 1, 1, 1])
        pad_sizes = (kernel_sizes - 1) / 2
        out_channels = np.array([64, 128, 256, 256, 512])
        pool_sizes = np.array([2, 2, 2, 2, 2])
        pool_strides = np.array([2, 2, 2, 2, 2])

        fcs = [512, 100]  #512, 100

        if do_dropblock:
            self.Dropblock = LinearScheduler(DropBlock2D(
                drop_prob=drop_prob, block_size=block_size),
                                             start_value=0.,
                                             stop_value=drop_prob,
                                             nr_steps=5000)

        conv_layers = []
        assert len(kernel_sizes) == len(out_channels) == len(
            pool_sizes), "inconsistent layer length"
        layers = range(len(kernel_sizes))
        input_dim = 1 if grayscale else 3
        out_channels = np.insert(out_channels, 0, input_dim)
        for l in layers:
            conv_layers.append(
                nn.Conv2d(out_channels[l],
                          out_channels[l + 1],
                          kernel_sizes[l],
                          stride=conv_strides[l],
                          padding=int(pad_sizes[l])))
            conv_layers.append(nn.BatchNorm2d(out_channels[l + 1]))
            # conv_layers.append(nn.Dropout(0.1))
            if do_dropblock:
                conv_layers.append(self.Dropblock)
            conv_layers.append(Activation)
            if pool_sizes[l]:
                conv_layers.append(
                    nn.MaxPool2d(int(pool_sizes[l]),
                                 stride=int(pool_strides[l])))

        self.conv = nn.Sequential(*conv_layers)

        # compute input shape of FCs
        x = torch.zeros([1, input_dim, 224, 224])
        x = self.conv(x)
        x = x.view(1, -1)
        input_FC = x.size(1)

        FCs = []
        layers = range(len(fcs))
        # input length: input_FC
        fcs = np.insert(fcs, 0, input_FC)
        for l in layers:
            FCs.append(nn.Linear(fcs[l], fcs[l + 1]))
            FCs.append(Activation)
            FCs.append(nn.BatchNorm1d(fcs[l + 1]))
            FCs.append(nn.Dropout(0.5))

        self.fc = nn.Sequential(*FCs)

        self.out = nn.Linear(fcs[-1], 3)  # three categories: class A, B, C

    def forward(self, x):
        if do_dropblock:
            self.Dropblock.step()  # increment number of iterations
        x = self.conv(x)
        x = x.view(x.size(0), -1)  # flattened
        if not do_dropblock:
            x = F.dropout(x, p=0.5)
        x = self.fc(x)
        x = self.out(x)
        return x

    def get_conv_output(self, x):
        if do_dropblock:
            self.Dropblock.step()  # increment number of iterations
        x = self.conv(x)
        x = x.view(x.size(0), -1)  # flattened
        return x.cpu().numpy()