def __init__(self, data_io, save_interval=100): super(GlobalReaderStateHook, self).__init__() self._worker_num = get_task_num() self._worker_index = get_task_index() self._step = 0 self._save_interval = save_interval self._data_io = data_io self._state = io_state_pb2.DSState() self._state.ds_name = data_io.ds_name self._state.epochs = 0
def __init__(self, data_io, save_interval=100): self._worker_num = get_task_num() self._worker_index = get_task_index() self._step = 0 self._save_interval = save_interval self._data_io = data_io self._state_var = Variable(name=self._data_io.ds_name + "/reader_state", shape=[self._worker_num, _STATE_SIZE], dtype=DataType.int8, initializer=Zeros(), trainable=False)
def train(self, input_fn, auc_fn=auc, auc_interval=100, auc_bucket_num=200, max_step=sys.maxint, checkpoint_interval=None, log_format=LOG_FMT, user_hooks=None): ''' Args: input_fn: auc_fn: auc_interval: max_step: checkpoint_interval: log_format: user_hooks: ''' data, labels = input_fn() model_outputs = self._model_fn(data, labels) if len(model_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") loss = model_outputs[0] logits = model_outputs[1] train_op = self._optimizer.optimize() auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num) hooks = [] hooks.append(QpsMetricsHook()) hooks.append(MetricsHook("auc", auc_op, interval=auc_interval)) if user_hooks is not None: if isinstance(user_hooks, list): hooks.extend(user_hooks) else: hooks.append(user_hooks) reader_hooks = get_collection(READER_HOOKS) if reader_hooks is not None: hooks.extend(reader_hooks) if checkpoint_interval and get_task_index() == 0: hooks.append(CheckpointHook(checkpoint_interval)) hooks.append(MetricsPrinterHook(log_format, auc_interval)) sess = TrainSession(hooks=hooks) i = 0 while not sess.should_stop() and i < max_step: sess.run(train_op) i = i + 1
def evaluate(self, input_fn, checkpoint_version="", log_format=EVAL_LOG_FMT, log_interval=100, max_step=sys.maxint, auc_fn=auc, auc_bucket_num=200, user_hooks=None): ''' Args: input_fn: checkpoint_version: log_format: log_interval: max_step: auc_fn: user_hooks: ''' from xdl.python.training.saver import Saver if get_task_index() == 0: saver = Saver() saver.restore(checkpoint_version) data, labels = input_fn() model_outputs = self._model_fn(data, labels) if len(model_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") logits = model_outputs[1] auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num) hooks = [] hooks.append(QpsMetricsHook()) hooks.append(MetricsHook("auc", auc_op, interval=1)) if user_hooks is not None: if isinstance(user_hooks, list): hooks.extend(user_hooks) else: hooks.append(user_hooks) hooks.append(MetricsPrinterHook(log_format, log_interval)) sess = TrainSession(hooks=hooks) if id(auc_fn) == id(auc): sess.run(reset_auc_variables_op(auc_bucket_num)) i = 0 while not sess.should_stop() and i < max_step: sess.run([]) i = i + 1
def predict(self, input_fn, checkpoint_version="", log_format=PREDICT_LOG_FMT, log_interval=100, max_step=sys.maxint, user_hooks=None): ''' Args: input_fn: checkpoint_version: log_format: log_interval: max_step: user_hooks: ''' from xdl.python.training.saver import Saver if get_task_index() == 0: saver = Saver() saver.restore(checkpoint_version) data, labels = input_fn() model_outputs = self._model_fn(data, labels) if len(model_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") logits = model_outputs[1] hooks = [] hooks.append(QpsMetricsHook()) hooks.append(MetricsHook("prediction", logits, interval=1)) if user_hooks is not None: if isinstance(user_hooks, list): hooks.extend(user_hooks) else: hooks.append(user_hooks) hooks.append(MetricsPrinterHook(log_format, log_interval)) sess = TrainSession(hooks=hooks) i = 0 while not sess.should_stop() and i < max_step: sess.run([]) i = i + 1
def __init__(self, data_io, save_interval=100): super(ReaderStateHook, self).__init__() self._worker_num = get_task_num() self._worker_index = get_task_index() self._step = 0 self._save_interval = save_interval self._last_state = np.zeros((1, _STATE_SIZE), dtype=np.int8) self._data_io = data_io with xdl.python.framework.variable.variable_info(io_ratio=0.5 / save_interval): self._state_var = Variable(name=self._data_io.ds_name + "/reader_state", shape=[self._worker_num, _STATE_SIZE], dtype=DataType.int8, initializer=Zeros(), trainable=False) self._offset_var = Variable(name=self._data_io.ds_name + '/reader_offset', shape=[self._worker_num], dtype=DataType.int64, initializer=Zeros(), trainable=False, vtype='index')
def train_and_evaluate(self, train_input_fn, eval_input_fn, eval_interval, eval_steps, checkpoint_interval, auc_fn=auc, auc_bucket_num=200, train_hooks=None, eval_hooks=None, auc_interval=100, log_interval=100, log_format=LOG_FMT, eval_log_format=EVAL_LOG_FMT, max_step=sys.maxint): with model_scope('train'): datas, labels = train_input_fn() train_outputs = self._model_fn(datas, labels) if len(train_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") loss = train_outputs[0] logits = train_outputs[1] train_op = self._optimizer.optimize() auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num, namescope="train_auc") train_hooks = [] train_hooks.append(QpsMetricsHook()) train_hooks.append( MetricsHook("auc", auc_op, interval=auc_interval)) if train_hooks is not None: if isinstance(train_hooks, list): train_hooks.extend(train_hooks) else: train_hooks.append(train_hooks) reader_hooks = get_collection(READER_HOOKS) if reader_hooks is not None: train_hooks.extend(reader_hooks) if checkpoint_interval and get_task_index() == 0: train_hooks.append(CheckpointHook(checkpoint_interval)) train_hooks.append(MetricsPrinterHook(log_format, auc_interval)) train_sess = TrainSession(hooks=train_hooks) with model_scope('test'): eval_datas, eval_labels = eval_input_fn() eval_outputs = self._model_fn(eval_datas, eval_labels) if len(eval_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") eval_logits = eval_outputs[1] eval_auc_op = auc_fn(eval_logits, eval_labels, num_thresholds=auc_bucket_num, namescope="eval_auc") eval_hooks = [] eval_hooks.append(QpsMetricsHook()) eval_hooks.append(MetricsHook("auc", eval_auc_op, interval=1)) if eval_hooks is not None: if isinstance(eval_hooks, list): eval_hooks.extend(eval_hooks) else: eval_hooks.append(eval_hooks) eval_hooks.append(MetricsPrinterHook(eval_log_format, log_interval)) eval_sess = TrainSession(hooks=eval_hooks) lstep = 0 while True: print('\n>>> start train at local step[%d]\n' % lstep) while not train_sess.should_stop() and (lstep == 0 or lstep % eval_interval != 0) \ and lstep < max_step: train_sess.run(train_op) lstep = lstep + 1 lstep = lstep + 1 eval_step = 0 print('\n>>> start evaluate at local step[%d]\n' % lstep) while not eval_sess.should_stop() and eval_step < eval_steps: eval_sess.run([]) eval_step = eval_step + 1 if train_sess.should_stop() or lstep >= max_step: break