示例#1
0
 def __init__(self, data_io, save_interval=100):
     super(GlobalReaderStateHook, self).__init__()
     self._worker_num = get_task_num()
     self._worker_index = get_task_index()
     self._step = 0
     self._save_interval = save_interval
     self._data_io = data_io
     self._state = io_state_pb2.DSState()
     self._state.ds_name = data_io.ds_name
     self._state.epochs = 0
示例#2
0
 def __init__(self, data_io, save_interval=100):
     self._worker_num = get_task_num()
     self._worker_index = get_task_index()
     self._step = 0
     self._save_interval = save_interval
     self._data_io = data_io
     self._state_var = Variable(name=self._data_io.ds_name +
                                "/reader_state",
                                shape=[self._worker_num, _STATE_SIZE],
                                dtype=DataType.int8,
                                initializer=Zeros(),
                                trainable=False)
示例#3
0
 def train(self,
           input_fn,
           auc_fn=auc,
           auc_interval=100,
           auc_bucket_num=200,
           max_step=sys.maxint,
           checkpoint_interval=None,
           log_format=LOG_FMT,
           user_hooks=None):
     '''
   Args:
   input_fn:
   auc_fn:
   auc_interval:
   max_step:
   checkpoint_interval:
   log_format:
   user_hooks:
   '''
     data, labels = input_fn()
     model_outputs = self._model_fn(data, labels)
     if len(model_outputs) < 2:
         raise ArgumentError("model_fn must return loss and logits")
     loss = model_outputs[0]
     logits = model_outputs[1]
     train_op = self._optimizer.optimize()
     auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num)
     hooks = []
     hooks.append(QpsMetricsHook())
     hooks.append(MetricsHook("auc", auc_op, interval=auc_interval))
     if user_hooks is not None:
         if isinstance(user_hooks, list):
             hooks.extend(user_hooks)
         else:
             hooks.append(user_hooks)
     reader_hooks = get_collection(READER_HOOKS)
     if reader_hooks is not None:
         hooks.extend(reader_hooks)
     if checkpoint_interval and get_task_index() == 0:
         hooks.append(CheckpointHook(checkpoint_interval))
     hooks.append(MetricsPrinterHook(log_format, auc_interval))
     sess = TrainSession(hooks=hooks)
     i = 0
     while not sess.should_stop() and i < max_step:
         sess.run(train_op)
         i = i + 1
示例#4
0
 def evaluate(self,
              input_fn,
              checkpoint_version="",
              log_format=EVAL_LOG_FMT,
              log_interval=100,
              max_step=sys.maxint,
              auc_fn=auc,
              auc_bucket_num=200,
              user_hooks=None):
     '''
 Args:
 input_fn:
 checkpoint_version:
 log_format:
 log_interval:
 max_step:
 auc_fn:
 user_hooks:
 '''
     from xdl.python.training.saver import Saver
     if get_task_index() == 0:
         saver = Saver()
         saver.restore(checkpoint_version)
     data, labels = input_fn()
     model_outputs = self._model_fn(data, labels)
     if len(model_outputs) < 2:
         raise ArgumentError("model_fn must return loss and logits")
     logits = model_outputs[1]
     auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num)
     hooks = []
     hooks.append(QpsMetricsHook())
     hooks.append(MetricsHook("auc", auc_op, interval=1))
     if user_hooks is not None:
         if isinstance(user_hooks, list):
             hooks.extend(user_hooks)
         else:
             hooks.append(user_hooks)
     hooks.append(MetricsPrinterHook(log_format, log_interval))
     sess = TrainSession(hooks=hooks)
     if id(auc_fn) == id(auc):
         sess.run(reset_auc_variables_op(auc_bucket_num))
     i = 0
     while not sess.should_stop() and i < max_step:
         sess.run([])
         i = i + 1
示例#5
0
 def predict(self,
             input_fn,
             checkpoint_version="",
             log_format=PREDICT_LOG_FMT,
             log_interval=100,
             max_step=sys.maxint,
             user_hooks=None):
     ''' 
 Args:
 input_fn:
 checkpoint_version:
 log_format:
 log_interval:
 max_step:
 user_hooks:
 '''
     from xdl.python.training.saver import Saver
     if get_task_index() == 0:
         saver = Saver()
         saver.restore(checkpoint_version)
     data, labels = input_fn()
     model_outputs = self._model_fn(data, labels)
     if len(model_outputs) < 2:
         raise ArgumentError("model_fn must return loss and logits")
     logits = model_outputs[1]
     hooks = []
     hooks.append(QpsMetricsHook())
     hooks.append(MetricsHook("prediction", logits, interval=1))
     if user_hooks is not None:
         if isinstance(user_hooks, list):
             hooks.extend(user_hooks)
         else:
             hooks.append(user_hooks)
     hooks.append(MetricsPrinterHook(log_format, log_interval))
     sess = TrainSession(hooks=hooks)
     i = 0
     while not sess.should_stop() and i < max_step:
         sess.run([])
         i = i + 1
