Ejemplo n.º 1
0
def train(num_training_iterations, report_interval):
	"""Trains the DNC and periodically reports the loss."""

	dataset = repeat_copy.RepeatCopy(
			FLAGS.num_bits, FLAGS.batch_size,
			FLAGS.min_length, FLAGS.max_length,
			FLAGS.min_repeats, FLAGS.max_repeats)
	dataset_tensors = dataset()

	output_logits = run_model(dataset_tensors.observations, dataset.target_size)
	# Used for visualization.
	output = tf.round(tf.expand_dims(dataset_tensors.mask, -1) * tf.sigmoid(output_logits))

	train_loss = dataset.cost(output_logits, dataset_tensors.target, dataset_tensors.mask)

	# Set up optimizer with global norm clipping.
	trainable_variables = tf.trainable_variables()
	grads, _ = tf.clip_by_global_norm(tf.gradients(train_loss, trainable_variables), FLAGS.max_grad_norm)

	global_step = tf.get_variable(
			name="global_step",
			shape=[],
			dtype=tf.int64,
			initializer=tf.zeros_initializer(),
			trainable=False,
			collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP])

	optimizer = tf.train.RMSPropOptimizer(FLAGS.learning_rate, epsilon=FLAGS.optimizer_epsilon)
	train_step = optimizer.apply_gradients(zip(grads, trainable_variables), global_step=global_step)

	saver = tf.train.Saver()

	if FLAGS.checkpoint_interval > 0:
		hooks = [
				tf.train.CheckpointSaverHook(
						checkpoint_dir=FLAGS.checkpoint_dir,
						save_steps=FLAGS.checkpoint_interval,
						saver=saver)
		]
	else:
		hooks = []

	# Train.
	with tf.train.SingularMonitoredSession(hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir) as sess:

		start_iteration = sess.run(global_step)
		total_loss = 0

		for train_iteration in xrange(start_iteration, num_training_iterations):
			_, loss = sess.run([train_step, train_loss])
			total_loss += loss

			if (train_iteration + 1) % report_interval == 0:
				dataset_tensors_np, output_np = sess.run([dataset_tensors, output])
				dataset_string = dataset.to_human_readable(dataset_tensors_np, output_np)
				tf.logging.info("%d: Avg training loss %f.\n%s", train_iteration, total_loss / report_interval, dataset_string)
				total_loss = 0
Ejemplo n.º 2
0
def main(unused_argv):
    dataset = repeat_copy.RepeatCopy(FLAGS.num_bits, FLAGS.batch_size,
                                     FLAGS.min_length, FLAGS.max_length,
                                     FLAGS.min_repeats, FLAGS.max_repeats)

    input_size = FLAGS.num_bits + 2
    output_size = FLAGS.num_bits + 1
    batch_size = FLAGS.batch_size
    wrapper = DNCWrapper(input_size, output_size, batch_size, dataset.cost)
    tf.logging.set_verbosity(3)  # Print INFO log messages.
    # sending one batch of observations
    # -> TODO single observation
    dataset_tensors = dataset()

    # Train.
    with tf.Session() as sess:
        wrapper.compile(sess,
                        save_dir="./../../saves/simple/repeat_copy/",
                        save_name="repeat_copy")
        try:
            wrapper.restore()
        except:
            pass

        total_loss = 0
        for i in range(1, FLAGS.num_training_iterations + 1):
            dt = sess.run(dataset_tensors)
            loss, prediction, output = wrapper.train(dt.observations,
                                                     dt.target, dt.mask)

            total_loss += loss
            if i % FLAGS.report_interval == 0:
                dataset_string = ""  #dataset.to_human_readable(dataset_tensors_np, output_np)
                tf.logging.info("%d: Avg training loss %f.\n%s", i,
                                total_loss / FLAGS.report_interval,
                                dataset_string)
                total_loss = 0

        wrapper.save()
Ejemplo n.º 3
0
def train(num_training_iterations, report_interval):
    # from DNC-master/train.py
    dataset = repeat_copy.RepeatCopy(FLAGS.num_bits,
                                     FLAGS.batch_size,
                                     FLAGS.min_length,
                                     FLAGS.max_length,
                                     FLAGS.min_repeats,
                                     FLAGS.max_repeats,
                                     time_average_cost=True)
    dataset_tensors = dataset()
    data_length = tf.shape(dataset_tensors.observations)[0]
    zeros_obs = tf.zeros(
        [data_max_length - data_length, FLAGS.batch_size, FLAGS.num_bits + 2])
    obs_new = tf.concat([dataset_tensors.observations, zeros_obs], 0)
    output_logits, ins_L2_norm, ins_sequence, ins_prob =\
            run_model(obs_new, dataset.target_size)
    output_logits = output_logits[:data_length, :, :]
    # Used for visualization.
    output = tf.round(
        tf.expand_dims(dataset_tensors.mask, -1) * tf.sigmoid(output_logits))

    train_loss = dataset.cost(output_logits, dataset_tensors.target,
                              dataset_tensors.mask)
    train_loss += tf.where(train_loss < 10, 1., 0.) * ins_L2_norm
    # Set up optimizer with global norm clipping.
    trainable_variables = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(
        tf.gradients(train_loss, trainable_variables), FLAGS.max_grad_norm)
    global_step = tf.get_variable(
        name="global_step",
        shape=[],
        dtype=tf.int64,
        initializer=tf.zeros_initializer(),
        trainable=False,
        collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP])
    #   optimizer = tf.train.RMSPropOptimizer(FLAGS.learning_rate, epsilon=FLAGS.optimizer_epsilon)
    optimizer = tf.train.AdamOptimizer()
    tf.train.AdamOptimizer
    train_step = optimizer.apply_gradients(zip(grads, trainable_variables),
                                           global_step=global_step)
    hooks = []
    # scalar
    tf.summary.scalar('loss', train_loss)
    hooks.append(
        tf.train.SummarySaverHook(save_steps=5,
                                  output_dir=FLAGS.checkpoint_dir + "/logs",
                                  summary_op=tf.summary.merge_all()))
    # saver
    saver = tf.train.Saver(max_to_keep=50)
    if FLAGS.checkpoint_interval > 0:
        hooks.append(
            tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir,
                                         save_steps=FLAGS.checkpoint_interval,
                                         saver=saver))
    # Train.
    with tf.train.SingularMonitoredSession(
            hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir) as sess:
        start_iteration = sess.run(global_step)
        total_loss = 0

        for train_iteration in range(start_iteration, num_training_iterations):
            _, loss = sess.run([train_step, train_loss])
            total_loss += loss

            if (train_iteration + 1) % report_interval == 0:
                dataset_tensors_np, output_np, ins_sequence_np, ins_prob_np =\
                        sess.run([dataset_tensors, output, ins_sequence, ins_prob])
                dataset_string = dataset.to_human_readable(
                    dataset_tensors_np, output_np)
                print("%d: Avg training loss %f.  %s\n  %s\n  %s" %
                      (train_iteration, total_loss / report_interval,
                       dataset_string, ins_sequence_np, ins_prob_np))
                sys.stdout.flush()
                total_loss = 0