Exemplo n.º 1
0
def train():
    """
    A simple Neural Network
    :return:
    """
    logging.info("Start Training!")
    corpus_type = 'train'
    summary = Summary(args.batch_size, args.max_len)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(summary.parameters(), lr=args.lr)
    summary.train()

    start_epoch = 0
    if args.model_name:
        checkpoint = torch.load(args.load_model)
        summary = checkpoint['model']
        start_epoch = checkpoint['epochs']

    start_epoch += 1 if start_epoch != 0 else start_epoch

    for epoch in range(start_epoch, args.epoch):
        epoch_loss = 0
        batch_num = 0
        for i, batch in enumerate(bachify_data(corpus_type)):
            batch_df, batch_label, _, _ = batch
            batch_df = torch.tensor(batch_df)
            batch_label = torch.tensor(batch_label)
            binary_output = summary(batch_df)

            # calculate loss
            loss = criterion(binary_output, batch_label)
            loss.backward()
            optimizer.step()
            epoch_loss += loss
            batch_num += 1

        logging.info("Epoch {}: Total loss is {}, Avg loss is {}".format(epoch, epoch_loss, epoch_loss/batch_num))
        # store model
        model_name = "{}_epoch_model.tar".format(epoch)
        directory = os.path.join(args.save_path, model_name)
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save({
            'model': summary.state_dict(),
            'loss': epoch_loss / batch_num,
            "epochs": epoch
        }, directory)

    logging.info("Finish Training!")