def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    if args.gpu_id != -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    # config = []

    # device = torch.device('cuda')

    # model = None
    # if args.arch == "UNet":
    #     model = UNet(args).to(device)
    # elif args.arch == "Fourier":
    #     model = FNO_multimodal_2d(args).to(device)
    # else:
    #     raise("architecture {args.arch} hasn't been added!!")

    # # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    # optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)

    # tmp = filter(lambda x: x.requires_grad, model.parameters())
    # num = sum(map(lambda x: np.prod(x.shape), tmp))
    # print(model)

    # #for name, param in model.named_parameters():
    # #    print(name, param.size())
    # print('Total trainable tensors:', num, flush=True)

    # model_path = args.model_saving_path + args.model_name + \
    #                                       "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \
    #                                       "_fmodes_" + str(args.f_modes) + \
    #                                       "_flayers_" + str(args.num_fourier_layers) + \
    #                                       "_Hidden_" + str(args.HIDDEN_DIM) + \
    #                                       "_f_padding_" + str(args.f_padding) + \
    #                                       "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)
    # if not os.path.isdir(model_path):
    #     os.mkdir(model_path)

    ds = SimulationDataset(args.data_folder,
                           to_cuda=args.to_cuda,
                           total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(
        ds,
        [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

    #print("total training samples: %d, total test samples: %d" % (len(train_ds), len(test_ds)), flush=True)
    # train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    # test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)

    # train_mean = 0
    # test_mean = 0
    # # first get the mean-absolute-field value:
    # for sample_batched in train_loader:
    #     train_mean += torch.mean(torch.abs(sample_batched["field"]))
    # for sample_batched in test_loader:
    #     test_mean += torch.mean(torch.abs(sample_batched["field"]))
    # train_mean /= len(train_loader)
    # test_mean /= len(test_loader)

    # print("total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f" % (len(train_ds), len(test_ds), train_mean, test_mean), flush=True)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    # train_loss_history = []
    # train_phys_reg_history = []

    # test_loss_history = []
    # test_phys_reg_history = []

    # start_epoch=0
    # if (args.continue_train == "True"):
    #     print("Restoring weights from ", model_path+"/last_model.pt", flush=True)
    #     checkpoint = torch.load(model_path+"/last_model.pt")
    #     start_epoch=checkpoint['epoch']
    #     model = checkpoint['model']
    #     lr_scheduler = checkpoint['lr_scheduler']
    #     optimizer = checkpoint['optimizer']
    #     df = pd.read_csv(model_path + '/'+'df.csv')

    # scaler = GradScaler()

    # best_loss = 1e4
    # last_epoch_data_loss = 1.0
    # last_epoch_physical_loss = 1.0
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8764'
    mp.spawn(train, nprocs=args.world_size, args=(args, train_ds, test_ds))
Пример #2
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args)
    elif args.arch == "Fourier":
        model = FNO_multimodal_2d(args)
    else:
        raise ("architecture {args.arch} hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     weight_decay=args.weight_decay)
    loss_fn = model.loss_fn
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = args.model_saving_path + args.model_name + \
                                          "_domain_size_" + str(args.domain_sizex) + "_"+ str(args.domain_sizey) + \
                                          "_fmodes_" + str(args.f_modes) + \
                                          "_flayers_" + str(args.num_fourier_layers) + \
                                          "_Hidden_" + str(args.HIDDEN_DIM) + \
                                          "_f_padding_" + str(args.f_padding) + \
                                          "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    ds = SimulationDataset(args.data_folder,
                           to_cuda=args.to_cuda,
                           total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(
        ds,
        [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

    #print("total training samples: %d, total test samples: %d" % (len(train_ds), len(test_ds)), flush=True)
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=0)

    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print(
        "total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f"
        % (len(train_ds), len(test_ds), train_mean, test_mean),
        flush=True)

    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()

    df = pd.DataFrame(columns=[
        'epoch', 'train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'
    ])

    train_loss_history = []
    train_phys_reg_history = []

    test_loss_history = []
    test_phys_reg_history = []

    start_epoch = 0
    if (args.continue_train == "True"):
        print("Restoring weights from ",
              model_path + "/last_model.pt",
              flush=True)
        checkpoint = torch.load(model_path + "/last_model.pt")
        start_epoch = checkpoint['epoch']
        model = checkpoint['model']
        lr_scheduler = checkpoint['lr_scheduler']
        optimizer = checkpoint['optimizer']
        df = pd.read_csv(model_path + '/' + 'df.csv')

    # scaler = GradScaler()
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model.to(device)

    best_loss = 1e4
    last_epoch_data_loss = 1.0
    last_epoch_physical_loss = 1.0

    gradient_count = 0
    for step in range(start_epoch, args.epoch):
        epoch_start_time = timeit.default_timer()
        print("epoch: ", step, flush=True)
        # reg_norm = regConstScheduler(step, args, last_epoch_data_loss, last_epoch_physical_loss);
        # training
        for sample_batched in train_loader:
            gradient_count += 1
            optimizer.zero_grad()

            x_batch_train, y_batch_train, top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = sample_batched['structure'], sample_batched['field'], \
                                                                                 sample_batched['top_bc'],sample_batched['bottom_bc'],sample_batched['left_bc'],sample_batched['right_bc']
            if args.to_cuda == False:
                x_batch_train = x_batch_train.to(device)
                y_batch_train = y_batch_train.to(device)
                top_bc_train = top_bc_train.to(device)
                bottom_bc_train = bottom_bc_train.to(device)
                left_bc_train = left_bc_train.to(device)
                right_bc_train = right_bc_train.to(device)
            bc_mean = 1 / 4 * (
                torch.mean(torch.abs(top_bc_train), dim=(2, 3), keepdim=True) +
                torch.mean(
                    torch.abs(bottom_bc_train), dim=(2, 3), keepdim=True) +
                torch.mean(torch.abs(left_bc_train), dim=(2, 3), keepdim=True)
                + torch.mean(
                    torch.abs(right_bc_train), dim=(2, 3), keepdim=True))
            # with autocast():
            logits = model(x_batch_train, top_bc_train / bc_mean,
                           bottom_bc_train / bc_mean, left_bc_train / bc_mean,
                           right_bc_train / bc_mean).reshape(
                               y_batch_train.shape) * bc_mean
            #calculate the loss using the ground truth
            loss = loss_fn(
                logits.contiguous().view(args.batch_size, -1),
                y_batch_train.contiguous().view(args.batch_size, -1))

            # logits = logits[:,:,1:-1, :]
            # # print("loss: ", loss, flush=True)

            # # Calculate physical residue
            # # pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
            # pattern = x_batch_train
            # fields = logits[:,:,1:-1, :]
            # fields = torch.cat((y_batch_train[:, :, 0:1, :], fields, y_batch_train[:, :, -1:, :]), dim=2);
            # fields = fields[:,:,:,1:-1]
            # fields = torch.cat((y_batch_train[:, :, :, 0:1], fields, y_batch_train[:, :, :, -1:]), dim=3);
            # FD_Hy = H_to_H(-fields[:, 0], -fields[:, 1], yeex, yeey, dL, omega, pattern);
            # phys_regR = loss_fn(FD_Hy[:, 0], logits[:, 0, 1:-1, 1:-1]);
            # phys_regI = loss_fn(FD_Hy[:, 1], logits[:, 1, 1:-1, 1:-1]);
            # loss += phys_regR + phys_regI;

            # scaler.scale(loss).backward()
            # scaler.step(optimizer)
            # scaler.update()

            loss.backward()
            optimizer.step()

            if gradient_count >= args.lr_update_steps:
                gradient_count = 0
                lr_scheduler.step()
                print('lr: ', lr_scheduler.get_last_lr())

        #Save the weights at the end of each epoch
        checkpoint = {
            'epoch': step,
            'model': model,
            'optimizer': optimizer,
            'lr_scheduler': lr_scheduler
        }
        torch.save(checkpoint, model_path + "/last_model.pt")

        # evaluation
        train_loss = 0
        # train_phys_reg = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train, top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = sample_batched['structure'], sample_batched['field'], \
                                                                                 sample_batched['top_bc'],sample_batched['bottom_bc'],sample_batched['left_bc'],sample_batched['right_bc']
            if args.to_cuda == False:
                x_batch_train = x_batch_train.to(device)
                y_batch_train = y_batch_train.to(device)
                top_bc_train = top_bc_train.to(device)
                bottom_bc_train = bottom_bc_train.to(device)
                left_bc_train = left_bc_train.to(device)
                right_bc_train = right_bc_train.to(device)

            with torch.no_grad():
                bc_mean = 1 / 4 * (
                    torch.mean(
                        torch.abs(top_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(bottom_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(left_bc_train), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(right_bc_train), dim=(2, 3), keepdim=True))
                # with autocast():
                logits = model(x_batch_train, top_bc_train / bc_mean,
                               bottom_bc_train / bc_mean, left_bc_train /
                               bc_mean, right_bc_train / bc_mean).reshape(
                                   y_batch_train.shape) * bc_mean

                loss = loss_fn(
                    logits.contiguous().view(args.batch_size, -1),
                    y_batch_train.contiguous().view(args.batch_size, -1))

                # Calculate physical residue
                # pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                # pattern = x_batch_train
                # fields = logits[:,:,1:-1, :]
                # fields = torch.cat((y_batch_train[:, :, 0:1, :], fields, y_batch_train[:, :, -1:, :]), dim=2);
                # fields = fields[:,:,:,1:-1]
                # fields = torch.cat((y_batch_train[:, :, :, 0:1], fields, y_batch_train[:, :, :, -1:]), dim=3);
                # FD_Hy = H_to_H(-fields[:, 0], -fields[:, 1], yeex, yeey, dL, omega, pattern);
                # phys_regR = loss_fn(FD_Hy[:, 0], logits[:, 0, 1:-1, 1:-1]);
                # phys_regI = loss_fn(FD_Hy[:, 1], logits[:, 1, 1:-1, 1:-1]);

                train_loss += loss
                # train_phys_reg += 0.5*(phys_regR + phys_regI);

        train_loss /= len(train_loader)
        # train_phys_reg /= len(train_loader)

        test_loss = 0
        # test_phys_reg = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test, top_bc_test, bottom_bc_test, left_bc_test, right_bc_test = sample_batched['structure'], sample_batched['field'], \
                                                                                 sample_batched['top_bc'],sample_batched['bottom_bc'],sample_batched['left_bc'],sample_batched['right_bc']
            if args.to_cuda == False:
                x_batch_test = x_batch_test.to(device)
                y_batch_test = y_batch_test.to(device)
                top_bc_test = top_bc_test.to(device)
                bottom_bc_test = bottom_bc_test.to(device)
                left_bc_test = left_bc_test.to(device)
                right_bc_test = right_bc_test.to(device)

            with torch.no_grad():
                bc_mean = 1 / 4 * (
                    torch.mean(
                        torch.abs(top_bc_test), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(bottom_bc_test), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(left_bc_test), dim=(2, 3), keepdim=True) +
                    torch.mean(
                        torch.abs(right_bc_test), dim=(2, 3), keepdim=True))
                # with autocast():
                logits = model(x_batch_test, top_bc_test / bc_mean,
                               bottom_bc_test / bc_mean, left_bc_test /
                               bc_mean, right_bc_test / bc_mean).reshape(
                                   y_batch_test.shape) * bc_mean

                loss = loss_fn(
                    logits.contiguous().view(args.batch_size, -1),
                    y_batch_test.contiguous().view(args.batch_size, -1))

                # Calculate physical residue
                # pattern = x_batch_test
                # fields = logits[:,:,1:-1, :]
                # fields = torch.cat((y_batch_test[:, :, 0:1, :], fields, y_batch_test[:, :, -1:, :]), dim=2);
                # fields = fields[:,:,:,1:-1]
                # fields = torch.cat((y_batch_test[:, :, :, 0:1], fields, y_batch_test[:, :, :, -1:]), dim=3);
                # FD_Hy = H_to_H(-fields[:, 0], -fields[:, 1], yeex, yeey, dL, omega, pattern);
                # phys_regR = loss_fn(FD_Hy[:, 0], logits[:, 0, 1:-1, 1:-1]);
                # phys_regI = loss_fn(FD_Hy[:, 1], logits[:, 1, 1:-1, 1:-1]);

                test_loss += loss
                # test_phys_reg += 0.5*(phys_regR + phys_regI);

        test_loss /= len(test_loader)
        # test_phys_reg /= len(test_loader)
        last_epoch_data_loss = test_loss
        # last_epoch_physical_loss = test_phys_reg.detach().clone()
        # test_phys_reg *= reg_norm

        print('train loss: %.5f, test loss: %.5f' % (train_loss, test_loss),
              flush=True)

        df = df.append(
            {
                'epoch': step + 1,
                'lr': str(lr_scheduler.get_last_lr()),
                'train_loss': train_loss.item(),
                'test_loss': test_loss.item(),
            },
            ignore_index=True)

        df.to_csv(model_path + '/' + 'df.csv', index=False)

        if (test_loss < best_loss):
            best_loss = test_loss
            checkpoint = {
                'epoch': step,
                'model': model,
                'optimizer': optimizer,
                'lr_scheduler': lr_scheduler
            }
            torch.save(checkpoint, model_path + "/best_model.pt")
        epoch_stop_time = timeit.default_timer()
        print("epoch run time:", epoch_stop_time - epoch_start_time)
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    elif args.arch == "Fourier":
        model = FNO_multimodal_2d(args).to(device)
    else:
        raise("architecture {args.arch} hasn't been added!!")


    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)
    
    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    SUMMARY_INTERVAL=5
    TEST_PRINT_INTERVAL=SUMMARY_INTERVAL*5
    ITER_SAVE_INTERVAL=300
    EPOCH_SAVE_INTERVAL=5

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)
    
    ds = SimulationDataset(args.data_folder, total_sample_number = args.total_sample_number)
    means = [1e-3, 1e-3]#ds.means; #[Hy_meanR, Hy_meanI, Ex_meanR, Ex_meanI, Ez_meanR, Ez_meanI];
    print("means: ", means);
    torch.manual_seed(42)
    train_ds, test_ds = random_split(ds, [int(0.9*len(ds)), len(ds) - int(0.9*len(ds))])
    
    #print("total training samples: %d, total test samples: %d" % (len(train_ds), len(test_ds)), flush=True)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    
    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print("total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f" % (len(train_ds), len(test_ds), train_mean, test_mean), flush=True)
    
    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()
    

    df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    train_loss_history = []
    train_phys_reg_history = []

    test_loss_history = []
    test_phys_reg_history = []

    start_epoch=0
    if (args.continue_train):
        print("Restoring weights from ", model_path+"/last_model.pt", flush=True)
        checkpoint = torch.load(model_path+"/last_model.pt")
        start_epoch=checkpoint['epoch']
        model = checkpoint['model']
        model.lr_scheduler = checkpoint['lr_scheduler']
        model.optimizer = checkpoint['optimizer']
        df = pd.read_csv(model_path + '/'+'df.csv')
        
    # scaler = GradScaler()
    
    best_loss = 1e4
    last_epoch_data_loss = 1.0
    last_epoch_physical_loss = 1.0
    for step in range(start_epoch, args.epoch):
        print("epoch: ", step, flush=True)
        reg_norm = regConstScheduler(step, args, last_epoch_data_loss, last_epoch_physical_loss);
        # training
        for sample_batched in train_loader:
            model.optimizer.zero_grad()
            
            x_batch_train, y_batch_train, top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device)
            
            # with autocast():
            top_zeros = torch.zeros(top_bc_train.shape).to(device)
            bottom_zeros = torch.zeros(bottom_bc_train.shape).to(device)
            left_zeros = torch.zeros(left_bc_train.shape).to(device)
            right_zeros = torch.zeros(right_bc_train.shape).to(device)
            logits = model(x_batch_train, top_zeros, bottom_zeros, left_zeros, right_zeros).reshape(y_batch_train.shape)
            #calculate the loss using the ground truth
            loss = model.loss_fn(logits.contiguous().view(args.batch_size,-1), y_batch_train.contiguous().view(args.batch_size,-1))
            # logits = logits[:,:,1:-1, :]
            # # print("loss: ", loss, flush=True)
            
            # # Calculate physical residue
            # pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
            # pattern = torch.cat((torch.ones([pattern.shape[0], 1, 1, pattern.shape[-1]], dtype = torch.float32, device = device)*n_sub**2, pattern), dim=2);
            # #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
            # fields = torch.cat((y_batch_train[:, :, 0:1, :], logits, y_batch_train[:, :, -1:, :]), dim=2);
            # FD_Hy = H_to_H(-fields[:, 0]*means[0], -fields[:, 1]*means[1], dL, omega, pattern);
            # #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], logits[:, 0])/reg_norm;
            # #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], logits[:, 1])/reg_norm;
            # phys_regR = model.loss_fn(FD_Hy[:, 0]/means[0], logits[:, 0])*reg_norm;
            # phys_regI = model.loss_fn(FD_Hy[:, 1]/means[1], logits[:, 1])*reg_norm;
            # loss += phys_regR + phys_regI;
            
            # scaler.scale(loss).backward()
            # scaler.step(model.optimizer)
            # scaler.update()

            loss.backward()
            model.optimizer.step()
           
        #Save the weights at the end of each epoch
        checkpoint = {
                        'epoch': step,
                        'model': model,
                        'optimizer': model.optimizer,
                        'lr_scheduler': model.lr_scheduler
                     }
        torch.save(checkpoint, model_path+"/last_model.pt")


        # evaluation
        train_loss = 0
        train_phys_reg = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train, top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device)
            
            with torch.no_grad():
                top_zeros = torch.zeros(top_bc_train.shape).to(device)
                bottom_zeros = torch.zeros(bottom_bc_train.shape).to(device)
                left_zeros = torch.zeros(left_bc_train.shape).to(device)
                right_zeros = torch.zeros(right_bc_train.shape).to(device)
                
                logits = model(x_batch_train, top_zeros, bottom_zeros, left_zeros, right_zeros).reshape(y_batch_train.shape)
            
                loss = model.loss_fn(logits.contiguous().view(args.batch_size,-1), y_batch_train.contiguous().view(args.batch_size,-1))
                logits = logits[:, :, 1:-1, :]
                
                # Calculate physical residue
                pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat((torch.ones([pattern.shape[0], 1, 1, pattern.shape[-1]], dtype = torch.float32, device = device)*n_sub**2, pattern), dim=2);
                #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
                fields = torch.cat((y_batch_train[:, :, 0:1, :], logits, y_batch_train[:, :, -1:, :]), dim=2);
                FD_Hy = H_to_H(-fields[:, 0]*means[0], -fields[:, 1]*means[1], dL, omega, pattern);
                #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :])/reg_norm;
                #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :])/reg_norm;
                phys_regR = model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :])*reg_norm;
                phys_regI = model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :])*reg_norm;
                
                #loss = loss + phys_reg1 + phys_reg2 + phys_reg3;
                train_loss += loss
                train_phys_reg += 0.5*(phys_regR + phys_regI);
                
        train_loss /= len(train_loader)
        train_phys_reg /= len(train_loader)

        test_loss = 0
        test_phys_reg = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test, top_bc_test, bottom_bc_test, left_bc_test, right_bc_test = sample_batched['structure'].to(device), sample_batched['field'].to(device), \
                                                                                 sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device)
             
            with torch.no_grad():
                top_zeros = torch.zeros(top_bc_test.shape).to(device)
                bottom_zeros = torch.zeros(bottom_bc_test.shape).to(device)
                left_zeros = torch.zeros(left_bc_test.shape).to(device)
                right_zeros = torch.zeros(right_bc_test.shape).to(device)
                
                logits = model(x_batch_test, top_zeros, bottom_zeros, left_zeros, right_zeros).reshape(y_batch_test.shape)
                loss = model.loss_fn(logits.contiguous().view(args.batch_size,-1), y_batch_test.contiguous().view(args.batch_size,-1))
                logits = logits[:, :, 1:-1, :]
                # Calculate physical residue
                pattern = (x_batch_test*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat((torch.ones([pattern.shape[0], 1, 1, pattern.shape[-1]], dtype = torch.float32, device = device)*n_sub**2, pattern), dim=2);
                #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
                fields = torch.cat((y_batch_test[:, :, 0:1, :], logits, y_batch_test[:, :, -1:, :]), dim=2);
                FD_Hy = H_to_H(-fields[:, 0]*means[0], -fields[:, 1]*means[1], dL, omega, pattern);
                #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :])/reg_norm;
                #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :])/reg_norm;
                phys_regR = model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :]);
                phys_regI = model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :]);
                
                test_loss += loss
                test_phys_reg += 0.5*(phys_regR + phys_regI);
                
        test_loss /= len(test_loader)
        test_phys_reg /= len(test_loader)
        last_epoch_data_loss = test_loss
        last_epoch_physical_loss = test_phys_reg.detach().clone()
        test_phys_reg *= reg_norm
        
        print('train loss: %.5f, train phys reg: %.5f, test loss: %.5f, test phys reg: %.5f, last_physical_loss: %.5f' % (train_loss, train_phys_reg, test_loss, test_phys_reg, last_epoch_physical_loss), flush=True)
    
        model.lr_scheduler.step()

        df = df.append({'epoch': step+1, 'lr': str(model.lr_scheduler.get_last_lr()),
                        'train_loss': train_loss.item(),
                        'train_phys_reg': train_phys_reg.item(),
                        'test_loss': test_loss.item(),
                        'test_phys_reg': test_phys_reg.item(),
                       }, ignore_index=True)

        df.to_csv(model_path + '/'+'df.csv',index=False)

        if(test_loss<best_loss):
            best_loss = test_loss
            checkpoint = {
                            'epoch': step,
                            'model': model,
                            'optimizer': model.optimizer,
                            'lr_scheduler': model.lr_scheduler
                         }
            torch.save(checkpoint, model_path+"/best_model.pt")
