def train_model(train_iterator, val_iterator, test_iterator):
    hidden_size = 8
    vocab_size = len(train_iterator.word2index)
    n_extra_feat = 10
    output_size = 2
    n_layers = 1
    dropout = 0.5
    learning_rate = 0.001
    epochs = 40
    spatial_dropout = True
    bidirectional = True

    # Load the weights matrix
    weights = np.load('glove/weights-biGRU-glove.npy')

    # Check whether system supports CUDA
    CUDA = torch.cuda.is_available()

    model = BiGRU(hidden_size, vocab_size, n_extra_feat, weights, output_size,
                  n_layers, dropout, spatial_dropout, bidirectional)

    # Move the model to GPU if possible
    if CUDA:
        model.cuda()

    model.add_loss_fn(nn.NLLLoss())

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.add_optimizer(optimizer)

    device = torch.device('cuda' if CUDA else 'cpu')

    model.add_device(device)

    # Instantiate the EarlyStopping
    early_stop = EarlyStopping(wait_epochs=2)

    train_losses_list, train_avg_loss_list, train_accuracy_list = [], [], []
    eval_avg_loss_list, eval_accuracy_list, conf_matrix_list = [], [], []

    for epoch in range(epochs):

        print('\nStart epoch [{}/{}]'.format(epoch + 1, epochs))

        train_losses, train_avg_loss, train_accuracy = model.train_model(
            train_iterator)

        train_losses_list.append(train_losses)
        train_avg_loss_list.append(train_avg_loss)
        train_accuracy_list.append(train_accuracy)

        _, eval_avg_loss, eval_accuracy, conf_matrix = model.evaluate_model(
            val_iterator)

        eval_avg_loss_list.append(eval_avg_loss)
        eval_accuracy_list.append(eval_accuracy)
        conf_matrix_list.append(conf_matrix)

        print(
            '\nEpoch [{}/{}]: Train accuracy: {:.3f}. Train loss: {:.4f}. Evaluation accuracy: {:.3f}. Evaluation loss: {:.4f}' \
            .format(epoch + 1, epochs, train_accuracy, train_avg_loss, eval_accuracy, eval_avg_loss))

        if early_stop.stop(eval_avg_loss, model, delta=0.003):
            break

    _, test_avg_loss, test_accuracy, test_conf_matrix = model.evaluate_model(
        test_iterator)
    print('Test accuracy: {:.3f}. Test error: {:.3f}'.format(
        test_accuracy, test_avg_loss))
def train_model(train_iterator, val_iterator, test_iterator):
    batch_size = 32
    vocab_size = len(train_iterator.word2index)
    dmodel = 64
    output_size = 2
    padding_idx = train_iterator.word2index['<PAD>']
    n_layers = 4
    ffnn_hidden_size = dmodel * 2
    heads = 8
    pooling = 'max'
    dropout = 0.5
    label_smoothing = 0.1
    learning_rate = 0.001
    epochs = 30
    CUDA = torch.cuda.is_available()
    max_len = 0
    for batches in train_iterator:
        x_lengths = batches['x_lengths']
        if max(x_lengths) > max_len:
            max_len = int(max(x_lengths))
    model = Transformer(vocab_size, dmodel, output_size, max_len, padding_idx, n_layers, \
                        ffnn_hidden_size, heads, pooling, dropout)
    if CUDA:
        model.cuda()

    if label_smoothing:
        loss_fn = LabelSmoothingLoss(output_size, label_smoothing)
    else:
        loss_fn = nn.NLLLoss()
    model.add_loss_fn(loss_fn)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.add_optimizer(optimizer)

    device = torch.device('cuda' if CUDA else 'cpu')

    model.add_device(device)
    params = {
        'batch_size': batch_size,
        'dmodel': dmodel,
        'n_layers': n_layers,
        'ffnn_hidden_size': ffnn_hidden_size,
        'heads': heads,
        'pooling': pooling,
        'dropout': dropout,
        'label_smoothing': label_smoothing,
        'learning_rate': learning_rate
    }

    train_writer = SummaryWriter('runs/transformer_train')

    val_writer = SummaryWriter('runs/transformer_val')

    early_stop = EarlyStopping(wait_epochs=3)

    train_losses_list, train_avg_loss_list, train_accuracy_list = [], [], []
    eval_avg_loss_list, eval_accuracy_list, conf_matrix_list = [], [], []
    for epoch in range(epochs):

        try:
            print('\nStart epoch [{}/{}]'.format(epoch + 1, epochs))

            train_losses, train_avg_loss, train_accuracy = model.train_model(
                train_iterator)

            train_losses_list.append(train_losses)
            train_avg_loss_list.append(train_avg_loss)
            train_accuracy_list.append(train_accuracy)

            _, eval_avg_loss, eval_accuracy, conf_matrix = model.evaluate_model(
                val_iterator)

            eval_avg_loss_list.append(eval_avg_loss)
            eval_accuracy_list.append(eval_accuracy)
            conf_matrix_list.append(conf_matrix)

            print(
                '\nEpoch [{}/{}]: Train accuracy: {:.3f}. Train loss: {:.4f}. Evaluation accuracy: {:.3f}. Evaluation loss: {:.4f}' \
                .format(epoch + 1, epochs, train_accuracy, train_avg_loss, eval_accuracy, eval_avg_loss))

            train_writer.add_scalar('Training loss', train_avg_loss, epoch)
            val_writer.add_scalar('Validation loss', eval_avg_loss, epoch)

            if early_stop.stop(eval_avg_loss, model, delta=0.003):
                break

        finally:
            train_writer.close()
            val_writer.close()

    _, test_avg_loss, test_accuracy, test_conf_matrix = model.evaluate_model(
        test_iterator)
    print('Test accuracy: {:.3f}. Test error: {:.3f}'.format(
        test_accuracy, test_avg_loss))