示例#1
0
    def __init__(self, params):
        """

        :param params: {'num_channels':1,
                        'num_filters':64,
                        'kernel_h':5,
                        'kernel_w':5,
                        'stride_conv':1,
                        'pool':2,
                        'stride_pool':2,
                        'num_classes':28
                        'se_block': False,
                        'drop_out':0.2}
        """
        super(CustomQuickNat, self).__init__()

        self.encode1 = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = params['num_filters']
        self.encode2 = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode3 = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode4 = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.bottleneck = DenseBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 2 * params['num_filters']
        self.decode1 = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode2 = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode3 = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode4 = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = params['num_filters']
        self.classifier = sm.ClassifierBlock(params)
    def __init__(self, params):
        """

        :param params: {'num_channels':1,
                        'num_filters':64,
                        'kernel_h':5,
                        'kernel_w':5,
                        'stride_conv':1,
                        'pool':2,
                        'stride_pool':2,
                        'num_classes':28
                        'se_block': False,
                        'drop_out':0.2}
        """
        super(QuickFCN, self).__init__()

        self.encode1 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 64
        self.encode2 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode3 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode4 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.bottleneck = sm.DenseBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 128
        ############Segmentation Task############
        self.decode1 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode2 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode3 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode4 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 64
        self.segmenter = sm.ClassifierBlock(params)
        ############Classification Task############
        self.classifier = nn.Sequential(nn.Linear(40000, 25), nn.PReLU(),
                                        nn.Linear(25, 3))
 def __init__(self, params):
     super(SDnetConditioner, self).__init__()
     params['num_channels'] = 1
     params['num_filters'] = 64
     self.encode1 = sm.SDnetEncoderBlock(params)
     params['num_channels'] = 64
     self.encode2 = sm.SDnetEncoderBlock(params)
     self.encode3 = sm.SDnetEncoderBlock(params)
     self.bottleneck = sm.GenericBlock(params)
     params['num_channels'] = 128
     self.decode1 = sm.SDnetDecoderBlock(params)
     self.squeeze_conv_d1 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1,
                                    kernel_size=(1, 1),
                                    padding=(0, 0),
                                    stride=1)
     self.decode2 = sm.SDnetDecoderBlock(params)
     self.squeeze_conv_d2 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1,
                                    kernel_size=(1, 1),
                                    padding=(0, 0),
                                    stride=1)
     self.decode3 = sm.SDnetDecoderBlock(params)
     self.squeeze_conv_d3 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1,
                                    kernel_size=(1, 1),
                                    padding=(0, 0),
                                    stride=1)
     params['num_channels'] = 64
     self.classifier = sm.ClassifierBlock(params)
     self.sigmoid = nn.Sigmoid()
    def __init__(self, params):
        super(SDnetConditioner, self).__init__()
        params['num_channels'] = 2
        params['num_filters'] = 16
        self.encode1 = sm.SDnetEncoderBlock(params)

        params['num_channels'] = 16
        self.encode2 = sm.SDnetEncoderBlock(params)

        self.encode3 = sm.SDnetEncoderBlock(params)

        self.encode4 = sm.SDnetEncoderBlock(params)

        self.bottleneck = sm.GenericBlock(params)

        params['num_channels'] = 16
        self.decode1 = sm.SDnetDecoderBlock(params)

        self.decode2 = sm.SDnetDecoderBlock(params)

        self.decode3 = sm.SDnetDecoderBlock(params)

        self.decode4 = sm.SDnetDecoderBlock(params)

        params['num_channels'] = 16
        self.classifier = sm.ClassifierBlock(params)
        self.sigmoid = nn.Sigmoid()

        self.fc_layer = nn.Linear(params['num_filters'], 64, bias=True)
示例#5
0
    def __init__(self, params):
        super(SDnetConditioner, self).__init__()
        se_block_type = se.SELayer.SSE
        params['num_channels'] = 2
        params['num_filters'] = 16
        self.encode1 = sm.SDnetEncoderBlock(params)

        params['num_channels'] = 16
        self.encode2 = sm.SDnetEncoderBlock(params)

        self.encode3 = sm.SDnetEncoderBlock(params)

        self.encode4 = sm.SDnetEncoderBlock(params)

        self.bottleneck = sm.GenericBlock(params)
        self.squeeze_conv_bn = nn.Conv2d(in_channels=params['num_filters'],
                                         out_channels=1,
                                         kernel_size=(1, 1),
                                         padding=(0, 0),
                                         stride=1)
        params['num_channels'] = 16
        self.decode1 = sm.SDnetDecoderBlock(params)
        self.decode2 = sm.SDnetDecoderBlock(params)
        self.decode3 = sm.SDnetDecoderBlock(params)
        self.decode4 = sm.SDnetDecoderBlock(params)
        params['num_channels'] = 16
        self.classifier = sm.ClassifierBlock(params)
        self.sigmoid = nn.Sigmoid()
