Esempio n. 1
0
    def __init__(self,
                 ftr_model,
                 d_model,
                 im_size,
                 input_patch,
                 output_patch,
                 n_frames,
                 bw=False,
                 denoise_model=None,
                 batch_size=8):
        super(TransformerNetwork32, self).__init__()
        """
        Transformer for denoising

        im_size : int : the length and width of the square input image

        """
        self.d_model = d_model
        self.ftr_model = ftr_model
        print(d_model * n_frames, d_model)
        self.denoise_model_preproc_a = nn.Conv2d(d_model * n_frames,
                                                 d_model * n_frames, 3, 1, 1)
        # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3*n_frames,3,1,1)
        # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3,3,1,1)
        self.denoise_model_preproc_b = nn.Conv2d(d_model, 3, 3, 1, 1)
        self.denoise_model = denoise_model
        self.std = 5. / 255
        nhead = 4  # 4
        num_enc_layers = 2  # 6
        num_dec_layers = 2  # 6
        dim_ff = 256  # 512
        dropout = 0.1
        xform_args = [
            d_model, nhead, num_enc_layers, num_dec_layers, dim_ff, dropout
        ]
        self.perform = Performer(d_model, 4, 4)  # 8,4
        # self.xform = nn.Transformer(*xform_args)
        self.clusters = np.load(
            f"{settings.ROOT_PATH}/data/kmeans_centers.npy")

        # -- constrastiv learning loss --
        num_transforms = n_frames
        hyperparams = edict()
        hyperparams.temperature = 0.1
        self.simclr_loss = ClBlockLoss(hyperparams, num_transforms, batch_size)

        # kwargs = {'num_tokens':512,'max_seq_len':input_patch**2,
        #           'dim':512, 'depth':6,'heads':4,'causal':False,'cross_attend':False}
        # self.xform_enc = PerformerLM(**kwargs)
        # self.xform_dec = PerformerLM(**kwargs)

        nhead = 1
        vdim = 1 if bw else 3
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout, vdim=vdim)
        self.input_patch, self.output_patch = input_patch, output_patch
        self.im_size = im_size

        ftr_size = (input_patch**2) * d_model
        img_size = (input_patch**2) * 3
        in_channels = d_model * n_frames
        # in_channels = 2304
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,1)])
        stride = input_patch // output_patch
        padding = 1 if stride > 1 else 1
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- left settings --; reduce imsize by 2 pix ( 10 in ,8 out )
        # stride = 1
        # padding = 0
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- conv settings; reduce imsize by Half -- ( 16 in, 8 out ) ( A in, A/2 out )
        # stride = 2
        # padding = 1
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- conv settings; maintain imsize -- ( A in, A out )
        stride = 1
        padding = 1
        out_chann = 1 if bw else 3
        self.conv = nn.Sequential(
            *[nn.Conv2d(in_channels, out_chann, 3, stride, padding)])

        self.end_conv = nn.Sequential(
            *[nn.Conv2d(3, 3, 1),
              nn.LeakyReLU(),
              nn.Conv2d(3, 3, 1)])

        padding = (input_patch - output_patch) // 2
        # stride = im_size-input_patch
        stride = output_patch
        # print(input_patch,padding,stride)
        self.unfold_input = nn.Unfold(input_patch, 1, padding, stride)
