def train_crf_adam(params,
                   X,
                   y,
                   C,
                   num_epochs,
                   learning_rate,
                   l2_lambda,
                   test_xy,
                   model_name,
                   beta1=0.9,
                   beta2=0.999,
                   epsilon=1e-8,
                   amsgrad=False):
    num_words = len(X)
    test_X, test_Y = test_xy

    W = matricize_W(params)
    T = matricize_Tij(params)

    print("Starting SGD optimizaion")

    # make results reproducible
    np.random.seed(1)

    # ADAM parameters
    M = {
        'W': np.zeros_like(W, dtype=np.float32),
        'T': np.zeros_like(T, dtype=np.float32)
    }
    R = {
        'W': np.zeros_like(W, dtype=np.float32),
        'T': np.zeros_like(T, dtype=np.float32)
    }

    if amsgrad:
        R_hat = {
            'W': np.zeros_like(W, dtype=np.float32),
            'T': np.zeros_like(T, dtype=np.float32)
        }

    start = time.time()
    iter = 1
    for epoch in range(num_epochs):
        print("Begin epoch %d" % (epoch + 1))

        indices = np.arange(0, num_words, step=1, dtype=int)
        np.random.shuffle(indices)

        for i, word_index in enumerate(indices):
            W_grad, T_grad = gradient_per_word(X,
                                               y,
                                               W,
                                               T,
                                               word_index,
                                               concat_grads=False)

            # ADAM updates to Momentum params
            M['W'] = beta1 * M['W'] + (1 - beta1) * W_grad
            M['T'] = beta1 * M['T'] + (1 - beta1) * T_grad

            # ADAM updates to RMSProp params
            R['W'] = beta2 * R['W'] + (1 - beta2) * W_grad**2
            R['T'] = beta2 * R['T'] + (1 - beta2) * T_grad**2

            if not amsgrad:
                # ADAM Update
                # bias correction
                m_k_w = M['W'] / (1 - beta1**iter)
                m_k_t = M['T'] / (1 - beta1**iter)
                r_k_w = R['W'] / (1 - beta2**iter)
                r_k_t = R['T'] / (1 - beta2**iter)

                lr_m = learning_rate / (np.sqrt(r_k_w) + epsilon)
                lr_t = learning_rate / (np.sqrt(r_k_t) + epsilon)

                # perform ADAM update
                W -= lr_m * (C * m_k_w + epsilon + l2_lambda * W)
                T -= lr_t * (C * m_k_t + epsilon + l2_lambda * T)

            else:
                # AMSGrad Update
                R_hat['W'] = np.maximum(R_hat['W'], R['W'])
                R_hat['T'] = np.maximum(R_hat['T'], R['T'])

                # bias correction
                m_k_w = M['W'] / (1 - beta1**iter)
                m_k_t = M['T'] / (1 - beta1**iter)
                r_k_w = R_hat['W'] / (1 - beta2**iter)
                r_k_t = R_hat['T'] / (1 - beta2**iter)

                lr_m = learning_rate / (np.sqrt(r_k_w) + epsilon)
                lr_t = learning_rate / (np.sqrt(r_k_t) + epsilon)

                # perform AMSGrad update
                W -= lr_m * (C * m_k_w + l2_lambda * W)
                T -= lr_t * (C * m_k_t + l2_lambda * T)

                R['W'] = R_hat['W']
                R['T'] = R_hat['T']

            #learning_rate *= 0.99
            #learning_rate = max(learning_rate, 1e-5)
            #
            #print("W norm", np.linalg.norm(W))
            #print("T norm", np.linalg.norm(T))
            #print("lr", learning_rate)
            #print("W grad stats", np.linalg.norm(W_grad), np.min(W_grad), np.max(W_grad), np.mean(W_grad), np.std(W_grad))
            #print("T grad stats", np.linalg.norm(T_grad), np.min(T_grad), np.max(T_grad), np.mean(T_grad), np.std(T_grad))
            #print()

            iter += 1
            iter = min(iter, int(1e6))

        params = np.concatenate((W.flatten(), T.flatten()))
        print("W norm", np.linalg.norm(W))
        print("T norm", np.linalg.norm(T))
        logloss = optimization_function(params, X, y, C, num_words)
        print("Logloss : ", logloss)

        if (epoch + 1) % 5 == 0:
            print('*' * 80)
            print("Computing metrics after end of epoch %d" % (epoch + 1))
            # print evaluation metrics every 1000 steps of SGD
            train_loss = optimization_function(params, X, y, C, num_words)

            y_preds = decode_crf(test_X, W, T)
            word_acc, char_acc = compute_word_char_accuracy_score(
                y_preds, test_Y)

            print(
                "Epoch %d | Train loss = %0.8f | Word Accuracy = %0.5f | Char Accuracy = %0.5f"
                % (epoch + 1, train_loss, word_acc, char_acc))
            print('*' * 80, '\n')

    # merge the two grads into a single long vector
    out = np.concatenate((W.flatten(), T.flatten()))
    print("Total time: ", end='')
    print(time.time() - start)

    with open("result/" + model_name + ".txt", "w") as text_file:
        for i, elt in enumerate(out):
            text_file.write(str(elt) + "\n")
    # BETA2 = 0.999
    #
    # train_crf_adam(params, X_train, y_train, C=1, l2_lambda=L2_LAMBDA,
    #               num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE,
    #               test_xy=test_xy, model_name='adam-6',
    #                beta2=BETA2, amsgrad=AMSGRAD)

    params = get_trained_model_parameters('sgd-2')
    w = matricize_W(params)
    t = matricize_Tij(params)

    print(
        "Function value: ",
        optimization_function(params,
                              X_train,
                              y_train,
                              C=1,
                              limit=len(X_train)))
    ''' accuracy '''
    y_preds = decode_crf(X_test, w, t)

    with open("result/prediction-sgd.txt", "w") as text_file:
        for i, elt in enumerate(y_preds):
            # convert to characters
            for word in elt:
                text_file.write(str(word + 1))
                text_file.write("\n")

    print("Test accuracy : ",
          compute_word_char_accuracy_score(y_preds, y_test))
