def __init__(self, hooks=None): current_env().sess_start() self._hooks = [] if hooks is None else hooks reader_hooks = get_collection(READER_HOOKS) if reader_hooks is not None: self._hooks.extend(reader_hooks) self._cur_scope = cur_model_scope() self._session = SimpleSession(hooks) self._finish = False
def train(self, input_fn, auc_fn=auc, auc_interval=100, auc_bucket_num=200, max_step=sys.maxint, checkpoint_interval=None, log_format=LOG_FMT, user_hooks=None): ''' Args: input_fn: auc_fn: auc_interval: max_step: checkpoint_interval: log_format: user_hooks: ''' data, labels = input_fn() model_outputs = self._model_fn(data, labels) if len(model_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") loss = model_outputs[0] logits = model_outputs[1] train_op = self._optimizer.optimize() auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num) hooks = [] hooks.append(QpsMetricsHook()) hooks.append(MetricsHook("auc", auc_op, interval=auc_interval)) if user_hooks is not None: if isinstance(user_hooks, list): hooks.extend(user_hooks) else: hooks.append(user_hooks) reader_hooks = get_collection(READER_HOOKS) if reader_hooks is not None: hooks.extend(reader_hooks) if checkpoint_interval and get_task_index() == 0: hooks.append(CheckpointHook(checkpoint_interval)) hooks.append(MetricsPrinterHook(log_format, auc_interval)) sess = TrainSession(hooks=hooks) i = 0 while not sess.should_stop() and i < max_step: sess.run(train_op) i = i + 1
def train_and_evaluate(self, train_input_fn, eval_input_fn, eval_interval, eval_steps, checkpoint_interval, auc_fn=auc, auc_bucket_num=200, train_hooks=None, eval_hooks=None, auc_interval=100, log_interval=100, log_format=LOG_FMT, eval_log_format=EVAL_LOG_FMT, max_step=sys.maxint): with model_scope('train'): datas, labels = train_input_fn() train_outputs = self._model_fn(datas, labels) if len(train_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") loss = train_outputs[0] logits = train_outputs[1] train_op = self._optimizer.optimize() auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num, namescope="train_auc") train_hooks = [] train_hooks.append(QpsMetricsHook()) train_hooks.append( MetricsHook("auc", auc_op, interval=auc_interval)) if train_hooks is not None: if isinstance(train_hooks, list): train_hooks.extend(train_hooks) else: train_hooks.append(train_hooks) reader_hooks = get_collection(READER_HOOKS) if reader_hooks is not None: train_hooks.extend(reader_hooks) if checkpoint_interval and get_task_index() == 0: train_hooks.append(CheckpointHook(checkpoint_interval)) train_hooks.append(MetricsPrinterHook(log_format, auc_interval)) train_sess = TrainSession(hooks=train_hooks) with model_scope('test'): eval_datas, eval_labels = eval_input_fn() eval_outputs = self._model_fn(eval_datas, eval_labels) if len(eval_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") eval_logits = eval_outputs[1] eval_auc_op = auc_fn(eval_logits, eval_labels, num_thresholds=auc_bucket_num, namescope="eval_auc") eval_hooks = [] eval_hooks.append(QpsMetricsHook()) eval_hooks.append(MetricsHook("auc", eval_auc_op, interval=1)) if eval_hooks is not None: if isinstance(eval_hooks, list): eval_hooks.extend(eval_hooks) else: eval_hooks.append(eval_hooks) eval_hooks.append(MetricsPrinterHook(eval_log_format, log_interval)) eval_sess = TrainSession(hooks=eval_hooks) lstep = 0 while True: print('\n>>> start train at local step[%d]\n' % lstep) while not train_sess.should_stop() and (lstep == 0 or lstep % eval_interval != 0) \ and lstep < max_step: train_sess.run(train_op) lstep = lstep + 1 lstep = lstep + 1 eval_step = 0 print('\n>>> start evaluate at local step[%d]\n' % lstep) while not eval_sess.should_stop() and eval_step < eval_steps: eval_sess.run([]) eval_step = eval_step + 1 if train_sess.should_stop() or lstep >= max_step: break
def get_sparse_grads(name): return get_collection('sparse_grad')[0][name]
def run(is_training, files): data_io = reader("esmm", files, 2, batch_size, 2, user_fn, ad_fn) batch = data_io.read() user_embs = list() for fn in user_fn: emb = xdl.embedding('u_' + fn, batch[fn], xdl.TruncatedNormal(stddev=0.001), embed_size, 1000, 'sum', vtype='hash') user_embs.append(emb) ad_embs = list() for fn in ad_fn: emb = xdl.embedding('a_' + fn, batch[fn], xdl.TruncatedNormal(stddev=0.001), embed_size, 1000, 'sum', vtype='hash') ad_embs.append(emb) var_list = model(is_training)(ad_embs, user_embs, batch["indicators"][0], batch["label"]) keys = [ 'loss', 'ctr_prop', 'ctcvr_prop', 'cvr_prop', 'ctr_label', 'ctcvr_label', 'cvr_label' ] run_vars = dict(zip(keys, list(var_list))) hooks = [] if is_training: train_op = xdl.Adam(lr).optimize() hooks = get_collection(READER_HOOKS) if hooks is None: hooks = [] if xdl.get_task_index() == 0: ckpt_hook = xdl.CheckpointHook(1000) hooks.append(ckpt_hook) run_vars.update({None: train_op}) if is_debug > 1: print("=========gradients") grads = xdl.get_gradients() grads_keys = grads[''].keys() grads_keys.sort() for key in grads_keys: run_vars.update({"grads {}".format(key): grads[''][key]}) hooks.append(QpsMetricsHook()) log_format = "lstep[%(lstep)s] gstep[%(gstep)s] " \ "lqps[%(lqps)s] gqps[%(gqps)s]" hooks.append(MetricsPrinterHook(log_format, 100)) ckpt = xdl.get_config("checkpoint", "ckpt") if ckpt is not None and len(ckpt) > 0: if int(xdl.get_task_index()) == 0: from xdl.python.training.saver import Saver saver = Saver() print("restore from %s" % ckpt) saver.restore(ckpt) else: time.sleep(120) sess = xdl.TrainSession(hooks) if is_training: itr = 1 ctr_auc = Auc('ctr') ctcvr_auc = Auc('ctcvr') cvr_auc = Auc('cvr') while not sess.should_stop(): print('iter=', itr) values = sess.run(run_vars.values()) if not values: continue value_map = dict(zip(run_vars.keys(), values)) print('loss=', value_map['loss']) ctr_auc.add(value_map['ctr_prop'], value_map['ctr_label']) ctcvr_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label']) cvr_auc.add_with_filter(value_map['cvr_prop'], value_map['cvr_label'], np.where(value_map['ctr_label'] == 1)) itr += 1 ctr_auc.show() ctcvr_auc.show() cvr_auc.show() else: ctr_test_auc = Auc('ctr') ctcvr_test_auc = Auc('ctcvr') cvr_test_auc = Auc('cvr') for i in xrange(test_batch_num): print('iter=', i + 1) values = sess.run(run_vars.values()) value_map = dict(zip(run_vars.keys(), values)) print('test_loss=', value_map['loss']) ctr_test_auc.add(value_map['ctr_prop'], value_map['ctr_label']) ctcvr_test_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label']) cvr_test_auc.add_with_filter(value_map['cvr_prop'], value_map['cvr_label'], np.where(value_map['ctr_label'] == 1)) ctr_test_auc.show() ctcvr_test_auc.show() cvr_test_auc.show()