Esempio n. 2
0
    def __init__(self, ftr_model, d_model, im_size, input_patch, output_patch, n_frames, bw=False, denoise_model = None, batch_size = 8):
        super(TransformerNetwork32_dip_v2, self).__init__()
        """
        Transformer for denoising

        im_size : int : the length and width of the square input image

        """
        self.d_model = d_model
        self.ftr_model = {'spoof':ftr_model}
        print(d_model*n_frames,d_model)
        self.denoise_model_preproc_a = nn.Conv2d(d_model*n_frames,d_model*n_frames,3,1,1)
        # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3*n_frames,3,1,1)
        # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3,3,1,1)
        self.denoise_model_preproc_b = nn.Conv2d(d_model,3,3,1,1)
        self.denoise_model = denoise_model
        self.std = 5./255
        nhead = 4 # 4
        num_enc_layers = 2 # 6
        num_dec_layers = 2 # 6
        dim_ff = 256 # 512
        dropout = 0.0
        xform_args = [d_model, nhead, num_enc_layers, num_dec_layers, dim_ff, dropout]
        self.use_pos_enc = True
        self.use_pos_enc_values = False
        self.use_all_loss = False
        self.all_loss_v = "v1"

        self.color_code = nn.Linear(d_model,3)

        d_model = 3
        # self.perform = Performer(d_model,8,1) # 8,4
        # self.xform = nn.Transformer(*xform_args)
        self.clusters = np.load(f"{settings.ROOT_PATH}/data/kmeans_centers.npy")

        # -- constrastiv learning loss --
        num_transforms = n_frames
        hyperparams = edict()
        hyperparams.temperature = 0.1
        self.simclr_loss = ClBlockLoss(hyperparams, num_transforms, batch_size)

        # kwargs = {'num_tokens':512,'max_seq_len':input_patch**2,
        #           'dim':512, 'depth':6,'heads':4,'causal':False,'cross_attend':False}
        # self.xform_enc = PerformerLM(**kwargs)
        # self.xform_dec = PerformerLM(**kwargs)

        nhead = 1
        vdim = 1 if bw else 3
        self.pos_enc = PositionalEncoder(3,32*32*n_frames)
        self.attn_1 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False)
        self.attn_2 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False)
        self.attn_3 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False)
        self.attn = [self.attn_1,self.attn_2,self.attn_3]
        # self.stn = STN_Net()

        # -- keep output to be shifted inputs --
        # print(dir(self.attn))
        # print(self.attn.out_proj)
        # print(self.attn.v_proj_weight)
        for i in range(3):
            self.attn[i].v_proj_weight.data = torch.eye(3)
            # # print(self.attn.v_proj_weight)
            self.attn[i].v_proj_weight = self.attn[i].v_proj_weight.requires_grad_(False)
            # # print('o',self.attn.out_proj.weight)
            self.attn[i].out_proj.weight.data = torch.eye(3)
            # # print('o',self.attn.out_proj.weight.data)
            # # print('b',self.attn.out_proj.bias.data)
            self.attn[i].out_proj.bias.data = torch.zeros_like(self.attn[i].out_proj.bias.data)
            # # print('b',self.attn.out_proj.bias.data)
            self.attn[i].out_proj = self.attn[i].out_proj.requires_grad_(False)
            

        # self.attn = nn.MultiheadAttention(d_model,nhead,dropout)#,vdim=vdim)
        self.input_patch,self.output_patch = input_patch,output_patch
        self.im_size = im_size

        ftr_size = (input_patch**2)*d_model
        img_size = (input_patch**2)*3
        in_channels = d_model * n_frames
        # in_channels = 2304
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,1)])
        stride = input_patch // output_patch
        padding = 1 if stride > 1 else 1
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- left settings --; reduce imsize by 2 pix ( 10 in ,8 out )
        # stride = 1
        # padding = 0
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- conv settings; reduce imsize by Half -- ( 16 in, 8 out ) ( A in, A/2 out )
        # stride = 2
        # padding = 1
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- conv settings; maintain imsize -- ( A in, A out )
        stride = 1
        padding = 1
        out_chann = 1 if bw else 3
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,out_chann,3,stride,padding)])

        # -- using MEAN as denoiser --
        in_channels = d_model
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,out_chann,3,stride,padding)])

        self.end_conv = nn.Sequential(*[nn.Conv2d(3,3,1),nn.LeakyReLU(),nn.Conv2d(3,3,1)])
        self.unet = UNet_small(3)

        padding = (input_patch - output_patch) // 2
        # stride = im_size-input_patch
        stride = output_patch
        # print(input_patch,padding,stride)
        self.unfold_input = nn.Unfold(input_patch,1,padding,stride)
