Esempio n. 1
0
def train(args):
    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # init model and optimizer
    if args.verbose:
        print("Training baseline model:" if args.baseline else "Training HNN model:")

    output_dim = args.input_dim if args.baseline else 2
    nn_model = MLP(args.input_dim, 400, output_dim, args.nonlinearity)
    model = HNN(args.input_dim, differentiable_model=nn_model,
                field_type=args.field_type, baseline=args.baseline)
    optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4)

    # the data API is different
    # make sure it is a fair comparison
    # generate the data the same way as in the SymODEN
    # compute the time derivative based on the generated data
    us = [0.0]
    data = get_dataset(seed=args.seed, save_dir=args.save_dir, 
                    rad=args.rad, us=us, samples=50, timesteps=45) 
    
    # arrange data 

    train_x, t_eval = data['x'][0,:,:,0:2], data['t']
    test_x, t_eval = data['test_x'][0,:,:,0:2], data['t']

    train_dxdt = (train_x[1:,:,:] - train_x[:-1,:,:]) / (t_eval[1] - t_eval[0])
    test_dxdt = (test_x[1:,:,:] - test_x[:-1,:,:]) / (t_eval[1] - t_eval[0])

    train_x = train_x[0:-1,:,:].reshape((-1,2))
    test_x = test_x[0:-1,:,:].reshape((-1,2))
    test_dxdt = test_dxdt.reshape((-1,2))
    train_dxdt = train_dxdt.reshape((-1,2))
    
    x = torch.tensor( train_x, requires_grad=True, dtype=torch.float32)
    test_x = torch.tensor( test_x, requires_grad=True, dtype=torch.float32)
    dxdt = torch.Tensor(train_dxdt)
    test_dxdt = torch.Tensor(test_dxdt)

    # vanilla train loop
    stats = {'train_loss': [], 'test_loss': []}
    for step in range(args.total_steps+1):
        
        # train step
        dxdt_hat = model.rk4_time_derivative(x) if args.use_rk4 else model.time_derivative(x)
        loss = L2_loss(dxdt, dxdt_hat)
        loss.backward() ; optim.step() ; optim.zero_grad()
        
        # run test data
        test_dxdt_hat = model.rk4_time_derivative(test_x) if args.use_rk4 else model.time_derivative(test_x)
        test_loss = L2_loss(test_dxdt, test_dxdt_hat)

        # logging
        stats['train_loss'].append(loss.item())
        stats['test_loss'].append(test_loss.item())
        if args.verbose and step % args.print_every == 0:
            print("step {}, train_loss {:.4e}, test_loss {:.4e}".format(step, loss.item(), test_loss.item()))

    train_dxdt_hat = model.time_derivative(x)
    train_dist = (dxdt - train_dxdt_hat)**2
    test_dxdt_hat = model.time_derivative(test_x)
    test_dist = (test_dxdt - test_dxdt_hat)**2
    print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
        .format(train_dist.mean().item(), train_dist.std().item()/np.sqrt(train_dist.shape[0]),
                test_dist.mean().item(), test_dist.std().item()/np.sqrt(test_dist.shape[0])))

    return model, stats
Esempio n. 2
0
def train(args):
    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # init model and optimizer
    if args.verbose:
        print("Training baseline model:" if args.
              baseline else "Training HNN model:")

    output_dim = args.input_dim if args.baseline else 2
    nn_model = MLP(args.input_dim, args.hidden_dim, output_dim,
                   args.nonlinearity)
    model = HNN(args.input_dim,
                differentiable_model=nn_model,
                field_type=args.field_type,
                baseline=args.baseline)
    optim = torch.optim.Adam(model.parameters(),
                             args.learn_rate,
                             weight_decay=1e-4)

    # arrange data
    data = get_dataset(seed=args.seed)
    x = torch.tensor(data['x'], requires_grad=True, dtype=torch.float32)
    test_x = torch.tensor(data['test_x'],
                          requires_grad=True,
                          dtype=torch.float32)
    dxdt = torch.Tensor(data['dx'])
    test_dxdt = torch.Tensor(data['test_dx'])

    # vanilla train loop
    stats = {'train_loss': [], 'test_loss': []}
    for step in range(args.total_steps + 1):

        # train step
        dxdt_hat = model.rk4_time_derivative(
            x) if args.use_rk4 else model.time_derivative(x)
        loss = L2_loss(dxdt, dxdt_hat)
        loss.backward()
        optim.step()
        optim.zero_grad()

        # run test data
        test_dxdt_hat = model.rk4_time_derivative(
            test_x) if args.use_rk4 else model.time_derivative(test_x)
        test_loss = L2_loss(test_dxdt, test_dxdt_hat)

        # logging
        stats['train_loss'].append(loss.item())
        stats['test_loss'].append(test_loss.item())
        if args.verbose and step % args.print_every == 0:
            print("step {}, train_loss {:.4e}, test_loss {:.4e}".format(
                step, loss.item(), test_loss.item()))

    train_dxdt_hat = model.time_derivative(x)
    train_dist = (dxdt - train_dxdt_hat)**2
    test_dxdt_hat = model.time_derivative(test_x)
    test_dist = (test_dxdt - test_dxdt_hat)**2
    print(
        'Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
        .format(train_dist.mean().item(),
                train_dist.std().item() / np.sqrt(train_dist.shape[0]),
                test_dist.mean().item(),
                test_dist.std().item() / np.sqrt(test_dist.shape[0])))

    return model, stats