def __init__(self, hooks=None):
     current_env().sess_start()
     self._hooks = [] if hooks is None else hooks
     reader_hooks = get_collection(READER_HOOKS)
     if reader_hooks is not None:
         self._hooks.extend(reader_hooks)
     self._cur_scope = cur_model_scope()
     self._session = SimpleSession(hooks)
     self._finish = False
 def train(self,
           input_fn,
           auc_fn=auc,
           auc_interval=100,
           auc_bucket_num=200,
           max_step=sys.maxint,
           checkpoint_interval=None,
           log_format=LOG_FMT,
           user_hooks=None):
     '''
   Args:
   input_fn:
   auc_fn:
   auc_interval:
   max_step:
   checkpoint_interval:
   log_format:
   user_hooks:
   '''
     data, labels = input_fn()
     model_outputs = self._model_fn(data, labels)
     if len(model_outputs) < 2:
         raise ArgumentError("model_fn must return loss and logits")
     loss = model_outputs[0]
     logits = model_outputs[1]
     train_op = self._optimizer.optimize()
     auc_op = auc_fn(logits, labels, num_thresholds=auc_bucket_num)
     hooks = []
     hooks.append(QpsMetricsHook())
     hooks.append(MetricsHook("auc", auc_op, interval=auc_interval))
     if user_hooks is not None:
         if isinstance(user_hooks, list):
             hooks.extend(user_hooks)
         else:
             hooks.append(user_hooks)
     reader_hooks = get_collection(READER_HOOKS)
     if reader_hooks is not None:
         hooks.extend(reader_hooks)
     if checkpoint_interval and get_task_index() == 0:
         hooks.append(CheckpointHook(checkpoint_interval))
     hooks.append(MetricsPrinterHook(log_format, auc_interval))
     sess = TrainSession(hooks=hooks)
     i = 0
     while not sess.should_stop() and i < max_step:
         sess.run(train_op)
         i = i + 1
    def train_and_evaluate(self,
                           train_input_fn,
                           eval_input_fn,
                           eval_interval,
                           eval_steps,
                           checkpoint_interval,
                           auc_fn=auc,
                           auc_bucket_num=200,
                           train_hooks=None,
                           eval_hooks=None,
                           auc_interval=100,
                           log_interval=100,
                           log_format=LOG_FMT,
                           eval_log_format=EVAL_LOG_FMT,
                           max_step=sys.maxint):
        with model_scope('train'):
            datas, labels = train_input_fn()
            train_outputs = self._model_fn(datas, labels)
            if len(train_outputs) < 2:
                raise ArgumentError("model_fn must return loss and logits")
            loss = train_outputs[0]
            logits = train_outputs[1]
            train_op = self._optimizer.optimize()
            auc_op = auc_fn(logits,
                            labels,
                            num_thresholds=auc_bucket_num,
                            namescope="train_auc")

            train_hooks = []
            train_hooks.append(QpsMetricsHook())
            train_hooks.append(
                MetricsHook("auc", auc_op, interval=auc_interval))
            if train_hooks is not None:
                if isinstance(train_hooks, list):
                    train_hooks.extend(train_hooks)
                else:
                    train_hooks.append(train_hooks)
            reader_hooks = get_collection(READER_HOOKS)
            if reader_hooks is not None:
                train_hooks.extend(reader_hooks)
            if checkpoint_interval and get_task_index() == 0:
                train_hooks.append(CheckpointHook(checkpoint_interval))
            train_hooks.append(MetricsPrinterHook(log_format, auc_interval))
            train_sess = TrainSession(hooks=train_hooks)

        with model_scope('test'):
            eval_datas, eval_labels = eval_input_fn()
            eval_outputs = self._model_fn(eval_datas, eval_labels)
            if len(eval_outputs) < 2:
                raise ArgumentError("model_fn must return loss and logits")
            eval_logits = eval_outputs[1]
            eval_auc_op = auc_fn(eval_logits,
                                 eval_labels,
                                 num_thresholds=auc_bucket_num,
                                 namescope="eval_auc")
            eval_hooks = []
            eval_hooks.append(QpsMetricsHook())
            eval_hooks.append(MetricsHook("auc", eval_auc_op, interval=1))
            if eval_hooks is not None:
                if isinstance(eval_hooks, list):
                    eval_hooks.extend(eval_hooks)
                else:
                    eval_hooks.append(eval_hooks)
            eval_hooks.append(MetricsPrinterHook(eval_log_format,
                                                 log_interval))
            eval_sess = TrainSession(hooks=eval_hooks)

        lstep = 0
        while True:
            print('\n>>> start train at local step[%d]\n' % lstep)
            while not train_sess.should_stop() and (lstep == 0 or lstep % eval_interval != 0) \
                  and lstep < max_step:
                train_sess.run(train_op)
                lstep = lstep + 1
            lstep = lstep + 1
            eval_step = 0
            print('\n>>> start evaluate at local step[%d]\n' % lstep)
            while not eval_sess.should_stop() and eval_step < eval_steps:
                eval_sess.run([])
                eval_step = eval_step + 1
            if train_sess.should_stop() or lstep >= max_step:
                break
