Ejemplo n.º 1
0
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.
        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(PropUnitCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.resblock, _ = make_resblock(self.input_dim,
                                         32,
                                         blocks=2,
                                         stride=1,
                                         block=Bottleneck)

        self.conv = nn.Conv2d(in_channels=32 * Bottleneck.expansion +
                              self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
Ejemplo n.º 2
0
    def __init__(self, in_channel):
        super(AdaMatting, self).__init__()

        # Encoder
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        encoder_inplanes = 64
        self.encoder_maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.encoder_resblock1, encoder_inplanes = make_resblock(encoder_inplanes, 64, blocks=3, stride=2, block=Bottleneck)
        self.encoder_resblock2, encoder_inplanes = make_resblock(encoder_inplanes, 128, blocks=4, stride=2, block=Bottleneck)
        self.encoder_resblock3, encoder_inplanes = make_resblock(encoder_inplanes, 256, blocks=6, stride=2, block=Bottleneck)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        
        # Shortcuts
        self.shortcut_shallow = GCN(64, 64)
        self.shortcut_middle = GCN(64 * Bottleneck.expansion, 64 * Bottleneck.expansion)
        self.shortcut_deep = GCN(128 * Bottleneck.expansion, 128 * Bottleneck.expansion)

        # T-decoder
        self.t_decoder_upscale1 = nn.Sequential(
            nn.Conv2d(256 * Bottleneck.expansion, 512 * 4, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(512 * 4),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )
        self.t_decoder_upscale2 = nn.Sequential(
            nn.Conv2d(512, 256 * 4, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(256 * 4),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )
        self.t_decoder_upscale3 = nn.Sequential(
            nn.Conv2d(256, 64 * 4, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )
        self.t_decoder_upscale4 = nn.Sequential(
            nn.Conv2d(64, 3 * (2 ** 2), kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(3 * (2 ** 2)),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )

        # A-deocder
        self.a_decoder_upscale1 = nn.Sequential(
            nn.Conv2d(256 * Bottleneck.expansion, 512 * 4, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(512 * 4),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )
        self.a_decoder_upscale2 = nn.Sequential(
            nn.Conv2d(512, 256 * 4, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(256 * 4),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )
        self.a_decoder_upscale3 = nn.Sequential(
            nn.Conv2d(256, 64 * 4, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )
        self.a_decoder_upscale4 = nn.Sequential(
            nn.Conv2d(64, 1 * (2 ** 2), kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(1 * (2 ** 2)),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )

        # Propagation unit
        # self.propunit = PropUnit(
        #     input_dim=4 + 1 + 1,
        #     hidden_dim=[1],
        #     kernel_size=(3, 3),
        #     num_layers=3,
        #     seq_len=3,
        #     bias=True)
        self.prop_unit = nn.Sequential(
            nn.Conv2d(4 + 3 + 1, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0, bias=True),
        )

        # Task uncertainty loss parameters
        self.log_sigma_t_sqr = nn.Parameter(torch.log(torch.Tensor([16.0])))
        self.log_sigma_a_sqr = nn.Parameter(torch.log(torch.Tensor([16.0])))
Ejemplo n.º 3
0
    def __init__(self, in_channel):
        super(AdaMatting, self).__init__()

        # Encoder
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(in_channel,
                      64,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=True), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        encoder_inplanes = 64
        self.encoder_maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.encoder_resblock1, encoder_inplanes = make_resblock(
            encoder_inplanes, 64, blocks=3, stride=2, block=Bottleneck)
        self.encoder_resblock2, encoder_inplanes = make_resblock(
            encoder_inplanes, 128, blocks=3, stride=2, block=Bottleneck)
        self.encoder_resblock3, encoder_inplanes = make_resblock(
            encoder_inplanes, 256, blocks=3, stride=2, block=Bottleneck)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

        #Boundary Refinement
        self.br1 = BR(64)
        self.br2 = BR(64 * Bottleneck.expansion)
        self.br3 = BR(128 * Bottleneck.expansion)

        #  RES boundary Shortcuts
        shortcut_inplanes = 64
        self.shortcut_shallow_intial, shortcut_inplanes = make_resblock(
            shortcut_inplanes, 256, blocks=1, stride=2, block=Bottleneck)
        self.shortcut_shallow = self.br1(self.shortcut_shallow_intial)
        self.shortcut_middle_initial, shortcut_inplanes = make_resblock(
            shortcut_inplanes, 256, blocks=1, stride=2, block=Bottleneck)
        self.shortcut_shallow = self.br2(self.shortcut_middle_initial)
        self.shortcut_deep_initial, shortcut_inplanes = make_resblock(
            shortcut_inplanes, 256, blocks=1, stride=2, block=Bottleneck)
        self.shortcut_deep = self.br3(self.shortcut_deep_initial)

        # Boundary GCN Shortcuts
        # self.shortcut_shallow_intial = GCN(64, 64)
        # self.shortcut_shallow = self.br1(self.shortcut_shallow_intial)
        # self.shortcut_middle_initial = GCN(64 * Bottleneck.expansion, 64 * Bottleneck.expansion)
        # self.shortcut_middle = self.br2(self.shortcut_middle_initial)
        # self.shortcut_deep_initial = GCN(128 * Bottleneck.expansion, 128 * Bottleneck.expansion)
        # self.shortcut_deep = self.br3(self.shortcut_deep_initial)

        # Original shortcuts
        # self.shortcut_shallow = GCN(64, 64)
        # self.shortcut_middle = GCN(64 * Bottleneck.expansion, 64 * Bottleneck.expansion)
        # self.shortcut_deep = GCN(128 * Bottleneck.expansion, 128 * Bottleneck.expansion)
        # Separate two middle shortcuts
        # self.shortcut_shallow = self.shortcut_block(64, 64)
        # self.shortcut_middle_a = self.shortcut_block(64 * Bottleneck.expansion, 64 * Bottleneck.expansion)
        # self.shortcut_middle_t = self.shortcut_block(64 * Bottleneck.expansion, 64 * Bottleneck.expansion)
        # self.shortcut_deep = self.shortcut_block(128 * Bottleneck.expansion, 128 * Bottleneck.expansion)

        # T-decoder
        self.t_decoder_upscale1 = nn.Sequential(
            self.decoder_unit(256 * Bottleneck.expansion, 512 * 4),
            self.decoder_unit(512 * 4, 512 * 4), nn.PixelShuffle(2))
        self.t_decoder_upscale2 = nn.Sequential(
            self.decoder_unit(512, 256 * 4),
            self.decoder_unit(256 * 4, 256 * 4), nn.PixelShuffle(2))
        self.t_decoder_upscale3 = nn.Sequential(
            self.decoder_unit(256, 64 * 4), self.decoder_unit(64 * 4, 64 * 4),
            nn.PixelShuffle(2))
        self.t_decoder_upscale4 = nn.Sequential(
            self.decoder_unit(64, 3 * (2**2)),
            self.decoder_unit(3 * (2**2), 3 * (2**2)), nn.PixelShuffle(2))

        # A-deocder
        self.a_decoder_upscale1 = nn.Sequential(
            self.decoder_unit(256 * Bottleneck.expansion, 512 * 4),
            self.decoder_unit(512 * 4, 512 * 4), nn.PixelShuffle(2))
        self.a_decoder_upscale2 = nn.Sequential(
            self.decoder_unit(512, 256 * 4),
            self.decoder_unit(256 * 4, 256 * 4), nn.PixelShuffle(2))
        self.a_decoder_upscale3 = nn.Sequential(
            self.decoder_unit(256, 64 * 4), self.decoder_unit(64 * 4, 64 * 4),
            nn.PixelShuffle(2))
        self.a_decoder_upscale4 = nn.Sequential(
            self.decoder_unit(64, 1 * (2**2)),
            self.decoder_unit(1 * (2**2), 1 * (2**2)), nn.PixelShuffle(2))

        # Propagation unit
        # self.propunit = PropUnit(
        #     input_dim=4 + 1 + 1,
        #     hidden_dim=[1],
        #     kernel_size=(3, 3),
        #     num_layers=3,
        #     seq_len=3,
        #     bias=True)
        self.prop_unit = nn.Sequential(
            nn.Conv2d(3 + 1 + 1,
                      64,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=True),
        )

        # Task uncertainty loss parameters
        self.log_sigma_t_sqr = nn.Parameter(torch.log(torch.Tensor([16.0])))
        self.log_sigma_a_sqr = nn.Parameter(torch.log(torch.Tensor([16.0])))