Exemplo n.º 1
0
    def _init_nets(self):
        data_type = torch.cuda.FloatTensor
        pad = 'reflection'
        reflection_net = skip(self.input_depth,
                              self.images[0].shape[0],
                              num_channels_down=[8, 16, 32, 64, 128],
                              num_channels_up=[8, 16, 32, 64, 128],
                              num_channels_skip=[0, 0, 0, 4, 4],
                              upsample_mode='bilinear',
                              filter_size_down=5,
                              filter_size_up=5,
                              need_sigmoid=True,
                              need_bias=True,
                              pad=pad,
                              act_fun='LeakyReLU')

        self.reflection_net = reflection_net.type(data_type)

        transmission_net = skip(self.input_depth,
                                self.images[0].shape[0],
                                num_channels_down=[8, 16, 32, 64, 128],
                                num_channels_up=[8, 16, 32, 64, 128],
                                num_channels_skip=[0, 0, 0, 4, 4],
                                upsample_mode='bilinear',
                                filter_size_down=5,
                                filter_size_up=5,
                                need_sigmoid=True,
                                need_bias=True,
                                pad=pad,
                                act_fun='LeakyReLU')

        self.transmission_net = transmission_net.type(data_type)
Exemplo n.º 2
0
    def _init_nets(self):
        pad = 'reflection'
        net = skip(self.input_depth,
                   3,
                   num_channels_down=[8, 16, 32],
                   num_channels_up=[8, 16, 32],
                   num_channels_skip=[0, 0, 0],
                   filter_size_down=3,
                   filter_size_up=3,
                   upsample_mode='bilinear',
                   need_sigmoid=True,
                   need_bias=True,
                   pad=pad,
                   act_fun='LeakyReLU')

        self.net1 = net.type(torch.cuda.FloatTensor)

        net = skip(self.input_depth,
                   3,
                   num_channels_down=[8, 16, 32],
                   num_channels_up=[8, 16, 32],
                   num_channels_skip=[0, 0, 0],
                   filter_size_down=3,
                   filter_size_up=3,
                   upsample_mode='bilinear',
                   need_sigmoid=True,
                   need_bias=True,
                   pad=pad,
                   act_fun='LeakyReLU')

        self.net2 = net.type(torch.cuda.FloatTensor)
