def __init__(self, ftr_model, d_model, im_size, input_patch, output_patch, n_frames, bw=False, denoise_model=None, batch_size=8): super(TransformerNetwork32, self).__init__() """ Transformer for denoising im_size : int : the length and width of the square input image """ self.d_model = d_model self.ftr_model = ftr_model print(d_model * n_frames, d_model) self.denoise_model_preproc_a = nn.Conv2d(d_model * n_frames, d_model * n_frames, 3, 1, 1) # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3*n_frames,3,1,1) # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3,3,1,1) self.denoise_model_preproc_b = nn.Conv2d(d_model, 3, 3, 1, 1) self.denoise_model = denoise_model self.std = 5. / 255 nhead = 4 # 4 num_enc_layers = 2 # 6 num_dec_layers = 2 # 6 dim_ff = 256 # 512 dropout = 0.1 xform_args = [ d_model, nhead, num_enc_layers, num_dec_layers, dim_ff, dropout ] self.perform = Performer(d_model, 4, 4) # 8,4 # self.xform = nn.Transformer(*xform_args) self.clusters = np.load( f"{settings.ROOT_PATH}/data/kmeans_centers.npy") # -- constrastiv learning loss -- num_transforms = n_frames hyperparams = edict() hyperparams.temperature = 0.1 self.simclr_loss = ClBlockLoss(hyperparams, num_transforms, batch_size) # kwargs = {'num_tokens':512,'max_seq_len':input_patch**2, # 'dim':512, 'depth':6,'heads':4,'causal':False,'cross_attend':False} # self.xform_enc = PerformerLM(**kwargs) # self.xform_dec = PerformerLM(**kwargs) nhead = 1 vdim = 1 if bw else 3 self.attn = nn.MultiheadAttention(d_model, nhead, dropout, vdim=vdim) self.input_patch, self.output_patch = input_patch, output_patch self.im_size = im_size ftr_size = (input_patch**2) * d_model img_size = (input_patch**2) * 3 in_channels = d_model * n_frames # in_channels = 2304 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,1)]) stride = input_patch // output_patch padding = 1 if stride > 1 else 1 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- left settings --; reduce imsize by 2 pix ( 10 in ,8 out ) # stride = 1 # padding = 0 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- conv settings; reduce imsize by Half -- ( 16 in, 8 out ) ( A in, A/2 out ) # stride = 2 # padding = 1 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- conv settings; maintain imsize -- ( A in, A out ) stride = 1 padding = 1 out_chann = 1 if bw else 3 self.conv = nn.Sequential( *[nn.Conv2d(in_channels, out_chann, 3, stride, padding)]) self.end_conv = nn.Sequential( *[nn.Conv2d(3, 3, 1), nn.LeakyReLU(), nn.Conv2d(3, 3, 1)]) padding = (input_patch - output_patch) // 2 # stride = im_size-input_patch stride = output_patch # print(input_patch,padding,stride) self.unfold_input = nn.Unfold(input_patch, 1, padding, stride)
def __init__(self, ftr_model, d_model, im_size, input_patch, output_patch, n_frames, bw=False, denoise_model = None, batch_size = 8): super(TransformerNetwork32_dip_v2, self).__init__() """ Transformer for denoising im_size : int : the length and width of the square input image """ self.d_model = d_model self.ftr_model = {'spoof':ftr_model} print(d_model*n_frames,d_model) self.denoise_model_preproc_a = nn.Conv2d(d_model*n_frames,d_model*n_frames,3,1,1) # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3*n_frames,3,1,1) # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3,3,1,1) self.denoise_model_preproc_b = nn.Conv2d(d_model,3,3,1,1) self.denoise_model = denoise_model self.std = 5./255 nhead = 4 # 4 num_enc_layers = 2 # 6 num_dec_layers = 2 # 6 dim_ff = 256 # 512 dropout = 0.0 xform_args = [d_model, nhead, num_enc_layers, num_dec_layers, dim_ff, dropout] self.use_pos_enc = True self.use_pos_enc_values = False self.use_all_loss = False self.all_loss_v = "v1" self.color_code = nn.Linear(d_model,3) d_model = 3 # self.perform = Performer(d_model,8,1) # 8,4 # self.xform = nn.Transformer(*xform_args) self.clusters = np.load(f"{settings.ROOT_PATH}/data/kmeans_centers.npy") # -- constrastiv learning loss -- num_transforms = n_frames hyperparams = edict() hyperparams.temperature = 0.1 self.simclr_loss = ClBlockLoss(hyperparams, num_transforms, batch_size) # kwargs = {'num_tokens':512,'max_seq_len':input_patch**2, # 'dim':512, 'depth':6,'heads':4,'causal':False,'cross_attend':False} # self.xform_enc = PerformerLM(**kwargs) # self.xform_dec = PerformerLM(**kwargs) nhead = 1 vdim = 1 if bw else 3 self.pos_enc = PositionalEncoder(3,32*32*n_frames) self.attn_1 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False) self.attn_2 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False) self.attn_3 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False) self.attn = [self.attn_1,self.attn_2,self.attn_3] # self.stn = STN_Net() # -- keep output to be shifted inputs -- # print(dir(self.attn)) # print(self.attn.out_proj) # print(self.attn.v_proj_weight) for i in range(3): self.attn[i].v_proj_weight.data = torch.eye(3) # # print(self.attn.v_proj_weight) self.attn[i].v_proj_weight = self.attn[i].v_proj_weight.requires_grad_(False) # # print('o',self.attn.out_proj.weight) self.attn[i].out_proj.weight.data = torch.eye(3) # # print('o',self.attn.out_proj.weight.data) # # print('b',self.attn.out_proj.bias.data) self.attn[i].out_proj.bias.data = torch.zeros_like(self.attn[i].out_proj.bias.data) # # print('b',self.attn.out_proj.bias.data) self.attn[i].out_proj = self.attn[i].out_proj.requires_grad_(False) # self.attn = nn.MultiheadAttention(d_model,nhead,dropout)#,vdim=vdim) self.input_patch,self.output_patch = input_patch,output_patch self.im_size = im_size ftr_size = (input_patch**2)*d_model img_size = (input_patch**2)*3 in_channels = d_model * n_frames # in_channels = 2304 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,1)]) stride = input_patch // output_patch padding = 1 if stride > 1 else 1 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- left settings --; reduce imsize by 2 pix ( 10 in ,8 out ) # stride = 1 # padding = 0 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- conv settings; reduce imsize by Half -- ( 16 in, 8 out ) ( A in, A/2 out ) # stride = 2 # padding = 1 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- conv settings; maintain imsize -- ( A in, A out ) stride = 1 padding = 1 out_chann = 1 if bw else 3 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,out_chann,3,stride,padding)]) # -- using MEAN as denoiser -- in_channels = d_model # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,out_chann,3,stride,padding)]) self.end_conv = nn.Sequential(*[nn.Conv2d(3,3,1),nn.LeakyReLU(),nn.Conv2d(3,3,1)]) self.unet = UNet_small(3) padding = (input_patch - output_patch) // 2 # stride = im_size-input_patch stride = output_patch # print(input_patch,padding,stride) self.unfold_input = nn.Unfold(input_patch,1,padding,stride)
def train_loop(cfg, model_disc, model_rec, model_noise, optimizer_noise, optimizer_rec, optimizer_disc, model_simclr, criterion, train_loader, epoch, num_epochs, fixed_noise): device = cfg.device real_label = 1 #cfg.real_label fake_label = not real_label nz = 100 # cfg.nz cfg.log_interval = 3 cfg.train_gen_interval = 500 # need: cfg.log_interval,cfg.train_gen_interval simclr_loss = ClBlockLoss(cfg.hyper_params, 2 * cfg.N, cfg.batch_size) model_simclr.eval() for p in model_simclr.parameters(): p.requires_grad = False img_list = [] D_losses = [] G_losses = [] one = torch.FloatTensor([1.]).to(device) mone = one * -1. idx = 0 for batch_idx, (noisy_imgs, target) in enumerate(train_loader): # for batch_idx, (noisy_imgs, raw_imgs) in enumerate(train_loader): # idx += 1 # if idx > 30: break # -- setup -- N, BS = noisy_imgs.shape[:2] p_shape = noisy_imgs.shape[2:] p_view = (N * BS, ) + p_shape dshape = ( N, BS, ) + p_shape n_view = noisy_imgs.shape # N,_BS = noisy_imgs.shape[:2] # noisy_imgs = noisy_imgs.view(N*_BS,3,32,32) # BS = N*_BS # ---------------------------- # (1) Maximize Discriminater # ---------------------------- for p in model_disc.parameters(): p.requires_grad = True model_disc.zero_grad() noisy_imgs = noisy_imgs.to(device) noisy_imgs = noisy_imgs.view(p_view) disc_update = get_disc_update_bool(batch_idx, epoch) # -- (i) real images noisy_input = noisy_imgs if cfg.use_simclr == "all": noisy_emb, _ = model_simclr(noisy_imgs) noisy_emb = noisy_emb.view(N * BS, 2, 32, 32) noisy_input = noisy_emb.detach() print(noisy_input.shape) output = model_disc(noisy_input).view(-1) err_disc_real = output.mean(0).view(1) if disc_update: err_disc_real.backward(one) D_x = output.mean().item() # -- (ii) fake images fake = model_rec(noisy_imgs).view(p_view) fake_noisy = model_noise(fake, fake - noisy_imgs) fake_input = fake_noisy if cfg.use_simclr == "all": fake_emb, _ = model_simclr(fake_noisy.view(n_view)) fake_emb = fake_emb.view(N * BS, 2, 32, 32) fake_input = fake_emb.detach() output = model_disc(fake_input).view(-1) err_disc_fake = output.mean(0).view(1) if disc_update: grad_penalty = calc_gradient_penalty(cfg, model_disc, noisy_input, fake_input) grad_penalty.backward() err_disc_fake.backward(mone) D_G_z1 = output.mean().item() error_disc = err_disc_real - err_disc_fake if disc_update: optimizer_disc.step() # for p in model_disc.parameters(): # p.data.clamp_(-0.01, 0.01) # if (batch_idx % 2) == 0 and batch_idx > 0: # optimizer_disc.step() # ----------------------- # (2) Maximize Generator # ----------------------- for p in model_disc.parameters(): p.requires_grad = False error_gen = (error_disc - error_disc) D_G_z2 = (D_G_z1 - D_G_z1) # if ((batch_idx % 15) == 0 and batch_idx > 0 and epoch > 10) or ((batch_idx % 20) and batch_idx > 0): model_rec.zero_grad() model_noise.zero_grad() noisy_update = get_noisy_update_bool(batch_idx, epoch) rec_update = get_rec_update_bool(batch_idx, epoch) rec_update = rec_update or noisy_update gen_loss = 0 if rec_update: # include reconstruction loss. if cfg.use_simclr != "none" and cfg.use_simclr != False: if cfg.use_simclr != "all": noisy_emb, _ = model_simclr(noisy_imgs.view(n_view)) noisy_emb = noisy_emb.view(N * BS, 2, 32, 32) fake_emb, _ = model_simclr(fake_noisy.view(n_view)) fake_emb = fake_emb.view(N * BS, 2, 32, 32) offset_idx = [(i + 1) % N for i in range(N)] noisy_emb = noisy_emb.view((N, BS, 2, 32, 32)) # fake_emb = fake_emb.view(N,BS,2,32,32)[offset_idx] fake_emb = fake_emb.view(N, BS, 2, 32, 32) # embs = torch.stack([fake_emb,noisy_emb]) embs = torch.cat([fake_emb, noisy_emb], dim=0) rec_loss = simclr_loss(embs) else: offset_idx = [(i + 1) % N for i in range(N)] noisy_imgs = noisy_imgs.view(dshape) fake = fake.view(dshape) offset_idx = [(i + 1) % N for i in range(N)] fake_offset = fake[offset_idx] rec_loss = F.mse_loss(fake_offset, noisy_imgs) gen_loss = rec_loss.view(1) if noisy_update: # if ((batch_idx % 5) == 0 and batch_idx > 0 and epoch > 10) or ((batch_idx % 10) == 0 and batch_idx > 0): if cfg.use_simclr == "all": output = model_disc(fake_emb).view(-1) else: output = model_disc(fake_noisy).view(-1) error_gen = output.mean(0).view(1) gen_loss += error_gen # error_gen.backward(one,retain_graph=True) D_G_z2 = output.mean().item() if rec_update or noisy_update: gen_loss.backward() if noisy_update: optimizer_noise.step() if rec_update: optimizer_rec.step() if (batch_idx % cfg.log_interval) == 0: print( "[%d/%d][%d/%d] Loss_D: %.4f\tLoss_G %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f" % (epoch, num_epochs, batch_idx, len(train_loader), error_disc.item(), error_gen.item(), D_x, D_G_z1, D_G_z2)) D_losses.append(error_disc) G_losses.append(error_gen) # Check how the generator is doing by saving G's output on fixed_noise if (batch_idx % cfg.train_gen_interval == 0) or ((epoch == num_epochs - 1) and (batch_idx == len(train_loader) - 1)): with torch.no_grad(): fake = model_rec(fixed_noise).detach().cpu() img_list.append( vutils.make_grid(fake, padding=2, normalize=True, nrow=16)) return D_losses, G_losses, img_list
def __init__( self, net, image_size, batch_size=2, rand_batch_size=2, hidden_layer=-2, projection_size=256, projection_hidden_size=4096, augment_fn=None, augment_fn2=None, moving_average_decay=0.99, use_momentum=True, patch_helper=None, ): super().__init__() self.net = net self.patch_helper = patch_helper # default SimCLR augmentation """ In our function, color is an important property """ DEFAULT_AUG = torch.nn.Sequential( # RandomApply( # T.ColorJitter(0.8, 0.8, 0.8, 0.2), # p = 0.3 # ), # T.RandomGrayscale(p=0.5), # T.RandomHorizontalFlip(), # RandomApply( # T.GaussianBlur((3, 3), (1.0, 2.0)), # p = 0.2 # ), # T.RandomResizedCrop((image_size, image_size)), # T.Normalize( # mean=torch.tensor([0.485, 0.456, 0.406]), # std=torch.tensor([0.229, 0.224, 0.225])), ) self.augment1 = default(augment_fn, DEFAULT_AUG) self.augment2 = default(augment_fn2, self.augment1) def null(image): return image gn, pn = AddGaussianNoise(std=75.), AddPoissonNoiseBW(4.), self.gn, self.pn = gn, pn self.choose_noise = PickOnlyOne([gn, null]) self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) self.use_momentum = use_momentum self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) # simclr since byol isn't working well hyper = edict() hyper.temperature = 1 self.simclr_loss = ClBlockLoss(hyper, 2, batch_size) # get device of network and make wrapper same device device = get_module_device(net) self.to(device) # send a mock image tensor to instantiate singleton parameters rdata = torch.abs( torch.randn(rand_batch_size, 3, image_size, image_size, device=device)) self.forward(rdata)