Ejemplo n.º 1
0
  def infer_deploy(self, model, funcs=[]):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default() as graph:
      # Default graph
      self.graph = graph
      # Session
      config = tf.ConfigProto(allow_soft_placement=True)
      config.gpu_options.allow_growth = True

      devices = self.ctx.devices if len(self.ctx.devices) > 0 else self.devices
      config.gpu_options.visible_device_list = ','.join(str(x) for x in devices) if len(devices) > 0 else ''

      self.sess = tf.Session(graph=graph, config=config)


      #######################
      # Config model_deploy #
      #######################
      deploy_config = tfmodel_deploy.DeploymentConfig(num_clones=1,
                                                      devices=[],
                                                      clone_on_cpu=getattr(self, 'clone_on_cpu', False),
                                                      replica_id=0,
                                                      num_replicas=1,
                                                      num_ps_tasks=0)
      
      # init some info
      with tf.device(deploy_config.inputs_device()):
        #############################
        ####    define model input ##
        #############################
        data_queue = None
        if self.ctx.ant is not None:
          with tf.variable_scope('input'):
            data_queue = self.ctx.model.model_input(self.is_training)
            if data_queue is not None:
              self._has_model_input = True

        #############################
        ####    define model       ##
        #############################
        func = model.model_fn
        @functools.wraps(func)
        def network_fn(*args, **kwargs):
          res = func(self.is_training, *args, **kwargs)
          return res
  
        #######################
        # Create model clones #
        #######################
        self.clones = tfmodel_deploy.create_clones(deploy_config, network_fn, [data_queue] if data_queue is not None else None)

        # create other func
        for create_func in funcs:
          self._create_funcs.append(create_func())

        # write graph
        tf.train.write_graph(graph.as_graph_def(), self.dump_dir, 'infer_graph.pbtxt')
        # svg_graph = _convert_to_svg_graph(os.path.join(self.dump_dir, 'infer_graph.pbtxt'),
        #                                   self.dump_dir,
        #                                   ['input'])
        # if svg_graph is not None:
        #   self.ctx.job.send({'DATA': {'GRAPH': svg_graph}})

        # Global initialization
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())
        
        self.coord = tf.train.Coordinator()
        self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)

        custom_dataset_queue = tf\
          .get_collection('CUSTOM_DATASET_QUEUE')
        if len(custom_dataset_queue) > 0:
          custom_dataset_queue[0].coord = self.coord
          custom_threads = custom_dataset_queue[0].start_threads(self.sess)
          self.threads.extend(custom_threads)

        # Restore from checkpoint
        restore_fns = _get_init_fn(self, model, self.dump_dir, self.ctx)
        if restore_fns is not None:
          for restore_fn in restore_fns:
            restore_fn(self.sess)

        # Restore from custom auxilary init funcs
        for func in self._aux_init_funcs:
          func(self.sess)

        # model saver
        # model_variables = slim.get_model_variables() if model.model_variables is None else model.model_variables
        self.saver = tf.train.Saver(max_to_keep=1)

        # snapshot
        self.snapshot(0)

        # Value ops
        self.val_ops = self.clones[0].outputs
        if type(self.val_ops) != list and type(self.val_ops) != tuple:
          self.val_ops = [self.val_ops]
        if type(self.val_ops) == tuple:
          self.val_ops = list(self.val_ops)

        # Append recorder model fn
        if self.ctx.recorder is not None and self.ctx.recorder.model_fn is not None:
          self.val_ops.append(self.ctx.recorder.model_fn)
