def pos_cut(text):
    #print('------------', text, gezi.env_has('BSEG'))
    if gezi.env_has('STANFORD_NLP'):
        import emoji
        init_stanford_nlp()
        l = stanford_nlp.pos_tag(emoji.demojize(text))
        l = hack_emoji2(l)
        res = merge_expression2(l)
    elif gezi.env_has('BSEG'):
        import emoji
        init_bseg(use_pos=True)

        def bseg_(text):
            nodes = bseg.Cut(to_gbk(emoji.demojize(text.decode('utf8'))))
            l = [(to_utf8(x.word), pos_tags[x.type]) for x in nodes]
            return l

        l = bseg_(text)

        # if not same_with_ner:
        #   l = bseg_(text)
        # else:
        #   # HACK to have same seg result as word ner
        #   MAX_LEN = 500
        #   text_len = len(text)
        #   if  text_len < MAX_LEN:
        #     l = bseg_(text)
        #   else:
        #     len_ = 0
        #     words = []
        #     l = []
        #     for word in jieba.cut(text):
        #       word = word.encode('utf-8')
        #       len_ += len(word)
        #       words.append(word)
        #       if len_ >= MAX_LEN:
        #         len_ = 0
        #         l += bseg_(''.join(words))
        #         #print(len(''.join(words)), len(l))
        #         words = []
        #     if words:
        #       l += bseg_(''.join(words))

        #assert l
        l = hack_emoji2(l)
        res = merge_expression2(l)
    else:
        res = merge_expression2(list(jieba.posseg.cut(text)))

    for i in range(len(res)):
        w, t = res[i]
        if w == '\x01' or w == '\x02' or w == '\x03':
            res[i] = (w, 'sep')
    return res
Exemple #2
0
def init(vocab_path_=None, append=None):
    global vocab, vocab_size, vocab_path
    if vocab is None:
        if not FLAGS.vocab_buckets:
            vocab_path = vocab_path_ or FLAGS.vocab or gezi.dirname(
                FLAGS.model_dir) + '/vocab.txt'
            FLAGS.vocab = vocab_path
            logging.info('vocab:{}'.format(vocab_path))
            logging.info('NUM_RESERVED_IDS:{}'.format(FLAGS.num_reserved_ids))
            if append is None:
                append = FLAGS.vocab_append
                if gezi.env_has('VOCAB_APPEND'):
                    append = True
            vocab = Vocabulary(vocab_path,
                               FLAGS.num_reserved_ids,
                               append=append,
                               max_words=FLAGS.vocab_max_words,
                               min_count=FLAGS.vocab_min_count)
        else:
            vocab = Vocabulary(buckets=FLAGS.vocab_buckets)
        vocab_size = vocab.size() if not FLAGS.vocab_size else min(
            vocab.size(), FLAGS.vocab_size)
        logging.info('vocab_size:{}'.format(vocab_size))
        assert vocab_size > FLAGS.num_reserved_ids, 'empty vocab, wrong vocab path? %s' % FLAGS.vocab
        logging.info('vocab_start:{} id:{}'.format(vocab.key(vocab.start_id()),
                                                   vocab.start_id()))
        logging.info('vocab_end:{} id:{}'.format(vocab.key(vocab.end_id()),
                                                 vocab.end_id()))
        logging.info('vocab_unk:{} id:{}'.format(vocab.key(vocab.unk_id()),
                                                 vocab.unk_id()))
def word_cut(text):
    if gezi.env_has('SENTENCE_PIECE'):
        init_sp()
        l = sp.EncodeAsPieces(text)
        if l:
            if l[0] == '▁':
                l = l[1:]
            elif l[0].startswith('▁'):
                l[0] = l[0][1:]
        return l

    if gezi.env_has('STANFORD_NLP'):
        import emoji
        init_stanford_nlp()
        try:
            l = stanford_nlp.word_tokenize(emoji.demojize(text))
            l = hack_emoji(l)
            return merge_expression(l)
        except Exception:
            print('stnaord error text: %s' % text, file=sys.stderr)
            l = list(jieba.cut(text))
            return merge_expression(l)

    if gezi.env_has('BSEG'):
        import emoji
        init_bseg()
        try:
            # NOTICE py2 need decode utf8 to get unicode as input to emoji
            l = bseg.Segment(to_gbk(emoji.demojize(text.decode('utf8'))))
            l = [to_utf8(x) for x in l]
            l = hack_emoji(l)
            l = merge_expression(l)
            return l
        except Exception:
            print('bseg error text: %s' % text, file=sys.stderr)
            l = list(jieba.cut(text))
            l = merge_expression(l)
            return l

    # TODO make a switch since jieba.posseg is much slower... then jieba.cut
    if gezi.env_has('JIEBA_POS'):
        l = jieba.posseg.cut(text)
        l = [word for word, tag in l]
        return merge_expression(l)

    return merge_expression(list(jieba.cut(text)))
 def bseg_(text):
     bseg.Cut(to_gbk(emoji.demojize(text.decode('utf8'))))
     bseg.NerTag()
     if not gezi.env_has('BSEG_SUBNER'):
         nodes = bseg.GetNerNodes()
     else:
         nodes = bseg.GetSubNerNodes()
     l = [(to_utf8(x.word), x.name) for x in nodes]
     return l
