def train_main(hparams):
    """
    Main training routine for the dot semantic network bot
    :return:
    """

    # -----------------------
    # INIT EXPERIMENT
    # ----------------------
    exp = Experiment(name=hparams.exp_name,
                     debug=hparams.debug,
                     description=hparams.exp_desc,
                     autosave=False,
                     save_dir=hparams.test_tube_dir)

    exp.add_argparse_meta(hparams)
    exp.save()

    # -----------------------
    # LOAD DATASET
    # ----------------------
    udc_dataset = UDCDataset(vocab_path=hparams.vocab_path,
                             train_path=hparams.dataset_train_path,
                             test_path=hparams.dataset_test_path,
                             val_path=hparams.dataset_val_path,
                             max_seq_len=hparams.max_seq_len)

    # -----------------------
    # INIT TF VARS
    # ----------------------
    # input_x holds chat history
    # input_y holds our responses
    # labels holds the ground truth labels
    input_x = tf.placeholder(dtype=tf.int32,
                             shape=[hparams.batch_size, None],
                             name='input_x')
    input_y = tf.placeholder(dtype=tf.int32,
                             shape=[hparams.batch_size, None],
                             name='input_y')

    # ----------------------
    # EMBEDDING LAYER
    # ----------------------
    # you can preload your own or learn in the network
    # in this case we'll just learn it in the network
    embedding = tf.get_variable(
        'embedding', [udc_dataset.vocab_size, hparams.embedding_dim])

    # ----------------------
    # RESOLVE EMBEDDINGS
    # ----------------------
    # Lookup the embeddings.
    embedding_x = tf.nn.embedding_lookup(embedding, input_x)
    embedding_y = tf.nn.embedding_lookup(embedding, input_y)

    # Generates 1 vector per training example.
    x = tf.reduce_sum(embedding_x, axis=1)
    y = tf.reduce_sum(embedding_y, axis=1)

    # ----------------------
    # OPTIMIZATION PROBLEM
    # ----------------------
    S = dot_product_scoring(x, y, is_training=True)
    K = tf.reduce_logsumexp(S, axis=1)
    loss = -tf.reduce_mean(tf.diag_part(S) - K)

    # allow optimizer to be changed through hyper params
    optimizer = get_optimizer(hparams=hparams, minimize=loss)

    # ----------------------
    # TF ADMIN (VAR INIT, SESS)
    # ----------------------
    sess = tf.Session()
    init_vars = tf.global_variables_initializer()
    sess.run(init_vars)

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()

    # ----------------------
    # TRAINING ROUTINE
    # ----------------------
    # admin vars
    nb_batches_served = 0
    eval_every_n_batches = hparams.eval_every_n_batches

    train_err = 1000
    prec_at_1 = 0
    prec_at_2 = 0

    # iter for the needed epochs
    print('\n\n', '-' * 100,
          '\n  {} TRAINING\n'.format(hparams.exp_name.upper()), '-' * 100,
          '\n\n')
    for epoch in range(hparams.nb_epochs):
        print('training epoch:', epoch + 1)
        progbar = Progbar(target=udc_dataset.nb_tng, width=50)
        train_gen = udc_dataset.train_generator(batch_size=hparams.batch_size,
                                                max_epochs=1)

        # mini batches
        for batch_context, batch_utterance in train_gen:
            feed_dict = {input_x: batch_context, input_y: batch_utterance}

            # OPT: run one step of optimization
            optimizer.run(session=sess, feed_dict=feed_dict)
            # update loss metrics
            if nb_batches_served % eval_every_n_batches == 0:

                # calculate test error
                train_err = loss.eval(session=sess, feed_dict=feed_dict)
                prec_at_1 = test_precision_at_k(S, feed_dict, k=1, sess=sess)
                prec_at_2 = test_precision_at_k(S, feed_dict, k=2, sess=sess)

                # update prog bar
                exp.add_metric_row({
                    'tng loss': train_err,
                    'P@1': prec_at_1,
                    'P@2': prec_at_2
                })

            nb_batches_served += 1

            progbar.add(n=len(batch_context),
                        values=[('train_err', train_err), ('P@1', prec_at_1),
                                ('P@2', prec_at_2)])

        # ----------------------
        # END OF EPOCH PROCESSING
        # ----------------------
        # calculate the val loss
        print('\nepoch complete...\n')
        check_val_stats(loss, S, udc_dataset, hparams, input_x, input_y, exp,
                        sess, epoch)

        # save model
        save_model(saver=saver, hparams=hparams, sess=sess, epoch=epoch)

        # save exp data
        exp.save()

    tf.reset_default_graph()
