Exemple #1
0
 def __init__(self, model, checkpoint_dir, **params):
     self.graph = tf.Graph()
     with self.graph.as_default():
         self.global_step = tf.train.get_or_create_global_step()
         self.ops = model(**params)
         self.sess = tf.Session()
         saver = tf.train.Saver()
         ckpt = utils.find_latest_checkpoint(checkpoint_dir)
         saver.restore(self.sess, ckpt)
Exemple #2
0
 def __init__(self, model, checkpoint_dir, **params):
     self.graph = tf.Graph()
     with self.graph.as_default():
         self.global_step = tf.train.get_or_create_global_step()
         self.ops = model(**params)
         ckpt = utils.find_latest_checkpoint(checkpoint_dir,
                                             'stage*/model.ckpt-*.meta')
         self.sess = tf.train.SingularMonitoredSession(
             checkpoint_filename_with_path=ckpt)
Exemple #3
0
 def eval_mode(self, ckpt=None):
     self.session = tf.Session(config=utils.get_config())
     saver = tf.train.Saver()
     if ckpt is None:
         ckpt = utils.find_latest_checkpoint(self.checkpoint_dir)
     else:
         ckpt = os.path.abspath(ckpt)
     saver.restore(self.session, ckpt)
     self.tmp.step = self.session.run(self.step)
     print('Eval model %s at global_step %d' % (self.__class__.__name__, self.tmp.step))
     return self
Exemple #4
0
    def train(self, dataset, schedule):
        assert isinstance(schedule, TrainSchedule)
        batch = FLAGS.batch
        resume_step = utils.get_latest_global_step_in_subdir(
            self.checkpoint_dir)
        phase_start = schedule.phase_index(resume_step * batch)
        checkpoint_dir = lambda stage: os.path.join(self.checkpoint_dir,
                                                    'stage_%d' % stage)

        for phase in schedule.schedule[phase_start:]:
            print('Resume step %d  Phase %dK:%dK  LOD %d:%d' %
                  (resume_step, phase.nimg_start >> 10, phase.nimg_stop >> 10,
                   phase.lod_start, phase.lod_stop))
            assert isinstance(phase, TrainPhase)

            def lod_fn():
                return phase.lod(self.nimg_cur)

            with dataset.graph.as_default():
                train_data = dataset.train.batch(batch)
                train_data = train_data.prefetch(64)
                train_data = iter(as_iterator(train_data, dataset.sess))

            with tf.Graph().as_default():
                global_step = tf.train.get_or_create_global_step()
                ops = self.model(dataset=dataset,
                                 lod_start=phase.lod_start,
                                 lod_stop=phase.lod_stop,
                                 lod_max=schedule.lod_max,
                                 total_steps=schedule.total_nimg // batch,
                                 **self.params)
                self.add_summaries(dataset, ops, lod_fn, **self.params)
                stop_hook = tf.train.StopAtStepHook(
                    last_step=phase.nimg_stop // batch)
                report_hook = utils.HookReport(FLAGS.report_kimg << 10, batch)
                config = tf.ConfigProto()
                if len(utils.get_available_gpus()) > 1:
                    config.allow_soft_placement = True
                if FLAGS.log_device_placement:
                    config.log_device_placement = True
                config.gpu_options.allow_growth = True
                config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF

                # When growing the model, load the previously trained layer weights.
                stage_step_last = utils.get_latest_global_step(
                    checkpoint_dir(phase.lod_stop - 1))
                stage_step = utils.get_latest_global_step(
                    checkpoint_dir(phase.lod_stop))
                if stage_step_last and not stage_step:
                    last_checkpoint = utils.find_latest_checkpoint(
                        checkpoint_dir(phase.lod_stop - 1))
                    tf.train.init_from_checkpoint(
                        last_checkpoint,
                        {x: x
                         for x in self.stage_scopes(phase.lod_stop - 1)})

                with tf.train.MonitoredTrainingSession(
                        checkpoint_dir=checkpoint_dir(phase.lod_stop),
                        config=config,
                        hooks=[stop_hook],
                        chief_only_hooks=[report_hook],
                        save_checkpoint_secs=600,
                        save_summaries_steps=(FLAGS.save_kimg << 10) //
                        batch) as sess:
                    self.sess = sess
                    self.nimg_cur = batch * self.tf_sess.run(global_step)
                    while not sess.should_stop():
                        self.train_step(train_data, lod_fn(), ops)
                        resume_step = self.tf_sess.run(global_step)
                        self.nimg_cur = batch * resume_step