def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)
    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    elif args.arch == "Fourier":
        model = FNO_multimodal_2d(args).to(device)
    else:
        raise ("architecture {args.arch} hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    # model.optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # model.lr_scheduler = optim.lr_scheduler.ExponentialLR(model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    model_path1 = args.model_saving_path + args.model_name + \
                                          "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \
                                          "_fmodes_" + str(args.f_modes) + \
                                          "_flayers_" + str(args.num_fourier_layers) + \
                                          "_Hidden_" + str(args.HIDDEN_DIM) + \
                                          "_f_padding_" + str(args.f_padding) + \
                                          "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)

    model_path2 = args.model_saving_path + args.model_name + "_batch_size_" + str(
        args.batch_size) + "_lr_" + str(args.lr)

    model_path = model_path1
    if not os.path.isdir(model_path1):
        if not os.path.isdir(model_path2):
            raise ("model path not found: ".model_path)
        else:
            model_path = model_path2

    # load model:
    print("Restoring weights from ", model_path + "/best_model.pt", flush=True)
    checkpoint = torch.load(model_path + "/best_model.pt")
    # start_epoch=checkpoint['epoch']
    model = checkpoint['model']
    # model.lr_scheduler = checkpoint['lr_scheduler']
    # model.optimizer = checkpoint['optimizer']
    # df = pd.read_csv(model_path + '/'+'df.csv')

    ds = DDM_Dataset(args.data_folder,
                     total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    DDM_loader = DataLoader(ds,
                            batch_size=args.plot_figs,
                            shuffle=True,
                            num_workers=0)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    size_x = args.scale_sx
    size_y = args.scale_sy
    loss = []
    for sample_batched in DDM_loader:
        with torch.no_grad():
            DDM_img, DDM_Hy = sample_batched['structure'], sample_batched[
                'field']

            top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, x_batch_train, intep_field = scale_four_point_interp(
                DDM_img, DDM_Hy, args, device)

            bc_mean = 1 / 4 * (
                torch.mean(torch.abs(top_bc_train), dim=(2, 3), keepdim=True) +
                torch.mean(
                    torch.abs(bottom_bc_train), dim=(2, 3), keepdim=True) +
                torch.mean(torch.abs(left_bc_train), dim=(2, 3), keepdim=True)
                + torch.mean(
                    torch.abs(right_bc_train), dim=(2, 3), keepdim=True))
            # with autocast():
            logits = model(x_batch_train / 4, left_bc_train / bc_mean,
                           right_bc_train / bc_mean, top_bc_train / bc_mean,
                           bottom_bc_train / bc_mean).reshape(
                               intep_field.shape) * bc_mean

            # reconstruct the whole field
            scaled_result = scale_back(logits.cpu(), args, device)
            # print("shapes: ",scaled_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape )
            loss = torch.mean(torch.abs(
                scaled_result.contiguous().view(args.plot_figs, -1) -
                DDM_Hy[:, :, args.starting_x:args.starting_x + size_x,
                       args.starting_y:args.starting_y +
                       size_y].contiguous().view(args.plot_figs, -1)),
                              dim=1).numpy()
        break

    for i in range(args.plot_figs):
        plt.rcParams["font.size"] = "5"
        fig, axs = plt.subplots(5, 2)
        axs[0, 0].imshow(
            DDM_img.cpu().numpy()[i, 0,
                                  args.starting_x:args.starting_x + size_x,
                                  args.starting_y:args.starting_y + size_y])
        axs[0, 1].imshow(x_batch_train.cpu().numpy()[i, 0, :, :])
        axs[1, 0].imshow(
            DDM_Hy.cpu().numpy()[i, 0,
                                 args.starting_x:args.starting_x + size_x,
                                 args.starting_y:args.starting_y + size_y])
        axs[1, 1].imshow(
            DDM_Hy.cpu().numpy()[i, 1,
                                 args.starting_x:args.starting_x + size_x,
                                 args.starting_y:args.starting_y + size_y])
        im = axs[2, 0].imshow(intep_field.cpu().numpy()[i, 0, :, :])
        axs[2, 0].set_title("scaled_gt", fontsize=4)
        plt.colorbar(im, ax=axs[2, 0])

        im = axs[2, 1].imshow(intep_field.cpu().numpy()[i, 1, :, :])
        axs[2, 0].set_title("scaled_gt", fontsize=4)
        plt.colorbar(im, ax=axs[2, 1])

        im = axs[3, 0].imshow(logits.cpu().numpy()[i, 0, :, :])
        axs[3, 0].set_title("logits", fontsize=4)
        plt.colorbar(im, ax=axs[3, 0])
        # im.set_clim(-3,4)

        im = axs[3, 1].imshow(logits.cpu().numpy()[i, 1, :, :])
        axs[3, 1].set_title("logits", fontsize=4)
        plt.colorbar(im, ax=axs[3, 1])

        im = axs[4, 0].imshow(scaled_result.cpu().numpy()[i, 0, :, :])
        axs[4, 0].set_title("loss: " + str(loss[i]), fontsize=4)
        plt.colorbar(im, ax=axs[4, 0])
        # im.set_clim(-4,4)

        im = axs[4, 1].imshow(scaled_result.cpu().numpy()[i, 1, :, :])
        axs[4, 1].set_title("loss: " + str(loss[i]), fontsize=4)
        plt.colorbar(im, ax=axs[4, 1])

        fig.savefig("eval_" + str(i) + ".png", dpi=1000)
Exemplo n.º 2
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)
    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    elif args.arch == "Fourier":
        model = FNO_multimodal_2d(args).to(device)
    else:
        raise ("architecture {args.arch} hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    # model.optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # model.lr_scheduler = optim.lr_scheduler.ExponentialLR(model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    model_path = args.model_saving_path + args.model_name + \
                                          "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \
                                          "_fmodes_" + str(args.f_modes) + \
                                          "_flayers_" + str(args.num_fourier_layers) + \
                                          "_Hidden_" + str(args.HIDDEN_DIM) + \
                                          "_f_padding_" + str(args.f_padding) + \
                                          "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)

    if not os.path.isdir(model_path):
        raise ("model path not found: ".model_path)

    # load model:
    print("Restoring weights from ", model_path + "/best_model.pt", flush=True)
    checkpoint = torch.load(model_path + "/best_model.pt")
    # start_epoch=checkpoint['epoch']
    model = checkpoint['model']
    # model.lr_scheduler = checkpoint['lr_scheduler']
    # model.optimizer = checkpoint['optimizer']
    # df = pd.read_csv(model_path + '/'+'df.csv')

    ds = DDM_Dataset(args.data_folder,
                     total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex -
                                                         args.overlap_pixels)
    size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey -
                                                         args.overlap_pixels)
    for sample_id, sample_batched in enumerate(DDM_loader):
        if sample_id > 2:
            break
        with torch.no_grad():

            DDM_img, DDM_Hy = sample_batched['structure'], sample_batched[
                'field']

            # prepare the input batched subdomains to model:
            model_bs = args.x_patches * args.y_patches
            x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                            args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)]
            x_batch_train = torch.stack(x_batch_train).reshape(
                model_bs, 1, args.domain_sizex, args.domain_sizex).to(device)

            y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                           args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]  for i in range(args.x_patches) for j in range(args.y_patches)]
            y_batch_train = torch.stack(y_batch_train).reshape(
                model_bs, 2, args.domain_sizex, args.domain_sizex).to(device)

            top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, intep_field = init_four_point_interp(
                DDM_Hy, prop2, args, device)

            plt.rcParams["font.size"] = "5"
            fig, axs = plt.subplots(int(args.DDM_iters / 5) + 3, 4)
            axs[0, 0].imshow(DDM_img[0, 0,
                                     args.starting_x:args.starting_x + size_x,
                                     args.starting_y:args.starting_y + size_y])
            axs[0, 1].imshow(DDM_img[0, 0,
                                     args.starting_x:args.starting_x + size_x,
                                     args.starting_y:args.starting_y + size_y])
            axs[1, 0].imshow(DDM_Hy[0, 0,
                                    args.starting_x:args.starting_x + size_x,
                                    args.starting_y:args.starting_y + size_y])
            axs[1, 1].imshow(DDM_Hy[0, 1,
                                    args.starting_x:args.starting_x + size_x,
                                    args.starting_y:args.starting_y + size_y])
            axs[2, 0].imshow(intep_field[0, :, :])
            axs[2, 1].imshow(intep_field[1, :, :])

            for k in range(args.DDM_iters):
                bc_mean = 1 / 4 * (
                    torch.mean(
                        torch.abs(top_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(bottom_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(left_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(right_bc_train), dim=(2, 3), keepdim=True))
                # with autocast():
                logits = model(x_batch_train, left_bc_train / bc_mean,
                               right_bc_train / bc_mean, top_bc_train /
                               bc_mean, bottom_bc_train / bc_mean).reshape(
                                   y_batch_train.shape) * bc_mean

                # reconstruct the whole field
                intermediate_result = reconstruct(logits, args)
                # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape )
                loss = model.loss_fn(
                    intermediate_result.contiguous().view(1, -1),
                    DDM_Hy[:, :, args.starting_x:args.starting_x + size_x,
                           args.starting_y:args.starting_y +
                           size_y].contiguous().view(1, -1))

                if k % 5 == 0:
                    img_idx = int(k / 5)
                    im = axs[img_idx + 3,
                             0].imshow(intermediate_result[0, :, :])
                    axs[img_idx + 3, 0].set_title("loss: " + str(loss),
                                                  fontsize=4)
                    plt.colorbar(im, ax=axs[img_idx + 3, 0])
                    # im.set_clim(-4,4)

                    im = axs[img_idx + 3,
                             1].imshow(intermediate_result[1, :, :])
                    axs[img_idx + 3, 1].set_title("loss: " + str(loss),
                                                  fontsize=4)
                    plt.colorbar(im, ax=axs[img_idx + 3, 1])
                    # im.set_clim(-4,4)

                    im = axs[img_idx + 3, 2].imshow(
                        reconstruct(y_batch_train, args)[0, :, :])
                    axs[img_idx + 3,
                        2].set_title("y_batch_train: " + str(loss), fontsize=4)
                    plt.colorbar(im, ax=axs[img_idx + 3, 2])
                    # im.set_clim(-4,4)

                    im = axs[img_idx + 3, 3].imshow(
                        reconstruct(y_batch_train, args)[1, :, :])
                    axs[img_idx + 3,
                        3].set_title("y_batch_train: " + str(loss), fontsize=4)
                    plt.colorbar(im, ax=axs[img_idx + 3, 3])
                    # im.set_clim(-4,4)

                # Then prepare the data for next iteration:
                top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = new_iter_bcs(
                    logits, y_batch_train, args, robin_transform_0th, device)

                #calculate the loss using the ground truth
            fig.savefig("eval_" + str(sample_id) + ".png", dpi=1000)
Exemplo n.º 3
0
def main(args, solver):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ds = DDM_Dataset(args.data_folder, total_sample_number=None)
    torch.manual_seed(42)
    DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex -
                                                         args.overlap_pixels)
    size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey -
                                                         args.overlap_pixels)
    device_losses = np.zeros((args.num_device, args.DDM_iters))
    for sample_id, sample_batched in enumerate(DDM_loader):
        if sample_id >= args.num_device:
            break
        # if sample_id>2:
        #   break
        # if sample_id<2:
        #   continue

        DDM_img, DDM_Hy = sample_batched['structure'], sample_batched['field']

        # prepare the input batched subdomains to model:
        model_bs = args.x_patches * args.y_patches
        x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                        args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)]
        x_batch_train = torch.stack(x_batch_train).reshape(
            model_bs, 1, args.domain_sizex, args.domain_sizey)

        # yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
        #                                        1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
        #                          DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
        #                                        1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
        #                     for i in range(args.x_patches) for j in range(args.y_patches)]
        yeex_batch_train = [1/2*(DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) :  0+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) :  0+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0,-2+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) :  0+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeex_batch_train = torch.stack(yeex_batch_train).reshape(
            model_bs, 1, args.domain_sizex + 1, args.domain_sizey)

        # yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
        #                                        0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
        #                          DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
        #                                       -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
        #                     for i in range(args.x_patches) for j in range(args.y_patches)]
        yeey_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) :  0+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                              -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) :  0+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) :  0+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                              -2+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeey_batch_train = torch.stack(yeey_batch_train).reshape(
            model_bs, 1, args.domain_sizex, args.domain_sizey + 1)

        y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                       args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]  for i in range(args.x_patches) for j in range(args.y_patches)]
        y_batch_train = torch.stack(y_batch_train).reshape(
            model_bs, 2, args.domain_sizex, args.domain_sizex)

        intep_field, patched_solved = init_four_point_interp(
            DDM_Hy, prop2, args)

        # for debugging:
        # patched_solved = y_batch_train

        plt.rcParams["font.size"] = "5"
        fig, axs = plt.subplots(int(args.DDM_iters / args.div_k) + 2, 2)
        im = axs[0,
                 0].imshow(DDM_img[0, 0,
                                   args.starting_x:args.starting_x + size_x,
                                   args.starting_y:args.starting_y + size_y])
        plt.colorbar(im, ax=axs[0, 0])
        im = axs[0,
                 1].imshow(DDM_img[0, 0,
                                   args.starting_x:args.starting_x + size_x,
                                   args.starting_y:args.starting_y + size_y])
        plt.colorbar(im, ax=axs[0, 1])
        im = axs[1, 0].imshow(DDM_Hy[0, 0,
                                     args.starting_x:args.starting_x + size_x,
                                     args.starting_y:args.starting_y + size_y])
        plt.colorbar(im, ax=axs[1, 0])
        im = axs[1, 1].imshow(DDM_Hy[0, 1,
                                     args.starting_x:args.starting_x + size_x,
                                     args.starting_y:args.starting_y + size_y])
        plt.colorbar(im, ax=axs[1, 1])
        # im = axs[2,0].imshow(intep_field[0, :, :])
        # plt.colorbar(im, ax=axs[2,0])
        # im = axs[2,1].imshow(intep_field[1, :, :])
        # plt.colorbar(im, ax=axs[2,1])

        b, _, n, m = patched_solved.shape
        last_gs = np.zeros((b, n, m), dtype=np.csingle)
        for k in range(args.DDM_iters):
            # for idx in range(model_bs):
            #     # left, right, top, bottom
            #     ops = [solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, 0]+x_batch_train[idx,0, :, 1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, -2]+x_batch_train[idx,0, :,-1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, 1, :]+x_batch_train[idx,0, 0, :]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, -2, :]+x_batch_train[idx,0,-1, :]).numpy())]
            #     if k==0:
            #         g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops)
            #     else:
            #         g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops, last_gs[idx])
            #     last_gs[idx] = g

            #     A,b = solver.construct_matrices_complex(g, ops, yeex_batch_train[idx,0], yeey_batch_train[idx,0], alpha, beta, gamma, g_mul)

            #     field_vec = torch.tensor(solver.solve(A, b).reshape((args.domain_sizex, args.domain_sizey)), dtype=torch.cfloat)
            #     field_vec_real = torch.real(field_vec)
            #     field_vec_imag = torch.imag(field_vec)

            #     solved = torch.stack([field_vec_real, field_vec_imag], dim=0)
            #     patched_solved[idx] = solved

            # for now try brute force solving
            print("yee shape:", yeex_batch_train.shape, yeey_batch_train.shape)
            dual_m, phi_c_global, Phi_r = solver.direct_solve(
                x_batch_train, y_batch_train, yeex_batch_train,
                yeey_batch_train)

            # reconstruct the whole field
            intermediate_result = solver.reconstruct(phi_c_global, Phi_r)

            # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape )
            # diff = intermediate_result.contiguous().view(1,-1) - DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1)
            # loss = torch.mean(torch.abs(diff)) / \
            #        torch.mean(torch.abs(DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1)))
            # print(f"iter {k}, loss {loss}")

            # device_losses[sample_id, k] = loss
            if k % args.div_k == 0:
                img_idx = int(k / args.div_k)

                # im = axs[img_idx+2,0].imshow(intermediate_result[0,:,:])
                im = axs[img_idx + 2, 0].imshow(
                    intermediate_result[0, :, :] -
                    DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x,
                           args.starting_y:args.starting_y + size_y].numpy())

                loss1 = np.mean(
                    np.abs(intermediate_result[0, :, :] -
                           DDM_Hy[0, 0, args.starting_x:args.starting_x +
                                  size_x, args.starting_y:args.starting_y +
                                  size_y].numpy())
                ) / np.mean(
                    np.abs(DDM_Hy[0, 0, args.starting_x:args.starting_x +
                                  size_x, args.starting_y:args.starting_y +
                                  size_y].numpy()))
                # axs[img_idx+3,0].set_title("loss: " + str(loss), fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 2, 0])
                # im.set_clim(-4,4)

                # im = axs[img_idx+2,1].imshow(intermediate_result[1,:,:])
                im = axs[img_idx + 2, 1].imshow(
                    intermediate_result[1, :, :] -
                    DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x,
                           args.starting_y:args.starting_y + size_y].numpy())

                loss2 = np.mean(
                    np.abs(intermediate_result[1, :, :] -
                           DDM_Hy[0, 1, args.starting_x:args.starting_x +
                                  size_x, args.starting_y:args.starting_y +
                                  size_y].numpy())
                ) / np.mean(
                    np.abs(DDM_Hy[0, 1, args.starting_x:args.starting_x +
                                  size_x, args.starting_y:args.starting_y +
                                  size_y].numpy()))
                print(f"loss1: {loss1}, loss2: {loss2}")
                # axs[img_idx+3,1].set_title("loss: " + str(loss), fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 2, 1])

        fig.savefig(str(sample_id) + ".png", dpi=500)
