def __init__(self, model, checkpoint_dir, **params): self.graph = tf.Graph() with self.graph.as_default(): self.global_step = tf.train.get_or_create_global_step() self.ops = model(**params) self.sess = tf.Session() saver = tf.train.Saver() ckpt = utils.find_latest_checkpoint(checkpoint_dir) saver.restore(self.sess, ckpt)
def __init__(self, model, checkpoint_dir, **params): self.graph = tf.Graph() with self.graph.as_default(): self.global_step = tf.train.get_or_create_global_step() self.ops = model(**params) ckpt = utils.find_latest_checkpoint(checkpoint_dir, 'stage*/model.ckpt-*.meta') self.sess = tf.train.SingularMonitoredSession( checkpoint_filename_with_path=ckpt)
def eval_mode(self, ckpt=None): self.session = tf.Session(config=utils.get_config()) saver = tf.train.Saver() if ckpt is None: ckpt = utils.find_latest_checkpoint(self.checkpoint_dir) else: ckpt = os.path.abspath(ckpt) saver.restore(self.session, ckpt) self.tmp.step = self.session.run(self.step) print('Eval model %s at global_step %d' % (self.__class__.__name__, self.tmp.step)) return self
def train(self, dataset, schedule): assert isinstance(schedule, TrainSchedule) batch = FLAGS.batch resume_step = utils.get_latest_global_step_in_subdir( self.checkpoint_dir) phase_start = schedule.phase_index(resume_step * batch) checkpoint_dir = lambda stage: os.path.join(self.checkpoint_dir, 'stage_%d' % stage) for phase in schedule.schedule[phase_start:]: print('Resume step %d Phase %dK:%dK LOD %d:%d' % (resume_step, phase.nimg_start >> 10, phase.nimg_stop >> 10, phase.lod_start, phase.lod_stop)) assert isinstance(phase, TrainPhase) def lod_fn(): return phase.lod(self.nimg_cur) with dataset.graph.as_default(): train_data = dataset.train.batch(batch) train_data = train_data.prefetch(64) train_data = iter(as_iterator(train_data, dataset.sess)) with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() ops = self.model(dataset=dataset, lod_start=phase.lod_start, lod_stop=phase.lod_stop, lod_max=schedule.lod_max, total_steps=schedule.total_nimg // batch, **self.params) self.add_summaries(dataset, ops, lod_fn, **self.params) stop_hook = tf.train.StopAtStepHook( last_step=phase.nimg_stop // batch) report_hook = utils.HookReport(FLAGS.report_kimg << 10, batch) config = tf.ConfigProto() if len(utils.get_available_gpus()) > 1: config.allow_soft_placement = True if FLAGS.log_device_placement: config.log_device_placement = True config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF # When growing the model, load the previously trained layer weights. stage_step_last = utils.get_latest_global_step( checkpoint_dir(phase.lod_stop - 1)) stage_step = utils.get_latest_global_step( checkpoint_dir(phase.lod_stop)) if stage_step_last and not stage_step: last_checkpoint = utils.find_latest_checkpoint( checkpoint_dir(phase.lod_stop - 1)) tf.train.init_from_checkpoint( last_checkpoint, {x: x for x in self.stage_scopes(phase.lod_stop - 1)}) with tf.train.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir(phase.lod_stop), config=config, hooks=[stop_hook], chief_only_hooks=[report_hook], save_checkpoint_secs=600, save_summaries_steps=(FLAGS.save_kimg << 10) // batch) as sess: self.sess = sess self.nimg_cur = batch * self.tf_sess.run(global_step) while not sess.should_stop(): self.train_step(train_data, lod_fn(), ops) resume_step = self.tf_sess.run(global_step) self.nimg_cur = batch * resume_step