Exemplo n.º 1
0
    def __init__(self, args, model, cluster, task):
        self.args = args
        self.model = model
        self.cluster = cluster
        self.task = task
        self.run_async = None

        def default_run_async(f, args):
            return f(*args)

        def run_async(f, args):
            return (self.run_async or default_run_async)(f, args)

        self.evaluator = Evaluator(
            self.args, self.model,
            train_dir(self.args.output_path),
            self.args.eval_data_paths,
            'eval_set',
            run_async=run_async
        )
        self.train_evaluator = Evaluator(
            self.args, self.model,
            train_dir(self.args.output_path),
            self.args.train_data_paths,
            'train_set',
            run_async=run_async
        )
        self.qualitative_evaluator = get_qualitative_evaluator(
            self.args,
            self.model,
            run_async=run_async
        )
        self.min_train_eval_rate = args.min_train_eval_rate
        self.global_step = None
        self.last_save = 0
Exemplo n.º 2
0
def get_qualitative_evaluator(args, model, run_async):
    if args.qualitative_data_paths:
        return Evaluator(
            args,
            model,
            train_dir(args.output_path),
            args.qualitative_data_paths,
            dataset='qualitative_set',
            eval_set_size=args.qualitative_set_size or args.eval_set_size,
            qualitative_set_size=args.qualitative_set_size,
            run_async=run_async
        )
    else:
        return None
Exemplo n.º 3
0
    def test_should_not_fail_eval_in_session(self):
        with tf.Graph().as_default():
            model = ExampleModel([EXAMPLE_PROPS_1] * BATCH_SIZE)
            tensors = model.build_train_graph(DATA_PATHS, BATCH_SIZE)

            evaluator = Evaluator(args=to_namedtuple(DEFAULT_ARGS,
                                                     name='args'),
                                  model=model,
                                  **DEFAULT_KWARGS)
            evaluator.init()

            get_logger().info('starting session')
            with tf.Session() as session:
                coord = tf.train.Coordinator()
                tf.train.start_queue_runners(sess=session, coord=coord)
                get_logger().info('evaluating')
                session.run(tensors.initializer)
                evaluator.evaluate_in_session(session, tensors)
            get_logger().info('done')
Exemplo n.º 4
0
def write_predictions(args, model, cluster, task):
    if not cluster or not task or task.type == 'master':
        pass  # Run locally.
    else:
        raise ValueError('invalid task_type %s' % (task.type,))

    if args.seed is not None:
        set_random_seed(args.seed)

    logger = get_logger()
    logger.info('Starting to write predictions on %s/%d', task.type, task.index)
    pool = Pool(processes=args.pool_size)

    def run_async(f, args):
        return pool.apply_async(f, args)

    qualitative_evaluator = get_qualitative_evaluator(
        args,
        model,
        run_async=run_async
    )
    if qualitative_evaluator:
        qualitative_evaluator.init()
        qualitative_evaluator.write_predictions()

    evaluator = Evaluator(
        args, model, train_dir(args.output_path), args.eval_data_paths,
        run_async=run_async
    )
    evaluator.init()
    evaluator.write_predictions()

    logger.info('Waiting for background tasks to finish')
    pool.close()
    pool.join()
    logger.info('Done writing predictions on %s/%d', task.type, task.index)