def main(_):
    # FLAGS.seg_method = 'basic_digit'
    # FLAGS.feed_single = True
    # FLAGS.feed_single_en = True
    # print('seg_method:', FLAGS.seg_method, file=sys.stderr)
    # print('feed_single:', FLAGS.feed_single, file=sys.stderr)
    # print('feed_single_en:', FLAGS.feed_single_en, file=sys.stderr)

    #assert FLAGS.vocab

    global vocab
    vocab = gezi.Vocabulary(FLAGS.vocab)

    ifile = sys.argv[1]
    if not gezi.env_has('BSEG'):
        ofile = ifile.replace('.csv', '.seg.jieba.mix.txt')
    else:
        ofile = ifile.replace('.csv', '.seg.bseg.mix.txt')

    counter = WordCounter(most_common=0, min_count=1)
    vocab2 = ifile.replace('.csv', '.pos.mix.vocab')

    ids_set = set()
    fm = 'w'
    if os.path.exists(ofile):
        fm = 'a'
        for line in open(ofile):
            ids_set.add(line.split('\t')[0])

    print('%s already done %d' % (ofile, len(ids_set)))

    num_errs = 0
    with open(ofile, fm) as out:
        df = pd.read_csv(ifile, lineterminator='\n')
        contents = df['content'].values
        ids = df['id'].values
        for i in tqdm(range(len(df)), ascii=True):
            #if str(ids[i]) in ids_set:
            #  continue
            #if i != 2333:
            #  continue
            #print(gezi.cut(filter.filter(contents[i]), type_))
            try:
                seg(ids[i], contents[i], out, counter)
            except Exception:
                if num_errs == 0:
                    print(traceback.format_exc())
                num_errs += 1
                continue
            #exit(0)

    counter.save(vocab2)
    print('num_errs:', num_errs, 'ratio:', num_errs / len(df))