Esempio n. 3
0
def train_loop(cfg, model_disc, model_rec, model_noise, optimizer_noise,
               optimizer_rec, optimizer_disc, model_simclr, criterion,
               train_loader, epoch, num_epochs, fixed_noise):

    device = cfg.device
    real_label = 1  #cfg.real_label
    fake_label = not real_label
    nz = 100  # cfg.nz
    cfg.log_interval = 3
    cfg.train_gen_interval = 500
    # need: cfg.log_interval,cfg.train_gen_interval
    simclr_loss = ClBlockLoss(cfg.hyper_params, 2 * cfg.N, cfg.batch_size)
    model_simclr.eval()
    for p in model_simclr.parameters():
        p.requires_grad = False

    img_list = []
    D_losses = []
    G_losses = []
    one = torch.FloatTensor([1.]).to(device)
    mone = one * -1.
    idx = 0

    for batch_idx, (noisy_imgs, target) in enumerate(train_loader):
        # for batch_idx, (noisy_imgs, raw_imgs) in enumerate(train_loader):
        # idx += 1
        # if idx > 30: break

        # -- setup --
        N, BS = noisy_imgs.shape[:2]
        p_shape = noisy_imgs.shape[2:]
        p_view = (N * BS, ) + p_shape
        dshape = (
            N,
            BS,
        ) + p_shape
        n_view = noisy_imgs.shape

        # N,_BS = noisy_imgs.shape[:2]
        # noisy_imgs = noisy_imgs.view(N*_BS,3,32,32)
        # BS = N*_BS

        # ----------------------------
        # (1) Maximize Discriminater
        # ----------------------------
        for p in model_disc.parameters():
            p.requires_grad = True

        model_disc.zero_grad()
        noisy_imgs = noisy_imgs.to(device)
        noisy_imgs = noisy_imgs.view(p_view)
        disc_update = get_disc_update_bool(batch_idx, epoch)

        # -- (i) real images
        noisy_input = noisy_imgs
        if cfg.use_simclr == "all":
            noisy_emb, _ = model_simclr(noisy_imgs)
            noisy_emb = noisy_emb.view(N * BS, 2, 32, 32)
            noisy_input = noisy_emb.detach()
        print(noisy_input.shape)
        output = model_disc(noisy_input).view(-1)
        err_disc_real = output.mean(0).view(1)
        if disc_update:
            err_disc_real.backward(one)
        D_x = output.mean().item()

        # -- (ii) fake images
        fake = model_rec(noisy_imgs).view(p_view)
        fake_noisy = model_noise(fake, fake - noisy_imgs)
        fake_input = fake_noisy
        if cfg.use_simclr == "all":
            fake_emb, _ = model_simclr(fake_noisy.view(n_view))
            fake_emb = fake_emb.view(N * BS, 2, 32, 32)
            fake_input = fake_emb.detach()
        output = model_disc(fake_input).view(-1)
        err_disc_fake = output.mean(0).view(1)
        if disc_update:
            grad_penalty = calc_gradient_penalty(cfg, model_disc, noisy_input,
                                                 fake_input)
            grad_penalty.backward()
            err_disc_fake.backward(mone)
        D_G_z1 = output.mean().item()
        error_disc = err_disc_real - err_disc_fake
        if disc_update:
            optimizer_disc.step()

        # for p in model_disc.parameters():
        #     p.data.clamp_(-0.01, 0.01)
        # if (batch_idx % 2) == 0 and batch_idx > 0:
        #     optimizer_disc.step()

        # -----------------------
        # (2) Maximize Generator
        # -----------------------

        for p in model_disc.parameters():
            p.requires_grad = False

        error_gen = (error_disc - error_disc)
        D_G_z2 = (D_G_z1 - D_G_z1)
        # if ((batch_idx % 15) == 0 and batch_idx > 0 and epoch > 10) or ((batch_idx % 20) and batch_idx > 0):

        model_rec.zero_grad()
        model_noise.zero_grad()

        noisy_update = get_noisy_update_bool(batch_idx, epoch)
        rec_update = get_rec_update_bool(batch_idx, epoch)
        rec_update = rec_update or noisy_update

        gen_loss = 0
        if rec_update:
            # include reconstruction loss.
            if cfg.use_simclr != "none" and cfg.use_simclr != False:
                if cfg.use_simclr != "all":
                    noisy_emb, _ = model_simclr(noisy_imgs.view(n_view))
                    noisy_emb = noisy_emb.view(N * BS, 2, 32, 32)
                    fake_emb, _ = model_simclr(fake_noisy.view(n_view))
                    fake_emb = fake_emb.view(N * BS, 2, 32, 32)
                    offset_idx = [(i + 1) % N for i in range(N)]
                    noisy_emb = noisy_emb.view((N, BS, 2, 32, 32))
                    # fake_emb = fake_emb.view(N,BS,2,32,32)[offset_idx]
                fake_emb = fake_emb.view(N, BS, 2, 32, 32)
                # embs = torch.stack([fake_emb,noisy_emb])
                embs = torch.cat([fake_emb, noisy_emb], dim=0)
                rec_loss = simclr_loss(embs)
            else:
                offset_idx = [(i + 1) % N for i in range(N)]
                noisy_imgs = noisy_imgs.view(dshape)
                fake = fake.view(dshape)
                offset_idx = [(i + 1) % N for i in range(N)]
                fake_offset = fake[offset_idx]
                rec_loss = F.mse_loss(fake_offset, noisy_imgs)
            gen_loss = rec_loss.view(1)

        if noisy_update:
            # if ((batch_idx % 5) == 0 and batch_idx > 0 and epoch > 10) or ((batch_idx % 10) == 0 and batch_idx > 0):
            if cfg.use_simclr == "all":
                output = model_disc(fake_emb).view(-1)
            else:
                output = model_disc(fake_noisy).view(-1)
            error_gen = output.mean(0).view(1)
            gen_loss += error_gen
            # error_gen.backward(one,retain_graph=True)
            D_G_z2 = output.mean().item()

        if rec_update or noisy_update: gen_loss.backward()
        if noisy_update: optimizer_noise.step()
        if rec_update: optimizer_rec.step()

        if (batch_idx % cfg.log_interval) == 0:
            print(
                "[%d/%d][%d/%d] Loss_D: %.4f\tLoss_G %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f"
                % (epoch, num_epochs, batch_idx, len(train_loader),
                   error_disc.item(), error_gen.item(), D_x, D_G_z1, D_G_z2))

        D_losses.append(error_disc)
        G_losses.append(error_gen)

        # Check how the generator is doing by saving G's output on fixed_noise
        if (batch_idx % cfg.train_gen_interval
                == 0) or ((epoch == num_epochs - 1) and
                          (batch_idx == len(train_loader) - 1)):
            with torch.no_grad():
                fake = model_rec(fixed_noise).detach().cpu()
            img_list.append(
                vutils.make_grid(fake, padding=2, normalize=True, nrow=16))

    return D_losses, G_losses, img_list
