def map_at_k(predictions, actuals, k):
    """
    Computes the MAP at k
    :param predictions: array, predicted values
    :param actuals: array, actual values
    :param k: int, value to compute the metric at
    :returns MAP: float, the score at k
    """
    avg_prec = []
    for i in range(1, k + 1):
        prec = precision_at_k(predictions, actuals, i)
        avg_prec.append(prec)
    return np.mean(avg_prec)
def main():
    n_epochs = FLAGS.epochs
    n_hidden = FLAGS.size
    NEG_COUNT = FLAGS.neg_count
    CORRUPT_RATIO = FLAGS.corrupt

    event_data = ds.EventData(ds.rsvp_chicago_file, ug_dataset.user_group_chicago_file)
    users = event_data.get_train_users()

    n_inputs = event_data.n_events
    n_users = event_data.n_users
    n_outputs = event_data.n_events

    model = CDAEAutoEncoder(n_inputs, n_hidden, n_outputs, n_users,
                                    FLAGS.hidden_fn, FLAGS.output_fn,
                                    learning_rate=0.001)
    tf_config = tf.ConfigProto(
        gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.25,
                                  allow_growth=True))
    sv = tf.train.Supervisor(logdir=FLAGS.save_dir)
    set_logging_config(FLAGS.save_dir)

    with sv.prepare_or_wait_for_session(config=tf_config) as sess:
        prev_epoch_loss = 0.0
        for epoch in range(n_epochs):
            # additive gaussian noise or multiplicative mask-out/drop-out noise
            epoch_loss = 0.0
            users = shuffle(users)
            precision = []
            recall = []
            mean_avg_prec = []
            ndcg = []
            eval_at = [5, 10]

            tf.logging.info("Training the model...")
            for user_id in users:
                x, y, item = event_data.get_user_train_events(
                                                    user_id, NEG_COUNT, CORRUPT_RATIO)
                train_event_index = event_data.get_user_train_event_index(user_id)
                user_index = event_data.get_user_index(user_id)

                # We only compute loss on events we used as inputs
                # Each row is to index the first dimension
                gather_indices = list(zip(range(len(y)), item))

                # Get a batch of data
                batch_loss, _ = sess.run([model.loss, model.train], {
                    model.x: x.toarray().astype(np.float32),
                    model.gather_indices: gather_indices,
                    model.user_id: user_index,
                    model.y: y,
                    model.dropout: 0.8
                })
               
                epoch_loss += batch_loss
               
            tf.logging.info("Epoch: {:>16}       Loss: {:>10,.6f}".format("%s/%s" % (epoch, n_epochs),
                                                                epoch_loss))
            tf.logging.info("")
            if prev_epoch_loss != 0 and abs(epoch_loss - prev_epoch_loss) < 1:
                tf.logging.info("Decaying learning rate...")
                model.decay_learning_rate(sess, 0.5)

        # evaluate the model on the cv set
        cv_users = event_data.get_cv_users()
        precision = []
        recall = []
        mean_avg_prec = []
        ndcg = []
        eval_at = [5, 10]

        tf.logging.info("Evaluating on the CV set...")
        valid_test_users = 0
        for user_id in cv_users:
            # check if user was present in training data
            train_users = event_data.get_train_users()
            if user_id in train_users:
                valid_test_users = valid_test_users + 1
                test_event_index = event_data.get_user_cv_event_index(user_id)

                x, _, _ = event_data.get_user_train_events(user_id, 0, 0)
                user_index = event_data.get_user_index(user_id)
                # Compute score
                score = sess.run(model.outputs, {
                    model.x: x.toarray().astype(np.float32),
                    model.user_id: user_index,
                    model.dropout: 1.0
                })[0]  # We only do one sample at a time, take 0 index

                # Sorted in ascending order, we then take the last values
                index = np.argsort(score)

                # Number of test instances
                preck = []
                recallk = []
                mapk = []
                ndcgk = []
                for k in eval_at:
                    preck.append(precision_at_k(index, test_event_index, k))
                    recallk.append(recall_at_k(index, test_event_index, k))
                    mapk.append(map_at_k(index, test_event_index, k))
                    ndcgk.append(ndcg_at_k(index, test_event_index, k))

                precision.append(preck)
                recall.append(recallk)
                mean_avg_prec.append(mapk)
                ndcg.append(ndcgk)

        if valid_test_users > 0:
            # Unpack the lists zip(*[[1,2], [3, 4]]) => [1,3], [2,4]
            avg_precision_5, avg_precision_10 = zip(*precision)
            avg_precision_5, avg_precision_10 = np.mean(avg_precision_5), np.mean(avg_precision_10)

            avg_recall_5, avg_recall_10 = zip(*recall)
            avg_recall_5, avg_recall_10 = np.mean(avg_recall_5), np.mean(avg_recall_10)

            avg_map_5, avg_map_10 = zip(*mean_avg_prec)
            avg_map_5, avg_map_10 = np.mean(avg_map_5), np.mean(avg_map_10)

            avg_ndcg_5, avg_ndcg_10 = zip(*ndcg)
            avg_ndcg_5, avg_ndcg_10 = np.mean(avg_ndcg_5), np.mean(avg_ndcg_10)

        # Directly access variables
        tf.logging.info(f"Precision@5: {avg_precision_5:>10.6f}       Precision@10: {avg_precision_10:>10.6f}")
        tf.logging.info(f"Recall@5:    {avg_recall_5:>10.6f}       Recall@10:    {avg_recall_10:>10.6f}")
        tf.logging.info(f"MAP@5:       {avg_map_5:>10.6f}       MAP@10:       {avg_map_10:>10.6f}")
        tf.logging.info(f"NDCG@5:      {avg_ndcg_5:>10.6f}       NDCG@10:      {avg_ndcg_10:>10.6f}")
        tf.logging.info("")
        
        # evaluate on test users
        tf.logging.info("Evaluating on the test set...")
        valid_test_users = 0
        precision = []
        recall = []
        mean_avg_prec = []
        ndcg = []
        eval_at = [5, 10]
        for user_id in event_data.get_test_users():
            # check if user was present in training data
            train_users = event_data.get_train_users()
            if user_id in train_users:
                valid_test_users = valid_test_users + 1
                test_event_index = event_data.get_user_test_event_index(user_id)

                x, _, _ = event_data.get_user_train_events(user_id, 0, 0)
                user_index = event_data.get_user_index(user_id)
                # Compute score
                score = sess.run(model.outputs, {
                    model.x: x.toarray().astype(np.float32),
                    model.user_id: user_index,
                    model.dropout: 1.0
                })[0]  # We only do one sample at a time, take 0 index

                # Sorted in ascending order, we then take the last values
                index = np.argsort(score)

                # Number of test instances
                preck = []
                recallk = []
                mapk = []
                ndcgk = []
                for k in eval_at:
                    preck.append(precision_at_k(index, test_event_index, k))
                    recallk.append(recall_at_k(index, test_event_index, k))
                    mapk.append(map_at_k(index, test_event_index, k))
                    ndcgk.append(ndcg_at_k(index, test_event_index, k))

                precision.append(preck)
                recall.append(recallk)
                mean_avg_prec.append(mapk)
                ndcg.append(ndcgk)

        if valid_test_users > 0:
            # Unpack the lists zip(*[[1,2], [3, 4]]) => [1,3], [2,4]
            avg_precision_5, avg_precision_10 = zip(*precision)
            avg_precision_5, avg_precision_10 = np.mean(avg_precision_5), np.mean(avg_precision_10)

            avg_recall_5, avg_recall_10 = zip(*recall)
            avg_recall_5, avg_recall_10 = np.mean(avg_recall_5), np.mean(avg_recall_10)

            avg_map_5, avg_map_10 = zip(*mean_avg_prec)
            avg_map_5, avg_map_10 = np.mean(avg_map_5), np.mean(avg_map_10)

            avg_ndcg_5, avg_ndcg_10 = zip(*ndcg)
            avg_ndcg_5, avg_ndcg_10 = np.mean(avg_ndcg_5), np.mean(avg_ndcg_10)

        # Directly access variables
        tf.logging.info(f"Precision@5: {avg_precision_5:>10.6f}       Precision@10: {avg_precision_10:>10.6f}")
        tf.logging.info(f"Recall@5:    {avg_recall_5:>10.6f}       Recall@10:    {avg_recall_10:>10.6f}")
        tf.logging.info(f"MAP@5:       {avg_map_5:>10.6f}       MAP@10:       {avg_map_10:>10.6f}")
        tf.logging.info(f"NDCG@5:      {avg_ndcg_5:>10.6f}       NDCG@10:      {avg_ndcg_10:>10.6f}")
        tf.logging.info("")
    sv.request_stop()