示例#1
0
class RNN_UNet_Config1(nn.Module):
    def __init__(self, model="ProgressiveDilated", model_path=None, kernel_size=3, num_layers=1):
        super().__init__()
        
        if model == "ProgressiveDilated":
            self.unet = UNet_ProgressiveDilated(in_channels=1, out_channels=3)
        elif model == "Dilated":
            self.unet = UNet_Dilated(in_channels=1, out_channels=3)
        elif model == "Dilated":
            self.unet = UNet_Baseline(in_channels=1, out_channels=3)
        elif model == "Original_with_BatchNorm":
            self.unet = UNet_Original_with_BatchNorm(in_channels=1, out_channels=3)
        else:
            self.unet = UNet_Original(in_channels=1, out_channels=3)
            
        if model_path:
            print("load model -- mode: {}".format(model))
            self.unet.load_state_dict(torch.load(model_path))
        
        self.convlstm_forward = ConvLSTM(input_size=(256, 256), input_dim=3, 
                                         hidden_dim=3, kernel_size=(kernel_size, kernel_size), 
                                         num_layers=num_layers, batch_first=False, 
                                         bias=True, return_all_layers=False)
        
        self.convlstm_backward = ConvLSTM(input_size=(256, 256), input_dim=3, 
                                         hidden_dim=3, kernel_size=(kernel_size, kernel_size), 
                                         num_layers=num_layers, batch_first=False, 
                                         bias=True, return_all_layers=False)
        
        self.last_conv = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, padding=1)
        
    def forward(self, x):

        out_unet = self.unet(x)
        x = out_unet[None].permute(1, 0, 2, 3, 4)
#         print(x.shape)
        
        out_forward = self.convlstm_forward(x)
        out_backward = self.convlstm_backward(torch.flip(x, (0,)))
#         print(out_forward[0][0].shape, out_backward[0][0].shape)
        
        out = torch.cat([out_forward[0][0][0], out_backward[0][0][0]], dim=1)
        out = self.last_conv(out)
        
        return out_unet, out
示例#2
0
 def __init__(self, model="ProgressiveDilated", model_path=None, kernel_size=3, num_layers=1,
              train_unet_decoder=False, train_unet=False):
     super().__init__()
     
     if model == "ProgressiveDilated":
         self.unet = UNet_ProgressiveDilated(in_channels=1, out_channels=3)
     elif model == "Dilated":
         self.unet = UNet_Dilated(in_channels=1, out_channels=3)
     elif model == "Dilated":
         self.unet = UNet_Baseline(in_channels=1, out_channels=3)
     elif model == "Original_with_BatchNorm":
         self.unet = UNet_Original_with_BatchNorm(in_channels=1, out_channels=3)
     else:
         self.unet = UNet_Original(in_channels=1, out_channels=3)
         
     if model_path:
         print("load model -- mode: {}".format(model))
         self.unet.load_state_dict(torch.load(model_path))
     
     self.train_unet_decoder = train_unet_decoder
     self.train_unet = train_unet
     
     self.convlstm_forward = ConvLSTM(input_size=(256, 256), input_dim=32, 
                                      hidden_dim=32, kernel_size=(kernel_size, kernel_size), 
                                      num_layers=num_layers, batch_first=False, 
                                      bias=True, return_all_layers=False)
     
     self.convlstm_backward = ConvLSTM(input_size=(256, 256), input_dim=32, 
                                      hidden_dim=32, kernel_size=(kernel_size, kernel_size), 
                                      num_layers=num_layers, batch_first=False, 
                                      bias=True, return_all_layers=False)
     
     self.last_conv = nn.Conv2d(in_channels=32*2, out_channels=3, kernel_size=3, padding=1)
示例#3
0
 def __init__(self, t_shift, model="ProgressiveDilated", model_path=None, kernel_size=3, 
              train_unet_decoder=False, train_unet=False):
     super().__init__()
     
     if model == "ProgressiveDilated":
         self.unet = UNet_ProgressiveDilated(in_channels=1, out_channels=3)
     elif model == "Dilated":
         self.unet = UNet_Dilated(in_channels=1, out_channels=3)
     elif model == "Baseline":
         self.unet = UNet_Baseline(in_channels=1, out_channels=3)
     elif model == "Original_with_BatchNorm":
         self.unet = UNet_Original_with_BatchNorm(in_channels=1, out_channels=3)
     else:
         self.unet = UNet_Original(in_channels=1, out_channels=3)
         
     if model_path:
         print("load model -- mode: {}".format(model))
         self.unet.load_state_dict(torch.load(model_path))
     
     self.train_unet_decoder = train_unet_decoder
     self.train_unet = train_unet
         
     self.conv_temp_block = nn.Sequential(
         nn.Conv3d(in_channels=32, out_channels=32, kernel_size=(t_shift, kernel_size, kernel_size), 
                   padding=( int((t_shift - 1)/2), int( (kernel_size-1) /2), int( (kernel_size-1) /2) )),
         nn.BatchNorm3d(32),
         nn.PReLU(),
         
         nn.Conv3d(in_channels=32, out_channels=32, kernel_size=(t_shift, kernel_size, kernel_size), 
                   padding=( int((t_shift - 1)/2), int( (kernel_size-1) /2), int( (kernel_size-1) /2) ))
     )
     self.last_conv = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, padding=1)
