예제 #1
0
        model = TreeConvlstm(params)

    elif config.MODEL_TYPE == "convlstm":           # Baseline model. Simple ConvLSTM model.
        model = ConvLSTM(params["cnn_params"])

    elif config.MODEL_TYPE == "graphlstm":          # SOTA model. Graph based LSTM model.
        params = config.GraphLSTMParams().__dict__
        model = GraphLSTM(params)

    plt.ion()                                       # Plotting object for the validation AP.
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.draw()

    for e in range(config.EPOCHS):                               # For every epoch...
        x, w, y = batch_generator.get_batch("train")    # Read training data.
        loss = model.fit(x, w, y)                       # Perform update on the training data. Also compute loss.
        print("EPOCH: {}".format(e))                    # Print the epoch number.
        print("Loss: " + str(float(loss)))              # Print the loss value for the current training epoch.
        print("")

        x, w, y = batch_generator.get_batch("validation")   # Get the validation set.
        y_hat = model.predict(x, w)                         # Compute the predictions for the validation set.
        pr_curve(y_hat, y, ax)                              # Compute the AP score for the validation score.

    x, w, y = batch_generator.get_batch("test")   # Get the test set.
    y_hat = model.predict(x, w)                   # Compute the predictions for the test set.
    pr_curve(y_hat, y, ax)                        # Compute the AP score for the test score.