Exemple #6
0
def tf_train_flow(
        train_once_fn,
        model_dir=None,
        log_dir=None,
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=None,
        freeze_graph=False,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        valid_interval_epochs=0,
        inference_fn=None,
        inference_interval_epochs=0,
        init_fn=None,
        restore_fn=None,
        restore_include=None,
        restore_exclude=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        output_collection_names=None,
        output_node_names=None,
        learning_rate=None,  #not use yet, just use in train_once
        learning_rate_patience=None,
        learning_rate_decay_factor=None,
        write_during_train=True,
        model=None,
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    model_dir_ = model_dir
    if use_horovod and hvd.rank() != 0:
        model_dir = None

    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()

    if FLAGS.use_tpu:
        sess.run(tpu.initialize_system())
    #logging.info('tf_train_flow start')
    #logging.info('max_models_keep:', max_models_keep)
    #logging.info('save_interval_seconds:', save_interval_seconds)

    if model_dir:
        if model:
            checkpoint = tf.train.Checkpoint(model=model)
            ckpt_dir = model_dir + '/ckpt'
            checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')

        #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
        #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

        # var_list = None if not restore_scope else tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
        # #logging.info('-------------var_list', var_list)

        # if not variables_to_restore:
        #   variables_to_restore = var_list

        if not variables_to_restore:
            variables_to_restore = slim.get_variables_to_restore(
                include=restore_include, exclude=restore_exclude)

        if not variables_to_save:
            variables_to_save = variables_to_restore
        if save_all_scope:
            variables_to_save = None

        #if variables_to_restore is None:
        logging.info('variables_to_restore from %s' % model_dir)
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #logging.info('varnames_in_checkpoint: {}'.format(varnames_in_checkpoint))

        # TODO has someproblem say  tf.Variable 'r_net/text_encoder/cudnn_rnn/cu_dnngru/recurrent_kernel/adam_v:0' even though in checkpoint I have renated it as ignore/rnet
        variables_to_restore_from_model = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)
        #logging.info('variables_to_restore_from_model: {}'.format(variables_to_restore_from_model))
        if not variables_to_restore:
            variables_to_restore = variables_to_restore_from_model
        else:
            variables_to_restore = [
                v for v in variables_to_restore
                if v in variables_to_restore_from_model
            ]
        if restore_exclude:
            for excl in restore_exclude:
                variables_to_restore = [
                    v for v in variables_to_restore if not excl in v.name
                ]
        #--tf 1.6 adadelta will have same vars...
        variables_to_restore = list(set(variables_to_restore))
        #logging.info('variables_to_restore', variables_to_restore[:100])
        logging.info('variables_to_restore', [
            x for x in variables_to_restore if not 'OptimizeLoss' in x.name
        ][:100])

    ##finally remove global_step since melt.apps.train will handle it!
    global_step = tf.train.get_or_create_global_step()

    #variables_to_restore = [v for v in variables_to_restore if not tf.GraphKeys.GLOBAL_STEP in v.name]
    #variables_to_restore = [v for v in variables_to_restore if not 'learning_rate' in v.name]

    # TODO fixme if step, step2.. and in checkpoint step then here will be step2...
    #print('------------', [v for v in variables_to_restore if 'step' in v.name])
    loader = tf.train.Saver(var_list=variables_to_restore)

    logging.info('max models to keep {}, keep every {} hours'.format(
        max_models_keep, save_interval_seconds / 3600.0))
    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)
    #logging.info('variables_to_save:{}'.format(variables_to_save))

    # # #TODO for safe restore all init will be ok ?
    # # if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # # else:
    # #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    # #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    # https://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables
    # def guarantee_initialized_variables(session, list_of_variables = None):
    #   if list_of_variables is None:
    #       list_of_variables = tf.global_variables()
    #   uninitialized_variables = list(tf.get_variable(name) for name in
    #                                  session.run(tf.report_uninitialized_variables(list_of_variables)))
    #   return unintialized_variables

    # unintialized_variables = guarantee_initialized_variables(sess)
    # init_op = tf.group(tf.initialize_variables(uninitialized_vars), tf.local_variables_initializer())

    timer = gezi.Timer('sess run init_op in melt.tf_train_flow')
    #model.save('./weights')

    # notice
    sess.run(init_op)

    timer.print_elapsed()

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if you change batch size
    #print(model_dir)
    pre_epoch = None
    if model_dir:
        model_path = _get_model_path(model_dir, save_model)
        # if not model_path:
        #   model_path = _get_model_path(os.path.join(model_dir, 'epoch'))
        #print(model_path)
        model_dir = gezi.get_dir(
            model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model

        if model_path is not None:
            if not restore_from_latest:
                logging.info('using recent but not latest model')
                model_path = melt.recent_checkpoint(model_dir)
            model_name = os.path.basename(model_path)
            timer = gezi.Timer(
                'Loading and training from existing model [%s]' % model_path)
            if restore_fn is not None:
                restore_fn(sess)
            loader.restore(sess, model_path)
            ## not supported
            #model.save()
            #model.save_weights('./weights')
            timer.print()
            #pre_step = melt.get_model_step(model_path) - 1 if FLAGS.global_step is None else FLAGS.global_step -1
            # TODO check ..
            pre_step = sess.run(tf.train.get_global_step()) - 1
            pre_epoch = melt.get_model_epoch(
                model_path
            ) if FLAGS.global_epoch is None else FLAGS.global_epoch
            fixed_pre_step = pre_step
            # if pre_epoch is not None:
            #   #like using batch size 32, then reload train using batch size 64
            #   if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
            #     fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
            #     logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
            #       .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
        else:
            latest_checkpoint = None
            if not use_horovod:  #now will hang
                try:
                    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
                    if latest_checkpoint:
                        logging.info(
                            'Try start from eager trained mode, latest checkpoint:',
                            latest_checkpoint)
                        checkpoint.restore(latest_checkpoint).run_restore_ops(
                            session=sess)

                        pre_epoch = int(latest_checkpoint.split('-')[-1])
                        #pre_step = pre_epoch * num_steps_per_epoch - 1
                        # TODO check
                        pre_step = sess.run(tf.train.get_global_step()) - 1
                        fixed_pre_step = pre_step
                        logging.info('Start step is:', pre_step)
                except Exception:
                    logging.info(
                        'Something wrong with restore from eager trained model'
                    )
                if latest_checkpoint is None:
                    logging.info('Train all start step 0')
                    #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
                    #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
                    #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
                    #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
                    #init_op = tf.group(tf.global_variables_initializer(),
                    #                   tf.local_variables_initializer())
                    #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

                    #sess.run(init_op)

                    #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
                    #for finetune from loading other model init
                    if init_fn is not None:
                        init_fn(sess)

    if gezi.env_has('METRIC'):
        l = metric_eval_fn(model_path)
        print(list(zip(l[1], l[0])))
        exit(0)

    #sess.run(tf.assign(global_step, tf.constant(global_step_val, dtype=tf.int64)))
    try:
        learning_rate = tf.get_collection('learning_rate')[-1]
        learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
        sess.run(tf.assign(learning_rate,
                           learning_rate * learning_rate_weight))
    except Exception:
        # if not using weight_decay but using optimizer decay then will go here as learning rate is a tensor can not assign
        pass

    try:
        logging.info('Actual start global step:',
                     sess.run(global_step), 'learning rate:',
                     sess.run(learning_rate), 'learning_rate_weight:',
                     sess.run(learning_rate_weight))
    except Exception:
        pass

    if model_dir_:
        #if save_interval_epochs and num_steps_per_epoch and num_steps >= 0:
        epoch_dir = os.path.join(model_dir_, 'epoch')
        gezi.try_mkdir(epoch_dir)
        checkpoint_path = os.path.join(model_dir_, 'model.ckpt')

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if use_horovod:
        bcast = hvd.broadcast_global_variables(0)
        sess.run(bcast)

    #tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        if use_horovod:
            ## TODO FIXME why bcast here not work ? simple test work see tests/bcast.py
            #comm.bcast(pre_step, root=0)
            temp = np.array([pre_step, fixed_pre_step])
            comm.Bcast(temp, root=0)
            pre_step = temp[0]
            fixed_pre_step = temp[1]

        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1

        #first = True

        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            logging.info('just load and resave then exit')
            model_path_ = _get_checkpoint_path(checkpoint_path,
                                               step,
                                               num_steps_per_epoch,
                                               epoch=pre_epoch)
            saver.save(sess, model_path_, global_step=step + 1)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step + 1,
                                  output_collection_names, output_node_names)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            logging.info('just run one step')

        if FLAGS.work_mode != 'train':
            assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
            if 'valid' in FLAGS.work_mode:
                vals, names = metric_eval_fn(FLAGS.model_dir)
                logging.info(list(zip(names, vals)))
            if 'test' in FLAGS.work_mode:
                inference_fn(FLAGS.model_dir)
            exit(0)

        #early_stop = True #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        epoch_saved_step = 0
        while not coord.should_stop():
            model_step_path = None
            if model_dir_:
                model_path_ = os.path.join(
                    epoch_dir, 'model.ckpt-%.2f' %
                    (fixed_step / float(num_steps_per_epoch)))
                model_step_path_ = model_path_ + '-' + str(step)
                if (write_during_train and metric_eval_fn is not None
                        and valid_interval_epochs and fixed_step %
                        int(num_steps_per_epoch * valid_interval_epochs) == 0):
                    model_step_path = model_step_path_
                else:
                    model_step_path = None

            if step == 0:
                model_step_path = None

            #print('--------------------step', step)
            stop = train_once_fn(
                sess,
                step,
                is_start=(step == start),
                fixed_step=fixed_step,
                num_epochs=num_epochs,
                model_path=model_step_path,
                use_horovod=use_horovod,
                ## TODO FIXME this line will cause   tensorflow.python.framework.errors_impl.NotFoundError: Resource localhost/save_counter/N10tensorflow3VarE does not exist.
            )

            #first = False

            if only_one_step:
                stop = True

            step += 1
            fixed_step += 1

            if save_model and step and model_dir:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer(
                        'save model step %d to %s' % (step, checkpoint_path),
                        False)
                    model_path_ = _get_checkpoint_path(checkpoint_path,
                                                       fixed_step,
                                                       num_steps_per_epoch)
                    saver.save(sess, model_path_, global_step=step)
                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)
                    #if log_dir != model_dir:
                    #  assert log_dir
                    #  command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                    #  print(command, file=sys.stderr)
                    #  os.system(command)
                    timer.print_elapsed()

                if save_interval_steps and num_steps_per_epoch and fixed_step % int(
                        num_steps_per_epoch * save_interval_epochs) == 0:
                    # TODO only epoch in name not sep ?
                    epoch_saved_step = step
                    model_path_ = os.path.join(
                        epoch_dir, 'model.ckpt-%.2f' %
                        (fixed_step / float(num_steps_per_epoch)))
                    model_step_path = model_path_ + '-' + str(step)
                    epoch_saver.save(sess, model_path_, global_step=step)
                    #epoch_saver.save(sess, model_path_)

                    ## TODO FIXME do not support tf.keras save currently with horovod
                    # if model:
                    #   #model.save_weights(epoch_dir + '/ckpt-%.2f' % (fixed_step / float(num_steps_per_epoch)))
                    #   # TODO FIXME if restart will save from 1... again..
                    #   checkpoint.save(checkpoint_prefix, session=sess)
                    #   #print(sess.run(checkpoint.save_counter))

                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)

                if write_during_train:
                    if inference_fn is not None and inference_interval_epochs and fixed_step % int(
                            num_steps_per_epoch *
                            inference_interval_epochs) == 0:
                        model_step_path = model_path_ + '-' + str(step)
                        try:
                            #print('--------------inference fn')
                            inference_fn(model_path=model_step_path)
                        except Exception:
                            logging.info(traceback.format_exc())

                    # if metric_eval_fn is not None and valid_interval_epochs and fixed_step % int(num_steps_per_epoch * valid_interval_epochs) == 0:
                    #   model_step_path = model_path_ + '-' + str(step)
                    #   try:
                    #     metric_eval_fn(model_path=model_step_path)
                    #   except Exception:
                    #     logging.info(traceback.format_exc())

            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            #if max_num_epochs and num_steps_per_epoch and fixed_step // num_steps_per_epoch >= max_num_epochs:
            if max_num_epochs and num_steps_per_epoch and fixed_step / num_steps_per_epoch > max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
    #except tf.errors.OutOfRangeError, e:
    except tf.errors.OutOfRangeError:
        # if run 2 epoch and we have just epoch saved, do not need to save only 1 step more model
        if (step - epoch_saved_step > 1) and not (
                step == start
        ) and save_model and step % save_interval_steps != 0 and model_dir:
            model_path_ = _get_checkpoint_path(checkpoint_path, step,
                                               num_steps_per_epoch)
            saver.save(sess, model_path_, global_step=step)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step,
                                  output_collection_names, output_node_names)
            if log_dir != model_dir:
                assert log_dir
                command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                print(command, file=sys.stderr)
                os.system(command)
        if only_one_step:
            logging.info('Done one step')
            exit(0)

        # if (step - epoch_saved_step > 1) and metric_eval_fn is not None:
        #   metric_eval_fn(model_path=model_step_path)

        if (num_epochs and fixed_step / num_steps_per_epoch >= num_epochs) or (
                num_steps and step == start + num_steps):
            logging.info('Done training for %.3f epochs, %d steps.' %
                         (fixed_step / num_steps_per_epoch, step))
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            logging.info('Should not stop, but stopped at epoch: %.3f' %
                         (fixed_step / num_steps_per_epoch))
            logging.info(traceback.format_exc())
            #raise e
    finally:
        coord.request_stop()

    coord.join(threads, stop_grace_period_secs=5)
    #FIMXE due to use melt.get_session(global not handle del well)
    #Done training for 3090020 steps.
    #Exception TypeError: "'NoneType' object is not callable" in <bound method Session.__del__ of <tensorflow.python.client.session.Session object at 0x7f6cf33cd450>> ignored
    if FLAGS.use_tpu:
        sess.run(tpu.shutdown_system())
    sess.close()
