Ejemplo n.º 1
0
def updateDataloader(args, dataSizer, val_error_list):#, val_error_2bar=1, val_error_4bar=1, val_error_6bar=1, val_error_8bar=1):

    train_nums = dataSizer.updatePools(val_error_list)

    ds = SimulationDataset(args.data_folder, train_nums,
                           total_sample_number = args.total_sample_number)

    torch.manual_seed(42)
    train_ds, test_ds = random_split(ds, [int(0.95*len(ds)), len(ds) - int(0.95*len(ds))])

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)

    train_mean = 0
    test_mean = 0

    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    return train_loader, test_loader, train_mean, test_mean, train_nums
Ejemplo n.º 2
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    ds = SimulationDataset(args.data_folder,
                           total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(
        ds,
        [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

    train_3k = np.stack([train_ds[i]['field'] for i in range(3000)])
    test_3k = np.stack(
        [test_ds[i]['field'] for i in range(min(3000, len(test_ds)))])
    print("train_3k.shape: ", train_3k.shape)
    print("test_3k.shape: ", test_3k.shape)
    np.save(
        "/scratch/groups/jonfan/UNet/data_for_model_evaluation/30k_forward_train_3k.npy",
        train_3k)
    np.save(
        "/scratch/groups/jonfan/UNet/data_for_model_evaluation/30k_forward_test_3k.npy",
        test_3k)
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    elif args.arch == "Fourier":
        model = FNO2d(args).to(device)
    else:
        raise("architecture {args.arch} hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)
    print('Total trainable tensors:', num, flush=True)

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)
    
    ds = SimulationDataset(args.data_folder, total_sample_number = args.total_sample_number)
    means = [1e-3, 1e-3]#ds.means; #[Hy_meanR, Hy_meanI, Ex_meanR, Ex_meanI, Ez_meanR, Ez_meanI];
    print("means: ", means);
    torch.manual_seed(42)
    train_ds, test_ds = random_split(ds, [int(0.9*len(ds)), len(ds) - int(0.9*len(ds))])
    
    #print("total training samples: %d, total test samples: %d" % (len(train_ds), len(test_ds)), flush=True)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    
    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print("total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f" % (len(train_ds), len(test_ds), train_mean, test_mean), flush=True)
    
    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()
    

    df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    train_loss_history = []
    train_phys_reg_history = []

    test_loss_history = []
    test_phys_reg_history = []

    start_epoch=0
    if (args.continue_train==1):
        print("Restoring weights from ", model_path+"/last_model.pt", flush=True)
        checkpoint = torch.load(model_path+"/last_model.pt")
        start_epoch=checkpoint['epoch']
        model.load_state_dict(checkpoint['model'].state_dict())
        model.lr_scheduler = checkpoint['lr_scheduler']
        model.optimizer = checkpoint['optimizer']
        model.to(device)
        df = pd.read_csv(model_path + '/'+'df.csv')
    else:
        with open(model_path + '/'+'config.txt', 'w') as f:
            f.write('\n'.join(sys.argv[1:]))
            f.write(model.__str__())
            f.write(f'model Total trainable tensors: {num}')
        
    # scaler = GradScaler()
    
    best_loss = 1e4
    last_epoch_data_loss = 1.0
    last_epoch_physical_loss = 1.0
    for step in range(start_epoch, args.epoch):
        print("epoch: ", step, flush=True)
        reg_norm = regConstScheduler(step, args, last_epoch_data_loss, last_epoch_physical_loss);
        # training
        for sample_batched in train_loader:
            model.optimizer.zero_grad()
            
            x_batch_train, y_batch_train = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            
            # with autocast():
            logits = model(x_batch_train).reshape(y_batch_train.shape)
            #calculate the loss using the ground truth
            loss = model.loss_fn(logits.contiguous().view(args.batch_size,-1), y_batch_train.contiguous().view(args.batch_size,-1))
            logits = logits[:,:,1:-1, :]
            # print("loss: ", loss, flush=True)
            
            # Calculate physical residue
            # pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
            # pattern = torch.cat((torch.ones([pattern.shape[0], 1, 1, 256], dtype = torch.float32, device = device)*n_sub**2, pattern), dim=2);
            # #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
            # fields = torch.cat((y_batch_train[:, :, 0:1, :], logits, y_batch_train[:, :, -1:, :]), dim=2);
            # FD_Hy = H_to_H(-fields[:, 0]*means[0], -fields[:, 1]*means[1], dL, omega, pattern);
            # #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], logits[:, 0])/reg_norm;
            # #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], logits[:, 1])/reg_norm;
            # phys_regR = model.loss_fn(FD_Hy[:, 0]/means[0], logits[:, 0])*reg_norm;
            # phys_regI = model.loss_fn(FD_Hy[:, 1]/means[1], logits[:, 1])*reg_norm;
            # loss += phys_regR + phys_regI;
            
            # scaler.scale(loss).backward()
            # scaler.step(model.optimizer)
            # scaler.update()

            loss.backward()
            model.optimizer.step()
           
        #Save the weights at the end of each epoch
        checkpoint = {
                        'epoch': step,
                        'model': model,
                        'optimizer': model.optimizer,
                        'lr_scheduler': model.lr_scheduler
                     }
        torch.save(checkpoint, model_path+"/last_model.pt")


        # evaluation
        train_loss = 0
        train_phys_reg = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            
            with torch.no_grad():
                logits = model(x_batch_train).reshape(y_batch_train.shape)
                loss = model.loss_fn(logits.contiguous().view(args.batch_size,-1), y_batch_train.contiguous().view(args.batch_size,-1))
                logits = logits[:, :, 1:-1, :]
                
                # Calculate physical residue
                pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat((torch.ones([pattern.shape[0], 1, 1, 256], dtype = torch.float32, device = device)*n_sub**2, pattern), dim=2);
                #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
                fields = torch.cat((y_batch_train[:, :, 0:1, :], logits, y_batch_train[:, :, -1:, :]), dim=2);
                FD_Hy = H_to_H(-fields[:, 0]*means[0], -fields[:, 1]*means[1], dL, omega, pattern);
                #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :])/reg_norm;
                #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :])/reg_norm;
                phys_regR = model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :])*reg_norm;
                phys_regI = model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :])*reg_norm;
                
                #loss = loss + phys_reg1 + phys_reg2 + phys_reg3;
                train_loss += loss
                train_phys_reg += 0.5*(phys_regR + phys_regI);
                
        train_loss /= len(train_loader)
        train_phys_reg /= len(train_loader)

        test_loss = 0
        test_phys_reg = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            
            with torch.no_grad():
                logits = model(x_batch_test).reshape(y_batch_test.shape)
                loss = model.loss_fn(logits.contiguous().view(args.batch_size,-1), y_batch_test.contiguous().view(args.batch_size,-1))
                logits = logits[:, :, 1:-1, :]
                # Calculate physical residue
                pattern = (x_batch_test*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat((torch.ones([pattern.shape[0], 1, 1, 256], dtype = torch.float32, device = device)*n_sub**2, pattern), dim=2);
                #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
                fields = torch.cat((y_batch_test[:, :, 0:1, :], logits, y_batch_test[:, :, -1:, :]), dim=2);
                FD_Hy = H_to_H(-fields[:, 0]*means[0], -fields[:, 1]*means[1], dL, omega, pattern);
                #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :])/reg_norm;
                #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :])/reg_norm;
                phys_regR = model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :]);
                phys_regI = model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :]);
                
                test_loss += loss
                test_phys_reg += 0.5*(phys_regR + phys_regI);
                
        test_loss /= len(test_loader)
        test_phys_reg /= len(test_loader)
        last_epoch_data_loss = test_loss
        last_epoch_physical_loss = test_phys_reg.detach().clone()
        test_phys_reg *= reg_norm
        
        print('train loss: %.5f, train phys reg: %.5f, test loss: %.5f, test phys reg: %.5f, last_physical_loss: %.5f' % (train_loss, train_phys_reg, test_loss, test_phys_reg, last_epoch_physical_loss), flush=True)
    
        model.lr_scheduler.step()

        df = df.append({'epoch': step+1, 'lr': str(model.lr_scheduler.get_last_lr()),
                        'train_loss': train_loss.item(),
                        'train_phys_reg': train_phys_reg.item()/reg_norm if reg_norm!=0 else train_phys_reg.item(),
                        'test_loss': test_loss.item(),
                        'test_phys_reg': test_phys_reg.item()/reg_norm if reg_norm!=0 else test_phys_reg.item(),
                       }, ignore_index=True)

        df.to_csv(model_path + '/'+'df.csv',index=False)

        if(test_loss<best_loss):
            best_loss = test_loss
            checkpoint = {
                            'epoch': step,
                            'model': model,
                            'optimizer': model.optimizer,
                            'lr_scheduler': model.lr_scheduler
                         }
            torch.save(checkpoint, model_path+"/best_model.pt")
