コード例 #1
0
def train():
    if model_type == 'din_mogujie':
        model = Model_DIN_MOGUJIE(EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE,
                                  False, train_file, batch_size)
    else:
        raise Exception('only support din_mogujie and dien')

    #data set

    with xdl.model_scope('train'):

        train_ops = model.build_network()
        lr = 0.001
        # Adam Adagrad
        train_ops.append(xdl.Adam(lr).optimize())
        hooks = []
        log_format = "[%(time)s] lstep[%(lstep)s] gstep[%(gstep)s] lqps[%(lqps)s] gqps[%(gqps)s] loss[%(loss)s]"
        hooks = [QpsMetricsHook(), MetricsPrinterHook(log_format)]
        if xdl.get_task_index() == 0:
            hooks.append(xdl.CheckpointHook(save_interval))
        train_sess = xdl.TrainSession(hooks=hooks)
    """
    with xdl.model_scope('test'):
        test_ops = model.build_network(
            EMBEDDING_DIM, is_train=False)
        test_sess = xdl.TrainSession()
    """
    model.run(train_ops, train_sess)
コード例 #2
0
def train(train_file=train_file,
          test_file=test_file,
          uid_voc=uid_voc,
          mid_voc=mid_voc,
          cat_voc=cat_voc,
          item_info=item_info,
          reviews_info=reviews_info,
          batch_size=128,
          maxlen=100,
          test_iter=700):
    model = Model_DIEN(EMBEDDING_DIM,
                       HIDDEN_SIZE,
                       ATTENTION_SIZE,
                       LIGHT_EMBEDDING_DIM,
                       LIGHT_HIDDEN_SIZE,
                       LIGHT_ATTENTION_SIZE,
                       use_rocket_training=use_rocket_training())
    sample_io = SampleIO(train_file,
                         test_file,
                         uid_voc,
                         mid_voc,
                         cat_voc,
                         item_info,
                         reviews_info,
                         batch_size,
                         maxlen,
                         embedding_dim=EMBEDDING_DIM,
                         light_embedding_dim=LIGHT_EMBEDDING_DIM)
    with xdl.model_scope('train'):
        train_ops = model.build_final_net(EMBEDDING_DIM, LIGHT_EMBEDDING_DIM,
                                          sample_io)
        lr = 0.001
        # Adam Adagrad
        train_ops.append(xdl.Adam(lr).optimize())
        hooks = []
        log_format = "[%(time)s] lstep[%(lstep)s] gstep[%(gstep)s] lqps[%(lqps)s] gqps[%(gqps)s] loss[%(loss)s]"
        hooks = [QpsMetricsHook(), MetricsPrinterHook(log_format)]
        if xdl.get_task_index() == 0:
            hooks.append(
                xdl.CheckpointHook(
                    xdl.get_config('checkpoint', 'save_interval')))
        train_sess = xdl.TrainSession(hooks=hooks)

    with xdl.model_scope('test'):
        test_ops = model.build_final_net(EMBEDDING_DIM,
                                         LIGHT_EMBEDDING_DIM,
                                         sample_io,
                                         is_train=False)
        test_sess = xdl.TrainSession()

    model.run(train_ops, train_sess, test_ops, test_sess, test_iter=test_iter)
コード例 #3
0
 def train(self,
           input_fn,
           auc_fn=auc,
           auc_interval=100,
           auc_bucket_num=200,
           max_step=sys.maxint,
           checkpoint_interval=None,
           log_format=LOG_FMT,
           user_hooks=None):
     '''
   Args:
   input_fn:
   auc_fn:
   auc_interval:
   max_step:
   checkpoint_interval:
   log_format:
   user_hooks:
   '''
     data, labels = input_fn()
     model_outputs = self._model_fn(data, labels)
     if len(model_outputs) < 2:
         raise ArgumentError("model_fn must return loss and logits")
     loss = model_outputs[0]
     logits = model_outputs[1]
     train_op = self._optimizer.optimize()
     auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num)
     hooks = []
     hooks.append(QpsMetricsHook())
     hooks.append(MetricsHook("auc", auc_op, interval=auc_interval))
     if user_hooks is not None:
         if isinstance(user_hooks, list):
             hooks.extend(user_hooks)
         else:
             hooks.append(user_hooks)
     reader_hooks = get_collection(READER_HOOKS)
     if reader_hooks is not None:
         hooks.extend(reader_hooks)
     if checkpoint_interval and get_task_index() == 0:
         hooks.append(CheckpointHook(checkpoint_interval))
     hooks.append(MetricsPrinterHook(log_format, auc_interval))
     sess = TrainSession(hooks=hooks)
     i = 0
     while not sess.should_stop() and i < max_step:
         sess.run(train_op)
         i = i + 1