Ejemplo n.º 2
0
  def training_deploy(self, model, funcs=[]):
    # Horovod: initialize Horovod (prepare MPI envoriment)
    if self.is_distribute_training:
      import horovod.tensorflow as hvd
      hvd.init()

      # reset num_clones = 1
      self.num_clones = 1
      self.rank = hvd.rank()
      self.local_rank = hvd.local_rank()

      get_global_context().quiet = False if self.rank == 0 else True

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default() as graph:
      # Default graph
      self.graph = graph
      # Session
      config = tf.ConfigProto(allow_soft_placement=True)
      config.gpu_options.allow_growth = True
      devices = self.ctx.devices if len(self.ctx.devices) > 0 else self.devices
      config.gpu_options.visible_device_list = ','.join(str(x) for x in devices) if len(devices) > 0 else ''
      self.sess = tf.Session(graph=graph, config=config)
      
      #######################
      # Config model deploy #
      #######################
      deploy_config = tfmodel_deploy.DeploymentConfig(num_clones=self.num_clones,
                                                      devices=[],
                                                      clone_on_cpu=self.clone_on_cpu,
                                                      replica_id=self.replica_id,
                                                      num_replicas=self.worker_replicas,
                                                      num_ps_tasks=self.num_ps_tasks,
                                                      clone_id_map={0:self.local_rank} if self.is_distribute_training else {})

      # init some info
      with tf.device(deploy_config.inputs_device()):
        # Create global_step
        with tf.device(deploy_config.variables_device()):
          global_step = slim.get_or_create_global_step()

        ###################################
        ####    define model input (CPU) ##
        ###################################
        with tf.variable_scope('input'):
          data_queue = self.ctx.model.model_input(self.is_training)
          if data_queue is not None:
            self._has_model_input = True
        
        ###################################
        ####    define model (CPU or GPU) #
        ###################################
        func = model.model_fn
        @functools.wraps(func)
        def network_fn(*args, **kwargs):
          res = func(self.is_training, *args, **kwargs)
          if kwargs['clone'] == 0:
            # 1.step save graph file
            tf.train.write_graph(self.sess.graph_def, self.dump_dir, 'graph.pbtxt')
            
            # # 2.step transfer to local graph net
            # logger.info('build model graph svg')
            # svg_graph = _convert_to_svg_graph(os.path.join(self.dump_dir, 'graph.pbtxt'),
            #                                   self.dump_dir,
            #                                   ['input'])
            # if svg_graph is not None:
            #   self.ctx.job.send({'DATA': {'GRAPH': svg_graph}})
          return res

        ####################################
        ####### Create summary      ########
        ####################################
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        ####################################
        ####### Create model clones ########
        ####################################
        self.clones = tfmodel_deploy.create_clones(deploy_config,
                                                   network_fn,
                                                   [data_queue] if data_queue is not None else None)
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
          summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # create other func
        for create_func in funcs:
          self._create_funcs.append(create_func())

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
          # samples total number
          num_samples = self.num_samples if self.num_samples > 0 else self.ctx.data_source.size

          # Horovod: adjust learning rate based on number of GPUs
          self.lr = _configure_learning_rate(self, num_samples, global_step)
          summaries.add(tf.summary.scalar('learning_rate', self.lr))

          # config optimizer
          optimizer = _configure_optimizer(self, self.lr)

          # Horovod: add Horovod Distributed Optimizer
          if self.is_distribute_training:
            optimizer = hvd.DistributedOptimizer(optimizer)

        # Variables to train.
        variables_to_train = _get_variables_to_train(self)
        
        with tf.control_dependencies(self.model_dependence):
          # Train_tensor
          total_loss, clones_gradients = \
            tfmodel_deploy.optimize_clones(self.clones,
                                           optimizer,
                                           regularization_losses=None if self.regularization_loss else [],
                                           var_list=variables_to_train)

          summaries.add(tf.summary.scalar('total_loss', total_loss))

          # Create gradient updates.
          grad_updates = optimizer.apply_gradients(clones_gradients,
                                                   global_step=global_step)

        # Value ops
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
          self.val_ops = tf.identity(total_loss, name='train_op')
  
          if self.clones[0].outputs is not None:
            self.val_ops = [self.val_ops]
            if type(self.clones[0].outputs) == list:
              self.val_ops.extend(self.clones[0].outputs)
            elif type(self.clones[0].outputs) == tuple:
              self.val_ops.extend(list(self.clones[0].outputs))
            else:
              self.val_ops.append(self.clones[0].outputs)

          if type(self.val_ops) != list:
            self.val_ops = [self.val_ops]

        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))

        # Merge all summaries together.
        self.summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if self.summary_op is not None:
          val_ops_temp = [self.summary_op]
          val_ops_temp.extend(self.val_ops)
          self.val_ops = val_ops_temp

        # summary write
        if not os.path.exists(os.path.join(self.dump_dir, 'summary')):
          os.makedirs(os.path.join(self.dump_dir, 'summary'))

        self.train_writer = tf.summary.FileWriter(os.path.join(self.dump_dir, 'summary'), graph)

        # Global initialization
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())
        # coord
        self.coord = tf.train.Coordinator()
        self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)
        
        custom_dataset_queue = tf.get_collection('CUSTOM_DATASET_QUEUE')
        if len(custom_dataset_queue) > 0:
          custom_dataset_queue[0].coord = self.coord
          custom_threads = custom_dataset_queue[0].start_threads(self.sess)
          self.threads.extend(custom_threads)
        
        # Training saver
        # model_variables = slim.get_model_variables() if model.model_variables is None else model.model_variables
        self.saver = tf.train.Saver(max_to_keep=2)

        # Restore from checkpoint
        if not self.is_distribute_training or (self.is_distribute_training and self.rank == 0):
          restore_fns = _get_init_fn(self, model, self.dump_dir, self.ctx)
          if restore_fns is not None:
            for restore_fn in restore_fns:
              restore_fn(self.sess)

          # Restore from custom auxilary init funcs
          for func in self._aux_init_funcs:
            func(self.sess)

        # resotre from auxilary checkpoint
        for auxilary_scope, auxilary_checkpoint in self.auxilary_checkpoints.items():
          self.restore_scopy_from(model, auxilary_scope, auxilary_checkpoint)

        # Horovod boardcast global variables
        if self.is_distribute_training:
          bgv = hvd.BroadcastGlobalVariablesHook(0)
          bgv.begin()
          bgv.after_create_session(self.sess, self.coord)
