예제 #1
0
    def parse(self):
        args = self.parse_args()
        date = '{}'.format(time.strftime('%b_%d'))
        args.run_dir = args.exp_dir + '/' + args.exp_name + '/' + date \
            + f'/{args.net}_noise_train{args.noise_levels_train}_'\
            f'test{args.noise_levels_test}_{args.transform}_'\
            f'epochs{args.epochs}_bs{args.batch_size}_lr{args.lr}'
        args.ckpt_dir = args.run_dir + '/checkpoints'

        if not args.post:
            mkdirs([args.run_dir, args.ckpt_dir])

        # seed
        if args.seed is None:
            args.seed = random.randint(1, 10000)
        print("Random Seed: ", args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = True

        print('Arguments:')
        pprint(vars(args))

        if not args.post:
            with open(args.run_dir + "/args.txt", 'w') as args_file:
                json.dump(vars(args), args_file, indent=4)

        return args
예제 #2
0
    def parse(self):
        args = self.parse_args()

        args.run_dir = args.exp_dir + '/' + args.exp_name \
            + '/kle{}/ntrain{}_blocks{}_growth{}_nif{}_drop{}_batch{}_lr{}_wd{}_epochs{}'.format(
                args.kle, args.ntrain, args.blocks, args.growth_rate,
                args.init_features, args.drop_rate, args.batch_size,
                args.lr, args.weight_decay, args.epochs
            )
        args.ckpt_dir = args.run_dir + '/checkpoints'
        mkdirs([args.run_dir, args.ckpt_dir])

        assert args.epochs % args.ckpt_freq == 0, 'epochs must'\
            'be dividable by ckpt_freq'

        # seed
        if args.seed is None:
            args.seed = random.randint(1, 10000)
        print("Random Seed: ", args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)

        print('Arguments:')
        pprint(vars(args))

        if not args.post:
            with open(args.run_dir + "/args.txt", 'w') as args_file:
                json.dump(vars(args), args_file, indent=4)

        return args
예제 #3
0
    def parse(self):
        args = self.parse_args()

        args.run_dir = args.exp_dir + '/' + args.exp_name \
            + '/kle{}/nsamples{}_ntrain{}_batch{}_lr{}_noiselr{}_epochs{}'.format(
                args.kle, args.n_samples, args.ntrain, args.batch_size, args.lr,
                args.lr_noise, args.epochs)

        args.ckpt_dir = args.run_dir + '/checkpoints'
        mkdirs([args.run_dir, args.ckpt_dir])

        assert args.epochs % args.ckpt_freq == 0, 'epochs must'\
            'be dividable by ckpt_freq'
        assert args.ntrain % args.batch_size == 0, 'num of training data must'\
            'be dividable by batch size'

        # seed
        if args.seed is None:
            args.seed = random.randint(1, 10000)
        print("Random Seed: ", args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)

        print('Arguments:')
        pprint(vars(args))

        if not args.post:
            with open(args.run_dir + "/args.txt", 'w') as args_file:
                json.dump(vars(args), args_file, indent=4)

        return args
예제 #4
0
    def parse(self, save=True):
        if not self.initialized:
            self.initialize()
        self.opt = self.parser.parse_args()
        self.opt.isTrain = self.isTrain  # train or test

        str_ids = self.opt.gpu_ids.split(',')
        self.opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                self.opt.gpu_ids.append(id)

        # set gpu ids
        if len(self.opt.gpu_ids) > 0:
            torch.cuda.set_device(self.opt.gpu_ids[0])

        args = vars(self.opt)

        print('------------ Options -------------')
        for k, v in sorted(args.items()):
            print('%s: %s' % (str(k), str(v)))
        print('-------------- End ----------------')

        # save to the disk
        expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
        misc.mkdirs(expr_dir)
        if save and not self.opt.continue_train:
            file_name = os.path.join(expr_dir, 'opt.txt')
            with open(file_name, 'wt') as opt_file:
                opt_file.write('------------ Options -------------\n')
                for k, v in sorted(args.items()):
                    opt_file.write('%s: %s\n' % (str(k), str(v)))
                opt_file.write('-------------- End ----------------\n')
        return self.opt
예제 #5
0
    def parse(self):
        args = self.parse_args()
        tid = 2
        hparams = f'{args.data}_{tid}_run{args.run}_bs{args.batch_size}'
        #        if args.debug:
        #            hparams = 'debug/' + hparams
        args.run_dir = args.exp_dir + '/' + args.exp_name + '/' + hparams
        args.ckpt_dir = args.run_dir + '/checkpoints'
        # print(args.run_dir)
        # print(args.ckpt_dir)
        mkdirs(args.run_dir, args.ckpt_dir)

        # assert args.ntrain % args.batch_size == 0 and \
        #     args.ntest % args.test_batch_size == 0

        if args.seed is None:
            args.seed = random.randint(1, 10000)
        print("Random Seed: ", args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)

        print('Arguments:')
        pprint(vars(args))
        with open(args.run_dir + "/args.txt", 'w') as args_file:
            json.dump(vars(args), args_file, indent=4)

        return args
예제 #6
0
 def _save(self, args):
     
     expr_dir = os.path.join(self._opt.checkpoints_dir, self._opt.name)
     print(expr_dir)
     misc.mkdirs(expr_dir)
     file_name = os.path.join(expr_dir, 'opt_%s.txt' % ('train' if self.is_train else 'test'))
     with open(file_name, 'wt') as opt_file:
         opt_file.write('------------ Options -------------\n')
         for k, v in sorted(args.items()):
             opt_file.write('%s: %s\n' % (str(k), str(v)))
         opt_file.write('-------------- End ----------------\n')
예제 #7
0
    def parse(self):
        args = self.parse_args()
        args.LU_decompose = not args.no_LU_decompose
        assert len(args.enc_blocks) == len(args.flow_blocks)
        hparams = f'kle{args.kle}_ntrain{args.ntrain}_'\
            f'ENC_blocks{args.enc_blocks}_FLOW_blocks{args.flow_blocks}_'\
            f'wb{args.weight_bound}_beta{args.beta}_'\
            f'batch{args.batch_size}_lr{args.lr}_epochs{args.epochs}'
        if args.debug:
            hparams = 'debug/' + hparams
        if args.data_init:
            hparams = hparams + '_data_init'
        args.run_dir = args.exp_dir + '/' + args.exp_name + '/' + hparams
        args.ckpt_dir = args.run_dir + '/checkpoints'
        args.train_dir = args.run_dir + '/training'
        args.pred_dir = args.train_dir + '/predictions'
        mkdirs(args.run_dir, args.ckpt_dir, args.train_dir, args.pred_dir)

        if args.seed is None:
            args.seed = random.randint(1, 10000)
        print("Random Seed: ", args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)

        args_file = args.run_dir + "/args.txt"
        if os.path.isfile(args_file):
            # args.resume, args.ckpt directly overwrite
            if args.ckpt_epoch is None and args.resume:
                with open(args_file, 'r') as args_f:
                    args_old = argparse.Namespace(**json.load(args_f))
                args.ckpt_epoch = args_old.ckpt_epoch
        else:
            with open(args_file, 'w') as args_f:
                json.dump(vars(args), args_f, indent=4)
        print('Arguments:')
        pprint(vars(args))
        return args
예제 #8
0
        pprint(vars(args))

        if not args.post:
            with open(args.run_dir + "/args.txt", 'w') as args_file:
                json.dump(vars(args), args_file, indent=4)

        return args


args = Parser().parse()
device = torch.device(
    f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')

args.train_dir = args.run_dir + "/training"
args.pred_dir = args.train_dir + "/predictions"
mkdirs([args.train_dir, args.pred_dir])
if args.net == 'unet':
    model = UnetN2N(args.in_channels, args.out_channels).to(device)
elif args.net == 'unetv2':
    model = UnetN2Nv2(args.in_channels, args.out_channels).to(device)

if args.debug:
    print(model)
    print(model.model_size)

if args.transform == 'four_crop':
    # wide field images may have complete noise in center-crop case
    transform = transforms.Compose([
        transforms.FiveCrop(args.imsize),
        transforms.Lambda(lambda crops: torch.stack(
            [fluore_to_tensor(crop) for crop in crops[:4]])),
예제 #9
0
        print('Arguments:')
        pprint(vars(args))
        with open(args.run_dir + "/args.txt", 'w') as args_file:
            json.dump(vars(args), args_file, indent=4)

        return args


args = Parser().parse()
device = torch.device(
    f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")
print(device)
args.train_dir = args.run_dir + '/training'
args.pred_dir = args.train_dir + '/predictions'
mkdirs(args.train_dir, args.pred_dir)

data, ref = LoadData(device)

model = DenseED(in_channels=1,
                out_channels=2,
                imsize=args.imsize,
                blocks=args.blocks,
                growth_rate=args.growth_rate,
                init_features=args.init_features,
                drop_rate=args.drop_rate,
                out_activation=None,
                upsample=args.upsample).to(device)
# modelname = "experiments/codec/mixed_residual/debug/SIMP_penal3_plus1_1_run1_bs64/checkpoints/model_epoch5300.pth"
# model = torch.load(modelname)
def main():
    parser = argparse.ArgumentParser(description='CNN to solve PDE')
    parser.add_argument('--exp-dir',
                        type=str,
                        default='./experiments/solver',
                        help='color map')
    parser.add_argument('--nonlinear',
                        action='store_true',
                        default=False,
                        help='set True for nonlinear PDE')
    # data
    parser.add_argument('--data-dir',
                        type=str,
                        default="./datasets",
                        help='directory to dataset')
    parser.add_argument('--data',
                        type=str,
                        default='grf',
                        choices=['grf', 'channelized', 'warped_grf'],
                        help='data type')
    parser.add_argument('--kle', type=int, default=512, help='# kle terms')
    parser.add_argument('--imsize', type=int, default=64, help='image size')
    parser.add_argument('--idx',
                        type=int,
                        default=8,
                        help='idx of input, please use 0 ~ 999')
    parser.add_argument('--alpha1',
                        type=float,
                        default=1.0,
                        help='coefficient for the squared term')
    parser.add_argument('--alpha2',
                        type=float,
                        default=1.0,
                        help='coefficient for the cubic term')
    # latent size: (nz, sz, sz)
    parser.add_argument('--nz',
                        type=int,
                        default=1,
                        help='# feature maps of latent z')
    # parser.add_argument('--sz', type=int, default=16, help='feature map size of latent z')
    parser.add_argument('--blocks',
                        type=list,
                        default=[8, 6],
                        help='# layers in each dense block of the decoder')
    parser.add_argument('--weight-bound',
                        type=float,
                        default=10,
                        help='weight for boundary condition loss')
    parser.add_argument('--lr', type=float, default=0.5, help='learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=500,
                        help='# epochs to train')
    parser.add_argument('--test-freq',
                        type=int,
                        default=50,
                        help='every # epoch to test')
    parser.add_argument('--ckpt-freq',
                        type=int,
                        default=250,
                        help='every # epoch to save model')
    parser.add_argument('--cmap', type=str, default='jet', help='color map')
    parser.add_argument(
        '--same-scale',
        action='store_true',
        help='true for setting noise to be same scale as output')
    parser.add_argument('--animate',
                        action='store_true',
                        help='true to plot animate figures')
    parser.add_argument('--cuda', type=int, default=1, help='cuda number')
    parser.add_argument('-v',
                        '--verbose',
                        action='store_true',
                        help='True for versbose output')

    args = parser.parse_args()
    pprint(vars(args))
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")
    dataset = f'{args.data}_kle{args.kle}' if args.data == 'grf' else args.data
    hyparams = f'{dataset}_idx{args.idx}_dz{args.nz}_blocks{args.blocks}_'\
        f'lr{args.lr}_wb{args.weight_bound}_epochs{args.epochs}'

    if args.nonlinear:
        from utils.fenics import solve_nonlinear_poisson
        exp_name = 'conv_mixed_residual_nonlinear'
        from models.darcy import conv_constitutive_constraint_nonlinear as constitutive_constraint
        hyparams = hyparams + f'_alpha1_{args.alpha1}_alpha2_{args.alpha2}'
    else:
        exp_name = 'conv_mixed_residual'
        from models.darcy import conv_constitutive_constraint as constitutive_constraint

    run_dir = args.exp_dir + '/' + exp_name + '/' + hyparams
    mkdirs(run_dir)
    # load data
    assert args.idx < 1000
    if args.data == 'grf':
        assert args.kle in [512, 128, 1024, 2048]
        ntest = 1000 if args.kle == 512 else 1024
        hdf5_file = args.data_dir + f'/{args.imsize}x{args.imsize}/kle{args.kle}_lhs{ntest}_test.hdf5'
    elif args.data == 'warped_grf':
        hdf5_file = args.data_dir + f'/{args.imsize}x{args.imsize}/warped_gp_ng64_n1000.hdf5'
    elif args.data == 'channelized':
        hdf5_file = args.data_dir + f'/{args.imsize}x{args.imsize}/channel_ng64_n512_test.hdf5'
    else:
        raise ValueError('No dataset are found for the speficied parameters')
    print(f'dataset: {hdf5_file}')
    with h5py.File(hdf5_file, 'r') as f:
        input_data = f['input'][()]
        output_data = f['output'][()]
        print(f'input: {input_data.shape}')
        print(f'output: {output_data.shape}')
    # permeability, (1, 1, 64, 64)
    perm_arr = input_data[[args.idx]]
    # pressure, flux_hor, flux_ver, (3, 64, 64)
    if args.nonlinear:
        # solve nonlinear Darcy for perm_arr with FEniCS
        output_file = run_dir + '/output_fenics.npy'
        if os.path.isfile(output_file):
            output_arr = np.load(output_file)
            print('Loaded solved output field')
        else:
            print('Solve nonlinear poisson with FEniCS...')
            output_arr = solve_nonlinear_poisson(perm_arr[0, 0], args.alpha1,
                                                 args.alpha2, run_dir)
            np.save(output_file, output_arr)
    else:
        output_arr = output_data[args.idx]
    print('output shape: ', output_arr.shape)
    # model
    model = Decoder(args.nz, out_channels=3, blocks=args.blocks).to(device)
    print(f'model size: {model.model_size}')

    fixed_latent = torch.randn(1, args.nz, 16, 16).to(device) * 0.5
    perm_tensor = torch.FloatTensor(perm_arr).to(device)

    sobel_filter = SobelFilter(args.imsize, correct=True, device=device)
    optimizer = optim.LBFGS(model.parameters(),
                            lr=args.lr,
                            max_iter=20,
                            history_size=50)

    logger = {}
    logger['loss'] = []

    def train(epoch):
        model.train()

        def closure():
            optimizer.zero_grad()
            output = model(fixed_latent)
            if args.nonlinear:
                energy = constitutive_constraint(perm_tensor, output,
                    sobel_filter, args.alpha1, args.alpha2) \
                    + continuity_constraint(output, sobel_filter)
            else:
                energy = constitutive_constraint(
                    perm_tensor, output, sobel_filter) + continuity_constraint(
                        output, sobel_filter)
            loss_dirichlet, loss_neumann = boundary_condition(output)
            loss_boundary = loss_dirichlet + loss_neumann
            loss = energy + loss_boundary * args.weight_bound
            loss.backward()
            if args.verbose:
                print(f'epoch {epoch}: loss {loss.item():6f}, '\
                    f'energy {energy.item():.6f}, diri {loss_dirichlet.item():.6f}, '\
                    f'neum {loss_neumann.item():.6f}')
            return loss

        loss = optimizer.step(closure)
        loss_value = loss.item() if not isinstance(loss, float) else loss
        logger['loss'].append(loss_value)
        print(f'epoch {epoch}: loss {loss_value:.6f}')
        if epoch % args.ckpt_freq == 0:
            torch.save(model.state_dict(),
                       run_dir + "/model_epoch{}.pth".format(epoch))

    def test(epoch):
        if epoch % args.epochs == 0 or epoch % args.test_freq == 0:
            output = model(fixed_latent)
            output = to_numpy(output)
            if args.animate:
                i_plot = epoch // args.test_freq
                plot_prediction_det_animate2(run_dir,
                                             output_arr,
                                             output[0],
                                             epoch,
                                             args.idx,
                                             i_plot,
                                             plot_fn='imshow',
                                             cmap=args.cmap,
                                             same_scale=args.same_scale)
            else:
                plot_prediction_det(run_dir,
                                    output_arr,
                                    output[0],
                                    epoch,
                                    args.idx,
                                    plot_fn='imshow',
                                    cmap=args.cmap,
                                    same_scale=args.same_scale)
            np.save(run_dir + f'/epoch{epoch}.npy', output[0])

    print('start training...')
    dryrun = False
    tic = time.time()
    for epoch in range(1, args.epochs + 1):
        if not dryrun:
            train(epoch)
        test(epoch)
    print(
        f'Finished optimization for {args.epochs} epochs using {(time.time()-tic)/60:.3f} minutes'
    )
    save_stats(run_dir, logger, 'loss')
    # save input
    plt.imshow(perm_arr[0, 0])
    plt.colorbar()
    plt.savefig(run_dir + '/input.png')
    plt.close()
        with open(args.run_dir + "/args.txt", 'w') as args_file:
            json.dump(vars(args), args_file, indent=4)

        return args


if __name__ == '__main__':

    args = Parser().parse()
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")

    args.train_dir = args.run_dir + '/training'
    args.pred_dir = args.train_dir + '/predictions'
    mkdirs(args.train_dir, args.pred_dir)

    model = DenseED(in_channels=1,
                    out_channels=3,
                    imsize=args.imsize,
                    blocks=args.blocks,
                    growth_rate=args.growth_rate,
                    init_features=args.init_features,
                    drop_rate=args.drop_rate,
                    out_activation=None,
                    upsample=args.upsample).to(device)
    if args.debug:
        print(model)
    # if start from ckpt
    if args.ckpt_epoch is not None:
        ckpt_file = args.run_dir + f'/checkpoints/model_epoch{args.ckpt_epoch}.pth'
예제 #12
0
parser.add_argument('--cuda', type=int, default=0, help='gpu #')
args_post = parser.parse_args()

run_dir = './experiments/cglow/reverse_kld/kle100_ntrain4096_ENC_blocks[3, 4, 4]_'\
    'FLOW_blocks[6, 6, 6]_wb50_beta150.0_batch32_lr0.0015_epochs400'

if args_post.run_dir is not None:
    run_dir = args_post.run_dir

device = torch.device(f'cuda:{args_post.cuda}' if torch.cuda.is_available() else 'cpu')
# load the args for pre-trained run
args = load_args(run_dir)
args.device = device
args.post_dir = args.run_dir + f'/post_proc_mc{args_post.n_mc}_nsamples{args_post.n_samples}_'\
    f'varsamples{args_post.var_samples}_temp{args_post.temperature}'
mkdirs(args.post_dir)

# load the pre-trained model
cglow = MultiScaleCondGlow(img_size=args.imsize, 
                        x_channels=args.x_channels, 
                        y_channels=args.y_channels, 
                        enc_blocks=args.enc_blocks, 
                        flow_blocks=args.flow_blocks, 
                        LUdecompose=args.LU_decompose,
                        squeeze_factor=2,
                        data_init=args.data_init)
print(cglow.model_size)
ckpt_file = args.ckpt_dir + f"/model_epoch{args.epochs}.pth"
checkpoint = torch.load(ckpt_file, map_location='cpu')
cglow.load_state_dict(checkpoint['model_state_dict'])
cglow = cglow.to(device)
예제 #13
0
    ]
    image_types = [
        'TwoPhoton_BPAE_R', 'TwoPhoton_BPAE_G', 'TwoPhoton_BPAE_B',
        'TwoPhoton_MICE'
    ]
# image_types = ['test_mix']
run_dir = args_test.pretrain_dir + f'/{args_test.model}'

with open(run_dir + '/args.txt') as args_file:
    args = Namespace(**json.load(args_file))
pprint(args)
if args_test.no_cuda:
    test_dir = run_dir + '/benchmark_cpu'
else:
    test_dir = run_dir + '/benchmark_gpu'
mkdirs(test_dir)

if args_test.model == 'dncnn':
    model = DnCNN(depth=args.depth,
                  n_channels=args.width,
                  image_channels=1,
                  use_bnorm=True,
                  kernel_size=3)
elif args_test.model == 'n2n':
    model = UnetN2N(args.in_channels, args.out_channels)

if args.debug:
    print(model)
    print(module_size(model))
model.load_state_dict(
    torch.load(run_dir + f'/checkpoints/model_epoch{args.epochs}.pth',
def main():
    parser = argparse.ArgumentParser(description='CNN to solve PDE')
    parser.add_argument('--exp-dir', type=str, default='./experiments/solver', help='color map')
    # data
    parser.add_argument('--data-dir', type=str, default="./datasets", help='directory to dataset')
    parser.add_argument('--data', type=str, default='grf', choices=['grf', 'channelized', 'warped_grf'], help='data type')
    parser.add_argument('--kle', type=int, default=512, help='# kle terms')
    parser.add_argument('--imsize', type=int, default=64, help='image size')
    parser.add_argument('--idx', type=int, default=8, help='idx of input, please use 0 ~ 999')
    parser.add_argument('--alpha1', type=float, default=1.0, help='coefficient for the squared term')
    parser.add_argument('--alpha2', type=float, default=1.0, help='coefficient for the cubic term')
    parser.add_argument('--dim-hidden', type=int, default=512, help='# nodes in each hidden layer')
    parser.add_argument('--layers-hidden', type=int, default=8, help='# hidden layers')
    parser.add_argument('--off-grid', action='store_true', help='set True to use colloc ')
    parser.add_argument('--n-colloc', type=int, default=4096, help='# collocation points')
    parser.add_argument('--weight-bound', type=float, default=10, help='weight for boundary condition loss')
    parser.add_argument('--lr', type=float, default=0.5, help='learning rate')
    parser.add_argument('--epochs', type=int, default=2000, help='# epochs to train')
    parser.add_argument('--test-freq', type=int, default=50, help='every # epoch to test')
    parser.add_argument('--ckpt-freq', type=int, default=250, help='every # epoch to save model')
    parser.add_argument('--cmap', type=str, default='jet', help='color map')
    parser.add_argument('--same-scale', action='store_true', help='true for setting noise to be same scale as output')
    parser.add_argument('--animate', action='store_true', help='true to plot animate figures')
    parser.add_argument('--cuda', type=int, default=2, help='cuda number')
    parser.add_argument('-v', '--verbose', action='store_true', help='True for versbose output')


    args = parser.parse_args()
    pprint(vars(args))
    device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")
 
    exp_name = 'fc_mixed_residual'
    dataset = f'{args.data}_kle{args.kle}' if args.data == 'grf' else args.data
    hyparams = f'{dataset}_idx{args.idx}_dhid{args.dim_hidden}_lhid{args.layers_hidden}_alpha1_{args.alpha1}_alpha2_{args.alpha2}_'\
    f'lr{args.lr}_wb{args.weight_bound}_epochs{args.epochs}_ongrid_{not args.off_grid}_ncolloc{args.n_colloc}'

    run_dir = args.exp_dir + '/' + exp_name + '/' + hyparams
    mkdirs(run_dir)
    # load data
    assert args.idx < 1000
    if args.data == 'grf':
        assert args.kle in [512, 128, 1024, 2048]
        ntest = 1000 if args.kle == 512 else 1024
        hdf5_file = args.data_dir + f'/{args.imsize}x{args.imsize}/kle{args.kle}_lhs{ntest}_test.hdf5'
    elif args.data == 'warped_grf':
        hdf5_file = args.data_dir + f'/{args.imsize}x{args.imsize}/warped_gp_ng64_n1000.hdf5'
    elif args.data == 'channelized':
        assert args.idx < 512
        hdf5_file = args.data_dir + f'/{args.imsize}x{args.imsize}/channel_ng64_n512_test.hdf5'
    else:
        raise ValueError('No dataset are found for the speficied parameters')
    print(f'dataset: {hdf5_file}')
    with h5py.File(hdf5_file, 'r') as f:
        input_data = f['input'][()]
        output_data = f['output'][()]
        print(f'input: {input_data.shape}')    
        print(f'output: {output_data.shape}') 
    # permeability, (1, 1, 64, 64)
    perm_arr = input_data[[args.idx]]
    # pressure, flux_hor, flux_ver, (3, 64, 64)
    output_arr = output_data[args.idx]

    def to_tensor_gpu(*numpy_seq):
        # x: numpy array --> tensor on GPU
        return (torch.FloatTensor(x).to(device) for x in numpy_seq)

    # define networks
    net_u = CPPN(dim_in=2, dim_out=3, dim_hidden=args.dim_hidden, 
        layers_hidden=args.layers_hidden).to(device)
    print(net_u)
    print(net_u._model_size())
    optimizer = optim.LBFGS(net_u.parameters(), 
                            lr=args.lr, max_iter=20, history_size=50)
            
    logger = {}    
    logger['loss'] = []
    ngrids = [args.imsize, args.imsize]
    sampler = SampleSpatial2d(int(ngrids[0]), int(ngrids[1]))
    colloc_on_grid = not args.off_grid
    # for batch optimization
    x_colloc = sampler.colloc(colloc_on_grid, n_samples=args.n_colloc).to(device)
    x_dirichlet = torch.cat((sampler.left(on_grid=False, n_samples=256), 
        sampler.right(on_grid=False, n_samples=256)), 0).to(device)
    y_dirichlet = torch.cat((torch.ones(256, 1), torch.zeros(256, 1)), 0).to(device)
    x_neumann = torch.cat((sampler.top(colloc_on_grid), 
        sampler.bottom(colloc_on_grid)), 0).to(device)
    print(sampler.coordinates_no_boundary.shape)

    K_true_tensor, = to_tensor_gpu(perm_arr.reshape(-1, 1))
    if args.verbose:
        print('x_colloc: {}'.format(x_colloc.shape))
        print('x_dirc: {}'.format(x_dirichlet.shape))
        print('y_dirc: {}'.format(y_dirichlet.shape))
        print('x_neumann: {}'.format(x_neumann.shape))

    def train(epoch):
        net_u.train()
        def closure():
            optimizer.zero_grad()
            loss_colloc = mixed_residual_fc(net_u, x_colloc, K_true_tensor, 
                args.verbose, rand_colloc=args.off_grid)
            # loss_colloc = 0
            loss_dirichlet = F.mse_loss(net_u(x_dirichlet)[:, [0]], y_dirichlet)
            loss_neumann = neumann_boundary_mixed(net_u, x_neumann)
            loss = loss_colloc + args.weight_bound * (loss_dirichlet + loss_neumann)
            loss.backward()
            if args.verbose:
                print(f'epoch {epoch}: colloc {loss_colloc:.6f}, '
                    f'diri {loss_dirichlet:.6f}, neum {loss_neumann:.6f}')
            return loss
        loss = optimizer.step(closure)
        loss_value = loss.item() if not isinstance(loss, float) else loss
        logger['loss'].append(loss_value)
        print('epoch {}: loss {:.10f}'.format(epoch, loss_value))

        if epoch % args.ckpt_freq == 0:
            torch.save(net_u.state_dict(), run_dir + "/model_epoch{}.pth".format(epoch))
                
    def test(epoch):
        if epoch % args.epochs == 0 or epoch % args.test_freq == 0:
            # plot the solution
            xx, yy = np.meshgrid(np.arange(ngrids[0]), np.arange(ngrids[1]))
            x_test = xx.flatten()[:, None] / ngrids[1]
            y_test = yy.flatten()[:, None] / ngrids[0]
            x_test, y_test = to_tensor_gpu(x_test, y_test)
            net_u.eval()
            x_test.requires_grad = True
            y_test.requires_grad = True
            xy_test = torch.cat((y_test, x_test), 1)
            y_pred = net_u(xy_test)
            target = output_arr
            # three output of net_u from 0-3 channel: u, flux_y, flux_x
            u_pred = y_pred[:, 0].detach().cpu().numpy().reshape(*ngrids)
            u_y = y_pred[:, 1].detach().cpu().numpy().reshape(*ngrids)
            u_x = y_pred[:, 2].detach().cpu().numpy().reshape(*ngrids)
            prediction = np.stack((u_pred, u_x, u_y))
            # prediction = y_pred.view(*ngrids, -1).transpose(0, 1).permute(2, 1, 0).detach().cpu().numpy()
            if args.animate:
                i_plot = epoch // args.test_freq
                plot_prediction_det_animate2(run_dir, target, prediction, epoch, args.idx, i_plot,
                    plot_fn='imshow', cmap=args.cmap, same_scale=args.same_scale)
            else:
                plot_prediction_det(run_dir, target, prediction, epoch, args.idx, 
                    plot_fn='imshow', cmap=args.cmap, same_scale=args.same_scale)
            np.save(run_dir + f'/epoch{epoch}.npy', prediction)
           
    print('start training...')
    dryrun = False
    tic = time.time()
    for epoch in range(1, args.epochs + 1):
        if not dryrun:
            train(epoch)
        test(epoch)
    print(f'Finished training {args.epochs} epochs in {(time.time()-tic)/60:.3f} minutes')
    save_stats(run_dir, logger, 'loss')

    # save input
    plt.close()
    plt.imshow(np.log(perm_arr[0, 0]))
    plt.colorbar()
    plt.savefig(run_dir + '/input_logK.png')
    plt.close()

    # super-resultion one
    ngrids = (640, 640)
    xx, yy = np.meshgrid(np.arange(ngrids[0]), np.arange(ngrids[1]))
    x_test = torch.FloatTensor(np.stack((yy.flatten() / (ngrids[1]-1), 
        xx.flatten() / (ngrids[0]-1)), 1)).to(device)
    net_u.eval()
    u_pred = net_u(x_test)
    u_pred = u_pred[:, 0].reshape(*ngrids).detach().cpu().numpy() 
    plt.contourf(u_pred, 65)
    plt.colorbar()
    plt.savefig(run_dir + '/solution_HR.png')
    plt.close()
noise_levels = [1]
image_types = [
    'Confocal_BPAE_R', 'Confocal_BPAE_G', 'Confocal_BPAE_B', 'Confocal_FISH'
]

data_dir = args_test.data_root
run_dir = args_test.pretrain_dir + f'/{args_test.model}'

with open(run_dir + '/args.txt') as args_file:
    args = Namespace(**json.load(args_file))
pprint(args)
if args_test.no_cuda:
    test_dir = run_dir + '/example_cpu'
else:
    test_dir = run_dir + '/example_gpu'
mkdirs(test_dir)

if args_test.model == 'dncnn':
    model = DnCNN(depth=args.depth,
                  n_channels=args.width,
                  image_channels=1,
                  use_bnorm=True,
                  kernel_size=3)
elif args_test.model == 'n2n':
    model = UnetN2N(args.in_channels, args.out_channels)

if args.debug:
    print(model)
    print(module_size(model))
model.load_state_dict(
    torch.load(run_dir + f'/checkpoints/model_epoch{args.epochs}.pth',