Exemplo n.º 5
0
class Trainer(object):
    """Performs model training and optionally evaluation."""

    def __init__(self, args, model, cluster, task):
        self.args = args
        self.model = model
        self.cluster = cluster
        self.task = task
        self.run_async = None

        def default_run_async(f, args):
            return f(*args)

        def run_async(f, args):
            return (self.run_async or default_run_async)(f, args)

        self.evaluator = Evaluator(
            self.args, self.model,
            train_dir(self.args.output_path),
            self.args.eval_data_paths,
            'eval_set',
            run_async=run_async
        )
        self.train_evaluator = Evaluator(
            self.args, self.model,
            train_dir(self.args.output_path),
            self.args.train_data_paths,
            'train_set',
            run_async=run_async
        )
        self.qualitative_evaluator = get_qualitative_evaluator(
            self.args,
            self.model,
            run_async=run_async
        )
        self.min_train_eval_rate = args.min_train_eval_rate
        self.global_step = None
        self.last_save = 0

    def run_training(self):
        get_logger().info('creating async pool, pool size: %d', self.args.pool_size)
        pool = Pool(processes=self.args.pool_size)
        self.run_async = lambda f, args: pool.apply_async(f, args)
        self._do_run_training()
        get_logger().info('Waiting for tasks to complete')
        pool.close()
        pool.join()
        self.run_async = None

    def _do_run_training(self):
        """Runs a Master."""
        logger = get_logger()

        logger.info('tensorflow version: %s', tf.__version__)

        if self.args.seed is not None:
            set_random_seed(self.args.seed)
        self.train_evaluator.init()
        self.evaluator.init()
        if self.qualitative_evaluator:
            self.qualitative_evaluator.init()
        ensure_output_path(self.args.output_path)
        train_path = train_dir(self.args.output_path)
        # model_path = model_dir(self.args.output_path)
        is_master = self.task.type != 'worker'
        log_interval = self.args.log_interval_secs
        save_interval = self.args.save_interval_secs
        eval_interval = self.args.eval_interval_secs
        summary_interval = log_interval
        summary_freq = self.args.log_freq
        if is_master and self.task.index > 0:
            raise ValueError('Only one replica of master expected')

        if self.cluster:
            logging.info('Starting %s/%d', self.task.type, self.task.index)
            server = start_server(self.cluster, self.task)
            target = server.target
            device_fn = tf.train.replica_device_setter(
                ps_device='/job:ps',
                worker_device='/job:%s/task:%d' % (self.task.type, self.task.index),
                cluster=self.cluster
            )
            # We use a device_filter to limit the communication between this job
            # and the parameter servers, i.e., there is no need to directly
            # communicate with the other workers; attempting to do so can result
            # in reliability problems.
            device_filters = [
                '/job:ps', '/job:%s/task:%d' % (self.task.type, self.task.index)
            ]
            config = tf.ConfigProto(device_filters=device_filters)
        else:
            target = ''
            device_fn = ''
            config = None

        logger.info('batch_size: %s', self.args.batch_size)

        logger.info(
            'available devices: %s',
            ', '.join([
                '%s (%s)' % (x.name, x.device_type)
                for x in list_local_devices()
            ])
        )

        with tf.Graph().as_default() as graph:
            with tf.device(device_fn):
                # Build the training graph.
                logger.info('building graph...')
                tensors = self.model.build_train_graph(
                    self.args.train_data_paths,
                    self.args.batch_size)

                logger.info('done building graph, calculating graph size...')
                logger.info('graph_size: %s bytes', '{:,}'.format(get_graph_size()))

                # Create a saver for writing training checkpoints.
                saver = tf.train.Saver(
                    max_to_keep=self.args.save_max_to_keep
                )

        # Create a "supervisor", which oversees the training process.
        sv = CustomSupervisor(
            model=self.model,
            graph=graph,
            is_chief=is_master,
            logdir=train_path,
            saver=saver,
            # Write summary_ops by hand.
            summary_op=None,
            global_step=tensors.global_step
        )

        save_path = sv.save_path

        should_retry = True
        local_step = 0

        while should_retry:
            try:
                should_retry = False
                with sv.managed_session(target, config=config) as session:
                    start_time = time.time()
                    now = start_time

                    global_step = session.run(tensors.global_step)
                    training_progress_logger = TrainingProgressLogger(
                        start_time,
                        global_step,
                        self.task
                    )

                    log_scheduler = SimpleStepScheduler(
                        lambda: training_progress_logger.log(now, global_step, local_step),
                        min_interval=log_interval,
                        min_freq=self.args.log_freq,
                        step=global_step,
                        last_run=start_time
                    )

                    def do_save():
                        logger.info('saving model to %s (%s)', save_path, global_step)
                        saver.save(session, save_path, tensors.global_step)

                    save_scheduler = SimpleStepScheduler(
                        do_save,
                        min_interval=save_interval,
                        min_freq=self.args.save_freq,
                        step=global_step,
                        last_run=start_time
                    )

                    eval_train_scheduler = SimpleStepScheduler(
                        lambda: self.eval_train(session, tensors, global_step),
                        min_interval=eval_interval,
                        min_freq=self.args.eval_freq,
                        step=global_step,
                        last_run=start_time
                    )

                    schedulers = [
                        log_scheduler,
                        save_scheduler,
                        eval_train_scheduler
                    ]

                    if is_master:
                        eval_scheduler = SimpleStepScheduler(
                            lambda: self.eval(global_step=global_step),
                            min_interval=save_interval,
                            min_freq=self.args.save_freq,
                            step=global_step,
                            last_run=start_time
                        )
                        schedulers = schedulers + [eval_scheduler]

                    summary_op = sv.summary_op if tensors.summary is None else tensors.summary
                    if summary_op is not None:
                        schedulers.append(SimpleStepScheduler(
                            lambda: sv.summary_writer.add_summary(
                                *session.run([summary_op, tensors.global_step])
                            ),
                            min_interval=summary_interval,
                            min_freq=summary_freq,
                            step=global_step,
                            last_run=start_time
                        ))

                    # Loop until the supervisor shuts down or args.max_steps have
                    # completed.
                    max_steps = self.args.max_steps
                    while not sv.should_stop() and global_step < max_steps:
                        logging.info("global_step: %s", global_step)
                        try:
                            # Run one step of the model.
                            global_step = session.run([tensors.global_step, tensors.train])[0]
                            logging.info("global_step: %s", global_step)
                            local_step += 1

                            now = time.time()
                            for scheduler in schedulers:
                                scheduler.step(now)

                        except tf.errors.AbortedError as e:
                            should_retry = True
                            logging.info('AbortedError (%s)', e)
                        except (KeyboardInterrupt, tf.errors.CancelledError):
                            logging.info('cancelled')
                            should_retry = False

                    logging.info('finished (is_master: %s)', is_master)

                    if is_master:
                        # Take the final checkpoint and compute the final accuracy.
                        now = time.time()
                        for scheduler in schedulers:
                            scheduler.flush(now)

            except tf.errors.AbortedError as e:
                should_retry = True
                logging.info('AbortedError (%s)', e)

        # Ask for all the services to stop.
        sv.stop()

    def eval_train(self, session, tensors, global_step):
        """Runs evaluation loop."""
        logging.info(
            'Eval, step %d:\n- on train set %s',
            global_step,
            self.model.format_metric_values(
                self.train_evaluator.evaluate_in_session(
                    session=session,
                    tensors=tensors
                )
            )
        )

    def eval(self, global_step=None):
        """Runs evaluation loop."""
        if self.qualitative_evaluator:
            logging.info(
                'Quantitive Eval, step %s:\n- on eval set %s',
                global_step,
                self.model.format_metric_values(self.qualitative_evaluator.evaluate())
            )
        logging.info(
            'Eval, step %s:\n- on eval set %s',
            global_step,
            self.model.format_metric_values(self.evaluator.evaluate())
        )