import pandas as pd

from projects.ai2018.sentiment.prepare import filter

from tqdm import tqdm
import traceback

START_WORD = '<S>'
END_WORD = '</S>'

counter = WordCounter(most_common=0, min_count=1)
counter2 = WordCounter(most_common=0, min_count=1)

print('seg_method:', FLAGS.seg_method, file=sys.stderr)

if gezi.env_has('SENTENCE_PIECE'):
  assert FLAGS.sp_path 
  gezi.segment.init_sp(FLAGS.sp_path)

def seg(id, text, out, type):
  text = filter.filter(text)
  counter.add(START_WORD)
  counter.add(END_WORD)
  l = gezi.cut(text, type)

  if type != 'word':
    for x, y in l:
      counter.add(x)
      counter2.add(y)
    words = ['%s|%s' % (x, y) for x,y in l]
  else:
Exemple #8
0
  for ch in sentence:
    try:
      gbk_ch = ch.decode('utf8').encode('gbk')
      l.append(gbk_ch)
    except Exception:
      if l:
        chs = chnormalizer.Normalize(''.join(l), toLower=to_lower, toSimplified=to_simplify, toHalf=to_half).decode('gbk').encode('utf8')
        l = []
        res.append(chs)
      res.append(ch.encode('utf8'))
  if l:
    chs = chnormalizer.Normalize(''.join(l), toLower=to_lower, toSimplified=to_simplify, toHalf=to_half).decode('gbk').encode('utf8')
    res.append(chs)
  return ''.join(res)

