Example #1
0
File: train.py Project: devzhk/PINO
    def train_adam(self,
                   optimizer,
                   alpha=100.0, beta=100.0,
                   iter_num=10,
                   path='beltrami', name='test.pt',
                   scheduler=None, re=1.0):
        self.model.train()
        self.col_xyzt.requires_grad = True
        mse = torch.nn.MSELoss()
        pbar = tqdm(range(iter_num), dynamic_ncols=True, smoothing=0.01)
        for e in pbar:
            optimizer.zero_grad()
            zero_grad(self.col_xyzt)

            pred_bd_uvwp = self.model(self.bd_xyzt)
            bd_loss = mse(pred_bd_uvwp[0:3], self.bd_uvwp[0:3])

            pred_ini_uvwp = self.model(self.ini_xyzt)
            ini_loss = mse(pred_ini_uvwp[0:3], self.ini_uvwp[0:3])

            pred_col_uvwp = self.model(self.col_xyzt)
            f_loss = self.loss_f(pred_col_uvwp, self.col_xyzt, re=re)

            total_loss = alpha * bd_loss + beta * ini_loss + f_loss
            total_loss.backward()
            optimizer.step()
            if scheduler is not None:
                scheduler.step()

            pbar.set_description(
                (
                    f'Total loss: {total_loss.item():.6f}, f loss: {f_loss.item():.7f} '
                    f'Boundary loss : {bd_loss.item():.7f}, initial loss: {ini_loss.item():.7f}'
                )
            )
            if e % 500 == 0:
                u_err, v_err, w_err = self.eval_error()
                print(f'u error: {u_err}, v error: {v_err}, w error: {w_err}')
        save_checkpoint(path, name, self.model)
Example #2
0
        optimizer.step()

        train_l2 += loss_u.item()
        test_l2 += loss.item()
        train_pino += loss_f.item()
        train_loss += total_loss.item()

    scheduler.step()

    # if ep % step_size == 0:
    #     plt.imsave('%s/y_%d.png' % (image_dir, ep), y[0, :, :].cpu().numpy())
    #     plt.imsave('%s/out_%d.png' % (image_dir, ep), out[0, :, :, 0].cpu().numpy())

    t2 = default_timer()
    pbar.set_description((
        f'Time cost: {t2- t1:.2f}; Train f error: {train_pino:.5f}; Train l2 error: {train_l2:.5f}. '
        f'Test l2 error: {test_l2:.5f}'))
    if wandb and log:
        wandb.log({
            'Train f error': train_pino,
            'Train L2 error': train_l2,
            'Train loss': train_loss,
            'Test L2 error': test_l2,
            'Time cost': t2 - t1
        })

save_checkpoint(ckpt_dir, name, model, optimizer)

# 80 pretrain, 100 epoch
# 100
# 6401 x 256 x 256 x 128
Example #3
0
        ))
        if wandb and log:
            wandb.log({
                'Train f error': f_error,
                'Train IC error': ic_error,
                'Train BC error': bc_error,
                'Test L2 error': test_error,
                'Total loss': total_train_loss,
                'u error': u_error,
                'v error': v_error
            })
        scheduler.step()
    return model


if __name__ == '__main__':
    log = True
    if wandb and log:
        wandb.init(project='PINO-NS40-NSFnet',
                   entity='hzzheng-pino',
                   group='with pressure',
                   tags=['4x50'])

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    datapath = 'data/NS_fine_Re40_s64_T1000.npy'
    dataset = NS40data(datapath, nx=64, nt=64, sub=1, sub_t=1, N=1000, index=1)
    layers = [3, 50, 50, 50, 50, 3]
    model = FCNet(layers).to(device)
    model = train_adam(model, dataset, device)
    save_checkpoint('checkpoints/pinns', name='NS40.pt', model=model)
