def train(): if model_type == 'din_mogujie': model = Model_DIN_MOGUJIE(EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE, False, train_file, batch_size) else: raise Exception('only support din_mogujie and dien') #data set with xdl.model_scope('train'): train_ops = model.build_network() lr = 0.001 # Adam Adagrad train_ops.append(xdl.Adam(lr).optimize()) hooks = [] log_format = "[%(time)s] lstep[%(lstep)s] gstep[%(gstep)s] lqps[%(lqps)s] gqps[%(gqps)s] loss[%(loss)s]" hooks = [QpsMetricsHook(), MetricsPrinterHook(log_format)] if xdl.get_task_index() == 0: hooks.append(xdl.CheckpointHook(save_interval)) train_sess = xdl.TrainSession(hooks=hooks) """ with xdl.model_scope('test'): test_ops = model.build_network( EMBEDDING_DIM, is_train=False) test_sess = xdl.TrainSession() """ model.run(train_ops, train_sess)
def train(train_file=train_file, test_file=test_file, uid_voc=uid_voc, mid_voc=mid_voc, cat_voc=cat_voc, item_info=item_info, reviews_info=reviews_info, batch_size=128, maxlen=100, test_iter=700): model = Model_DIEN(EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE, LIGHT_EMBEDDING_DIM, LIGHT_HIDDEN_SIZE, LIGHT_ATTENTION_SIZE, use_rocket_training=use_rocket_training()) sample_io = SampleIO(train_file, test_file, uid_voc, mid_voc, cat_voc, item_info, reviews_info, batch_size, maxlen, embedding_dim=EMBEDDING_DIM, light_embedding_dim=LIGHT_EMBEDDING_DIM) with xdl.model_scope('train'): train_ops = model.build_final_net(EMBEDDING_DIM, LIGHT_EMBEDDING_DIM, sample_io) lr = 0.001 # Adam Adagrad train_ops.append(xdl.Adam(lr).optimize()) hooks = [] log_format = "[%(time)s] lstep[%(lstep)s] gstep[%(gstep)s] lqps[%(lqps)s] gqps[%(gqps)s] loss[%(loss)s]" hooks = [QpsMetricsHook(), MetricsPrinterHook(log_format)] if xdl.get_task_index() == 0: hooks.append( xdl.CheckpointHook( xdl.get_config('checkpoint', 'save_interval'))) train_sess = xdl.TrainSession(hooks=hooks) with xdl.model_scope('test'): test_ops = model.build_final_net(EMBEDDING_DIM, LIGHT_EMBEDDING_DIM, sample_io, is_train=False) test_sess = xdl.TrainSession() model.run(train_ops, train_sess, test_ops, test_sess, test_iter=test_iter)
def train(self, input_fn, auc_fn=auc, auc_interval=100, auc_bucket_num=200, max_step=sys.maxint, checkpoint_interval=None, log_format=LOG_FMT, user_hooks=None): ''' Args: input_fn: auc_fn: auc_interval: max_step: checkpoint_interval: log_format: user_hooks: ''' data, labels = input_fn() model_outputs = self._model_fn(data, labels) if len(model_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") loss = model_outputs[0] logits = model_outputs[1] train_op = self._optimizer.optimize() auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num) hooks = [] hooks.append(QpsMetricsHook()) hooks.append(MetricsHook("auc", auc_op, interval=auc_interval)) if user_hooks is not None: if isinstance(user_hooks, list): hooks.extend(user_hooks) else: hooks.append(user_hooks) reader_hooks = get_collection(READER_HOOKS) if reader_hooks is not None: hooks.extend(reader_hooks) if checkpoint_interval and get_task_index() == 0: hooks.append(CheckpointHook(checkpoint_interval)) hooks.append(MetricsPrinterHook(log_format, auc_interval)) sess = TrainSession(hooks=hooks) i = 0 while not sess.should_stop() and i < max_step: sess.run(train_op) i = i + 1
def evaluate(self, input_fn, checkpoint_version="", log_format=EVAL_LOG_FMT, log_interval=100, max_step=sys.maxint, auc_fn=auc, auc_bucket_num=200, user_hooks=None): ''' Args: input_fn: checkpoint_version: log_format: log_interval: max_step: auc_fn: user_hooks: ''' from xdl.python.training.saver import Saver if get_task_index() == 0: saver = Saver() saver.restore(checkpoint_version) data, labels = input_fn() model_outputs = self._model_fn(data, labels) if len(model_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") logits = model_outputs[1] auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num) hooks = [] hooks.append(QpsMetricsHook()) hooks.append(MetricsHook("auc", auc_op, interval=1)) if user_hooks is not None: if isinstance(user_hooks, list): hooks.extend(user_hooks) else: hooks.append(user_hooks) hooks.append(MetricsPrinterHook(log_format, log_interval)) sess = TrainSession(hooks=hooks) if id(auc_fn) == id(auc): sess.run(reset_auc_variables_op(auc_bucket_num)) i = 0 while not sess.should_stop() and i < max_step: sess.run([]) i = i + 1
def train(train_file=train_file, test_file=test_file, uid_voc=uid_voc, mid_voc=mid_voc, cat_voc=cat_voc, item_info=item_info, reviews_info=reviews_info, batch_size=128, maxlen=100, test_iter=700): if xdl.get_config('model') == 'din': model = Model_DIN(EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE) elif xdl.get_config('model') == 'dien': model = Model_DIEN(EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE) else: raise Exception('only support din and dien') sample_io = SampleIO(train_file, test_file, uid_voc, mid_voc, cat_voc, item_info, reviews_info, batch_size, maxlen, EMBEDDING_DIM) with xdl.model_scope('train'): train_ops = model.build_final_net(EMBEDDING_DIM, sample_io) lr = 0.001 # Adam Adagrad train_ops.append(xdl.Adam(lr).optimize()) hooks = [] log_format = "[%(time)s] lstep[%(lstep)s] gstep[%(gstep)s] lqps[%(lqps)s] gqps[%(gqps)s] loss[%(loss)s]" hooks = [QpsMetricsHook(), MetricsPrinterHook(log_format)] if xdl.get_task_index() == 0: hooks.append( xdl.CheckpointHook( xdl.get_config('checkpoint', 'save_interval'))) train_sess = xdl.TrainSession(hooks=hooks) with xdl.model_scope('test'): test_ops = model.build_final_net(EMBEDDING_DIM, sample_io, is_train=False) test_sess = xdl.TrainSession() print('=' * 10 + 'start train' + '=' * 10) model.run(train_ops, train_sess, test_ops, test_sess, test_iter=test_iter)
def predict(self, input_fn, checkpoint_version="", log_format=PREDICT_LOG_FMT, log_interval=100, max_step=sys.maxint, user_hooks=None): ''' Args: input_fn: checkpoint_version: log_format: log_interval: max_step: user_hooks: ''' from xdl.python.training.saver import Saver if get_task_index() == 0: saver = Saver() saver.restore(checkpoint_version) data, labels = input_fn() model_outputs = self._model_fn(data, labels) if len(model_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") logits = model_outputs[1] hooks = [] hooks.append(QpsMetricsHook()) hooks.append(MetricsHook("prediction", logits, interval=1)) if user_hooks is not None: if isinstance(user_hooks, list): hooks.extend(user_hooks) else: hooks.append(user_hooks) hooks.append(MetricsPrinterHook(log_format, log_interval)) sess = TrainSession(hooks=hooks) i = 0 while not sess.should_stop() and i < max_step: sess.run([]) i = i + 1
def train_and_evaluate(self, train_input_fn, eval_input_fn, eval_interval, eval_steps, checkpoint_interval, auc_fn=auc, auc_bucket_num=200, train_hooks=None, eval_hooks=None, auc_interval=100, log_interval=100, log_format=LOG_FMT, eval_log_format=EVAL_LOG_FMT, max_step=sys.maxint): with model_scope('train'): datas, labels = train_input_fn() train_outputs = self._model_fn(datas, labels) if len(train_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") loss = train_outputs[0] logits = train_outputs[1] train_op = self._optimizer.optimize() auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num, namescope="train_auc") train_hooks = [] train_hooks.append(QpsMetricsHook()) train_hooks.append( MetricsHook("auc", auc_op, interval=auc_interval)) if train_hooks is not None: if isinstance(train_hooks, list): train_hooks.extend(train_hooks) else: train_hooks.append(train_hooks) reader_hooks = get_collection(READER_HOOKS) if reader_hooks is not None: train_hooks.extend(reader_hooks) if checkpoint_interval and get_task_index() == 0: train_hooks.append(CheckpointHook(checkpoint_interval)) train_hooks.append(MetricsPrinterHook(log_format, auc_interval)) train_sess = TrainSession(hooks=train_hooks) with model_scope('test'): eval_datas, eval_labels = eval_input_fn() eval_outputs = self._model_fn(eval_datas, eval_labels) if len(eval_outputs) < 2: raise ArgumentError("model_fn must return loss and logits") eval_logits = eval_outputs[1] eval_auc_op = auc_fn(eval_logits, eval_labels, num_thresholds=auc_bucket_num, namescope="eval_auc") eval_hooks = [] eval_hooks.append(QpsMetricsHook()) eval_hooks.append(MetricsHook("auc", eval_auc_op, interval=1)) if eval_hooks is not None: if isinstance(eval_hooks, list): eval_hooks.extend(eval_hooks) else: eval_hooks.append(eval_hooks) eval_hooks.append(MetricsPrinterHook(eval_log_format, log_interval)) eval_sess = TrainSession(hooks=eval_hooks) lstep = 0 while True: print('\n>>> start train at local step[%d]\n' % lstep) while not train_sess.should_stop() and (lstep == 0 or lstep % eval_interval != 0) \ and lstep < max_step: train_sess.run(train_op) lstep = lstep + 1 lstep = lstep + 1 eval_step = 0 print('\n>>> start evaluate at local step[%d]\n' % lstep) while not eval_sess.should_stop() and eval_step < eval_steps: eval_sess.run([]) eval_step = eval_step + 1 if train_sess.should_stop() or lstep >= max_step: break
def run(is_training, files): data_io = reader("esmm", files, 2, batch_size, 2, user_fn, ad_fn) batch = data_io.read() user_embs = list() for fn in user_fn: emb = xdl.embedding('u_' + fn, batch[fn], xdl.TruncatedNormal(stddev=0.001), embed_size, 1000, 'sum', vtype='hash') user_embs.append(emb) ad_embs = list() for fn in ad_fn: emb = xdl.embedding('a_' + fn, batch[fn], xdl.TruncatedNormal(stddev=0.001), embed_size, 1000, 'sum', vtype='hash') ad_embs.append(emb) var_list = model(is_training)(ad_embs, user_embs, batch["indicators"][0], batch["label"]) keys = [ 'loss', 'ctr_prop', 'ctcvr_prop', 'cvr_prop', 'ctr_label', 'ctcvr_label', 'cvr_label' ] run_vars = dict(zip(keys, list(var_list))) hooks = [] if is_training: train_op = xdl.Adam(lr).optimize() hooks = get_collection(READER_HOOKS) if hooks is None: hooks = [] if xdl.get_task_index() == 0: ckpt_hook = xdl.CheckpointHook(1000) hooks.append(ckpt_hook) run_vars.update({None: train_op}) if is_debug > 1: print("=========gradients") grads = xdl.get_gradients() grads_keys = grads[''].keys() grads_keys.sort() for key in grads_keys: run_vars.update({"grads {}".format(key): grads[''][key]}) hooks.append(QpsMetricsHook()) log_format = "lstep[%(lstep)s] gstep[%(gstep)s] " \ "lqps[%(lqps)s] gqps[%(gqps)s]" hooks.append(MetricsPrinterHook(log_format, 100)) ckpt = xdl.get_config("checkpoint", "ckpt") if ckpt is not None and len(ckpt) > 0: if int(xdl.get_task_index()) == 0: from xdl.python.training.saver import Saver saver = Saver() print("restore from %s" % ckpt) saver.restore(ckpt) else: time.sleep(120) sess = xdl.TrainSession(hooks) if is_training: itr = 1 ctr_auc = Auc('ctr') ctcvr_auc = Auc('ctcvr') cvr_auc = Auc('cvr') while not sess.should_stop(): print('iter=', itr) values = sess.run(run_vars.values()) if not values: continue value_map = dict(zip(run_vars.keys(), values)) print('loss=', value_map['loss']) ctr_auc.add(value_map['ctr_prop'], value_map['ctr_label']) ctcvr_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label']) cvr_auc.add_with_filter(value_map['cvr_prop'], value_map['cvr_label'], np.where(value_map['ctr_label'] == 1)) itr += 1 ctr_auc.show() ctcvr_auc.show() cvr_auc.show() else: ctr_test_auc = Auc('ctr') ctcvr_test_auc = Auc('ctcvr') cvr_test_auc = Auc('cvr') for i in xrange(test_batch_num): print('iter=', i + 1) values = sess.run(run_vars.values()) value_map = dict(zip(run_vars.keys(), values)) print('test_loss=', value_map['loss']) ctr_test_auc.add(value_map['ctr_prop'], value_map['ctr_label']) ctcvr_test_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label']) cvr_test_auc.add_with_filter(value_map['cvr_prop'], value_map['cvr_label'], np.where(value_map['ctr_label'] == 1)) ctr_test_auc.show() ctcvr_test_auc.show() cvr_test_auc.show()