Exemplo n.º 1
0
 def train_loop(self, training_op: tf.Operation, summary_op: tf.Operation):
     self.sum_writer.add_session_log(
         tf.SessionLog(status=tf.SessionLog.START))
     with self.sess.as_default():
         # initialization
         coord = tf.train.Coordinator(
         )  # start coordinator and threads for queues
         threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)
         # training loop
         for step in range(self.max_steps):
             i = self.step.eval()  # the current step
             # evaluate
             if i % self.config.eval_freq == 0:
                 self.evaluate()
                 # run training and summary
             _, summary = self.sess.run([training_op, summary_op])
             # sess.run(training_op)
             if i % self.config.sum_save_freq == 0:
                 self.sum_writer.add_summary(summary, i)
             # log every so often
             if step % self.config.log_freq == 0:
                 self.log()
                 # print("Step ", i, ": ", last_eval_accuracy.eval())
                 self.save_checkpoint()
         coord.request_stop()
         coord.join(threads, stop_grace_period_secs=60)
         self.sum_writer.add_session_log(
             tf.SessionLog(status=tf.SessionLog.STOP))
Exemplo n.º 2
0
def reload_checkpoint_if_exists(sess, saver, train_writer, validation_writer,
                                test_writer):
    """
    restore existing model from checkpoint data
    """
    global_step = -1
    if FLAGS.continue_run:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            # extract global_step from it.
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print("checkpoint found at step %d", global_step)
            # ensure that the writers ignore saved summaries that occurred after the last checkpoint but before a crash
            train_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.START), global_step)
            validation_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.START), global_step)
            test_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.START), global_step)
        else:
            print('No checkpoint file found')
    return global_step
Exemplo n.º 3
0
    def load(self, step=0):

        print(sys.path)

        # checkpoint_dir = '/Users/cc/Project/Lean/Launcher/bin/Debug/python/oracle/data/'
        checkpoint_dir = './data'

        try:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            self.learn_step_counter = int(
                os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        except:
            ckpt = None

        if not (ckpt and ckpt.model_checkpoint_path):
            print('Cannot find any saved sess in checkpoint_dir')
            #sys.exit(2)
        else:
            try:
                # self.saver = tf.train.Saver()
                self.saver.restore(self.sess, ckpt.model_checkpoint_path)
                self.summary_writer.add_session_log(
                    tf.SessionLog(status=tf.SessionLog.START),
                    global_step=step)
                print('Sess restored successfully: {}'.format(
                    ckpt.model_checkpoint_path))
            except Exception as e:
                print('Failed to load sess: {}'.format(str(e)))
                # sys.exit(2)
                self.learn_step_counter = 1
Exemplo n.º 4
0
    def after_run(self, run_context, run_values):
        """Add summaries/profiling if requested."""
        _ = run_context
        if not self._summary_writer:
            return

        stale_global_step = run_values.results["global_step"]
        global_step = stale_global_step + 1
        if self._next_step is None or self._request_summary or \
                self._request_profile:
            global_step = run_context.session.run(self._global_step_tensor)

        if self._next_step is None:
            self._summary_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.START), global_step)
        if self._request_summary:
            self._timer.update_last_triggered_step(global_step)
            if "summary" in run_values.results:
                for summary in run_values.results["summary"]:
                    self._summary_writer.add_summary(summary, global_step)

        if self._request_profile:
            self._profile_timer.update_last_triggered_step(global_step)
            self._summary_writer.add_run_metadata(run_values.run_metadata,
                                                  "step{}".format(global_step),
                                                  global_step=global_step)
            print("Added profiling for step {}.".format(global_step))
        self._next_step = global_step + 1
Exemplo n.º 5
0
    def initialize_tf_variables(self):
        """
        Initialize tensorflow variables (either initializes them from scratch or restores from checkpoint).
        
        :return: updated TeLL session
        """
        session = self.tf_session
        checkpoint = self.workspace.get_checkpoint()
        #
        # Initialize or load variables
        #
        with Timer(name="Initializing variables"):
            session.run(tf.global_variables_initializer())
            session.run(tf.local_variables_initializer())

        if checkpoint is not None:
            # restore from checkpoint
            self.tf_saver.restore(session, checkpoint)
            # get step number from checkpoint
            step = session.run(self.__global_step_placeholder) + 1
            self.global_step = step
            # reopen summaries
            for _, summary in self.tf_summaries.items():
                summary.reopen()
                summary.add_session_log(
                    tf.SessionLog(status=tf.SessionLog.START),
                    global_step=step)
            print("Resuming from checkpoint '{}' at iteration {}".format(
                checkpoint, step))
        else:
            for _, summary in self.tf_summaries.items():
                summary.add_graph(session.graph)

        return self
