Exemple #1
0
def train_cnn():
    """Training CNN model."""

    # Load sentences, labels, and training parameters
    logger.info("✔︎ Loading data...")

    logger.info("✔︎ Training data processing...")
    train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.embedding_dim)

    logger.info("✔︎ Validation data processing...")
    validation_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.embedding_dim)

    logger.info("Recommended padding Sequence length is: {0}".format(FLAGS.pad_seq_len))

    logger.info("✔︎ Training data padding...")
    x_train_front, x_train_behind, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info("✔︎ Validation data padding...")
    x_validation_front, x_validation_behind, y_validation = dh.pad_data(validation_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = dh.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and cnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            cnn = TextCNN(
                sequence_length=FLAGS.pad_seq_len,
                num_classes=y_train.shape[1],
                vocab_size=VOCAB_SIZE,
                fc_hidden_size=FLAGS.fc_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))),
                num_filters=FLAGS.num_filters,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate,
                                                           global_step=cnn.global_step, decay_steps=FLAGS.decay_steps,
                                                           decay_rate=FLAGS.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(cnn.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=cnn.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input("☛ Please input the checkpoints model you want to restore, "
                              "it should be like(1490175368): ")  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input("✘ The format of your input is illegal, please re-input: ")
                logger.info("✔︎ The format of your input is legal, now loading to next step...")
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss and accuracy
            loss_summary = tf.summary.scalar("loss", cnn.loss)
            acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary, acc_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if FLAGS.train_or_restore == 'R':
                # Load cnn model
                logger.info("✔︎ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(cnn.global_step)

            def train_step(x_batch_front, x_batch_behind, y_batch):
                """A single training step"""
                feed_dict = {
                    cnn.input_x_front: x_batch_front,
                    cnn.input_x_behind: x_batch_behind,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    cnn.is_training: True
                }
                _, step, summaries, loss, accuracy = sess.run(
                    [train_op, cnn.global_step, train_summary_op, cnn.loss, cnn.accuracy], feed_dict)
                logger.info("step {0}: loss {1:g}, acc {2:g}".format(step, loss, accuracy))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_batch_front, x_batch_behind, y_batch, writer=None):
                """Evaluates model on a validation set"""
                feed_dict = {
                    cnn.input_x_front: x_batch_front,
                    cnn.input_x_behind: x_batch_behind,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: 1.0,
                    cnn.is_training: False
                }
                step, summaries, loss, accuracy, recall, precision, f1, auc = sess.run(
                    [cnn.global_step, validation_summary_op, cnn.loss, cnn.accuracy,
                     cnn.recall, cnn.precision, cnn.F1, cnn.AUC], feed_dict)
                logger.info("step {0}: loss {1:g}, acc {2:g}, recall {3:g}, precision {4:g}, f1 {5:g}, AUC {6}"
                            .format(step, loss, accuracy, recall, precision, f1, auc))
                if writer:
                    writer.add_summary(summaries, step)

                return accuracy

            # Generate batches
            batches = dh.batch_iter(
                list(zip(x_train_front, x_train_behind, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int((len(x_train_front) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch in batches:
                x_batch_front, x_batch_behind, y_batch = zip(*batch)
                train_step(x_batch_front, x_batch_behind, y_batch)
                current_step = tf.train.global_step(sess, cnn.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    accuracy = validation_step(x_validation_front, x_validation_behind, y_validation,
                                               writer=validation_summary_writer)
                    best_saver.handle(accuracy, sess, current_step)
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
def train_han():
    """Training HAN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.num_classes, args.word2vec_file, data_aug_flag=False)
    val_data = dh.load_data_and_labels(args.validation_file, args.num_classes, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_train, y_train = dh.pad_data(train_data, args.pad_seq_len)
    x_val, y_val = dh.pad_data(val_data, args.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Build a graph and han object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            han = TextHAN(
                sequence_length=args.pad_seq_len,
                vocab_size=VOCAB_SIZE,
                embedding_type=args.embedding_type,
                embedding_size=EMBEDDING_SIZE,
                lstm_hidden_size=args.lstm_dim,
                fc_hidden_size=args.fc_dim,
                num_classes=args.num_classes,
                l2_reg_lambda=args.l2_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate,
                                                           global_step=han.global_step, decay_steps=args.decay_steps,
                                                           decay_rate=args.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(han.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", han.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if OPTION == 'R':
                # Load han model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(han.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    han.input_x: x_batch,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: args.dropout_rate,
                    han.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, han.global_step, train_summary_op, han.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val, y_val)), args.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss = 0, 0.0

                eval_pre_tk = [0.0] * args.topK
                eval_rec_tk = [0.0] * args.topK
                eval_F1_tk = [0.0] * args.topK

                true_onehot_labels = []
                predicted_onehot_scores = []
                predicted_onehot_labels_ts = []
                predicted_onehot_labels_tk = [[] for _ in range(args.topK)]

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val = zip(*batch_validation)
                    feed_dict = {
                        han.input_x: x_batch_val,
                        han.input_y: y_batch_val,
                        han.dropout_keep_prob: 1.0,
                        han.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [han.global_step, validation_summary_op, han.scores, han.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch_val:
                        true_onehot_labels.append(i)
                    for j in scores:
                        predicted_onehot_scores.append(j)

                    # Predict by threshold
                    batch_predicted_onehot_labels_ts = \
                        dh.get_onehot_label_threshold(scores=scores, threshold=args.threshold)

                    for k in batch_predicted_onehot_labels_ts:
                        predicted_onehot_labels_ts.append(k)

                    # Predict by topK
                    for top_num in range(args.topK):
                        batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1)

                        for i in batch_predicted_onehot_labels_tk:
                            predicted_onehot_labels_tk[top_num].append(i)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels),
                                              y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
                                           y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
                                      y_pred=np.array(predicted_onehot_labels_ts), average='micro')

                for top_num in range(args.topK):
                    eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
                                                           y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                           average='micro')
                    eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
                                                        y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                        average='micro')
                    eval_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
                                                   y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                         y_score=np.array(predicted_onehot_scores), average='micro')
                # Calculate the average PR
                eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
                                                   y_score=np.array(predicted_onehot_scores), average='micro')

                return eval_loss, eval_auc, eval_prc, eval_pre_ts, eval_rec_ts, eval_F1_ts, \
                       eval_pre_tk, eval_rec_tk, eval_F1_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train)), args.batch_size, args.epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, han.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_prc, \
                    eval_pre_ts, eval_rec_ts, eval_F1_ts, eval_pre_tk, eval_rec_tk, eval_F1_tk = \
                        validation_step(x_val, y_val, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                                .format(eval_loss, eval_auc, eval_prc))

                    # Predict by threshold
                    logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                                .format(eval_pre_ts, eval_rec_ts, eval_F1_ts))

                    # Predict by topK
                    logger.info("Predict by topK:")
                    for top_num in range(args.topK):
                        logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}"
                                    .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F1_tk[top_num]))
                    best_saver.handle(eval_prc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
Exemple #3
0
def train_rcnn():
    """Training RCNN model."""

    # Load sentences, labels, and training parameters
    logger.info("✔︎ Loading data...")

    logger.info("✔︎ Training data processing...")
    train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes,
                                         FLAGS.embedding_dim, data_aug_flag=False)

    logger.info("✔︎ Validation data processing...")
    val_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes,
                                       FLAGS.embedding_dim, data_aug_flag=False)

    logger.info("Recommended padding Sequence length is: {0}".format(FLAGS.pad_seq_len))

    logger.info("✔︎ Training data padding...")
    x_train, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info("✔︎ Validation data padding...")
    x_val, y_val = dh.pad_data(val_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = dh.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and rcnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            rcnn = TextRCNN(
                sequence_length=FLAGS.pad_seq_len,
                num_classes=FLAGS.num_classes,
                vocab_size=VOCAB_SIZE,
                lstm_hidden_size=FLAGS.lstm_hidden_size,
                fc_hidden_size=FLAGS.fc_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))),
                num_filters=FLAGS.num_filters,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate,
                                                           global_step=rcnn.global_step, decay_steps=FLAGS.decay_steps,
                                                           decay_rate=FLAGS.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(rcnn.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=rcnn.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input("☛ Please input the checkpoints model you want to restore, "
                              "it should be like(1490175368): ")  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input("✘ The format of your input is illegal, please re-input: ")
                logger.info("✔︎ The format of your input is legal, now loading to next step...")
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", rcnn.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if FLAGS.train_or_restore == 'R':
                # Load rcnn model
                logger.info("✔︎ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(rcnn.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    rcnn.input_x: x_batch,
                    rcnn.input_y: y_batch,
                    rcnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    rcnn.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, rcnn.global_step, train_summary_op, rcnn.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val, y_val)), FLAGS.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss = 0, 0.0

                eval_pre_tk = [0.0] * FLAGS.top_num
                eval_rec_tk = [0.0] * FLAGS.top_num
                eval_F_tk = [0.0] * FLAGS.top_num

                true_onehot_labels = []
                predicted_onehot_scores = []
                predicted_onehot_labels_ts = []
                predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)]

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val = zip(*batch_validation)
                    feed_dict = {
                        rcnn.input_x: x_batch_val,
                        rcnn.input_y: y_batch_val,
                        rcnn.dropout_keep_prob: 1.0,
                        rcnn.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [rcnn.global_step, validation_summary_op, rcnn.scores, rcnn.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch_val:
                        true_onehot_labels.append(i)
                    for j in scores:
                        predicted_onehot_scores.append(j)

                    # Predict by threshold
                    batch_predicted_onehot_labels_ts = \
                        dh.get_onehot_label_threshold(scores=scores, threshold=FLAGS.threshold)

                    for k in batch_predicted_onehot_labels_ts:
                        predicted_onehot_labels_ts.append(k)

                    # Predict by topK
                    for top_num in range(FLAGS.top_num):
                        batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1)

                        for i in batch_predicted_onehot_labels_tk:
                            predicted_onehot_labels_tk[top_num].append(i)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1 (threshold & topK)
                eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels),
                                              y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
                                           y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_F_ts = f1_score(y_true=np.array(true_onehot_labels),
                                     y_pred=np.array(predicted_onehot_labels_ts), average='micro')

                for top_num in range(FLAGS.top_num):
                    eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
                                                           y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                           average='micro')
                    eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
                                                        y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                        average='micro')
                    eval_F_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
                                                  y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                  average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                         y_score=np.array(predicted_onehot_scores), average='micro')
                # Calculate the average PR
                eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
                                                   y_score=np.array(predicted_onehot_scores), average='micro')

                return eval_loss, eval_auc, eval_prc, eval_rec_ts, eval_pre_ts, eval_F_ts, \
                       eval_rec_tk, eval_pre_tk, eval_F_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, rcnn.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_prc, \
                    eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk = \
                        validation_step(x_val, y_val, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                                .format(eval_loss, eval_auc, eval_prc))

                    # Predict by threshold
                    logger.info("☛ Predict by threshold: Precision {0:g}, Recall {1:g}, F {2:g}"
                                .format(eval_pre_ts, eval_rec_ts, eval_F_ts))

                    # Predict by topK
                    logger.info("☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}"
                                    .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F_tk[top_num]))
                    best_saver.handle(eval_prc, sess, current_step)
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
Exemple #4
0
def train_sann():
    """Training RNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load word2vec model
    word2idx, embedding_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args, args.train_file, word2idx)
    val_data = dh.load_data_and_labels(args, args.validation_file, word2idx)

    # Build a graph and sann object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            sann = TextSANN(sequence_length=args.pad_seq_len,
                            vocab_size=len(word2idx),
                            embedding_type=args.embedding_type,
                            embedding_size=args.embedding_dim,
                            lstm_hidden_size=args.lstm_dim,
                            attention_unit_size=args.attention_dim,
                            attention_hops_size=args.attention_hops_dim,
                            fc_hidden_size=args.fc_dim,
                            num_classes=args.num_classes,
                            l2_reg_lambda=args.l2_lambda,
                            pretrained_embedding=embedding_matrix)

            # Define training procedure
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=args.learning_rate,
                    global_step=sann.global_step,
                    decay_steps=args.decay_steps,
                    decay_rate=args.decay_rate,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(sann.loss))
                grads, _ = tf.clip_by_global_norm(grads,
                                                  clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(
                    zip(grads, vars),
                    global_step=sann.global_step,
                    name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{0}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", sann.loss)

            # Train summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir,
                                                num_to_keep=3,
                                                maximize=True)

            if OPTION == 'R':
                # Load sann model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer,
                                               config)

                # Save the embedding visualization
                saver.save(
                    sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(sann.global_step)

            def train_step(batch_data):
                """A single training step."""
                x_f, x_b, y_onehot = zip(*batch_data)

                feed_dict = {
                    sann.input_x_front: x_f,
                    sann.input_x_behind: x_b,
                    sann.input_y: y_onehot,
                    sann.dropout_keep_prob: args.dropout_rate,
                    sann.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, sann.global_step, train_summary_op, sann.loss],
                    feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(val_loader, writer=None):
                """Evaluates model on a validation set."""
                batches_validation = dh.batch_iter(
                    list(create_input_data(val_loader)), args.batch_size, 1)

                eval_counter, eval_loss = 0, 0.0
                true_labels = []
                predicted_scores = []
                predicted_labels = []

                for batch_validation in batches_validation:
                    x_f, x_b, y_onehot = zip(*batch_validation)
                    feed_dict = {
                        sann.input_x_front: x_f,
                        sann.input_x_behind: x_b,
                        sann.input_y: y_onehot,
                        sann.dropout_keep_prob: 1.0,
                        sann.is_training: False
                    }
                    step, summaries, predictions, cur_loss = sess.run([
                        sann.global_step, validation_summary_op,
                        sann.topKPreds, sann.loss
                    ], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_onehot:
                        true_labels.append(np.argmax(i))
                    for j in predictions[0]:
                        predicted_scores.append(j[0])
                    for k in predictions[1]:
                        predicted_labels.append(k[0])

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_acc = accuracy_score(y_true=np.array(true_labels),
                                          y_pred=np.array(predicted_labels))
                eval_pre = precision_score(y_true=np.array(true_labels),
                                           y_pred=np.array(predicted_labels),
                                           average='micro')
                eval_rec = recall_score(y_true=np.array(true_labels),
                                        y_pred=np.array(predicted_labels),
                                        average='micro')
                eval_F1 = f1_score(y_true=np.array(true_labels),
                                   y_pred=np.array(predicted_labels),
                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_labels),
                                         y_score=np.array(predicted_scores),
                                         average='micro')

                return eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc

            # Generate batches
            batches_train = dh.batch_iter(list(create_input_data(train_data)),
                                          args.batch_size, args.epochs)
            num_batches_per_epoch = int(
                (len(train_data['f_pad_seqs']) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                train_step(batch_train)
                current_step = tf.train.global_step(sess, sann.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc = \
                        validation_step(val_data, writer=validation_summary_writer)
                    logger.info(
                        "All Validation set: Loss {0:g} | Acc {1:g} | Precision {2:g} | "
                        "Recall {3:g} | F1 {4:g} | AUC {5:g}".format(
                            eval_loss, eval_acc, eval_pre, eval_rec, eval_F1,
                            eval_auc))
                    best_saver.handle(eval_acc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info(
                        "Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
Exemple #5
0
def train():
    train_students, train_max_num_problems, train_max_skill_num = dh.read_data_from_csv_file(
        FLAGS.train_data_path)
    test_students, test_max_num_problems, test_max_skill_num = dh.read_data_from_csv_file(
        FLAGS.valid_data_path)
    max_num_steps = max(train_max_num_problems, test_max_num_problems)
    max_num_skills = max(train_max_skill_num, test_max_skill_num)
    fileName = FLAGS.information_data_path
    same, differ = eb.embedding(fileName)

    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            tcn = TCN(batch_size=FLAGS.batch_size,
                      num_steps=max_num_steps,
                      num_skills=max_num_skills)

            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=FLAGS.learning_rate,
                    global_step=tcn.global_step,
                    decay_steps=(len(train_students) // FLAGS.batch_size + 1) *
                    FLAGS.decay_steps,
                    decay_rate=FLAGS.decay_rate,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                train_op = optimizer.minimize(tcn.loss,
                                              global_step=tcn.global_step,
                                              name="train_op")

            if FLAGS.train_or_restore == 'R':
                MODEL = input(
                    "Please input the checkpoints model you want to restore, "
                    "it should be like(1490175368): ")

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input(
                        "The format of your input is illegal, please re-input: "
                    )
                logger.info(
                    "The format of your input is legal, now loading to next step..."
                )
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "bestcheckpoints"))

            loss_summary = tf.summary.scalar("loss", tcn.loss)

            train_summary_op = tf.summary.merge([loss_summary])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir,
                                                num_to_keep=3,
                                                maximize=True)

            if FLAGS.train_or_restore == 'R':
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())
            current_step = sess.run(tcn.global_step)

            def train_step(x, next_id, target_id, target_correctness, same,
                           differ):
                feed_dict = {
                    tcn.input_data: x,
                    tcn.next_id: next_id,
                    tcn.target_id: target_id,
                    tcn.target_correctness: target_correctness,
                    tcn.dropout_keep_prob: FLAGS.keep_prob,
                    tcn.is_training: True,
                    tcn.same: same,
                    tcn.differ: differ
                }
                _, step, summaries, pred, loss = sess.run([
                    train_op, tcn.global_step, train_summary_op, tcn.pred,
                    tcn.loss
                ], feed_dict)
                train_summary_writer.add_summary(summaries, step)
                return pred, loss

            def validation_step(x, next_id, target_id, target_correctness,
                                same, differ):
                feed_dict = {
                    tcn.input_data: x,
                    tcn.next_id: next_id,
                    tcn.target_id: target_id,
                    tcn.target_correctness: target_correctness,
                    tcn.dropout_keep_prob: 1.0,
                    tcn.is_training: False,
                    tcn.same: same,
                    tcn.differ: differ
                }
                step, summaries, pred, loss = sess.run([
                    tcn.global_step, validation_summary_op, tcn.pred, tcn.loss
                ], feed_dict)
                validation_summary_writer.add_summary(summaries, step)
                return pred

            run_time = []
            m_rmse = 1
            m_r2 = 0
            m_acc = 0
            m_auc = 0
            for iii in range(FLAGS.epochs):
                random.shuffle(train_students)
                a = datetime.now()
                data_size = len(train_students)
                index = 0
                actual_labels = []
                pred_labels = []
                while (index + FLAGS.batch_size < data_size):
                    x = np.zeros((FLAGS.batch_size, max_num_steps))
                    next_id = np.zeros((FLAGS.batch_size, max_num_steps))
                    target_id = []
                    target_correctness = []
                    for i in range(FLAGS.batch_size):
                        student = train_students[index + i]
                        problem_ids = student[1]
                        correctness = student[2]
                        for j in range(len(problem_ids) - 1):
                            problem_id = int(problem_ids[j])

                            if (int(correctness[j]) == 0):
                                x[i, j] = problem_id + max_num_skills
                            else:
                                x[i, j] = problem_id

                            next_id[i, j] = int(problem_ids[j + 1])
                            target_id.append(i * max_num_steps + j)
                            target_correctness.append(int(correctness[j + 1]))
                            actual_labels.append(int(correctness[j + 1]))
                    index += FLAGS.batch_size
                    pred, loss = train_step(x, next_id, target_id,
                                            target_correctness, same, differ)
                    for p in pred:
                        pred_labels.append(p)
                    current_step = tf.train.global_step(sess, tcn.global_step)
                ll = data_size - index
                x = np.zeros((ll, max_num_steps))
                next_id = np.zeros((ll, max_num_steps))
                target_id = []
                target_correctness = []
                for i in range(ll):
                    student = train_students[index + i]
                    problem_ids = student[1]
                    correctness = student[2]
                    for j in range(len(problem_ids) - 1):
                        problem_id = int(problem_ids[j])
                        if (int(correctness[j]) == 0):
                            x[i, j] = problem_id + max_num_skills
                        else:
                            x[i, j] = problem_id
                        next_id[i, j] = int(problem_ids[j + 1])
                        target_id.append(i * max_num_steps + j)
                        target_correctness.append(int(correctness[j + 1]))
                        actual_labels.append(int(correctness[j + 1]))

                pred, loss = train_step(x, next_id, target_id,
                                        target_correctness, same, differ)
                for p in pred:
                    pred_labels.append(p)
                current_step = tf.train.global_step(sess, tcn.global_step)
                b = datetime.now()
                e_time = (b - a).total_seconds()
                run_time.append(e_time)
                rmse = sqrt(mean_squared_error(actual_labels, pred_labels))
                fpr, tpr, thresholds = metrics.roc_curve(actual_labels,
                                                         pred_labels,
                                                         pos_label=1)
                auc = metrics.auc(fpr, tpr)
                r2 = r2_score(actual_labels, pred_labels)
                pred_score = np.greater_equal(pred_labels, 0.5)
                pred_score = pred_score.astype(int)
                pred_score = np.equal(actual_labels, pred_score)
                acc = np.mean(pred_score.astype(int))
                logger.info(
                    "epochs {0}: rmse: {1:g}  auc: {2:g}  r2: {3:g}  acc:{4:g} "
                    .format((iii + 1), rmse, auc, r2, acc))

                if ((iii + 1) % FLAGS.evaluation_interval == 0):
                    logger.info("\nEvaluation:")
                    data_size = len(test_students)
                    index = 0
                    actual_labels = []
                    pred_labels = []
                    while (index + FLAGS.batch_size < data_size):
                        x = np.zeros((FLAGS.batch_size, max_num_steps))
                        next_id = np.zeros((FLAGS.batch_size, max_num_steps))
                        target_id = []
                        target_correctness = []
                        for i in range(FLAGS.batch_size):
                            student = test_students[index + i]
                            problem_ids = student[1]
                            correctness = student[2]
                            for j in range(len(problem_ids) - 1):
                                problem_id = int(problem_ids[j])
                                if (int(correctness[j]) == 0):
                                    x[i, j] = problem_id + max_num_skills
                                else:
                                    x[i, j] = problem_id
                                next_id[i, j] = int(problem_ids[j + 1])
                                target_id.append(i * max_num_steps + j)
                                target_correctness.append(
                                    int(correctness[j + 1]))
                                actual_labels.append(int(correctness[j + 1]))
                        index += FLAGS.batch_size
                        pred = validation_step(x, next_id, target_id,
                                               target_correctness, same,
                                               differ)
                        for p in pred:
                            pred_labels.append(p)
                    ll = data_size - index
                    x = np.zeros((
                        ll,
                        max_num_steps,
                    ))
                    next_id = np.zeros((ll, max_num_steps))
                    target_id = []
                    target_correctness = []
                    for i in range(ll):
                        student = test_students[index + i]
                        problem_ids = student[1]
                        correctness = student[2]
                        for j in range(len(problem_ids) - 1):
                            problem_id = int(problem_ids[j])
                            if (int(correctness[j]) == 0):
                                x[i, j] = problem_id + max_num_skills
                            else:
                                x[i, j] = problem_id
                            next_id[i, j] = int(problem_ids[j + 1])
                            target_id.append(i * max_num_steps + j)
                            target_correctness.append(int(correctness[j + 1]))
                            actual_labels.append(int(correctness[j + 1]))
                    pred = validation_step(x, next_id, target_id,
                                           target_correctness, same, differ)
                    for p in pred:
                        pred_labels.append(p)
                    rmse = sqrt(mean_squared_error(actual_labels, pred_labels))
                    fpr, tpr, thresholds = metrics.roc_curve(actual_labels,
                                                             pred_labels,
                                                             pos_label=1)
                    auc = metrics.auc(fpr, tpr)
                    r2 = r2_score(actual_labels, pred_labels)
                    pred_score = np.greater_equal(pred_labels, 0.5)
                    pred_score = pred_score.astype(int)
                    pred_score = np.equal(actual_labels, pred_score)
                    acc = np.mean(pred_score.astype(int))
                    logger.info(
                        "VALIDATION {0}: rmse {1:g}  auc {2:g}  r2 {3:g}   acc{4:g} "
                        .format((iii + 1) / FLAGS.evaluation_interval, rmse,
                                auc, r2, acc))

                    if rmse < m_rmse:
                        m_rmse = rmse
                    if auc > m_auc:
                        m_auc = auc
                    if acc > m_acc:
                        m_acc = acc
                    if r2 > m_r2:
                        m_r2 = r2

                    best_saver.handle(auc, sess, current_step)
                if ((iii + 1) % FLAGS.checkpoint_every == 0):
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))

                logger.info("Epoch {0} has finished!".format(iii + 1))

            logger.info("running time analysis: epoch{0}, avg_time{1}".format(
                len(run_time), np.mean(run_time)))
            logger.info(
                "max: rmse {0:g}  auc {1:g}  r2 {2:g}   acc{3:g} ".format(
                    m_rmse, m_auc, m_r2, m_acc))
def train():
    """Training model."""

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")

    logger.info("Training data processing...")
    train_students, train_max_num_problems, train_max_skill_num = dh.read_data_from_csv_file(
        FLAGS.train_data_path)

    logger.info("Validation data processing...")
    test_students, test_max_num_problems, test_max_skill_num = dh.read_data_from_csv_file(
        FLAGS.test_data_path)
    max_num_steps = max(train_max_num_problems, test_max_num_problems)
    max_num_skills = max(train_max_skill_num, test_max_skill_num)

    # Build a graph and lstm_3 object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            ckt = CKT(
                batch_size=FLAGS.batch_size,
                num_steps=max_num_steps,
                num_skills=max_num_skills,
                hidden_size=FLAGS.hidden_size,
            )

            # Define training procedure
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=FLAGS.learning_rate,
                    global_step=ckt.global_step,
                    decay_steps=(len(train_students) // FLAGS.batch_size + 1) *
                    FLAGS.decay_steps,
                    decay_rate=FLAGS.decay_rate,
                    staircase=True)
                # learning_rate = tf.train.piecewise_constant(FLAGS.epochs, boundaries=[7,10], values=[0.005, 0.0005, 0.0001])
                optimizer = tf.train.AdamOptimizer(learning_rate)
                #optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
                #grads, vars = zip(*optimizer.compute_gradients(ckt.loss))
                #grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio)
                #train_op = optimizer.apply_gradients(zip(grads, vars), global_step=ckt.global_step, name="train_op")
                train_op = optimizer.minimize(ckt.loss,
                                              global_step=ckt.global_step,
                                              name="train_op")

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input(
                    "Please input the checkpoints model you want to restore, "
                    "it should be like(1490175368): "
                )  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input(
                        "The format of your input is illegal, please re-input: "
                    )
                logger.info(
                    "The format of your input is legal, now loading to next step..."
                )
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", ckt.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir,
                                                num_to_keep=3,
                                                maximize=True)

            if FLAGS.train_or_restore == 'R':
                # Load ckt model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

            current_step = sess.run(ckt.global_step)

            def train_step(x, xx, l, next_id, target_id, target_correctness,
                           target_id2, target_correctness2):
                """A single training step"""

                #print(ability)
                feed_dict = {
                    ckt.input_data: x,
                    ckt.input_skill: xx,
                    ckt.l: l,
                    ckt.next_id: next_id,
                    ckt.target_id: target_id,
                    ckt.target_correctness: target_correctness,
                    ckt.target_id2: target_id2,
                    ckt.target_correctness2: target_correctness2,
                    ckt.dropout_keep_prob: FLAGS.keep_prob,
                    ckt.is_training: True
                }
                _, step, summaries, pred, loss = sess.run([
                    train_op, ckt.global_step, train_summary_op, ckt.pred,
                    ckt.loss
                ], feed_dict)

                logger.info("step {0}: loss {1:g} ".format(step, loss))
                train_summary_writer.add_summary(summaries, step)
                return pred

            def validation_step(x, xx, l, next_id, target_id,
                                target_correctness, target_id2,
                                target_correctness2):
                """Evaluates model on a validation set"""

                feed_dict = {
                    ckt.input_data: x,
                    ckt.input_skill: xx,
                    ckt.l: l,
                    ckt.next_id: next_id,
                    ckt.target_id: target_id,
                    ckt.target_correctness: target_correctness,
                    ckt.target_id2: target_id2,
                    ckt.target_correctness2: target_correctness2,
                    ckt.dropout_keep_prob: 1.0,
                    ckt.is_training: False
                }
                step, summaries, pred, loss = sess.run([
                    ckt.global_step, validation_summary_op, ckt.pred, ckt.loss
                ], feed_dict)
                validation_summary_writer.add_summary(summaries, step)
                return pred

            # Training loop. For each batch...

            run_time = []
            m_rmse = 1
            m_r2 = 0
            m_acc = 0
            m_auc = 0
            for iii in range(FLAGS.epochs):
                random.shuffle(train_students)
                a = datetime.now()
                data_size = len(train_students)
                index = 0
                actual_labels = []
                pred_labels = []
                while (index + FLAGS.batch_size < data_size):
                    x = np.zeros((FLAGS.batch_size, max_num_steps))
                    xx = np.zeros((FLAGS.batch_size, max_num_steps))
                    next_id = np.zeros((FLAGS.batch_size, max_num_steps))
                    l = np.ones(
                        (FLAGS.batch_size, max_num_steps, max_num_skills))
                    target_id = []
                    target_correctness = []
                    target_id2 = []
                    target_correctness2 = []
                    for i in range(FLAGS.batch_size):
                        student = train_students[index + i]
                        problem_ids = student[1]
                        correctness = student[2]
                        correct_num = np.zeros(max_num_skills)
                        answer_count = np.ones(max_num_skills)
                        for j in range(len(problem_ids) - 1):
                            problem_id = int(problem_ids[j])

                            if (int(correctness[j]) == 0):
                                x[i, j] = problem_id + max_num_skills
                            else:
                                x[i, j] = problem_id
                                correct_num[problem_id] += 1
                            l[i, j] = correct_num / answer_count
                            answer_count[problem_id] += 1
                            xx[i, j] = problem_id
                            next_id[i, j] = int(problem_ids[j + 1])

                            target_id.append(i * max_num_steps + j)
                            target_correctness.append(int(correctness[j + 1]))
                            actual_labels.append(int(correctness[j + 1]))
                        target_id2.append(i * max_num_steps + j)
                        target_correctness2.append(int(correctness[j + 1]))

                    index += FLAGS.batch_size
                    #print(ability)
                    pred = train_step(x, xx, l, next_id, target_id,
                                      target_correctness, target_id2,
                                      target_correctness2)
                    for p in pred:
                        pred_labels.append(p)
                    current_step = tf.train.global_step(sess, ckt.global_step)
                b = datetime.now()
                e_time = (b - a).total_seconds()
                run_time.append(e_time)
                rmse = sqrt(mean_squared_error(actual_labels, pred_labels))
                fpr, tpr, thresholds = metrics.roc_curve(actual_labels,
                                                         pred_labels,
                                                         pos_label=1)
                auc = metrics.auc(fpr, tpr)
                #calculate r^2
                r2 = r2_score(actual_labels, pred_labels)
                pred_score = np.greater_equal(pred_labels, 0.5)
                pred_score = pred_score.astype(int)
                pred_score = np.equal(actual_labels, pred_score)
                acc = np.mean(pred_score.astype(int))
                logger.info(
                    "epochs {0}: rmse {1:g}  auc {2:g}  r2 {3:g}  acc{4:g} ".
                    format((iii + 1), rmse, auc, r2, acc))

                if ((iii + 1) % FLAGS.evaluation_interval == 0):
                    logger.info("\nEvaluation:")

                    data_size = len(test_students)
                    index = 0
                    actual_labels = []
                    pred_labels = []
                    while (index + FLAGS.batch_size < data_size):
                        x = np.zeros((FLAGS.batch_size, max_num_steps))
                        xx = np.zeros((FLAGS.batch_size, max_num_steps))
                        next_id = np.zeros((FLAGS.batch_size, max_num_steps))
                        l = np.ones(
                            (FLAGS.batch_size, max_num_steps, max_num_skills))
                        target_id = []
                        target_correctness = []
                        target_id2 = []
                        target_correctness2 = []
                        for i in range(FLAGS.batch_size):
                            student = test_students[index + i]
                            problem_ids = student[1]
                            correctness = student[2]
                            correct_num = np.zeros(max_num_skills)
                            answer_count = np.ones(max_num_skills)
                            for j in range(len(problem_ids) - 1):
                                problem_id = int(problem_ids[j])

                                if (int(correctness[j]) == 0):
                                    x[i, j] = problem_id + max_num_skills
                                else:
                                    x[i, j] = problem_id
                                    correct_num[problem_id] += 1
                                l[i, j] = correct_num / answer_count
                                answer_count[problem_id] += 1
                                xx[i, j] = problem_id
                                next_id[i, j] = int(problem_ids[j + 1])
                                target_id.append(i * max_num_steps + j)
                                target_correctness.append(
                                    int(correctness[j + 1]))
                                actual_labels.append(int(correctness[j + 1]))
                            target_id2.append(i * max_num_steps + j)
                            target_correctness2.append(int(correctness[j + 1]))

                        index += FLAGS.batch_size
                        pred = validation_step(x, xx, l, next_id, target_id,
                                               target_correctness, target_id2,
                                               target_correctness2)
                        for p in pred:
                            pred_labels.append(p)

                    rmse = sqrt(mean_squared_error(actual_labels, pred_labels))
                    fpr, tpr, thresholds = metrics.roc_curve(actual_labels,
                                                             pred_labels,
                                                             pos_label=1)
                    auc = metrics.auc(fpr, tpr)
                    #calculate r^2
                    r2 = r2_score(actual_labels, pred_labels)
                    pred_score = np.greater_equal(pred_labels, 0.5)
                    pred_score = pred_score.astype(int)
                    pred_score = np.equal(actual_labels, pred_score)
                    acc = np.mean(pred_score.astype(int))

                    logger.info(
                        "VALIDATION {0}: rmse {1:g}  auc {2:g}  r2 {3:g}   acc{4:g} "
                        .format((iii + 1) / FLAGS.evaluation_interval, rmse,
                                auc, r2, acc))

                    if rmse < m_rmse:
                        m_rmse = rmse
                    if auc > m_auc:
                        m_auc = auc
                    if acc > m_acc:
                        m_acc = acc
                    if r2 > m_r2:
                        m_r2 = r2

                    best_saver.handle(auc, sess, current_step)
                if ((iii + 1) % FLAGS.checkpoint_every == 0):
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))

                logger.info("Epoch {0} has finished!".format(iii + 1))

            logger.info("running time analysis: epoch{0}, avg_time{1}".format(
                len(run_time), np.mean(run_time)))
            logger.info(
                "max: rmse {0:g}  auc {1:g}  r2 {2:g}   acc{3:g} ".format(
                    m_rmse, m_auc, m_r2, m_acc))
    logger.info("Done.")
Exemple #7
0
def train_lmlp():
    """Training LMLP model."""

    # Load sentences, labels, and training parameters
    logger.info("✔︎ Loading data...")

    logger.info("✔︎ Training data processing...")
    train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes_list, FLAGS.total_classes,
                                         FLAGS.embedding_dim, data_aug_flag=False)

    logger.info("✔︎ Validation data processing...")
    val_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes_list, FLAGS.total_classes,
                                       FLAGS.embedding_dim, data_aug_flag=False)

    logger.info("Recommended padding Sequence length is: {0}".format(FLAGS.pad_seq_len))

    logger.info("✔︎ Training data padding...")
    x_train, y_train, y_train_tuple = dh.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info("✔︎ Validation data padding...")
    x_val, y_val, y_val_tuple = dh.pad_data(val_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = dh.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and lmlp object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            lmlp = eval(FLAGS.lmlp_type)(
                sequence_length=FLAGS.pad_seq_len,
                num_classes_list=list(map(int, FLAGS.num_classes_list.split(','))),
                total_classes=FLAGS.total_classes,
                vocab_size=VOCAB_SIZE,
                fc_hidden_size=FLAGS.fc_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix,
                alpha=FLAGS.alpha,
                beta=FLAGS.beta)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate,
                                                           global_step=lmlp.global_step, decay_steps=FLAGS.decay_steps,
                                                           decay_rate=FLAGS.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(lmlp.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=lmlp.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input("☛ Please input the checkpoints model you want to restore, "
                              "it should be like(1490175368): ")  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input("✘ The format of your input is illegal, please re-input: ")
                logger.info("✔︎ The format of your input is legal, now loading to next step...")
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", lmlp.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if FLAGS.train_or_restore == 'R':
                # Load lmlp model
                logger.info("✔︎ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(lmlp.global_step)

            def train_step(x_batch, y_batch, y_batch_tuple):
                """A single training step"""
                y_batch_first = [i[0] for i in y_batch_tuple]
                y_batch_second = [j[1] for j in y_batch_tuple]
                y_batch_third = [k[2] for k in y_batch_tuple]

                feed_dict = {
                    lmlp.input_x: x_batch,
                    lmlp.input_y_first: y_batch_first,
                    lmlp.input_y_second: y_batch_second,
                    lmlp.input_y_third: y_batch_third,
                    lmlp.input_y: y_batch,
                    lmlp.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    lmlp.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, lmlp.global_step, train_summary_op, lmlp.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val, y_val, y_val_tuple, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(
                    list(zip(x_val, y_val, y_val_tuple)), FLAGS.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss, eval_auc = 0, 0.0, 0.0
                eval_rec_ts, eval_pre_ts, eval_F_ts = 0.0, 0.0, 0.0
                eval_rec_tk = [0.0] * FLAGS.top_num
                eval_pre_tk = [0.0] * FLAGS.top_num
                eval_F_tk = [0.0] * FLAGS.top_num
                val_scores = []

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val, y_batch_val_tuple = zip(*batch_validation)

                    y_batch_val_first = [i[0] for i in y_batch_val_tuple]
                    y_batch_val_second = [j[1] for j in y_batch_val_tuple]
                    y_batch_val_third = [k[2] for k in y_batch_val_tuple]

                    feed_dict = {
                        lmlp.input_x: x_batch_val,
                        lmlp.input_y_first: y_batch_val_first,
                        lmlp.input_y_second: y_batch_val_second,
                        lmlp.input_y_third: y_batch_val_third,
                        lmlp.input_y: y_batch_val,
                        lmlp.dropout_keep_prob: 1.0,
                        lmlp.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [lmlp.global_step, validation_summary_op, lmlp.scores, lmlp.loss], feed_dict)

                    for predicted_scores in scores:
                        val_scores.append(predicted_scores)

                    # Predict by threshold
                    predicted_labels_threshold, predicted_values_threshold = \
                        dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold)

                    cur_rec_ts, cur_pre_ts, cur_F_ts = 0.0, 0.0, 0.0

                    for index, predicted_label_threshold in enumerate(predicted_labels_threshold):
                        rec_inc_ts, pre_inc_ts = dh.cal_metric(predicted_label_threshold, y_batch_val[index])
                        cur_rec_ts, cur_pre_ts = cur_rec_ts + rec_inc_ts, cur_pre_ts + pre_inc_ts

                    cur_rec_ts = cur_rec_ts / len(y_batch_val)
                    cur_pre_ts = cur_pre_ts / len(y_batch_val)

                    cur_F_ts = dh.cal_F(cur_rec_ts, cur_pre_ts)

                    eval_rec_ts, eval_pre_ts = eval_rec_ts + cur_rec_ts, eval_pre_ts + cur_pre_ts

                    # Predict by topK
                    topK_predicted_labels = []
                    for top_num in range(FLAGS.top_num):
                        predicted_labels_topk, predicted_values_topk = \
                            dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1)
                        topK_predicted_labels.append(predicted_labels_topk)

                    cur_rec_tk = [0.0] * FLAGS.top_num
                    cur_pre_tk = [0.0] * FLAGS.top_num
                    cur_F_tk = [0.0] * FLAGS.top_num

                    for top_num, predicted_labels_topK in enumerate(topK_predicted_labels):
                        for index, predicted_label_topK in enumerate(predicted_labels_topK):
                            rec_inc_tk, pre_inc_tk = dh.cal_metric(predicted_label_topK, y_batch_val[index])
                            cur_rec_tk[top_num], cur_pre_tk[top_num] = \
                                cur_rec_tk[top_num] + rec_inc_tk, cur_pre_tk[top_num] + pre_inc_tk

                        cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(y_batch_val)
                        cur_pre_tk[top_num] = cur_pre_tk[top_num] / len(y_batch_val)

                        cur_F_tk[top_num] = dh.cal_F(cur_rec_tk[top_num], cur_pre_tk[top_num])

                        eval_rec_tk[top_num], eval_pre_tk[top_num] = \
                            eval_rec_tk[top_num] + cur_rec_tk[top_num], eval_pre_tk[top_num] + cur_pre_tk[top_num]

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                # Calculate the average AUC
                val_scores = np.array(val_scores)
                y_val = np.array(y_val)
                missing_labels_num = 0
                for index in range(FLAGS.total_classes):
                    y_true = y_val[:, index]
                    y_score = val_scores[:, index]
                    try:
                        eval_auc = eval_auc + roc_auc_score(y_true=y_true, y_score=y_score)
                    except:
                        missing_labels_num += 1

                eval_auc = eval_auc / (FLAGS.total_classes - missing_labels_num)
                eval_loss = float(eval_loss / eval_counter)
                eval_rec_ts = float(eval_rec_ts / eval_counter)
                eval_pre_ts = float(eval_pre_ts / eval_counter)
                eval_F_ts = dh.cal_F(eval_rec_ts, eval_pre_ts)

                for top_num in range(FLAGS.top_num):
                    eval_rec_tk[top_num] = float(eval_rec_tk[top_num] / eval_counter)
                    eval_pre_tk[top_num] = float(eval_pre_tk[top_num] / eval_counter)
                    eval_F_tk[top_num] = dh.cal_F(eval_rec_tk[top_num], eval_pre_tk[top_num])

                return eval_loss, eval_auc, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train, y_train_tuple)), FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train, y_batch_train_tuple = zip(*batch_train)
                train_step(x_batch_train, y_batch_train, y_batch_train_tuple)
                current_step = tf.train.global_step(sess, lmlp.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk = \
                        validation_step(x_val, y_val, y_val_tuple, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g} | AUC {1:g}".format(eval_loss, eval_auc))

                    # Predict by threshold
                    logger.info("☛ Predict by threshold: Recall {0:g}, Precision {1:g}, F {2:g}"
                                .format(eval_rec_ts, eval_pre_ts, eval_F_ts))

                    # Predict by topK
                    logger.info("☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info("Top{0}: Recall {1:g}, Precision {2:g}, F {3:g}"
                                    .format(top_num+1, eval_rec_tk[top_num], eval_pre_tk[top_num], eval_F_tk[top_num]))
                    best_saver.handle(eval_auc, sess, current_step)
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
Exemple #8
0
def train_tarnn():
    """Training TARNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.word2vec_file, data_aug_flag=False)
    val_data = dh.load_data_and_labels(args.validation_file, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_train_content, x_train_question, x_train_option, y_train = dh.pad_data(train_data, args.pad_seq_len)
    x_val_content, x_val_question, x_val_option, y_val = dh.pad_data(val_data, args.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Build a graph and tarnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            tarnn = TextTARNN(
                sequence_length=args.pad_seq_len,
                vocab_size=VOCAB_SIZE,
                embedding_type=args.embedding_type,
                embedding_size=EMBEDDING_SIZE,
                rnn_hidden_size=args.rnn_dim,
                rnn_type=args.rnn_type,
                rnn_layers=args.rnn_layers,
                attention_type=args.attention_type,
                fc_hidden_size=args.fc_dim,
                l2_reg_lambda=args.l2_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate,
                                                           global_step=tarnn.global_step, decay_steps=args.decay_steps,
                                                           decay_rate=args.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(tarnn.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=tarnn.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", tarnn.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=False)

            if OPTION == 'R':
                # Load tarnn model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(tarnn.global_step)

            def train_step(x_batch_content, x_batch_question, x_batch_option, y_batch):
                """A single training step"""
                feed_dict = {
                    tarnn.input_x_content: x_batch_content,
                    tarnn.input_x_question: x_batch_question,
                    tarnn.input_x_option: x_batch_option,
                    tarnn.input_y: y_batch,
                    tarnn.dropout_keep_prob: args.dropout_rate,
                    tarnn.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, tarnn.global_step, train_summary_op, tarnn.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val_content, x_val_question, x_val_option, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val_content, x_val_question, x_val_option, y_val)),
                                                   args.batch_size, 1)

                eval_counter, eval_loss = 0, 0.0
                true_labels = []
                predicted_scores = []

                for batch_validation in batches_validation:
                    x_batch_content, x_batch_question, x_batch_option, y_batch = zip(*batch_validation)
                    feed_dict = {
                        tarnn.input_x_content: x_batch_content,
                        tarnn.input_x_question: x_batch_question,
                        tarnn.input_x_option: x_batch_option,
                        tarnn.input_y: y_batch,
                        tarnn.dropout_keep_prob: 1.0,
                        tarnn.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [tarnn.global_step, validation_summary_op, tarnn.scores, tarnn.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch:
                        true_labels.append(i)
                    for j in scores:
                        predicted_scores.append(j)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate PCC & DOA
                pcc, doa = dh.evaluation(true_labels, predicted_scores)
                # Calculate RMSE
                rmse = mean_squared_error(true_labels, predicted_scores) ** 0.5
                r2 = r2_score(true_labels, predicted_scores)

                return eval_loss, pcc, doa, rmse, r2

            # Generate batches
            batches_train = dh.batch_iter(list(zip(x_train_content, x_train_question, x_train_option, y_train)),
                                          args.batch_size, args.epochs)

            num_batches_per_epoch = int((len(y_train) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train_content, x_batch_train_question, x_batch_train_option, y_batch_train = zip(*batch_train)
                train_step(x_batch_train_content, x_batch_train_question, x_batch_train_option, y_batch_train)
                current_step = tf.train.global_step(sess, tarnn.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, pcc, doa, rmse, r2 = validation_step(x_val_content, x_val_question, x_val_option, y_val,
                                                                    writer=validation_summary_writer)
                    logger.info("All Validation set: Loss {0:g} | PCC {1:g} | DOA {2:g} | RMSE {3:g} | R2 {4:g}"
                                .format(eval_loss, pcc, doa, rmse, r2))
                    best_saver.handle(rmse, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
Exemple #9
0
def train(ckpt_path, log_path, class_path, decay_steps=2000, decay_rate=0.8):
    """ Function to train the model.
		ckpt_path: string, path for saving/restoring the model
		log_path: string, path for saving the training/validation logs
		class_path: string, path for the classes of the dataset
		decay_steps: int, steps after which the learning rate is to be decayed
		decay_rate: float, rate to carrying out exponential decay
	"""

    # Getting the anchors
    anchors = read_anchors(config.anchors_path)
    if not os.path.exists(config.data_dir):
        os.mkdir(config.data_dir)

    classes = get_classes(class_path)

    # Building the training pipeline
    graph = tf.get_default_graph()

    with graph.as_default():

        # Getting the training data
        with tf.name_scope('data_parser/'):
            train_reader = Parser('train',
                                  config.data_dir,
                                  config.anchors_path,
                                  config.output_dir,
                                  config.num_classes,
                                  input_shape=config.input_shape,
                                  max_boxes=config.max_boxes)
            train_data = train_reader.build_dataset(config.train_batch_size //
                                                    config.subdivisions)
            train_iterator = train_data.make_one_shot_iterator()

            val_reader = Parser('val',
                                config.data_dir,
                                config.anchors_path,
                                config.output_dir,
                                config.num_classes,
                                input_shape=config.input_shape,
                                max_boxes=config.max_boxes)
            val_data = val_reader.build_dataset(config.val_batch_size //
                                                config.subdivisions)
            val_iterator = val_data.make_one_shot_iterator()

            is_training = tf.placeholder(
                dtype=tf.bool, shape=[], name='train_flag'
            )  # Used for different behaviour of batch normalization
            mode = tf.placeholder(dtype=tf.int16, shape=[], name='mode_flag')

            def train():
                # images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = train_iterator.get_next()
                return train_iterator.get_next()

            def valid():
                # images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = val_iterator.get_next()
                return val_iterator.get_next()

            images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = tf.cond(
                pred=tf.equal(mode, 1),
                true_fn=train,
                false_fn=valid,
                name='train_val_cond')

            images.set_shape([None, config.input_shape, config.input_shape, 3])
            bbox.set_shape([None, config.max_boxes, 5])

            grid_shapes = [
                config.input_shape // 32, config.input_shape // 16,
                config.input_shape // 8
            ]
            draw_box(images, bbox)

        # Extracting the pre-defined yolo graph from the darknet cfg file
        if not os.path.exists(ckpt_path):
            os.mkdir(ckpt_path)
        output = yolo(images, is_training, config.yolov3_cfg_path,
                      config.num_classes)

        # Declaring the parameters for GT
        with tf.name_scope('Targets'):
            bbox_true_13.set_shape([
                None, grid_shapes[0], grid_shapes[0], 3, 5 + config.num_classes
            ])
            bbox_true_26.set_shape([
                None, grid_shapes[1], grid_shapes[1], 3, 5 + config.num_classes
            ])
            bbox_true_52.set_shape([
                None, grid_shapes[2], grid_shapes[2], 3, 5 + config.num_classes
            ])
        y_true = [bbox_true_13, bbox_true_26, bbox_true_52]

        # Compute Loss
        with tf.name_scope('Loss_and_Detect'):
            yolo_loss = compute_loss(output,
                                     y_true,
                                     anchors,
                                     config.num_classes,
                                     print_loss=False)
            l2_loss = tf.losses.get_regularization_loss()
            loss = yolo_loss + l2_loss
            yolo_loss_summary = tf.summary.scalar('yolo_loss', yolo_loss)
            l2_loss_summary = tf.summary.scalar('l2_loss', l2_loss)
            total_loss_summary = tf.summary.scalar('Total_loss', loss)

        # Declaring the parameters for training the model
        with tf.name_scope('train_parameters'):
            epoch_loss = []
            global_step = tf.Variable(0, trainable=False, name='global_step')
            learning_rate = tf.train.exponential_decay(config.learning_rate,
                                                       global_step,
                                                       decay_steps, decay_rate)
            tf.summary.scalar('learning rate', learning_rate)

        # Define optimizer for minimizing the computed loss
        with tf.name_scope('Optimizer'):
            #optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=config.momentum)
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                if config.pre_train:
                    train_vars = tf.get_collection(
                        tf.GraphKeys.TRAINABLE_VARIABLES, scope='yolo')
                    grads = optimizer.compute_gradients(loss=loss,
                                                        var_list=train_vars)
                    gradients = [(tf.placeholder(dtype=tf.float32,
                                                 shape=grad[1].get_shape()),
                                  grad[1]) for grad in grads]
                    gradients = gradients * config.subdivisions
                    train_step = optimizer.apply_gradients(
                        grads_and_vars=gradients, global_step=global_step)
                else:
                    grads = optimizer.compute_gradients(loss=loss)
                    gradients = [(tf.placeholder(dtype=tf.float32,
                                                 shape=grad[1].get_shape()),
                                  grad[1]) for grad in grads]
                    gradients = gradients * config.subdivisions
                    train_step = optimizer.apply_gradients(
                        grads_and_vars=gradients, global_step=global_step)


#################################### Training loop ############################################################
# A saver object for saving the model
        best_ckpt_saver = checkmate.BestCheckpointSaver(save_dir=ckpt_path,
                                                        num_to_keep=5)
        summary_op = tf.summary.merge_all()
        summary_op_valid = tf.summary.merge(
            [yolo_loss_summary, l2_loss_summary, total_loss_summary])
        init_op = tf.global_variables_initializer()

        # Defining some train loop dependencies
        gpu_config = tf.ConfigProto(log_device_placement=False)
        gpu_config.gpu_options.allow_growth = True
        sess = tf.Session(config=gpu_config)
        tf.logging.set_verbosity(tf.logging.ERROR)
        train_summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, 'train'), sess.graph)
        val_summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, 'val'), sess.graph)

        # Restoring the model
        ckpt = tf.train.get_checkpoint_state(ckpt_path)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('Restoring model ', checkmate.get_best_checkpoint(ckpt_path))
            tf.train.Saver().restore(sess,
                                     checkmate.get_best_checkpoint(ckpt_path))
            print('Model Loaded!')
        elif config.pre_train is True:
            load_ops = load_weights(tf.global_variables(scope='darknet53'),
                                    config.darknet53_weights_path)
            sess.run(load_ops)
        else:
            sess.run(init_op)

        print('Uninitialized variables: ',
              sess.run(tf.report_uninitialized_variables()))

        epochbar = tqdm(range(config.Epoch))
        for epoch in epochbar:
            epochbar.set_description('Epoch %s of %s' % (epoch, config.Epoch))
            mean_loss_train = []
            mean_loss_valid = []

            trainbar = tqdm(range(config.train_num // config.train_batch_size))
            for k in trainbar:
                total_grad = []
                for minibach in range(config.subdivisions):
                    train_summary, loss_train, grads_and_vars = sess.run(
                        [summary_op, loss, grads],
                        feed_dict={
                            is_training: True,
                            mode: 1
                        })
                    total_grad += grads_and_vars

                feed_dict = {is_training: True, mode: 1}
                for i in range(len(gradients)):
                    feed_dict[gradients[i][0]] = total_grad[i][0]
                # print(np.shape(feed_dict))

                _ = sess.run(train_step, feed_dict=feed_dict)
                train_summary_writer.add_summary(train_summary, epoch)
                train_summary_writer.flush()
                mean_loss_train.append(loss_train)
                trainbar.set_description('Train loss: %s' % str(loss_train))

            print('Validating.....')
            valbar = tqdm(range(config.val_num // config.val_batch_size))
            for k in valbar:

                val_summary, loss_valid = sess.run([summary_op_valid, loss],
                                                   feed_dict={
                                                       is_training: False,
                                                       mode: 0
                                                   })

                val_summary_writer.add_summary(val_summary, epoch)
                val_summary_writer.flush()
                mean_loss_valid.append(loss_valid)
                valbar.set_description('Validation loss: %s' % str(loss_valid))

            mean_loss_train = np.mean(mean_loss_train)
            mean_loss_valid = np.mean(mean_loss_valid)

            print('\n')
            print('Train loss after %d epochs is: %f' %
                  (epoch + 1, mean_loss_train))
            print('Validation loss after %d epochs is: %f' %
                  (epoch + 1, mean_loss_valid))
            print('\n\n')

            if ((epoch + 1) % 3) == 0:
                best_ckpt_saver.handle(mean_loss_valid, sess,
                                       tf.constant(epoch))

        print('Tuning Completed!!')
        train_summary_writer.close()
        val_summary_writer.close()
        sess.close()
def train_hmidp():
    """Training hmdip model."""

    # Load sentences, labels, and training parameters
    logger.info("✔︎ Loading data...")

    logger.info("✔︎ Training data processing...")
    train_data = dh.load_data_and_labels(FLAGS.training_data_file,
                                         FLAGS.embedding_dim,
                                         data_aug_flag=False)

    logger.info("✔︎ Validation data processing...")
    val_data = dh.load_data_and_labels(FLAGS.validation_data_file,
                                       FLAGS.embedding_dim,
                                       data_aug_flag=False)

    logger.info("✔︎ Training data padding...")
    x_train_content, x_train_question, x_train_option, y_train = dh.pad_data(
        train_data, FLAGS.pad_seq_len)

    logger.info("✔︎ Validation data padding...")
    x_val_content, x_val_question, x_val_option, y_val = dh.pad_data(
        val_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(
        FLAGS.embedding_dim)

    # Build a graph and hmidp object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            hmidp = TextHMIDP(
                sequence_length=list(map(int, FLAGS.pad_seq_len.split(','))),
                vocab_size=VOCAB_SIZE,
                fc_hidden_size=FLAGS.fc_hidden_size,
                lstm_hidden_size=FLAGS.lstm_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))),
                num_filters=list(map(int, FLAGS.num_filters.split(','))),
                pooling_size=FLAGS.pooling_size,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=FLAGS.learning_rate,
                    global_step=hmidp.global_step,
                    decay_steps=FLAGS.decay_steps,
                    decay_rate=FLAGS.decay_rate,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(hmidp.loss))
                grads, _ = tf.clip_by_global_norm(grads,
                                                  clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(
                    zip(grads, vars),
                    global_step=hmidp.global_step,
                    name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{0}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input(
                    "☛ Please input the checkpoints model you want to restore, "
                    "it should be like(1490175368): "
                )  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input(
                        "✘ The format of your input is illegal, please re-input: "
                    )
                logger.info(
                    "✔︎ The format of your input is legal, now loading to next step..."
                )
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", hmidp.loss)

            # Train summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir,
                                                num_to_keep=3,
                                                maximize=False)

            if FLAGS.train_or_restore == 'R':
                # Load hmidp model
                logger.info("✔︎ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer,
                                               config)

                # Save the embedding visualization
                saver.save(
                    sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(hmidp.global_step)

            def train_step(x_batch_content, x_batch_question, x_batch_option,
                           y_batch):
                """A single training step"""
                feed_dict = {
                    hmidp.input_x_content: x_batch_content,
                    hmidp.input_x_question: x_batch_question,
                    hmidp.input_x_option: x_batch_option,
                    hmidp.input_y: y_batch,
                    hmidp.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    hmidp.is_training: True
                }
                _, step, summaries, loss = sess.run([
                    train_op, hmidp.global_step, train_summary_op, hmidp.loss
                ], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val_content,
                                x_val_question,
                                x_val_option,
                                y_val,
                                writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(
                    list(
                        zip(x_val_content, x_val_question, x_val_option,
                            y_val)), FLAGS.batch_size, 1)

                eval_counter, eval_loss = 0, 0.0

                true_labels = []
                predicted_scores = []

                for batch_validation in batches_validation:
                    x_batch_content, x_batch_question, x_batch_option, y_batch = zip(
                        *batch_validation)
                    feed_dict = {
                        hmidp.input_x_content: x_batch_content,
                        hmidp.input_x_question: x_batch_question,
                        hmidp.input_x_option: x_batch_option,
                        hmidp.input_y: y_batch,
                        hmidp.dropout_keep_prob: 1.0,
                        hmidp.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run([
                        hmidp.global_step, validation_summary_op, hmidp.scores,
                        hmidp.loss
                    ], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch:
                        true_labels.append(i)
                    for j in scores:
                        predicted_scores.append(j)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate PCC & DOA
                pcc, doa = dh.evaluation(true_labels, predicted_scores)
                # Calculate RMSE
                rmse = mean_squared_error(true_labels, predicted_scores)**0.5

                return eval_loss, pcc, doa, rmse

            # Generate batches
            batches_train = dh.batch_iter(
                list(
                    zip(x_train_content, x_train_question, x_train_option,
                        y_train)), FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int(
                (len(y_train) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train_content, x_batch_train_question, x_batch_train_option, y_batch_train = zip(
                    *batch_train)
                train_step(x_batch_train_content, x_batch_train_question,
                           x_batch_train_option, y_batch_train)
                current_step = tf.train.global_step(sess, hmidp.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, pcc, doa, rmse = validation_step(
                        x_val_content,
                        x_val_question,
                        x_val_option,
                        y_val,
                        writer=validation_summary_writer)
                    logger.info(
                        "All Validation set: Loss {0:g} | PCC {1:g} | DOA {2:g} | RMSE {3:g}"
                        .format(eval_loss, pcc, doa, rmse))
                    best_saver.handle(rmse, sess, current_step)
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info(
                        "✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info(
                        "✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
Exemple #11
0
def train(ckpt_path, log_path, class_path):
    """ Function to train the model.
		ckpt_path: string, path for saving/restoring the model
		log_path: string, path for saving the training/validation logs
		class_path: string, path for the classes of the dataset
		decay_steps: int, steps after which the learning rate is to be decayed
		decay_rate: float, rate to carrying out exponential decay
	"""

    # Getting the anchors
    anchors = read_anchors(config.anchors_path)

    classes = get_classes(class_path)

    if anchors.shape[0] // 3 == 2:
        yolo_tiny = True
    else:
        yolo_tiny = False

    # Building the training pipeline
    graph = tf.get_default_graph()

    with graph.as_default():

        # Getting the training data
        with tf.name_scope('data_parser/'):
            train_reader = Parser('train',
                                  config.anchors_path,
                                  config.output_dir,
                                  config.num_classes,
                                  input_shape=config.input_shape,
                                  max_boxes=config.max_boxes)
            train_data = train_reader.build_dataset(config.train_batch_size //
                                                    config.subdivisions)
            train_iterator = train_data.make_one_shot_iterator()

            val_reader = Parser('val',
                                config.anchors_path,
                                config.output_dir,
                                config.num_classes,
                                input_shape=config.input_shape,
                                max_boxes=config.max_boxes)
            val_data = val_reader.build_dataset(config.val_batch_size //
                                                config.subdivisions)
            val_iterator = val_data.make_one_shot_iterator()

            is_training = tf.placeholder(
                dtype=tf.bool, shape=[], name='train_flag'
            )  # Used for different behaviour of batch normalization
            mode = tf.placeholder(dtype=tf.int16, shape=[], name='mode_flag')

            def train():
                # images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = train_iterator.get_next()
                return train_iterator.get_next()

            def valid():
                # images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = val_iterator.get_next()
                return val_iterator.get_next()

            if yolo_tiny:
                images, bbox, bbox_true_13, bbox_true_26 = tf.cond(
                    pred=tf.equal(mode, 1),
                    true_fn=train,
                    false_fn=valid,
                    name='train_val__data')
                grid_shapes = [
                    config.input_shape // 32, config.input_shape // 16
                ]
            else:
                images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = tf.cond(
                    pred=tf.equal(mode, 1),
                    true_fn=train,
                    false_fn=valid,
                    name='train_val_data')
                grid_shapes = [
                    config.input_shape // 32, config.input_shape // 16,
                    config.input_shape // 8
                ]

            images.set_shape([None, config.input_shape, config.input_shape, 3])
            bbox.set_shape([None, config.max_boxes, 5])

            # image_summary = draw_box(images, bbox)

        # Extracting the pre-defined yolo graph from the darknet cfg file
        if not os.path.exists(ckpt_path):
            os.mkdir(ckpt_path)
        output = yolo(images, is_training, config.yolov3_cfg_path,
                      config.num_classes)

        # Declaring the parameters for GT
        with tf.name_scope('Targets'):
            if yolo_tiny:
                bbox_true_13.set_shape([
                    None, grid_shapes[0], grid_shapes[0],
                    config.num_anchors_per_scale, 5 + config.num_classes
                ])
                bbox_true_26.set_shape([
                    None, grid_shapes[1], grid_shapes[1],
                    config.num_anchors_per_scale, 5 + config.num_classes
                ])
                y_true = [bbox_true_13, bbox_true_26]
            else:
                bbox_true_13.set_shape([
                    None, grid_shapes[0], grid_shapes[0],
                    config.num_anchors_per_scale, 5 + config.num_classes
                ])
                bbox_true_26.set_shape([
                    None, grid_shapes[1], grid_shapes[1],
                    config.num_anchors_per_scale, 5 + config.num_classes
                ])
                bbox_true_52.set_shape([
                    None, grid_shapes[2], grid_shapes[2],
                    config.num_anchors_per_scale, 5 + config.num_classes
                ])
                y_true = [bbox_true_13, bbox_true_26, bbox_true_52]

        # Compute Loss
        with tf.name_scope('Loss_and_Detect'):
            loss_scale, yolo_loss, xy_loss, wh_loss, obj_loss, noobj_loss, conf_loss, class_loss = compute_loss(
                output,
                y_true,
                anchors,
                config.num_classes,
                config.input_shape,
                ignore_threshold=config.ignore_thresh)
            loss = yolo_loss
            exponential_moving_average_op = tf.train.ExponentialMovingAverage(
                config.weight_decay).apply(
                    var_list=tf.trainable_variables())  # For regularisation
            scale1_loss_summary = tf.summary.scalar('scale_loss_1',
                                                    loss_scale[0],
                                                    family='Loss')
            scale2_loss_summary = tf.summary.scalar('scale_loss_2',
                                                    loss_scale[1],
                                                    family='Loss')
            yolo_loss_summary = tf.summary.scalar('yolo_loss',
                                                  yolo_loss,
                                                  family='Loss')
            # total_loss_summary = tf.summary.scalar('Total_loss', loss, family='Loss')
            xy_loss_summary = tf.summary.scalar('xy_loss',
                                                xy_loss,
                                                family='Loss')
            wh_loss_summary = tf.summary.scalar('wh_loss',
                                                wh_loss,
                                                family='Loss')
            obj_loss_summary = tf.summary.scalar('obj_loss',
                                                 obj_loss,
                                                 family='Loss')
            noobj_loss_summary = tf.summary.scalar('noobj_loss',
                                                   noobj_loss,
                                                   family='Loss')
            conf_loss_summary = tf.summary.scalar('confidence_loss',
                                                  conf_loss,
                                                  family='Loss')
            class_loss_summary = tf.summary.scalar('class_loss',
                                                   class_loss,
                                                   family='Loss')

        # Declaring the parameters for training the model
        with tf.name_scope('train_parameters'):
            global_step = tf.Variable(0, trainable=False, name='global_step')

            def learning_rate_scheduler(learning_rate,
                                        scheduler_name,
                                        global_step,
                                        decay_steps=100):
                if scheduler_name == 'exponential':
                    lr = tf.train.exponential_decay(
                        learning_rate,
                        global_step,
                        decay_steps,
                        decay_rate,
                        staircase=True,
                        name='exponential_learning_rate')
                    return tf.maximum(lr, config.learning_rate_lower_bound)
                elif scheduler_name == 'polynomial':
                    lr = tf.train.polynomial_decay(
                        learning_rate,
                        global_step,
                        decay_steps,
                        config.learning_rate_lower_bound,
                        power=0.8,
                        cycle=True,
                        name='polynomial_learning_rate')
                    return tf.maximum(lr, config.learning_rate_lower_bound)
                elif scheduler_name == 'cosine':
                    lr = tf.train.cosine_decay(learning_rate,
                                               global_step,
                                               decay_steps,
                                               alpha=0.5,
                                               name='cosine_learning_rate')
                    return tf.maximum(lr, config.learning_rate_lower_bound)
                elif scheduler_name == 'linear':
                    return tf.convert_to_tensor(learning_rate,
                                                name='linear_learning_rate')
                else:
                    raise ValueError(
                        'Unsupported learning rate scheduler\n[supported types: exponential, polynomial, linear]'
                    )

            if config.use_warm_up:
                learning_rate = tf.cond(
                    pred=tf.less(
                        global_step,
                        config.burn_in_epochs *
                        (config.train_num // config.train_batch_size)),
                    true_fn=lambda: learning_rate_scheduler(
                        config.init_learning_rate, config.warm_up_lr_scheduler,
                        global_step),
                    false_fn=lambda: learning_rate_scheduler(
                        config.learning_rate,
                        config.lr_scheduler,
                        global_step,
                        decay_steps=500))
            else:
                learning_rate = learning_rate_scheduler(config.learning_rate,
                                                        config.lr_scheduler,
                                                        global_step,
                                                        decay_steps=2000)

            tf.summary.scalar('learning rate', learning_rate)

        # Define optimizer for minimizing the computed loss
        with tf.name_scope('Optimizer'):
            # optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=config.momentum)
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            # optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate, momentum=config.momentum)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                if config.pre_train:
                    train_vars = tf.get_collection(
                        tf.GraphKeys.TRAINABLE_VARIABLES, scope='yolo')
                else:
                    train_vars = tf.get_collection(
                        tf.GraphKeys.TRAINABLE_VARIABLES)

                grads = optimizer.compute_gradients(loss=loss,
                                                    var_list=train_vars)
                gradients = [(tf.placeholder(dtype=tf.float32,
                                             shape=grad[1].get_shape()),
                              grad[1]) for grad in grads]
                optimizing_op = optimizer.apply_gradients(
                    grads_and_vars=gradients, global_step=global_step)
                # optimizing_op = optimizer.minimize(loss=loss, global_step=global_step)

            with tf.control_dependencies([optimizing_op]):
                with tf.control_dependencies([exponential_moving_average_op]):
                    train_op_with_mve = tf.no_op()
            train_op = train_op_with_mve


#################################### Training loop ############################################################
# A saver object for saving the model
        best_ckpt_saver_train = checkmate.BestCheckpointSaver(
            save_dir=ckpt_path + 'train/', num_to_keep=5)
        best_ckpt_saver_valid = checkmate.BestCheckpointSaver(
            save_dir=ckpt_path + 'valid/', num_to_keep=5)
        summary_op = tf.summary.merge_all()
        summary_op_valid = tf.summary.merge([
            yolo_loss_summary, xy_loss_summary, wh_loss_summary,
            obj_loss_summary, noobj_loss_summary, conf_loss_summary,
            class_loss_summary, scale1_loss_summary, scale2_loss_summary
        ])

        init_op = tf.global_variables_initializer()

        # Defining some train loop dependencies
        gpu_config = tf.ConfigProto(log_device_placement=False)
        gpu_config.gpu_options.allow_growth = True
        sess = tf.Session(config=gpu_config)
        tf.logging.set_verbosity(tf.logging.ERROR)
        train_summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, 'train'), sess.graph)
        val_summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, 'val'), sess.graph)

        # Restoring the model
        ckpt = tf.train.get_checkpoint_state(ckpt_path + 'valid/')
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('Restoring model ',
                  checkmate.get_best_checkpoint(ckpt_path + 'valid/'))
            tf.train.Saver().restore(
                sess, checkmate.get_best_checkpoint(ckpt_path + 'valid/'))
            print('Model Loaded!')
        elif config.pre_train is True:
            sess.run(init_op)
            load_ops = load_weights(tf.global_variables(scope='darknet53'),
                                    config.darknet53_weights_path)
            sess.run(load_ops)
        else:
            sess.run(init_op)

        print('Uninitialized variables: ',
              sess.run(tf.report_uninitialized_variables()))

        epochbar = tqdm(range(config.Epoch))
        for epoch in epochbar:
            epochbar.set_description('Epoch %s of %s' % (epoch, config.Epoch))
            mean_loss_train = []
            mean_loss_valid = []

            trainbar = tqdm(range(config.train_num // config.train_batch_size))
            for k in trainbar:
                all_grads_and_vars = []
                for minibatch in range(config.train_batch_size //
                                       config.subdivisions):
                    num_steps, train_summary, loss_train, grads_and_vars = sess.run(
                        [global_step, summary_op, loss, grads],
                        feed_dict={
                            is_training: True,
                            mode: 1
                        })

                    all_grads_and_vars += grads_and_vars

                    train_summary_writer.add_summary(train_summary, epoch)
                    train_summary_writer.flush()
                    mean_loss_train.append(loss_train)
                    trainbar.set_description('Train loss: %s' %
                                             str(loss_train))

                feed_dict = {is_training: True, mode: 1}
                for i in range(len(gradients), len(all_grads_and_vars)):
                    all_grads_and_vars[
                        i % len(gradients)] += all_grads_and_vars[i][0]
                all_grads_and_vars = all_grads_and_vars[:len(gradients)]
                for i in range(len(gradients)):
                    feed_dict[gradients[i][0]] = all_grads_and_vars[i][0]
                # print(np.shape(feed_dict))

                _ = sess.run(train_op, feed_dict=feed_dict)

            print('Validating.....')
            valbar = tqdm(range(config.val_num // config.val_batch_size))
            for k in valbar:
                for minibatch in range(config.train_batch_size //
                                       config.subdivisions):
                    val_summary, loss_valid = sess.run(
                        [summary_op_valid, loss],
                        feed_dict={
                            is_training: False,
                            mode: 0
                        })
                    val_summary_writer.add_summary(val_summary, epoch)
                    val_summary_writer.flush()
                    mean_loss_valid.append(loss_valid)
                    valbar.set_description('Validation loss: %s' %
                                           str(loss_valid))

            mean_loss_train = np.mean(mean_loss_train)
            mean_loss_valid = np.mean(mean_loss_valid)

            print('\n')
            print('Train loss after %d epochs is: %f' %
                  (epoch + 1, mean_loss_train))
            print('Validation loss after %d epochs is: %f' %
                  (epoch + 1, mean_loss_valid))
            print('\n\n')

            if (config.use_warm_up):
                if (num_steps > config.burn_in_epochs *
                    (config.train_num // config.train_batch_size)):
                    best_ckpt_saver_train.handle(mean_loss_train, sess,
                                                 global_step)
                    best_ckpt_saver_valid.handle(mean_loss_valid, sess,
                                                 global_step)
                else:
                    continue
            else:
                best_ckpt_saver_train.handle(mean_loss_train, sess,
                                             global_step)
                best_ckpt_saver_valid.handle(mean_loss_valid, sess,
                                             global_step)

        print('Tuning Completed!!')
        train_summary_writer.close()
        val_summary_writer.close()
        sess.close()
Exemple #12
0
def train(ckpt_path, log_path, class_path):
	""" Function to train the model.
		ckpt_path: string, path for saving/restoring the model
		log_path: string, path for saving the training/validation logs
		class_path: string, path for the classes of the dataset
		decay_steps: int, steps after which the learning rate is to be decayed
		decay_rate: float, rate to carrying out exponential decay
	"""


	# Getting the anchors
	anchors = read_anchors(config.anchors_path)
	if not os.path.exists(config.data_dir):
		os.mkdir(config.data_dir)

	classes = get_classes(class_path)

	# Building the training pipeline
	graph = tf.get_default_graph()

	with graph.as_default():

		# Getting the training data
		with tf.name_scope('data_parser/'):
			train_reader = Parser('train', config.data_dir, config.anchors_path, config.output_dir, 
				config.num_classes, input_shape=config.input_shape, max_boxes=config.max_boxes)
			train_data = train_reader.build_dataset(config.train_batch_size//config.subdivisions)
			train_iterator = train_data.make_one_shot_iterator()

			val_reader = Parser('val', config.data_dir, config.anchors_path, config.output_dir, 
				config.num_classes, input_shape=config.input_shape, max_boxes=config.max_boxes)
			val_data = val_reader.build_dataset(config.val_batch_size)
			val_iterator = val_data.make_one_shot_iterator()


			is_training = tf.placeholder(dtype=tf.bool, shape=[], name='train_flag') # Used for different behaviour of batch normalization
			mode = tf.placeholder(dtype=tf.int16, shape=[], name='mode_flag')


			def train():
				return train_iterator.get_next()
			def valid():
				return val_iterator.get_next()


			images, labels = tf.cond(pred=tf.equal(mode, 1), true_fn=train, false_fn=valid, name='train_val_data')
			grid_shapes = [config.input_shape // 32, config.input_shape // 16, config.input_shape // 8]

			images.set_shape([None, config.input_shape, config.input_shape, 3])
			labels.set_shape([None, required_shape, 5])

			# image_summary = draw_box(images, bbox, file_name)

		if not os.path.exists(ckpt_path):
			os.mkdir(ckpt_path)

		model = model(images, is_training, config.num_classes, config.num_anchors_per_scale, config.weight_decay, config.norm_decay)
		output, model_layers = model.forward()

		print('Summary of the created model.......\n')
		for layer in model_layers:
			print(layer)

		# Declaring the parameters for GT
		with tf.name_scope('Targets'):
			### GT PROCESSING ###

		# Compute Loss
		with tf.name_scope('Loss_and_Detect'):
			loss_scale,summaries = compute_loss(output, y_true, config.num_classes, ignore_threshold=config.ignore_thresh)
			exponential_moving_average_op = tf.train.ExponentialMovingAverage(config.weight_decay).apply(var_list=tf.trainable_variables())
			loss = model_loss
			model_loss_summary = tf.summary.scalar('model_loss', summaries, family='Losses')


		# Declaring the parameters for training the model
		with tf.name_scope('train_parameters'):
			global_step = tf.Variable(0, trainable=False, name='global_step')

		# Declaring the parameters for training the model
		with tf.name_scope('train_parameters'):
			global_step = tf.Variable(0, trainable=False, name='global_step')

			def learning_rate_scheduler(learning_rate, scheduler_name, global_step, decay_steps=100):
				if scheduler_name == 'exponential':
					lr =  tf.train.exponential_decay(learning_rate, global_step,
						decay_steps, decay_rate, staircase=True, name='exponential_learning_rate')
					return tf.maximum(lr, config.learning_rate_lower_bound)
				elif scheduler_name == 'polynomial':
					lr =  tf.train.polynomial_decay(learning_rate, global_step,
						decay_steps, config.learning_rate_lower_bound, power=0.8, cycle=True, name='polynomial_learning_rate')
					return tf.maximum(lr, config.learning_rate_lower_bound)
				elif scheduler_name == 'cosine':
					lr = tf.train.cosine_decay(learning_rate, global_step,
						decay_steps, alpha=0.5, name='cosine_learning_rate')
					return tf.maximum(lr, config.learning_rate_lower_bound)
				elif scheduler_name == 'linear':
					return tf.convert_to_tensor(learning_rate, name='linear_learning_rate')
				else:
					raise ValueError('Unsupported learning rate scheduler\n[supported types: exponential, polynomial, linear]')


			if config.use_warm_up:
				learning_rate = tf.cond(pred=tf.less(global_step, config.burn_in_epochs * (config.train_num // config.train_batch_size)),
					true_fn=lambda: learning_rate_scheduler(config.init_learning_rate, config.warm_up_lr_scheduler, global_step),
					false_fn=lambda: learning_rate_scheduler(config.learning_rate, config.lr_scheduler, global_step, decay_steps=2000))
			else:
				learning_rate = learning_rate_scheduler(config.learning_rate, config.lr_scheduler, global_step=global_step, decay_steps=2000)

			tf.summary.scalar('learning rate', learning_rate, family='Train_Parameters')


		# Define optimizer for minimizing the computed loss
		with tf.name_scope('Optimizer'):
			optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=config.momentum)
			# optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
			# optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate, momentum=config.momentum)
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				# grads = optimizer.compute_gradients(loss=loss)
				# gradients = [(tf.placeholder(dtype=tf.float32, shape=grad[1].get_shape()), grad[1]) for grad in grads]
				# train_step = optimizer.apply_gradients(grads_and_vars=gradients, global_step=global_step)
				optimizing_op = optimizer.minimize(loss=loss, global_step=global_step)
			
			with tf.control_dependencies([optimizing_op]):
				with tf.control_dependencies([exponential_moving_average_op]):
					train_op_with_mve = tf.no_op()
			train_op = train_op_with_mve



#################################### Training loop ############################################################
		# A saver object for saving the model
		best_ckpt_saver_train = checkmate.BestCheckpointSaver(save_dir=ckpt_path+'train/', num_to_keep=5)
		best_ckpt_saver_valid = checkmate.BestCheckpointSaver(save_dir=ckpt_path+'valid/', num_to_keep=5)
		summary_op = tf.summary.merge_all()
		summary_op_valid = tf.summary.merge([model_loss_summary_without_learning_rate])
		init_op = tf.global_variables_initializer()


		
		# Defining some train loop dependencies
		gpu_config = tf.ConfigProto(log_device_placement=False)
		gpu_config.gpu_options.allow_growth = True
		sess = tf.Session(config=gpu_config)
		tf.logging.set_verbosity(tf.logging.ERROR)
		train_summary_writer = tf.summary.FileWriter(os.path.join(log_path, 'train'), sess.graph)
		val_summary_writer = tf.summary.FileWriter(os.path.join(log_path, 'val'), sess.graph)

		print(sess.run(receptive_field))
		
		# Restoring the model
		ckpt = tf.train.get_checkpoint_state(ckpt_path+'train/')
		if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
			print('Restoring model ', checkmate.get_best_checkpoint(ckpt_path+'train/'))
			tf.train.Saver().restore(sess, checkmate.get_best_checkpoint(ckpt_path+'train/'))
			print('Model Loaded!')
		else:
			sess.run(init_op)

		print('Uninitialized variables: ', sess.run(tf.report_uninitialized_variables()))


		epochbar = tqdm(range(config.Epoch))
		for epoch in epochbar:
			epochbar.set_description('Epoch %s of %s' % (epoch, config.Epoch))
			mean_loss_train = []
			mean_loss_valid = []

			trainbar = tqdm(range(config.train_num//config.train_batch_size))
			for k in trainbar:

				num_steps, train_summary, loss_train, _ = sess.run([global_step, summary_op, loss,
					train_op], feed_dict={is_training: True, mode: 1})

				train_summary_writer.add_summary(train_summary, epoch)
				train_summary_writer.flush()
				mean_loss_train.append(loss_train)
				trainbar.set_description('Train loss: %s' %str(loss_train))


			print('Validating.....')
			valbar = tqdm(range(config.val_num//config.val_batch_size))
			for k in valbar:
				val_summary, loss_valid = sess.run([summary_op_valid, loss], feed_dict={is_training: False, mode: 0})
				val_summary_writer.add_summary(val_summary, epoch)
				val_summary_writer.flush()
				mean_loss_valid.append(loss_valid)
				valbar.set_description('Validation loss: %s' %str(loss_valid))

			mean_loss_train = np.mean(mean_loss_train)
			mean_loss_valid = np.mean(mean_loss_valid)

			print('\n')
			print('Train loss after %d epochs is: %f' %(epoch+1, mean_loss_train))
			print('Validation loss after %d epochs is: %f' %(epoch+1, mean_loss_valid))
			print('\n\n')

			if (config.use_warm_up):
				if (num_steps > config.burn_in_epochs * (config.train_num // config.train_batch_size)):
					best_ckpt_saver_train.handle(mean_loss_train, sess, global_step)
					best_ckpt_saver_valid.handle(mean_loss_valid, sess, global_step)
				else:
					continue
			else:
				best_ckpt_saver_train.handle(mean_loss_train, sess, global_step)
				best_ckpt_saver_valid.handle(mean_loss_valid, sess, global_step)

		print('Tuning Completed!!')
		train_summary_writer.close()
		val_summary_writer.close()
		sess.close()





def main():
	""" main function which calls all the other required functions for training """
	os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
	os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_num)
	train(config.model_dir, config.logs_dir, config.classes_path)
	os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 



if __name__ == '__main__':
	main()
def train():
    """Training RMIDP model."""
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.word2vec_file)
    val_data = dh.load_data_and_labels(args.validation_file, args.word2vec_file)

    logger.info("Data padding...")
    train_dataset = dh.MyData(train_data, args.pad_seq_len, device)
    val_dataset = dh.MyData(val_data, args.pad_seq_len, device)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # Load word2vec model
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Init network
    logger.info("Init nn...")
    net = RMIDP(args, VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix).to(device)

    print("Model's state_dict:")
    for param_tensor in net.state_dict():
        print(param_tensor, "\t", net.state_dict()[param_tensor].size())

    criterion = Loss()
    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.l2_lambda)

    if OPTION == 'T':
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        saver = cm.BestCheckpointSaver(save_dir=out_dir, num_to_keep=args.num_checkpoints, maximize=False)
        logger.info("Writing to {0}\n".format(out_dir))
    elif OPTION == 'R':
        timestamp = input("[Input] Please input the checkpoints model you want to restore: ")
        while not (timestamp.isdigit() and len(timestamp) == 10):
            timestamp = input("[Warning] The format of your input is illegal, please re-input: ")
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        saver = cm.BestCheckpointSaver(save_dir=out_dir, num_to_keep=args.num_checkpoints, maximize=False)
        logger.info("Writing to {0}\n".format(out_dir))
        checkpoint = torch.load(out_dir)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    logger.info("Training...")
    writer = SummaryWriter('summary')

    def eval_model(val_loader, epoch):
        """
        Evaluate on the validation set.
        """
        net.eval()
        eval_loss = 0.0
        true_labels, predicted_scores = [], []
        for batch in val_loader:
            x_val_fb_content, x_val_fb_question, x_val_fb_option, \
            x_val_fb_clens, x_val_fb_qlens, x_val_fb_olens, y_val_fb = batch

            logits, scores = net(x_val_fb_content, x_val_fb_question, x_val_fb_option)
            avg_batch_loss = criterion(scores, y_val_fb)
            eval_loss = eval_loss + avg_batch_loss.item()
            for i in y_val_fb[0].tolist():
                true_labels.append(i)
            for j in scores[0].tolist():
                predicted_scores.append(j)

        # Calculate the Metrics
        eval_rmse = mean_squared_error(true_labels, predicted_scores) ** 0.5
        eval_r2 = r2_score(true_labels, predicted_scores)
        eval_pcc, eval_doa = dh.evaluation(true_labels, predicted_scores)
        eval_loss = eval_loss / len(val_loader)
        cur_value = eval_rmse
        logger.info("All Validation set: Loss {0:g} | PCC {1:.4f} | DOA {2:.4f} | RMSE {3:.4f} | R2 {4:.4f}"
                    .format(eval_loss, eval_pcc, eval_doa, eval_rmse, eval_r2))
        writer.add_scalar('validation loss', eval_loss, epoch)
        writer.add_scalar('validation PCC', eval_pcc, epoch)
        writer.add_scalar('validation DOA', eval_doa, epoch)
        writer.add_scalar('validation RMSE', eval_rmse, epoch)
        writer.add_scalar('validation R2', eval_r2, epoch)
        return cur_value

    for epoch in tqdm(range(args.epochs), desc="Epochs:", leave=True):
        # Training step
        batches = trange(len(train_loader), desc="Batches", leave=True)
        for batch_cnt, batch in zip(batches, train_loader):
            net.train()
            x_train_fb_content, x_train_fb_question, x_train_fb_option, \
            x_train_fb_clens, x_train_fb_qlens, x_train_fb_olens, y_train_fb = batch

            optimizer.zero_grad()   # 如果不置零,Variable 的梯度在每次 backward 的时候都会累加
            logits, scores = net(x_train_fb_content, x_train_fb_question, x_train_fb_option)
            avg_batch_loss = criterion(scores, y_train_fb)
            avg_batch_loss.backward()
            optimizer.step()    # Parameter updating
            batches.set_description("Batches (Loss={:.4f})".format(avg_batch_loss.item()))
            logger.info('[epoch {0}, batch {1}] loss: {2:.4f}'.format(epoch + 1, batch_cnt, avg_batch_loss.item()))
            writer.add_scalar('training loss', avg_batch_loss, batch_cnt)
        # Evaluation step
        cur_value = eval_model(val_loader, epoch)
        saver.handle(cur_value, net, optimizer, epoch)
    writer.close()

    logger.info('Training Finished.')
Exemple #14
0
def train():
    """Training QuesNet model."""
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file)
    val_data = dh.load_data_and_labels(args.validation_file)

    logger.info("Data padding...")
    train_dataset = dh.MyData(train_data.activity, train_data.timestep,
                              train_data.labels)
    val_dataset = dh.MyData(val_data.activity, val_data.timestep,
                            val_data.labels)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              collate_fn=dh.collate_fn)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            collate_fn=dh.collate_fn)

    # Load word2vec model
    COURSE_SIZE = dh.course2vec(args.course2vec_file)

    # Init network
    logger.info("Init nn...")
    net = MOOCNet(args, COURSE_SIZE).to(device)

    # weights_init(model=net)
    # print_weight(model=net)

    print("Model's state_dict:")
    for param_tensor in net.state_dict():
        print(param_tensor, "\t", net.state_dict()[param_tensor].size())

    criterion = Loss()
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.l2_lambda)

    if OPTION == 'T':
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "runs", timestamp))
        saver = cm.BestCheckpointSaver(save_dir=out_dir,
                                       num_to_keep=args.num_checkpoints,
                                       maximize=False)
        logger.info("Writing to {0}\n".format(out_dir))
    elif OPTION == 'R':
        timestamp = input(
            "[Input] Please input the checkpoints model you want to restore: ")
        while not (timestamp.isdigit() and len(timestamp) == 10):
            timestamp = input(
                "[Warning] The format of your input is illegal, please re-input: "
            )
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "runs", timestamp))
        saver = cm.BestCheckpointSaver(save_dir=out_dir,
                                       num_to_keep=args.num_checkpoints,
                                       maximize=False)
        logger.info("Writing to {0}\n".format(out_dir))
        checkpoint = torch.load(out_dir)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    logger.info("Training...")
    writer = SummaryWriter('summary')

    def eval_model(val_loader, epoch):
        """
        Evaluate on the validation set.
        """
        net.eval()
        eval_loss = 0.0
        true_labels, predicted_scores, predicted_labels = [], [], []
        for batch in val_loader:
            x_val, tsp_val, y_val = create_input_data(batch)
            logits, scores = net(x_val, tsp_val)
            avg_batch_loss = criterion(scores, y_val)
            eval_loss = eval_loss + avg_batch_loss.item()
            for i in y_val.tolist():
                true_labels.append(i)
            for j in scores.tolist():
                predicted_scores.append(j)
                if j >= args.threshold:
                    predicted_labels.append(1)
                else:
                    predicted_labels.append(0)

        # Calculate the Metrics
        eval_acc = accuracy_score(true_labels, predicted_labels)
        eval_pre = precision_score(true_labels, predicted_labels)
        eval_rec = recall_score(true_labels, predicted_labels)
        eval_F1 = f1_score(true_labels, predicted_labels)
        eval_auc = roc_auc_score(true_labels, predicted_scores)
        eval_prc = average_precision_score(true_labels, predicted_scores)
        eval_loss = eval_loss / len(val_loader)
        cur_value = eval_F1
        logger.info(
            "All Validation set: Loss {0:g} | ACC {1:.4f} | PRE {2:.4f} | REC {3:.4f} | F1 {4:.4f} | AUC {5:.4f} | PRC {6:.4f}"
            .format(eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc,
                    eval_prc))
        writer.add_scalar('validation loss', eval_loss, epoch)
        writer.add_scalar('validation ACC', eval_acc, epoch)
        writer.add_scalar('validation PRECISION', eval_pre, epoch)
        writer.add_scalar('validation RECALL', eval_rec, epoch)
        writer.add_scalar('validation F1', eval_F1, epoch)
        writer.add_scalar('validation AUC', eval_auc, epoch)
        writer.add_scalar('validation PRC', eval_prc, epoch)
        return cur_value

    for epoch in tqdm(range(args.epochs), desc="Epochs:", leave=True):
        # Training step
        batches = trange(len(train_loader), desc="Batches", leave=True)
        for batch_cnt, batch in zip(batches, train_loader):
            net.train()
            x_train, tsp_train, y_train = create_input_data(batch)
            optimizer.zero_grad()  # 如果不置零,Variable 的梯度在每次 backward 的时候都会累加
            logits, scores = net(x_train, tsp_train)
            # TODO
            avg_batch_loss = criterion(scores, y_train)
            avg_batch_loss.backward()
            optimizer.step()  # Parameter updating
            batches.set_description("Batches (Loss={:.4f})".format(
                avg_batch_loss.item()))
            logger.info('[epoch {0}, batch {1}] loss: {2:.4f}'.format(
                epoch + 1, batch_cnt, avg_batch_loss.item()))
            writer.add_scalar('training loss', avg_batch_loss, batch_cnt)
        # Evaluation step
        cur_value = eval_model(val_loader, epoch)
        saver.handle(cur_value, net, optimizer, epoch)
    writer.close()

    logger.info('Training Finished.')