示例#6
0
 def __init__(self, data_io, save_interval=100):
     super(ReaderStateHook, self).__init__()
     self._worker_num = get_task_num()
     self._worker_index = get_task_index()
     self._step = 0
     self._save_interval = save_interval
     self._last_state = np.zeros((1, _STATE_SIZE), dtype=np.int8)
     self._data_io = data_io
     with xdl.python.framework.variable.variable_info(io_ratio=0.5 /
                                                      save_interval):
         self._state_var = Variable(name=self._data_io.ds_name +
                                    "/reader_state",
                                    shape=[self._worker_num, _STATE_SIZE],
                                    dtype=DataType.int8,
                                    initializer=Zeros(),
                                    trainable=False)
         self._offset_var = Variable(name=self._data_io.ds_name +
                                     '/reader_offset',
                                     shape=[self._worker_num],
                                     dtype=DataType.int64,
                                     initializer=Zeros(),
                                     trainable=False,
                                     vtype='index')
示例#7
0
    def train_and_evaluate(self,
                           train_input_fn,
                           eval_input_fn,
                           eval_interval,
                           eval_steps,
                           checkpoint_interval,
                           auc_fn=auc,
                           auc_bucket_num=200,
                           train_hooks=None,
                           eval_hooks=None,
                           auc_interval=100,
                           log_interval=100,
                           log_format=LOG_FMT,
                           eval_log_format=EVAL_LOG_FMT,
                           max_step=sys.maxint):
        with model_scope('train'):
            datas, labels = train_input_fn()
            train_outputs = self._model_fn(datas, labels)
            if len(train_outputs) < 2:
                raise ArgumentError("model_fn must return loss and logits")
            loss = train_outputs[0]
            logits = train_outputs[1]
            train_op = self._optimizer.optimize()
            auc_op = auc_fn(logits,
                            labels,
                            num_thresholds=auc_bucket_num,
                            namescope="train_auc")

            train_hooks = []
            train_hooks.append(QpsMetricsHook())
            train_hooks.append(
                MetricsHook("auc", auc_op, interval=auc_interval))
            if train_hooks is not None:
                if isinstance(train_hooks, list):
                    train_hooks.extend(train_hooks)
                else:
                    train_hooks.append(train_hooks)
            reader_hooks = get_collection(READER_HOOKS)
            if reader_hooks is not None:
                train_hooks.extend(reader_hooks)
            if checkpoint_interval and get_task_index() == 0:
                train_hooks.append(CheckpointHook(checkpoint_interval))
            train_hooks.append(MetricsPrinterHook(log_format, auc_interval))
            train_sess = TrainSession(hooks=train_hooks)

        with model_scope('test'):
            eval_datas, eval_labels = eval_input_fn()
            eval_outputs = self._model_fn(eval_datas, eval_labels)
            if len(eval_outputs) < 2:
                raise ArgumentError("model_fn must return loss and logits")
            eval_logits = eval_outputs[1]
            eval_auc_op = auc_fn(eval_logits,
                                 eval_labels,
                                 num_thresholds=auc_bucket_num,
                                 namescope="eval_auc")
            eval_hooks = []
            eval_hooks.append(QpsMetricsHook())
            eval_hooks.append(MetricsHook("auc", eval_auc_op, interval=1))
            if eval_hooks is not None:
                if isinstance(eval_hooks, list):
                    eval_hooks.extend(eval_hooks)
                else:
                    eval_hooks.append(eval_hooks)
            eval_hooks.append(MetricsPrinterHook(eval_log_format,
                                                 log_interval))
            eval_sess = TrainSession(hooks=eval_hooks)

        lstep = 0
        while True:
            print('\n>>> start train at local step[%d]\n' % lstep)
            while not train_sess.should_stop() and (lstep == 0 or lstep % eval_interval != 0) \
                  and lstep < max_step:
                train_sess.run(train_op)
                lstep = lstep + 1
            lstep = lstep + 1
            eval_step = 0
            print('\n>>> start evaluate at local step[%d]\n' % lstep)
            while not eval_sess.should_stop() and eval_step < eval_steps:
                eval_sess.run([])
                eval_step = eval_step + 1
            if train_sess.should_stop() or lstep >= max_step:
                break