plt.ioff()
plt.show()
예제 #2
0
class Classifier:
    def __init__(self):
        self._build()
        self.generator = BatchGenerator()

    def _build(self):
        self.label_placeholder = tf.placeholder(dtype=tf.int32,
                                                shape=[None, 10])
        self.dropout_rate = tf.placeholder_with_default(1.,
                                                        shape=None,
                                                        name='dropout_rate_ph')
        self.learning_rate = tf.placeholder_with_default(0.0001,
                                                         shape=None,
                                                         name='learning_rate')
        self.avr_loss = tf.placeholder(dtype=tf.float32,
                                       shape=None,
                                       name='avr_loss')
        self.avr_accuracy = tf.placeholder(dtype=tf.float32,
                                           shape=None,
                                           name='avr_accuracy')
        self.model_manager = ModelManager(self.dropout_rate)

        self.input_placeholder, self.output = self.model_manager.models[
            FLAGS.model]()
        #self.output = tf.nn.softmax(output_kernel)

        self.global_step = tf.Variable(0, trainable=False)

        self.loss = tf.losses.softmax_cross_entropy(self.label_placeholder,
                                                    self.output)

        _, self.accuracy = tf.metrics.accuracy(
            tf.argmax(self.label_placeholder, axis=1),
            tf.argmax(self.output, axis=1))

        self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(
            self.loss, global_step=self.global_step)

        with tf.name_scope('tensorboard_scalars'):
            self.scalar_step = tf.summary.scalar('step', self.global_step)

            scalar_loss = tf.summary.scalar('loss', self.loss)
            scalar_accuracy = tf.summary.scalar('accuracy', self.accuracy)

            self.scalars_irl = tf.summary.merge([scalar_loss, scalar_accuracy])

            scalar_avr_loss = tf.summary.scalar('scalar_avr_loss',
                                                self.avr_loss)
            scalar_avr_accuracy = tf.summary.scalar('scalar_avr_accuracy',
                                                    self.avr_accuracy)

            self.scalar_avr = tf.summary.merge(
                [scalar_avr_loss, scalar_avr_accuracy])

    def train(self):

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            self.saver = tf.train.Saver()
            train_writer, valid_writer, avr_train_writer, avr_valid_writer = self._get_writers(
                sess)
            try:
                lr = FLAGS.lr
                for epoch in range(FLAGS.epochs):

                    avr_accuracy = []
                    avr_loss = []
                    val_avr_accuracy = []
                    val_avr_loss = []

                    for i in range(self.generator.get_iters_count()):
                        _, accuracy, loss, summary_str = self._train_on_batch(
                            sess, [
                                self.train_step, self.accuracy, self.loss,
                                self.scalars_irl
                            ], FLAGS.dropout, lr)
                        train_writer.add_summary(
                            summary_str,
                            i + epoch * self.generator.get_iters_count())

                        #print('acc:', accuracy, 'loss:', loss)
                        avr_accuracy.append(accuracy)
                        avr_loss.append(loss)
                    train_steps_count = len(avr_loss)

                    train_avr = sess.run(
                        [self.scalar_avr],
                        feed_dict={
                            self.avr_accuracy:
                            sum(avr_accuracy) / train_steps_count,
                            self.avr_loss: sum(avr_loss) / train_steps_count,
                        })[0]
                    avr_train_writer.add_summary(train_avr, epoch)

                    for i in range(self.generator.get_val_iters_count()):
                        accuracy, loss, summary_str = self._train_on_batch(
                            sess, [self.accuracy, self.loss, self.scalars_irl],
                            1, lr, 'valid')
                        print('validation acc:', accuracy, 'loss:', loss)
                        valid_writer.add_summary(
                            summary_str,
                            i + epoch * self.generator.get_val_iters_count())
                        val_avr_accuracy.append(accuracy)
                        val_avr_loss.append(loss)
                    val_steps_count = len(val_avr_loss)
                    val_avr = sess.run(
                        [self.scalar_avr],
                        feed_dict={
                            self.avr_accuracy:
                            sum(val_avr_accuracy) / val_steps_count,
                            self.avr_loss: sum(val_avr_loss) / val_steps_count,
                        })[0]
                    avr_valid_writer.add_summary(val_avr, epoch)
                    lr *= FLAGS.decay
                    self._save(sess, epoch)
            except KeyboardInterrupt:
                self._save(sess, 10000)

    def _get_writers(self, sess):
        train_writer = tf.summary.FileWriter(
            'logs/{}_{}_{}/train'.format(self.__class__.__name__,
                                         FLAGS.batch_size, FLAGS.info),
            sess.graph)
        valid_writer = tf.summary.FileWriter('logs/{}_{}_{}/val'.format(
            self.__class__.__name__, FLAGS.batch_size, FLAGS.info))
        avr_train_writer = tf.summary.FileWriter(
            'logs/{}_{}_{}/avr_train'.format(self.__class__.__name__,
                                             FLAGS.batch_size, FLAGS.info))
        avr_valid_writer = tf.summary.FileWriter(
            'logs/{}_{}_{}/avr_val'.format(self.__class__.__name__,
                                           FLAGS.batch_size, FLAGS.info))
        return train_writer, valid_writer, avr_train_writer, avr_valid_writer

    def _train_on_batch(self, sess, eval, dropout, lr, mode='train'):
        try:
            batch_x, batch_y = self.generator.get_batch(mode)
        except:
            print(f'Problem with batch in {mode} part.')
            batch_x, batch_y = self.generator.get_batch(mode)

        result = sess.run(eval,
                          feed_dict={
                              self.input_placeholder: batch_x,
                              self.label_placeholder: batch_y,
                              self.dropout_rate: dropout,
                              self.learning_rate: lr
                          })
        return result

    def _save(self, sess, epoch):
        path = "sessions/{}_{}_{}/graph.ckpt".format(self.__class__.__name__,
                                                     FLAGS.batch_size,
                                                     FLAGS.info)
        if not os.path.exists(path):
            os.makedirs(path)
        save_path = self.saver.save(sess, path, epoch)
        print("Model saved in path: %s" % save_path)