def train_main(hparams):
    """
    Main training routine for the dot semantic network bot
    :return:
    """

    # -----------------------
    # INIT EXPERIMENT
    # ----------------------
    exp = Experiment(name=hparams.exp_name,
                     debug=hparams.debug,
                     description=hparams.exp_desc,
                     autosave=False,
                     save_dir=hparams.test_tube_dir)

    exp.add_argparse_meta(hparams)
    exp.save()

    # -----------------------
    # LOAD DATASET
    # ----------------------
    udc_dataset = UDCDataset(vocab_path=hparams.vocab_path,
                             train_path=hparams.dataset_train_path,
                             test_path=hparams.dataset_test_path,
                             val_path=hparams.dataset_val_path,
                             max_seq_len=hparams.max_seq_len)

    # -----------------------
    # INIT TF VARS
    # ----------------------
    # context holds chat history
    # utterance holds our responses
    # labels holds the ground truth labels
    context_ph = tf.placeholder(dtype=tf.int32,
                                shape=[hparams.batch_size, None],
                                name='context_seq_in')
    utterance_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[hparams.batch_size, None],
                                  name='utterance_seq_in')

    # ----------------------
    # EMBEDDING LAYER
    # ----------------------
    # you can preload your own or learn in the network
    # in this case we'll just learn it in the network
    embedding_layer = tf.Variable(tf.random_uniform(
        [udc_dataset.vocab_size, hparams.embedding_dim], -1.0, 1.0),
                                  name='embedding')

    # ----------------------
    # RESOLVE EMBEDDINGS
    # ----------------------
    # look up embeddings
    context_embedding = tf.nn.embedding_lookup(embedding_layer, context_ph)
    utterance_embedding = tf.nn.embedding_lookup(embedding_layer, utterance_ph)

    # avg all embeddings (sum works better?)
    # this generates 1 vector per training example
    context_embedding_summed = tf.reduce_mean(context_embedding, axis=1)
    utterance_embedding_summed = tf.reduce_mean(utterance_embedding, axis=1)

    # ----------------------
    # OPTIMIZATION PROBLEM
    # ----------------------
    model, _, _, pred_opt = dot_semantic_nn(
        context=context_embedding_summed,
        utterance=utterance_embedding_summed,
        tng_mode=hparams.train_mode)

    # allow optiizer to be changed through hyper params
    optimizer = get_optimizer(hparams=hparams, minimize=model)

    # ----------------------
    # TF ADMIN (VAR INIT, SESS)
    # ----------------------
    sess = tf.Session()
    init_vars = tf.global_variables_initializer()
    sess.run(init_vars)

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()

    # ----------------------
    # TRAINING ROUTINE
    # ----------------------
    # admin vars
    nb_batches_served = 0
    eval_every_n_batches = hparams.eval_every_n_batches

    train_err = 1000
    precission_at_1 = 0
    precission_at_2 = 0

    # iter for the needed epochs
    print('\n\n', '-' * 100,
          '\n  {} TRAINING\n'.format(hparams.exp_name.upper()), '-' * 100,
          '\n\n')
    for epoch in range(hparams.nb_epochs):
        print('training epoch:', epoch + 1)
        progbar = Progbar(target=udc_dataset.nb_tng, width=50)
        train_gen = udc_dataset.train_generator(batch_size=hparams.batch_size,
                                                max_epochs=1)

        # mini batches
        for batch_context, batch_utterance in train_gen:

            feed_dict = {
                context_ph: batch_context,
                utterance_ph: batch_utterance
            }

            # OPT: run one step of optimization
            optimizer.run(session=sess, feed_dict=feed_dict)
            # update loss metrics
            if nb_batches_served % eval_every_n_batches == 0:

                # calculate test error
                train_err = model.eval(session=sess, feed_dict=feed_dict)
                precission_at_1 = test_precision_at_k(pred_opt,
                                                      feed_dict,
                                                      k=1,
                                                      sess=sess)
                precission_at_2 = test_precision_at_k(pred_opt,
                                                      feed_dict,
                                                      k=2,
                                                      sess=sess)

                # update prog bar
                exp.add_metric_row({
                    'tng loss': train_err,
                    'P@1': precission_at_1,
                    'P@2': precission_at_2
                })

            nb_batches_served += 1

            progbar.add(n=len(batch_context),
                        values=[('train_err', train_err),
                                ('P@1', precission_at_1),
                                ('P@2', precission_at_2)])

        # ----------------------
        # END OF EPOCH PROCESSING
        # ----------------------
        # calculate the val loss
        print('\nepoch complete...\n')
        check_val_stats(model, pred_opt, udc_dataset, hparams, context_ph,
                        utterance_ph, exp, sess, epoch)

        # save model
        save_model(saver=saver, hparams=hparams, sess=sess, epoch=epoch)

        # save exp data
        exp.save()