Ejemplo n.º 3
0
    def infer_deploy(self, model):
        tf.logging.set_verbosity(tf.logging.INFO)
        with tf.Graph().as_default() as graph:
            # Default graph
            self.graph = graph
            # Session
            config = tf.ConfigProto(allow_soft_placement=True)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(graph=graph, config=config)

            #######################
            # Config model_deploy #
            #######################
            deploy_config = tfmodel_deploy.DeploymentConfig(
                num_clones=1,
                devices=self.devices,
                clone_on_cpu=getattr(self, 'clone_on_cpu', False),
                replica_id=0,
                num_replicas=1,
                num_ps_tasks=0)

            # init some info
            with tf.device(deploy_config.inputs_device()):
                #############################
                ####    define model input ##
                #############################
                with tf.variable_scope('input'):
                    data_queue = self.ctx.model.model_input(
                        self.is_training, self.ctx.data_source)

                #############################
                ####    define model       ##
                #############################
                func = model.model_fn

                @functools.wraps(func)
                def network_fn(*args, **kwargs):
                    res = func(self.is_training, *args, **kwargs)
                    return res

                #######################
                # Create model clones #
                #######################
                self.clones = tfmodel_deploy.create_clones(
                    deploy_config, network_fn, [data_queue])

                # write graph
                tf.train.write_graph(graph.as_graph_def(), self.dump_dir,
                                     'infer_graph.pbtxt')
                svg_graph = _convert_to_svg_graph(
                    os.path.join(self.dump_dir, 'infer_graph.pbtxt'),
                    self.dump_dir, ['input'])
                if svg_graph is not None:
                    self.ctx.job.send({'DATA': {'GRAPH': svg_graph}})

                # Global initialization
                self.sess.run(tf.global_variables_initializer())
                self.sess.run(tf.local_variables_initializer())

                self.coord = tf.train.Coordinator()
                self.threads = tf.train.start_queue_runners(sess=self.sess,
                                                            coord=self.coord)

                # Restore from checkpoint
                restore_fns = _get_init_fn(self, self.dump_dir, self.ctx)
                if restore_fns is not None:
                    for restore_fn in restore_fns:
                        restore_fn(self.sess)

                # model saver
                model_variables = slim.get_model_variables(
                ) if self.model_variables is None else self.model_variables
                self.saver = tf.train.Saver(var_list=model_variables,
                                            max_to_keep=1)

                # snapshot
                self.snapshot(0)

                # Value ops
                self.val_ops = self.clones[0].outputs
                if type(self.val_ops) != list and type(self.val_ops) != tuple:
                    self.val_ops = [self.val_ops]
                if type(self.val_ops) == tuple:
                    self.val_ops = list(self.val_ops)

                # Append recorder model fn
                if self.ctx.recorder is not None and self.ctx.recorder.model_fn is not None:
                    self.val_ops.append(self.ctx.recorder.model_fn)