コード例 #4
0
 def evaluate(self,
              input_fn,
              checkpoint_version="",
              log_format=EVAL_LOG_FMT,
              log_interval=100,
              max_step=sys.maxint,
              auc_fn=auc,
              auc_bucket_num=200,
              user_hooks=None):
     '''
 Args:
 input_fn:
 checkpoint_version:
 log_format:
 log_interval:
 max_step:
 auc_fn:
 user_hooks:
 '''
     from xdl.python.training.saver import Saver
     if get_task_index() == 0:
         saver = Saver()
         saver.restore(checkpoint_version)
     data, labels = input_fn()
     model_outputs = self._model_fn(data, labels)
     if len(model_outputs) < 2:
         raise ArgumentError("model_fn must return loss and logits")
     logits = model_outputs[1]
     auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num)
     hooks = []
     hooks.append(QpsMetricsHook())
     hooks.append(MetricsHook("auc", auc_op, interval=1))
     if user_hooks is not None:
         if isinstance(user_hooks, list):
             hooks.extend(user_hooks)
         else:
             hooks.append(user_hooks)
     hooks.append(MetricsPrinterHook(log_format, log_interval))
     sess = TrainSession(hooks=hooks)
     if id(auc_fn) == id(auc):
         sess.run(reset_auc_variables_op(auc_bucket_num))
     i = 0
     while not sess.should_stop() and i < max_step:
         sess.run([])
         i = i + 1
コード例 #5
0
def train(train_file=train_file,
          test_file=test_file,
          uid_voc=uid_voc,
          mid_voc=mid_voc,
          cat_voc=cat_voc,
          item_info=item_info,
          reviews_info=reviews_info,
          batch_size=128,
          maxlen=100,
          test_iter=700):
    if xdl.get_config('model') == 'din':
        model = Model_DIN(EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE)
    elif xdl.get_config('model') == 'dien':
        model = Model_DIEN(EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE)
    else:
        raise Exception('only support din and dien')

    sample_io = SampleIO(train_file, test_file, uid_voc, mid_voc, cat_voc,
                         item_info, reviews_info, batch_size, maxlen,
                         EMBEDDING_DIM)
    with xdl.model_scope('train'):
        train_ops = model.build_final_net(EMBEDDING_DIM, sample_io)
        lr = 0.001
        # Adam Adagrad
        train_ops.append(xdl.Adam(lr).optimize())
        hooks = []
        log_format = "[%(time)s] lstep[%(lstep)s] gstep[%(gstep)s] lqps[%(lqps)s] gqps[%(gqps)s] loss[%(loss)s]"
        hooks = [QpsMetricsHook(), MetricsPrinterHook(log_format)]
        if xdl.get_task_index() == 0:
            hooks.append(
                xdl.CheckpointHook(
                    xdl.get_config('checkpoint', 'save_interval')))
        train_sess = xdl.TrainSession(hooks=hooks)

    with xdl.model_scope('test'):
        test_ops = model.build_final_net(EMBEDDING_DIM,
                                         sample_io,
                                         is_train=False)
        test_sess = xdl.TrainSession()

    print('=' * 10 + 'start train' + '=' * 10)
    model.run(train_ops, train_sess, test_ops, test_sess, test_iter=test_iter)