Exemplo n.º 6
0
    def testSessionLogStartMessageDiscardsExpiredEvents(self):
        """Test that SessionLog.START message discards expired events.

    This discard logic is preferred over the out-of-order step discard logic,
    but this logic can only be used for event protos which have the SessionLog
    enum, which was introduced to event.proto for file_version >= brain.Event:2.
    """
        gen = _EventGenerator(self)
        acc = ea.EventAccumulator(gen)
        gen.AddEvent(
            tf.Event(wall_time=0, step=1, file_version='brain.Event:2'))

        gen.AddScalar('s1', wall_time=1, step=100, value=20)
        gen.AddScalar('s1', wall_time=1, step=200, value=20)
        gen.AddScalar('s1', wall_time=1, step=300, value=20)
        gen.AddScalar('s1', wall_time=1, step=400, value=20)

        gen.AddScalar('s2', wall_time=1, step=202, value=20)
        gen.AddScalar('s2', wall_time=1, step=203, value=20)

        slog = tf.SessionLog(status=tf.SessionLog.START)
        gen.AddEvent(tf.Event(wall_time=2, step=201, session_log=slog))
        acc.Reload()
        self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200])
        self.assertEqual([x.step for x in acc.Scalars('s2')], [])
def train_validate_test(session, experiment):
    """
    """
    FLAGS = tf.app.flags.FLAGS

    model = experiment['model']

    source_ckpt_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    if source_ckpt_path is None:
        session.run(tf.global_variables_initializer())
    else:
        tf.train.Saver().restore(session, source_ckpt_path)

    step = session.run(model['step'])

    # NOTE: exclude log which does not happend yet :)
    experiment['scribe'].add_session_log(
        tf.SessionLog(status=tf.SessionLog.START), global_step=step)

    while session.run(model['step']) < FLAGS.training_stop_step:
        train(session, experiment)

        save(session, experiment)

        validate(session, experiment)

        test(session, experiment)
 def __init__(self, log_dir, start_step=0):
     """Create a summary writer logging to log_dir."""
     self.writer = tf.summary.FileWriter(log_dir)
     if start_step != 0:
         self.writer.add_session_log(
             tf.SessionLog(status=tf.SessionLog.START),
             global_step=start_step)
Exemplo n.º 9
0
def summary_thread(coord, mngr, sess, path, rstrd):
    """ Summary thread entry point.
  Args:
    coord Coordinator to use
    mngr  Graph manager to use
    sess  Session to use
    path  Path to the manager to use
    rstrd Whether the model was just restored from a checkpoint
  """
    global args
    delta = args.summary_delta
    period = args.summary_period
    if delta < 0 and period < 0:  # Effectively disabled
        tools.info("Summary saving is effectively disabled")
        return
    if mngr.summary_tn is None:
        tools.warning("No summary to save")
        return
    if rstrd:
        last_step = sess.run(mngr.step)
        last_time = time.time()
    else:
        last_step = -delta
        last_time = -period
    # Save summaries
    with mngr.graph.as_default():
        with tf.summary.FileWriter(args.summary_dir,
                                   graph=mngr.graph) as writer:
            writer.add_session_log(tf.SessionLog(status=tf.SessionLog.START),
                                   sess.run(mngr.step))
            while True:
                time.sleep(config.thread_idle_delay)
                step = sess.run(mngr.step)
                now = time.time()
                stop = coord.should_stop()
                if stop or (delta >= 0 and step - last_step >= delta) or (
                        period >= 0. and now - last_time >= period):
                    writer.add_summary(sess.run(mngr.summary_tn), step)
                    tools.info("Summaries saved (took " +
                               repr(time.time() - now) + " s)")
                    last_step = sess.run(mngr.step)
                    last_time = time.time()
                    if stop:
                        break
            writer.add_session_log(tf.SessionLog(status=tf.SessionLog.STOP),
                                   step)
Exemplo n.º 10
0
def train():
    """
    """
    # tensorflow
    ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)
    ckpt_target_path = os.path.join(FLAGS.ckpt_dir_path, 'model.ckpt')

    srcnn = build_srcnn()
    summaries = build_summaries(srcnn)

    reporter = tf.summary.FileWriter(FLAGS.logs_dir_path)

    with tf.Session() as session:
        if ckpt_source_path is None:
            session.run(tf.global_variables_initializer())
        else:
            tf.train.Saver().restore(session, ckpt_source_path)

        # give up overlapped old data
        step = session.run(srcnn['step'])

        reporter.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START), global_step=step)

        # make dataset reader work
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        while True:
            # discriminator
            fetches = {
                'loss': srcnn['loss'],
                'step': srcnn['step'],
                'trainer': srcnn['trainer'],
            }

            if step % 500 == 0:
                fetches['summary'] = summaries['summary_plus']
            else:
                fetches['summary'] = summaries['summary_part']

            fetched = session.run(fetches)

            step = fetched['step']

            reporter.add_summary(fetched['summary'], step)

            if step % 100 == 0:
                print('loss[{}]: {}'.format(step, fetched['loss']))

            if step % 5000 == 0:
                tf.train.Saver().save(
                    session,
                    ckpt_target_path,
                    global_step=srcnn['step'])

        coord.request_stop()
        coord.join(threads)
