Exemple #1
0
def run():
    user_ms = xdl.ModelServer(
        "user_graph", user_graph_train, xdl.DataType.float,
        xdl.ModelServer.Forward.UniqueCache(xdl.get_task_num()),
        xdl.ModelServer.Backward.UniqueCache(xdl.get_task_num()))
    xdl.current_env().start_model_server(user_ms)
    ad_ms = xdl.ModelServer(
        "ad_graph", ad_graph_train, xdl.DataType.float,
        xdl.ModelServer.Forward.UniqueCache(xdl.get_task_num()),
        xdl.ModelServer.Backward.UniqueCache(xdl.get_task_num()))
    xdl.current_env().start_model_server(ad_ms)
    batch = reader().read()

    user0 = xdl.embedding("user0",
                          batch["user0"],
                          xdl.TruncatedNormal(stddev=0.001),
                          16,
                          2 * 1024 * 1024,
                          "sum",
                          vtype="hash")

    user1 = xdl.embedding("user1",
                          batch["user1"],
                          xdl.TruncatedNormal(stddev=0.001),
                          16,
                          2 * 1024 * 1024,
                          "sum",
                          vtype="hash")

    ad0 = batch["ad0"]
    ad1 = batch["ad1"]
    img0 = user_ms(batch["user_img"].ids)
    ids0 = xdl.py_func(to_tf_segment_id, [batch["user_img"].segments],
                       [np.int32])[0]
    img1 = ad_ms(batch["ad_img"].ids)
    ids1 = xdl.py_func(to_tf_segment_id, [batch["ad_img"].segments],
                       [np.int32])[0]
    label = batch['label']
    loss = ams_main(main_model)(user0,
                                user1,
                                ad0,
                                ad1,
                                label,
                                ids0,
                                ids1,
                                gear_inputs=[img0, img1])

    optimizer = xdl.Adam(0.0005).optimize()

    run_ops = [loss, optimizer]

    sess = xdl.TrainSession([])
    while not sess.should_stop():
        values = sess.run(run_ops)
        if values is not None:
            print 'loss: ', values[0]
Exemple #2
0
def train():
    batch = reader.read()
    sess = xdl.TrainSession()
    emb1 = xdl.embedding('emb1', batch['sparse0'], xdl.TruncatedNormal(stddev=0.001), 8, 1024, vtype='hash')
    emb2 = xdl.embedding('emb2', batch['sparse1'], xdl.TruncatedNormal(stddev=0.001), 8, 1024, vtype='hash')
    loss = model(batch['deep0'], [emb1, emb2], batch['label'])
    train_op = xdl.SGD(0.5).optimize()
    log_hook = xdl.LoggerHook(loss, "loss:{0}", 10)
    sess = xdl.TrainSession(hooks=[log_hook])
    while not sess.should_stop():
        sess.run(train_op)
Exemple #3
0
def get_xdl_initializer(name="glorot"):
    if name == "const":
        return xdl.Constant(0.3)
    elif name == "glorot":
        return xdl.VarianceScaling(scale=1.0,
                                   mode="fan_avg",
                                   distribution="normal",
                                   seed=3)
    elif name == "normal":
        return xdl.TruncatedNormal(stddev=0.36, seed=3)
Exemple #4
0
def run(is_training, files):

    data_io = reader("esmm", files, 2, batch_size, 2, user_fn, ad_fn)
    batch = data_io.read()

    user_embs = list()
    for fn in user_fn:
        emb = xdl.embedding('u_' + fn,
                            batch[fn],
                            xdl.TruncatedNormal(stddev=0.001),
                            embed_size,
                            1000,
                            'sum',
                            vtype='hash')
        user_embs.append(emb)

    ad_embs = list()
    for fn in ad_fn:
        emb = xdl.embedding('a_' + fn,
                            batch[fn],
                            xdl.TruncatedNormal(stddev=0.001),
                            embed_size,
                            1000,
                            'sum',
                            vtype='hash')
        ad_embs.append(emb)

    var_list = model(is_training)(ad_embs, user_embs, batch["indicators"][0],
                                  batch["label"])
    keys = [
        'loss', 'ctr_prop', 'ctcvr_prop', 'cvr_prop', 'ctr_label',
        'ctcvr_label', 'cvr_label'
    ]
    run_vars = dict(zip(keys, list(var_list)))

    hooks = []
    if is_training:
        train_op = xdl.Adam(lr).optimize()
        hooks = get_collection(READER_HOOKS)
        if hooks is None:
            hooks = []
        if xdl.get_task_index() == 0:
            ckpt_hook = xdl.CheckpointHook(1000)
            hooks.append(ckpt_hook)

        run_vars.update({None: train_op})

    if is_debug > 1:
        print("=========gradients")
        grads = xdl.get_gradients()
        grads_keys = grads[''].keys()
        grads_keys.sort()
        for key in grads_keys:
            run_vars.update({"grads {}".format(key): grads[''][key]})

    hooks.append(QpsMetricsHook())
    log_format = "lstep[%(lstep)s] gstep[%(gstep)s] " \
                 "lqps[%(lqps)s] gqps[%(gqps)s]"
    hooks.append(MetricsPrinterHook(log_format, 100))

    ckpt = xdl.get_config("checkpoint", "ckpt")
    if ckpt is not None and len(ckpt) > 0:
        if int(xdl.get_task_index()) == 0:
            from xdl.python.training.saver import Saver
            saver = Saver()
            print("restore from %s" % ckpt)
            saver.restore(ckpt)
        else:
            time.sleep(120)

    sess = xdl.TrainSession(hooks)

    if is_training:
        itr = 1
        ctr_auc = Auc('ctr')
        ctcvr_auc = Auc('ctcvr')
        cvr_auc = Auc('cvr')
        while not sess.should_stop():
            print('iter=', itr)
            values = sess.run(run_vars.values())
            if not values:
                continue
            value_map = dict(zip(run_vars.keys(), values))
            print('loss=', value_map['loss'])
            ctr_auc.add(value_map['ctr_prop'], value_map['ctr_label'])
            ctcvr_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label'])
            cvr_auc.add_with_filter(value_map['cvr_prop'],
                                    value_map['cvr_label'],
                                    np.where(value_map['ctr_label'] == 1))
            itr += 1
        ctr_auc.show()
        ctcvr_auc.show()
        cvr_auc.show()
    else:
        ctr_test_auc = Auc('ctr')
        ctcvr_test_auc = Auc('ctcvr')
        cvr_test_auc = Auc('cvr')
        for i in xrange(test_batch_num):
            print('iter=', i + 1)
            values = sess.run(run_vars.values())
            value_map = dict(zip(run_vars.keys(), values))
            print('test_loss=', value_map['loss'])
            ctr_test_auc.add(value_map['ctr_prop'], value_map['ctr_label'])
            ctcvr_test_auc.add(value_map['ctcvr_prop'],
                               value_map['ctcvr_label'])
            cvr_test_auc.add_with_filter(value_map['cvr_prop'],
                                         value_map['cvr_label'],
                                         np.where(value_map['ctr_label'] == 1))
        ctr_test_auc.show()
        ctcvr_test_auc.show()
        cvr_test_auc.show()