コード例 #6
0
 def predict(self,
             input_fn,
             checkpoint_version="",
             log_format=PREDICT_LOG_FMT,
             log_interval=100,
             max_step=sys.maxint,
             user_hooks=None):
     ''' 
 Args:
 input_fn:
 checkpoint_version:
 log_format:
 log_interval:
 max_step:
 user_hooks:
 '''
     from xdl.python.training.saver import Saver
     if get_task_index() == 0:
         saver = Saver()
         saver.restore(checkpoint_version)
     data, labels = input_fn()
     model_outputs = self._model_fn(data, labels)
     if len(model_outputs) < 2:
         raise ArgumentError("model_fn must return loss and logits")
     logits = model_outputs[1]
     hooks = []
     hooks.append(QpsMetricsHook())
     hooks.append(MetricsHook("prediction", logits, interval=1))
     if user_hooks is not None:
         if isinstance(user_hooks, list):
             hooks.extend(user_hooks)
         else:
             hooks.append(user_hooks)
     hooks.append(MetricsPrinterHook(log_format, log_interval))
     sess = TrainSession(hooks=hooks)
     i = 0
     while not sess.should_stop() and i < max_step:
         sess.run([])
         i = i + 1
コード例 #7
0
    def train_and_evaluate(self,
                           train_input_fn,
                           eval_input_fn,
                           eval_interval,
                           eval_steps,
                           checkpoint_interval,
                           auc_fn=auc,
                           auc_bucket_num=200,
                           train_hooks=None,
                           eval_hooks=None,
                           auc_interval=100,
                           log_interval=100,
                           log_format=LOG_FMT,
                           eval_log_format=EVAL_LOG_FMT,
                           max_step=sys.maxint):
        with model_scope('train'):
            datas, labels = train_input_fn()
            train_outputs = self._model_fn(datas, labels)
            if len(train_outputs) < 2:
                raise ArgumentError("model_fn must return loss and logits")
            loss = train_outputs[0]
            logits = train_outputs[1]
            train_op = self._optimizer.optimize()
            auc_op = auc_fn(logits,
                            labels,
                            num_thresholds=auc_bucket_num,
                            namescope="train_auc")

            train_hooks = []
            train_hooks.append(QpsMetricsHook())
            train_hooks.append(
                MetricsHook("auc", auc_op, interval=auc_interval))
            if train_hooks is not None:
                if isinstance(train_hooks, list):
                    train_hooks.extend(train_hooks)
                else:
                    train_hooks.append(train_hooks)
            reader_hooks = get_collection(READER_HOOKS)
            if reader_hooks is not None:
                train_hooks.extend(reader_hooks)
            if checkpoint_interval and get_task_index() == 0:
                train_hooks.append(CheckpointHook(checkpoint_interval))
            train_hooks.append(MetricsPrinterHook(log_format, auc_interval))
            train_sess = TrainSession(hooks=train_hooks)

        with model_scope('test'):
            eval_datas, eval_labels = eval_input_fn()
            eval_outputs = self._model_fn(eval_datas, eval_labels)
            if len(eval_outputs) < 2:
                raise ArgumentError("model_fn must return loss and logits")
            eval_logits = eval_outputs[1]
            eval_auc_op = auc_fn(eval_logits,
                                 eval_labels,
                                 num_thresholds=auc_bucket_num,
                                 namescope="eval_auc")
            eval_hooks = []
            eval_hooks.append(QpsMetricsHook())
            eval_hooks.append(MetricsHook("auc", eval_auc_op, interval=1))
            if eval_hooks is not None:
                if isinstance(eval_hooks, list):
                    eval_hooks.extend(eval_hooks)
                else:
                    eval_hooks.append(eval_hooks)
            eval_hooks.append(MetricsPrinterHook(eval_log_format,
                                                 log_interval))
            eval_sess = TrainSession(hooks=eval_hooks)

        lstep = 0
        while True:
            print('\n>>> start train at local step[%d]\n' % lstep)
            while not train_sess.should_stop() and (lstep == 0 or lstep % eval_interval != 0) \
                  and lstep < max_step:
                train_sess.run(train_op)
                lstep = lstep + 1
            lstep = lstep + 1
            eval_step = 0
            print('\n>>> start evaluate at local step[%d]\n' % lstep)
            while not eval_sess.should_stop() and eval_step < eval_steps:
                eval_sess.run([])
                eval_step = eval_step + 1
            if train_sess.should_stop() or lstep >= max_step:
                break
