Exemple #1
0
    def __init__(self, outer_nc, inner_nc, opt, innerCos_list, shift_list, mask_global, input_nc, \
                 submodule=None, shift_layer=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d,
                 use_spectral_norm=False, layer_to_last=3):
        super(UnetSkipConnectionShiftBlock, self).__init__()
        self.outermost = outermost
        if input_nc is None:
            input_nc = outer_nc

        downconv = spectral_norm(
            nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
        # As the downconv layer is outer_nc in and inner_nc out.
        # So the shift define like this:
        shift = InnerShiftTriple(opt.shift_sz,
                                 opt.stride,
                                 opt.mask_thred,
                                 opt.triple_weight,
                                 layer_to_last=layer_to_last,
                                 device=device)

        shift.set_mask(mask_global)
        shift_list.append(shift)

        # Add latent constraint
        # Then add the constraint to the constrain layer list!
        innerCos = InnerCos(strength=opt.strength,
                            skip=opt.skip,
                            layer_to_last=layer_to_last,
                            device=device)
        innerCos.set_mask(
            mask_global)  # Here we need to set mask for innerCos layer too.
        innerCos_list.append(innerCos)

        # Different position only has differences in `upconv`
        # for the outermost, the special is `tanh`
        if outermost:
            upconv = spectral_norm(
                nn.ConvTranspose2d(inner_nc * 2,
                                   outer_nc,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1), use_spectral_norm)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
            # for the innermost, the special is `inner_nc` instead of `inner_nc*2`
        elif innermost:
            upconv = spectral_norm(
                nn.ConvTranspose2d(inner_nc,
                                   outer_nc,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1), use_spectral_norm)
            down = [downrelu, downconv
                    ]  # for the innermost, no submodule, and delete the bn
            up = [uprelu, upconv, upnorm]
            model = down + up
            # else, the normal
        else:
            # shift triple differs in here. It is `*3` not `*2`.
            upconv = spectral_norm(
                nn.ConvTranspose2d(inner_nc * 3,
                                   outer_nc,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1), use_spectral_norm)
            down = [downrelu, downconv, downnorm]
            # shift should be placed after uprelu
            # NB: innerCos is placed before shift. So need to add the latent gredient to
            # to former part.
            up = [uprelu, innerCos, shift, upconv, upnorm]

            model = down + [submodule] + up

        self.model = nn.Sequential(*model)
    def __init__(self, outer_nc, inner_nc, innerCos_list=None, shift_list=None, mask_global=None, input_nc=None, opt=None,\
                 submodule=None, shift_layer=False, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, layer_to_last=3):
        super(InceptionShiftUnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        if input_nc is None:
            input_nc = outer_nc

        if shift_layer:
            # As the downconv layer is outer_nc in and inner_nc out.
            # So the shift define like this:
            shift = InnerShiftTriple(opt.shift_sz,
                                     opt.stride,
                                     opt.mask_thred,
                                     opt.triple_weight,
                                     layer_to_last=layer_to_last)

            shift.set_mask(mask_global)
            shift_list.append(shift)

            # Add latent constraint
            # Then add the constraint to the constrain layer list!
            innerCosBefore = InnerCos(strength=opt.strength,
                                      skip=opt.skip,
                                      layer_to_last=3)
            innerCosBefore.set_mask(
                mask_global
            )  # Here we need to set mask for innerCos layer too.
            innerCos_list.append(innerCosBefore)

        downconv = InceptionDown(
            input_nc, inner_nc
        )  # nn.Conv2d(input_nc, inner_nc, kernel_size=4,stride=2, padding=1)

        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        # Different position only has differences in `upconv`
        # for the outermost, the special is `tanh`
        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2,
                                        outer_nc,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1)
            downconv = nn.Conv2d(input_nc,
                                 inner_nc,
                                 kernel_size=4,
                                 stride=2,
                                 padding=1)

            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
            # for the innermost, the special is `inner_nc` instead of `inner_nc*2`
        elif innermost:
            upconv = InceptionUp(
                inner_nc, outer_nc
            )  #nn.ConvTranspose2d(inner_nc, outer_nc,kernel_size=4, stride=2,padding=1)

            down = [downrelu, downconv
                    ]  # for the innermost, no submodule, and delete the bn
            up = [uprelu, upconv, upnorm]
            model = down + up
            # else, the normal
        else:
            upconv = InceptionUp(
                inner_nc * 3, outer_nc
            )  #nn.ConvTranspose2d(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1)

            down = [downrelu, downconv, downnorm]
            up = [uprelu, innerCosBefore, shift, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)