示例#6
0
    def __init__(self, params):
        super(SDnetConditioner, self).__init__()
        se_block_type = se.SELayer.SSE
        params['num_channels'] = 2
        params['num_filters'] = 16
        self.encode1 = sm.SDnetEncoderBlock(params)

        params['num_channels'] = 16

        self.encode2 = sm.SDnetEncoderBlock(params)

        self.encode3 = sm.SDnetEncoderBlock(params)

        self.encode4 = sm.SDnetEncoderBlock(params)

        self.bottleneck = sm.GenericBlock(params)

        params['num_channels'] = 16

        self.decode1 = sm.SDnetDecoderBlock(params)
        self.channel_conv_d1 = nn.Linear(params['num_filters'], 64, bias=True)

        self.decode2 = sm.SDnetDecoderBlock(params)
        self.channel_conv_d2 = nn.Linear(params['num_filters'], 64, bias=True)

        self.decode3 = sm.SDnetDecoderBlock(params)
        self.channel_conv_d3 = nn.Linear(params['num_filters'], 64, bias=True)

        self.decode4 = sm.SDnetDecoderBlock(params)
        self.channel_conv_d4 = nn.Linear(params['num_filters'], 64, bias=True)

        params['num_channels'] = 16

        self.classifier = sm.ClassifierBlock(params)
        self.sigmoid = nn.Sigmoid()
    def __init__(self, params):
        """

        :param params: {'num_channels':1,
                        'num_filters':64,
                        'kernel_h':5,
                        'kernel_w':5,
                        'stride_conv':1,
                        'pool':2,
                        'stride_pool':2,
                        'num_classes':28
                        'se_block': False,
                        'drop_out':0.2}
        """
        super(QuickResNet, self).__init__()

        self.resultnet = ResUltNet(
            params)  # TODO: Is it right to pass params here too?

        self.unbottle = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(64, 32, 4, 1, 0, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(16, 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(8, 3, 4, 2, 1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(True))

        self.encode1 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 64
        self.encode2 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode3 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode4 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.bottleneck = sm.DenseBlock(params, se_block_type=se.SELayer.CSSE)
        self.conv1 = torch.nn.Conv2d(params['num_channels'], 3, kernel_size=1)
        params['num_channels'] = 128
        self.decode1 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode2 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode3 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode4 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 64
        self.classifier = sm.ClassifierBlock(params)
 def __init__(self, params):
     super(Segmentor, self).__init__()
     params['num_channels'] = 1
     self.encode1 = sm.EncoderBlock(params)
     params['num_channels'] = 64
     self.encode2 = sm.EncoderBlock(params)
     self.encode3 = sm.EncoderBlock(params)
     self.bottleneck = sm.DenseBlock(params)
     params['num_channels'] = 128
     self.decode1 = sm.DecoderBlock(params, se_block_type=se.SELayer.NONE)
     self.decode2 = sm.DecoderBlock(params, se_block_type=se.SELayer.NONE)
     self.decode3 = sm.DecoderBlock(params, se_block_type=se.SELayer.NONE)
     params['num_channels'] = 64
     self.classifier = sm.ClassifierBlock(params)
     self.sigmoid = nn.Sigmoid()
示例#9
0
 def __init__(self, params):
     super(SDnetSegmentor, self).__init__()
     params['num_channels'] = 1
     params['num_filters'] = 64
     self.encode1 = sm.SDnetEncoderBlock(params)
     params['num_channels'] = 64
     self.encode2 = sm.SDnetEncoderBlock(params)
     self.encode3 = sm.SDnetEncoderBlock(params)
     self.bottleneck = sm.GenericBlock(params)
     params['num_channels'] = 128
     self.decode1 = sm.SDnetDecoderBlock(params)
     self.decode2 = sm.SDnetDecoderBlock(params)
     self.decode3 = sm.SDnetDecoderBlock(params)
     params['num_channels'] = 64
     self.classifier = sm.ClassifierBlock(params)
示例#10
0
    def __init__(self, params):

        super(QuickNat, self).__init__()

        self.encode1 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 64
        self.encode2 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode3 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode4 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.bottleneck = sm.DenseBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 128
        self.decode1 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode2 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode3 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode4 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 64
        self.classifier = sm.ClassifierBlock(params)
示例#11
0
    def __init__(self):

        super(Network, self).__init__()
        params['num_channels'] = 1
        params['num_class'] = pretrained_num_classes
        self.encode1 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 64
        self.encode2 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode3 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode4 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.bottleneck = sm.DenseBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 128
        self.decode1 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode2 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode3 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode4 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        params['num_channels'] = 64
        self.classifier = sm.ClassifierBlock(params)
示例#12
0
 def __init__(self, params):
     super(SDnetSegmentor, self).__init__()
     se_block_type = se.SELayer.SSE
     params['num_channels'] = 1
     params['num_filters'] = 64
     self.encode1 = sm.SDnetEncoderBlock(params)
     params['num_channels'] = 64
     self.encode2 = sm.SDnetEncoderBlock(params)
     self.encode3 = sm.SDnetEncoderBlock(params)
     self.encode4 = sm.SDnetEncoderBlock(params)
     self.bottleneck = sm.GenericBlock(params)
     params['num_channels'] = 128
     self.decode1 = sm.SDnetDecoderBlock(params)
     self.decode2 = sm.SDnetDecoderBlock(params)
     self.decode3 = sm.SDnetDecoderBlock(params)
     self.decode4 = sm.SDnetDecoderBlock(params)
     params['num_channels'] = 64
     self.classifier = sm.ClassifierBlock(params)
     self.soft_max = nn.Softmax2d()
    def __init__(self, params):
        super(SDnetSegmentor, self).__init__()
        params['num_channels'] = 1
        params['num_filters'] = 64
        self.encode1 = sm.SDnetEncoderBlock(params)
        params['num_channels'] = 64 + 16
        self.encode2 = sm.SDnetEncoderBlock(params)
        self.encode3 = sm.SDnetEncoderBlock(params)
        self.encode4 = sm.SDnetEncoderBlock(params)
        self.bottleneck = sm.GenericBlock(params)

        self.decode1 = sm.SDnetDecoderBlock(params)
        self.decode2 = sm.SDnetDecoderBlock(params)
        self.decode3 = sm.SDnetDecoderBlock(params)
        self.decode4 = sm.SDnetDecoderBlock(params)
        params['num_channels'] = 64
        self.classifier = sm.ClassifierBlock(params)
        self.soft_max = nn.Softmax2d()
        self.sigmoid = nn.Sigmoid()
示例#14
0
    def __init__(self, params):
        """

        :param params: {'num_channels':1,
                        'num_filters':64,
                        'kernel_h':5,
                        'kernel_w':5,
                        'stride_conv':1,
                        'pool':2,
                        'stride_pool':2,
                        'num_classes':28
                        'se_block': False,
                        'drop_out':0.2}
        """
        super(QuickOct, self).__init__()
        print("NUMBER OF CHANNEL", params['num_channels'])
        self.encode1 = sm.EncoderBlock(params,
                                       se_block_type=params['se_block'])

        params['num_channels'] = params['num_filters']
        self.encode2 = sm.OctaveEncoderBlock(params,
                                             se_block_type=params['se_block'])
        self.encode3 = sm.OctaveEncoderBlock(params,
                                             se_block_type=params['se_block'])
        # self.encode4 = sm.OctaveEncoderBlock(params, se_block_type=params['se_block'])

        self.bottleneck = sm.OctaveDenseBlock(params,
                                              se_block_type=params['se_block'])
        params['num_channels'] = params['num_filters'] * 2
        self.decode1 = sm.OctaveDecoderBlock(params,
                                             se_block_type=params['se_block'])
        self.decode2 = sm.OctaveDecoderBlock(params,
                                             se_block_type=params['se_block'])

        self.decode3 = sm.DecoderBlock(params,
                                       se_block_type=params['se_block'])

        # self.decode4 = sm.DecoderBlock(params, se_block_type=params['se_block'])
        params['num_channels'] = params['num_filters']
        self.classifier = sm.ClassifierBlock(params)
示例#15
0
 def __init__(self):
     super(Classifier, self).__init__()
     self.network = Network()
     params['num_class'] = num_class
     self.classifier = sm.ClassifierBlock(params)
def train(train_params, common_params, data_params, net_params):
    query_label = 8
    Num_support = 10

    # train_data, test_data = load_data(data_params)

    support_volume, support_labelmap, _, _ = du.load_and_preprocess(
        "/home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral/10000132_1_CTce_ThAb.mat",
        orientation='AXI',
        remap_config="WholeBody")
    support_volume = support_volume if len(
        support_volume.shape) == 4 else support_volume[:, np.newaxis, :, :]
    support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \
                                       torch.tensor(support_labelmap).type(torch.LongTensor)

    support_labelmap = (support_labelmap == query_label).type(
        torch.FloatTensor)
    batch, _, _ = support_labelmap.size()
    slice_with_class = torch.sum(support_labelmap.view(batch, -1), dim=1) > 10
    support_labelmap = support_labelmap[slice_with_class]
    support_volume = support_volume[slice_with_class]

    support_slice_indexes = np.round(
        np.linspace(0,
                    len(support_volume) - 1, Num_support + 1)).astype(int)
    support_slice_indexes += (len(support_volume) // Num_support) // 2
    support_slice_indexes = support_slice_indexes[:-1]
    support_volume = support_volume[support_slice_indexes]
    support_labelmap = support_labelmap[support_slice_indexes]

    train_data = ImdbData(support_volume.numpy(), support_labelmap.numpy(),
                          np.ones_like(support_labelmap.numpy()))

    # Removing unused classes
    # train_data.y[train_data.y == 3] = 0
    # test_data.y[test_data.y == 3] = 0
    #
    # train_data.y[train_data.y == 4] = 0
    # test_data.y[test_data.y == 4] = 0
    #
    # train_data.y[train_data.y == 5] = 0
    # test_data.y[test_data.y == 5] = 0
    #
    # train_data.y[train_data.y == 6] = 3
    # test_data.y[test_data.y == 6] = 3
    #
    # train_data.y[train_data.y == 7] = 4
    # test_data.y[test_data.y == 7] = 4
    #
    # train_data.y[train_data.y == 8] = 0
    # test_data.y[test_data.y == 8] = 0
    #
    # train_data.y[train_data.y == 9] = 0
    # test_data.y[test_data.y == 9] = 0

    # batch_size = len(train_data.y)
    # non_black_slices = np.sum(train_data.y.reshape(batch_size, -1), axis=1) > 10
    # train_data.X = train_data.X[non_black_slices]
    # train_data.y = train_data.y[non_black_slices]

    # batch_size = len(test_data.y)
    # non_black_slices = np.sum(test_data.y.reshape(batch_size, -1), axis=1) > 10
    # test_data.X = test_data.X[non_black_slices]
    # test_data.y = test_data.y[non_black_slices]

    model_prefix = 'finetuned_segmentor_'
    folds = ['fold4']
    for fold in folds:
        final_model_path = os.path.join(common_params['save_model_dir'],
                                        model_prefix + fold + '.pth.tar')

        train_params['exp_name'] = model_prefix + fold

        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=train_params['train_batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True)
        val_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=train_params['val_batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True)

        # conditioner_pretrained = torch.load(train_params['pre_trained_path'])
        # segmentor_pretrained.classifier = Identity()
        # segmentor_pretrained.sigmoid = Identity()

        # for param in segmentor_pretrained.parameters():
        #     param.requires_grad = False

        # few_shot_model = fs.SDnetSegmentor(net_params)
        few_shot_model = torch.load(train_params['pre_trained_path'])
        for param in few_shot_model.parameters():
            param.requires_grad = False
        net_params['num_channels'] = 64
        few_shot_model.classifier = sm.ClassifierBlock(net_params)
        for param in few_shot_model.classifier.parameters():
            param.requires_grad = True

        # few_shot_model = segmentor_pretrained
        # few_shot_model.conditioner = conditioner_pretrained

        solver = Solver(
            few_shot_model,
            device=common_params['device'],
            num_class=net_params['num_class'],
            optim_args={
                "lr": train_params['learning_rate'],
                # "betas": train_params['optim_betas'],
                # "eps": train_params['optim_eps'],
                "weight_decay": train_params['optim_weight_decay'],
                "momentum": train_params['momentum']
            },
            model_name=common_params['model_name'],
            exp_name=train_params['exp_name'],
            labels=data_params['labels'],
            log_nth=train_params['log_nth'],
            num_epochs=train_params['num_epochs'],
            lr_scheduler_step_size=train_params['lr_scheduler_step_size'],
            lr_scheduler_gamma=train_params['lr_scheduler_gamma'],
            use_last_checkpoint=train_params['use_last_checkpoint'],
            log_dir=common_params['log_dir'],
            exp_dir=common_params['exp_dir'])

        solver.train(train_loader, val_loader)

        # few_shot_model.save(final_model_path)
        # final_model_path = os.path.join(common_params['save_model_dir'], )
        solver.save_best_model(final_model_path)
        print("final model saved @ " + str(final_model_path))
示例#17
0
    def __init__(self, params):
        """

        :param params: {'num_channels':1,
                        'num_filters':64,
                        'kernel_h':5,
                        'kernel_w':5,
                        'stride_conv':1,
                        'pool':2,
                        'stride_pool':2,
                        'num_classes':28
                        'se_block': False,
                        'drop_out':0.2}
        """
        super(SoftQuickFCN, self).__init__()

        print('num channels: ', params['num_channels'])

        self.encode1_seg = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode1_class = EncoderBlock(params,
                                          se_block_type=se.SELayer.CSSE)
        cross1s_init = torch.FloatTensor(1, 1).uniform_(0.05, 0.95)
        cross1c_init = torch.FloatTensor(1, 1).uniform_(0.05, 0.95)
        self.cross1ss = torch.ones(1, params['num_filters'], 1,
                                   1) * cross1s_init
        self.cross1sc = torch.ones(1, params['num_filters'], 1,
                                   1) * (1 - cross1s_init)
        self.cross1cc = torch.ones(1, params['num_filters'], 1,
                                   1) * cross1c_init
        self.cross1cs = torch.ones(1, params['num_filters'], 1,
                                   1) * (1 - cross1c_init)

        params['num_channels'] = params['num_filters']
        self.encode2_seg = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode2_class = EncoderBlock(params,
                                          se_block_type=se.SELayer.CSSE)
        cross2s_init = torch.FloatTensor(1, 1).uniform_(0.05, 0.95)
        cross2c_init = torch.FloatTensor(1, 1).uniform_(0.05, 0.95)
        self.cross2ss = torch.ones(1, params['num_filters'], 1,
                                   1) * cross2s_init
        self.cross2sc = torch.ones(1, params['num_filters'], 1,
                                   1) * (1 - cross2s_init)
        self.cross2cc = torch.ones(1, params['num_filters'], 1,
                                   1) * cross2c_init
        self.cross2cs = torch.ones(1, params['num_filters'], 1,
                                   1) * (1 - cross2c_init)

        self.encode3_seg = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.encode3_class = EncoderBlock(params,
                                          se_block_type=se.SELayer.CSSE)
        cross3s_init = torch.FloatTensor(1, 1).uniform_(0.05, 0.95)
        cross3c_init = torch.FloatTensor(1, 1).uniform_(0.05, 0.95)
        self.cross3ss = torch.ones(1, params['num_filters'], 1,
                                   1) * cross3s_init
        self.cross3sc = torch.ones(1, params['num_filters'], 1,
                                   1) * (1 - cross3s_init)
        self.cross3cc = torch.ones(1, params['num_filters'], 1,
                                   1) * cross3c_init
        self.cross3cs = torch.ones(1, params['num_filters'], 1,
                                   1) * (1 - cross3c_init)

        self.bottleneck_seg = DenseBlock(params, se_block_type=se.SELayer.CSSE)
        self.bottleneck_class = DenseBlock(params,
                                           se_block_type=se.SELayer.CSSE)
        crossbs_init = torch.FloatTensor(1, 1).uniform_(0.05, 0.95)
        crossbc_init = torch.FloatTensor(1, 1).uniform_(0.05, 0.95)
        self.crossbss = torch.ones(1, params['num_filters'], 1,
                                   1) * crossbs_init
        self.crossbsc = torch.ones(1, params['num_filters'], 1,
                                   1) * (1 - crossbs_init)
        self.crossbcc = torch.ones(1, params['num_filters'], 1,
                                   1) * crossbc_init
        self.crossbcs = torch.ones(1, params['num_filters'], 1,
                                   1) * (1 - crossbc_init)

        params['num_channels'] = 2 * params['num_filters']
        self.decode1_seg = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode2_seg = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
        self.decode3_seg = DecoderBlock(params, se_block_type=se.SELayer.CSSE)

        params['num_channels'] = params['num_filters']
        self.classifier_seg = sm.ClassifierBlock(params)

        ############Classification Task############
        self.classifier_class = nn.Sequential(
            nn.Linear(params['num_channels'] * 50 * 50, 25), nn.PReLU(),
            nn.Linear(25, 3))