Exemplo n.º 11
0
    def __init__(self,
                 sess,
                 model_fn,
                 input_size,
                 num_action,
                 game,
                 restore=False,
                 discount=0.99,
                 lr=1e-4,
                 vf_coef=0.25,
                 ent_coef=1e-3,
                 clip_grads=1.,
                 agenttype="vpg"):
        self.sess, self.discount = sess, discount
        self.vf_coef, self.ent_coef = vf_coef, ent_coef
        self.game = game
        self.global_step_tensor = tf.Variable(0,
                                              trainable=False,
                                              name='global_step')
        self.agenttype = agenttype

        if game == "Pong-v0":
            (self.policy,
             self.value), self.inputs = model_fn(input_size, num_action)
            #print(sample(self.policy))
            self.action = sample(self.policy)
        else:
            (self.policy, self.value), self.inputs = model_fn(num_action)
            self.action = sample(self.policy)
        loss_fn, loss_val, self.loss_inputs = self._loss_func()

        self.step = tf.Variable(0, trainable=False)
        #opt = tf.train.RMSPropOptimizer(learning_rate=lr, decay=0.99, epsilon=1e-5)
        opt = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-5)
        #self.train_op = layers.optimize_loss(loss=loss_fn, optimizer=opt, learning_rate=None, global_step= self.global_step_tensor, clip_gradients=clip_grads)
        #self.train_op_val = layers.optimize_loss(loss=loss_val, optimizer=opt, learning_rate=None, global_step= self.global_step_tensor, clip_gradients=clip_grads)
        self.train_op = layers.optimize_loss(
            loss=loss_fn,
            optimizer=opt,
            learning_rate=None,
            global_step=self.global_step_tensor)
        self.train_op_val = layers.optimize_loss(
            loss=loss_val,
            optimizer=opt,
            learning_rate=None,
            global_step=self.global_step_tensor)
        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver()
        if restore:
            self.saver.restore(
                self.sess, tf.train.latest_checkpoint('weights/' + self.game))

        self.summary_op = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter('logs/' + self.game,
                                                    graph=None)
        self.summary_writer.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START), sess.run(self.step))
Exemplo n.º 12
0
    def testAddingSummaryAndGraph(self):
        test_dir = self._CleanTestDir("basics")
        sw = tf.train.SummaryWriter(test_dir)

        sw.add_session_log(tf.SessionLog(status=SessionLog.START), 1)
        sw.add_summary(
            tf.Summary(value=[tf.Summary.Value(tag="mee", simple_value=10.0)]),
            10)
        sw.add_summary(
            tf.Summary(value=[tf.Summary.Value(tag="boo", simple_value=20.0)]),
            20)
        with tf.Graph().as_default() as g:
            tf.constant([0], name="zero")
        gd = g.as_graph_def()
        sw.add_graph(gd, global_step=30)
        sw.close()
        rr = self._EventsReader(test_dir)

        # The first event should list the file_version.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals("brain.Event:2", ev.file_version)

        # The next event should be the START message.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals(1, ev.step)
        self.assertEquals(SessionLog.START, ev.session_log.status)

        # The next event should have the value 'mee=10.0'.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals(10, ev.step)
        self.assertProtoEquals(
            """
      value { tag: 'mee' simple_value: 10.0 }
      """, ev.summary)

        # The next event should have the value 'boo=20.0'.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals(20, ev.step)
        self.assertProtoEquals(
            """
      value { tag: 'boo' simple_value: 20.0 }
      """, ev.summary)

        # The next event should have the graph_def.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals(30, ev.step)
        ev_graph = tf.GraphDef()
        ev_graph.ParseFromString(ev.graph_def)
        self.assertProtoEquals(gd, ev_graph)

        # We should be done.
        self.assertRaises(StopIteration, lambda: next(rr))
