def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) device = torch.device('cuda') model = None if args.arch == "UNet": model = UNet(args).to(device) elif args.arch == "Fourier": model = FNO_multimodal_2d(args).to(device) else: raise ("architecture {args.arch} hasn't been added!!") # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True) # model.optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # model.lr_scheduler = optim.lr_scheduler.ExponentialLR(model.optimizer, args.exp_decay) tmp = filter(lambda x: x.requires_grad, model.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(model) #for name, param in model.named_parameters(): # print(name, param.size()) print('Total trainable tensors:', num, flush=True) model_path1 = args.model_saving_path + args.model_name + \ "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \ "_fmodes_" + str(args.f_modes) + \ "_flayers_" + str(args.num_fourier_layers) + \ "_Hidden_" + str(args.HIDDEN_DIM) + \ "_f_padding_" + str(args.f_padding) + \ "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr) model_path2 = args.model_saving_path + args.model_name + "_batch_size_" + str( args.batch_size) + "_lr_" + str(args.lr) model_path = model_path1 if not os.path.isdir(model_path1): if not os.path.isdir(model_path2): raise ("model path not found: ".model_path) else: model_path = model_path2 # load model: print("Restoring weights from ", model_path + "/best_model.pt", flush=True) checkpoint = torch.load(model_path + "/best_model.pt") # start_epoch=checkpoint['epoch'] model = checkpoint['model'] # model.lr_scheduler = checkpoint['lr_scheduler'] # model.optimizer = checkpoint['optimizer'] # df = pd.read_csv(model_path + '/'+'df.csv') ds = DDM_Dataset(args.data_folder, total_sample_number=args.total_sample_number) torch.manual_seed(42) DDM_loader = DataLoader(ds, batch_size=args.plot_figs, shuffle=True, num_workers=0) # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg']) size_x = args.scale_sx size_y = args.scale_sy loss = [] for sample_batched in DDM_loader: with torch.no_grad(): DDM_img, DDM_Hy = sample_batched['structure'], sample_batched[ 'field'] top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, x_batch_train, intep_field = scale_four_point_interp( DDM_img, DDM_Hy, args, device) bc_mean = 1 / 4 * ( torch.mean(torch.abs(top_bc_train), dim=(2, 3), keepdim=True) + torch.mean( torch.abs(bottom_bc_train), dim=(2, 3), keepdim=True) + torch.mean(torch.abs(left_bc_train), dim=(2, 3), keepdim=True) + torch.mean( torch.abs(right_bc_train), dim=(2, 3), keepdim=True)) # with autocast(): logits = model(x_batch_train / 4, left_bc_train / bc_mean, right_bc_train / bc_mean, top_bc_train / bc_mean, bottom_bc_train / bc_mean).reshape( intep_field.shape) * bc_mean # reconstruct the whole field scaled_result = scale_back(logits.cpu(), args, device) # print("shapes: ",scaled_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape ) loss = torch.mean(torch.abs( scaled_result.contiguous().view(args.plot_figs, -1) - DDM_Hy[:, :, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].contiguous().view(args.plot_figs, -1)), dim=1).numpy() break for i in range(args.plot_figs): plt.rcParams["font.size"] = "5" fig, axs = plt.subplots(5, 2) axs[0, 0].imshow( DDM_img.cpu().numpy()[i, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[0, 1].imshow(x_batch_train.cpu().numpy()[i, 0, :, :]) axs[1, 0].imshow( DDM_Hy.cpu().numpy()[i, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[1, 1].imshow( DDM_Hy.cpu().numpy()[i, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) im = axs[2, 0].imshow(intep_field.cpu().numpy()[i, 0, :, :]) axs[2, 0].set_title("scaled_gt", fontsize=4) plt.colorbar(im, ax=axs[2, 0]) im = axs[2, 1].imshow(intep_field.cpu().numpy()[i, 1, :, :]) axs[2, 0].set_title("scaled_gt", fontsize=4) plt.colorbar(im, ax=axs[2, 1]) im = axs[3, 0].imshow(logits.cpu().numpy()[i, 0, :, :]) axs[3, 0].set_title("logits", fontsize=4) plt.colorbar(im, ax=axs[3, 0]) # im.set_clim(-3,4) im = axs[3, 1].imshow(logits.cpu().numpy()[i, 1, :, :]) axs[3, 1].set_title("logits", fontsize=4) plt.colorbar(im, ax=axs[3, 1]) im = axs[4, 0].imshow(scaled_result.cpu().numpy()[i, 0, :, :]) axs[4, 0].set_title("loss: " + str(loss[i]), fontsize=4) plt.colorbar(im, ax=axs[4, 0]) # im.set_clim(-4,4) im = axs[4, 1].imshow(scaled_result.cpu().numpy()[i, 1, :, :]) axs[4, 1].set_title("loss: " + str(loss[i]), fontsize=4) plt.colorbar(im, ax=axs[4, 1]) fig.savefig("eval_" + str(i) + ".png", dpi=1000)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) device = torch.device('cuda') model = None if args.arch == "UNet": model = UNet(args).to(device) elif args.arch == "Fourier": model = FNO_multimodal_2d(args).to(device) else: raise ("architecture {args.arch} hasn't been added!!") # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True) # model.optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # model.lr_scheduler = optim.lr_scheduler.ExponentialLR(model.optimizer, args.exp_decay) tmp = filter(lambda x: x.requires_grad, model.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(model) #for name, param in model.named_parameters(): # print(name, param.size()) print('Total trainable tensors:', num, flush=True) model_path = args.model_saving_path + args.model_name + \ "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \ "_fmodes_" + str(args.f_modes) + \ "_flayers_" + str(args.num_fourier_layers) + \ "_Hidden_" + str(args.HIDDEN_DIM) + \ "_f_padding_" + str(args.f_padding) + \ "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr) if not os.path.isdir(model_path): raise ("model path not found: ".model_path) # load model: print("Restoring weights from ", model_path + "/best_model.pt", flush=True) checkpoint = torch.load(model_path + "/best_model.pt") # start_epoch=checkpoint['epoch'] model = checkpoint['model'] # model.lr_scheduler = checkpoint['lr_scheduler'] # model.optimizer = checkpoint['optimizer'] # df = pd.read_csv(model_path + '/'+'df.csv') ds = DDM_Dataset(args.data_folder, total_sample_number=args.total_sample_number) torch.manual_seed(42) DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0) # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg']) size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex - args.overlap_pixels) size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey - args.overlap_pixels) for sample_id, sample_batched in enumerate(DDM_loader): if sample_id > 2: break with torch.no_grad(): DDM_img, DDM_Hy = sample_batched['structure'], sample_batched[ 'field'] # prepare the input batched subdomains to model: model_bs = args.x_patches * args.y_patches x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] x_batch_train = torch.stack(x_batch_train).reshape( model_bs, 1, args.domain_sizex, args.domain_sizex).to(device) y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] y_batch_train = torch.stack(y_batch_train).reshape( model_bs, 2, args.domain_sizex, args.domain_sizex).to(device) top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, intep_field = init_four_point_interp( DDM_Hy, prop2, args, device) plt.rcParams["font.size"] = "5" fig, axs = plt.subplots(int(args.DDM_iters / 5) + 3, 4) axs[0, 0].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[0, 1].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[1, 0].imshow(DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[1, 1].imshow(DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[2, 0].imshow(intep_field[0, :, :]) axs[2, 1].imshow(intep_field[1, :, :]) for k in range(args.DDM_iters): bc_mean = 1 / 4 * ( torch.mean( torch.abs(top_bc_train), dim=(2, 3), keepdim=True) + torch.mean( torch.abs(bottom_bc_train), dim=(2, 3), keepdim=True) + torch.mean( torch.abs(left_bc_train), dim=(2, 3), keepdim=True) + torch.mean( torch.abs(right_bc_train), dim=(2, 3), keepdim=True)) # with autocast(): logits = model(x_batch_train, left_bc_train / bc_mean, right_bc_train / bc_mean, top_bc_train / bc_mean, bottom_bc_train / bc_mean).reshape( y_batch_train.shape) * bc_mean # reconstruct the whole field intermediate_result = reconstruct(logits, args) # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape ) loss = model.loss_fn( intermediate_result.contiguous().view(1, -1), DDM_Hy[:, :, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].contiguous().view(1, -1)) if k % 5 == 0: img_idx = int(k / 5) im = axs[img_idx + 3, 0].imshow(intermediate_result[0, :, :]) axs[img_idx + 3, 0].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 0]) # im.set_clim(-4,4) im = axs[img_idx + 3, 1].imshow(intermediate_result[1, :, :]) axs[img_idx + 3, 1].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 1]) # im.set_clim(-4,4) im = axs[img_idx + 3, 2].imshow( reconstruct(y_batch_train, args)[0, :, :]) axs[img_idx + 3, 2].set_title("y_batch_train: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 2]) # im.set_clim(-4,4) im = axs[img_idx + 3, 3].imshow( reconstruct(y_batch_train, args)[1, :, :]) axs[img_idx + 3, 3].set_title("y_batch_train: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 3]) # im.set_clim(-4,4) # Then prepare the data for next iteration: top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = new_iter_bcs( logits, y_batch_train, args, robin_transform_0th, device) #calculate the loss using the ground truth fig.savefig("eval_" + str(sample_id) + ".png", dpi=1000)
def main(args, solver): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) ds = DDM_Dataset(args.data_folder, total_sample_number=None) torch.manual_seed(42) DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0) # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg']) size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex - args.overlap_pixels) size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey - args.overlap_pixels) device_losses = np.zeros((args.num_device, args.DDM_iters)) for sample_id, sample_batched in enumerate(DDM_loader): if sample_id >= args.num_device: break # if sample_id>2: # break # if sample_id<2: # continue DDM_img, DDM_Hy = sample_batched['structure'], sample_batched['field'] # prepare the input batched subdomains to model: model_bs = args.x_patches * args.y_patches x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] x_batch_train = torch.stack(x_batch_train).reshape( model_bs, 1, args.domain_sizex, args.domain_sizey) # yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ # 1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ # DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ # 1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ # for i in range(args.x_patches) for j in range(args.y_patches)] yeex_batch_train = [1/2*(DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : 0+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : 0+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ DDM_img[0, 0,-2+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : 0+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ for i in range(args.x_patches) for j in range(args.y_patches)] yeex_batch_train = torch.stack(yeex_batch_train).reshape( model_bs, 1, args.domain_sizex + 1, args.domain_sizey) # yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ # 0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ # DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ # -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ # for i in range(args.x_patches) for j in range(args.y_patches)] yeey_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : 0+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : 0+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : 0+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ -2+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ for i in range(args.x_patches) for j in range(args.y_patches)] yeey_batch_train = torch.stack(yeey_batch_train).reshape( model_bs, 1, args.domain_sizex, args.domain_sizey + 1) y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] y_batch_train = torch.stack(y_batch_train).reshape( model_bs, 2, args.domain_sizex, args.domain_sizex) intep_field, patched_solved = init_four_point_interp( DDM_Hy, prop2, args) # for debugging: # patched_solved = y_batch_train plt.rcParams["font.size"] = "5" fig, axs = plt.subplots(int(args.DDM_iters / args.div_k) + 2, 2) im = axs[0, 0].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) plt.colorbar(im, ax=axs[0, 0]) im = axs[0, 1].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) plt.colorbar(im, ax=axs[0, 1]) im = axs[1, 0].imshow(DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) plt.colorbar(im, ax=axs[1, 0]) im = axs[1, 1].imshow(DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) plt.colorbar(im, ax=axs[1, 1]) # im = axs[2,0].imshow(intep_field[0, :, :]) # plt.colorbar(im, ax=axs[2,0]) # im = axs[2,1].imshow(intep_field[1, :, :]) # plt.colorbar(im, ax=axs[2,1]) b, _, n, m = patched_solved.shape last_gs = np.zeros((b, n, m), dtype=np.csingle) for k in range(args.DDM_iters): # for idx in range(model_bs): # # left, right, top, bottom # ops = [solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, 0]+x_batch_train[idx,0, :, 1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, -2]+x_batch_train[idx,0, :,-1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, 1, :]+x_batch_train[idx,0, 0, :]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, -2, :]+x_batch_train[idx,0,-1, :]).numpy())] # if k==0: # g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops) # else: # g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops, last_gs[idx]) # last_gs[idx] = g # A,b = solver.construct_matrices_complex(g, ops, yeex_batch_train[idx,0], yeey_batch_train[idx,0], alpha, beta, gamma, g_mul) # field_vec = torch.tensor(solver.solve(A, b).reshape((args.domain_sizex, args.domain_sizey)), dtype=torch.cfloat) # field_vec_real = torch.real(field_vec) # field_vec_imag = torch.imag(field_vec) # solved = torch.stack([field_vec_real, field_vec_imag], dim=0) # patched_solved[idx] = solved # for now try brute force solving print("yee shape:", yeex_batch_train.shape, yeey_batch_train.shape) dual_m, phi_c_global, Phi_r = solver.direct_solve( x_batch_train, y_batch_train, yeex_batch_train, yeey_batch_train) # reconstruct the whole field intermediate_result = solver.reconstruct(phi_c_global, Phi_r) # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape ) # diff = intermediate_result.contiguous().view(1,-1) - DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1) # loss = torch.mean(torch.abs(diff)) / \ # torch.mean(torch.abs(DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1))) # print(f"iter {k}, loss {loss}") # device_losses[sample_id, k] = loss if k % args.div_k == 0: img_idx = int(k / args.div_k) # im = axs[img_idx+2,0].imshow(intermediate_result[0,:,:]) im = axs[img_idx + 2, 0].imshow( intermediate_result[0, :, :] - DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].numpy()) loss1 = np.mean( np.abs(intermediate_result[0, :, :] - DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].numpy()) ) / np.mean( np.abs(DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].numpy())) # axs[img_idx+3,0].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 2, 0]) # im.set_clim(-4,4) # im = axs[img_idx+2,1].imshow(intermediate_result[1,:,:]) im = axs[img_idx + 2, 1].imshow( intermediate_result[1, :, :] - DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].numpy()) loss2 = np.mean( np.abs(intermediate_result[1, :, :] - DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].numpy()) ) / np.mean( np.abs(DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].numpy())) print(f"loss1: {loss1}, loss2: {loss2}") # axs[img_idx+3,1].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 2, 1]) fig.savefig(str(sample_id) + ".png", dpi=500)
def main(args, solver): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) ds = DDM_Dataset(args.data_folder, total_sample_number = None) torch.manual_seed(42) DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0) # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg']) model_bs = args.x_patches*args.y_patches size_x = args.domain_sizex+(args.x_patches-1)*(args.domain_sizex-args.overlap_pixels) size_y = args.domain_sizey+(args.y_patches-1)*(args.domain_sizey-args.overlap_pixels) device_losses = np.zeros((args.num_device, args.DDM_iters)) history_fields=np.zeros((len(DDM_loader), args.DDM_iters+1,model_bs,2,args.domain_sizex,args.domain_sizey)) x_batch_trains=np.zeros((len(DDM_loader), model_bs,1,args.domain_sizex,args.domain_sizey)) y_batch_trains=np.zeros((len(DDM_loader), model_bs,2,args.domain_sizex,args.domain_sizey)) print("shapes:", history_fields.shape, x_batch_trains.shape, y_batch_trains.shape) for sample_id, sample_batched in enumerate(DDM_loader): if sample_id>=args.num_device: break if sample_id%20 == 0: print("sample_id: ", sample_id, flush=True) DDM_img, DDM_Hy= sample_batched['structure'], sample_batched['field'] # prepare the input batched subdomains to model: x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] x_batch_train = torch.stack(x_batch_train).reshape(model_bs,1,args.domain_sizex,args.domain_sizey) yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ for i in range(args.x_patches) for j in range(args.y_patches)] yeex_batch_train = torch.stack(yeex_batch_train).reshape(model_bs,1,args.domain_sizex-1,args.domain_sizey-2) yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ for i in range(args.x_patches) for j in range(args.y_patches)] yeey_batch_train = torch.stack(yeey_batch_train).reshape(model_bs,1,args.domain_sizex-2,args.domain_sizey-1) y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] y_batch_train = torch.stack(y_batch_train).reshape(model_bs,2,args.domain_sizex,args.domain_sizex) intep_field, patched_solved = init_four_point_interp(DDM_Hy, prop2, args) b, _, n, m = patched_solved.shape last_gs = np.zeros((b,n,m), dtype=np.csingle) # history_fields=np.zeros((args.DDM_iters+1,b,2,n,m)) # history_fields[0, :, :, :, :] = patched_solved history_fields[sample_id, 0, :, :, :, :] = patched_solved x_batch_trains[sample_id] = x_batch_train y_batch_trains[sample_id] = y_batch_train for k in range(args.DDM_iters): for idx in range(model_bs): # left, right, top, bottom ops = [solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, 0]+x_batch_train[idx,0, :, 1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, -2]+x_batch_train[idx,0, :,-1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, 1, :]+x_batch_train[idx,0, 0, :]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, -2, :]+x_batch_train[idx,0,-1, :]).numpy())] if k==0: g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops) else: g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops, last_gs[idx]) last_gs[idx] = g A,b = solver.construct_matrices_complex(g, ops, yeex_batch_train[idx,0], yeey_batch_train[idx,0], alpha, beta, gamma, g_mul) field_vec = torch.tensor(solver.solve(A, b).reshape((args.domain_sizex, args.domain_sizey)), dtype=torch.cfloat) field_vec_real = torch.real(field_vec) field_vec_imag = torch.imag(field_vec) solved = torch.stack([field_vec_real, field_vec_imag], dim=0) patched_solved[idx] = solved history_fields[sample_id, k+1, :, :, :, :] = patched_solved # reconstruct the whole field intermediate_result = reconstruct(patched_solved, args) # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape ) diff = intermediate_result.contiguous().view(1,-1) - DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1) loss = torch.mean(torch.abs(diff)) / \ torch.mean(torch.abs(DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1))) print(f"iter {k}, loss {loss}") device_losses[sample_id, k] = loss np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_field_history.npy", history_fields) np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_eps.npy", x_batch_trains) np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_Hy_gt.npy", y_batch_trains) plt.figure() plt.plot(list(range(args.DDM_iters)), device_losses.T) plt.legend([f"device_{name}" for name in range(args.num_device)]) plt.xlabel("iteration") plt.yscale('log') plt.ylabel("Relative Error") plt.savefig(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_device_loss.png", dpi=300) np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_device_losses.npy", device_losses)
def main(args, solver): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) ds = DDM_Dataset(args.data_folder, total_sample_number=None) torch.manual_seed(42) DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0) # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg']) size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex - args.overlap_pixels) size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey - args.overlap_pixels) for sample_id, sample_batched in enumerate(DDM_loader): if sample_id > 2: break DDM_img, DDM_Hy = sample_batched['structure'], sample_batched['field'] # prepare the input batched subdomains to model: model_bs = args.x_patches * args.y_patches x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] x_batch_train = torch.stack(x_batch_train).reshape( model_bs, 1, args.domain_sizex, args.domain_sizey) yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ for i in range(args.x_patches) for j in range(args.y_patches)] yeex_batch_train = torch.stack(yeex_batch_train).reshape( model_bs, 1, args.domain_sizex - 1, args.domain_sizey - 2) yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ for i in range(args.x_patches) for j in range(args.y_patches)] yeey_batch_train = torch.stack(yeey_batch_train).reshape( model_bs, 1, args.domain_sizex - 2, args.domain_sizey - 1) y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] y_batch_train = torch.stack(y_batch_train).reshape( model_bs, 2, args.domain_sizex, args.domain_sizex) top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, intep_field = init_four_point_interp( DDM_Hy, prop2, args) plt.rcParams["font.size"] = "5" fig, axs = plt.subplots(int(args.DDM_iters / args.div_k) + 3, 4) axs[0, 0].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[0, 1].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[1, 0].imshow(DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[1, 1].imshow(DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[2, 0].imshow(intep_field[0, :, :]) axs[2, 1].imshow(intep_field[1, :, :]) for k in range(args.DDM_iters): # with autocast(): patched_solved = torch.zeros(y_batch_train.shape) boundary_field = torch.zeros(y_batch_train.shape) boundary_field[:, :, 0:1, :] = top_bc_train[:, :, :, :] boundary_field[:, :, -1:, :] = bottom_bc_train[:, :, :, :] boundary_field[:, :, :, 0:1] = left_bc_train[:, :, :, :] boundary_field[:, :, :, -1:] = right_bc_train[:, :, :, :] for idx in range(model_bs): A_real, b_real = solver.construct_matrices( boundary_field[idx, 0], yeex_batch_train[idx, 0], yeey_batch_train[idx, 0]) A_imag, b_imag = solver.construct_matrices( boundary_field[idx, 1], yeex_batch_train[idx, 0], yeey_batch_train[idx, 0]) field_vec_real = solver.solve(A_real, b_real).reshape( (args.domain_sizex - 2, args.domain_sizey - 2)) field_vec_imag = solver.solve(A_imag, b_imag).reshape( (args.domain_sizex - 2, args.domain_sizey - 2)) inner_field = torch.stack([field_vec_real, field_vec_imag], dim=0) solved = torch.zeros((2, args.domain_sizex, args.domain_sizey)) solved[:, 0, :] = torch.clone(boundary_field[idx, :, 0, :]) solved[:, -1, :] = torch.clone(boundary_field[idx, :, -1, :]) solved[:, :, 0] = torch.clone(boundary_field[idx, :, :, 0]) solved[:, :, -1] = torch.clone(boundary_field[idx, :, :, -1]) solved[:, 1:-1, 1:-1] = inner_field patched_solved[idx] = solved # reconstruct the whole field intermediate_result = reconstruct(patched_solved, args) # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape ) loss = torch.mean( torch.abs(intermediate_result.contiguous().view(1, -1) - DDM_Hy[:, :, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].contiguous().view(1, -1))) if k % args.div_k == 0: img_idx = int(k / args.div_k) im = axs[img_idx + 3, 0].imshow(intermediate_result[0, :, :]) axs[img_idx + 3, 0].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 0]) # im.set_clim(-4,4) im = axs[img_idx + 3, 1].imshow(intermediate_result[1, :, :]) axs[img_idx + 3, 1].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 1]) # im.set_clim(-4,4) im = axs[img_idx + 3, 2].imshow(reconstruct(y_batch_train, args)[0, :, :]) axs[img_idx + 3, 2].set_title("y_batch_train: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 2]) # im.set_clim(-4,4) im = axs[img_idx + 3, 3].imshow(reconstruct(y_batch_train, args)[1, :, :]) axs[img_idx + 3, 3].set_title("y_batch_train: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 3]) # im.set_clim(-4,4) # a = np.zeros((32,32)) # a[0,:] = top_bc_train.cpu()[0,0,:,:].squeeze() # a[-1,:] = bottom_bc_train.cpu()[0,0,:,:].squeeze() # a[:,0] = left_bc_train.cpu()[0,0,:,:].squeeze() # a[:,-1] = right_bc_train.cpu()[0,0,:,:].squeeze() # axs[k+3,0].imshow(a) # axs[k+3,0].set_title("bcs: " + str(loss)) # axs[k+3,1].imshow(y_batch_train.cpu()[0,0,:,:]) # axs[k+3,1].set_title("y_batch_train: " + str(loss)) # axs[k+3,2].imshow(logits.cpu()[0,0,:,:]) # axs[k+3,2].set_title("logits: " + str(loss)) # Then prepare the data for next iteration: top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = new_iter_bcs( patched_solved, y_batch_train, args, robin_transform_2nd) #calculate the loss using the ground truth fig.savefig("eval_" + str(sample_id) + ".png", dpi=3000)
def main(args, solver): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) ds = DDM_Dataset(args.data_folder, total_sample_number=None) torch.manual_seed(42) DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0) # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg']) size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex - args.overlap_pixels) size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey - args.overlap_pixels) device_losses = np.zeros((args.num_device, args.DDM_iters)) for sample_id, sample_batched in enumerate(DDM_loader): # if sample_id>=args.num_device: # break if sample_id > 4: break if sample_id < 4: continue DDM_img, DDM_Hy = sample_batched['structure'], sample_batched['field'] # prepare the input batched subdomains to model: model_bs = args.x_patches * args.y_patches x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] x_batch_train = torch.stack(x_batch_train).reshape( model_bs, 1, args.domain_sizex, args.domain_sizey) yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ for i in range(args.x_patches) for j in range(args.y_patches)] yeex_batch_train = torch.stack(yeex_batch_train).reshape( model_bs, 1, args.domain_sizex - 1, args.domain_sizey - 2) yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ 0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \ DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \ for i in range(args.x_patches) for j in range(args.y_patches)] yeey_batch_train = torch.stack(yeey_batch_train).reshape( model_bs, 1, args.domain_sizex - 2, args.domain_sizey - 1) y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] y_batch_train = torch.stack(y_batch_train).reshape( model_bs, 2, args.domain_sizex, args.domain_sizex) intep_field, patched_solved = init_four_point_interp( DDM_Hy, prop2, args) plt.rcParams["font.size"] = "5" fig, axs = plt.subplots(int(args.DDM_iters / args.div_k) + 3, 4) axs[0, 0].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[0, 1].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[1, 0].imshow(DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[1, 1].imshow(DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) axs[2, 0].imshow(intep_field[0, :, :]) axs[2, 1].imshow(intep_field[1, :, :]) b, _, n, m = patched_solved.shape last_gs = torch.zeros((b, n, m), dtype=torch.cfloat) for k in range(args.DDM_iters): for idx in range(model_bs): if k == 0: g, alpha, beta, gamma = trasmission_2nd_g_alpha_beta_gamma_complex( patched_solved, x_batch_train, idx, args.transmission_func, args, solver.diffL) else: g, alpha, beta, gamma = trasmission_2nd_g_alpha_beta_gamma_complex( patched_solved, x_batch_train, idx, args.transmission_func, args, solver.diffL, last_gs[idx]) last_gs[idx] = g A, b = solver.construct_matrices_complex( g, yeex_batch_train[idx, 0], yeey_batch_train[idx, 0], alpha, beta, gamma) field_vec = solver.solve(A, b).reshape( (args.domain_sizex, args.domain_sizey)) field_vec_real = torch.real(field_vec) field_vec_imag = torch.imag(field_vec) solved = torch.stack([field_vec_real, field_vec_imag], dim=0) patched_solved[idx] = solved # reconstruct the whole field intermediate_result = reconstruct(patched_solved, args) # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape ) loss = torch.mean( torch.abs(intermediate_result.contiguous().view(1, -1) - DDM_Hy[:, :, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].contiguous().view(1, -1))) print(f"iter {k}, loss {loss}") # device_losses[sample_id, k] = loss if k % args.div_k == 0: img_idx = int(k / args.div_k) im = axs[img_idx + 3, 0].imshow(intermediate_result[0, :, :]) axs[img_idx + 3, 0].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 0]) # im.set_clim(-4,4) im = axs[img_idx + 3, 1].imshow(intermediate_result[1, :, :]) axs[img_idx + 3, 1].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 1]) # im.set_clim(-4,4) im = axs[img_idx + 3, 2].imshow(reconstruct(y_batch_train, args)[0, :, :]) axs[img_idx + 3, 2].set_title("y_batch_train: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 2]) # im.set_clim(-4,4) im = axs[img_idx + 3, 3].imshow(reconstruct(y_batch_train, args)[1, :, :]) axs[img_idx + 3, 3].set_title("y_batch_train: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 3]) # im.set_clim(-4,4) fig.savefig(f"alpha_{args.alpha}_beta{args.beta}_" + str(sample_id) + ".png", dpi=500) plt.figure() plt.plot(list(range(args.DDM_iters)), device_losses.T) plt.legend([f"device_{name}" for name in range(args.num_device)]) plt.xlabel("iteration") plt.yscale('log') plt.ylabel("Relative Error") plt.savefig(f"alpha_{args.alpha}_beta{args.beta}_" + "device_losses.png", dpi=300)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) device = torch.device('cuda') s_model = None if args.arch == "UNet": s_model = UNet(args).to(device) elif args.arch == "Fourier": s_model = FNO_multimodal_2d(args).to(device) else: raise ("architecture {args.arch} hasn't been added!!") s_model_path = args.s_model_saving_path + args.s_model_name + \ "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \ "_fmodes_" + str(args.f_modes) + \ "_flayers_" + str(args.num_fourier_layers) + \ "_Hidden_" + str(args.HIDDEN_DIM) + \ "_f_padding_" + str(args.f_padding) + \ "_batch_size_" + str(args.s_batch_size) + "_lr_" + str(args.s_lr) if not os.path.isdir(s_model_path): print(f"path error: {s_model_path}") raise b_model_path = args.b_model_saving_path + args.b_model_name + \ "_b_hidden_" + str(args.b_hidden) + \ "_b_layers_" + str(args.b_layers) + \ "_boundary_thickness_" + str(args.boundary_thickness) + \ "_domain_size_" + str(args.domain_sizex) + \ "_b_batch_size_" + str(args.b_batch_size) + "_b_lr_" + str(args.b_lr) if not os.path.isdir(b_model_path): print(f"path error: {b_model_path}") raise # load the already trained s_model (subdomain model): print("Restoring subdomain model weights from ", s_model_path + "/best_model.pt", flush=True) checkpoint = torch.load(s_model_path + "/best_model.pt") s_model.load_state_dict(checkpoint['model'].state_dict()) s_model.to(device) # init b_model (boundary model): b_model = boundary_model(args).to(device) print("Restoring boundary model weights from ", b_model_path + "/best_model.pt", flush=True) checkpoint = torch.load(b_model_path + "/best_model.pt") b_model.load_state_dict(checkpoint['model'].state_dict()) b_model.to(device) tmp = filter(lambda x: x.requires_grad, s_model.parameters()) s_num = sum(map(lambda x: np.prod(x.shape), tmp)) tmp = filter(lambda x: x.requires_grad, b_model.parameters()) b_num = sum(map(lambda x: np.prod(x.shape), tmp)) print('s_model total trainable tensors:', s_num, flush=True) print('b_model total trainable tensors:', b_num, flush=True) # data loader ds = DDM_Dataset(args.data_folder, total_sample_number=args.total_sample_number) torch.manual_seed(42) DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0) # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg']) size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex - args.overlap_pixels) size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey - args.overlap_pixels) device_losses = np.zeros((args.num_device, args.DDM_iters)) start_time = timeit.default_timer() for sample_id, sample_batched in enumerate(DDM_loader): if sample_id >= args.num_device: break # if sample_id>0: # break # if sample_id<0: # continue with torch.no_grad(): DDM_img, DDM_Hy = sample_batched['structure'], sample_batched[ 'field'] # prepare the input batched subdomains to model: model_bs = args.x_patches * args.y_patches x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] x_batch_train = torch.stack(x_batch_train).reshape( model_bs, 1, args.domain_sizex, args.domain_sizex).to(device) y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\ args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)] y_batch_train = torch.stack(y_batch_train).reshape( model_bs, 2, args.domain_sizex, args.domain_sizex).to(device) top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, intep_field = init_four_point_interp( DDM_Hy, prop2, args, device, return_bc=True) plt.rcParams["font.size"] = "5" fig, axs = plt.subplots(int(args.DDM_iters / args.div_k) + 3, 2) im = axs[0, 0].imshow( DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) plt.colorbar(im, ax=axs[0, 0]) im = axs[0, 1].imshow( DDM_img[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) plt.colorbar(im, ax=axs[0, 1]) im = axs[1, 0].imshow( DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) plt.colorbar(im, ax=axs[1, 0]) im = axs[1, 1].imshow( DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y]) plt.colorbar(im, ax=axs[1, 1]) im = axs[2, 0].imshow(intep_field[0, :, :]) plt.colorbar(im, ax=axs[2, 0]) im = axs[2, 1].imshow(intep_field[1, :, :]) plt.colorbar(im, ax=axs[2, 1]) # for debugging: # top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = boundary_from_gt(DDM_Hy, args, device) for k in range(args.DDM_iters): bc_mean = 1 / 4 * ( torch.mean( torch.abs(top_bc_train), dim=(2, 3), keepdim=True) + torch.mean( torch.abs(bottom_bc_train), dim=(2, 3), keepdim=True) + torch.mean( torch.abs(left_bc_train), dim=(2, 3), keepdim=True) + torch.mean( torch.abs(right_bc_train), dim=(2, 3), keepdim=True)) # with autocast(): logits = s_model(x_batch_train, left_bc_train / bc_mean, right_bc_train / bc_mean, top_bc_train / bc_mean, bottom_bc_train / bc_mean).reshape( y_batch_train.shape) * bc_mean # reconstruct the whole field intermediate_result = reconstruct(logits, args) # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape ) loss = s_model.loss_fn( intermediate_result.contiguous().view(1, -1), DDM_Hy[:, :, args.starting_x:args.starting_x + size_x, args.starting_y:args.starting_y + size_y].contiguous().view(1, -1)) device_losses[sample_id, k] = loss if k % args.div_k == 0: img_idx = int(k / args.div_k) im = axs[img_idx + 3, 0].imshow(intermediate_result[0, :, :]) axs[img_idx + 3, 0].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 0]) # im.set_clim(-4,4) im = axs[img_idx + 3, 1].imshow(intermediate_result[1, :, :]) axs[img_idx + 3, 1].set_title("loss: " + str(loss), fontsize=4) plt.colorbar(im, ax=axs[img_idx + 3, 1]) # Then prepare the data for next iteration: if args.traditional_bc == 1: top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = new_iter_bcs( logits, y_batch_train, args, args.bc_func, device=device) else: top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = new_iter_bcs_model( x_batch_train.reshape( (args.x_patches, args.y_patches, 1, args.domain_sizex, args.domain_sizey)), logits, y_batch_train, args, b_model, device=device) if args.traditional_bc == 1: fig.savefig(f"traditional_bc_" + str(sample_id) + ".png", dpi=1000) else: fig.savefig(f"all_ml_" + str(sample_id) + ".png", dpi=1000) end_time = timeit.default_timer() print( f"Run time for {args.num_device} devices each with {args.DDM_iters} iterations (note: this may include time for plotting and saving figure):", end_time - start_time) plt.figure() plt.plot(list(range(args.DDM_iters)), device_losses.T) plt.legend([f"device_{name}" for name in range(args.num_device)]) plt.xlabel("iteration") plt.yscale('log') plt.ylabel("Relative Error") if args.traditional_bc == 1: plt.savefig(f"traditional_bc_device_loss.png", dpi=300) else: plt.savefig(f"all_ml_device_loss.png", dpi=300)