예제 #1
0
def evaluate(model,
             dataset,
             eval_fn,
             model_path=None,
             names=None,
             write_fn=None,
             write_streaming=False,
             num_steps_per_epoch=None,
             suffix='.valid',
             sep=','):
    if FLAGS.torch:
        model.eval()
    if not write_fn:
        write_streaming = True
    predicts_list = []
    labels_list = []
    ids_list = []
    ofile = model_path + suffix if model_path else None
    if write_streaming:
        out = open(ofile, 'w', encoding='utf-8') if ofile else None
        if out:
            if names is not None:
                print(*names, sep=sep, file=out)
    else:
        out = None

    for x, y in tqdm(dataset, total=num_steps_per_epoch, ascii=True):
        if FLAGS.torch:
            x, y = to_torch(x, y)

        predicts = model(x)
        if FLAGS.torch:
            predicts = predicts.detach().cpu()
            y = y.detach().cpu()

        predicts_list.append(predicts)
        labels_list.append(y)
        if not FLAGS.torch:
            ids = gezi.decode(x['id'].numpy())
        else:
            ids = gezi.decode(x['id'])
        ids_list.append(ids)
        if out:
            for id, label, predict in zip(ids, y.numpy(), predicts.numpy()):
                if write_fn is None:
                    if not gezi.iterable(label):
                        label = [label]
                    if not gezi.iterable(predict):
                        predict = [predict]
                    print(id, *label, *predict, sep=sep, file=out)
                else:
                    write_fn(id, label, predict, out)

    # if FLAGS.torch:
    #   predicts_list = [x.detach().numpy() for x in predicts_list]
    #   labels_lis = [x.detach().numpy() for x in labels_list]

    predicts = np.concatenate(predicts_list)
    labels = np.concatenate(labels_list)
    ids = np.concatenate(ids_list)

    if out:
        out.close()

    if not write_streaming and ofile:
        write_fn(ids, labels, predicts, ofile)

    if len(inspect.getargspec(eval_fn).args) == 4:
        vals, names = eval_fn(labels, predicts, ids=ids, model_path=model_path)
    elif len(inspect.getargspec(eval_fn).args) == 3:
        if 'ids' in inspect.getargspec(eval_fn).args:
            vals, names = eval_fn(labels, predicts, ids)
    else:
        vals, names = eval_fn(labels, predicts)

    if model_path:
        with open(model_path + '.valid.metrics', 'w') as out:
            for val, name in zip(vals, names):
                print(name, val, sep='\t', file=out)

    return vals, names
