Example #1
0
File: train.py Project: devzhk/PINO
 def eval_error(self):
     lploss = LpLoss()
     self.model.eval()
     with torch.no_grad():
         pred_uvwp = self.model(self.col_xyzt)
         u_error = lploss(pred_uvwp[:, 0], self.col_uvwp[:, 0])
         v_error = lploss(pred_uvwp[:, 1], self.col_uvwp[:, 1])
         w_error = lploss(pred_uvwp[:, 2], self.col_uvwp[:, 2])
     return u_error.item(), v_error.item(), w_error.item()
Example #2
0
    width * 4 // 4
]
modes = [modes * (5 - i) // 4 for i in range(4)]

model = FNN2d(modes1=modes, modes2=modes, widths=width,
              layers=layers).to(device)
num_param = count_params(model)
print('Number of model parameters', num_param)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
milestones = [i * 1000 for i in range(1, 5)]
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=milestones,
                                                 gamma=gamma)

myloss = LpLoss(size_average=True)
pbar = tqdm(range(epochs), dynamic_ncols=True, smoothing=0.01)

for ep in pbar:
    model.train()
    t1 = default_timer()
    train_pino = 0.0
    train_l2 = 0.0
    train_loss = 0.0
    test_l2 = 0.0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)

        loss = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
Example #3
0
def train_adam(model, dataset, device):
    alpha = 100
    beta = 100
    epoch_num = 3000
    dataloader = DataLoader(dataset,
                            batch_size=5000,
                            shuffle=True,
                            drop_last=True)

    model.train()
    criterion = LpLoss(size_average=True)
    mse = nn.MSELoss()
    optimizer = Adam(model.parameters(), lr=0.0005)
    milestones = [100, 500, 1500, 2000]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=milestones,
                                                     gamma=0.9)
    bd_x, bd_y, bd_t, bd_vor, u_gt, v_gt = dataset.get_boundary()
    bd_x, bd_y, bd_t, bd_vor, u_gt, v_gt = bd_x.to(device), bd_y.to(device), bd_t.to(device), \
                                           bd_vor.to(device), u_gt.to(device), v_gt.to(device)
    pbar = tqdm(range(epoch_num), dynamic_ncols=True, smoothing=0.01)

    set_grad([bd_x, bd_y, bd_t])
    for e in pbar:
        total_train_loss = 0.0
        bc_error = 0.0
        ic_error = 0.0
        f_error = 0.0
        model.train()
        for x, y, t, vor, true_u, true_v in dataloader:
            optimizer.zero_grad()
            # initial condition
            u, v, _ = net_NS(bd_x, bd_y, bd_t, model)
            loss_ic = mse(u, u_gt.view(-1)) + mse(v, v_gt.view(-1))
            #  boundary condition
            loss_bc = boundary_loss(model, 100)

            # collocation points
            x, y, t, vor, true_u, true_v = x.to(device), y.to(device), t.to(device), \
                                           vor.to(device), true_u.to(device), true_v.to(device)
            set_grad([x, y, t])
            u, v, p = net_NS(x, y, t, model)
            # velu_loss = criterion(u, true_u)
            # velv_loss = criterion(v, true_v)
            res_x, res_y, evp3 = resf_NS(u, v, p, x, y, t, re=40)
            loss_f = mse(res_x, torch.zeros_like(res_x)) \
                     + mse(res_y, torch.zeros_like(res_y)) \
                     + mse(evp3, torch.zeros_like(evp3))

            total_loss = loss_f + loss_bc * alpha + loss_ic * beta
            total_loss.backward()
            optimizer.step()

            total_train_loss += total_loss.item()
            bc_error += loss_bc.item()
            ic_error += loss_ic.item()
            f_error += loss_f.item()
        total_train_loss /= len(dataloader)

        ic_error /= len(dataloader)
        f_error /= len(dataloader)

        u_error = 0.0
        v_error = 0.0
        test_error = 0.0
        model.eval()
        for x, y, t, vor, true_u, true_v in dataloader:
            x, y, t, vor, true_u, true_v = x.to(device), y.to(device), t.to(device), \
                                           vor.to(device), true_u.to(device), true_v.to(device)
            set_grad([x, y, t])
            u, v, _ = net_NS(x, y, t, model)
            pred_vor = vel2vor(u, v, x, y)
            velu_loss = criterion(u, true_u)
            velv_loss = criterion(v, true_v)
            test_loss = criterion(pred_vor, vor)
            u_error += velu_loss.item()
            v_error += velv_loss.item()
            test_error += test_loss.item()

        u_error /= len(dataloader)
        v_error /= len(dataloader)
        test_error /= len(dataloader)
        pbar.set_description((
            f'Train f error: {f_error:.5f}; Train IC error: {ic_error:.5f}. '
            f'Train loss: {total_train_loss:.5f}; Test l2 error: {test_error:.5f}'
        ))
        if wandb and log:
            wandb.log({
                'Train f error': f_error,
                'Train IC error': ic_error,
                'Train BC error': bc_error,
                'Test L2 error': test_error,
                'Total loss': total_train_loss,
                'u error': u_error,
                'v error': v_error
            })
        scheduler.step()
    return model