def main():
    restore_model = args.restore
    print(restore_model)
    seq_len = args.seq_len
    batch_size = args.batch_size
    num_epoch = args.epochs
    batches_per_epoch = 1000

    batch_generator = BatchGenerator(batch_size, seq_len)
    g, vs = create_graph(batch_generator.num_letters,
                         batch_size,
                         num_units=args.units,
                         lstm_layers=args.lstm_layers,
                         window_mixtures=args.window_mixtures,
                         output_mixtures=args.output_mixtures)

    with tf.Session(graph=g) as sess:
        model_saver = tf.train.Saver(max_to_keep=2)
        if restore_model:
            model_file = tf.train.latest_checkpoint(
                os.path.join(restore_model, 'models'))
            experiment_path = restore_model
            epoch = int(model_file.split('-')[-1]) + 1
            model_saver.restore(sess, model_file)
        else:
            sess.run(tf.global_variables_initializer())
            experiment_path = next_experiment_path()
            epoch = 0

        summary_writer = tf.summary.FileWriter(experiment_path,
                                               graph=g,
                                               flush_secs=10)
        summary_writer.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START),
            global_step=epoch * batches_per_epoch)

        for e in range(epoch, num_epoch):
            print('\nEpoch {}'.format(e))
            for b in range(1, batches_per_epoch + 1):
                coords, seq, reset, needed = batch_generator.next_batch()
                if needed:
                    sess.run(vs.reset_states, feed_dict={vs.reset: reset})
                l, s, _ = sess.run([vs.loss, vs.summary, vs.train_step],
                                   feed_dict={
                                       vs.coordinates: coords,
                                       vs.sequence: seq
                                   })
                summary_writer.add_summary(s,
                                           global_step=e * batches_per_epoch +
                                           b)
                print('\r[{:5d}/{:5d}] loss = {}'.format(
                    b, batches_per_epoch, l),
                      end='')

            model_saver.save(sess,
                             os.path.join(experiment_path, 'models', 'model'),
                             global_step=e)
Exemplo n.º 14
0
    def testCloseAndReopen(self):
        test_dir = self._CleanTestDir("close_and_reopen")
        sw = tf.train.SummaryWriter(test_dir)
        sw.add_session_log(tf.SessionLog(status=SessionLog.START), 1)
        sw.close()
        # Sleep at least one second to make sure we get a new event file name.
        time.sleep(1.2)
        sw.reopen()
        sw.add_session_log(tf.SessionLog(status=SessionLog.START), 2)
        sw.close()

        # We should now have 2 events files.
        event_paths = sorted(glob.glob(os.path.join(test_dir, "event*")))
        self.assertEquals(2, len(event_paths))

        # Check the first file contents.
        rr = tf.train.summary_iterator(event_paths[0])
        # The first event should list the file_version.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals("brain.Event:2", ev.file_version)
        # The next event should be the START message.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals(1, ev.step)
        self.assertEquals(SessionLog.START, ev.session_log.status)
        # We should be done.
        self.assertRaises(StopIteration, lambda: next(rr))

        # Check the second file contents.
        rr = tf.train.summary_iterator(event_paths[1])
        # The first event should list the file_version.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals("brain.Event:2", ev.file_version)
        # The next event should be the START message.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals(2, ev.step)
        self.assertEquals(SessionLog.START, ev.session_log.status)
        # We should be done.
        self.assertRaises(StopIteration, lambda: next(rr))
Exemplo n.º 15
0
    def __init__(self, sess, agent_eA, agent_eB, model_fn, config, weight_dir, log_dir, args):
        gg = tf.Graph()
        with gg.as_default():
            self.config, self.discount = config, args.discount
            self.vf_coef, self.ent_coef = args.vf_coef, args.ent_coef
            self.weight_dir = weight_dir
            self.log_dir = log_dir
            self.save_interval = args.save_interval if args is not None else 500
            self.args = args
            (self.policy, self.value), self.inputs = model_fn(config)
            self.action = [sample(p) for p in self.policy]
            loss_distill, self.loss_distill_inputs, self.distill_policyLoss_scalar, self.distill_value_loss_scalar = self._distill_loss_func()

            self.step = tf.Variable(0, trainable=False)
            if self.args.lr_decay:
                self.lr = tf.train.exponential_decay(args.lr, self.step, args.lr_decay_step, args.lr_decay_rate, staircase=True)
            else:
                self.lr = args.lr
            self.next_best = args.save_best_start
            self.best_interval = args.save_best_inc
            if args.optimizer == 'adam':
                opt = tf.train.AdamOptimizer(
                    learning_rate=self.lr, beta1=args.beta1, beta2=args.beta2, epsilon=args.epsilon)
            elif args.optimizer == 'rmsprop':
                opt = tf.train.RMSPropOptimizer(
                    learning_rate=self.lr, decay=args.decay, momentum=args.momentum, epsilon=args.epsilon)

            self.train_op_distill = layers.optimize_loss(
                loss=loss_distill, optimizer=opt, learning_rate=None, global_step=self.step)

        sess = self.sess = tf.Session(graph=gg)
        print('weights restored from', args.paths[-1])

        # restore the weights by restore
        with gg.as_default():              
            self.saver = tf.train.Saver(max_to_keep=args.num_snapshot) # the main difference
            print ('restrore the model from', tf.train.get_checkpoint_state(args.paths[-1]).model_checkpoint_path)
            self.saver.restore(self.sess, tf.train.get_checkpoint_state(args.paths[-1]).model_checkpoint_path)
            if not self.args.Student_restore:
                self.sess.run(tf.variables_initializer(opt.variables()))
                self.sess.run(tf.assign(self.step, 0))
                print ('the global step start from', sess.run(self.step))
                print ('the optimizer is initialized')          
            self.summary_op = tf.summary.merge_all()
            self.summary_writer = tf.summary.FileWriter(self.log_dir, graph=None)
            self.summary_writer.add_session_log(tf.SessionLog(status=tf.SessionLog.START), sess.run(self.step))

        self.eA_input = agent_eA.input_e
        self.eB_input = agent_eB.input_e
