def test_model(self, test_model_dir):  # test or validation
     self.ckp.write_log('\nEvaluation:')
     self.ckp.add_log(torch.zeros(1))  #(torch.zeros(1, len(self.scale)))
     self.model.eval()
     with torch.no_grad():
         eval_acc = 0
         for im_idx, im_dict in enumerate(self.loader_results, 1):
             lr = im_dict['im_lr']
             hr = im_dict['im_hr']
             lr, hr = self.prepare([lr, hr])
             sr, sr_ = self.model(lr)
             #sr = torch.clamp(sr, 0, 1)
             eval_acc += errors.find_psnr(sr, hr)
             if True:
                 im_sr = np.float64(
                     normalise01(sr[0, :, :, :].permute(1, 2,
                                                        0).cpu().numpy()))
                 im_sr = im_sr / im_sr.max()
                 im_sr = np.uint8(im_sr * 255)
                 imsave(
                     test_model_dir + '/im_sr_{}.tiff'.format(im_idx + 275),
                     im_sr)
             print("Image: {}".format(im_idx))
         psnr = eval_acc / len(self.loader_test)
     return psnr
 def test(self):  # test or validation
     epoch = self.epoch()
     self.ckp.write_log('\nEvaluation:')
     scale = 2
     self.ckp.add_log(torch.zeros(1))  #(torch.zeros(1, len(self.scale)))
     self.model.eval()
     timer_test = utility.timer()
     with torch.no_grad():
         eval_acc = 0  #psnr loss
         valid_loss = 0  # total loss based on the training loss
         for im_idx, im_dict in enumerate(self.loader_test, 1):
             lr = im_dict['im_lr']
             hr = im_dict['im_hr']
             lr, hr = self.prepare([lr, hr])
             sr, sr_ = self.model(lr)
             sr = torch.clamp(sr, 0, 1)
             sr_ = torch.clamp(sr_, 0, 1)
             self.lr_valid = np.average(lr[0, :, :, :].permute(
                 1, 2, 0).cpu().numpy(),
                                        axis=2)
             self.hr_valid = hr[0, :, :, :].permute(1, 2, 0).cpu().numpy()
             self.sr_valid = sr[0, :, :, :].permute(
                 1, 2, 0).cpu().detach().numpy()
             #                   sr = utility.quantize(sr, self.args.rgb_range)
             save_list = [sr]
             # do some processing on sr, hr or modify find_psnr()
             eval_acc += errors.find_psnr(sr, hr)
             save_list.extend([lr, hr])
             loss = self.loss.valid_loss(sr, sr_, hr)
             valid_loss += loss.item()
             # save the sr images of the last epoch
             if self.args.save_results and epoch == self.args.epochs:
                 self.ckp.save_results("image_{}_sr".format(im_idx),
                                       save_list, scale)
         self.ckp.log_accuracy[-1] = (valid_loss / len(self.loader_test))
         self.ckp.log[-1] = eval_acc / len(self.loader_test)
         best = self.ckp.log.max(0)
         self.ckp.write_log(
             '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                 self.args.data_test, scale, self.ckp.log[-1],
                 best[0].item(), epoch))
     # ckp.save saves loss and model and plot_loss defined in the
     # Checkpoint class
     self.ckp.write_log('Total time: {:.2f}s\n'.format(timer_test.toc()),
                        refresh=True)
     if not self.args.test_only:
         #            self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
         self.ckp.save(self, epoch, is_best=False)
