コード例 #1
    def update_progressbar(set_name):
        if not hasattr(update_progressbar, 'current_set_name'):
            update_progressbar.current_set_name = None

        if (update_progressbar.current_set_name != set_name or
            update_progressbar.current_job_index == update_progressbar.total_jobs):

            # finish prev pbar if it exists
            if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar:

            update_progressbar.total_jobs = None
            update_progressbar.current_job_index = 0

            current_epoch = coord._epoch-1

            if set_name == "train":
                log_info('Training epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_train
                log_info('Validating epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_dev

            # recreate pbar
            update_progressbar.pbar = progressbar.ProgressBar(max_value=update_progressbar.total_jobs,

            update_progressbar.current_set_name = set_name

        if update_progressbar.pbar:
            update_progressbar.pbar.update(update_progressbar.current_job_index+1, force=True)

        update_progressbar.current_job_index += 1
コード例 #2
    def next_job(self, job):
        '''Sends a finished job back to the coordinator and retrieves in exchange the next one.

            job (WorkerJob): job that was finished by a worker and who's results are to be
                digested by the coordinator

            WorkerJob. next job of one of the running epochs that will get
                associated with the worker from the finished job and put into state 'running'
        if self.is_chief:
            # Try to find the epoch the job belongs to
            epoch = next((epoch for epoch in self._epochs_running if epoch.id == job.epoch_id), None)
            if epoch:
                # We are going to manipulate things - let's avoid undefined state
                with self._lock:
                    # Let the epoch finish the job
                    # Check, if epoch is done now
                    if epoch.done():
                        # If it declares itself done, move it from 'running' to 'done' collection
                        log_info('%s' % epoch)
                # There was no running epoch found for this job - this should never happen.
                log_error('There is no running epoch of ID %d for job with ID %d. This should never happen.' % (job.epoch_id, job.id))
            return self.get_job(job.worker)

        # We are a remote worker and have to hand over to the chief worker by HTTP
        result = self._talk_to_chief('', data=pickle.dumps(job))
        if result:
            result = pickle.loads(result)
        return result
コード例 #3
    def _next_epoch(self):
        # State-machine of the coordination process

        # Indicates, if there were 'new' epoch(s) provided
        result = False

        # Make sure that early stop is enabled and validation part is enabled
        if (FLAGS.early_stop is True) and (FLAGS.validation_step > 0) and (len(self._dev_losses) >= FLAGS.earlystop_nsteps):

            # Calculate the mean of losses for past epochs
            mean_loss = np.mean(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
            # Calculate the standard deviation for losses from validation part in the past epochs
            std_loss = np.std(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
            # Update the list of losses incurred
            self._dev_losses = self._dev_losses[-FLAGS.earlystop_nsteps:]
            log_debug('Checking for early stopping (last %d steps) validation loss: %f, with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))

            # Check if validation loss has started increasing or is not decreasing substantially, making sure slight fluctuations don't bother the early stopping from working
            if self._dev_losses[-1] > np.max(self._dev_losses[:-1]) or (abs(self._dev_losses[-1] - mean_loss) < FLAGS.estop_mean_thresh and std_loss < FLAGS.estop_std_thresh):
                # Time to early stop
                log_info('Early stop triggered as (for last %d steps) validation loss: %f with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))
                self._dev_losses = []
                self._train = False

        if self._train:
            # We are in train mode
            if self._num_jobs_train_left > 0:
                # There are still jobs left
                num_jobs_train = min(self._num_jobs_train_left, self._num_jobs_train)
                self._num_jobs_train_left -= num_jobs_train

                # Let's try our best to keep the notion of curriculum learning

                # Append the training epoch
                self._epochs_running.append(Epoch(self, self._epoch, num_jobs_train, set_name='train'))

                if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0:
                    # The current epoch should also have a validation part
                    self._epochs_running.append(Epoch(self, self._epoch, self._num_jobs_dev, set_name='dev'))

                # Indicating that there were 'new' epoch(s) provided
                result = True
                # No jobs left, but still in train mode: concluding training
                self._train = False

        if result:
            # Increment the epoch index
            self._epoch += 1
        return result
コード例 #4
def export():
    Restores the trained variables into a simpler graph that will be exported for serving.
    log_info('Exporting the model...')
    with tf.device('/cpu:0'):
        from tensorflow.python.framework.ops import Tensor, Operation

        session = tf.Session(config=Config.session_config)

        inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
        input_names = ",".join(tensor.op.name for tensor in inputs.values())
        output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ]
        output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ]
        output_names = ",".join(output_names_tensors + output_names_ops)
        input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())

        if not FLAGS.export_tflite:
            mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
            # Create a saver using variables from the above newly created graph
            def fixup(name):
                if name.startswith('rnn/lstm_cell/'):
                    return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
                return name

            mapping = {fixup(v.op.name): v for v in tf.global_variables()}

        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        checkpoint_path = checkpoint.model_checkpoint_path

        output_filename = 'output_graph.pb'
        if FLAGS.remove_export:
            if os.path.isdir(FLAGS.export_dir):
                log_info('Removing old export')
            output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

            if not os.path.isdir(FLAGS.export_dir):

            def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):

            if not FLAGS.export_tflite:
                do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
                temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
                do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
                output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
                class TFLiteFlags():
                    def __init__(self):
                        self.graph_def_file = temp_freeze
                        self.inference_type = 'FLOAT'
                        self.input_arrays   = input_names
                        self.input_shapes   = input_shapes
                        self.output_arrays  = output_names
                        self.output_file    = output_tflite_path
                        self.output_format  = 'TFLITE'

                        default_empty = [
                            'default_ranges_min', 'default_ranges_max',
                        for e in default_empty:
                            self.__dict__[e] = None

                flags = TFLiteFlags()
                log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))

            log_info('Models exported at %s' % (FLAGS.export_dir)) 
        except RuntimeError as e:
コード例 #5
 def _end_training(self):
     self._training_time = stopwatch(self._training_time)
     log_info('FINISHED Optimization - training time: %s' % format_duration(self._training_time))
コード例 #6
    def start_coordination(self, model_feeder, step=0):
        '''Starts to coordinate epochs and jobs among workers on base of
        data-set sizes, the (global) step and FLAGS parameters.

            model_feeder (ModelFeeder): data-sets to be used for coordinated training

            step (int): global step of a loaded model to determine starting point
        with self._lock:

            # Number of GPUs per worker - fixed for now by local reality or cluster setup
            gpus_per_worker = len(Config.available_devices)

            # Number of batches processed per job per worker
            batches_per_job  = gpus_per_worker * max(1, FLAGS.iters_per_worker)

            # Number of batches per global step
            batches_per_step = gpus_per_worker * max(1, FLAGS.replicas_to_agg)

            # Number of global steps per epoch - to be at least 1
            steps_per_epoch = max(1, model_feeder.train.total_batches // batches_per_step)

            # The start epoch of our training
            self._epoch = step // steps_per_epoch

            # Number of additional 'jobs' trained already 'on top of' our start epoch
            jobs_trained = (step % steps_per_epoch) * batches_per_step // batches_per_job

            # Total number of train/dev jobs covering their respective whole sets (one epoch)
            self._num_jobs_train = max(1, model_feeder.train.total_batches // batches_per_job)
            self._num_jobs_dev   = max(1, model_feeder.dev.total_batches   // batches_per_job)

            if FLAGS.epoch < 0:
                # A negative epoch means to add its absolute number to the epochs already computed
                self._target_epoch = self._epoch + abs(FLAGS.epoch)
                self._target_epoch = FLAGS.epoch

            # State variables
            # We only have to train, if we are told so and are not at the target epoch yet
            self._train = FLAGS.train and self._target_epoch > self._epoch

            if self._train:
                # The total number of jobs for all additional epochs to be trained
                # Will be decremented for each job that is produced/put into state 'open'
                self._num_jobs_train_left = (self._target_epoch - self._epoch) * self._num_jobs_train - jobs_trained
                log_info('STARTING Optimization')
                self._training_time = stopwatch()

            # Important for debugging
            log_debug('step: %d' % step)
            log_debug('epoch: %d' % self._epoch)
            log_debug('target epoch: %d' % self._target_epoch)
            log_debug('steps per epoch: %d' % steps_per_epoch)
            log_debug('number of batches in train set: %d' % model_feeder.train.total_batches)
            log_debug('batches per job: %d' % batches_per_job)
            log_debug('batches per step: %d' % batches_per_step)
            log_debug('number of jobs in train set: %d' % self._num_jobs_train)
            log_debug('number of jobs already trained in first epoch: %d' % jobs_trained)


        # The coordinator is ready to serve
        self.started = True