コード例 #1
0
ファイル: train.py プロジェクト: ekrim/glow
def train(param, x, y):

    dim_in = x.shape[1]
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(device)

    dataloader = DataLoader(torch.from_numpy(x.astype(np.float32)),
                            batch_size=param.batch_size,
                            shuffle=True,
                            num_workers=2)

    flow = RealNVP(dim_in, device)
    flow.to(device)
    flow.train()

    optimizer = torch.optim.Adam(
        [p for p in flow.parameters() if p.requires_grad == True], lr=param.lr)

    it, print_cnt = 0, 0
    while it < param.total_it:

        for i, data in enumerate(dataloader):

            loss = -flow.log_prob(data.to(device)).mean()

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

            it += data.shape[0]
            print_cnt += data.shape[0]
            if print_cnt > PRINT_FREQ:
                print('it {:d} -- loss {:.03f}'.format(it, loss))
                print_cnt = 0

        torch.save(flow.state_dict(), 'flow_model.pytorch')
コード例 #2
0
ファイル: main.py プロジェクト: TinyVolt/normalizing-flows
def train_and_eval(flow, epochs, lr, train_loader, test_loader,
                   target_distribution):
    print('no of parameters is',
          sum(param.numel() for param in flow.parameters()))
    optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
    train_losses, test_losses = [], []
    for epoch in range(epochs):
        print('Starting epoch:', epoch + 1, 'of', epochs)
        train(flow, train_loader, optimizer, target_distribution)
        train_losses.append(eval_loss(flow, train_loader, target_distribution))
        test_losses.append(eval_loss(flow, test_loader, target_distribution))
    return flow, train_losses, test_losses


if __name__ == '__main__':
    print('Device is:', device)
    from torch.distributions.normal import Normal
    import numpy as np

    flow = RealNVP(INPUT_H, INPUT_W).to(device)
    target_distribution = Normal(
        torch.tensor(0).float().to(device),
        torch.tensor(1).float().to(device))
    flow, train_losses, test_losses = train_and_eval(flow, 100, 5e-4,
                                                     train_loader, test_loader,
                                                     target_distribution)
    print('train losses are', train_losses)
    print('test losses are', test_losses)
    torch.save(flow.state_dict(), 'trained_weights.pt')