def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    elif args.arch == "Fourier":
        model = FNO_multimodal_2d(args).to(device)
    else:
        raise ("architecture {args.arch} hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(
        model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(
        args.batch_size) + "_lr_" + str(args.lr)
    if not os.path.isdir(model_path):
        print("no folder found for path: " + model_path)
        raise

    print("Restoring weights from ", model_path + "/last_model.pt", flush=True)
    checkpoint = torch.load(model_path + "/last_model.pt")
    start_epoch = checkpoint['epoch']
    model = checkpoint['model']
    model.lr_scheduler = checkpoint['lr_scheduler']
    model.optimizer = checkpoint['optimizer']
    # df = pd.read_csv(model_path + '/'+'df.csv')

    eval_ds = SimulationDataset(args.data_folder,
                                total_sample_number=args.total_sample_number)
    means = [1e-3, 1e-3]
    torch.manual_seed(42)
    eval_loader = DataLoader(eval_ds,
                             batch_size=args.eval_batch_size,
                             shuffle=True,
                             num_workers=0)

    total_num = eval_ds.__len__()
    print("total_num: ", total_num)

    Hy_mae_list_eval = torch.zeros((total_num, 1))
    Ex_mae_list_eval = torch.zeros((total_num, 1))
    Ez_mae_list_eval = torch.zeros((total_num, 1))
    H2nd_physreg_list_eval = torch.zeros((total_num, 1))
    E2nd_physreg_list_eval = torch.zeros((total_num, 1))

    x_eval = torch.zeros((total_num, 256))
    eff_eval = torch.zeros((total_num, 1))
    eff_pred_eval = torch.zeros((total_num, 1))
    eff_error_eval = torch.zeros((total_num, 1))

    div_num = args.eval_batch_size
    loop_idx = -1
    for sample_batched in eval_loader:
        loop_idx = loop_idx + 1
        start_idx = loop_idx * div_num
        if 1 + loop_idx < int(np.round(total_num / div_num)):
            end_idx = (loop_idx + 1) * div_num
        else:
            end_idx = total_num

        print(start_idx, end_idx)
        x_batch_train, y_batch_train, top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, yeex, yeey = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device), \
                                                                                 sample_batched['yeex'].to(device),sample_batched['yeey'].to(device)

        with torch.no_grad():
            logits = model(x_batch_eval, top_bc_eval, bottom_bc_eval,
                           left_bc_eval,
                           right_bc_eval).reshape(y_batch_eval.shape)

            # normmae = model.loss_fn(logits[:,:,1:-1,1:-1].contiguous().view(args.eval_batch_size,-1), y_batch_eval[:,:,1:-1,1:-1].contiguous().view(args.eval_batch_size,-1))

            mae = torch.mean(torch.abs(logits[:, :, 1:-1, 1:-1] -
                                       y_batch_eval[:, :, 1:-1, 1:-1]),
                             dim=(1, 2, 3))
            normmae = mae / torch.mean(
                torch.abs(y_batch_eval[:, :, 1:-1, 1:-1]), dim=(1, 2, 3))
            # Hy MAE
            Hy_mae_list_eval[start_idx:end_idx, :] = normmae.reshape((-1, 1))
            # print("Hy MAE: ", torch.mean(Hy_mae_list_eval))

            pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
            # rescale the 0/1 pattern into dielectric constant
            fields = logits[:, :, 1:-1, :]
            fields = torch.cat(
                (y_batch_train[:, :, 0:1, :], fields, y_batch_train[:, :,
                                                                    -1:, :]),
                dim=2)
            fields = fields[:, :, :, 1:-1]
            fields = torch.cat(
                (y_batch_train[:, :, :, 0:1], fields, y_batch_train[:, :, :,
                                                                    -1:]),
                dim=3)
            FD_Hy = H_to_H(-fields[:, 0] * means[0], -fields[:, 1] * means[1],
                           yeex, yeey, dL, omega, pattern)
            phys_regR = model.loss_fn(FD_Hy[:, 0] / means[0],
                                      logits[:, 0, 1:-1, 1:-1]) * reg_norm
            phys_regI = model.loss_fn(FD_Hy[:, 1] / means[1],
                                      logits[:, 1, 1:-1, 1:-1]) * reg_norm
            normphysreg = 0.5 * (phys_regR + phys_regI)
            H2nd_physreg_list_eval[start_idx:end_idx, :] = normphysreg.reshape(
                (-1, 1))

    # df = pd.DataFrame(columns=['normmae_mean_H','normmae_mean_Ex', 'normmae_mean_Ez', 'normmae_mean_avg', 'normphysreg_mean_H', 'normphysreg_mean_E', 'eff_error_mean'])
    df = pd.DataFrame(columns=['normmae_mean_H', 'normphysreg_mean_H'])

    df = df.append(
        {
            'normmae_mean_H': np.mean(Hy_mae_list_eval.numpy()),
            # 'normmae_mean_Ex': np.mean(Ex_mae_list_eval.numpy()),
            # 'normmae_mean_Ez': np.mean(Ez_mae_list_eval.numpy()),
            # 'normmae_mean_avg': np.mean(1/3*(Hy_mae_list_eval + Ex_mae_list_eval + Ez_mae_list_eval).numpy()),
            'normphysreg_mean_H': np.mean(H2nd_physreg_list_eval.numpy()),
            # 'normphysreg_mean_E': np.mean(E2nd_physreg_list_eval.numpy()),
            # 'eff_error_mean': np.mean(eff_error.cpu().numpy())
        },
        ignore_index=True)

    df.to_csv(model_path + '/' + 'eval.csv', index=False)
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = FNO_multimodal_2d(args).to(device)
    model_path1 = args.model_saving_path + args.model_name + "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)
    model_path2 = args.model_saving_path + args.model_name + \
                                          "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \
                                          "_fmodes_" + str(args.f_modes) + \
                                          "_flayers_" + str(args.num_fourier_layers) + \
                                          "_Hidden_" + str(args.HIDDEN_DIM) + \
                                          "_f_padding_" + str(args.f_padding) + \
                                          "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)

    model_path=None
    if not os.path.isdir(model_path1):
        if not os.path.isdir(model_path2):
            raise("no folder found for path")
        else:
            model_path=model_path2
    else:
        model_path=model_path1
    
    print("Restoring weights from ", model_path+"/best_model.pt", flush=True)
    checkpoint = torch.load(model_path+"/best_model.pt")
    start_epoch=checkpoint['epoch']
    model.load_state_dict(checkpoint['model'].state_dict())
    model.lr_scheduler = checkpoint['lr_scheduler']
    model.optimizer = checkpoint['optimizer']
    # df = pd.read_csv(model_path + '/'+'df.csv')
        
    eval_ds = SimulationDataset(args.data_folder, total_sample_number = args.total_sample_number)
    torch.manual_seed(42)
    print("creating Dataloader:")
    eval_loader = DataLoader(eval_ds, batch_size=args.eval_batch_size, shuffle=True, num_workers=0)

    total_num = eval_ds.__len__()
    print("total_num: ", total_num)

    Hy_mae_list_eval = torch.zeros((total_num,1))
    Ex_mae_list_eval = torch.zeros((total_num,1))
    Ez_mae_list_eval = torch.zeros((total_num,1))
    H2nd_physreg_list_eval = torch.zeros((total_num,1))
    E2nd_physreg_list_eval = torch.zeros((total_num,1))

    x_eval = torch.zeros((total_num,256))
    eff_eval = torch.zeros((total_num,1))
    eff_pred_eval = torch.zeros((total_num,1))
    eff_error_eval = torch.zeros((total_num,1))

    div_num = args.eval_batch_size
    loop_idx = -1
    for sample_batched in eval_loader:
        loop_idx = loop_idx+1
        start_idx = loop_idx*div_num
        if 1+loop_idx<int(np.round(total_num/div_num)):
            end_idx = (loop_idx+1)*div_num
        else:
            end_idx = total_num
        
        print(start_idx, end_idx, flush=True)
        x_batch_eval, y_batch_eval, top_bc_eval, bottom_bc_eval, left_bc_eval, right_bc_eval, yeex, yeey = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device), \
                                                                                 sample_batched['yeex'].to(device),sample_batched['yeey'].to(device)
                    
        with torch.no_grad():
            # pattern = (x_batch_eval*(n_Si - n_air) + n_air)**2
            pattern = x_batch_eval
            bc_mean = 1/4*( torch.mean(torch.abs(top_bc_eval),dim=(2,3),keepdim=True) +
                            torch.mean(torch.abs(bottom_bc_eval),dim=(2,3),keepdim=True) + 
                            torch.mean(torch.abs(left_bc_eval),dim=(2,3),keepdim=True) + 
                            torch.mean(torch.abs(right_bc_eval),dim=(2,3),keepdim=True))
            # with autocast():
            logits = model(pattern, top_bc_eval/bc_mean, bottom_bc_eval/bc_mean, left_bc_eval/bc_mean, right_bc_eval/bc_mean).reshape(y_batch_eval.shape)*bc_mean
            
            # logits = model(pattern, top_bc_eval, bottom_bc_eval, left_bc_eval, right_bc_eval).reshape(y_batch_eval.shape)
            
            # normmae = model.loss_fn(logits[:,:,1:-1,1:-1].contiguous().view(args.eval_batch_size,-1), y_batch_eval[:,:,1:-1,1:-1].contiguous().view(args.eval_batch_size,-1))
            
            mae = torch.mean(torch.abs(logits[:,:,1:-1,1:-1] - y_batch_eval[:,:,1:-1,1:-1]),dim=(1,2,3))
            normmae = mae/torch.mean(torch.abs(y_batch_eval[:,:,1:-1,1:-1]),dim=(1,2,3))
            # Hy MAE
            Hy_mae_list_eval[start_idx:end_idx,:]=normmae.reshape((-1,1))
            # print("Hy MAE: ", torch.mean(Hy_mae_list_eval))

            # fields = logits[:,:,1:-1, :]
            # fields = torch.cat((y_batch_eval[:, :, 0:1, :], fields, y_batch_eval[:, :, -1:, :]), dim=2);
            # fields = fields[:,:,:,1:-1]
            # fields = torch.cat((y_batch_eval[:, :, :, 0:1], fields, y_batch_eval[:, :, :, -1:]), dim=3);
            fields = logits
            FD_Hy = H_to_H(-fields[:, 0], -fields[:, 1], yeex, yeey, dL, wl=1050e-9);
            phys_regR = model.loss_fn(FD_Hy[:, 0], logits[:, 0, 1:-1, 1:-1]);
            phys_regI = model.loss_fn(FD_Hy[:, 1], logits[:, 1, 1:-1, 1:-1]);
            normphysreg = 0.5*(phys_regR+phys_regI)
            H2nd_physreg_list_eval[start_idx:end_idx,:] = normphysreg.reshape((-1,1))
        
    # df = pd.DataFrame(columns=['normmae_mean_H','normmae_mean_Ex', 'normmae_mean_Ez', 'normmae_mean_avg', 'normphysreg_mean_H', 'normphysreg_mean_E', 'eff_error_mean'])
    df = pd.DataFrame(columns=['normmae_mean_H','normphysreg_mean_H'])
  
    df = df.append({'normmae_mean_H': np.mean(Hy_mae_list_eval.numpy()),
                    # 'normmae_mean_Ex': np.mean(Ex_mae_list_eval.numpy()),
                    # 'normmae_mean_Ez': np.mean(Ez_mae_list_eval.numpy()),
                    # 'normmae_mean_avg': np.mean(1/3*(Hy_mae_list_eval + Ex_mae_list_eval + Ez_mae_list_eval).numpy()),
                    'normphysreg_mean_H': np.mean(H2nd_physreg_list_eval.numpy()),
                    # 'normphysreg_mean_E': np.mean(E2nd_physreg_list_eval.numpy()),
                    # 'eff_error_mean': np.mean(eff_error.cpu().numpy())
                   }, ignore_index=True)

    df.to_csv(model_path + '/'+'eval.csv',index=False)
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    elif args.arch == "Fourier":
        model = FNO_multimodal_2d(args).to(device)
    else:
        raise ("architecture {args.arch} hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(
        model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(
        args.batch_size) + "_lr_" + str(args.lr)
    if not os.path.isdir(model_path):
        print("no folder found for path: " + model_path)
        raise

    print("Restoring weights from ", model_path + "/last_model.pt", flush=True)
    checkpoint = torch.load(model_path + "/last_model.pt")
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'].state_dict())
    model.lr_scheduler = checkpoint['lr_scheduler']
    model.optimizer = checkpoint['optimizer']
    # df = pd.read_csv(model_path + '/'+'df.csv')

    eval_ds = SimulationDataset(args.data_folder,
                                total_sample_number=args.total_sample_number)
    means = [1e-3, 1e-3]
    torch.manual_seed(42)
    eval_loader = DataLoader(eval_ds,
                             batch_size=args.eval_batch_size,
                             shuffle=True,
                             num_workers=0)

    total_num = eval_ds.__len__()
    print("total_num: ", total_num)

    count = 0
    for sample_batched in eval_loader:
        count += 1
        x_batch_eval, y_batch_eval, top_bc_eval, bottom_bc_eval, left_bc_eval, right_bc_eval, yeex, yeey = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device), \
                                                                                 sample_batched['yeex'].to(device),sample_batched['yeey'].to(device)
        with torch.no_grad():
            # pattern = (x_batch_eval*(n_Si - n_air) + n_air)**2
            # logits, bc = model(pattern, top_bc_eval, bottom_bc_eval, left_bc_eval, right_bc_eval)

            logits, bc = model(x_batch_eval, top_bc_eval, bottom_bc_eval,
                               left_bc_eval, right_bc_eval)

            logits = logits.reshape(y_batch_eval.shape)

            # pattern = (x_batch_eval*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
            # fields = logits[:,:,1:-1, :]
            # fields = torch.cat((y_batch_eval[:, :, 0:1, :], fields, y_batch_eval[:, :, -1:, :]), dim=2);
            # fields = fields[:,:,:,1:-1]
            # fields = torch.cat((y_batch_eval[:, :, :, 0:1], fields, y_batch_eval[:, :, :, -1:]), dim=3);
            # FD_Hy = H_to_H(-fields[:, 0]*means[0], -fields[:, 1]*means[1], yeex, yeey, dL, omega, pattern);
            # phys_regR = model.loss_fn(FD_Hy[:, 0]/means[0], logits[:, 0, 1:-1, 1:-1]);
            # phys_regI = model.loss_fn(FD_Hy[:, 1]/means[1], logits[:, 1, 1:-1, 1:-1]);

            fig, axs = plt.subplots(7)
            im = axs[0].imshow(x_batch_eval.cpu().numpy()[0, 0, :, :])
            plt.colorbar(im, ax=axs[0])
            im = axs[1].imshow(y_batch_eval.cpu().numpy()[0, 0, :, :])
            plt.colorbar(im, ax=axs[1])
            im = axs[2].imshow(y_batch_eval.cpu().numpy()[0, 1, :, :])
            plt.colorbar(im, ax=axs[2])
            im = axs[3].imshow(logits.cpu().numpy()[0, 0, :, :])
            plt.colorbar(im, ax=axs[3])
            im = axs[4].imshow(logits.cpu().numpy()[0, 1, :, :])
            plt.colorbar(im, ax=axs[4])
            im = axs[5].imshow(bc.cpu().numpy()[0, 0, :, :])
            plt.colorbar(im, ax=axs[5])
            im = axs[6].imshow(bc.cpu().numpy()[0, 1, :, :])
            plt.colorbar(im, ax=axs[6])
            fig.savefig("test" + str(count) + ".png")

            # np.save("/home/users/chenkaim/scripts/models/MAML_EM_simulation/data_gen/boundary_CG_ieterative_gen/x_batch_eval_"+str(count)+".npy", x_batch_eval.cpu().numpy())
            # np.save("/home/users/chenkaim/scripts/models/MAML_EM_simulation/data_gen/boundary_CG_ieterative_gen/y_batch_eval_"+str(count)+".npy", y_batch_eval.cpu().numpy())
            # np.save("/home/users/chenkaim/scripts/models/MAML_EM_simulation/data_gen/boundary_CG_ieterative_gen/logits_"+str(count)+".npy", logits.cpu().numpy())

        if count > 3:
            break
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = FNO_multimodal_2d(args).to(device)
    model_path1 = args.model_saving_path + args.model_name + "_batch_size_" + str(
        args.batch_size) + "_lr_" + str(args.lr)
    model_path2 = args.model_saving_path + args.model_name + \
                                          "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \
                                          "_fmodes_" + str(args.f_modes) + \
                                          "_flayers_" + str(args.num_fourier_layers) + \
                                          "_Hidden_" + str(args.HIDDEN_DIM) + \
                                          "_f_padding_" + str(args.f_padding) + \
                                          "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)

    model_path = None
    if not os.path.isdir(model_path1):
        if not os.path.isdir(model_path2):
            raise ("no folder found for path")
        else:
            model_path = model_path2
    else:
        model_path = model_path1

    print("Restoring weights from ", model_path + "/best_model.pt", flush=True)
    checkpoint = torch.load(model_path + "/best_model.pt")
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'].state_dict())
    model.lr_scheduler = checkpoint['lr_scheduler']
    model.optimizer = checkpoint['optimizer']
    # df = pd.read_csv(model_path + '/'+'df.csv')

    eval_ds = SimulationDataset(args.data_folder,
                                total_sample_number=args.total_sample_number)
    torch.manual_seed(123)
    eval_loader = DataLoader(eval_ds,
                             batch_size=args.eval_batch_size,
                             shuffle=True,
                             num_workers=0)

    total_num = eval_ds.__len__()
    print("total_num: ", total_num)

    count = 0
    for sample_batched in eval_loader:
        count += 1
        x_batch_eval, y_batch_eval, top_bc_eval, bottom_bc_eval, left_bc_eval, right_bc_eval, yeex, yeey = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device), \
                                                                                 sample_batched['yeex'].to(device),sample_batched['yeey'].to(device)
        with torch.no_grad():
            # pattern = (x_batch_eval*(n_Si - n_air) + n_air)**2
            pattern = x_batch_eval
            bc_mean = 1 / 4 * (
                torch.mean(torch.abs(top_bc_eval), dim=(2, 3), keepdim=True) +
                torch.mean(torch.abs(bottom_bc_eval), dim=(2, 3),
                           keepdim=True) +
                torch.mean(torch.abs(left_bc_eval), dim=(2, 3), keepdim=True) +
                torch.mean(torch.abs(right_bc_eval), dim=(2, 3), keepdim=True))
            logits = model(pattern, top_bc_eval / bc_mean,
                           bottom_bc_eval / bc_mean, left_bc_eval / bc_mean,
                           right_bc_eval / bc_mean)

            logits = logits.reshape(y_batch_eval.shape) * bc_mean

            # pattern = (x_batch_eval*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
            # fields = logits[:,:,1:-1, :]
            # fields = torch.cat((y_batch_eval[:, :, 0:1, :], fields, y_batch_eval[:, :, -1:, :]), dim=2);
            # fields = fields[:,:,:,1:-1]
            # fields = torch.cat((y_batch_eval[:, :, :, 0:1], fields, y_batch_eval[:, :, :, -1:]), dim=3);
            # FD_Hy = H_to_H(-fields[:, 0]*means[0], -fields[:, 1]*means[1], yeex, yeey, dL, omega, pattern);
            # phys_regR = model.loss_fn(FD_Hy[:, 0]/means[0], logits[:, 0, 1:-1, 1:-1]);
            # phys_regI = model.loss_fn(FD_Hy[:, 1]/means[1], logits[:, 1, 1:-1, 1:-1]);

            fig, axs = plt.subplots(7)
            im = axs[0].imshow(x_batch_eval.cpu().numpy()[0, 0, 1:-1, 1:-1])
            plt.colorbar(im, ax=axs[0])
            im = axs[1].imshow(y_batch_eval.cpu().numpy()[0, 0, 1:-1, 1:-1])
            plt.colorbar(im, ax=axs[1])
            im = axs[2].imshow(y_batch_eval.cpu().numpy()[0, 1, 1:-1, 1:-1])
            plt.colorbar(im, ax=axs[2])
            im = axs[3].imshow(logits.cpu().numpy()[0, 0, 1:-1, 1:-1])
            plt.colorbar(im, ax=axs[3])
            im = axs[4].imshow(logits.cpu().numpy()[0, 1, 1:-1, 1:-1])
            plt.colorbar(im, ax=axs[4])
            im = axs[5].imshow(
                torch.abs(y_batch_eval - logits).cpu().numpy()[0, 0, 1:-1,
                                                               1:-1])
            plt.colorbar(im, ax=axs[5])
            axs[5].set_title("loss: " + str(
                torch.mean((torch.abs(y_batch_eval - logits)) /
                           torch.mean(torch.abs(y_batch_eval))).item()),
                             fontsize=4)
            im = axs[6].imshow(
                torch.abs(y_batch_eval - logits).cpu().numpy()[0, 1, 1:-1,
                                                               1:-1])
            plt.colorbar(im, ax=axs[6])
            axs[6].set_title("loss: " + str(
                torch.mean((torch.abs(y_batch_eval - logits)) /
                           torch.mean(torch.abs(y_batch_eval))).item()),
                             fontsize=4)
            fig.savefig("test" + str(count) + ".png",
                        dpi=500,
                        transparent=True)

            # np.save("/home/users/chenkaim/scripts/models/MAML_EM_simulation/data_gen/boundary_CG_ieterative_gen/x_batch_eval_"+str(count)+".npy", x_batch_eval.cpu().numpy())
            # np.save("/home/users/chenkaim/scripts/models/MAML_EM_simulation/data_gen/boundary_CG_ieterative_gen/y_batch_eval_"+str(count)+".npy", y_batch_eval.cpu().numpy())
            # np.save("/home/users/chenkaim/scripts/models/MAML_EM_simulation/data_gen/boundary_CG_ieterative_gen/logits_"+str(count)+".npy", logits.cpu().numpy())

        if count > 20:
            break
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    elif args.arch == "Fourier":
        model = FNO_multimodal_2d(args).to(device)
    else:
        raise ("architecture {args.arch} hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(
        model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(
        args.batch_size) + "_lr_" + str(args.lr)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    ds = SimulationDataset(args.data_folder,
                           total_sample_number=args.total_sample_number)
    means = [
        1e-3, 1e-3
    ]  #ds.means; #[Hy_meanR, Hy_meanI, Ex_meanR, Ex_meanI, Ez_meanR, Ez_meanI];
    print("means: ", means)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(
        ds,
        [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

    #print("total training samples: %d, total test samples: %d" % (len(train_ds), len(test_ds)), flush=True)
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=0)

    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print(
        "total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f"
        % (len(train_ds), len(test_ds), train_mean, test_mean),
        flush=True)

    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()

    df = pd.DataFrame(columns=[
        'epoch', 'train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'
    ])

    train_loss_history = []
    train_phys_reg_history = []

    test_loss_history = []
    test_phys_reg_history = []

    start_epoch = 0
    if (args.continue_train):
        print("Restoring weights from ",
              model_path + "/last_model.pt",
              flush=True)
        checkpoint = torch.load(model_path + "/last_model.pt")
        start_epoch = checkpoint['epoch']
        model = checkpoint['model']
        model.lr_scheduler = checkpoint['lr_scheduler']
        model.optimizer = checkpoint['optimizer']
        df = pd.read_csv(model_path + '/' + 'df.csv')

    # scaler = GradScaler()

    best_loss = 1e4
    last_epoch_data_loss = 0.1
    last_epoch_physical_loss = 1.0
    for step in range(start_epoch, args.epoch):
        print("epoch: ", step, flush=True)
        reg_norm = regConstScheduler(step, args, last_epoch_data_loss,
                                     last_epoch_physical_loss)
        # training
        for sample_batched in train_loader:
            model.optimizer.zero_grad()

            x_batch_train, y_batch_train, top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, yeex, yeey = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device), \
                                                                                 sample_batched['yeex'].to(device),sample_batched['yeey'].to(device)

            # with autocast():
            logits = model(x_batch_train, top_bc_train, bottom_bc_train,
                           left_bc_train,
                           right_bc_train).reshape(y_batch_train.shape)
            #calculate the loss using the ground truth
            # loss = model.loss_fn(logits.contiguous().view(args.batch_size,-1), y_batch_train.contiguous().view(args.batch_size,-1))

            # Calculate physical residue
            pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
            # rescale the 0/1 pattern into dielectric constant
            fields = logits[:, :, 1:-1, :]
            fields = torch.cat(
                (y_batch_train[:, :, 0:1, :], fields, y_batch_train[:, :,
                                                                    -1:, :]),
                dim=2)
            fields = fields[:, :, :, 1:-1]
            fields = torch.cat(
                (y_batch_train[:, :, :, 0:1], fields, y_batch_train[:, :, :,
                                                                    -1:]),
                dim=3)
            FD_Hy = H_to_H(-fields[:, 0] * means[0], -fields[:, 1] * means[1],
                           yeex, yeey, dL, omega, pattern)
            phys_regR = model.loss_fn(FD_Hy[:, 0] / means[0],
                                      logits[:, 0, 1:-1, 1:-1]) * reg_norm
            phys_regI = model.loss_fn(FD_Hy[:, 1] / means[1],
                                      logits[:, 1, 1:-1, 1:-1]) * reg_norm
            loss = phys_regR + phys_regI

            # scaler.scale(loss).backward()
            # scaler.step(model.optimizer)
            # scaler.update()

            loss.backward()
            model.optimizer.step()

        #Save the weights at the end of each epoch
        checkpoint = {
            'epoch': step,
            'model': model,
            'optimizer': model.optimizer,
            'lr_scheduler': model.lr_scheduler
        }
        torch.save(checkpoint, model_path + "/last_model.pt")

        # evaluation
        train_loss = 0
        train_phys_reg = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train, top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, yeex, yeey = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device), \
                                                                                 sample_batched['yeex'].to(device),sample_batched['yeey'].to(device)

            with torch.no_grad():
                logits = model(x_batch_train, top_bc_train, bottom_bc_train,
                               left_bc_train,
                               right_bc_train).reshape(y_batch_train.shape)

                loss = model.loss_fn(
                    logits[:, :, 1:-1,
                           1:-1].contiguous().view(args.batch_size, -1),
                    y_batch_train[:, :, 1:-1,
                                  1:-1].contiguous().view(args.batch_size, -1))

                # Calculate physical residue
                pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                fields = logits[:, :, 1:-1, :]
                fields = torch.cat((y_batch_train[:, :, 0:1, :], fields,
                                    y_batch_train[:, :, -1:, :]),
                                   dim=2)
                fields = fields[:, :, :, 1:-1]
                fields = torch.cat(
                    (y_batch_train[:, :, :,
                                   0:1], fields, y_batch_train[:, :, :, -1:]),
                    dim=3)
                FD_Hy = H_to_H(-fields[:, 0] * means[0],
                               -fields[:, 1] * means[1], yeex, yeey, dL, omega,
                               pattern)
                phys_regR = model.loss_fn(FD_Hy[:, 0] / means[0],
                                          logits[:, 0, 1:-1, 1:-1]) * reg_norm
                phys_regI = model.loss_fn(FD_Hy[:, 1] / means[1],
                                          logits[:, 1, 1:-1, 1:-1]) * reg_norm

                #loss = loss + phys_reg1 + phys_reg2 + phys_reg3;
                train_loss += loss
                train_phys_reg += 0.5 * (phys_regR + phys_regI)

        train_loss /= len(train_loader)
        train_phys_reg /= len(train_loader)

        test_loss = 0
        test_phys_reg = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test, top_bc_test, bottom_bc_test, left_bc_test, right_bc_test, yeex, yeey = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device), \
                                                                                 sample_batched['yeex'].to(device),sample_batched['yeey'].to(device)

            with torch.no_grad():
                logits = model(x_batch_test, top_bc_test, bottom_bc_test,
                               left_bc_test,
                               right_bc_test).reshape(y_batch_test.shape)

                loss = model.loss_fn(
                    logits[:, :, 1:-1,
                           1:-1].contiguous().view(args.batch_size, -1),
                    y_batch_test[:, :, 1:-1,
                                 1:-1].contiguous().view(args.batch_size, -1))

                # Calculate physical residue
                pattern = (x_batch_test * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                fields = logits[:, :, 1:-1, :]
                fields = torch.cat(
                    (y_batch_test[:, :, 0:1, :], fields, y_batch_test[:, :,
                                                                      -1:, :]),
                    dim=2)
                fields = fields[:, :, :, 1:-1]
                fields = torch.cat(
                    (y_batch_test[:, :, :, 0:1], fields, y_batch_test[:, :, :,
                                                                      -1:]),
                    dim=3)
                FD_Hy = H_to_H(-fields[:, 0] * means[0],
                               -fields[:, 1] * means[1], yeex, yeey, dL, omega,
                               pattern)
                phys_regR = model.loss_fn(FD_Hy[:, 0] / means[0],
                                          logits[:, 0, 1:-1, 1:-1]) * reg_norm
                phys_regI = model.loss_fn(FD_Hy[:, 1] / means[1],
                                          logits[:, 1, 1:-1, 1:-1]) * reg_norm

                #loss = loss + phys_reg1 + phys_reg2 + phys_reg3;
                test_loss += loss
                test_phys_reg += 0.5 * (phys_regR + phys_regI)

        test_loss /= len(test_loader)
        test_phys_reg /= len(test_loader)
        last_epoch_data_loss = test_loss
        last_epoch_physical_loss = test_phys_reg.detach().clone()
        test_phys_reg *= reg_norm

        print(
            'train loss: %.5f, train phys reg: %.5f, test loss: %.5f, test phys reg: %.5f, last_physical_loss: %.5f'
            % (train_loss, train_phys_reg, test_loss, test_phys_reg,
               last_epoch_physical_loss),
            flush=True)

        model.lr_scheduler.step()

        df = df.append(
            {
                'epoch': step + 1,
                'lr': str(model.lr_scheduler.get_last_lr()),
                'train_loss': train_loss.item(),
                'train_phys_reg': train_phys_reg.item(),
                'test_loss': test_loss.item(),
                'test_phys_reg': test_phys_reg.item(),
            },
            ignore_index=True)

        df.to_csv(model_path + '/' + 'df.csv', index=False)

        if (test_loss < best_loss):
            best_loss = test_loss
            checkpoint = {
                'epoch': step,
                'model': model,
                'optimizer': model.optimizer,
                'lr_scheduler': model.lr_scheduler
            }
            torch.save(checkpoint, model_path + "/best_model.pt")