def train_main(hparams):
    """
    Main training routine for the dot semantic network bot
    :return:
    """

    # -----------------------
    # INIT EXPERIMENT
    # ----------------------
    exp = Experiment(name=hparams.exp_name,
                     debug=hparams.debug,
                     description=hparams.exp_desc,
                     autosave=False,
                     save_dir=hparams.test_tube_dir)

    exp.add_meta_tags(vars(hparams))

    # -----------------------
    # LOAD DATASET
    # ----------------------
    udc_dataset = UDCDataset(vocab_path=hparams.vocab_path,
                             train_path=hparams.dataset_train_path,
                             test_path=hparams.dataset_test_path,
                             val_path=hparams.dataset_val_path,
                             max_seq_len=hparams.max_seq_len)

    # -----------------------
    # INIT TF VARS
    # ----------------------
    # context holds chat history
    # utterance holds our responses
    # labels holds the ground truth labels
    context_ph = tf.placeholder(dtype="string",
                                shape=[
                                    hparams.batch_size,
                                ],
                                name='context_seq_in')
    utterance_ph = tf.placeholder(dtype="string",
                                  shape=[
                                      hparams.batch_size,
                                  ],
                                  name='utterance_seq_in')

    # ----------------------
    # EMBEDDING LAYER
    # ----------------------
    # you can preload your own or learn in the network
    # in this case we'll just learn it in the network
    # embedding_layer = tf.Variable(tf.random_uniform([udc_dataset.vocab_size, hparams.embedding_dim], -1.0, 1.0), name='embedding')
    #x = prep(udc_dataset.train, hparams.batch_size)
    #print(type(x))
    #print(len(x))

    # elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=True)
    sess = tf.Session()

    K.set_session(sess)
    # Initialize sessions
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    # print('elmo')
    # context = list(udc_dataset['Context'])
    # elmo_text = elmo(context, signature="default", as_dict=True)
    # input_text = Input(shape=(100,), tensor= ,dtype="string")
    #custom_layer = MyLayer(output_dim=1024, trainable=True)(tf.convert_to_tensor(x, dtype='string'))
    # embedding = Lambda(ELMoEmbedding, output_shape=(1024, ))(input_text)

    # elmo_text = elmo(tf.squeeze(tf.cast(x, tf.string)), signature="default", as_dict=True)["default"]
    #embedding_layer = Dense(256, activation='relu', kernel_regularizer=keras.regularizers.l2(0.001))(custom_layer)

    print('embedding_layer')
    # ----------------------
    # RESOLVE EMBEDDINGS
    # ----------------------
    # look up embeddings
    context_embedding_custom0 = MyLayer(output_dim=1024,
                                        trainable=True)(tf.slice(
                                            context_ph, [0], [1]))
    utterance_embedding_custom0 = MyLayer(output_dim=1024,
                                          trainable=True)(tf.slice(
                                              utterance_ph, [0], [1]))
    print('context')
    print(tf.shape(context_embedding_custom0))

    for batch_num in range(1, hparams.batch_size):

        context_embedding_custom = MyLayer(output_dim=1024,
                                           trainable=True)(tf.slice(
                                               context_ph, [batch_num], [1]))
        utterance_embedding_custom = MyLayer(output_dim=1024, trainable=True)(
            tf.slice(utterance_ph, [batch_num], [1]))
        context_embedding_custom0 = tf.concat(
            [context_embedding_custom0, context_embedding_custom], axis=0)
        utterance_embedding_custom0 = tf.concat(
            [utterance_embedding_custom0, utterance_embedding_custom], axis=0)

    print('concat')
    print(tf.shape(context_embedding_custom0))

    #context_embedding_summed = tf.reduce_mean(context_embedding_custom0, axis=1)
    #utterance_embedding_summed = tf.reduce_mean(utterance_embedding_custom0, axis=1)
    #print('summed')
    #print(tf.shape(context_embedding_summed))

    #context_embedding = Dense(hparams.embedding_dim, activation='relu',
    #kernel_regularizer=keras.regularizers.l2(0.001))(
    #context_embedding_custom0)
    #utterance_embedding = Dense(hparams.embedding_dim, activation='relu',
    #kernel_regularizer=keras.regularizers.l2(0.001))(
    #utterance_embedding_custom0)
    #print('embedding')
    #print(tf.shape(context_embedding))

    # avg all embeddings (sum works better?)
    # this generates 1 vector per training example
    #context_embedding_summed = tf.reduce_mean(context_embedding, axis=1)
    #utterance_embedding_summed = tf.reduce_mean(utterance_embedding, axis=1)

    # ----------------------
    # OPTIMIZATION PROBLEM
    # ----------------------
    model, _, _, pred_opt = dot_semantic_nn(
        context=context_embedding_custom0,
        utterance=utterance_embedding_custom0,
        tng_mode=hparams.train_mode)

    # allow optiizer to be changed through hyper params
    optimizer = get_optimizer(hparams=hparams, minimize=model)

    # ----------------------
    # TF ADMIN (VAR INIT, SESS)
    # ----------------------
    sess = tf.Session()
    init_vars = tf.global_variables_initializer()
    sess.run(init_vars)

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()

    # ----------------------
    # TRAINING ROUTINE
    # ----------------------
    # admin vars
    nb_batches_served = 0
    eval_every_n_batches = hparams.eval_every_n_batches

    train_err = 1000
    precission_at_1 = 0
    precission_at_2 = 0

    # iter for the needed epochs
    print('\n\n', '-' * 100,
          '\n  {} TRAINING\n'.format(hparams.exp_name.upper()), '-' * 100,
          '\n\n')
    for epoch in range(hparams.nb_epochs):
        print('training epoch:', epoch + 1)
        progbar = Progbar(target=udc_dataset.nb_tng, width=50)
        train_gen = udc_dataset.train_generator(batch_size=hparams.batch_size,
                                                max_epochs=1)

        # mini batches
        for batch_context, batch_utterance in train_gen:

            feed_dict = {
                context_ph: batch_context,
                utterance_ph: batch_utterance
            }
            print("optimizer!")
            # OPT: run one step of optimization
            optimizer.run(session=sess, feed_dict=feed_dict)
            # update loss metrics
            if nb_batches_served % eval_every_n_batches == 0:
                # calculate test error
                train_err = model.eval(session=sess, feed_dict=feed_dict)
                precission_at_1 = test_precision_at_k(pred_opt,
                                                      feed_dict,
                                                      k=1,
                                                      sess=sess)
                precission_at_2 = test_precision_at_k(pred_opt,
                                                      feed_dict,
                                                      k=2,
                                                      sess=sess)

                # update prog bar
                exp.add_metric_row({
                    'tng loss': train_err,
                    'P@1': precission_at_1,
                    'P@2': precission_at_2
                })

            nb_batches_served += 1

            progbar.add(n=len(batch_context),
                        values=[('train_err', train_err),
                                ('P@1', precission_at_1),
                                ('P@2', precission_at_2)])

        # ----------------------
        # END OF EPOCH PROCESSING
        # ----------------------
        # calculate the val loss
        print('\nepoch complete...\n')
        check_val_stats(model, pred_opt, udc_dataset, hparams, context_ph,
                        utterance_ph, exp, sess, epoch)

        # save model
        save_model(saver=saver, hparams=hparams, sess=sess, epoch=epoch)