Example #4
0
def train():
    batch_size = 20
    learning_rate = 0.001

    epochs = 2500
    step_size = 100
    gamma = 0.25

    modes = 12  # 20
    width = 32  # 64

    config_defaults = {
        'ntrain': 800,
        'nlabels': 10,
        'ntest': 200,
        'lr': learning_rate,
        'batch_size': batch_size,
        'modes': modes,
        'width': width
    }
    wandb.init(config=config_defaults, tags=['Epoch'])
    config = wandb.config
    print('config: ', config)

    ntrain = config.ntrain
    nlabels = config.nlabels
    ntest = config.ntest

    image_dir = 'figs/FDM-burgers'
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)
    ckpt_dir = 'Burgers-FDM'
    name = 'PINO_FDM_burgers_N' + \
           str(ntrain) + '_L' + str(nlabels) + '-' + str(ntest) + '.pt'

    train_loader = constructor.make_loader(n_sample=ntrain,
                                           batch_size=batch_size,
                                           train=True)
    test_loader = constructor.make_loader(n_sample=ntest,
                                          batch_size=batch_size,
                                          train=False)
    if config.nlabels > 0:
        supervised_loader = constructor.make_loader(n_sample=nlabels,
                                                    batch_size=batch_size,
                                                    start=ntrain,
                                                    train=True)
        supervised_loader = sample_data(loader=supervised_loader)
    else:
        supervised_loader = None
    layers = [
        width * 2 // 4, width * 3 // 4, width * 3 // 4, width * 4 // 4,
        width * 4 // 4
    ]
    modes = [modes * (5 - i) // 4 for i in range(4)]

    model = FNN2d(modes1=modes, modes2=modes, width=width,
                  layers=layers).to(device)
    num_param = count_params(model)
    print('Number of model parameters', num_param)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[500, 1000, 2000], gamma=gamma)

    myloss = LpLoss(size_average=True)
    pbar = tqdm(range(epochs), dynamic_ncols=True, smoothing=0.01)

    for ep in pbar:
        model.train()
        t1 = default_timer()
        train_pino = 0.0
        train_l2 = 0.0
        train_loss = 0.0
        dp_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
            loss_u, loss_f = PINO_loss(out, x[:, 0, :, 0])
            total_loss = loss_u * 10 + loss_f

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            train_l2 += loss.item()
            train_pino += loss_f.item()
            train_loss += total_loss.item()

        for x, y in supervised_loader:
            x, y = x.to(device), y.to(device)

            out = model(x)
            datapoint_loss = myloss(out.view(batch_size, -1),
                                    y.view(batch_size, -1))

            optimizer.zero_grad()
            datapoint_loss.backward()
            optimizer.step()

            dp_loss += datapoint_loss.item()
        scheduler.step()

        model.eval()
        test_l2 = 0.0
        test_pino = 0.0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)

                out = model(x)

                test_l2 += myloss(out.view(batch_size, -1),
                                  y.view(batch_size, -1)).item()
                test_u, test_f = PINO_loss(out, x[:, 0, :, 0])
                test_pino = test_f.item()

        train_l2 /= len(train_loader)
        test_l2 /= len(test_loader)
        train_pino /= len(train_loader)
        test_pino /= len(test_loader)
        train_loss /= len(train_loader)
        dp_loss /= len(supervised_loader)

        t2 = default_timer()
        pbar.set_description((
            f'Time cost: {t2 - t1:.2f}; Train f error: {train_pino:.5f}; Train l2 error: {train_l2:.5f}. '
            f'Test f error: {test_pino:.5f}; Test l2 error: {test_l2:.5f}'))
        if wandb:
            wandb.log({
                'Train f error': train_pino,
                'Train L2 error': train_l2,
                'Train DP error': dp_loss,
                'Train loss': train_loss,
                'Test f error': test_pino,
                'Test L2 error': test_l2,
                'Time cost': t2 - t1
            })

    save_checkpoint(ckpt_dir, name, model, optimizer)