Пример #4
0
def main(data_folder, N, batch_size):
    device = torch.device('cuda')

    ds = SimulationDataset(data_folder, total_sample_number = None)
    
    #print("total training samples: %d, total test samples: %d" % (len(train_ds), len(test_ds)), flush=True)
    data_loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
    

    top_bc = np.load(data_folder+"/"+"cropped_top_bc.npy")
    bottom_bc = np.load(data_folder+"/"+"cropped_bottom_bc.npy")
    left_bc = np.load(data_folder+"/"+"cropped_left_bc.npy")
    right_bc = np.load(data_folder+"/"+"cropped_right_bc.npy")
    print("load1")
    # gt = np.load(data_folder+"/"+"cropped_Hys.npy")
    # print("load2")

    # MAE1 = 0
    # MAE2 = 0
    # eval_num = 10000
    preprocess_field = np.zeros((top_bc.shape[0], 32,32,2))

    for i, sample_batched in enumerate(data_loader):
        start = i*batch_size
        end = (i+1)*batch_size if (i+1)*batch_size<top_bc.shape[0] else top_bc.shape[0]

        if (i+1) % 10 ==0:
            print("i+1: ", i+1, "num_samples: ", (i+1)*batch_size)

        top_bc_train, bottom_bc_train, left_bc_train, right_bc_train = sample_batched['top_bc'].to(device),sample_batched['bottom_bc'].to(device),sample_batched['left_bc'].to(device),sample_batched['right_bc'].to(device)
        
        # hy = gt[index]
        # print("shapes: ", t.shape, b.shape, l.shape, r.shape)

        initial = torch.tensor(np.zeros((end-start,32,32,2))).to(device)
        initial[:, 0, :, :] = top_bc_train.permute(0, 2, 3, 1).squeeze()
        initial[:, -1, :, :] = bottom_bc_train.permute(0, 2, 3, 1).squeeze()
        initial[:, :, 0, :] = left_bc_train.permute(0, 2, 3, 1).squeeze()
        initial[:, :, -1, :] = right_bc_train.permute(0, 2, 3, 1).squeeze()
       
        # processed = average_inpaint(initial.copy())
        # processed1 = four_point_interp_inpaint(initial.copy(), prop1)
        processed2 = four_point_interp_inpaint(initial, prop2)

        preprocess_field[start:end] = processed2.cpu()

        # MAE1 += np.mean(np.abs(processed1-hy))/np.mean(np.abs(hy))
        # MAE2 += np.mean(np.abs(processed2-hy))/np.mean(np.abs(hy))
        # fig, axs = plt.subplots(3,2)
        # im = axs[0,0].imshow(initial[:,:,0])
        # plt.colorbar(im, ax = axs[0,0])
        # im = axs[0,1].imshow(initial[:,:,1])
        # plt.colorbar(im, ax = axs[0,1])
        # im = axs[1,0].imshow(processed[:,:,0])
        # plt.colorbar(im, ax = axs[1,0])
        # im = axs[1,1].imshow(processed[:,:,1])
        # plt.colorbar(im, ax = axs[1,1])
        # im = axs[2,0].imshow(hy[:,:,0])
        # plt.colorbar(im, ax = axs[2,0])
        # im = axs[2,1].imshow(hy[:,:,1])
        # plt.colorbar(im, ax = axs[2,1])
        # fig.savefig("test"+ str(i) + ".png", dpi=300)

    # print("mae1, mae2: ", MAE1/eval_num, MAE2/eval_num)
    np.save(data_folder + "/" + "preprocessed.npy", preprocess_field)