def test_model(self, test_model_dir): # test or validation self.ckp.write_log('\nEvaluation:') self.ckp.add_log(torch.zeros(1)) #(torch.zeros(1, len(self.scale))) self.model.eval() with torch.no_grad(): eval_acc = 0 for im_idx, im_dict in enumerate(self.loader_results, 1): lr = im_dict['im_lr'] hr = im_dict['im_hr'] lr, hr = self.prepare([lr, hr]) sr, sr_ = self.model(lr) #sr = torch.clamp(sr, 0, 1) eval_acc += errors.find_psnr(sr, hr) if True: im_sr = np.float64( normalise01(sr[0, :, :, :].permute(1, 2, 0).cpu().numpy())) im_sr = im_sr / im_sr.max() im_sr = np.uint8(im_sr * 255) imsave( test_model_dir + '/im_sr_{}.tiff'.format(im_idx + 275), im_sr) print("Image: {}".format(im_idx)) psnr = eval_acc / len(self.loader_test) return psnr
def test(self): # test or validation epoch = self.epoch() self.ckp.write_log('\nEvaluation:') scale = 2 self.ckp.add_log(torch.zeros(1)) #(torch.zeros(1, len(self.scale))) self.model.eval() timer_test = utility.timer() with torch.no_grad(): eval_acc = 0 #psnr loss valid_loss = 0 # total loss based on the training loss for im_idx, im_dict in enumerate(self.loader_test, 1): lr = im_dict['im_lr'] hr = im_dict['im_hr'] lr, hr = self.prepare([lr, hr]) sr, sr_ = self.model(lr) sr = torch.clamp(sr, 0, 1) sr_ = torch.clamp(sr_, 0, 1) self.lr_valid = np.average(lr[0, :, :, :].permute( 1, 2, 0).cpu().numpy(), axis=2) self.hr_valid = hr[0, :, :, :].permute(1, 2, 0).cpu().numpy() self.sr_valid = sr[0, :, :, :].permute( 1, 2, 0).cpu().detach().numpy() # sr = utility.quantize(sr, self.args.rgb_range) save_list = [sr] # do some processing on sr, hr or modify find_psnr() eval_acc += errors.find_psnr(sr, hr) save_list.extend([lr, hr]) loss = self.loss.valid_loss(sr, sr_, hr) valid_loss += loss.item() # save the sr images of the last epoch if self.args.save_results and epoch == self.args.epochs: self.ckp.save_results("image_{}_sr".format(im_idx), save_list, scale) self.ckp.log_accuracy[-1] = (valid_loss / len(self.loader_test)) self.ckp.log[-1] = eval_acc / len(self.loader_test) best = self.ckp.log.max(0) self.ckp.write_log( '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( self.args.data_test, scale, self.ckp.log[-1], best[0].item(), epoch)) # ckp.save saves loss and model and plot_loss defined in the # Checkpoint class self.ckp.write_log('Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True) if not self.args.test_only: # self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) self.ckp.save(self, epoch, is_best=False)
def main(): ck = util.checkpoint(args) seed = args.seed random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) ck.write_log(str(args)) # t = str(int(time.time())) # t = args.save_name # os.mkdir('./{}'.format(t)) # (ch_out, ch_in, k, k, stride, padding) config = [('conv2d', [32, 16, 3, 3, 1, 1]), ('relu', [True]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('+1', [True]), ('conv2d', [3, 32, 3, 3, 1, 1])] device = torch.device('cuda') maml = Meta(args, config).to(device) # (Dataset) calculate the number of trainable tensors tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) ck.write_log(str(maml)) ck.write_log('Total trainable tensors: {}'.format(num)) # (Dataset) batchsz here means total episode number DL_MSI = dl.StereoMSIDatasetLoader(args) db = DL_MSI.train_loader dv = DL_MSI.valid_loader psnr = [] l1_loss = [] psnr_valid = [] for epoch, (spt_ms, spt_rgb, qry_ms, qry_rgb) in enumerate(db): if epoch // args.epoch: break spt_ms, spt_rgb, qry_ms, qry_rgb = (spt_ms.to(device), spt_rgb.to(device), qry_ms.to(device), qry_rgb.to(device)) # optimization is carried out inside meta_learner class, maml. accs, train_loss = maml(spt_ms, spt_rgb, qry_ms, qry_rgb, epoch) maml.scheduler.step() if epoch % args.print_every == 0: log_epoch = 'epoch: {} \ttraining acc: {}'.format(epoch, accs) ck.write_log(log_epoch) psnr.append(accs) l1_loss.append(train_loss) ck.plot_loss(psnr, l1_loss, epoch, args.print_every) if epoch % args.save_every == 0: with torch.no_grad(): ck.save(maml.net, maml.meta_optim, epoch) eval_psnr = 0 # psnr loss for idx, (valid_ms, valid_rgb) in enumerate(dv): #print('idx', idx) valid_ms, valid_rgb = prepare([valid_ms, valid_rgb]) sr_rgb = maml.net(valid_ms) sr_rgb = torch.clamp(sr_rgb, 0, 1) eval_psnr += errors.find_psnr(valid_rgb, sr_rgb) ############## plot PSNR here you idiot! ########### psnr_valid.append(eval_psnr / 25) ck.plot_psnr(psnr_valid, epoch, args.save_every) ck.write_log('Max PSNR is: {}'.format(max(psnr_valid))) imsave( './{}/validation/img_{}.png'.format(ck.dir, epoch), np.uint8(sr_rgb[0, :, :, :].permute( 1, 2, 0).cpu().detach().numpy() * 255)) ck.done()
def forward(self, spt_ms, spt_rgb, qry_ms, qry_rgb, epoch): """ :b: number of tasks/batches. :setsz: number of training pairs? :querysz number of test pairs for few shot :param spt_ms: [task_num, setsz, 16, h, w] :param spt_rgb: [task_num, querysz, 3, h, w] :param qry_ms: [task_num, setsz, 16, h, w] :param qry_rgb: [task_num, querysz, 3, h, w] :return: """ spt_ms = spt_ms.squeeze() spt_rgb = spt_rgb.squeeze() qry_ms = qry_ms.squeeze() qry_rgb = qry_rgb.squeeze() task_num, setsz, c, h, w = spt_ms.size() _, querysz, c, _, _ = qry_ms.size() # losses_q[k] is the loss on step k of gradient descent (inner loop) losses_q = [0 for _ in range(self.update_step + 1)] # accuracy on step i of gradient descent (inner loop) corrects = [0 for _ in range(self.update_step + 1)] if (epoch < 4001): if (epoch % 2000 == 0) and (epoch > 1): decay = 2 #(epoch // 5) + 1 self.update_lr = self.update_lr / decay print('outer loop lr is: ', self.update_lr) for i in range(task_num): # 1. run the i-th task and compute loss for k=0, k is update step logits = self.net(spt_ms[i], vars=None, bn_training=True) loss = F.smooth_l1_loss(logits, spt_rgb[i]) # create a log with task_num x k #print(loss.item()) # the sum of graidents of outputs w.r.t the input grad = torch.autograd.grad(loss, self.net.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) # what are these two torch.no_grad()s about????????????????????? # the first one calculates accuracy right after initialization # which makes sense, the second one is doing an update...why????? # this is the loss and accuracy before first update with torch.no_grad(): # [setsz, nway] logits_q = self.net(qry_ms[i], self.net.parameters(), bn_training=True) loss_q = F.smooth_l1_loss(logits_q, qry_rgb[i]) losses_q[0] += loss_q # adding loss?! pred_q = logits_q # logits_q used to be cross_entropy loss, and # go through softmax to become pred_q. # calculate PSNR correct = errors.find_psnr(pred_q, qry_rgb[i]) corrects[0] = corrects[0] + correct # this is the loss and accuracy after the first update with torch.no_grad(): # [setsz, nway] logits_q = self.net(qry_ms[i], fast_weights, bn_training=True) loss_q = F.smooth_l1_loss(logits_q, qry_rgb[i]) losses_q[1] += loss_q # [setsz] pred_q = logits_q correct = errors.find_psnr(pred_q, qry_rgb[i]) corrects[1] = corrects[1] + correct for k in range(1, self.update_step): # 1. run the i-th task and compute loss for k=1~K-1 logits = self.net(spt_ms[i], fast_weights, bn_training=True) loss = F.smooth_l1_loss(logits, spt_rgb[i]) # 2. compute grad on theta_pi grad = torch.autograd.grad(loss, fast_weights) # 3. theta_pi = theta_pi - train_lr * grad fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) logits_q = self.net(qry_ms[i], fast_weights, bn_training=True) self.valid_img = logits_q # loss_q will be overwritten and we just keep the loss_q on # last update step ==> losses_q[-1] loss_q = F.smooth_l1_loss(logits_q, qry_rgb[i]) losses_q[k + 1] += loss_q with torch.no_grad(): pred_q = logits_q # convert to numpy correct = errors.find_psnr(pred_q, qry_rgb[i]) corrects[k + 1] = corrects[k + 1] + correct # end of all tasks # sum over all losses on query set across all tasks loss_q = losses_q[-1] / task_num # self.log[-1] += loss.item() # optimize theta parameters # In the Learner the update is with respect to accuracy of the training # set, but for meta_learner the meta_update is with respect to the test # set of each episode. self.meta_optim.zero_grad() loss_q.backward() # backwards through grad above ==> d(loss_q)/d(grad) # print('meta update') # for p in self.net.parameters()[:5]: # print(torch.norm(p).item()) self.meta_optim.step() accs = np.average(np.array(corrects[-1])) #/ (querysz * task_num) print('inner loop lr is: ', self.get_lr(self.meta_optim)) return accs, loss_q