示例#4
0
class Temporal_UNet_Config1_2(nn.Module):
    def __init__(self,
                 t_shift,
                 model="ProgressiveDilated",
                 model_path=None,
                 kernel_size=3,
                 train_unet_decoder=False,
                 train_unet=False):
        super().__init__()

        if model == "ProgressiveDilated":
            self.unet = UNet_ProgressiveDilated(in_channels=1, out_channels=3)
        elif model == "Dilated":
            self.unet = UNet_Dilated(in_channels=1, out_channels=3)
        elif model == "Baseline":
            self.unet = UNet_Baseline(in_channels=1, out_channels=3)
        elif model == "Original_with_BatchNorm":
            self.unet = UNet_Original_with_BatchNorm(in_channels=1,
                                                     out_channels=3)
        else:
            self.unet = UNet_Original(in_channels=1, out_channels=3)

        if model_path:
            print("load model -- mode: {}".format(model))
            self.unet.load_state_dict(torch.load(model_path))

        self.train_unet_decoder = train_unet_decoder
        self.train_unet = train_unet

        self.conv_temp_block = nn.Sequential(
            nn.Conv3d(in_channels=32,
                      out_channels=32,
                      kernel_size=(t_shift, kernel_size, kernel_size),
                      padding=(int(
                          (t_shift - 1) / 2), int(
                              (kernel_size - 1) / 2), int(
                                  (kernel_size - 1) / 2))))

    def forward(self, x):

        # 1
        with torch.set_grad_enabled(self.train_unet):
            encoding = []
            for conv_block in self.unet.model.conv_down:
                x = conv_block(x)

                if self.unet.model.downsample_type == 'conv_stride':
                    encoding.append(x)
                if self.unet.model.downsample_type == 'maxpool':
                    encoding.append(x)
                    x = self.unet.model.maxpool(x)

        # 2
        with torch.set_grad_enabled(self.train_unet_decoder
                                    or self.train_unet):
            if self.unet.model.advanced_bottleneck:
                x = self.unet.model.conv_middle_part1(x)
                x = self.unet.model.conv_middle_part2(x)
            else:
                x = self.unet.model.conv_middle(x)

            for i, conv_block in enumerate(self.unet.model.conv_up):
                x = self.unet.model.upsample(x)
                x = torch.cat([x, encoding[::-1][i]], dim=1)
                x = conv_block(x)

        # 3
        # Permute batch_size and channel dims
        x = x.permute(1, 0, 2, 3)
        x = x[None]  # 1 x C x T x H x W

        # Temporal convolutional block
        x = self.conv_temp_block(x)
        x = x[0, :, :, :, :]
        x = x.permute(1, 0, 2, 3)

        # Output
        out_unet = self.unet.model.conv_last(x)

        return out_unet, out_unet
示例#5
0
class RNN_UNet_Config1_1(nn.Module):
    def __init__(self, model="ProgressiveDilated", model_path=None, kernel_size=3, num_layers=1,
                 train_unet_decoder=False, train_unet=False):
        super().__init__()
        
        if model == "ProgressiveDilated":
            self.unet = UNet_ProgressiveDilated(in_channels=1, out_channels=3)
        elif model == "Dilated":
            self.unet = UNet_Dilated(in_channels=1, out_channels=3)
        elif model == "Dilated":
            self.unet = UNet_Baseline(in_channels=1, out_channels=3)
        elif model == "Original_with_BatchNorm":
            self.unet = UNet_Original_with_BatchNorm(in_channels=1, out_channels=3)
        else:
            self.unet = UNet_Original(in_channels=1, out_channels=3)
            
        if model_path:
            print("load model -- mode: {}".format(model))
            self.unet.load_state_dict(torch.load(model_path))
        
        self.train_unet_decoder = train_unet_decoder
        self.train_unet = train_unet
        
        self.convlstm_forward = ConvLSTM(input_size=(256, 256), input_dim=32, 
                                         hidden_dim=32, kernel_size=(kernel_size, kernel_size), 
                                         num_layers=num_layers, batch_first=False, 
                                         bias=True, return_all_layers=False)
        
        self.convlstm_backward = ConvLSTM(input_size=(256, 256), input_dim=32, 
                                         hidden_dim=32, kernel_size=(kernel_size, kernel_size), 
                                         num_layers=num_layers, batch_first=False, 
                                         bias=True, return_all_layers=False)
        
        self.last_conv = nn.Conv2d(in_channels=32*2, out_channels=3, kernel_size=3, padding=1)
        
    def forward(self, x):

        # 1
        with torch.set_grad_enabled(self.train_unet):
            encoding = []
            for conv_block in self.unet.model.conv_down:
                x = conv_block(x)

                if self.unet.model.downsample_type == 'conv_stride':
                    encoding.append(x)
                if self.unet.model.downsample_type == 'maxpool':
                    encoding.append(x)
                    x = self.unet.model.maxpool(x)
        
        # 2
        with torch.set_grad_enabled(self.train_unet_decoder or self.train_unet):
         
            if self.unet.model.advanced_bottleneck:
                x = self.unet.model.conv_middle_part1(x)
                x = self.unet.model.conv_middle_part2(x)
            else:
                x = self.unet.model.conv_middle(x)

            for i, conv_block in enumerate(self.unet.model.conv_up):
                x = self.unet.model.upsample(x)
                x = torch.cat([x, encoding[::-1][i]], dim=1)
                x = conv_block(x)

            out_unet = self.unet.model.conv_last(x)
        
        # 3
        x = x[None].permute(1, 0, 2, 3, 4)

        out_forward = self.convlstm_forward(x)
        out_backward = self.convlstm_backward(torch.flip(x, (0,)))

        out = torch.cat([out_forward[0][0][0], out_backward[0][0][0]], dim=1)
        out = self.last_conv(out)
        
        return out_unet, out