예제 #2
0
파일: train.py 프로젝트: liurht/wenzheng
def evaluate(model, dataset, eval_fn, model_path=None, 
             names=None, write_fn=None, write_streaming=False,
             num_steps=None, num_examples=None,
             suffix='.valid', sep=','):
  if hasattr(model, 'eval'):
    model.eval()
  if not write_fn:
    write_streaming = True
  predicts_list = []
  labels_list = []
  ids_list = []
  ofile = model_path + suffix if model_path else None
  if write_streaming:
    out = open(ofile, 'w', encoding='utf-8') if ofile else None
    if out:
      if names is not None:
        print(*names, sep=sep, file=out)
  else:
    out = None

  for x, y in tqdm(dataset, total=num_steps, ascii=True):
    if FLAGS.torch:
      x, y = to_torch(x, y)

    predicts = model(x)
    if FLAGS.torch:
      predicts = predicts.detach().cpu()
      y = y.detach().cpu()

    predicts_list.append(predicts)
    labels_list.append(y)
    if not FLAGS.torch:
      ids = gezi.decode(x['id'].numpy())
    else:
      ids = gezi.decode(x['id'])
    ids_list.append(ids)

    if out:
      for id, label, predict in zip(ids, y.numpy(), predicts.numpy()):
        if write_fn is None:
          if not gezi.iterable(label):
            label = [label]
          if not gezi.iterable(predict):
            predict = [predict]
          print(id, *label, *predict, sep=sep, file=out)
        else:
          write_fn(id, label, predict, out)

  if out:
    out.close()

    # if FLAGS.torch:
    #   predicts_list = [x.detach().numpy() for x in predicts_list]
    #   labels_lis = [x.detach().numpy() for x in labels_list]

  if FLAGS.use_horovod and FLAGS.horovod_eval:
    #import horovod.torch as hvd
    #print('----------------------before hvd reduce')
    # TODO check eager mode ok...
    tensor = tf.constant(0) if not FLAGS.torch else torch.zeros(0)
    hvd.allreduce(tensor)
    ## here for horovod mutliple gpu dataset is not repeat mode 
    ids_list = comm.allgather(np.concatenate(ids_list))
    predicts_list = comm.allgather(np.concatenate(predicts_list))
    labels_list = comm.allgather(np.concatenate(labels_list))
    comm.barrier()

    ids2 = np.concatenate(ids_list)
    predicts2 = np.concatenate(predicts_list)
    labels2 = np.concatenate(labels_list)
    #----below is for batch parse which if not repeat mode then final batch will still same size not smaller
    # and not use repeat mode so last batch fill with id '' empty we can remove here
    ids = []
    predicts = []
    labels = []
    for i in range(len(ids2)):
      if not ids2[i] == '':
        ids.append(ids2[i])
        predicts.append(predicts2[i])
        labels.append(labels2[i])
    ids = np.array(ids)
    predicts = np.array(predicts)
    labels = np.array(labels)
  else:
    try:
      # concat list so like [[512,], [512,]...] -> [512 * num_batchs]
      # ore [[512, 3], [512,3] ..] -> [512 * num_batchs, 3]
      ids = np.concatenate(ids_list)[:num_examples]
    except Exception:
      ids = ['0'] * num_examples
    predicts = np.concatenate(predicts_list)[:num_examples]
    labels = np.concatenate(labels_list)[:num_examples]
  
  if not write_streaming and ofile and (not FLAGS.use_horovod or hvd.rank() == 0):
    write_fn(ids, labels, predicts, ofile)
    
  if len(inspect.getargspec(eval_fn).args) == 4:
    vals, names = eval_fn(labels, predicts, ids=ids, model_path=model_path)
  elif len(inspect.getargspec(eval_fn).args) == 3:
    if 'ids' in inspect.getargspec(eval_fn).args:
      vals, names = eval_fn(labels, predicts, ids)
    else:
      vals, names = eval_fn(labels, predicts, model_path)
  else:
    vals, names = eval_fn(labels, predicts)
  
  if model_path and (not FLAGS.use_horovod or hvd.rank() == 0):
    with open(model_path + '.valid.metrics', 'w') as out:
      for val, name in zip(vals, names):
        print(name, val, sep='\t', file=out)

  return vals, names
예제 #3
0
def inference(model,
              dataset,
              model_path,
              names=None,
              debug_names=None,
              write_fn=None,
              write_streaming=False,
              num_steps_per_epoch=None,
              suffix='.infer',
              sep=','):
    if FLAGS.torch:
        model.eval()
    if not write_fn:
        write_streaming = True
    ofile = model_path + suffix
    ofile2 = ofile + '.debug'
    if write_streaming:
        if write_fn and len(inspect.getargspec(write_fn).args) == 4:
            out_debug = open(ofile2, 'w', encoding='utf-8')
        else:
            out_debug = None
        out = open(ofile, 'w', encoding='utf-8')
    else:
        out = None
        out_debug = None

    if write_streaming:
        if names is not None:
            print(*names, sep=sep, file=out)
        if debug_names and out_debug:
            print(*debug_names, sep=sep, file=out_debug)

    predicts_list = []
    ids_list = []
    for (x, _) in tqdm(dataset, total=num_steps_per_epoch, ascii=True):
        if FLAGS.torch:
            x = to_torch(x)
        predicts = model(x)
        if FLAGS.torch:
            predicts = predicts.detach().cpu()
        # here id is str in py3 will be bytes
        if not FLAGS.torch:
            ids = gezi.decode(x['id'].numpy())
        else:
            ids = gezi.decode(x['id'])

        if not write_streaming:
            predicts_list.append(predicts)
            ids_list.append(ids)
        else:
            for id, predict in zip(ids, predicts.numpy()):
                if write_fn is None:
                    if not gezi.iterable(predict):
                        predict = [predict]
                    print(id, *predict, sep=sep, file=out)
                else:
                    if out_debug:
                        write_fn(id, predict, out, out_debug)
                    else:
                        write_fn(id, predict, out)

    if out:
        out.close()
    if out_debug:
        out_debug.close()

    if not write_streaming:
        # if FLAGS.torch:
        #   predicts_list = [x.detach().numpy() for x in predicts_list]
        predicts = np.concatenate(predicts_list)
        ids = np.concatenate(ids_list)

        if len(inspect.getargspec(write_fn).args) == 4:
            write_fn(ids, predicts, ofile, ofile2)
        else:
            write_fn(ids, predicts, ofile)