예제 #3
0
def train_rnn(training_articles, testing_articles, n_epochs, batch_size,
              seq_length, char_skip, dropout_pkeep, force_retrain):
    print "[ INFO] Parsing training articles..."
    training_batch_generator = BatchGenerator(training_articles, batch_size,
                                              seq_length, char_skip)

    print "[ INFO] Parsing validation articles..."
    validation_batch_generator = BatchGenerator(testing_articles, batch_size,
                                                seq_length, char_skip)

    model_file = get_model_file()
    if model_file and not force_retrain:
        rnn_model = RNNModel.load_from_model_file(model_file)
        state_file = os.path.join(MODEL_SAVE_DIR, 'saved-vars.npz')
        if not os.path.exists(state_file):
            raise IOError("Numpy state file does not exist")
        saved_vars = np.load(state_file)
        istate = saved_vars['cell-state']
        training_batch_generator.restore_state_dict(**saved_vars)
        print "[ INFO] Resuming training from epoch %d, global step %d" % (
            training_batch_generator.n_epochs, rnn_model.training_step_num)
    else:
        print "[ INFO] Initializing RNN"
        rnn_model = RNNModel(max_seq_length=seq_length)
        rnn_model.init_network()
        istate = np.zeros(shape=(rnn_model.n_layers, 2, batch_size,
                                 rnn_model.cell_size))

    log_dir = os.path.join(
        LOG_DIR,
        'training_%s' % datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"))
    os.makedirs(log_dir)
    log_file = open(os.path.join(log_dir, 'log.txt'), 'w')

    validation_accuracies = list()
    validation_losses = list()
    validation_steps = list()

    while training_batch_generator.n_epochs < n_epochs:
        batch, labels, seq_length_arr, istate = training_batch_generator.get_batch(
            istate)
        pred, ostate, acc = rnn_model.process_training_batch(
            batch, labels, seq_length_arr, istate, dropout_pkeep)

        if rnn_model.training_step_num % DISPLAY_INTERVAL == 0:
            print "[ INFO] Accuracy at step %d (epoch %d): %.3f" % (
                rnn_model.training_step_num,
                training_batch_generator.n_epochs + 1, acc)
            print "[ INFO] Prediction of first sample in minibatch: %s" % idx_arr_to_str(
                pred[0])

        if rnn_model.training_step_num % TEXT_PREDICTION_LOG_INTERVAL == 0:
            log_file.write("Text prediction at step %d:\n" %
                           rnn_model.training_step_num)
            for i in range(batch_size):
                log_file.write(idx_arr_to_str(pred[i]) + '\n')
            log_file.write(
                "-----------------------------------------------------\n")

        if rnn_model.training_step_num % MODEL_SAVE_INTERVAL == 0:
            print "[ INFO] Saving model..."
            rnn_model.tf_saver.save(rnn_model.session,
                                    os.path.join(MODEL_SAVE_DIR, MODEL_PREFIX),
                                    global_step=rnn_model.training_step_num)

            # also save the cell state and counters of the BatchGenerator
            vars_to_store = training_batch_generator.get_state_dict()
            vars_to_store.update({'cell-state': ostate})
            np.savez(os.path.join(MODEL_SAVE_DIR, 'saved-vars.npz'),
                     **vars_to_store)

        if rnn_model.training_step_num % VALIDATION_INTERVAL == 0:
            print "[ INFO] Starting validation run"
            avg_loss, avg_accuracy = perform_validation_run(
                rnn_model, validation_batch_generator)
            validation_steps.append(rnn_model.training_step_num)
            validation_accuracies.append(avg_accuracy)
            validation_losses.append(avg_loss)

            plt.plot(validation_steps, validation_accuracies, label='accuracy')
            plt.plot(validation_steps, validation_losses, label='loss')

            plt.xlabel('Training Step')
            plt.yticks(np.arange(0., 1.05, 0.05))
            plt.legend(loc='upper left')
            plt.grid(True)
            plt.savefig(
                os.path.join(log_dir, 'validation_loss-accuracy-plot.png'))
            plt.close()

        istate = ostate

    log_file.close()
                                     loss,
                                     transformer_net.trainable_weights,
                                     lr=args.lr,
                                     dec=args.lr_decay))
    get_loss = theano.function([], loss)

    # Run the optimization loop.
    train_losses, val_losses = [], []
    with tqdm(desc="Training",
              file=sys.stdout,
              ncols=100,
              total=args.train_iterations,
              ascii=False,
              unit="iteration") as trbar:
        for tri in range(args.train_iterations):
            X.set_value(train_batch_generator.get_batch(), borrow=True)
            loss = optim_step().item()
            train_losses.append(loss)
            trbar.set_description("Training (loss {:.3g})".format(loss))
            trbar.update(1)

            if (tri + 1) % args.val_every == 0 or (tri +
                                                   1) == args.train_iterations:
                batch_val_losses = []
                n_val = 0
                with tqdm(desc="Validating",
                          file=sys.stdout,
                          ncols=100,
                          total=args.val_iterations,
                          ascii=False,
                          unit="iteration",