Exemplo n.º 1
0
def train(raw, flags):
    data_processor = DataProcessor(flags.forecast_length, flags.batch_size,
                                   flags.window)
    train_loader, val_loader = data_processor.get_train_test_data(
        raw, flags.validation_ratio)

    model = DeepAR(cov_dim=data_processor.num_features,
                   hidden_dim=flags.num_units,
                   num_layers=flags.num_layers,
                   num_class=len(raw['type'].unique()),
                   embedding_dim=flags.embedding_size,
                   batch_first=True,
                   dropout=flags.dropout)

    opt = torch.optim.Adam(model.parameters(), lr=flags.learning_rate)

    teacher_ratio = flags.teacher_ratio
    loss_history = []
    loss_fn = gaussian_likelihood_loss

    model, opt, start_epoch = load_checkpoint(flags.checkpoint_path, model,
                                              opt)
    if start_epoch >= flags.num_epochs:
        print('start_epoch is larger than num_epochs!')
    epoch = start_epoch
    # TODO: add early stop
    for epoch in range(start_epoch, flags.num_epochs):
        for step, data in enumerate(train_loader):
            avg_loss, _ = _forward(data, model, loss_fn, flags.window,
                                   flags.forecast_length, teacher_ratio)
            loss_history.append(avg_loss)
            opt.zero_grad()
            avg_loss.backward()
            opt.step()
            teacher_ratio *= flags.teacher_ratio_decay
        validation_loss = evaluate(val_loader, model, loss_fn, flags.window,
                                   flags.forecast_length)
        print('Epoch: %d' % epoch)
        print("Training Loss:%.3f" % avg_loss)
        print("Validation Loss:%.3f" % validation_loss)
        print('Teacher_ratio: %.3f' % teacher_ratio)
        print()

    print('Model training completed and save at %s' % flags.checkpoint_path)
    state = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': opt.state_dict()
    }
    if not os.path.exists(flags.checkpoint_path):
        os.mkdir(flags.checkpoint_path)
    torch.save(state, flags.checkpoint_path + '/model.pt')
    data_processor.save(flags.checkpoint_path)
    return model, loss_history
Exemplo n.º 2
0
def train(raw, flags):
    data_processor = DataProcessor(flags.forecast_length, flags.batch_size,
                                   flags.window)
    train_loader, val_loader = data_processor.get_train_test_data(
        raw, if_scale=True, val_ratio=flags.validation_ratio)

    model = LSTM(data_processor.num_features,
                 flags.num_units,
                 output_dim=flags.output_dim,
                 num_layers=flags.num_layers,
                 batch_first=True,
                 dropout=flags.dropout)

    if flags.loss == 'mse':
        loss_fn = torch.nn.MSELoss()
    else:
        loss_fn = SMAPE()

    opt = torch.optim.Adam(model.parameters(), lr=flags.learning_rate)

    teacher_ratio = flags.teacher_ratio
    loss_history = []

    model, opt, start_epoch = load_checkpoint(flags.checkpoint_path, model,
                                              opt)
    if start_epoch >= flags.num_epochs:
        print('start_epoch is larger than num_epochs!')
    epoch = start_epoch
    # TODO: add early stop
    for epoch in range(start_epoch, flags.num_epochs):
        for step, data in enumerate(train_loader):
            avg_loss, _, acc = _forward(data, model, loss_fn, flags.window,
                                        flags.forecast_length, True,
                                        teacher_ratio)
            loss_history.append(avg_loss)
            opt.zero_grad()
            avg_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()
            teacher_ratio *= flags.teacher_ratio_decay
        val_loss, val_acc = evaluate(val_loader, model, loss_fn, flags.window,
                                     flags.forecast_length)
        print('Epoch: %d' % epoch)
        print("Training Loss:%.3f" % avg_loss)
        print('Training Avg Accuracy:%.3f' % acc)
        print("Validation Loss:%.3f" % val_loss)
        print("Validation Accuracy:%.3f" % val_acc)
        print('Teacher_ratio: %.3f' % teacher_ratio)
        print('Gradients:%.3f' % torch.mean((torch.stack(
            [torch.mean(torch.abs(p.grad)) for p in model.parameters()], 0))))
        print()

    print('Model training completed and save at %s' % flags.checkpoint_path)
    state = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': opt.state_dict()
    }
    if not os.path.exists(flags.checkpoint_path):
        os.mkdir(flags.checkpoint_path)
    torch.save(state, flags.checkpoint_path + '/model.pt')
    data_processor.save(flags.checkpoint_path)
    return model, loss_history