예제 #4
0
파일: train.py 프로젝트: liurht/wenzheng
def inference(model, dataset, model_path, 
              names=None, debug_names=None, 
              write_fn=None, write_streaming=False,
              num_steps=None, num_examples=None,
              suffix='.infer', sep=','):
  if has_attr(model, 'eval'):
    model.eval()
  if not write_fn:
    write_streaming = True
  ofile = model_path + suffix
  ofile2 = ofile + '.debug'
  if write_streaming:
    if write_fn and len(inspect.getargspec(write_fn).args) == 4:
      out_debug = open(ofile2, 'w', encoding='utf-8')
    else:
      out_debug = None
    out = open(ofile, 'w', encoding='utf-8') 
  else:
    out = None
    out_debug = None
  
  if write_streaming:
    if names is not None:
      print(*names, sep=sep, file=out)
    if debug_names and out_debug:
      print(*debug_names, sep=sep, file=out_debug)

  predicts_list = []
  ids_list = []
  for (x, _) in tqdm(dataset, total=num_steps, ascii=True):
    if FLAGS.torch:
      x = to_torch(x)
    predicts = model(x)
    if FLAGS.torch:
      predicts = predicts.detach().cpu()
    # here id is str in py3 will be bytes
    if not FLAGS.torch:
      ids = gezi.decode(x['id'].numpy())
    else:
      ids = gezi.decode(x['id'])

    if not write_streaming:
      predicts_list.append(predicts)
      ids_list.append(ids)
    else:
      for id, predict in zip(ids, predicts.numpy()):
        if write_fn is None:
          if not gezi.iterable(predict):
            predict = [predict]
          print(id, *predict, sep=sep, file=out)
        else:
          if out_debug:
            write_fn(id, predict, out, out_debug)
          else:
            write_fn(id, predict, out)
  
  if out:
    out.close()
  if out_debug:
    out_debug.close()

  if not write_streaming:
    if FLAGS.use_horovod and FLAGS.horovod_eval:
      #import horovod.torch as hvd
      #print('----------------------before hvd reduce')
      tensor = tf.constant(0) if not FLAGS.torch else torch.zeros(0)
      hvd.allreduce(tensor)
      ## here for horovod mutliple gpu dataset is not repeat mode 
      ids_list = comm.allgather(np.concatenate(ids_list))
      predicts_list = comm.allgather(np.concatenate(predicts_list))
      comm.barrier()

      ids2 = np.concatenate(ids_list)
      predicts2 = np.concatenate(predicts_list)
      #----below is for batch parse which if not repeat mode then final batch will still same size not smaller
      # and not use repeat mode so last batch fill with id '' empty we can remove here
      ids = []
      predicts = []
      for i in range(len(ids2)):
        if not ids2[i] == '':
          ids.append(ids2[i])
          predicts.append(predicts2[i])
      ids = np.array(ids)
      predicts = np.array(predicts)
    else:
      try:
        # concat list so like [[512,], [512,]...] -> [512 * num_batchs]
        # ore [[512, 3], [512,3] ..] -> [512 * num_batchs, 3]
        ids = np.concatenate(ids_list)[:num_examples]
      except Exception:
        ids = ['0'] * num_examples
      predicts = np.concatenate(predicts_list)[:num_examples]

    if (not FLAGS.use_horovod or hvd.rank() == 0):
      if len(inspect.getargspec(write_fn).args) == 4:
        write_fn(ids, predicts, ofile, ofile2)
      else:
        write_fn(ids, predicts, ofile)