def train_crf_sgd(params, X, y, C, num_epochs, learning_rate, l2_lambda,
                  test_xy, model_name):
    num_words = len(X)
    test_X, test_Y = test_xy

    W = matricize_W(params)
    T = matricize_Tij(params)

    print("Starting SGD optimizaion")

    # make results reproducible
    np.random.seed(1)

    start = time.time()
    for epoch in range(num_epochs):
        print("Begin epoch %d" % (epoch + 1))

        indices = np.arange(0, num_words, step=1, dtype=int)
        np.random.shuffle(indices)

        for i, word_index in enumerate(indices):
            W_grad, T_grad = gradient_per_word(X,
                                               y,
                                               W,
                                               T,
                                               word_index,
                                               concat_grads=False)

            W_grad = -W_grad + l2_lambda * W
            T_grad = -T_grad + l2_lambda * T

            #W_grad /= np.linalg.norm(W_grad)
            #T_grad /= np.linalg.norm(T_grad)

            # perform SGD update
            W -= learning_rate * (W_grad)
            T -= learning_rate * (T_grad)

            if i % 100 == 0:
                params = np.concatenate((W.flatten(), T.flatten()))
                logloss = optimization_function(params, X, y, C, num_words)
                print("Epoch %d Iter %d : Logloss : " % (epoch + 1, i),
                      logloss)
                print("W norm", np.linalg.norm(W))
                print("T norm", np.linalg.norm(T))

            learning_rate *= 0.999
            learning_rate = max(learning_rate, 1e-4)

            #print("W norm", np.linalg.norm(W))
            #print("T norm", np.linalg.norm(T))
            #print("lr", learning_rate)
            #print("W grad stats", np.linalg.norm(W_grad), np.min(W_grad), np.max(W_grad), np.mean(W_grad), np.std(W_grad))
            #print("T grad stats", np.linalg.norm(T_grad), np.min(T_grad), np.max(T_grad), np.mean(T_grad), np.std(T_grad))
            #print()

        print("Learning rate : ", learning_rate)

        params = np.concatenate((W.flatten(), T.flatten()))
        logloss = optimization_function(params, X, y, C, num_words)
        print("Logloss : ", logloss)

        if (epoch + 1) % 5 == 0:
            print('*' * 80)
            print("Computing metrics after end of epoch %d" % (epoch + 1))
            # print evaluation metrics every 1000 steps of SGD
            train_loss = optimization_function(params, X, y, C, num_words)

            y_preds = decode_crf(test_X, W, T)
            word_acc, char_acc = compute_word_char_accuracy_score(
                y_preds, test_Y)

            print(
                "Epoch %d | Train loss = %0.8f | Word Accuracy = %0.5f | Char Accuracy = %0.5f"
                % (epoch + 1, train_loss, word_acc, char_acc))
            print('*' * 80, '\n')

    # merge the two grads into a single long vector
    out = np.concatenate((W.flatten(), T.flatten()))
    print("Total time: ", end='')
    print(time.time() - start)

    with open("result/" + model_name + ".txt", "w") as text_file:
        for i, elt in enumerate(out):
            text_file.write(str(elt) + "\n")