Ejemplo n.º 4
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    else:
        raise("architectures other than Unet hasn't been added!!")


    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=1e-7, amsgrad=True, weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)
    
    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    SUMMARY_INTERVAL=5
    TEST_PRINT_INTERVAL=SUMMARY_INTERVAL*5
    ITER_SAVE_INTERVAL=300
    EPOCH_SAVE_INTERVAL=5

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr) + "_data_" + str(args.data_folder.split('/')[-1]) + "_HIDDEN_DIM_" + str(args.HIDDEN_DIM)+"_regNorm_"+str(args.reg_norm)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)
    
    ds = SimulationDataset(args.data_folder, total_sample_number = args.total_sample_number)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(ds, [int(0.9*len(ds)), len(ds) - int(0.9*len(ds))])

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)

    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print("total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f" % (len(train_ds), len(test_ds), train_mean, test_mean), flush=True)
    

    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()
    

    df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    train_loss_history = []
    train_phys_reg_history = []

    test_loss_history = []
    test_phys_reg_history = []

    start_epoch=0
    if (args.continue_train):
        print("Restoring weights from ", model_path+"/last_model.pt", flush=True)
        checkpoint = torch.load(model_path+"/last_model.pt")
        start_epoch=checkpoint['epoch']
        model = checkpoint['model']
        model.lr_scheduler = checkpoint['lr_scheduler']
        model.optimizer = checkpoint['optimizer']
        df = pd.read_csv(model_path + '/'+'df.csv')
        
    scaler = GradScaler()

    best_loss = 1e4
    last_epoch_data_loss = 1.0
    last_epoch_physical_loss = 1.0
    for step in range(start_epoch, args.epoch):
        print("epoch: ", step, flush=True)
        reg_norm = regConstScheduler(step, args, last_epoch_data_loss, last_epoch_physical_loss);
        # training
        for sample_batched in train_loader:
            model.optimizer.zero_grad()
            
            x_batch_train, y_batch_train = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            with autocast():
                logits = model(x_batch_train, bn_training=True)
                #calculate the loss using the ground truth
                loss = model.loss_fn(logits, y_batch_train)
                # print("loss: ", loss, flush=True)

                pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                fields = logits; # predicted fields
                FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)

                phys_reg = model.loss_fn(FD_H, fields[:,:,1:(Nz-1),:])*reg_norm
                loss = loss + phys_reg            

                scaler.scale(loss).backward()
                scaler.step(model.optimizer)
                scaler.update()
                # loss.backward()
                # model.optimizer.step()

        #Save the weights at the end of each epoch
        checkpoint = {
                        'epoch': step,
                        'model': model,
                        'optimizer': model.optimizer,
                        'lr_scheduler': model.lr_scheduler
                     }
        torch.save(checkpoint, model_path+"/last_model.pt")


        # evaluation
        train_loss = 0
        train_phys_reg = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            
            with torch.no_grad():
                logits = model(x_batch_train, bn_training=False)
                loss = model.loss_fn(logits, y_batch_train)
                
                # Calculate physical residue
                pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                fields = logits;
                FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                phys_reg = model.loss_fn(FD_H, fields[:,:,1:(Nz-1),:])*reg_norm

                train_loss += loss
                train_phys_reg += phys_reg

        train_loss /= len(train_loader)*train_mean
        train_phys_reg /= len(train_loader)

        test_loss = 0
        test_phys_reg = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            
            with torch.no_grad():
                logits = model(x_batch_test, bn_training=False)
                loss = model.loss_fn(logits, y_batch_test)

                # Calculate physical residue
                pattern = (x_batch_test*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                fields = logits;
                FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                phys_reg = model.loss_fn(FD_H, fields[:,:,1:(Nz-1),:])

                test_loss += loss
                test_phys_reg += phys_reg
        test_loss /= len(test_loader)*test_mean
        test_phys_reg /= len(test_loader)
        last_epoch_data_loss = test_loss
        last_epoch_physical_loss = test_phys_reg.detach().clone()

        test_phys_reg *= reg_norm
        

        print('train loss: %.5f, test loss: %.5f, train phys reg: %.5f, test phys reg: %.5f, last_physical_loss: %.5f' % (train_loss, test_loss, train_phys_reg, test_phys_reg, last_epoch_physical_loss), flush=True)

            
        model.lr_scheduler.step()

        df = df.append({'epoch': step+1, 'lr': str(model.lr_scheduler.get_last_lr()),
                        'train_loss': train_loss.item(),
                        'train_phys_reg': train_phys_reg.item(),
                        'test_loss': test_loss.item(),
                        'test_phys_reg': test_phys_reg.item()
                       }, ignore_index=True)

        df.to_csv(model_path + '/'+'df.csv',index=False)

        if(test_loss<best_loss):
            best_loss = test_loss
            checkpoint = {
                            'epoch': step,
                            'model': model,
                            'optimizer': model.optimizer,
                            'lr_scheduler': model.lr_scheduler
                         }
            torch.save(checkpoint, model_path+"/best_model.pt")
def main(args):

    device = torch.device('cuda')
    ################## gaussian filter #######################
    # Set these to whatever you want for your gaussian filter
    kernel_size = 9
    pad = 4
    sigma = 4

    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
    x_cord = torch.arange(kernel_size)
    x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1)

    mean = (kernel_size - 1) / 2.
    variance = sigma**2.

    # Calculate the 2-dimensional gaussian kernel which is
    # the product of two gaussian distributions for two different
    # variables (in this case called x and y)
    gaussian_kernel = (1./(2.*math.pi*variance)) *\
                      torch.exp(
                          -torch.sum((xy_grid - mean)**2., dim=-1) /\
                          (2*variance)
                      )
    # Make sure sum of values in gaussian kernel equals 1.
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

    # Reshape to 2d depthwise convolutional weight
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)

    channels = 2
    gaussian_kernel2 = gaussian_kernel.repeat(channels, 1, 1, 1).to(device)

    gaussian_filter2 = nn.Conv2d(in_channels=channels,
                                 out_channels=channels,
                                 kernel_size=kernel_size,
                                 groups=channels,
                                 bias=False,
                                 padding=(pad, pad),
                                 padding_mode="replicate")

    gaussian_filter2.weight.data = gaussian_kernel2
    gaussian_filter2.weight.requires_grad = False

    channels = 1
    gaussian_kernel1 = gaussian_kernel.repeat(channels, 1, 1, 1).to(device)

    gaussian_filter1 = nn.Conv2d(in_channels=channels,
                                 out_channels=channels,
                                 kernel_size=kernel_size,
                                 groups=channels,
                                 bias=False,
                                 padding=(pad, pad),
                                 padding_mode="replicate")

    gaussian_filter1.weight.data = gaussian_kernel1
    gaussian_filter1.weight.requires_grad = False
    ##########################################################

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    else:
        raise ("architectures other than Unet hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 eps=1e-7,
                                 amsgrad=True,
                                 weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(
        model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(
        args.batch_size) + "_lr_" + str(args.lr) + "_data_" + str(
            args.data_folder.split('/')[-1]) + "_HIDDEN_DIM_" + str(
                args.HIDDEN_DIM) + "_regNorm_" + str(args.reg_norm)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    ds = SimulationDataset(args.data_folder,
                           total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(
        ds,
        [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=0)

    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print(
        "total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f"
        % (len(train_ds), len(test_ds), train_mean, test_mean),
        flush=True)

    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()

    df = pd.DataFrame(columns=[
        'epoch', 'train_loss', 'train_phys_reg_H', 'train_phys_reg_Ex',
        'train_phys_reg_Ez', 'test_loss', 'test_phys_reg_H',
        'test_phys_reg_Ex', 'test_phys_reg_Ez'
    ])

    start_epoch = 0
    if (args.continue_train):
        print("Restoring weights from ",
              model_path + "/last_model.pt",
              flush=True)
        checkpoint = torch.load(model_path + "/last_model.pt")
        start_epoch = checkpoint['epoch']
        model = checkpoint['model']
        model.lr_scheduler = checkpoint['lr_scheduler']
        model.optimizer = checkpoint['optimizer']
        df = pd.read_csv(model_path + '/' + 'df.csv')

    scaler = GradScaler()

    best_loss = 1e4
    last_epoch_data_loss = 1.0
    last_epoch_physical_loss_H = 1.0
    last_epoch_physical_loss_Ex = 1.0
    last_epoch_physical_loss_Ez = 1.0
    for step in range(start_epoch, args.epoch):
        print("epoch: ", step, flush=True)
        reg_norm_H, reg_norm_Ex, reg_norm_Ez = regConstScheduler(
            step, args, last_epoch_data_loss, last_epoch_physical_loss_H,
            last_epoch_physical_loss_Ex, last_epoch_physical_loss_Ez)
        #print("reg_norm_H: ", reg_norm_H, "reg_norm_Ex: ", reg_norm_Ex, "reg_norm_Ez: ", reg_norm_Ez)
        # training
        for sample_batched in train_loader:
            model.optimizer.zero_grad()

            x_batch_train, y_batch_train = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)
            with autocast():
                logits = model(x_batch_train, bn_training=True)
                #calculate the loss using the ground truth
                loss = model.loss_fn(logits, y_batch_train)
                # print("loss: ", loss, flush=True)

                pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat(
                    (torch.ones([pattern.shape[0], 1, 1, 256],
                                dtype=torch.float32,
                                device=device) * n_sub**2, pattern),
                    dim=2)

                filtered_logits = gaussian_filter2(logits)
                filtered_pattern = gaussian_filter1(pattern)

                # FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                FD_H = H_to_H(-filtered_logits[:, 0], -filtered_logits[:, 1],
                              dL, omega, filtered_pattern)

                # FD_Ex = Hz_to_Ex(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                # FD_Ez = Hz_to_Ey(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                FD_Ex = Hz_to_Ex(-filtered_logits[:, 0, :, :],
                                 -filtered_logits[:, 1, :, :], dx * 1e-9,
                                 omega, filtered_pattern)
                FD_Ez = Hz_to_Ey(-filtered_logits[:, 0, :, :],
                                 -filtered_logits[:, 1, :, :], dx * 1e-9,
                                 omega, filtered_pattern)

                pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                filtered_pattern = gaussian_filter1(pattern)
                FD2_Ex, FD2_Ez = E_to_E(FD_Ez[:, 0, 0:-1, :], FD_Ez[:, 1,
                                                                    0:-1, :],
                                        FD_Ex[:, 0, :, :], FD_Ex[:, 1, :, :],
                                        dx * 1e-9, omega,
                                        filtered_pattern[:, :, 0:-1, :])

                phys_reg_H = model.loss_fn(
                    FD_H, logits[:, :, 1:(Nz - 1), :]) * reg_norm_H
                # phys_reg_Ex = model.loss_fn(FD_Ex[:, :, 1:-1, :],FD2_Ex)*reg_norm_Ex
                # phys_reg_Ez = model.loss_fn(FD_Ez[:, :, 1:-1, :],FD2_Ez)*reg_norm_Ez

                # loss = loss + phys_reg_H + phys_reg_Ex + phys_reg_Ez
                loss = loss + phys_reg_H

                scaler.scale(loss).backward()
                scaler.step(model.optimizer)
                scaler.update()
                # loss.backward()
                # model.optimizer.step()

        #Save the weights at the end of each epoch
        checkpoint = {
            'epoch': step,
            'model': model,
            'optimizer': model.optimizer,
            'lr_scheduler': model.lr_scheduler
        }
        torch.save(checkpoint, model_path + "/last_model.pt")

        # evaluation
        train_loss = 0
        train_phys_reg_H = 0
        train_phys_reg_Ex = 0
        train_phys_reg_Ez = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)

            with torch.no_grad():
                logits = model(x_batch_train, bn_training=False)
                loss = model.loss_fn(logits, y_batch_train)

                # # Calculate physical residue
                # pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                # fields = logits;
                # FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                # pattern = torch.cat((torch.ones([pattern.shape[0], 1, 1, 256], dtype = torch.float32, device = device)*n_sub**2, pattern), dim=2);
                # FD_Ex = Hz_to_Ex(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                # FD_Ez = Hz_to_Ey(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                # pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                # FD2_Ex, FD2_Ez = E_to_E(FD_Ez[:,0,0:-1,:], FD_Ez[:,1,0:-1,:], FD_Ex[:,0,:,:], FD_Ex[:,1,:,:], dx*1e-9, omega, pattern[:,:,0:-1,:])

                # phys_reg_H = model.loss_fn(FD_H, fields[:,:,1:(Nz-1),:])*reg_norm_H
                # phys_reg_Ex = model.loss_fn(FD_Ex[:, :, 1:-1, :],FD2_Ex)*reg_norm_Ex
                # phys_reg_Ez = model.loss_fn(FD_Ez[:, :, 1:-1, :],FD2_Ez)*reg_norm_Ez

                pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat(
                    (torch.ones([pattern.shape[0], 1, 1, 256],
                                dtype=torch.float32,
                                device=device) * n_sub**2, pattern),
                    dim=2)

                filtered_logits = gaussian_filter2(logits)
                filtered_pattern = gaussian_filter1(pattern)

                # FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                FD_H = H_to_H(-filtered_logits[:, 0], -filtered_logits[:, 1],
                              dL, omega, filtered_pattern)

                # FD_Ex = Hz_to_Ex(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                # FD_Ez = Hz_to_Ey(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                FD_Ex = Hz_to_Ex(-filtered_logits[:, 0, :, :],
                                 -filtered_logits[:, 1, :, :], dx * 1e-9,
                                 omega, filtered_pattern)
                FD_Ez = Hz_to_Ey(-filtered_logits[:, 0, :, :],
                                 -filtered_logits[:, 1, :, :], dx * 1e-9,
                                 omega, filtered_pattern)

                pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                filtered_pattern = gaussian_filter1(pattern)
                FD2_Ex, FD2_Ez = E_to_E(FD_Ez[:, 0, 0:-1, :], FD_Ez[:, 1,
                                                                    0:-1, :],
                                        FD_Ex[:, 0, :, :], FD_Ex[:, 1, :, :],
                                        dx * 1e-9, omega,
                                        filtered_pattern[:, :, 0:-1, :])

                phys_reg_H = model.loss_fn(
                    FD_H, logits[:, :, 1:(Nz - 1), :]) * reg_norm_H
                phys_reg_Ex = model.loss_fn(FD_Ex[:, :, 1:-1, :],
                                            FD2_Ex) * reg_norm_Ex
                phys_reg_Ez = model.loss_fn(FD_Ez[:, :, 1:-1, :],
                                            FD2_Ez) * reg_norm_Ez

                train_loss += loss
                train_phys_reg_H += phys_reg_H
                train_phys_reg_Ex += phys_reg_Ex
                train_phys_reg_Ez += phys_reg_Ez

        train_loss /= len(train_loader) * train_mean
        train_phys_reg_H /= len(train_loader)
        train_phys_reg_Ex /= len(train_loader)
        train_phys_reg_Ez /= len(train_loader)

        test_loss = 0
        test_phys_reg_H = 0
        test_phys_reg_Ex = 0
        test_phys_reg_Ez = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)

            with torch.no_grad():
                logits = model(x_batch_test, bn_training=False)
                loss = model.loss_fn(logits, y_batch_test)

                # # Calculate physical residue
                # pattern = (x_batch_test*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                # fields = logits;
                # FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                # pattern = torch.cat((torch.ones([pattern.shape[0], 1, 1, 256], dtype = torch.float32, device = device)*n_sub**2, pattern), dim=2);
                # FD_Ex = Hz_to_Ex(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                # FD_Ez = Hz_to_Ey(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                # pattern = (x_batch_test*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                # FD2_Ex, FD2_Ez = E_to_E(FD_Ez[:,0,0:-1,:], FD_Ez[:,1,0:-1,:], FD_Ex[:,0,:,:], FD_Ex[:,1,:,:], dx*1e-9, omega, pattern[:,:,0:-1,:])

                # phys_reg_H = model.loss_fn(FD_H, fields[:,:,1:(Nz-1),:])
                # phys_reg_Ex = model.loss_fn(FD_Ex[:, :, 1:-1, :],FD2_Ex)
                # phys_reg_Ez = model.loss_fn(FD_Ez[:, :, 1:-1, :],FD2_Ez)

                pattern = (x_batch_test * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat(
                    (torch.ones([pattern.shape[0], 1, 1, 256],
                                dtype=torch.float32,
                                device=device) * n_sub**2, pattern),
                    dim=2)

                filtered_logits = gaussian_filter2(logits)
                filtered_pattern = gaussian_filter1(pattern)

                # FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                FD_H = H_to_H(-filtered_logits[:, 0], -filtered_logits[:, 1],
                              dL, omega, filtered_pattern)

                # FD_Ex = Hz_to_Ex(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                # FD_Ez = Hz_to_Ey(-fields[:,0,:,:], -fields[:,1,:,:], dx*1e-9, omega, pattern);
                FD_Ex = Hz_to_Ex(-filtered_logits[:, 0, :, :],
                                 -filtered_logits[:, 1, :, :], dx * 1e-9,
                                 omega, filtered_pattern)
                FD_Ez = Hz_to_Ey(-filtered_logits[:, 0, :, :],
                                 -filtered_logits[:, 1, :, :], dx * 1e-9,
                                 omega, filtered_pattern)

                pattern = (x_batch_test * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                filtered_pattern = gaussian_filter1(pattern)
                FD2_Ex, FD2_Ez = E_to_E(FD_Ez[:, 0, 0:-1, :], FD_Ez[:, 1,
                                                                    0:-1, :],
                                        FD_Ex[:, 0, :, :], FD_Ex[:, 1, :, :],
                                        dx * 1e-9, omega,
                                        filtered_pattern[:, :, 0:-1, :])

                phys_reg_H = model.loss_fn(FD_H, logits[:, :, 1:(Nz - 1), :])
                phys_reg_Ex = model.loss_fn(FD_Ex[:, :, 1:-1, :], FD2_Ex)
                phys_reg_Ez = model.loss_fn(FD_Ez[:, :, 1:-1, :], FD2_Ez)

                test_loss += loss
                test_phys_reg_H += phys_reg_H
                test_phys_reg_Ex += phys_reg_Ex
                test_phys_reg_Ez += phys_reg_Ez

        test_loss /= len(test_loader) * test_mean
        test_phys_reg_H /= len(test_loader)
        test_phys_reg_Ex /= len(test_loader)
        test_phys_reg_Ez /= len(test_loader)
        last_epoch_data_loss = test_loss
        last_epoch_physical_loss_H = test_phys_reg_H.detach().clone()
        last_epoch_physical_loss_Ex = test_phys_reg_Ex.detach().clone()
        last_epoch_physical_loss_Ez = test_phys_reg_Ez.detach().clone()

        test_phys_reg_H *= reg_norm_H
        test_phys_reg_Ex *= reg_norm_Ex
        test_phys_reg_Ez *= reg_norm_Ez

        print(
            'train loss: %.5f, test loss: %.5f, thisE_phys_reg_H: %.5f, thisE_reg_Ex: %.5f, thisE_phys_reg_Ez: %.5f'
            % (train_loss, test_loss, last_epoch_physical_loss_H,
               last_epoch_physical_loss_Ex, last_epoch_physical_loss_Ez),
            flush=True)

        model.lr_scheduler.step()

        df = df.append(
            {
                'epoch': step + 1,
                'lr': str(model.lr_scheduler.get_last_lr()),
                'train_loss': train_loss.item(),
                'train_phys_reg_H': train_phys_reg_H.item(),
                'train_phys_reg_Ex': train_phys_reg_Ex.item(),
                'train_phys_reg_Ez': train_phys_reg_Ez.item(),
                'test_loss': test_loss.item(),
                'thisE_phys_reg_H': last_epoch_physical_loss_H.item(),
                'thisE_phys_reg_Ex': last_epoch_physical_loss_Ex.item(),
                'thisE_phys_reg_Ez': last_epoch_physical_loss_Ez.item()
            },
            ignore_index=True)

        df.to_csv(model_path + '/' + 'df.csv', index=False)

        if (test_loss < best_loss):
            best_loss = test_loss
            checkpoint = {
                'epoch': step,
                'model': model,
                'optimizer': model.optimizer,
                'lr_scheduler': model.lr_scheduler
            }
            torch.save(checkpoint, model_path + "/best_model.pt")
Ejemplo n.º 6
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    else:
        raise ("architectures other than Unet hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 eps=1e-7,
                                 amsgrad=True,
                                 weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(
        model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    ds = SimulationDataset(args.data_folder,
                           total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(
        ds,
        [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=0)

    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print(
        "total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f"
        % (len(train_ds), len(test_ds), train_mean, test_mean),
        flush=True)

    start_epoch = 0

    reg_norm = 5
    # training
    for sample_batched in train_loader:

        x_batch_train, y_batch_train = sample_batched['structure'].to(
            device), sample_batched['field'].to(device)
        with autocast():
            logits = model(x_batch_train, bn_training=True)
            #calculate the loss using the ground truth
            loss = model.loss_fn(logits, y_batch_train)
            print("loss: ", loss, flush=True)

            pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
            # rescale the 0/1 pattern into dielectric constant
            fields = logits
            # predicted fields
            FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)

            phys_reg = model.loss_fn(FD_H, fields[:, :,
                                                  1:(Nz - 1), :]) / reg_norm
            print("phys_reg: ", phys_reg)
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    else:
        raise ("architectures other than Unet hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 eps=1e-7,
                                 amsgrad=True,
                                 weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(
        model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(
        args.batch_size) + "_lr_" + str(args.lr)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    ds = SimulationDataset(args.data_folder,
                           total_sample_number=args.total_sample_number)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(
        ds,
        [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

    print("total training samples: %d, total test samples: %d" %
          (len(train_ds), len(test_ds)),
          flush=True)
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=0)

    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()

    df = pd.DataFrame(columns=[
        'epoch', 'train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'
    ])

    train_loss_history = []
    train_phys_reg_history = []

    test_loss_history = []
    test_phys_reg_history = []

    start_epoch = 0
    if (args.continue_train):
        print("Restoring weights from ",
              model_path + "/last_model.pt",
              flush=True)
        checkpoint = torch.load(model_path + "/last_model.pt")
        start_epoch = checkpoint['epoch']
        model = checkpoint['model']
        model.lr_scheduler = checkpoint['lr_scheduler']
        model.optimizer = checkpoint['optimizer']
        df = read_csv(model_path + '/' + 'df.csv')

    best_loss = 1e4
    for step in range(start_epoch, args.epoch):
        print("epoch: ", step, flush=True)
        # training
        for sample_batched in train_loader:
            model.optimizer.zero_grad()

            x_batch_train, y_batch_train = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)
            logits = model(x_batch_train, bn_training=True)
            #calculate the loss using the ground truth
            loss = model.loss_fn(logits, y_batch_train)
            # print("loss: ", loss, flush=True)

            loss.backward()
            model.optimizer.step()
        #Save the weights at the end of each epoch
        checkpoint = {
            'epoch': step,
            'model': model,
            'optimizer': model.optimizer,
            'lr_scheduler': model.lr_scheduler
        }
        torch.save(checkpoint, model_path + "/last_model.pt")

        # evaluation
        train_loss = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)

            with torch.no_grad():
                logits = model(x_batch_train, bn_training=False)
                loss = model.loss_fn(logits, y_batch_train)
                train_loss += loss
        train_loss /= len(train_loader)

        test_loss = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)

            with torch.no_grad():
                logits = model(x_batch_test, bn_training=False)
                loss = model.loss_fn(logits, y_batch_test)
                test_loss += loss
        test_loss /= len(test_loader)

        print('train loss: %.5f, test loss: %.5f' % (train_loss, test_loss),
              flush=True)

        model.lr_scheduler.step()

        df = df.append(
            {
                'epoch': step + 1,
                'lr': str(model.lr_scheduler.get_last_lr()),
                'train_loss': train_loss,
                'phys_reg': 0,
                'test_loss': test_loss,
                'test_phys_reg': 0
            },
            ignore_index=True)

        df.to_csv(model_path + '/' + 'df.csv', index=False)

        if (test_loss < best_loss):
            best_loss = test_loss
            checkpoint = {
                'epoch': step,
                'model': model,
                'optimizer': model.optimizer,
                'lr_scheduler': model.lr_scheduler
            }
            torch.save(checkpoint, model_path + "/best_model.pt")