Exemplo n.º 16
0
    def restore_or_init(self):
        self.saver = tf.train.Saver()
        checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
        if checkpoint:
            self.saver.restore(self.sess, checkpoint)

            if self.training_enabled:
                # merge with previous summary session
                self.summary_writer.add_session_log(
                    tf.SessionLog(status=tf.SessionLog.START), self.sess.run(self.global_step))
        else:
            self.sess.run(tf.global_variables_initializer())
        # this call locks the computational graph into read-only state,
        # as a safety measure against memory leaks caused by mistakingly adding new ops to it
        self.sess.graph.finalize()
Exemplo n.º 17
0
    def __init__(self,
                 sess,
                 model_fn,
                 config,
                 restore=False,
                 discount=0.99,
                 lr=1e-4,
                 vf_coef=0.25,
                 ent_coef=1e-3,
                 clip_grads=1.):
        self.sess, self.config, self.discount = sess, config, discount
        self.vf_coef, self.ent_coef = vf_coef, ent_coef

        (self.policy, self.value), self.inputs = model_fn(
            config
        )  # policy[0]是one_hot动作函数,shape = [None, 524];  policy[1:]为13个one_hot参数shape = [[None, dim1], [None, dim2],...], spatial的feature维度是1024
        self.action = [sample(p) for p in self.policy]  # action 返回的是数值

        with tf.variable_scope('loss'):
            loss_fn, self.loss_inputs = self._loss_func()

        with tf.variable_scope('train'):
            self.step = tf.Variable(0, trainable=False)
            opt = tf.train.RMSPropOptimizer(learning_rate=lr,
                                            decay=0.99,
                                            epsilon=1e-5)
            # opt = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-5)
            self.train_op = layers.optimize_loss(loss=loss_fn,
                                                 optimizer=opt,
                                                 learning_rate=None,
                                                 global_step=self.step,
                                                 clip_gradients=clip_grads)
            self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver()
        if restore:
            self.saver.restore(
                self.sess,
                tf.train.latest_checkpoint('weights/' + self.config.full_id()))
            if self.config.imitation:
                self.sess.run(tf.assign(self.step, 0))

        self.summary_op = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter('logs/' +
                                                    self.config.full_id(),
                                                    graph=None)
        self.summary_writer.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START), sess.run(self.step))
Exemplo n.º 18
0
 def after_run(self, run_context, run_values):
     if not self._summary_writer:
         return
     stale_global_step = run_values.results['global_step']
     global_step = stale_global_step + 1
     if self._next_step is None or self._request_summary:
         global_step = run_context.session.run(self._global_step_tensor)
     if self._next_step is None:
         self._summary_writer.add_session_log(tf.SessionLog(status=tf.SessionLog.START), global_step)
     if self._request_summary:
         self._timer.update_last_triggered_step(global_step)
         if 'summary' in run_values.results:
             for summary in run_values.results['summary']:
                 self._summary_writer.add_summary(summary, global_step)
                 self._summary_writer.flush()
     self._next_step = global_step
