def run(): user_ms = xdl.ModelServer( "user_graph", user_graph_train, xdl.DataType.float, xdl.ModelServer.Forward.UniqueCache(xdl.get_task_num()), xdl.ModelServer.Backward.UniqueCache(xdl.get_task_num())) xdl.current_env().start_model_server(user_ms) ad_ms = xdl.ModelServer( "ad_graph", ad_graph_train, xdl.DataType.float, xdl.ModelServer.Forward.UniqueCache(xdl.get_task_num()), xdl.ModelServer.Backward.UniqueCache(xdl.get_task_num())) xdl.current_env().start_model_server(ad_ms) batch = reader().read() user0 = xdl.embedding("user0", batch["user0"], xdl.TruncatedNormal(stddev=0.001), 16, 2 * 1024 * 1024, "sum", vtype="hash") user1 = xdl.embedding("user1", batch["user1"], xdl.TruncatedNormal(stddev=0.001), 16, 2 * 1024 * 1024, "sum", vtype="hash") ad0 = batch["ad0"] ad1 = batch["ad1"] img0 = user_ms(batch["user_img"].ids) ids0 = xdl.py_func(to_tf_segment_id, [batch["user_img"].segments], [np.int32])[0] img1 = ad_ms(batch["ad_img"].ids) ids1 = xdl.py_func(to_tf_segment_id, [batch["ad_img"].segments], [np.int32])[0] label = batch['label'] loss = ams_main(main_model)(user0, user1, ad0, ad1, label, ids0, ids1, gear_inputs=[img0, img1]) optimizer = xdl.Adam(0.0005).optimize() run_ops = [loss, optimizer] sess = xdl.TrainSession([]) while not sess.should_stop(): values = sess.run(run_ops) if values is not None: print 'loss: ', values[0]
def train(): batch = reader.read() sess = xdl.TrainSession() emb1 = xdl.embedding('emb1', batch['sparse0'], xdl.TruncatedNormal(stddev=0.001), 8, 1024, vtype='hash') emb2 = xdl.embedding('emb2', batch['sparse1'], xdl.TruncatedNormal(stddev=0.001), 8, 1024, vtype='hash') loss = model(batch['deep0'], [emb1, emb2], batch['label']) train_op = xdl.SGD(0.5).optimize() log_hook = xdl.LoggerHook(loss, "loss:{0}", 10) sess = xdl.TrainSession(hooks=[log_hook]) while not sess.should_stop(): sess.run(train_op)
def get_xdl_initializer(name="glorot"): if name == "const": return xdl.Constant(0.3) elif name == "glorot": return xdl.VarianceScaling(scale=1.0, mode="fan_avg", distribution="normal", seed=3) elif name == "normal": return xdl.TruncatedNormal(stddev=0.36, seed=3)
def run(is_training, files): data_io = reader("esmm", files, 2, batch_size, 2, user_fn, ad_fn) batch = data_io.read() user_embs = list() for fn in user_fn: emb = xdl.embedding('u_' + fn, batch[fn], xdl.TruncatedNormal(stddev=0.001), embed_size, 1000, 'sum', vtype='hash') user_embs.append(emb) ad_embs = list() for fn in ad_fn: emb = xdl.embedding('a_' + fn, batch[fn], xdl.TruncatedNormal(stddev=0.001), embed_size, 1000, 'sum', vtype='hash') ad_embs.append(emb) var_list = model(is_training)(ad_embs, user_embs, batch["indicators"][0], batch["label"]) keys = [ 'loss', 'ctr_prop', 'ctcvr_prop', 'cvr_prop', 'ctr_label', 'ctcvr_label', 'cvr_label' ] run_vars = dict(zip(keys, list(var_list))) hooks = [] if is_training: train_op = xdl.Adam(lr).optimize() hooks = get_collection(READER_HOOKS) if hooks is None: hooks = [] if xdl.get_task_index() == 0: ckpt_hook = xdl.CheckpointHook(1000) hooks.append(ckpt_hook) run_vars.update({None: train_op}) if is_debug > 1: print("=========gradients") grads = xdl.get_gradients() grads_keys = grads[''].keys() grads_keys.sort() for key in grads_keys: run_vars.update({"grads {}".format(key): grads[''][key]}) hooks.append(QpsMetricsHook()) log_format = "lstep[%(lstep)s] gstep[%(gstep)s] " \ "lqps[%(lqps)s] gqps[%(gqps)s]" hooks.append(MetricsPrinterHook(log_format, 100)) ckpt = xdl.get_config("checkpoint", "ckpt") if ckpt is not None and len(ckpt) > 0: if int(xdl.get_task_index()) == 0: from xdl.python.training.saver import Saver saver = Saver() print("restore from %s" % ckpt) saver.restore(ckpt) else: time.sleep(120) sess = xdl.TrainSession(hooks) if is_training: itr = 1 ctr_auc = Auc('ctr') ctcvr_auc = Auc('ctcvr') cvr_auc = Auc('cvr') while not sess.should_stop(): print('iter=', itr) values = sess.run(run_vars.values()) if not values: continue value_map = dict(zip(run_vars.keys(), values)) print('loss=', value_map['loss']) ctr_auc.add(value_map['ctr_prop'], value_map['ctr_label']) ctcvr_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label']) cvr_auc.add_with_filter(value_map['cvr_prop'], value_map['cvr_label'], np.where(value_map['ctr_label'] == 1)) itr += 1 ctr_auc.show() ctcvr_auc.show() cvr_auc.show() else: ctr_test_auc = Auc('ctr') ctcvr_test_auc = Auc('ctcvr') cvr_test_auc = Auc('cvr') for i in xrange(test_batch_num): print('iter=', i + 1) values = sess.run(run_vars.values()) value_map = dict(zip(run_vars.keys(), values)) print('test_loss=', value_map['loss']) ctr_test_auc.add(value_map['ctr_prop'], value_map['ctr_label']) ctcvr_test_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label']) cvr_test_auc.add_with_filter(value_map['cvr_prop'], value_map['cvr_label'], np.where(value_map['ctr_label'] == 1)) ctr_test_auc.show() ctcvr_test_auc.show() cvr_test_auc.show()