Exemplo n.º 4
0
def main(args, solver):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ds = DDM_Dataset(args.data_folder, total_sample_number = None)
    torch.manual_seed(42)
    DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    model_bs = args.x_patches*args.y_patches
    size_x = args.domain_sizex+(args.x_patches-1)*(args.domain_sizex-args.overlap_pixels)
    size_y = args.domain_sizey+(args.y_patches-1)*(args.domain_sizey-args.overlap_pixels)
    device_losses = np.zeros((args.num_device, args.DDM_iters))

    history_fields=np.zeros((len(DDM_loader), args.DDM_iters+1,model_bs,2,args.domain_sizex,args.domain_sizey))
    x_batch_trains=np.zeros((len(DDM_loader), model_bs,1,args.domain_sizex,args.domain_sizey)) 
    y_batch_trains=np.zeros((len(DDM_loader), model_bs,2,args.domain_sizex,args.domain_sizey)) 
    
    print("shapes:", history_fields.shape, x_batch_trains.shape, y_batch_trains.shape)

    for sample_id, sample_batched in enumerate(DDM_loader):
        if sample_id>=args.num_device:
            break
        if sample_id%20 == 0:
            print("sample_id: ", sample_id, flush=True)

        DDM_img, DDM_Hy= sample_batched['structure'], sample_batched['field']

        # prepare the input batched subdomains to model:
        x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                        args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)]
        x_batch_train = torch.stack(x_batch_train).reshape(model_bs,1,args.domain_sizex,args.domain_sizey)

        yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeex_batch_train = torch.stack(yeex_batch_train).reshape(model_bs,1,args.domain_sizex-1,args.domain_sizey-2)

        yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                              -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeey_batch_train = torch.stack(yeey_batch_train).reshape(model_bs,1,args.domain_sizex-2,args.domain_sizey-1)

        y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                       args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]  for i in range(args.x_patches) for j in range(args.y_patches)]
        y_batch_train = torch.stack(y_batch_train).reshape(model_bs,2,args.domain_sizex,args.domain_sizex)

        intep_field, patched_solved = init_four_point_interp(DDM_Hy, prop2, args)

        b, _, n, m = patched_solved.shape
        last_gs = np.zeros((b,n,m), dtype=np.csingle)

        # history_fields=np.zeros((args.DDM_iters+1,b,2,n,m))
        # history_fields[0, :, :, :, :] = patched_solved
        history_fields[sample_id, 0, :, :, :, :] = patched_solved
        x_batch_trains[sample_id] = x_batch_train
        y_batch_trains[sample_id] = y_batch_train
        

        for k in range(args.DDM_iters):
            for idx in range(model_bs):
                # left, right, top, bottom
                ops = [solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, 0]+x_batch_train[idx,0, :, 1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, -2]+x_batch_train[idx,0, :,-1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, 1, :]+x_batch_train[idx,0, 0, :]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, -2, :]+x_batch_train[idx,0,-1, :]).numpy())]
                if k==0:
                    g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops)
                else:
                    g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops, last_gs[idx])
                last_gs[idx] = g

                A,b = solver.construct_matrices_complex(g, ops, yeex_batch_train[idx,0], yeey_batch_train[idx,0], alpha, beta, gamma, g_mul)
                
                field_vec = torch.tensor(solver.solve(A, b).reshape((args.domain_sizex, args.domain_sizey)), dtype=torch.cfloat)
                field_vec_real = torch.real(field_vec)
                field_vec_imag = torch.imag(field_vec)

                solved = torch.stack([field_vec_real, field_vec_imag], dim=0)
                patched_solved[idx] = solved

            history_fields[sample_id, k+1, :, :, :, :] = patched_solved
            
            # reconstruct the whole field
            intermediate_result = reconstruct(patched_solved, args)
            # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape )
            diff = intermediate_result.contiguous().view(1,-1) - DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1)
            loss = torch.mean(torch.abs(diff)) / \
                   torch.mean(torch.abs(DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1)))
            print(f"iter {k}, loss {loss}")
            device_losses[sample_id, k] = loss

    np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_field_history.npy", history_fields)
    np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_eps.npy", x_batch_trains)
    np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_Hy_gt.npy", y_batch_trains)

    plt.figure()
    plt.plot(list(range(args.DDM_iters)), device_losses.T)
    plt.legend([f"device_{name}" for name in range(args.num_device)])
    plt.xlabel("iteration")
    plt.yscale('log')
    plt.ylabel("Relative Error")
    plt.savefig(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_device_loss.png", dpi=300)    
    np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_device_losses.npy", device_losses)
def main(args, solver):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ds = DDM_Dataset(args.data_folder, total_sample_number=None)
    torch.manual_seed(42)
    DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex -
                                                         args.overlap_pixels)
    size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey -
                                                         args.overlap_pixels)
    for sample_id, sample_batched in enumerate(DDM_loader):
        if sample_id > 2:
            break

        DDM_img, DDM_Hy = sample_batched['structure'], sample_batched['field']

        # prepare the input batched subdomains to model:
        model_bs = args.x_patches * args.y_patches
        x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                        args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)]
        x_batch_train = torch.stack(x_batch_train).reshape(
            model_bs, 1, args.domain_sizex, args.domain_sizey)

        yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeex_batch_train = torch.stack(yeex_batch_train).reshape(
            model_bs, 1, args.domain_sizex - 1, args.domain_sizey - 2)

        yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                              -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeey_batch_train = torch.stack(yeey_batch_train).reshape(
            model_bs, 1, args.domain_sizex - 2, args.domain_sizey - 1)

        y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                       args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]  for i in range(args.x_patches) for j in range(args.y_patches)]
        y_batch_train = torch.stack(y_batch_train).reshape(
            model_bs, 2, args.domain_sizex, args.domain_sizex)

        top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, intep_field = init_four_point_interp(
            DDM_Hy, prop2, args)

        plt.rcParams["font.size"] = "5"
        fig, axs = plt.subplots(int(args.DDM_iters / args.div_k) + 3, 4)
        axs[0,
            0].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x,
                              args.starting_y:args.starting_y + size_y])
        axs[0,
            1].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x,
                              args.starting_y:args.starting_y + size_y])
        axs[1, 0].imshow(DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x,
                                args.starting_y:args.starting_y + size_y])
        axs[1, 1].imshow(DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x,
                                args.starting_y:args.starting_y + size_y])
        axs[2, 0].imshow(intep_field[0, :, :])
        axs[2, 1].imshow(intep_field[1, :, :])

        for k in range(args.DDM_iters):
            # with autocast():
            patched_solved = torch.zeros(y_batch_train.shape)

            boundary_field = torch.zeros(y_batch_train.shape)
            boundary_field[:, :, 0:1, :] = top_bc_train[:, :, :, :]
            boundary_field[:, :, -1:, :] = bottom_bc_train[:, :, :, :]
            boundary_field[:, :, :, 0:1] = left_bc_train[:, :, :, :]
            boundary_field[:, :, :, -1:] = right_bc_train[:, :, :, :]

            for idx in range(model_bs):

                A_real, b_real = solver.construct_matrices(
                    boundary_field[idx, 0], yeex_batch_train[idx, 0],
                    yeey_batch_train[idx, 0])
                A_imag, b_imag = solver.construct_matrices(
                    boundary_field[idx, 1], yeex_batch_train[idx, 0],
                    yeey_batch_train[idx, 0])
                field_vec_real = solver.solve(A_real, b_real).reshape(
                    (args.domain_sizex - 2, args.domain_sizey - 2))
                field_vec_imag = solver.solve(A_imag, b_imag).reshape(
                    (args.domain_sizex - 2, args.domain_sizey - 2))

                inner_field = torch.stack([field_vec_real, field_vec_imag],
                                          dim=0)
                solved = torch.zeros((2, args.domain_sizex, args.domain_sizey))
                solved[:, 0, :] = torch.clone(boundary_field[idx, :, 0, :])
                solved[:, -1, :] = torch.clone(boundary_field[idx, :, -1, :])
                solved[:, :, 0] = torch.clone(boundary_field[idx, :, :, 0])
                solved[:, :, -1] = torch.clone(boundary_field[idx, :, :, -1])
                solved[:, 1:-1, 1:-1] = inner_field

                patched_solved[idx] = solved

            # reconstruct the whole field
            intermediate_result = reconstruct(patched_solved, args)
            # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape )
            loss = torch.mean(
                torch.abs(intermediate_result.contiguous().view(1, -1) -
                          DDM_Hy[:, :, args.starting_x:args.starting_x +
                                 size_x, args.starting_y:args.starting_y +
                                 size_y].contiguous().view(1, -1)))

            if k % args.div_k == 0:
                img_idx = int(k / args.div_k)
                im = axs[img_idx + 3, 0].imshow(intermediate_result[0, :, :])
                axs[img_idx + 3, 0].set_title("loss: " + str(loss), fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 3, 0])
                # im.set_clim(-4,4)

                im = axs[img_idx + 3, 1].imshow(intermediate_result[1, :, :])
                axs[img_idx + 3, 1].set_title("loss: " + str(loss), fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 3, 1])
                # im.set_clim(-4,4)

                im = axs[img_idx + 3,
                         2].imshow(reconstruct(y_batch_train, args)[0, :, :])
                axs[img_idx + 3, 2].set_title("y_batch_train: " + str(loss),
                                              fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 3, 2])
                # im.set_clim(-4,4)

                im = axs[img_idx + 3,
                         3].imshow(reconstruct(y_batch_train, args)[1, :, :])
                axs[img_idx + 3, 3].set_title("y_batch_train: " + str(loss),
                                              fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 3, 3])
            # im.set_clim(-4,4)

            # a = np.zeros((32,32))
            # a[0,:] = top_bc_train.cpu()[0,0,:,:].squeeze()
            # a[-1,:] = bottom_bc_train.cpu()[0,0,:,:].squeeze()
            # a[:,0] = left_bc_train.cpu()[0,0,:,:].squeeze()
            # a[:,-1] = right_bc_train.cpu()[0,0,:,:].squeeze()

            # axs[k+3,0].imshow(a)
            # axs[k+3,0].set_title("bcs: " + str(loss))
            # axs[k+3,1].imshow(y_batch_train.cpu()[0,0,:,:])
            # axs[k+3,1].set_title("y_batch_train: " + str(loss))
            # axs[k+3,2].imshow(logits.cpu()[0,0,:,:])
            # axs[k+3,2].set_title("logits: " + str(loss))

            # Then prepare the data for next iteration:
            top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = new_iter_bcs(
                patched_solved, y_batch_train, args, robin_transform_2nd)

            #calculate the loss using the ground truth
        fig.savefig("eval_" + str(sample_id) + ".png", dpi=3000)