Exemplo n.º 19
0
    def __init__(self,
                 sess,
                 model_fn,
                 config,
                 restore=False,
                 discount=0.99,
                 lr=1e-4,
                 vf_coef=0.25,
                 ent_coef=1e-3,
                 clip_grads=1.,
                 save_best_only=False,
                 train=True):
        self.sess, self.config, self.discount = sess, config, discount
        self.vf_coef, self.ent_coef = vf_coef, ent_coef
        self.save_best_only = save_best_only

        (self.policy, self.value), self.inputs = model_fn(config, train)
        self.action = [sample(p) for p in self.policy]
        loss_fn, self.loss_inputs = self._loss_func()

        self.step = tf.Variable(0, trainable=False)

        if config.map == "FindAndDefeatZerglings":
            opt = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-5)
        else:
            opt = tf.train.RMSPropOptimizer(learning_rate=lr,
                                            decay=0.99,
                                            epsilon=1e-5)
        self.train_op = layers.optimize_loss(loss=loss_fn,
                                             optimizer=opt,
                                             learning_rate=None,
                                             global_step=self.step,
                                             clip_gradients=clip_grads)
        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver()
        if restore:
            self.saver.restore(
                self.sess,
                tf.train.latest_checkpoint('weights/' + self.config.full_id()))

        self.summary_op = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter('logs/' +
                                                    self.config.full_id(),
                                                    graph=None)
        self.summary_writer.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START), sess.run(self.step))
Exemplo n.º 20
0
    def initialize_tf_variables(self, reset_optimizer_on_restore=False):
        """
        Initialize tensorflow variables (either initializes them from scratch or restores from checkpoint).
        
        :param reset_optimizer_on_restore: Flag indicating whether to reset the optimizer(s) given that this 
            function call includes a restore operation. 
        
        :return: updated TeLL session
        """

        session = self.tf_session
        checkpoint = self.workspace.get_checkpoint()
        #
        # Initialize or load variables
        #
        with Timer(name="Initializing variables"):
            session.run(tf.global_variables_initializer())
            session.run(tf.local_variables_initializer())

        if checkpoint is not None:
            # restore from checkpoint
            self.tf_saver.restore(session, checkpoint)
            # get step number from checkpoint
            step = session.run(self.__global_step_placeholder) + 1
            self.global_step = step
            # reopen summaries
            for _, summary in self.tf_summaries.items():
                summary.reopen()
                summary.add_session_log(
                    tf.SessionLog(status=tf.SessionLog.START),
                    global_step=step)
            print("Resuming from checkpoint '{}' at iteration {}".format(
                checkpoint, step))

            if self.config.get_value('optimizer', None) is not None:
                if reset_optimizer_on_restore:
                    if isinstance(self.tf_optimizer, list):
                        for optimizer in self.tf_optimizer:
                            self.reset_optimizer(optimizer)
                    else:
                        self.reset_optimizer(self.tf_optimizer)
        else:
            for _, summary in self.tf_summaries.items():
                summary.add_graph(session.graph)

        return self
Exemplo n.º 21
0
    def restore_or_init(self):
        """ creates saver, loads from checkpoint if one exists. otherwise setup graph """
        self.saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(self.checkpoint_path)  # ?
        if ckpt:
            self.saver.restore(self.sess, ckpt)

            if self.training_enabled:
                # merge with previous summary session -- will this let global step load as well?
                self.summary_writer.add_session_log(  # not sure what these do?
                    tf.SessionLog(status=tf.SessionLog.START),
                    self.sess.run(self.global_step))

        else:
            self.sess.run(tf.global_variables_initializer())
        # this call locks the computational graph into read-only state,
        # as a safety measure against memory leaks caused by mistakingly adding new ops to it
        self.sess.graph.finalize()
Exemplo n.º 22
0
def train(model, data_iterator):
    """
    """
    FLAGS = tf.app.flags.FLAGS

    summaries = build_summaries(model)

    source_ckpt_path = tf.train.latest_checkpoint(FLAGS.ckpt_path)
    target_ckpt_path = os.path.join(FLAGS.ckpt_path, 'model.ckpt')

    reporter = tf.summary.FileWriter(FLAGS.logs_path)

    with tf.Session() as session:
        if source_ckpt_path is None:
            session.run(tf.global_variables_initializer())
        else:
            tf.train.Saver().restore(session, source_ckpt_path)

        step = session.run(model['step'])

        # NOTE: exclude log which does not happend yet :)
        reporter.add_session_log(tf.SessionLog(status=tf.SessionLog.START),
                                 global_step=step)

        session.run(data_iterator.initializer)

        while step < FLAGS.max_training_steps:
            fetch = {
                'step': model['step'],
                'optimizer': model['optimizer'],
                'summary_loss': summaries['loss'],
            }

            fetched = session.run(fetch)

            step = fetched['step']

            if 'summary_loss' in fetched:
                reporter.add_summary(fetched['summary_loss'], step)

        reporter.flush()

        tf.train.Saver().save(session, target_ckpt_path, global_step=step)