Ejemplo n.º 4
0
    def training_deploy(self, model, *args, **kwargs):
        tf.logging.set_verbosity(tf.logging.INFO)
        with tf.Graph().as_default() as graph:
            # default graph
            self.graph = graph

            # session
            config = tf.ConfigProto(allow_soft_placement=True)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(graph=graph, config=config)

            #######################
            # Config model deploy #
            #######################
            deploy_config = tfmodel_deploy.DeploymentConfig(
                num_clones=1,
                devices=self.devices,
                clone_on_cpu=self.clone_on_cpu,
                replica_id=self.replica_id,
                num_replicas=self.worker_replicas,
                num_ps_tasks=self.num_ps_tasks)

            # init some info
            with tf.device(deploy_config.inputs_device()):
                #############################
                #### Define model input #####
                #############################
                with tf.variable_scope('input'):
                    data_queue = self.ctx.model.model_input(self.is_training)
                    if data_queue is not None:
                        self._has_model_input = True

                func = model.model_fn

                @functools.wraps(func)
                def network_fn(*args, **kwargs):
                    #
                    logger.info('building computing graph')
                    res = func(self.is_training, *args, **kwargs)
                    tf.train.write_graph(self.sess.graph_def, self.dump_dir,
                                         'graph.pbtxt')
                    return res

                ####################################
                ####### Create summary      ########
                ####################################
                summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

                #############################
                # Create model clones #######
                #############################
                self.clones = tfmodel_deploy.create_clones(
                    deploy_config, network_fn,
                    [data_queue] if data_queue is not None else None,
                    {'trainer': self})

                #########################################
                # Configure the optimization procedure. #
                #########################################
                with tf.device(deploy_config.optimizer_device()):
                    num_samples = self.num_samples if self.num_samples > 0 else self.ctx.data_source.size
                    trainable_variables = self.get_variables_to_train()  # all
                    update_ops = tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS)  # all

                    for loss in tf.get_collection(
                            tf.GraphKeys.LOSSES, deploy_config.clone_scope(0)):
                        summaries.add(
                            tf.summary.scalar('losses/%s' % loss.op.name,
                                              loss))

                    for loss_name, loss_config in kwargs.items():
                        loss_scope = loss_config['scope']

                        # config loss log
                        self.loss_log[loss_name] = MovingAverage(10)

                        # Extract loss variable
                        self.loss_list[loss_name] = graph.get_tensor_by_name(
                            '{}:0'.format(loss_name))

                        if 'learning_rate' in loss_config:
                            self.lr_list[loss_name] = graph.get_tensor_by_name(
                                '{}:0'.format(loss_config['learning_rate']))
                            self.global_step_list[loss_name] = None
                        else:
                            global_step = tf.Variable(0, trainable=False)
                            self.lr_list[
                                loss_name] = self.configure_learning_rate(
                                    num_samples, global_step)
                            self.global_step_list[loss_name] = global_step

                        summaries.add(
                            tf.summary.scalar('%s_learning_rate' % loss_name,
                                              self.lr_list[loss_name]))

                        # config optimization procedure
                        self.optimizer_list[
                            loss_name] = self.configure_optimizer(
                                self.lr_list[loss_name])

                        # config related variables
                        for loss_scope_name in loss_scope.split(','):
                            if loss_name not in self.var_list:
                                self.var_list[loss_name] = []
                            self.var_list[loss_name].extend([
                                var for var in trainable_variables
                                if loss_scope_name in var.name
                            ])

                        # get update ops
                        for loss_scope_name in loss_scope.split(','):
                            if loss_name not in self.update_list:
                                self.update_list[loss_name] = []
                            self.update_list[loss_name].extend([
                                var for var in update_ops
                                if loss_scope_name in var.name
                            ])

                summaries |= set(
                    tf.get_collection(tf.GraphKeys.SUMMARIES,
                                      deploy_config.clone_scope(0)))

                # Merge all summaries together.
                self.summary_op = tf.summary.merge(list(summaries),
                                                   name='summary_op')

                # Variables to train.
                for loss_name, loss_var in self.loss_list.items():
                    optimizer = self.optimizer_list[loss_name]

                    with tf.name_scope(self.clones[0].scope):
                        with tf.device(self.clones[0].device):
                            clone_grad = optimizer.compute_gradients(
                                loss_var, var_list=self.var_list[loss_name])
                            grad_update = optimizer.apply_gradients(
                                clone_grad,
                                global_step=self.global_step_list[loss_name]
                                if self.global_step_list[loss_name] is not None
                                else None)

                            self.update_list[loss_name].append(grad_update)

                            with tf.control_dependencies(
                                [tf.group(*self.update_list[loss_name])]):
                                train_op = tf.identity(loss_var,
                                                       name='train_op_%s' %
                                                       loss_name)
                                self.trainop_list[loss_name] = train_op

                                if self.clones[0].outputs is not None:
                                    if type(self.clones[0].outputs) == dict:
                                        for k, v in self.clones[
                                                0].outputs.items():
                                            if type(k) != str:
                                                k = k.name.replace(':0', '')

                                            if k == loss_name:
                                                self.trainop_list[
                                                    loss_name] = [train_op]
                                                if type(v) == list or type(
                                                        v) == tuple:
                                                    self.trainop_list[
                                                        loss_name].extend(
                                                            list(v))
                                                else:
                                                    self.trainop_list[
                                                        loss_name].append(v)

                                if type(self.trainop_list[loss_name]) != list:
                                    self.trainop_list[loss_name] = [
                                        self.trainop_list[loss_name]
                                    ]

                                if self.summary_op is not None:
                                    val_ops_temp = [self.summary_op]
                                    val_ops_temp.extend(
                                        self.trainop_list[loss_name])
                                    self.trainop_list[loss_name] = val_ops_temp

                # summary write
                if not os.path.exists(os.path.join(self.dump_dir, 'summary')):
                    os.makedirs(os.path.join(self.dump_dir, 'summary'))

                self.train_writer = tf.summary.FileWriter(
                    os.path.join(self.dump_dir, 'summary'), graph)

                # Global initialization
                self.sess.run(tf.global_variables_initializer())
                self.sess.run(tf.local_variables_initializer())
                # coord
                self.coord = tf.train.Coordinator()
                self.threads = tf.train.start_queue_runners(sess=self.sess,
                                                            coord=self.coord)

                custom_dataset_queue = tf.get_collection(
                    'CUSTOM_DATASET_QUEUE')
                if len(custom_dataset_queue) > 0:
                    custom_dataset_queue[0].coord = self.coord
                    custom_threads = custom_dataset_queue[0].start_threads(
                        self.sess)
                    self.threads.extend(custom_threads)

                # Training saver
                model_variables = slim.get_model_variables(
                ) if model.model_variables is None else model.model_variables
                self.saver = tf.train.Saver(var_list=model_variables,
                                            max_to_keep=2)

                # Restore from checkpoint
                restore_fns = self.get_init_fn(model, self.dump_dir, self.ctx)
                if restore_fns is not None:
                    for restore_fn in restore_fns:
                        restore_fn(self.sess)

                # resotre from auxilary checkpoint
                for auxilary_scope, auxilary_checkpoint in self.auxilary_checkpoints.items(
                ):
                    self.restore_scopy_from(model, auxilary_scope,
                                            auxilary_checkpoint)