Beispiel #4
0
def get_sparse_grads(name):
  return get_collection('sparse_grad')[0][name]
Beispiel #5
0
def run(is_training, files):

    data_io = reader("esmm", files, 2, batch_size, 2, user_fn, ad_fn)
    batch = data_io.read()

    user_embs = list()
    for fn in user_fn:
        emb = xdl.embedding('u_' + fn,
                            batch[fn],
                            xdl.TruncatedNormal(stddev=0.001),
                            embed_size,
                            1000,
                            'sum',
                            vtype='hash')
        user_embs.append(emb)

    ad_embs = list()
    for fn in ad_fn:
        emb = xdl.embedding('a_' + fn,
                            batch[fn],
                            xdl.TruncatedNormal(stddev=0.001),
                            embed_size,
                            1000,
                            'sum',
                            vtype='hash')
        ad_embs.append(emb)

    var_list = model(is_training)(ad_embs, user_embs, batch["indicators"][0],
                                  batch["label"])
    keys = [
        'loss', 'ctr_prop', 'ctcvr_prop', 'cvr_prop', 'ctr_label',
        'ctcvr_label', 'cvr_label'
    ]
    run_vars = dict(zip(keys, list(var_list)))

    hooks = []
    if is_training:
        train_op = xdl.Adam(lr).optimize()
        hooks = get_collection(READER_HOOKS)
        if hooks is None:
            hooks = []
        if xdl.get_task_index() == 0:
            ckpt_hook = xdl.CheckpointHook(1000)
            hooks.append(ckpt_hook)

        run_vars.update({None: train_op})

    if is_debug > 1:
        print("=========gradients")
        grads = xdl.get_gradients()
        grads_keys = grads[''].keys()
        grads_keys.sort()
        for key in grads_keys:
            run_vars.update({"grads {}".format(key): grads[''][key]})

    hooks.append(QpsMetricsHook())
    log_format = "lstep[%(lstep)s] gstep[%(gstep)s] " \
                 "lqps[%(lqps)s] gqps[%(gqps)s]"
    hooks.append(MetricsPrinterHook(log_format, 100))

    ckpt = xdl.get_config("checkpoint", "ckpt")
    if ckpt is not None and len(ckpt) > 0:
        if int(xdl.get_task_index()) == 0:
            from xdl.python.training.saver import Saver
            saver = Saver()
            print("restore from %s" % ckpt)
            saver.restore(ckpt)
        else:
            time.sleep(120)

    sess = xdl.TrainSession(hooks)

    if is_training:
        itr = 1
        ctr_auc = Auc('ctr')
        ctcvr_auc = Auc('ctcvr')
        cvr_auc = Auc('cvr')
        while not sess.should_stop():
            print('iter=', itr)
            values = sess.run(run_vars.values())
            if not values:
                continue
            value_map = dict(zip(run_vars.keys(), values))
            print('loss=', value_map['loss'])
            ctr_auc.add(value_map['ctr_prop'], value_map['ctr_label'])
            ctcvr_auc.add(value_map['ctcvr_prop'], value_map['ctcvr_label'])
            cvr_auc.add_with_filter(value_map['cvr_prop'],
                                    value_map['cvr_label'],
                                    np.where(value_map['ctr_label'] == 1))
            itr += 1
        ctr_auc.show()
        ctcvr_auc.show()
        cvr_auc.show()
    else:
        ctr_test_auc = Auc('ctr')
        ctcvr_test_auc = Auc('ctcvr')
        cvr_test_auc = Auc('cvr')
        for i in xrange(test_batch_num):
            print('iter=', i + 1)
            values = sess.run(run_vars.values())
            value_map = dict(zip(run_vars.keys(), values))
            print('test_loss=', value_map['loss'])
            ctr_test_auc.add(value_map['ctr_prop'], value_map['ctr_label'])
            ctcvr_test_auc.add(value_map['ctcvr_prop'],
                               value_map['ctcvr_label'])
            cvr_test_auc.add_with_filter(value_map['cvr_prop'],
                                         value_map['cvr_label'],
                                         np.where(value_map['ctr_label'] == 1))
        ctr_test_auc.show()
        ctcvr_test_auc.show()
        cvr_test_auc.show()