def train_main(hparams):

    exp = Experiment(name=hparams.exp_name,
                     debug=hparams.debug,
                     description=hparams.exp_desc,
                     autosave=False,
                     save_dir=hparams.test_tube_dir)

    exp.add_argparse_meta(hparams)
    exp.save()

    udc_dataset = UDCDataset(vocab_path=hparams.vocab_path,
                             train_path=hparams.dataset_train_path,
                             test_path=hparams.dataset_test_path,
                             val_path=hparams.dataset_val_path,
                             max_seq_len=hparams.max_seq_len)

    context_ph = tf.placeholder(dtype=tf.int32,
                                shape=[hparams.batch_size, None],
                                name='context_seq_in')
    utterance_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[hparams.batch_size, None],
                                  name='utterance_seq_in')

    embedding_layer = tf.Variable(tf.random_uniform(
        [udc_dataset.vocab_size, hparams.embedding_dim], -1.0, 1.0),
                                  name='embedding')

    context_embedding = tf.nn.embedding_lookup(embedding_layer, context_ph)
    utterance_embedding = tf.nn.embedding_lookup(embedding_layer, utterance_ph)

    context_embedding_summed = tf.reduce_mean(context_embedding, axis=1)
    utterance_embedding_summed = tf.reduce_mean(utterance_embedding, axis=1)

    model, _, _, pred_opt = dot_semantic_nn(
        context=context_embedding_summed,
        utterance=utterance_embedding_summed,
        tng_mode=hparams.train_mode)

    optimizer = get_optimizer(hparams=hparams, minimize=model)

    sess = tf.Session()
    init_vars = tf.global_variables_initializer()
    sess.run(init_vars)

    saver = tf.train.Saver()

    nb_batches_served = 0
    eval_every_n_batches = hparams.eval_every_n_batches

    train_err = 1000
    precission_at_1 = 0
    precission_at_2 = 0

    # iter for the needed epochs
    print('\n\n', '-' * 100,
          '\n  {} TRAINING\n'.format(hparams.exp_name.upper()), '-' * 100,
          '\n\n')
    for epoch in range(hparams.nb_epochs):
        print('training epoch:', epoch + 1)
        progbar = Progbar(target=udc_dataset.nb_tng, width=50)
        train_gen = udc_dataset.train_generator(batch_size=hparams.batch_size,
                                                max_epochs=1)

        for batch_context, batch_utterance in train_gen:

            feed_dict = {
                context_ph: batch_context,
                utterance_ph: batch_utterance
            }

            optimizer.run(session=sess, feed_dict=feed_dict)

            if nb_batches_served % eval_every_n_batches == 0:

                train_err = model.eval(session=sess, feed_dict=feed_dict)
                precission_at_1 = test_precision_at_k(pred_opt,
                                                      feed_dict,
                                                      k=1,
                                                      sess=sess)
                precission_at_2 = test_precision_at_k(pred_opt,
                                                      feed_dict,
                                                      k=2,
                                                      sess=sess)

                exp.add_metric_row({
                    'tng loss': train_err,
                    'P@1': precission_at_1,
                    'P@2': precission_at_2
                })

            nb_batches_served += 1

            progbar.add(n=len(batch_context),
                        values=[('train_err', train_err),
                                ('P@1', precission_at_1),
                                ('P@2', precission_at_2)])

        print('\nepoch complete...\n')
        check_val_stats(model, pred_opt, udc_dataset, hparams, context_ph,
                        utterance_ph, exp, sess, epoch)

        save_model(saver=saver, hparams=hparams, sess=sess, epoch=epoch)

        exp.save()