def train_crf_sgd(params, X, y, C, num_epochs, learning_rate, l2_lambda,
                  test_xy, model_name):
    num_words = len(X)
    test_X, test_Y = test_xy

    W = matricize_W(params)
    T = matricize_Tij(params)

    print("Starting SGD optimizaion")

    # make results reproducible
    np.random.seed(1)

    W_running_avg = np.zeros_like(W)
    T_running_avg = np.zeros_like(T)

    MOMENTUM = 0.9

    W_old_avg = np.copy(W_running_avg)
    T_old_avg = np.copy(T_running_avg)

    start = time.time()
    for epoch in range(num_epochs):
        print("Begin epoch %d" % (epoch + 1))

        indices = np.arange(0, num_words, step=1, dtype=int)
        np.random.shuffle(indices)

        for i, word_index in enumerate(indices):
            W_grad, T_grad = d_optimization_function_word(
                X, y, W, T, word_index, C, l2_lambda)

            W_running_avg = MOMENTUM * W_running_avg + (1. - MOMENTUM) * W_grad
            T_running_avg = MOMENTUM * T_running_avg + (1. - MOMENTUM) * T_grad

            # perform SGD update
            W -= learning_rate * W_running_avg
            T -= learning_rate * T_running_avg

            if i % 1000 == 0:
                params = np.concatenate((W.flatten(), T.flatten()))
                logloss = optimization_function(params, X, y, C, l2_lambda,
                                                num_words)
                print("Epoch %d Iter %d : Logloss : " % (epoch + 1, i),
                      logloss)
                print("W norm", np.linalg.norm(W))
                print("T norm", np.linalg.norm(T))

            learning_rate *= 0.999
            learning_rate = max(learning_rate, 5e-4)

            #print("W norm", np.linalg.norm(W))
            #print("T norm", np.linalg.norm(T))
            #print("lr", learning_rate)
            #print("W grad stats", np.linalg.norm(W_grad), np.min(W_grad), np.max(W_grad), np.mean(W_grad), np.std(W_grad))
            #print("T grad stats", np.linalg.norm(T_grad), np.min(T_grad), np.max(T_grad), np.mean(T_grad), np.std(T_grad))
            #print()

        print("Learning rate : ", learning_rate)

        params = np.concatenate((W.flatten(), T.flatten()))
        logloss = optimization_function(params, X, y, C, l2_lambda, num_words)
        print("Logloss : ", logloss)
        print()

        moving_sum = np.sum(W_running_avg**2) + np.sum(T_running_avg**2)
        old_moving_sum = np.sum(W_old_avg**2) + np.sum(T_old_avg**2)

        if abs(moving_sum - old_moving_sum) < 1e-3:
            print("Gradient difference is small enough. Early stopping..")
            break
        else:
            W_old_avg = W_running_avg
            T_old_avg = T_running_avg

        # if (epoch + 1) % 5 == 0:
        #     print('*' * 80)
        #     print("Computing metrics after end of epoch %d" % (epoch + 1))
        #     # print evaluation metrics every 1000 steps of SGD
        #     train_loss = optimization_function(params, X, y, C, l2_lambda, num_words)
        #
        #     y_preds = decode_crf(test_X, W, T)
        #     word_acc, char_acc = compute_word_char_accuracy_score(y_preds, test_Y)
        #
        #     print("Epoch %d | Train loss = %0.8f | Word Accuracy = %0.5f | Char Accuracy = %0.5f" %
        #           (epoch + 1, train_loss, word_acc, char_acc))
        #     print('*' * 80, '\n')

    print('*' * 80)
    print("Computing metrics after end of epoch")
    # print evaluation metrics every 1000 steps of SGD
    train_loss = optimization_function(params, X, y, C, l2_lambda, num_words)

    y_preds = decode_crf(test_X, W, T)
    word_acc, char_acc = compute_word_char_accuracy_score(y_preds, test_Y)

    print(
        "Epoch %d | Train loss = %0.8f | Word Accuracy = %0.5f | Char Accuracy = %0.5f"
        % (epoch + 1, train_loss, word_acc, char_acc))
    print('*' * 80, '\n')

    # merge the two grads into a single long vector
    out = np.concatenate((W.flatten(), T.flatten()))
    print("Total time: ", end='')
    print(time.time() - start)

    with open("result/" + model_name + ".txt", "w") as text_file:
        for i, elt in enumerate(out):
            text_file.write(str(elt) + "\n")