def main():
    ck = util.checkpoint(args)
    seed = args.seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ck.write_log(str(args))
    # t = str(int(time.time()))
    # t = args.save_name
    # os.mkdir('./{}'.format(t))
    # (ch_out, ch_in, k, k, stride, padding)
    config = [('conv2d', [32, 16, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('+1', [True]), ('conv2d', [3, 32, 3, 3, 1, 1])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)
    # (Dataset) calculate the number of trainable tensors
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    ck.write_log(str(maml))
    ck.write_log('Total trainable tensors: {}'.format(num))

    # (Dataset) batchsz here means total episode number
    DL_MSI = dl.StereoMSIDatasetLoader(args)
    db = DL_MSI.train_loader
    dv = DL_MSI.valid_loader

    psnr = []
    l1_loss = []
    psnr_valid = []
    for epoch, (spt_ms, spt_rgb, qry_ms, qry_rgb) in enumerate(db):

        if epoch // args.epoch: break
        spt_ms, spt_rgb, qry_ms, qry_rgb = (spt_ms.to(device),
                                            spt_rgb.to(device),
                                            qry_ms.to(device),
                                            qry_rgb.to(device))

        # optimization is carried out inside meta_learner class, maml.
        accs, train_loss = maml(spt_ms, spt_rgb, qry_ms, qry_rgb, epoch)
        maml.scheduler.step()

        if epoch % args.print_every == 0:
            log_epoch = 'epoch: {} \ttraining acc: {}'.format(epoch, accs)
            ck.write_log(log_epoch)
            psnr.append(accs)
            l1_loss.append(train_loss)
            ck.plot_loss(psnr, l1_loss, epoch, args.print_every)
            if epoch % args.save_every == 0:
                with torch.no_grad():
                    ck.save(maml.net, maml.meta_optim, epoch)
                    eval_psnr = 0  # psnr loss
                    for idx, (valid_ms, valid_rgb) in enumerate(dv):
                        #print('idx', idx)
                        valid_ms, valid_rgb = prepare([valid_ms, valid_rgb])
                        sr_rgb = maml.net(valid_ms)
                        sr_rgb = torch.clamp(sr_rgb, 0, 1)
                        eval_psnr += errors.find_psnr(valid_rgb, sr_rgb)
                        ############## plot PSNR here you idiot! ###########
                    psnr_valid.append(eval_psnr / 25)
                    ck.plot_psnr(psnr_valid, epoch, args.save_every)
                    ck.write_log('Max PSNR is: {}'.format(max(psnr_valid)))
                    imsave(
                        './{}/validation/img_{}.png'.format(ck.dir, epoch),
                        np.uint8(sr_rgb[0, :, :, :].permute(
                            1, 2, 0).cpu().detach().numpy() * 255))
    ck.done()
    def forward(self, spt_ms, spt_rgb, qry_ms, qry_rgb, epoch):
        """
        :b:             number of tasks/batches.
        :setsz:         number of training pairs?
        :querysz        number of test pairs for few shot
        :param spt_ms:    [task_num, setsz, 16, h, w]
        :param spt_rgb:   [task_num, querysz, 3, h, w] 
        :param qry_ms:    [task_num, setsz, 16, h, w]
        :param qry_rgb:   [task_num, querysz, 3, h, w]

        :return:
        """

        spt_ms = spt_ms.squeeze()
        spt_rgb = spt_rgb.squeeze()
        qry_ms = qry_ms.squeeze()
        qry_rgb = qry_rgb.squeeze()

        task_num, setsz, c, h, w = spt_ms.size()
        _, querysz, c, _, _ = qry_ms.size()
        # losses_q[k] is the loss on step k of gradient descent (inner loop)
        losses_q = [0 for _ in range(self.update_step + 1)]
        # accuracy on step i of gradient descent (inner loop)
        corrects = [0 for _ in range(self.update_step + 1)]
        if (epoch < 4001):
            if (epoch % 2000 == 0) and (epoch > 1):
                decay = 2  #(epoch // 5) + 1
                self.update_lr = self.update_lr / decay
        print('outer loop lr is: ', self.update_lr)
        for i in range(task_num):

            # 1. run the i-th task and compute loss for k=0, k is update step
            logits = self.net(spt_ms[i], vars=None, bn_training=True)
            loss = F.smooth_l1_loss(logits, spt_rgb[i])
            # create a log with task_num x k
            #print(loss.item())
            # the sum of graidents of outputs w.r.t the input
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters())))
            # what are these two torch.no_grad()s about?????????????????????
            # the first one calculates accuracy right after initialization
            # which makes sense, the second one is doing an update...why?????
            # this is the loss and accuracy before first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.net(qry_ms[i],
                                    self.net.parameters(),
                                    bn_training=True)
                loss_q = F.smooth_l1_loss(logits_q, qry_rgb[i])
                losses_q[0] += loss_q  # adding loss?!

                pred_q = logits_q  # logits_q used to be cross_entropy loss, and
                # go through softmax to become pred_q.
                # calculate PSNR
                correct = errors.find_psnr(pred_q, qry_rgb[i])
                corrects[0] = corrects[0] + correct

            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.net(qry_ms[i], fast_weights, bn_training=True)
                loss_q = F.smooth_l1_loss(logits_q, qry_rgb[i])
                losses_q[1] += loss_q
                # [setsz]
                pred_q = logits_q
                correct = errors.find_psnr(pred_q, qry_rgb[i])
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                logits = self.net(spt_ms[i], fast_weights, bn_training=True)
                loss = F.smooth_l1_loss(logits, spt_rgb[i])
                # 2. compute grad on theta_pi
                grad = torch.autograd.grad(loss, fast_weights)
                # 3. theta_pi = theta_pi - train_lr * grad
                fast_weights = list(
                    map(lambda p: p[1] - self.update_lr * p[0],
                        zip(grad, fast_weights)))

                logits_q = self.net(qry_ms[i], fast_weights, bn_training=True)
                self.valid_img = logits_q
                # loss_q will be overwritten and we just keep the loss_q on
                # last update step ==> losses_q[-1]
                loss_q = F.smooth_l1_loss(logits_q, qry_rgb[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = logits_q
                    # convert to numpy
                    correct = errors.find_psnr(pred_q, qry_rgb[i])
                    corrects[k + 1] = corrects[k + 1] + correct

        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_q = losses_q[-1] / task_num
        # self.log[-1] += loss.item()
        # optimize theta parameters
        # In the Learner the update is with respect to accuracy of the training
        # set, but for meta_learner the meta_update is with respect to the test
        # set of each episode.
        self.meta_optim.zero_grad()
        loss_q.backward()  # backwards through grad above ==> d(loss_q)/d(grad)
        # print('meta update')
        # for p in self.net.parameters()[:5]:
        # 	print(torch.norm(p).item())
        self.meta_optim.step()
        accs = np.average(np.array(corrects[-1]))  #/ (querysz * task_num)
        print('inner loop lr is: ', self.get_lr(self.meta_optim))
        return accs, loss_q