def main():
    seq_len = 256
    batch_size = 64
    epochs = 30
    batches_per_epoch = 1000

    batch_generator = BatchGenerator(batch_size, seq_len)
    g, vs = create_graph(batch_generator.num_letters, batch_size)

    with tf.Session(graph=g) as sess:
        model_saver = tf.train.Saver(max_to_keep=2)
        sess.run(tf.global_variables_initializer())
        model_path = get_model_path()

        summary_writer = tf.summary.FileWriter(model_path,
                                               graph=g,
                                               flush_secs=10)
        summary_writer.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START), global_step=0)
        for e in range(epochs):
            print('\n{} : Epoch {}'.format(datetime.datetime.now().time(), e))
            for b in range(1, batches_per_epoch + 1):
                coordinates, labels, reset, to_reset = batch_generator.next_batch(
                )
                if to_reset:
                    sess.run(vs.reset_states, feed_dict={vs.reset: reset})
                loss, s, _ = sess.run([vs.loss, vs.summary, vs.train_step],
                                      feed_dict={
                                          vs.coordinates: coordinates,
                                          vs.sequence: labels
                                      })
                summary_writer.add_summary(s,
                                           global_step=e * batches_per_epoch +
                                           b)
                print('\r[{:5d}/{:5d}] loss = {}'.format(
                    b, batches_per_epoch, loss),
                      end='')

            model_saver.save(sess,
                             os.path.join(model_path, 'models', 'model'),
                             global_step=e)
Exemplo n.º 24
0
def train():
    """
    """
    ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)

    xx_real = build_image_batch_reader(FLAGS.x_images_dir_path,
                                       FLAGS.batch_size)

    yy_real = build_image_batch_reader(FLAGS.y_images_dir_path,
                                       FLAGS.batch_size)

    image_pool = {}

    model = build_cycle_gan(xx_real, yy_real, '')

    summaries = build_summaries(model)

    reporter = tf.summary.FileWriter(FLAGS.logs_dir_path)

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.local_variables_initializer())

        if ckpt_source_path is not None:
            tf.train.Saver().restore(session, ckpt_source_path)

        # give up overlapped old data
        step = session.run(model['step'])

        reporter.add_session_log(tf.SessionLog(status=tf.SessionLog.START),
                                 global_step=step)

        # make dataset reader work
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        while train_one_step(model, summaries, image_pool, reporter):
            pass

        coord.request_stop()
        coord.join(threads)
    def __init__(self, sess, model, num_action, restore=False, discount=0.99, lr=1e-3, clip_grads=1.):
        self.sess, self.discount = sess, discount
        self.num_action = num_action

        self.global_step_tensor = tf.Variable(0, trainable=False, name='global_step')
        self.network, self.inputs = model(num_action)

        self.loss_val, self.loss_inputs = self._loss_func()
        self.step = tf.Variable(0, trainable=False)

        #opt = tf.train.RMSPropOptimizer(learning_rate=lr, decay=0.99, epsilon=1e-5)
        opt = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-5)
        self.train_op = layers.optimize_loss(loss=self.loss_val, optimizer=opt, learning_rate = None, global_step= self.global_step_tensor, clip_gradients=clip_grads)
        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver()
        if restore:
            self.saver.restore(self.sess, tf.train.latest_checkpoint('weights/Q_nn'))
        self.summary_op = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter('logs/Q_nn', graph=None)
        self.summary_writer.add_session_log(tf.SessionLog(status=tf.SessionLog.START), sess.run(self.step))
Exemplo n.º 26
0
    def __init__(self, sess=None, gamma=0.8, epsilon=0.9):
        self.gamma = gamma
        self.epsilon = epsilon
        self.action_dim = len(smart_actions)
        self.state_dim = len(STATE)
        self.network()
        self.step = tf.Variable(0, trainable=False)
        self.sess = sess
        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())

        if os.path.isdir(DATA_FILE + '/' + WEIGHT_DIR):
            self.saver.restore(
                self.sess,
                tf.train.latest_checkpoint(DATA_FILE + '/' + WEIGHT_DIR))

        self.summary_op = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(DATA_FILE + '/' + LOG_DIR,
                                                    graph=sess.graph)
        self.summary_writer.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START), sess.run(self.step))