コード例 #8
0
def run(is_training, files):

    data_io = reader("esmm", files, 2, batch_size, 2, user_fn, ad_fn)
    batch = data_io.read()

    user_embs = list()
    for fn in user_fn:
        emb = xdl.embedding('u_' + fn,
                            batch[fn],
                            xdl.TruncatedNormal(stddev=0.001),
                            embed_size,
                            1000,
                            'sum',
                            vtype='hash')
        user_embs.append(emb)

    ad_embs = list()
    for fn in ad_fn:
        emb = xdl.embedding('a_' + fn,
                            batch[fn],
                            xdl.TruncatedNormal(stddev=0.001),
                            embed_size,
                            1000,
                            'sum',
                            vtype='hash')
        ad_embs.append(emb)

    var_list = model(is_training)(ad_embs, user_embs, batch["indicators"][0],
                                  batch["label"])
    keys = [
        'loss', 'ctr_prop', 'ctcvr_prop', 'cvr_prop', 'ctr_label',
        'ctcvr_label', 'cvr_label'
    ]
    run_vars = dict(zip(keys, list(var_list)))

    hooks = []
    if is_training:
        train_op = xdl.Adam(lr).optimize()
        hooks = get_collection(READER_HOOKS)
        if hooks is None:
            hooks = []
        if xdl.get_task_index() == 0:
            ckpt_hook = xdl.CheckpointHook(1000)
            hooks.append(ckpt_hook)

        run_vars.update({None: train_op})

    if is_debug > 1:
        print("=========gradients")
        grads = xdl.get_gradients()
        grads_keys = grads[''].keys()
        grads_keys.sort()
        for key in grads_keys:
            run_vars.update({"grads {}".format(key): grads[''][key]})

    hooks.append(QpsMetricsHook())
    log_format = "lstep[%(lstep)s] gstep[%(gstep)s] " \
                 "lqps[%(lqps)s] gqps[%(gqps)s]"
    hooks.append(MetricsPrinterHook(log_format, 100))

    ckpt = xdl.get_config("checkpoint", "ckpt")
    if ckpt is not None and len(ckpt) > 0:
        if int(xdl.get_task_index()) == 0:
            from xdl.python.training.saver import Saver
            saver = Saver()
            print("restore from %s" % ckpt)
            saver.restore(ckpt)
        else:
            time.sleep(120)

    sess = xdl.TrainSession(hooks)

    if is_training:
        itr = 1
        ctr_auc = Auc('ctr')
        ctcvr_auc = Auc('ctcvr')
        cvr_auc = Auc('cvr')
        while not sess.should_stop():
            print('iter=', itr)
            values = sess.run(run_vars.values())
            if not values:
                continue
            value_map = dict(zip(run_vars.keys(), values))
            print('loss=', value_map['loss'])
            ctr_auc.add(value_map['ctr_prop'], value_map['ctr_label'])
            ctcvr_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label'])
            cvr_auc.add_with_filter(value_map['cvr_prop'],
                                    value_map['cvr_label'],
                                    np.where(value_map['ctr_label'] == 1))
            itr += 1
        ctr_auc.show()
        ctcvr_auc.show()
        cvr_auc.show()
    else:
        ctr_test_auc = Auc('ctr')
        ctcvr_test_auc = Auc('ctcvr')
        cvr_test_auc = Auc('cvr')
        for i in xrange(test_batch_num):
            print('iter=', i + 1)
            values = sess.run(run_vars.values())
            value_map = dict(zip(run_vars.keys(), values))
            print('test_loss=', value_map['loss'])
            ctr_test_auc.add(value_map['ctr_prop'], value_map['ctr_label'])
            ctcvr_test_auc.add(value_map['ctcvr_prop'],
                               value_map['ctcvr_label'])
            cvr_test_auc.add_with_filter(value_map['cvr_prop'],
                                         value_map['cvr_label'],
                                         np.where(value_map['ctr_label'] == 1))
        ctr_test_auc.show()
        ctcvr_test_auc.show()
        cvr_test_auc.show()