def main():
    ck = util.checkpoint(args)
    seed = args.seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ck.write_log(str(args))
    # t = str(int(time.time()))
    # t = args.save_name
    # os.mkdir('./{}'.format(t))
    # (ch_out, ch_in, k, k, stride, padding)
    config = [('conv2d', [32, 16, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('+1', [True]), ('conv2d', [3, 32, 3, 3, 1, 1])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    params = torch.load(
        r'/flush5/sho092/Robust_learning/experiment/'
        r'2020-07-14-16:58:35_k0_metalr0.001_updatelr0.01_batchsz100000_updateStep7/'
        r'model/model_200.pt')

    DL_MSI = dl.StereoMSIDatasetLoader(args)
    dv = DL_MSI.valid_loader
    maml.net.load_state_dict(params, strict=False)
    maml.net.eval()
    for idx, (valid_ms, valid_rgb) in enumerate(dv):
        # print('idx', idx)
        valid_ms, valid_rgb = prepare([valid_ms, valid_rgb])
        sr_rgb = maml.net(valid_ms)
        print(sr_rgb.max(), sr_rgb.min())
        sr_rgb = torch.clamp(sr_rgb, 0, 1)

        imsave(
            '../experiment/{}.png'.format(idx),
            np.uint8(sr_rgb.cpu().squeeze().permute(1, 2, 0).detach().numpy() *
                     255))
def main():
    ck = util.checkpoint(args)
    seed = args.seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ck.write_log(str(args))
    # t = str(int(time.time()))
    # t = args.save_name
    # os.mkdir('./{}'.format(t))
    # (ch_out, ch_in, k, k, stride, padding)
    config = [('conv2d', [32, 16, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('+1', [True]), ('conv2d', [3, 32, 3, 3, 1, 1])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)
    # (Dataset) calculate the number of trainable tensors
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    ck.write_log(str(maml))
    ck.write_log('Total trainable tensors: {}'.format(num))

    # (Dataset) batchsz here means total episode number
    DL_MSI = dl.StereoMSIDatasetLoader(args)
    db = DL_MSI.train_loader
    dv = DL_MSI.valid_loader

    psnr = []
    l1_loss = []
    psnr_valid = []
    for epoch, (spt_ms, spt_rgb, qry_ms, qry_rgb) in enumerate(db):

        if epoch // args.epoch: break
        spt_ms, spt_rgb, qry_ms, qry_rgb = (spt_ms.to(device),
                                            spt_rgb.to(device),
                                            qry_ms.to(device),
                                            qry_rgb.to(device))

        # optimization is carried out inside meta_learner class, maml.
        accs, train_loss = maml(spt_ms, spt_rgb, qry_ms, qry_rgb, epoch)
        maml.scheduler.step()

        if epoch % args.print_every == 0:
            log_epoch = 'epoch: {} \ttraining acc: {}'.format(epoch, accs)
            ck.write_log(log_epoch)
            psnr.append(accs)
            l1_loss.append(train_loss)
            ck.plot_loss(psnr, l1_loss, epoch, args.print_every)
            if epoch % args.save_every == 0:
                with torch.no_grad():
                    ck.save(maml.net, maml.meta_optim, epoch)
                    eval_psnr = 0  # psnr loss
                    for idx, (valid_ms, valid_rgb) in enumerate(dv):
                        #print('idx', idx)
                        valid_ms, valid_rgb = prepare([valid_ms, valid_rgb])
                        sr_rgb = maml.net(valid_ms)
                        sr_rgb = torch.clamp(sr_rgb, 0, 1)
                        eval_psnr += errors.find_psnr(valid_rgb, sr_rgb)
                        ############## plot PSNR here you idiot! ###########
                    psnr_valid.append(eval_psnr / 25)
                    ck.plot_psnr(psnr_valid, epoch, args.save_every)
                    ck.write_log('Max PSNR is: {}'.format(max(psnr_valid)))
                    imsave(
                        './{}/validation/img_{}.png'.format(ck.dir, epoch),
                        np.uint8(sr_rgb[0, :, :, :].permute(
                            1, 2, 0).cpu().detach().numpy() * 255))
    ck.done()