Exemplo n.º 27
0
    def logLoss(self,
                nmbr_steps,
                episode,
                loss,
                test=False,
                sess=None,
                agent_id=0):
        """ Logs the rewards to a tensorboard file.

        Args:
          nmbr_steps: Number of steps performed this episode.
          episode: number of the current episode.
          reward: reward achieved throughout the episode.
          sess: A tensorflow session
          agent_id: ID of agent, not used currently.
        """
        tf.logging.info('Finished Episode {} with loss {}'.format(
            episode, loss))
        reward = loss / nmbr_steps
        if self._first_entry:
            self._logger_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.START), global_step=episode)
            self._logger_writer.add_graph(sess.graph)
            self._first_entry = False

        if test:
            sess.run([self._test_assign_reward_op],
                     feed_dict={self._test_placeholder_ep_reward: reward})
        else:
            sess.run([self._assign_reward_op],
                     feed_dict={
                         self._placeholder_ep_reward: reward,
                     })
        if episode > self._episode:
            summary = sess.run(self._merged)
            self._logger_writer.add_summary(summary, episode)
            self._logger_writer.flush()
            self._step_rewards = []
        self._episode = episode
Exemplo n.º 28
0
    def testSessionLogSummaries(self):
        data = [
            {
                'session_log': tf.SessionLog(status=tf.SessionLog.START),
                'step': 0
            },
            {
                'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT),
                'step': 1
            },
            {
                'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT),
                'step': 2
            },
            {
                'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT),
                'step': 3
            },
            {
                'session_log': tf.SessionLog(status=tf.SessionLog.STOP),
                'step': 4
            },
            {
                'session_log': tf.SessionLog(status=tf.SessionLog.START),
                'step': 5
            },
            {
                'session_log': tf.SessionLog(status=tf.SessionLog.STOP),
                'step': 6
            },
        ]

        self._WriteScalarSummaries(data)
        units = efi.get_inspection_units(self.logdir)
        self.assertEqual(1, len(units))
        printable = efi.get_dict_to_print(units[0].field_to_obs)
        self.assertEqual(printable['sessionlog:start']['steps'], [0, 5])
        self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6])
        self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3)
Exemplo n.º 29
0
def train():
    """
    build and train the pix2pix model.
    """
    ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)
    ckpt_target_path = os.path.join(FLAGS.ckpt_dir_path, 'model.ckpt')

    source_images, target_images = build_dataset_reader()

    model = build_pix2pix(
        source_images,
        target_images,
        FLAGS.lambda_value,
        FLAGS.learning_rate,
        True)

    summaries = build_summaries(model)

    reporter = tf.summary.FileWriter(FLAGS.logs_dir_path)

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())

        if ckpt_source_path is not None:
            tf.train.Saver().restore(session, ckpt_source_path)

        # give up overlapped old data
        step = session.run(model['step'])

        reporter.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START),
            global_step=step)

        # make dataset reader work
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        while True:
            fetches = {
                'step': model['step'],
                'd_trainer': model['d_trainer'],
                'summary': summaries['summary'],
            }

            fetched = session.run(fetches)

            if fetched['step'] % 100 == 0:
                reporter.add_summary(fetched['summary'], fetched['step'])
                print fetched['step']

            if fetched['step'] % 10000 == 0:
                tf.train.Saver().save(
                    session, ckpt_target_path, global_step=model['step'])

            fetches = {
                'g_trainer': model['g_trainer'],
            }

            fetched = session.run(fetches)

        coord.request_stop()
        coord.join(threads)
Exemplo n.º 30
0
#init_all = tf.initialize_all_variables()
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
sv = tf.train.Supervisor(logdir=log_dir, save_summaries_secs=0, saver=None)
Data = namedtuple('Data', ['x1', 'y1', 'x2', 'y2', 'flow', 'feature_matches1', 'mask1', 'feature_matches2', 'mask2'])
with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True,gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory_fraction))) as sess:
    #sess.run(init_all)
    #threads = tf.train.start_queue_runners(sess=sess)
    if args.restore: 
        saver.restore(sess, tf.train.latest_checkpoint(model_dir))
        logger.info('restoring {}'.format(tf.train.latest_checkpoint(model_dir)))
    else:
        restorer.restore(sess, checkpoint_file)

    st_step = max(0,sess.run(global_step))
    sv.summary_writer.add_session_log(tf.SessionLog(status=tf.SessionLog.START), global_step=st_step-1)
    time_start = time.time()
    tot_time = 0
    tot_train_time = 0

    for i in range(st_step, training_iter):
        batch_x1s, batch_y1s, batch_x2s, batch_y2s, batch_flows, batch_feature_matches1, batch_mask1, batch_feature_matches2, batch_mask2 = sess.run(
            [x1_batch, y1_batch, x2_batch, y2_batch, flow_batch, feature_matches1_batch, mask1_batch, feature_matches2_batch, mask2_batch])
        if (i > no_theta_iter):
            use_theta = 0
        else:
            use_theta = 1
        if (i >= do_temp_loss_iter):
            use_temp = 1
        else:
            use_temp = 0