Ejemplo n.º 5
0
    def training_deploy(self, model):
        tf.logging.set_verbosity(tf.logging.INFO)
        with tf.Graph().as_default() as graph:
            # Default graph
            self.graph = graph
            # Session
            config = tf.ConfigProto(allow_soft_placement=True)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(graph=graph, config=config)

            #######################
            # Config model deploy #
            #######################
            deploy_config = tfmodel_deploy.DeploymentConfig(
                num_clones=self.num_clones,
                devices=self.devices,
                clone_on_cpu=self.clone_on_cpu,
                replica_id=self.replica_id,
                num_replicas=self.worker_replicas,
                num_ps_tasks=self.num_ps_tasks)

            # init some info
            with tf.device(deploy_config.inputs_device()):
                #############################
                ####    define model input ##
                #############################
                with tf.variable_scope('input'):
                    data_queue = self.ctx.model.model_input(
                        self.is_training, self.ctx.data_source)

                #############################
                ####    define model       ##
                #############################
                func = model.model_fn

                @functools.wraps(func)
                def network_fn(*args, **kwargs):
                    res = func(self.is_training, *args, **kwargs)
                    if kwargs['clone'] == 0:
                        # 1.step save graph file
                        tf.train.write_graph(self.sess.graph_def,
                                             self.dump_dir, 'graph.pbtxt')

                        # 2.step transfer to local graph net
                        logger.info('build model graph svg')
                        svg_graph = _convert_to_svg_graph(
                            os.path.join(self.dump_dir, 'graph.pbtxt'),
                            self.dump_dir, ['input'])
                        if svg_graph is not None:
                            self.ctx.job.send({'DATA': {'GRAPH': svg_graph}})
                    return res

                #######################
                # Create model clones #
                #######################
                self.clones = tfmodel_deploy.create_clones(
                    deploy_config, network_fn,
                    [data_queue] if data_queue is not None else None,
                    {'trainer': self})
                first_clone_scope = deploy_config.clone_scope(0)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               first_clone_scope)

                # Create global_step
                with tf.device(deploy_config.variables_device()):
                    global_step = slim.get_or_create_global_step()

                #########################################
                # Configure the optimization procedure. #
                #########################################
                with tf.device(deploy_config.optimizer_device()):
                    num_samples = self.num_samples if self.num_samples > 0 else self.ctx.data_source.size
                    self.lr = _configure_learning_rate(self, num_samples,
                                                       global_step)
                    optimizer = _configure_optimizer(self, self.lr)

                # Variables to train.
                variables_to_train = _get_variables_to_train(self)

                with tf.control_dependencies(self.model_dependence):
                    # Train_tensor
                    total_loss, clones_gradients = \
                      tfmodel_deploy.optimize_clones(self.clones,
                                                     optimizer,
                                                     regularization_losses=None if self.regularization_loss else [],
                                                     var_list=variables_to_train)

                    # Create gradient updates.
                    grad_updates = optimizer.apply_gradients(
                        clones_gradients, global_step=global_step)

                # Value ops
                update_ops.append(grad_updates)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    self.val_ops = tf.identity(total_loss, name='train_op')

                    if self.clones[0].outputs is not None:
                        self.val_ops = [self.val_ops]
                        if type(self.clones[0].outputs) == list:
                            self.val_ops.extend(self.clones[0].outputs)
                        elif type(self.clones[0].outputs) == tuple:
                            self.val_ops.extend(list(self.clones[0].outputs))
                        else:
                            self.val_ops.append(self.clones[0].outputs)

                # Global initialization
                self.sess.run(tf.global_variables_initializer())
                self.sess.run(tf.local_variables_initializer())
                # coord
                self.coord = tf.train.Coordinator()
                self.threads = tf.train.start_queue_runners(sess=self.sess,
                                                            coord=self.coord)

                custom_dataset_queue = tf.get_collection(
                    'CUSTOM_DATASET_QUEUE')
                if len(custom_dataset_queue) > 0:
                    custom_dataset_queue[0].coord = self.coord
                    custom_threads = custom_dataset_queue[0].start_threads(
                        self.sess)
                    self.threads.extend(custom_threads)

                # Training saver
                model_variables = slim.get_model_variables(
                ) if self.model_variables is None else self.model_variables
                self.saver = tf.train.Saver(var_list=model_variables,
                                            max_to_keep=2)

                # Restore from checkpoint
                restore_fns = _get_init_fn(self, self.dump_dir, self.ctx)
                if restore_fns is not None:
                    for restore_fn in restore_fns:
                        restore_fn(self.sess)