if gezi.encoding == 'gbk' or gezi.env_has('BAIDU_SEG'):
  import libgezi # must include this not sure why..
  import libstring_util as su
  def get_single_cns(text):
    return su.to_cnvec(su.extract_chinese(text)) 
  
  def is_single_cn(word):
    word = word.decode('gbk', 'ignore')
    return u'\u4e00' <= word <= u'\u9fff'

  def get_single_chars(text):
    l = [x.encode('gbk') for x in text.decode('gbk', 'ignore')]
    return [x.strip() for x in l if x.strip()]

else:
  def get_single_cns(text):
Exemple #9
0
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    valid_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
    use_horovod=False,
):
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0
    summary_str = []

    eval_str = ''
    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        # if eval ops then should have bee rank 0

        if eval_ops:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # if use horovod let each rant use same sess.run!
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY') or use_horovod:
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False
                if use_horovod:
                    sess.run(hvd.allreduce(tf.constant(0)))

                #if not use_horovod or  hvd.local_rank() == 0:
                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                #if not use_horovod or hvd.rank() == 0:
                #  logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_)))
                eval_str = 'valid:{}'.format(
                    melt.parse_results(eval_loss, eval_names_))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                if not use_horovod or hvd.rank() == 0:
                    melt.set_global('eval_loss',
                                    melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != valid_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    if 'EVFIRST' in os.environ:
        if os.environ['EVFIRST'] == '0':
            if is_start:
                metric_evaluate = False
        else:
            if is_start:
                metric_evaluate = True

    if step == 0 or 'QUICK' in os.environ:
        metric_evaluate = False

    #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank())
    if metric_evaluate:
        if use_horovod:
            print('------------metric evaluate step', step, model_path,
                  hvd.rank())
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            metric_eval_fn_ = metric_eval_fn
        else:
            metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path)

        try:
            l = metric_eval_fn_()
            if isinstance(l, tuple):
                num_returns = len(l)
                if num_returns == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok'
                    evaluate_results, evaluate_names, evaluate_summaries = l
            else:  #return dict
                evaluate_results, evaluate_names = tuple(zip(*dict.items()))
                evaluate_summaries = None
        except Exception:
            logging.info('Do nothing for metric eval fn with exception:\n',
                         traceback.format_exc())

        if not use_horovod or hvd.rank() == 0:
            #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names)))
            logging.info2('{} valid_step:{} {}:{}'.format(
                epoch_str, step, 'valid_metrics',
                melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once,
                'summary_op') or train_once.summary_op is None or use_horovod:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            ## TODO why below ?
            #try:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        #if is_start or interval_steps and step % interval_steps == 0:
        interval_ok = not use_horovod or hvd.local_rank() == 0
        if interval_steps and step % interval_steps == 0 and interval_ok:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.2f} '.format(duration)
                melt.set_global('duration', '%.2f' % duration)
                #info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                info.write(
                    'elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]'
                    .format(elapsed, batch_size, gpu_info, steps_per_second,
                            instances_per_second, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))
            info.write(eval_str)
            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                summary = tf.Summary()
                if eval_ops is None:
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'valid' if not eval_names else ''
                    # loss/valid
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    try:
                        # loss/train_avg
                        melt.add_summarys(summary,
                                          train_average_loss,
                                          names_,
                                          suffix='train_avg')
                    except Exception:
                        pass
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary,
                                     melt.batch_size(),
                                     'batch_size',
                                     prefix='other')
                    melt.add_summary(summary,
                                     melt.epoch(),
                                     'epoch',
                                     prefix='other')
                    if steps_per_second:
                        melt.add_summary(summary,
                                         steps_per_second,
                                         'steps_per_second',
                                         prefix='perf')
                    if instances_per_second:
                        melt.add_summary(summary,
                                         instances_per_second,
                                         'instances_per_second',
                                         prefix='perf')
                    if hours_per_epoch:
                        melt.add_summary(summary,
                                         hours_per_epoch,
                                         'hours_per_epoch',
                                         prefix='perf')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step_eval'
                    if model_path:
                        prefix = 'eval'
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step
                    # eval/loss eval/auc ..
                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()
            return stop
        elif metric_evaluate and log_dir:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step_eval'
            if model_path:
                prefix = 'eval'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()