Exemplo n.º 3
0
    def _init_nets(self):
        if torch.cuda.is_available():
            data_type = torch.cuda.FloatTensor
        else:
            data_type = torch.FloatTensor
        pad = 'reflection'
        reflection_net = skip(
            self.input_depth, 3,
            num_channels_down=[8, 16, 32, 64, 128],
            num_channels_up=[8, 16, 32, 64, 128],
            num_channels_skip=[0, 0, 0, 4, 4],
            upsample_mode='bilinear',
            filter_size_down=5,
            filter_size_up=5,
            need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')

        self.reflection_net = reflection_net.type(data_type)

        transmission_net = skip(
            self.input_depth, 3,
            num_channels_down=[8, 16, 32, 64, 128],
            num_channels_up=[8, 16, 32, 64, 128],
            num_channels_skip=[0, 0, 0, 4, 4],
            upsample_mode='bilinear',
            filter_size_down=5,
            filter_size_up=5,
            need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')

        self.transmission_net = transmission_net.type(data_type)
        alpha_net1 = skip(
            self.input_depth, 1,
            num_channels_down=[8, 16, 32, 64, 128],
            num_channels_up=[8, 16, 32, 64, 128],
            num_channels_skip=[0, 0, 0, 4, 4],
            upsample_mode='bilinear',
            filter_size_down=5,
            filter_size_up=5,
            need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')

        self.alpha1 = alpha_net1.type(data_type)

        alpha_net2 = skip(
            self.input_depth, 1,
            num_channels_down=[8, 16, 32, 64, 128],
            num_channels_up=[8, 16, 32, 64, 128],
            num_channels_skip=[0, 0, 0, 4, 4],
            upsample_mode='bilinear',
            filter_size_down=5,
            filter_size_up=5,
            need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')

        self.alpha2 = alpha_net2.type(data_type)
Exemplo n.º 4
0
    def _init_nets(self):
        pad = 'reflection'
        cleans = [
            skip(self.input_depth,
                 3,
                 num_channels_down=[8, 16, 32, 64],
                 num_channels_up=[8, 16, 32, 64],
                 num_channels_skip=[0, 0, 0, 4],
                 upsample_mode='bilinear',
                 filter_size_down=5,
                 filter_size_up=5,
                 need_sigmoid=True,
                 need_bias=True,
                 pad=pad,
                 act_fun='LeakyReLU') for _ in self.images
        ]

        self.clean_nets = [
            clean.type(torch.cuda.FloatTensor) for clean in cleans
        ]

        mask_net = skip(self.input_depth,
                        1,
                        num_channels_down=[8, 16, 32],
                        num_channels_up=[8, 16, 32],
                        num_channels_skip=[0, 0, 4],
                        upsample_mode='bilinear',
                        filter_size_down=3,
                        filter_size_up=3,
                        need_sigmoid=True,
                        need_bias=True,
                        pad=pad,
                        act_fun='LeakyReLU')

        self.mask_net = mask_net.type(torch.cuda.FloatTensor)

        watermark = skip(self.input_depth,
                         3,
                         num_channels_down=[8, 16, 32, 64],
                         num_channels_up=[8, 16, 32, 64],
                         num_channels_skip=[0, 0, 0, 4],
                         upsample_mode='bilinear',
                         filter_size_down=3,
                         filter_size_up=3,
                         need_sigmoid=True,
                         need_bias=True,
                         pad=pad,
                         act_fun='LeakyReLU')

        self.watermark_net = watermark.type(torch.cuda.FloatTensor)
Exemplo n.º 5
0
    def _init_nets(self):
        data_type = torch.cuda.FloatTensor
        pad = 'reflection'
        mask1_net = skip_mask_vec(self.input_depth,
                                  self.images[0].shape[0],
                                  num_channels_down=[16, 32, 64],
                                  num_channels_up=[16, 32, 64],
                                  num_channels_skip=[0, 0, 4],
                                  upsample_mode='bilinear',
                                  filter_size_down=5,
                                  filter_size_up=5,
                                  need_bias=True,
                                  pad=pad,
                                  act_fun='LeakyReLU')

        self.mask1_net = mask1_net.type(data_type)

        sound1_net = skip(self.input_depth,
                          self.images[0].shape[0],
                          num_channels_down=[16, 32, 64],
                          num_channels_up=[16, 32, 64],
                          num_channels_skip=[0, 0, 4],
                          upsample_mode='bilinear',
                          filter_size_down=5,
                          filter_size_up=5,
                          need_sigmoid=False,
                          need_relu=True,
                          need_bias=True,
                          pad=pad,
                          act_fun='LeakyReLU')

        self.sound1_net = sound1_net.type(data_type)
Exemplo n.º 6
0
    def _init_nets(self):
        pad = 'reflection'
        clean1 = skip(self.input_depth,
                      3,
                      num_channels_down=[8, 16, 32, 64, 128],
                      num_channels_up=[8, 16, 32, 64, 128],
                      num_channels_skip=[0, 0, 0, 4, 4],
                      upsample_mode='bilinear',
                      filter_size_down=5,
                      filter_size_up=5,
                      need_sigmoid=True,
                      need_bias=True,
                      pad=pad,
                      act_fun='LeakyReLU')

        self.clean1_net = clean1.type(torch.cuda.FloatTensor)

        clean2 = skip(self.input_depth,
                      3,
                      num_channels_down=[8, 16, 32, 64, 128],
                      num_channels_up=[8, 16, 32, 64, 128],
                      num_channels_skip=[0, 0, 0, 4, 4],
                      upsample_mode='bilinear',
                      filter_size_down=5,
                      filter_size_up=5,
                      need_sigmoid=True,
                      need_bias=True,
                      pad=pad,
                      act_fun='LeakyReLU')

        self.clean2_net = clean2.type(torch.cuda.FloatTensor)

        watermark = skip(self.input_depth,
                         3,
                         num_channels_down=[8, 16, 32],
                         num_channels_up=[8, 16, 32],
                         num_channels_skip=[0, 0, 4],
                         upsample_mode='bilinear',
                         filter_size_down=3,
                         filter_size_up=3,
                         need_sigmoid=False,
                         need_bias=True,
                         pad=pad,
                         act_fun='LeakyReLU')

        self.watermark_net = watermark.type(torch.cuda.FloatTensor)
Exemplo n.º 7
0
    os.makedirs(args.output_path)

device = "cuda" if torch.cuda.is_available() else "cpu"

# train prior
img = prepare_image(args.img)

prior_input = torch.from_numpy(img).unsqueeze(0).to(device)
rec_samples = []
for ae_sample in range(1, args.prior_samples + 1):
    ae = skip(prior_input.size(1),
              prior_input.size(1),
              num_channels_down=[8, 16, 32],
              num_channels_up=[8, 16, 32],
              num_channels_skip=[0, 0, 0],
              upsample_mode='bilinear',
              filter_size_down=3,
              filter_size_up=3,
              need_sigmoid=True,
              need_bias=True,
              pad='reflection',
              act_fun='LeakyReLU').to(device)

    # TODO: tune lr?
    optimizer = optim.Adam(ae.parameters(), lr=0.0001)
    loss_fn = torch.nn.L1Loss()

    print("--- Train AE sample % d-----" % ae_sample)
    for _ in tqdm(range(args.prior_iters)):
        optimizer.zero_grad()

        loss = loss_fn(ae(prior_input), prior_input)