Esempio n. 4
0
    def __init__(
        self,
        net,
        image_size,
        batch_size=2,
        rand_batch_size=2,
        hidden_layer=-2,
        projection_size=256,
        projection_hidden_size=4096,
        augment_fn=None,
        augment_fn2=None,
        moving_average_decay=0.99,
        use_momentum=True,
        patch_helper=None,
    ):
        super().__init__()
        self.net = net
        self.patch_helper = patch_helper

        # default SimCLR augmentation
        """
        In our function, color is an important property
        """
        DEFAULT_AUG = torch.nn.Sequential(
            # RandomApply(
            #     T.ColorJitter(0.8, 0.8, 0.8, 0.2),
            #     p = 0.3
            # ),
            # T.RandomGrayscale(p=0.5),
            # T.RandomHorizontalFlip(),
            # RandomApply(
            #     T.GaussianBlur((3, 3), (1.0, 2.0)),
            #     p = 0.2
            # ),
            # T.RandomResizedCrop((image_size, image_size)),
            # T.Normalize(
            #     mean=torch.tensor([0.485, 0.456, 0.406]),
            #     std=torch.tensor([0.229, 0.224, 0.225])),
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        def null(image):
            return image

        gn, pn = AddGaussianNoise(std=75.), AddPoissonNoiseBW(4.),
        self.gn, self.pn = gn, pn
        self.choose_noise = PickOnlyOne([gn, null])

        self.online_encoder = NetWrapper(net,
                                         projection_size,
                                         projection_hidden_size,
                                         layer=hidden_layer)

        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size,
                                    projection_hidden_size)

        # simclr since byol isn't working well
        hyper = edict()
        hyper.temperature = 1
        self.simclr_loss = ClBlockLoss(hyper, 2, batch_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        rdata = torch.abs(
            torch.randn(rand_batch_size,
                        3,
                        image_size,
                        image_size,
                        device=device))
        self.forward(rdata)