import numpy as np

import gezi
from gezi import WordCounter

import pandas as pd

from projects.ai2018.sentiment.prepare import filter

from tqdm import tqdm
import traceback

#assert gezi.env_has('BSEG')

import six
if gezi.env_has('BSEG'):
    assert six.PY2

vocab = None


def seg(id, text, out, counter):
    text = filter.filter(text)
    words = []
    for i, word in enumerate(gezi.cut(text)):
        counter.add(str(i))
        if vocab.has(word) and not word.isdigit():
            words.append('%s|%d' % (word, i))
        else:
            if six.PY2:
                for ch in word.decode('utf8'):
        def Segment(self, text, method='default'):
            """
      default means all level combine
      """
            if gezi.env_has('JIEBA_SEG'):
                ori_text = text
                try:
                    text = text.decode('utf8').encode('gbk')
                except Exception:
                    #print('------------jieba cut')
                    return JiebaSegmentor().Segment(ori_text, method=method)
            else:
                text = text.decode('utf8').encode('gbk', 'ignore')

            if method == 'default' or method == 'all' or method == 'full':
                l = self.segment(text)
            elif method == 'phrase_single':
                l = self.segment_phrase_single(text)
            elif method == 'phrase_single_all':
                l = self.segment_phrase_single_all(text)
            elif method == 'phrase':
                l = seg.Segment(text)
            elif method == 'basic':
                l = seg.Segment(text, libsegment.SEG_BASIC)
            elif method == 'basic_digit':
                words = seg.Segment(text, libsegment.SEG_BASIC)

                def sep_digits(word):
                    l = []
                    s = ''
                    for c in word:
                        if c.isdigit():
                            if s:
                                l.append(s)
                            l.append(c)
                            s = ''
                        else:
                            s += c
                    if s:
                        l.append(s)
                    return l

                l = []
                for w in words:
                    l += sep_digits(w)
                words = l
            elif method == 'basic_single':
                l = self.segment_basic_single(text)
            elif method == 'basic_single_all':
                l = self.segment_basic_single_all(text)
            elif method == 'phrase_single_all':
                l = self.segment_phrase_single_all(text)
            elif method == 'merge_newword':
                l = seg.Segment(text, libsegment.SEG_MERGE_NEWWORD)
            elif method == 'merge_newword_single':
                l = self.segment_merge_newword_single(text)
            elif method == 'seq_all':
                l = self.segment_seq_all(text)
            elif method == 'en':
                l = segment_en(text)
            elif method == 'tokenize':
                l = tokenize(text)
            elif method == 'char':
                l = segment_char(text)
            elif method == 'tab':
                l = text.strip().split('\t')
            elif method == 'white_space':
                l = text.strip().split()
            else:
                raise ValueError('%s not supported' % method)

            return [x.decode('gbk').encode('utf8') for x in l]