Example #4
0
def train():
    batch_size = 20
    learning_rate = 0.001

    epochs = 2500
    step_size = 100
    gamma = 0.25

    modes = 12  # 20
    width = 32  # 64

    config_defaults = {
        'ntrain': 800,
        'nlabels': 10,
        'ntest': 200,
        'lr': learning_rate,
        'batch_size': batch_size,
        'modes': modes,
        'width': width
    }
    wandb.init(config=config_defaults, tags=['Epoch'])
    config = wandb.config
    print('config: ', config)

    ntrain = config.ntrain
    nlabels = config.nlabels
    ntest = config.ntest

    image_dir = 'figs/FDM-burgers'
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)
    ckpt_dir = 'Burgers-FDM'
    name = 'PINO_FDM_burgers_N' + \
           str(ntrain) + '_L' + str(nlabels) + '-' + str(ntest) + '.pt'

    train_loader = constructor.make_loader(n_sample=ntrain,
                                           batch_size=batch_size,
                                           train=True)
    test_loader = constructor.make_loader(n_sample=ntest,
                                          batch_size=batch_size,
                                          train=False)
    if config.nlabels > 0:
        supervised_loader = constructor.make_loader(n_sample=nlabels,
                                                    batch_size=batch_size,
                                                    start=ntrain,
                                                    train=True)
        supervised_loader = sample_data(loader=supervised_loader)
    else:
        supervised_loader = None
    layers = [
        width * 2 // 4, width * 3 // 4, width * 3 // 4, width * 4 // 4,
        width * 4 // 4
    ]
    modes = [modes * (5 - i) // 4 for i in range(4)]

    model = FNN2d(modes1=modes, modes2=modes, width=width,
                  layers=layers).to(device)
    num_param = count_params(model)
    print('Number of model parameters', num_param)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[500, 1000, 2000], gamma=gamma)

    myloss = LpLoss(size_average=True)
    pbar = tqdm(range(epochs), dynamic_ncols=True, smoothing=0.01)

    for ep in pbar:
        model.train()
        t1 = default_timer()
        train_pino = 0.0
        train_l2 = 0.0
        train_loss = 0.0
        dp_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
            loss_u, loss_f = PINO_loss(out, x[:, 0, :, 0])
            total_loss = loss_u * 10 + loss_f

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            train_l2 += loss.item()
            train_pino += loss_f.item()
            train_loss += total_loss.item()

        for x, y in supervised_loader:
            x, y = x.to(device), y.to(device)

            out = model(x)
            datapoint_loss = myloss(out.view(batch_size, -1),
                                    y.view(batch_size, -1))

            optimizer.zero_grad()
            datapoint_loss.backward()
            optimizer.step()

            dp_loss += datapoint_loss.item()
        scheduler.step()

        model.eval()
        test_l2 = 0.0
        test_pino = 0.0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)

                out = model(x)

                test_l2 += myloss(out.view(batch_size, -1),
                                  y.view(batch_size, -1)).item()
                test_u, test_f = PINO_loss(out, x[:, 0, :, 0])
                test_pino = test_f.item()

        train_l2 /= len(train_loader)
        test_l2 /= len(test_loader)
        train_pino /= len(train_loader)
        test_pino /= len(test_loader)
        train_loss /= len(train_loader)
        dp_loss /= len(supervised_loader)

        t2 = default_timer()
        pbar.set_description((
            f'Time cost: {t2 - t1:.2f}; Train f error: {train_pino:.5f}; Train l2 error: {train_l2:.5f}. '
            f'Test f error: {test_pino:.5f}; Test l2 error: {test_l2:.5f}'))
        if wandb:
            wandb.log({
                'Train f error': train_pino,
                'Train L2 error': train_l2,
                'Train DP error': dp_loss,
                'Train loss': train_loss,
                'Test f error': test_pino,
                'Test L2 error': test_l2,
                'Time cost': t2 - t1
            })

    save_checkpoint(ckpt_dir, name, model, optimizer)