def __init__(self, args, model, cluster, task): self.args = args self.model = model self.cluster = cluster self.task = task self.run_async = None def default_run_async(f, args): return f(*args) def run_async(f, args): return (self.run_async or default_run_async)(f, args) self.evaluator = Evaluator( self.args, self.model, train_dir(self.args.output_path), self.args.eval_data_paths, 'eval_set', run_async=run_async ) self.train_evaluator = Evaluator( self.args, self.model, train_dir(self.args.output_path), self.args.train_data_paths, 'train_set', run_async=run_async ) self.qualitative_evaluator = get_qualitative_evaluator( self.args, self.model, run_async=run_async ) self.min_train_eval_rate = args.min_train_eval_rate self.global_step = None self.last_save = 0
def get_qualitative_evaluator(args, model, run_async): if args.qualitative_data_paths: return Evaluator( args, model, train_dir(args.output_path), args.qualitative_data_paths, dataset='qualitative_set', eval_set_size=args.qualitative_set_size or args.eval_set_size, qualitative_set_size=args.qualitative_set_size, run_async=run_async ) else: return None
def test_should_not_fail_eval_in_session(self): with tf.Graph().as_default(): model = ExampleModel([EXAMPLE_PROPS_1] * BATCH_SIZE) tensors = model.build_train_graph(DATA_PATHS, BATCH_SIZE) evaluator = Evaluator(args=to_namedtuple(DEFAULT_ARGS, name='args'), model=model, **DEFAULT_KWARGS) evaluator.init() get_logger().info('starting session') with tf.Session() as session: coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=session, coord=coord) get_logger().info('evaluating') session.run(tensors.initializer) evaluator.evaluate_in_session(session, tensors) get_logger().info('done')
def write_predictions(args, model, cluster, task): if not cluster or not task or task.type == 'master': pass # Run locally. else: raise ValueError('invalid task_type %s' % (task.type,)) if args.seed is not None: set_random_seed(args.seed) logger = get_logger() logger.info('Starting to write predictions on %s/%d', task.type, task.index) pool = Pool(processes=args.pool_size) def run_async(f, args): return pool.apply_async(f, args) qualitative_evaluator = get_qualitative_evaluator( args, model, run_async=run_async ) if qualitative_evaluator: qualitative_evaluator.init() qualitative_evaluator.write_predictions() evaluator = Evaluator( args, model, train_dir(args.output_path), args.eval_data_paths, run_async=run_async ) evaluator.init() evaluator.write_predictions() logger.info('Waiting for background tasks to finish') pool.close() pool.join() logger.info('Done writing predictions on %s/%d', task.type, task.index)
class Trainer(object): """Performs model training and optionally evaluation.""" def __init__(self, args, model, cluster, task): self.args = args self.model = model self.cluster = cluster self.task = task self.run_async = None def default_run_async(f, args): return f(*args) def run_async(f, args): return (self.run_async or default_run_async)(f, args) self.evaluator = Evaluator( self.args, self.model, train_dir(self.args.output_path), self.args.eval_data_paths, 'eval_set', run_async=run_async ) self.train_evaluator = Evaluator( self.args, self.model, train_dir(self.args.output_path), self.args.train_data_paths, 'train_set', run_async=run_async ) self.qualitative_evaluator = get_qualitative_evaluator( self.args, self.model, run_async=run_async ) self.min_train_eval_rate = args.min_train_eval_rate self.global_step = None self.last_save = 0 def run_training(self): get_logger().info('creating async pool, pool size: %d', self.args.pool_size) pool = Pool(processes=self.args.pool_size) self.run_async = lambda f, args: pool.apply_async(f, args) self._do_run_training() get_logger().info('Waiting for tasks to complete') pool.close() pool.join() self.run_async = None def _do_run_training(self): """Runs a Master.""" logger = get_logger() logger.info('tensorflow version: %s', tf.__version__) if self.args.seed is not None: set_random_seed(self.args.seed) self.train_evaluator.init() self.evaluator.init() if self.qualitative_evaluator: self.qualitative_evaluator.init() ensure_output_path(self.args.output_path) train_path = train_dir(self.args.output_path) # model_path = model_dir(self.args.output_path) is_master = self.task.type != 'worker' log_interval = self.args.log_interval_secs save_interval = self.args.save_interval_secs eval_interval = self.args.eval_interval_secs summary_interval = log_interval summary_freq = self.args.log_freq if is_master and self.task.index > 0: raise ValueError('Only one replica of master expected') if self.cluster: logging.info('Starting %s/%d', self.task.type, self.task.index) server = start_server(self.cluster, self.task) target = server.target device_fn = tf.train.replica_device_setter( ps_device='/job:ps', worker_device='/job:%s/task:%d' % (self.task.type, self.task.index), cluster=self.cluster ) # We use a device_filter to limit the communication between this job # and the parameter servers, i.e., there is no need to directly # communicate with the other workers; attempting to do so can result # in reliability problems. device_filters = [ '/job:ps', '/job:%s/task:%d' % (self.task.type, self.task.index) ] config = tf.ConfigProto(device_filters=device_filters) else: target = '' device_fn = '' config = None logger.info('batch_size: %s', self.args.batch_size) logger.info( 'available devices: %s', ', '.join([ '%s (%s)' % (x.name, x.device_type) for x in list_local_devices() ]) ) with tf.Graph().as_default() as graph: with tf.device(device_fn): # Build the training graph. logger.info('building graph...') tensors = self.model.build_train_graph( self.args.train_data_paths, self.args.batch_size) logger.info('done building graph, calculating graph size...') logger.info('graph_size: %s bytes', '{:,}'.format(get_graph_size())) # Create a saver for writing training checkpoints. saver = tf.train.Saver( max_to_keep=self.args.save_max_to_keep ) # Create a "supervisor", which oversees the training process. sv = CustomSupervisor( model=self.model, graph=graph, is_chief=is_master, logdir=train_path, saver=saver, # Write summary_ops by hand. summary_op=None, global_step=tensors.global_step ) save_path = sv.save_path should_retry = True local_step = 0 while should_retry: try: should_retry = False with sv.managed_session(target, config=config) as session: start_time = time.time() now = start_time global_step = session.run(tensors.global_step) training_progress_logger = TrainingProgressLogger( start_time, global_step, self.task ) log_scheduler = SimpleStepScheduler( lambda: training_progress_logger.log(now, global_step, local_step), min_interval=log_interval, min_freq=self.args.log_freq, step=global_step, last_run=start_time ) def do_save(): logger.info('saving model to %s (%s)', save_path, global_step) saver.save(session, save_path, tensors.global_step) save_scheduler = SimpleStepScheduler( do_save, min_interval=save_interval, min_freq=self.args.save_freq, step=global_step, last_run=start_time ) eval_train_scheduler = SimpleStepScheduler( lambda: self.eval_train(session, tensors, global_step), min_interval=eval_interval, min_freq=self.args.eval_freq, step=global_step, last_run=start_time ) schedulers = [ log_scheduler, save_scheduler, eval_train_scheduler ] if is_master: eval_scheduler = SimpleStepScheduler( lambda: self.eval(global_step=global_step), min_interval=save_interval, min_freq=self.args.save_freq, step=global_step, last_run=start_time ) schedulers = schedulers + [eval_scheduler] summary_op = sv.summary_op if tensors.summary is None else tensors.summary if summary_op is not None: schedulers.append(SimpleStepScheduler( lambda: sv.summary_writer.add_summary( *session.run([summary_op, tensors.global_step]) ), min_interval=summary_interval, min_freq=summary_freq, step=global_step, last_run=start_time )) # Loop until the supervisor shuts down or args.max_steps have # completed. max_steps = self.args.max_steps while not sv.should_stop() and global_step < max_steps: logging.info("global_step: %s", global_step) try: # Run one step of the model. global_step = session.run([tensors.global_step, tensors.train])[0] logging.info("global_step: %s", global_step) local_step += 1 now = time.time() for scheduler in schedulers: scheduler.step(now) except tf.errors.AbortedError as e: should_retry = True logging.info('AbortedError (%s)', e) except (KeyboardInterrupt, tf.errors.CancelledError): logging.info('cancelled') should_retry = False logging.info('finished (is_master: %s)', is_master) if is_master: # Take the final checkpoint and compute the final accuracy. now = time.time() for scheduler in schedulers: scheduler.flush(now) except tf.errors.AbortedError as e: should_retry = True logging.info('AbortedError (%s)', e) # Ask for all the services to stop. sv.stop() def eval_train(self, session, tensors, global_step): """Runs evaluation loop.""" logging.info( 'Eval, step %d:\n- on train set %s', global_step, self.model.format_metric_values( self.train_evaluator.evaluate_in_session( session=session, tensors=tensors ) ) ) def eval(self, global_step=None): """Runs evaluation loop.""" if self.qualitative_evaluator: logging.info( 'Quantitive Eval, step %s:\n- on eval set %s', global_step, self.model.format_metric_values(self.qualitative_evaluator.evaluate()) ) logging.info( 'Eval, step %s:\n- on eval set %s', global_step, self.model.format_metric_values(self.evaluator.evaluate()) )