def ner_cut(text):
    import emoji
    if gezi.env_has('STANFORD_NLP'):
        init_stanford_nlp()
        l = stanford_nlp.ner(emoji.demojize(text))
        l = merge_expression2(l)
    else:
        init_bseg(use_ner=True)

        def bseg_(text):
            bseg.Cut(to_gbk(emoji.demojize(text.decode('utf8'))))
            bseg.NerTag()
            if not gezi.env_has('BSEG_SUBNER'):
                nodes = bseg.GetNerNodes()
            else:
                nodes = bseg.GetSubNerNodes()
            l = [(to_utf8(x.word), x.name) for x in nodes]
            return l

        # have tested as 718 cause error
        MAX_LEN = 500
        text_len = len(text)
        #l = bseg_(text)
        #print(text_len, len(l))

        # global x
        # if len(l) == 0:
        #   if text_len < x:
        #     x = text_len
        #     #print('-------------------', x)

        #print('-----------------', x)

        # HACK bseg wordner could nout seg long text so workaround here is to cut it
        # well still has some fail..
        jieba_cuts = None
        if text_len < MAX_LEN:
            l = bseg_(text)
        else:
            len_ = 0
            words = []
            l = []
            jieba_cuts = [x.encode('utf8') for x in jieba.cut(text)]
            for word in jieba_cuts:
                len_ += len(word)
                words.append(word)
                if len_ >= MAX_LEN:
                    len_ = 0
                    l += bseg_(''.join(words))
                    #print(len(''.join(words)), len(l))
                    words = []

            if words:
                l += bseg_(''.join(words))

        #if len(l) < len(text) / 10:
        if len(''.join([x for x, y in l
                        ]).decode('utf8')) < len(text.decode('utf8')) * 0.8:
            print('warning, bad cut for turn back to use jieba cut:',
                  text,
                  file=sys.stderr)
            if not jieba_cuts:
                jieba_cuts = [x.encode('utf8') for x in jieba.cut(text)]
            l = [(x, 'NOR') for x in jieba_cuts]

        #exit(0)
    #assert l
    l = hack_emoji2(l)
    res = merge_expression2(l)

    for i in range(len(res)):
        w, t = res[i]
        if w == '\x01' or w == '\x02' or w == '\x03':
            res[i] = (w, 'sep')
    return res
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    eval_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
):

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or eval_interval_steps and step % eval_interval_steps == 0
    summary_str = []

    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        if eval_ops is not None:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # might use EVAL_NO_SUMMARY if some old code has problem TODO CHECK
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY'):
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False

                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                logging.info2('{} eval_step:{} eval_metrics:{}'.format(
                    epoch_str, step,
                    melt.parse_results(eval_loss, eval_names_)))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                melt.set_global('eval_loss',
                                melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != eval_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    #if (is_start or step == 0) and (not 'EVFIRST' in os.environ):
    if ((step == 0) and
        (not 'EVFIRST' in os.environ)) or ('QUICK' in os.environ) or (
            'EVFIRST' in os.environ and os.environ['EVFIRST'] == '0'):
        metric_evaluate = False

    if metric_evaluate:
        # TODO better
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            l = metric_eval_fn()
            if len(l) == 2:
                evaluate_results, evaluate_names = l
                evaluate_summaries = None
            else:
                evaluate_results, evaluate_names, evaluate_summaries = l
        else:
            try:
                l = metric_eval_fn(model_path=model_path)
                if len(l) == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    evaluate_results, evaluate_names, evaluate_summaries = l
            except Exception:
                logging.info('Do nothing for metric eval fn with exception:\n',
                             traceback.format_exc())

        logging.info2('{} valid_step:{} {}:{}'.format(
            epoch_str, step,
            'valid_metrics' if model_path is None else 'epoch_valid_metrics',
            melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once, 'summary_op') or train_once.summary_op is None:
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            #try:
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
                if interval_steps != eval_interval_steps:
                    train_once.avg_loss2 = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)
            if interval_steps != eval_interval_steps:
                train_once.avg_loss2.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        if is_start or interval_steps and step % interval_steps == 0:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.3f} '.format(duration)
                melt.set_global('duration', '%.3f' % duration)
                info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = ' 1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                info.write(
                    'elapsed:[{:.3f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.8f}]'
                    .format(elapsed, batch_size, gpu_info, steps_per_second,
                            instances_per_second, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))

            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                # if not hasattr(train_once, 'summary_op'):
                #   melt.print_summary_ops()
                #   if summary_excls is None:
                #     train_once.summary_op = tf.summary.merge_all()
                #   else:
                #     summary_ops = []
                #     for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                #       for summary_excl in summary_excls:
                #         if not summary_excl in op.name:
                #           summary_ops.append(op)
                #     print('filtered summary_ops:')
                #     for op in summary_ops:
                #       print(op)
                #     train_once.summary_op = tf.summary.merge(summary_ops)

                #   print('-------------summary_op', train_once.summary_op)

                #   #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                #   train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

                #   tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_once.summary_writer, projector_config)

                summary = tf.Summary()
                # #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset
                # #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss
                # assert train_once.summary_train_op is None
                # if train_once.summary_train_op is not None:
                #   summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict)
                #   train_once.summary_writer.add_summary(summary_str, step)

                if eval_ops is None:
                    # #get train loss, for every batch train
                    # if train_once.summary_op is not None:
                    #   #timer2 = gezi.Timer('sess run')
                    #   try:
                    #     # TODO FIXME so this means one more train batch step without adding to global step counter ?! so should move it earlier
                    #     summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict)
                    #   except Exception:
                    #     if not hasattr(train_once, 'num_summary_errors'):
                    #       logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) fail')
                    #       train_once.num_summary_errors = 1
                    #       logging.warning(traceback.format_exc())
                    #     summary_str = ''
                    #   # #timer2.print()
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    # #get eval loss for every batch eval, then add train loss for eval step average loss
                    # try:
                    #   summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) if train_once.summary_op is not None else ''
                    # except Exception:
                    #   if not hasattr(train_once, 'num_summary_errors'):
                    #     logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) fail')
                    #     train_once.num_summary_errors = 1
                    #     logging.warning(traceback.format_exc())
                    #   summary_str = ''
                    #all single value results will be add to summary here not using tf.scalar_summary..
                    #summary.ParseFromString(summary_str)
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'eval' if not eval_names else ''
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    melt.add_summarys(summary,
                                      train_average_loss,
                                      names_,
                                      suffix='train_avg')
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary, melt.batch_size(), 'batch_size')
                    melt.add_summary(summary, melt.epoch(), 'epoch')
                    if steps_per_second:
                        melt.add_summary(summary, steps_per_second,
                                         'steps_per_second')
                    if instances_per_second:
                        melt.add_summary(summary, instances_per_second,
                                         'instances_per_second')
                    if hours_per_epoch:
                        melt.add_summary(summary, hours_per_epoch,
                                         'hours_per_epoch')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step/valid'
                    if model_path:
                        prefix = 'epoch/valid'
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step

                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()

                #timer_.print()

            # if print_time:
            #   full_duration = train_once.eval_timer.elapsed()
            #   if metric_evaluate:
            #     metric_full_duration = train_once.metric_eval_timer.elapsed()
            #   full_duration_str = 'elapsed:{:.3f} '.format(full_duration)
            #   #info.write('duration:{:.3f} '.format(timer.elapsed()))
            #   duration = timer.elapsed()
            #   info.write('duration:{:.3f} '.format(duration))
            #   info.write(full_duration_str)
            #   info.write('eval_time_ratio:{:.3f} '.format(duration/full_duration))
            #   if metric_evaluate:
            #     info.write('metric_time_ratio:{:.3f} '.format(duration/metric_full_duration))
            # #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue())
            # logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d'%step, info.getvalue()))
            return stop
        elif metric_evaluate:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step/valid'
            if model_path:
                prefix = 'epoch/valid'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()