示例#1
0
def train(eclms, train_filenames, val_filenames, server, init, train_epochs=3, batch_size=1, keep_rate=1.):
    initial_step = 1
    train_step_per_epoch = math.ceil(len(train_filenames)/batch_size)
    val_step_per_epoch = math.ceil(len(val_filenames)/batch_size)
    VERBOSE_STEP = eclm.verbose_step

    saver = tf.train.Saver()
    min_validation_loss = 100000000.
    with tf.Session(server.target) as sess:
        sess.run(init)
        writer = tf.summary.FileWriter("./graphs_1", sess.graph)
        for i in range(initial_step, initial_step + train_epochs):
            print('Epoch {:04}:'.format(i))
            train_losses = []
            val_losses = []

            for k in range(train_step_per_epoch//2):
                minibatch_losses = []
                for eclm, j in zip(eclms, list(range(2*k, 2*k + 2))):
                    x, y, pw, y_pw = restore_data_batch(train_filenames[j*batch_size : (j + 1)*batch_size])
                    loss = sess.run(eclm.loss,
                                    feed_dict={eclm.x: x, eclm.y: y, eclm.pw: pw, eclm.y_pw: y_pw,
                                               eclm.is_training: True, eclm.keep_rate: keep_rate})
                    minibatch_losses.append(loss)
                loss = np.mean(np.asarray(minibatch_losses))

                # Update gradient here
                grads_and_vars = eclms[0].optimizer.compute_gradients(tf.Variable(loss))
                for eclm in eclms:
                    eclm.optimizer.apply_gradients(grads_and_vars, global_step=eclm.global_step)

                if k % VERBOSE_STEP == 0:
                    print('     train_step {} - train_loss = {:0.7f}'.format(k, loss))
                train_losses.append(loss)
            train_losses = np.asarray(train_losses)
            avg_train_loss = np.mean(train_losses)
            print('Average Train Loss: {:0.7f}'.format(avg_train_loss))
            summary = tf.Summary()
            summary.value.add(tag="train_loss", simple_value=avg_train_loss)

            for j in range(val_step_per_epoch):
                x, y, pw, y_pw = restore_data_batch(val_filenames[j*batch_size : (j + 1)*batch_size])
                loss = sess.run(eclm.loss,
                                 feed_dict={eclm.x: x, eclm.y: y, eclm.pw: pw, eclm.y_pw: y_pw,
                                            eclm.is_training: False, eclm.keep_rate: keep_rate})
                if j % VERBOSE_STEP == 0:
                    print('     val_step {} - val_loss = {:0.7f}'.format(j, loss))
                val_losses.append(loss)
            val_losses = np.asarray(val_losses)
            avg_val_loss = np.mean(val_losses)
            if avg_val_loss < min_validation_loss:
                min_validation_loss = avg_val_loss
                saver.save(sess, "./checkpoint/best_model", i)
            print('Average Val Loss: {:0.7f}'.format(avg_val_loss))
            summary.value.add(tag="val_loss", simple_value=avg_val_loss)

            writer.add_summary(summary, global_step=i)
示例#2
0
def inference_all(sess, eclm, test_filenames, batch_size=1, keep_rate=1.0, inference_dir=None):
    if inference_dir:
        if not os.path.exists(inference_dir):
            os.makedirs(inference_dir)

    n_tests = len(test_filenames)
    steps = math.ceil(steps/batch_size)
    test_losses = []
    VERBOSE_STEP = eclm.verbose_step

    for i in range(steps):
        x, y, pw, y_pw = restore_data_batch(test_filenames[i*batch_size: (i + 1)*batch_size])
        loss, y_hat = sess.run([eclm.loss, eclm.y_hat],
                                feed_dict={eclm.x: x, eclm.y: y, eclm.pw: pw, eclm.y_pw: y_pw,
                                           eclm.is_training: False, eclm.keep_rate: keep_rate})
        if inference_dir:
            cache_data(y_hat, os.path.join(inference_dir, '{}.dat'.format(i)))
        if i % VERBOSE_STEP == 0:
            print('     test_{} - test_loss = {:0.7f}'.format(i, loss))
        test_losses.append(loss)
    test_losses = np.asarray(test_losses)
    avg_test_loss = np.mean(test_losses)
    print('Average Test Loss: {:0.7f}'.format(avg_test_loss))
    if inference_dir:
        cache_data(test_losses, os.path.join(inference_dir, 'loss.dat'))