Exemplo n.º 6
0
def main(args, solver):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ds = DDM_Dataset(args.data_folder, total_sample_number=None)
    torch.manual_seed(42)
    DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex -
                                                         args.overlap_pixels)
    size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey -
                                                         args.overlap_pixels)
    device_losses = np.zeros((args.num_device, args.DDM_iters))
    for sample_id, sample_batched in enumerate(DDM_loader):
        # if sample_id>=args.num_device:
        #     break
        if sample_id > 4:
            break
        if sample_id < 4:
            continue

        DDM_img, DDM_Hy = sample_batched['structure'], sample_batched['field']

        # prepare the input batched subdomains to model:
        model_bs = args.x_patches * args.y_patches
        x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                        args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)]
        x_batch_train = torch.stack(x_batch_train).reshape(
            model_bs, 1, args.domain_sizex, args.domain_sizey)

        yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeex_batch_train = torch.stack(yeex_batch_train).reshape(
            model_bs, 1, args.domain_sizex - 1, args.domain_sizey - 2)

        yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                              -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeey_batch_train = torch.stack(yeey_batch_train).reshape(
            model_bs, 1, args.domain_sizex - 2, args.domain_sizey - 1)

        y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                       args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]  for i in range(args.x_patches) for j in range(args.y_patches)]
        y_batch_train = torch.stack(y_batch_train).reshape(
            model_bs, 2, args.domain_sizex, args.domain_sizex)

        intep_field, patched_solved = init_four_point_interp(
            DDM_Hy, prop2, args)

        plt.rcParams["font.size"] = "5"
        fig, axs = plt.subplots(int(args.DDM_iters / args.div_k) + 3, 4)
        axs[0,
            0].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x,
                              args.starting_y:args.starting_y + size_y])
        axs[0,
            1].imshow(DDM_img[0, 0, args.starting_x:args.starting_x + size_x,
                              args.starting_y:args.starting_y + size_y])
        axs[1, 0].imshow(DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x,
                                args.starting_y:args.starting_y + size_y])
        axs[1, 1].imshow(DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x,
                                args.starting_y:args.starting_y + size_y])
        axs[2, 0].imshow(intep_field[0, :, :])
        axs[2, 1].imshow(intep_field[1, :, :])

        b, _, n, m = patched_solved.shape
        last_gs = torch.zeros((b, n, m), dtype=torch.cfloat)
        for k in range(args.DDM_iters):
            for idx in range(model_bs):
                if k == 0:
                    g, alpha, beta, gamma = trasmission_2nd_g_alpha_beta_gamma_complex(
                        patched_solved, x_batch_train, idx,
                        args.transmission_func, args, solver.diffL)
                else:
                    g, alpha, beta, gamma = trasmission_2nd_g_alpha_beta_gamma_complex(
                        patched_solved, x_batch_train, idx,
                        args.transmission_func, args, solver.diffL,
                        last_gs[idx])
                last_gs[idx] = g

                A, b = solver.construct_matrices_complex(
                    g, yeex_batch_train[idx, 0], yeey_batch_train[idx, 0],
                    alpha, beta, gamma)

                field_vec = solver.solve(A, b).reshape(
                    (args.domain_sizex, args.domain_sizey))
                field_vec_real = torch.real(field_vec)
                field_vec_imag = torch.imag(field_vec)

                solved = torch.stack([field_vec_real, field_vec_imag], dim=0)
                patched_solved[idx] = solved

            # reconstruct the whole field
            intermediate_result = reconstruct(patched_solved, args)
            # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape )
            loss = torch.mean(
                torch.abs(intermediate_result.contiguous().view(1, -1) -
                          DDM_Hy[:, :, args.starting_x:args.starting_x +
                                 size_x, args.starting_y:args.starting_y +
                                 size_y].contiguous().view(1, -1)))
            print(f"iter {k}, loss {loss}")
            # device_losses[sample_id, k] = loss
            if k % args.div_k == 0:
                img_idx = int(k / args.div_k)
                im = axs[img_idx + 3, 0].imshow(intermediate_result[0, :, :])
                axs[img_idx + 3, 0].set_title("loss: " + str(loss), fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 3, 0])
                # im.set_clim(-4,4)

                im = axs[img_idx + 3, 1].imshow(intermediate_result[1, :, :])
                axs[img_idx + 3, 1].set_title("loss: " + str(loss), fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 3, 1])
                # im.set_clim(-4,4)

                im = axs[img_idx + 3,
                         2].imshow(reconstruct(y_batch_train, args)[0, :, :])
                axs[img_idx + 3, 2].set_title("y_batch_train: " + str(loss),
                                              fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 3, 2])
                # im.set_clim(-4,4)

                im = axs[img_idx + 3,
                         3].imshow(reconstruct(y_batch_train, args)[1, :, :])
                axs[img_idx + 3, 3].set_title("y_batch_train: " + str(loss),
                                              fontsize=4)
                plt.colorbar(im, ax=axs[img_idx + 3, 3])
                # im.set_clim(-4,4)

        fig.savefig(f"alpha_{args.alpha}_beta{args.beta}_" + str(sample_id) +
                    ".png",
                    dpi=500)

    plt.figure()
    plt.plot(list(range(args.DDM_iters)), device_losses.T)
    plt.legend([f"device_{name}" for name in range(args.num_device)])
    plt.xlabel("iteration")
    plt.yscale('log')
    plt.ylabel("Relative Error")
    plt.savefig(f"alpha_{args.alpha}_beta{args.beta}_" + "device_losses.png",
                dpi=300)
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)
    device = torch.device('cuda')

    s_model = None
    if args.arch == "UNet":
        s_model = UNet(args).to(device)
    elif args.arch == "Fourier":
        s_model = FNO_multimodal_2d(args).to(device)
    else:
        raise ("architecture {args.arch} hasn't been added!!")

    s_model_path = args.s_model_saving_path + args.s_model_name + \
                                          "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \
                                          "_fmodes_" + str(args.f_modes) + \
                                          "_flayers_" + str(args.num_fourier_layers) + \
                                          "_Hidden_" + str(args.HIDDEN_DIM) + \
                                          "_f_padding_" + str(args.f_padding) + \
                                          "_batch_size_" + str(args.s_batch_size) + "_lr_" + str(args.s_lr)
    if not os.path.isdir(s_model_path):
        print(f"path error: {s_model_path}")
        raise

    b_model_path = args.b_model_saving_path + args.b_model_name + \
                                          "_b_hidden_" + str(args.b_hidden) + \
                                          "_b_layers_" + str(args.b_layers) + \
                                          "_boundary_thickness_" + str(args.boundary_thickness) + \
                                          "_domain_size_" + str(args.domain_sizex) + \
                                          "_b_batch_size_" + str(args.b_batch_size) + "_b_lr_" + str(args.b_lr)

    if not os.path.isdir(b_model_path):
        print(f"path error: {b_model_path}")
        raise

    # load the already trained s_model (subdomain model):
    print("Restoring subdomain model weights from ",
          s_model_path + "/best_model.pt",
          flush=True)
    checkpoint = torch.load(s_model_path + "/best_model.pt")
    s_model.load_state_dict(checkpoint['model'].state_dict())
    s_model.to(device)

    # init b_model (boundary model):
    b_model = boundary_model(args).to(device)
    print("Restoring boundary model weights from ",
          b_model_path + "/best_model.pt",
          flush=True)
    checkpoint = torch.load(b_model_path + "/best_model.pt")
    b_model.load_state_dict(checkpoint['model'].state_dict())
    b_model.to(device)

    tmp = filter(lambda x: x.requires_grad, s_model.parameters())
    s_num = sum(map(lambda x: np.prod(x.shape), tmp))

    tmp = filter(lambda x: x.requires_grad, b_model.parameters())
    b_num = sum(map(lambda x: np.prod(x.shape), tmp))

    print('s_model total trainable tensors:', s_num, flush=True)
    print('b_model total trainable tensors:', b_num, flush=True)

    # data loader
    ds = DDM_Dataset(args.data_folder,
                     total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    size_x = args.domain_sizex + (args.x_patches - 1) * (args.domain_sizex -
                                                         args.overlap_pixels)
    size_y = args.domain_sizey + (args.y_patches - 1) * (args.domain_sizey -
                                                         args.overlap_pixels)
    device_losses = np.zeros((args.num_device, args.DDM_iters))

    start_time = timeit.default_timer()
    for sample_id, sample_batched in enumerate(DDM_loader):
        if sample_id >= args.num_device:
            break
        # if sample_id>0:
        #   break
        # if sample_id<0:
        #   continue
        with torch.no_grad():

            DDM_img, DDM_Hy = sample_batched['structure'], sample_batched[
                'field']

            # prepare the input batched subdomains to model:
            model_bs = args.x_patches * args.y_patches
            x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                            args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)]
            x_batch_train = torch.stack(x_batch_train).reshape(
                model_bs, 1, args.domain_sizex, args.domain_sizex).to(device)

            y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                           args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]  for i in range(args.x_patches) for j in range(args.y_patches)]
            y_batch_train = torch.stack(y_batch_train).reshape(
                model_bs, 2, args.domain_sizex, args.domain_sizex).to(device)

            top_bc_train, bottom_bc_train, left_bc_train, right_bc_train, intep_field = init_four_point_interp(
                DDM_Hy, prop2, args, device, return_bc=True)

            plt.rcParams["font.size"] = "5"
            fig, axs = plt.subplots(int(args.DDM_iters / args.div_k) + 3, 2)
            im = axs[0, 0].imshow(
                DDM_img[0, 0, args.starting_x:args.starting_x + size_x,
                        args.starting_y:args.starting_y + size_y])
            plt.colorbar(im, ax=axs[0, 0])
            im = axs[0, 1].imshow(
                DDM_img[0, 0, args.starting_x:args.starting_x + size_x,
                        args.starting_y:args.starting_y + size_y])
            plt.colorbar(im, ax=axs[0, 1])
            im = axs[1, 0].imshow(
                DDM_Hy[0, 0, args.starting_x:args.starting_x + size_x,
                       args.starting_y:args.starting_y + size_y])
            plt.colorbar(im, ax=axs[1, 0])
            im = axs[1, 1].imshow(
                DDM_Hy[0, 1, args.starting_x:args.starting_x + size_x,
                       args.starting_y:args.starting_y + size_y])
            plt.colorbar(im, ax=axs[1, 1])
            im = axs[2, 0].imshow(intep_field[0, :, :])
            plt.colorbar(im, ax=axs[2, 0])
            im = axs[2, 1].imshow(intep_field[1, :, :])
            plt.colorbar(im, ax=axs[2, 1])

            # for debugging:
            # top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = boundary_from_gt(DDM_Hy, args, device)

            for k in range(args.DDM_iters):

                bc_mean = 1 / 4 * (
                    torch.mean(
                        torch.abs(top_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(bottom_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(left_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(right_bc_train), dim=(2, 3), keepdim=True))
                # with autocast():
                logits = s_model(x_batch_train, left_bc_train / bc_mean,
                                 right_bc_train / bc_mean, top_bc_train /
                                 bc_mean, bottom_bc_train / bc_mean).reshape(
                                     y_batch_train.shape) * bc_mean

                # reconstruct the whole field
                intermediate_result = reconstruct(logits, args)
                # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape )
                loss = s_model.loss_fn(
                    intermediate_result.contiguous().view(1, -1),
                    DDM_Hy[:, :, args.starting_x:args.starting_x + size_x,
                           args.starting_y:args.starting_y +
                           size_y].contiguous().view(1, -1))

                device_losses[sample_id, k] = loss
                if k % args.div_k == 0:
                    img_idx = int(k / args.div_k)
                    im = axs[img_idx + 3,
                             0].imshow(intermediate_result[0, :, :])
                    axs[img_idx + 3, 0].set_title("loss: " + str(loss),
                                                  fontsize=4)
                    plt.colorbar(im, ax=axs[img_idx + 3, 0])
                    # im.set_clim(-4,4)

                    im = axs[img_idx + 3,
                             1].imshow(intermediate_result[1, :, :])
                    axs[img_idx + 3, 1].set_title("loss: " + str(loss),
                                                  fontsize=4)
                    plt.colorbar(im, ax=axs[img_idx + 3, 1])

                # Then prepare the data for next iteration:
                if args.traditional_bc == 1:
                    top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = new_iter_bcs(
                        logits,
                        y_batch_train,
                        args,
                        args.bc_func,
                        device=device)
                else:
                    top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = new_iter_bcs_model(
                        x_batch_train.reshape(
                            (args.x_patches, args.y_patches, 1,
                             args.domain_sizex, args.domain_sizey)),
                        logits,
                        y_batch_train,
                        args,
                        b_model,
                        device=device)

            if args.traditional_bc == 1:
                fig.savefig(f"traditional_bc_" + str(sample_id) + ".png",
                            dpi=1000)
            else:
                fig.savefig(f"all_ml_" + str(sample_id) + ".png", dpi=1000)

    end_time = timeit.default_timer()
    print(
        f"Run time for {args.num_device} devices each with {args.DDM_iters} iterations (note: this may include time for plotting and saving figure):",
        end_time - start_time)

    plt.figure()
    plt.plot(list(range(args.DDM_iters)), device_losses.T)
    plt.legend([f"device_{name}" for name in range(args.num_device)])
    plt.xlabel("iteration")
    plt.yscale('log')
    plt.ylabel("Relative Error")
    if args.traditional_bc == 1:
        plt.savefig(f"traditional_bc_device_loss.png", dpi=300)
    else:
        plt.savefig(f"all_ml_device_loss.png", dpi=300)