Пример #1
0
def main(_):
    FLAGS.torch_only = True
    #FLAGS.valid_input = None
    melt.init()
    fit = melt.get_fit()

    FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier

    model_name = FLAGS.model
    model = getattr(base, model_name)()

    loss_fn = nn.BCEWithLogitsLoss()

    td = text_dataset.Dataset()
    train_files = gezi.list_files(FLAGS.train_input)
    train_ds = get_dataset(train_files, td)

    ## speed up a bit with pin_memory==True
    ## num_workers 1 is very slow especially for validation, seems 4 workers is enough, large number dangerous sometimes 12 ok sometimes hang, too much resource seems

    #kwargs = {'num_workers': 12, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
    #kwargs = {'num_workers': 6, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
    kwargs = {
        'num_workers': 8,
        'pin_memory': True,
        'collate_fn': lele.DictPadCollate()
    }
    ## for 1 gpu, set > 8 might startup very slow
    #num_workers = int(8 / hvd.size())
    # num_workers = 0
    # pin_memory = False
    #kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory, 'collate_fn': lele.DictPadCollate()}

    train_dl = DataLoader(train_ds, FLAGS.batch_size, shuffle=True, **kwargs)

    #kwargs['num_workers'] = max(1, num_workers)
    #logging.info('num train examples', len(train_ds), len(train_dl))

    if FLAGS.valid_input:
        valid_files = gezi.list_files(FLAGS.valid_input)
        valid_ds = get_dataset(valid_files, td)
        valid_dl = DataLoader(valid_ds, FLAGS.eval_batch_size, **kwargs)

        #kwargs['num_workers'] = max(1, num_workers)
        valid_dl2 = DataLoader(valid_ds, FLAGS.batch_size, **kwargs)
        #logging.info('num valid examples', len(valid_ds), len(valid_dl))

    fit(
        model,
        loss_fn,
        dataset=train_dl,
        valid_dataset=valid_dl,
        valid_dataset2=valid_dl2,
        eval_fn=ev.evaluate,
        valid_write_fn=ev.valid_write,
        #write_valid=FLAGS.write_valid)
        write_valid=False,
    )
Пример #2
0
def main(_):
  FLAGS.torch_only = True
  melt.init()
  fit = melt.get_fit()

  FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier

  model_name = FLAGS.model
  model = getattr(base, model_name)() 

  model = model.cuda()

  loss_fn = nn.BCEWithLogitsLoss()

  td = text_dataset.Dataset()

  train_files = gezi.list_files('../input/train/*')
  train_ds = get_dataset(train_files, td)
  
  #kwargs = {'num_workers': 4, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
  #kwargs = {'num_workers': 0, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
  #kwargs = {'num_workers': 4, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
  
  num_workers = 1
  kwargs = {'num_workers': num_workers, 'pin_memory': False, 'collate_fn': lele.DictPadCollate()}

  train_sampler = train_ds
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_ds, num_replicas=hvd.size(), rank=hvd.rank())
  
  train_dl = DataLoader(train_ds, FLAGS.batch_size, sampler=train_sampler, **kwargs)
  
  valid_files = gezi.list_files('../input/valid/*')
  valid_ds = get_dataset(valid_files, td)

  kwargs['num_workers'] = 1
  # support shuffle=False from version 1.2
  valid_sampler = torch.utils.data.distributed.DistributedSampler(
      valid_ds, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=False)

  kwargs['num_workers'] = 1
  valid_sampler2 = torch.utils.data.distributed.DistributedSampler(
      valid_ds, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=False)
  
  valid_dl = DataLoader(valid_ds, FLAGS.eval_batch_size, sampler=valid_sampler, **kwargs)
  valid_dl2 = DataLoader(valid_ds, FLAGS.batch_size, sampler=valid_sampler2, **kwargs)

  fit(model,  
      loss_fn,
      dataset=train_dl,
      valid_dataset=valid_dl,
      valid_dataset2=valid_dl2,
      eval_fn=ev.evaluate,
      valid_write_fn=ev.valid_write,
      #write_valid=FLAGS.write_valid)   
      write_valid=False,
     )
Пример #3
0
 def get_filenames(self):
     if self.subset in ['train', 'valid', 'test']:
         if self.subset == 'train':
             return gezi.list_files(FLAGS.train_input)
         elif self.subset == 'valid':
             return gezi.list_files(FLAGS.valid_input)
         elif self.subset == 'test':
             return gezi.list_files(FLAGS.test_input)
     else:
         raise ValueError('Invalid data subset "%s"' % self.subset)
Пример #4
0
def main(_):
    FLAGS.torch_only = True
    melt.init()
    fit = melt.get_fit()

    FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier

    model_name = FLAGS.model
    model = getattr(base, model_name)()

    loss_fn = nn.BCEWithLogitsLoss()

    td = TextDataset()
    train_files = gezi.list_files('../input/train/*')
    train_ds = get_dataset(train_files, td)

    import multiprocessing
    #--easy to be Killed .. if large workers
    num_threads = int(multiprocessing.cpu_count() * 0.3)
    logging.info('num_threads as multiprocessing.cpu_count', num_threads)

    num_threads = 12
    train_dl = DataLoader(train_ds,
                          FLAGS.batch_size,
                          shuffle=True,
                          num_workers=num_threads,
                          collate_fn=lele.DictPadCollate())
    #logging.info('num train examples', len(train_ds), len(train_dl))
    valid_files = gezi.list_files('../input/valid/*')
    valid_ds = get_dataset(valid_files, td)
    valid_dl = DataLoader(valid_ds,
                          FLAGS.eval_batch_size,
                          collate_fn=lele.DictPadCollate(),
                          num_workers=num_threads)
    valid_dl2 = DataLoader(valid_ds,
                           FLAGS.batch_size,
                           collate_fn=lele.DictPadCollate(),
                           num_workers=num_threads)
    #logging.info('num valid examples', len(valid_ds), len(valid_dl))

    fit(
        model,
        loss_fn,
        dataset=train_dl,
        valid_dataset=valid_dl,
        valid_dataset2=valid_dl2,
        eval_fn=ev.evaluate,
        valid_write_fn=ev.valid_write,
        #write_valid=FLAGS.write_valid)
        write_valid=False,
    )
Пример #5
0
 def get_filenames(self, subset=None):
     sbuset = subset or self.subset
     try:
         if subset in ['train', 'valid', 'test']:
             if subset == 'train':
                 return gezi.list_files(FLAGS.train_input)
             elif subset == 'valid':
                 return gezi.list_files(FLAGS.valid_input)
             elif subset == 'test':
                 return gezi.list_files(FLAGS.test_input)
         else:
             raise ValueError('Invalid data subset "%s"' % subset)
     except Exception:
         return None
Пример #6
0
def main(_):
    np.random.seed(FLAGS.seed_)

    files = gezi.list_files(FLAGS.in_dir)
    print('input', FLAGS.in_dir)

    FLAGS.out_dir += f'/{FLAGS.record_name}'
    if not os.path.exists(FLAGS.out_dir):
        print('make new dir: [%s]' % FLAGS.out_dir, file=sys.stderr)
        os.makedirs(FLAGS.out_dir)

    if FLAGS.train_by_day and FLAGS.shuffle_impressions:
        assert FLAGS.day is not None

    global df, uid_vocab, did_vocab, uid_vocab2, did_vocab2
    global cat_vocab, scat_vocab, entity_vocab, entity_type_vocab
    behaviors_file = f'{FLAGS.in_dir}/{FLAGS.mark}/behaviors.tsv'
    if FLAGS.mark == 'train' and FLAGS.day == 6:
        behaviors_file = f'{FLAGS.in_dir}/dev/behaviors.tsv'
    print('behaviors_file', behaviors_file)
    df = pd.read_csv(behaviors_file, sep='\t', names=behaviors_names)
    if FLAGS.mark == 'train':
        print('behaviors_df shuffle')
        df = df.sample(frac=1, random_state=FLAGS.seed_)
    uid_vocab = gezi.Vocab(f'{FLAGS.in_dir}/uid.txt')
    did_vocab = gezi.Vocab(f'{FLAGS.in_dir}/did.txt')
    uid_vocab2 = gezi.Vocab(f'{FLAGS.in_dir}/train/uid.txt')
    did_vocab2 = gezi.Vocab(f'{FLAGS.in_dir}/train/did.txt')
    cat_vocab = gezi.Vocab(f'{FLAGS.in_dir}/cat.txt')
    scat_vocab = gezi.Vocab(f'{FLAGS.in_dir}/sub_cat.txt')
    entity_vocab = gezi.Vocab(f'{FLAGS.in_dir}/entity.txt')
    entity_type_vocab = gezi.Vocab(f'{FLAGS.in_dir}/entity_type.txt')

    for line in open(f'{FLAGS.in_dir}/start_times.txt'):
        did, timestamp, _ = line.strip().split('\t')
        start_timestamps[did] = int(timestamp)

    global news_info
    # ndf = pd.read_csv(f'{FLAGS.in_dir}/{FLAGS.mark}/news.tsv', sep='\t', names=news_names)
    news_info = {}
    # for _, row in tqdm(ndf.iterrows(), total=len(ndf), ascii=True, desc='news_info'):
    #   news_info[row['did']] = row
    news_file = f'{FLAGS.in_dir}/{FLAGS.mark}/news.tsv'
    if FLAGS.mark == 'train' and FLAGS.day == 6:
        news_file = f'{FLAGS.in_dir}/dev/news.tsv'
    total = len(open(news_file).readlines())
    for line in tqdm(open(news_file),
                     total=total,
                     ascii=True,
                     desc='news_info'):
        l = line.strip('\n').split('\t')
        m = {}
        for i, name in enumerate(news_names):
            m[name] = l[i]
        news_info[l[0]] = m

    with Pool(FLAGS.num_records) as p:
        p.map(build_features, range(FLAGS.num_records))
Пример #7
0
def list_files(input):
    """
  @TODO support hdfsGlob 
  """
    local_files = gezi.list_files(input)
    if local_files:
        return local_files
    if not input.startswith(START_DIR):
        return []
    #now only support listdir not glob
    return hdfs_listdir(input)
Пример #8
0
def main(_):
    base = FLAGS.base
    logging.set_logging_path('./mount/tmp/')
    vocab_path = os.path.join(os.path.dirname(os.path.dirname(FLAGS.input)),
                              'vocab.txt')
    ids2text.init(vocab_path)
    FLAGS.vocab = f'{base}/vocab.txt'

    tf.set_random_seed(FLAGS.random_seed)

    # FLAGS.length_index = 2
    # FLAGS.buckets = '100,400'
    # FLAGS.batch_sizes = '64,64,32'

    input_ = FLAGS.input
    if FLAGS.type == 'test':
        input_ = input_.replace('valid', 'test')

    inputs = gezi.list_files(input_)
    inputs.sort()
    if FLAGS.fold is not None:
        inputs = [
            x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)
        ]

    print('type', FLAGS.type, 'inputs', inputs, file=sys.stderr)

    #dataset = Dataset('valid')
    dataset = Dataset('train')

    # balance pos neg tested ok
    dataset = dataset.make_batch(FLAGS.batch_size_, inputs, repeat=False)

    print('dataset', dataset)

    ids = []

    timer = gezi.Timer('read record')
    for i, (x, y) in enumerate(dataset):
        #if i % 10 == 1:
        #  print(x['passage'][0])
        #  print(ids2text.ids2text(x['passage'][0], sep='|'))
        #  print(ids2text.ids2text(x['candidate_pos'][0], sep='|'))
        #  print(ids2text.ids2text(x['candidate_neg'][0], sep='|'))
        #  print(x['passage'])
        #  print(x['candidate_pos'])
        #  print(type(x['id'].numpy()[0]) == bytes)
        #  break
        for id in x['id'].numpy():
            ids.append(id)
        print(i, x['type'].numpy())

    print(len(ids), len(set(ids)))
Пример #9
0
def get_num_records_print(files):
    num_records = 0
    if isinstance(files, str):
        files = gezi.list_files(files)
    num_inputs = len(files)
    index = 0
    for file in files:
        count = get_num_records_single(file)
        print(file, count, '%.3f' % (index / num_inputs))
        num_records += count
        index += 1
    print('num_records:', num_records)
    return num_records
Пример #10
0
def main(_):
    FLAGS.torch_only = True
    melt.init()
    fit = melt.get_fit()

    FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier

    model_name = FLAGS.model
    model = getattr(base, model_name)()

    loss_fn = nn.BCEWithLogitsLoss()

    td = TextDataset()
    train_files = gezi.list_files('../input/train/*')
    train_ds = get_dataset(train_files, td)

    train_dl = DataLoader(train_ds,
                          FLAGS.batch_size,
                          shuffle=True,
                          num_workers=12)
    logging.info('num train examples', len(train_ds), len(train_dl))
    valid_files = gezi.list_files('../input/valid/*')
    valid_ds = get_dataset(valid_files, td)
    valid_dl = DataLoader(valid_ds, FLAGS.eval_batch_size)
    valid_dl2 = DataLoader(valid_ds, FLAGS.batch_size)
    logging.info('num valid examples', len(valid_ds), len(valid_dl))
    print(dir(valid_dl))

    fit(
        model,
        loss_fn,
        dataset=train_dl,
        valid_dataset=valid_dl,
        valid_dataset2=valid_dl2,
        eval_fn=ev.evaluate,
        valid_write_fn=ev.valid_write,
        #write_valid=FLAGS.write_valid)
        write_valid=False,
    )
Пример #11
0
def main(_):
  logging.set_logging_path('./mount/tmp/')
  vocab_path = os.path.join(os.path.dirname(os.path.dirname(FLAGS.input)), 'vocab.txt')
  ids2text.init(vocab_path)
  FLAGS.vocab = './mount/temp/kaggle/toxic/tfrecords/glove/vocab.txt'

  FLAGS.length_index = 2
  #FLAGS.length_index = 1
  FLAGS.buckets = '100,400'
  FLAGS.batch_sizes = '64,64,32'

  input_ = FLAGS.input 
  if FLAGS.type == 'test':
    input_ = input_.replace('train', 'test')

  inputs = gezi.list_files(input_)
  inputs.sort()
  if FLAGS.fold is not None:
    inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)]

  if FLAGS.type != 'dump':
    print('type', FLAGS.type, 'inputs', inputs, file=sys.stderr)

    dataset = Dataset('valid')
    dataset = dataset.make_batch(FLAGS.batch_size_, inputs)

    print('dataset', dataset)

    timer = gezi.Timer('read record')
    for i, (x, y) in enumerate(dataset):
      if i % 10 == 1:
        print(y[0])
        print(x['comment'][0])
        print(ids2text.ids2text(x['comment'][0], sep='|'))
        print(x['comment_str'][0])
        break
  else:
    pass
Пример #12
0
def main(_):
  timer = gezi.Timer()
  input = FLAGS.input 
  
  if FLAGS.threads == 1:
    num_records = melt.get_num_records_print(input)
    print(timer.elapsed())
  else:
    files = gezi.list_files(input)
    print(files)
    pool = multiprocessing.Pool(processes = FLAGS.threads)
    pool.map(deal_file, files)
    pool.close()
    pool.join()
    
    num_records = counter.value 
    print('num_records:', num_records)

  if FLAGS.write_count:
    outdir = os.path.dirname(input)  
    output = '%s/num_records.txt' % outdir
    print('write to %s'%output)
    out = open(output, 'w')
    out.write(str(num_records))
Пример #13
0
def main(_):
  FLAGS.torch_only = True
  
  melt.init()
  #fit = melt.get_fit()

  FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier
  FLAGS.eval_batch_size = 512

  model_name = FLAGS.model
  model = getattr(base, model_name)() 

  model = model.cuda()

  loss_fn = nn.BCEWithLogitsLoss()

  td = text_dataset.Dataset()

  train_files = gezi.list_files('../input/train/*')
  train_ds = get_dataset(train_files, td)
  
  #kwargs = {'num_workers': 4, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
  #num_workers = int(16 / hvd.size())  
  num_workers = 1  # set to 1 2 min to start might just set to 0 for safe
  num_workers = 0 # 设置0 速度比1慢很多   启动都需要1分多。。
  # pin_memory 影响不大 单gpu提升速度一点点 多gpu 主要是 num_workers  影响资源占有。。有可能启动不起来
  # 多gpu pin_memory = False 反而速度更快。。
  #kwargs = {'num_workers': num_workers, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}  
  kwargs = {'num_workers': 1, 'pin_memory': False, 'collate_fn': lele.DictPadCollate()}

  train_sampler = train_ds
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_ds, num_replicas=hvd.size(), rank=hvd.rank())
  
  train_dl = DataLoader(train_ds, FLAGS.batch_size, sampler=train_sampler, **kwargs)
  
  valid_files = gezi.list_files('../input/valid/*')
  valid_ds = get_dataset(valid_files, td)

  # support shuffle=False from version 1.2
  valid_sampler = torch.utils.data.distributed.DistributedSampler(
      valid_ds, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=False)

  # valid_sampler2 = torch.utils.data.distributed.DistributedSampler(
  #     valid_ds, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=False)
  
  valid_dl = DataLoader(valid_ds, FLAGS.eval_batch_size, sampler=valid_sampler, **kwargs)
  
  #valid_dl2 = DataLoader(valid_ds, FLAGS.batch_size, sampler=valid_sampler2, **kwargs)


  optimizer = optim.Adamax(model.parameters(), lr=0.1)
  #optimizer = optim.SGD(model.parameters(), lr=0.1)
  hvd.broadcast_parameters(model.state_dict(), root_rank=0)
  hvd.broadcast_optimizer_state(optimizer, root_rank=0)

  optimizer = hvd.DistributedOptimizer(optimizer,
                                       named_parameters=model.named_parameters())

  for epoch in range(2):
    train(epoch, model, loss_fn, train_dl, optimizer)
    test(model, loss_fn, valid_dl)
Пример #14
0
def inputs(files,
           decode,
           batch_size=64,
           num_epochs=None,
           num_threads=12,
           shuffle=True,
           batch_join=True,
           shuffle_batch=True,
           min_after_dequeue=None,
           seed=None,
           fix_random=False,
           no_random=False,
           fix_sequence=False,
           allow_smaller_final_batch=False,
           num_prefetch_batches=None,
           name='input'):
    """Reads input data num_epochs times.
  for sparse input here will do:
  1. read decode serialized_example
  2. shuffle decoded values
  3. return batch decoded values
  Args:
  decode: user defined decode 
  #---decode example
  # features = tf.parse_single_example(
  #     serialized_example,
  #     features={
  #         'feature': tf.FixedLenFeature([], tf.string),
  #         'name': tf.FixedLenFeature([], tf.string),
  #         'comment_str': tf.FixedLenFeature([], tf.string),
  #         'comment': tf.FixedLenFeature([], tf.string),
  #         'num_words': tf.FixedLenFeature([], tf.int64),
  #     })
  # feature = tf.decode_raw(features['feature'], tf.float32)
  # feature.set_shape([IMAGE_FEATURE_LEN])
  # comment = tf.decode_raw(features['comment'], tf.int64)
  # comment.set_shape([COMMENT_MAX_WORDS])
  # name = features['name']
  # comment_str = features['comment_str']
  # num_words = features['num_words']
  # return name, feature, comment_str, comment, num_words
  Returns:
  list of tensors
  """
    if isinstance(files, str):
        files = gezi.list_files(files)

    if not min_after_dequeue:
        min_after_dequeue = melt.tfrecords.read.MIN_AFTER_QUEUE
    if not num_epochs: num_epochs = None

    if fix_random:
        if seed is None:
            seed = 1024
        shuffle = True
        batch_join = False  #check can be True ?

        #to get fix_random
        #shuffle_batch = True  and num_threads = 1 ok
        #shuffle_batch = False and num_threads >= 1 ok
        #from models/iamge-text-sim/read_records shuffle_batch = True will be quick, even single thread
        #and strange num_threas = 1 will be quicker then 12

        shuffle_batch = True
        num_threads = 1

        #shuffle_batch = False

    if fix_sequence:
        no_random = True
        allow_smaller_final_batch = True

    if no_random:
        shuffle = False
        batch_join = False
        shuffle_batch = False
        num_threads = 1

    #shuffle=True
    #batch_join = True #setting to False can get fixed result
    #seed = 1024

    with tf.name_scope(name):
        filename_queue = tf.train.string_input_producer(files,
                                                        num_epochs=num_epochs,
                                                        shuffle=shuffle,
                                                        seed=seed)

        # min_after_dequeue defines how big a buffer we will randomly sample
        #   from -- bigger means better shuffling but slower start up and more
        #   memory used.
        # capacity must be larger than min_after_dequeue and the amount larger
        #   determines the maximum we will prefetch.  Recommendation:
        #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
        #@TODO cifa10 always use num_prefetch_batches = 3, 3 * batch_size, check which is better
        if not num_prefetch_batches: num_prefetch_batches = num_threads + 3
        if batch_join:
            batch_list = [
                read_decode(filename_queue, decode)
                for _ in xrange(num_threads)
            ]
            #print batch_list
            batch = tf.train.shuffle_batch_join(
                batch_list,
                batch_size=batch_size,
                capacity=min_after_dequeue + num_prefetch_batches * batch_size,
                min_after_dequeue=min_after_dequeue,
                seed=seed,
                allow_smaller_final_batch=allow_smaller_final_batch)
        else:
            serialized_example = read_decode(filename_queue, decode)
            num_threads = 1 if fix_random else num_threads
            if shuffle_batch:
                batch = tf.train.shuffle_batch(
                    serialized_example,
                    batch_size=batch_size,
                    num_threads=num_threads,
                    capacity=min_after_dequeue +
                    num_prefetch_batches * batch_size,
                    min_after_dequeue=min_after_dequeue,
                    seed=seed,
                    allow_smaller_final_batch=allow_smaller_final_batch)
            else:
                batch = tf.train.batch(
                    serialized_example,
                    batch_size=batch_size,
                    num_threads=num_threads,
                    capacity=min_after_dequeue +
                    num_prefetch_batches * batch_size,
                    allow_smaller_final_batch=allow_smaller_final_batch)

        return batch
Пример #15
0
def get_num_records(files):
    if isinstance(files, str):
        files = gezi.list_files(files)
    return sum([get_num_records_single(file) for file in files])
Пример #16
0
def train(Dataset,
          model,
          loss_fn,
          evaluate_fn=None,
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          optimizer=None,
          param_groups=None,
          init_fn=None,
          dataset=None,
          valid_dataset=None,
          test_dataset=None,
          sep=','):
    if Dataset is None:
        assert dataset
    if FLAGS.torch:
        # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        model.to(device)

    input_ = FLAGS.train_input
    inputs = gezi.list_files(input_)
    inputs.sort()

    all_inputs = inputs

    #batch_size = FLAGS.batch_size
    batch_size = melt.batch_size()

    num_gpus = melt.num_gpus()

    #batch_size = max(batch_size, 1)
    #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
    batch_size_ = batch_size

    if FLAGS.fold is not None:
        inputs = [
            x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)
            and not x.endswith('%d.tfrecord' % FLAGS.fold)
        ]
        # if FLAGS.valid_input:
        #   inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)]
    logging.info('inputs', len(inputs), inputs[:100])
    num_folds = FLAGS.num_folds or len(inputs) + 1

    train_dataset_ = dataset or Dataset('train')
    train_dataset = train_dataset_.make_batch(batch_size, inputs)
    num_examples = train_dataset_.num_examples_per_epoch('train')
    num_all_examples = num_examples

    valid_inputs = None
    if FLAGS.valid_input:
        valid_inputs = gezi.list_files(FLAGS.valid_input)
    else:
        if FLAGS.fold is not None:
            #valid_inputs = [x for x in all_inputs if x not in inputs]
            if not FLAGS.test_aug:
                valid_inputs = [
                    x for x in all_inputs if not 'aug' in x and x not in inputs
                ]
            else:
                valid_inputs = [
                    x for x in all_inputs if 'aug' in x and x not in inputs
                ]

    logging.info('valid_inputs', valid_inputs)

    if valid_inputs:
        valid_dataset_ = valid_dataset or Dataset('valid')
        valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs)
        valid_dataset2 = valid_dataset_.make_batch(batch_size_,
                                                   valid_inputs,
                                                   repeat=True)
    else:
        valid_datsset = None
        valid_dataset2 = None

    if num_examples:
        if FLAGS.fold is not None:
            num_examples = int(num_examples * (num_folds - 1) / num_folds)
        num_steps_per_epoch = -(-num_examples // batch_size)
    else:
        num_steps_per_epoch = None
    logging.info('num_train_examples:', num_examples)

    num_valid_examples = None
    if FLAGS.valid_input:
        num_valid_examples = valid_dataset_.num_examples_per_epoch('valid')
        num_valid_steps_per_epoch = -(
            -num_valid_examples // batch_size_) if num_valid_examples else None
    else:
        if FLAGS.fold is not None:
            if num_examples:
                num_valid_examples = int(num_all_examples * (1 / num_folds))
                num_valid_steps_per_epoch = -(-num_valid_examples //
                                              batch_size_)
            else:
                num_valid_steps_per_epoch = None
    logging.info('num_valid_examples:', num_valid_examples)

    if FLAGS.test_input:
        test_inputs = gezi.list_files(FLAGS.test_input)
        #test_inputs = [x for x in test_inputs if not 'aug' in x]
        logging.info('test_inputs', test_inputs)
    else:
        test_inputs = None

    num_test_examples = None
    if test_inputs:
        test_dataset_ = test_dataset or Dataset('test')
        test_dataset = test_dataset_.make_batch(batch_size_, test_inputs)
        num_test_examples = test_dataset_.num_examples_per_epoch('test')
        num_test_steps_per_epoch = -(
            -num_test_examples // batch_size_) if num_test_examples else None
    else:
        test_dataset = None
    logging.info('num_test_examples:', num_test_examples)

    summary = tf.contrib.summary
    # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch')
    # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train')
    # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid')
    writer = summary.create_file_writer(FLAGS.log_dir)
    writer_train = summary.create_file_writer(FLAGS.log_dir)
    writer_valid = summary.create_file_writer(FLAGS.log_dir)
    global_step = tf.train.get_or_create_global_step()

    learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")

    tf.add_to_collection('learning_rate', learning_rate)

    learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
    try:
        learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
    except Exception:
        learning_rate_weights = None

    # ckpt dir save models one per epoch
    ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt')
    os.system('mkdir -p %s' % ckpt_dir)
    # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset
    ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2')
    os.system('mkdir -p %s' % ckpt_dir2)

    #TODO FIXME now I just changed tf code so to not by default save only latest 5
    # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
    #     checkpoint, directory=ckpt_dir, max_to_keep=5)
    # latest_checkpoint = manager.latest_checkpoint

    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
    if latest_checkpoint:
        logging.info('Latest checkpoint:', latest_checkpoint)
    else:
        latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2)
        logging.info('Latest checkpoint:', latest_checkpoint)

    if os.path.exists(FLAGS.model_dir + '.index'):
        latest_checkpoint = FLAGS.model_dir

    if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode:
        #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
        latest_checkpoint = FLAGS.model_dir
        #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint)

    checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')
    checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt')

    if not FLAGS.torch:
        try:
            optimizer = optimizer or melt.get_optimizer(
                FLAGS.optimizer)(learning_rate)
        except Exception:
            logging.warning(
                f'Fail to using {FLAGS.optimizer} use adam instead')
            optimizer = melt.get_optimizer('adam')(learning_rate)

        # TODO...
        if learning_rate_weights is None:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                model=model,
                optimizer=optimizer,
                global_step=global_step)
        else:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                learning_rate_weights=learning_rate_weights,
                model=model,
                optimizer=optimizer,
                global_step=global_step)

        checkpoint.restore(latest_checkpoint)
        checkpoint2 = copy.deepcopy(checkpoint)

        start_epoch = int(
            latest_checkpoint.split('-')
            [-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0
    else:
        # TODO torch with learning rate adjust
        if optimizer is None:
            import lele
            is_dynamic_opt = True
            if FLAGS.optimizer == 'noam':
                optimizer = lele.training.optimizers.NoamOpt(
                    128, 2, 4000, torch.optim.Adamax(model.parameters(), lr=0))
            elif FLAGS.optimizer == 'bert':
                num_train_steps = int(
                    num_steps_per_epoch *
                    (FLAGS.num_decay_epochs or FLAGS.num_epochs))
                num_warmup_steps = FLAGS.warmup_steps or int(
                    num_train_steps * FLAGS.warmup_proportion)
                logging.info('num_train_steps', num_train_steps,
                             'num_warmup_steps', num_warmup_steps,
                             'warmup_proportion', FLAGS.warmup_proportion)
                optimizer = lele.training.optimizers.BertOpt(
                    FLAGS.learning_rate, FLAGS.min_learning_rate,
                    num_train_steps, num_warmup_steps,
                    torch.optim.Adamax(model.parameters(), lr=0))
            else:
                is_dynamic_opt = False
                optimizer = torch.optim.Adamax(
                    param_groups if param_groups else model.parameters(),
                    lr=FLAGS.learning_rate)

        start_epoch = 0
        latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join(
            FLAGS.model_dir, 'latest.pyt')
        if not os.path.exists(latest_path):
            latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt')
        if os.path.exists(latest_path):
            logging.info('loading torch model from', latest_path)
            checkpoint = torch.load(latest_path)
            if not FLAGS.torch_finetune:
                start_epoch = checkpoint['epoch']
                step = checkpoint['step']
                global_step.assign(step + 1)
            load_torch_model(model, latest_path)
            if FLAGS.torch_load_optimizer:
                optimizer.load_state_dict(checkpoint['optimizer'])

        # TODO by this way restart can not change learning rate..
        if learning_rate_weights is None:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                global_step=global_step)
        else:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                learning_rate_weights=learning_rate_weights,
                global_step=global_step)

        try:
            checkpoint.restore(latest_checkpoint)
            checkpoint2 = copy.deepcopy(checkpoint)
        except Exception:
            pass

    if FLAGS.torch and is_dynamic_opt:
        optimizer._step = global_step.numpy()

    #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
    #model.save('./weight3.hd5')
    logging.info('optimizer:', optimizer)

    if FLAGS.torch_lr:
        learning_rate.assign(optimizer.rate(1))
    if FLAGS.torch:
        learning_rate.assign(optimizer.param_groups[0]['lr'])
        logging.info('learning rate got from pytorch latest.py as',
                     learning_rate)

    learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor)
    if learning_rate_weights is not None:
        learning_rate_weights.assign(learning_rate_weights *
                                     FLAGS.learning_rate_start_factor)

    # TODO currently not support 0.1 epoch.. like this
    num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024

    will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ
    if start_epoch == 0 and not 'EVFIRST' in os.environ and will_valid:
        will_valid = False

    if start_epoch > 0 and will_valid:
        will_valid = True

    if will_valid:
        logging.info('----------valid')
        if FLAGS.torch:
            model.eval()
        names = None
        if evaluate_fn is not None:
            vals, names = evaluate_fn(model, valid_dataset,
                                      tf.train.latest_checkpoint(ckpt_dir),
                                      num_valid_steps_per_epoch)
        elif eval_fn:
            model_path = None if not write_valid else latest_checkpoint
            names = valid_names if valid_names is not None else [
                infer_names[0]
            ] + [x + '_y' for x in infer_names[1:]
                 ] + infer_names[1:] if infer_names else None

            logging.info('model_path:', model_path, 'model_dir:',
                         FLAGS.model_dir)
            vals, names = evaluate(model,
                                   valid_dataset,
                                   eval_fn,
                                   model_path,
                                   names,
                                   valid_write_fn,
                                   write_streaming,
                                   num_valid_steps_per_epoch,
                                   suffix=valid_suffix,
                                   sep=sep)
        if names:
            logging.info2(
                'epoch:%d/%d' % (start_epoch, num_epochs),
                ['%s:%.5f' % (name, val) for name, val in zip(names, vals)])

    if FLAGS.work_mode == 'valid':
        exit(0)

    if 'test' in FLAGS.work_mode:
        logging.info('--------test/inference')
        if test_dataset:
            if FLAGS.torch:
                model.eval()
            if inference_fn is None:
                # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint
                # logging.info('model_path', model_path)
                assert latest_checkpoint
                inference(model,
                          test_dataset,
                          latest_checkpoint,
                          infer_names,
                          infer_debug_names,
                          infer_write_fn,
                          write_streaming,
                          num_test_steps_per_epoch,
                          suffix=infer_suffix)
            else:
                inference_fn(model, test_dataset,
                             tf.train.latest_checkpoint(ckpt_dir),
                             num_test_steps_per_epoch)
        exit(0)

    if 'SHOW' in os.environ:
        num_epochs = start_epoch + 1

    class PytObj(object):
        def __init__(self, x):
            self.x = x

        def numpy(self):
            return self.x

    class PytMean(object):
        def __init__(self):
            self._val = 0.
            self.count = 0

            self.is_call = True

        def clear(self):
            self._val = 0
            self.count = 0

        def __call__(self, val):
            if not self.is_call:
                self.clear()
                self.is_call = True
            self._val += val.item()
            self.count += 1

        def result(self):
            if self.is_call:
                self.is_call = False
            if not self.count:
                val = 0
            else:
                val = self._val / self.count
            # TODO just for compact with tf ..
            return PytObj(val)

    Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean
    timer = gezi.Timer()
    num_insts = 0

    if FLAGS.learning_rate_decay_factor > 0:
        #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?'
        #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay
        #since global step / decay_steps will not be correct epoch as num_steps per epoch changed
        #so if if you change batch set you have to reset global step as fixed step
        assert FLAGS.num_steps_per_decay or (
            FLAGS.num_epochs_per_decay and num_steps_per_epoch
        ), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch'
        decay_steps = FLAGS.num_steps_per_decay or int(
            num_steps_per_epoch * FLAGS.num_epochs_per_decay)
        decay_start_step = FLAGS.decay_start_step or int(
            num_steps_per_epoch * FLAGS.decay_start_epoch)
        # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
        logging.info(
            'learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}'
            .format(FLAGS.learning_rate_decay_factor,
                    FLAGS.num_epochs_per_decay, decay_steps,
                    FLAGS.decay_start_epoch, decay_start_step))

    for epoch in range(start_epoch, num_epochs):
        melt.set_global('epoch', '%.4f' % (epoch))

        if FLAGS.torch:
            model.train()

        epoch_loss_avg = Mean()
        epoch_valid_loss_avg = Mean()

        #for i, (x, y) in tqdm(enumerate(train_dataset), total=num_steps_per_epoch, ascii=True):
        for i, (x, y) in enumerate(train_dataset):
            if FLAGS.torch:
                x, y = to_torch(x, y)
                if is_dynamic_opt:
                    learning_rate.assign(optimizer.rate())

            #print(x, y)

            if not FLAGS.torch:
                loss, grads = melt.eager.grad(model, x, y, loss_fn)
                grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients)
                optimizer.apply_gradients(zip(grads, model.variables))
            else:
                optimizer.zero_grad()
                if 'training' in inspect.getargspec(loss_fn).args:
                    loss = loss_fn(model, x, y, training=True)
                else:
                    loss = loss_fn(model, x, y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               FLAGS.clip_gradients)
                optimizer.step()

            global_step.assign_add(1)

            epoch_loss_avg(loss)  # add current batch loss

            if FLAGS.torch:
                del loss

            batch_size_ = list(
                x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type(
                    {}) else x.shape[FLAGS.batch_size_dim]
            num_insts += int(batch_size_)
            if global_step.numpy() % FLAGS.interval_steps == 0:
                #checkpoint.save(checkpoint_prefix)
                elapsed = timer.elapsed()
                steps_per_second = FLAGS.interval_steps / elapsed
                instances_per_second = num_insts / elapsed
                num_insts = 0

                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)

                if valid_dataset2:
                    try:
                        x, y = next(iter(valid_dataset2))
                    except Exception:
                        # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset
                        x, y = next(iter(valid_dataset2))

                    if FLAGS.torch:
                        x, y = to_torch(x, y)
                        model.eval()
                    valid_loss = loss_fn(model, x, y)
                    epoch_valid_loss_avg(valid_loss)
                    if FLAGS.torch:
                        model.train()

                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' %
                        elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' %
                        num_gpus, 'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second, '%s' %
                        epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(),
                        'valid_loss:[%.4f]' %
                        epoch_valid_loss_avg.result().numpy())
                    if global_step.numpy() % FLAGS.eval_interval_steps == 0:
                        with writer_valid.as_default(
                        ), summary.always_record_summaries():
                            #summary.scalar('step/loss', epoch_valid_loss_avg.result().numpy())
                            summary.scalar(
                                'loss/eval',
                                epoch_valid_loss_avg.result().numpy())
                            writer_valid.flush()
                else:
                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' %
                        elapsed, 'batch_size:[%d]' % batch_size_,
                        'gpus:[%d]' % num_gpus,
                        'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second,
                        '%s' % epoch_time_info,
                        'lr:[%.8f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % epoch_loss_avg.result().numpy())

                if global_step.numpy() % FLAGS.eval_interval_steps == 0:
                    with writer_train.as_default(
                    ), summary.always_record_summaries():
                        #summary.scalar('step/loss', epoch_loss_avg.result().numpy())
                        summary.scalar('loss/train_avg',
                                       epoch_loss_avg.result().numpy())
                        summary.scalar('learning_rate', learning_rate.numpy())
                        summary.scalar('batch_size', batch_size_)
                        summary.scalar('epoch', melt.epoch())
                        summary.scalar('steps_per_second', steps_per_second)
                        summary.scalar('instances_per_second',
                                       instances_per_second)
                        writer_train.flush()

                    if FLAGS.log_dir != FLAGS.model_dir:
                        assert FLAGS.log_dir
                        command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir,
                                                              FLAGS.model_dir)
                        print(command, file=sys.stderr)
                        os.system(command)

            if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy(
            ) and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0:
                if FLAGS.torch:
                    model.eval()
                vals, names = None, None
                if evaluate_fn is not None:
                    vals, names = evaluate_fn(model, valid_dataset, None,
                                              num_valid_steps_per_epoch)
                elif eval_fn:
                    names = valid_names if valid_names is not None else [
                        infer_names[0]
                    ] + [x + '_y' for x in infer_names[1:]
                         ] + infer_names[1:] if infer_names else None
                    vals, names = evaluate(model,
                                           valid_dataset,
                                           eval_fn,
                                           None,
                                           names,
                                           valid_write_fn,
                                           write_streaming,
                                           num_valid_steps_per_epoch,
                                           sep=sep)
                if vals and names:
                    with writer_valid.as_default(
                    ), summary.always_record_summaries():
                        for name, val in zip(names, vals):
                            summary.scalar(f'step/valid/{name}', val)
                        writer_valid.flush()

                if FLAGS.torch:
                    if not FLAGS.torch_lr:
                        # control learning rate by tensorflow learning rate
                        for param_group in optimizer.param_groups:
                            # important learning rate decay
                            param_group['lr'] = learning_rate.numpy()

                    model.train()

                if names and vals:
                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'valid_step:%d' % global_step.numpy(), 'valid_metrics',
                        [
                            '%s:%.5f' % (name, val)
                            for name, val in zip(names, vals)
                        ])

            # if i == 5:
            #   print(i, '---------------------save')
            #   print(len(model.trainable_variables))
            ## TODO FIXME seems save weighs value not ok... not the same as checkpoint save
            #   model.save_weights(os.path.join(ckpt_dir, 'weights'))
            #   checkpoint.save(checkpoint_prefix)
            #   exit(0)

            if global_step.numpy() % FLAGS.save_interval_steps == 0:
                if FLAGS.torch:
                    state = {
                        'epoch':
                        epoch,
                        'step':
                        global_step.numpy(),
                        'state_dict':
                        model.state_dict() if not hasattr(model, 'module') else
                        model.module.state_dict(),
                        'optimizer':
                        optimizer.state_dict(),
                    }
                    torch.save(state,
                               os.path.join(FLAGS.model_dir, 'latest.pyt'))

            # TODO fixme why if both checpoint2 and chekpoint used... not ok..
            if FLAGS.save_interval_epochs and FLAGS.save_interval_epochs < 1 and global_step.numpy(
            ) % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
                #if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
                checkpoint2.save(checkpoint_prefix2)
                if FLAGS.torch:
                    state = {
                        'epoch':
                        epoch,
                        'step':
                        global_step.numpy(),
                        'state_dict':
                        model.state_dict() if not hasattr(model, 'module') else
                        model.module.state_dict(),
                        'optimizer':
                        optimizer.state_dict(),
                    }
                    torch.save(state,
                               tf.train.latest_checkpoint(ckpt_dir2) + '.pyt')

            if FLAGS.learning_rate_decay_factor > 0:
                if global_step.numpy(
                ) >= decay_start_step and global_step.numpy(
                ) % decay_steps == 0:
                    lr = max(
                        learning_rate.numpy() *
                        FLAGS.learning_rate_decay_factor,
                        FLAGS.min_learning_rate)
                    if lr < learning_rate.numpy():
                        learning_rate.assign(lr)
                        if FLAGS.torch:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate.numpy()

            if epoch == start_epoch and i == 0:
                try:
                    if not FLAGS.torch:
                        logging.info(model.summary())
                except Exception:
                    traceback.print_exc()
                    logging.info(
                        'Fail to do model.summary() may be you have layer define in init but not used in call'
                    )
                if 'SHOW' in os.environ:
                    exit(0)

        logging.info2(
            'epoch:%d/%d' % (epoch + 1, num_epochs),
            'step:%d' % global_step.numpy(), 'batch_size:[%d]' % batch_size,
            'gpus:[%d]' % num_gpus, 'lr:[%.8f]' % learning_rate.numpy(),
            'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(),
            'valid_loss::[%.4f]' % epoch_valid_loss_avg.result().numpy())

        timer = gezi.Timer(
            f'save model to {checkpoint_prefix}-{checkpoint.save_counter.numpy() + 1}',
            False)
        checkpoint.save(checkpoint_prefix)
        if FLAGS.torch and FLAGS.save_interval_epochs == 1:
            state = {
                'epoch':
                epoch + 1,
                'step':
                global_step.numpy(),
                'state_dict':
                model.state_dict()
                if not hasattr(model, 'module') else model.module.state_dict(),
                'optimizer':
                optimizer.state_dict(),
            }
            torch.save(state, tf.train.latest_checkpoint(ckpt_dir) + '.pyt')

        timer.print_elapsed()

        if valid_dataset and (epoch + 1) % FLAGS.valid_interval_epochs == 0:
            if FLAGS.torch:
                model.eval()

            vals, names = None, None
            if evaluate_fn is not None:
                vals, names = evaluate_fn(model, valid_dataset,
                                          tf.train.latest_checkpoint(ckpt_dir),
                                          num_valid_steps_per_epoch)
            elif eval_fn:
                model_path = None if not write_valid else tf.train.latest_checkpoint(
                    ckpt_dir)
                names = valid_names if valid_names is not None else [
                    infer_names[0]
                ] + [x + '_y' for x in infer_names[1:]
                     ] + infer_names[1:] if infer_names else None

                vals, names = evaluate(model,
                                       valid_dataset,
                                       eval_fn,
                                       model_path,
                                       names,
                                       valid_write_fn,
                                       write_streaming,
                                       num_valid_steps_per_epoch,
                                       suffix=valid_suffix,
                                       sep=sep)

            if vals and names:
                logging.info2('epoch:%d/%d' % (epoch + 1, num_epochs),
                              'step:%d' % global_step.numpy(),
                              'epoch_valid_metrics', [
                                  '%s:%.5f' % (name, val)
                                  for name, val in zip(names, vals)
                              ])

        with writer.as_default(), summary.always_record_summaries():
            temp = global_step.value()
            global_step.assign(epoch + 1)
            summary.scalar('epoch/train/loss', epoch_loss_avg.result().numpy())
            if valid_dataset:
                if FLAGS.torch:
                    model.eval()
                if vals and names:
                    for name, val in zip(names, vals):
                        summary.scalar(f'epoch/valid/{name}', val)
            writer.flush()
            global_step.assign(temp)

        if test_dataset and (epoch + 1) % FLAGS.inference_interval_epochs == 0:
            if FLAGS.torch:
                model.eval()
            if inference_fn is None:
                inference(model,
                          test_dataset,
                          tf.train.latest_checkpoint(ckpt_dir),
                          infer_names,
                          infer_debug_names,
                          infer_write_fn,
                          write_streaming,
                          num_test_steps_per_epoch,
                          suffix=infer_suffix,
                          sep=sep)
            else:
                inference_fn(model, test_dataset,
                             tf.train.latest_checkpoint(ckpt_dir),
                             num_test_steps_per_epoch)

    if FLAGS.log_dir != FLAGS.model_dir:
        assert FLAGS.log_dir
        command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir)
        print(command, file=sys.stderr)
        os.system(command)
        command = 'rm -rf %s/latest.pyt.*' % (FLAGS.model_dir)
        print(command, file=sys.stderr)
        os.system(command)
Пример #17
0
def main(_):  
  in_dir = sys.argv[1]
  files = gezi.list_files(in_dir)
  total = melt.get_num_records(files) 
  print('total', total, file=sys.stderr)

  if not total:
    exit(1)

  ofile = sys.argv[2]
  df = None
  if ofile and gezi.non_empty(os.path.realpath(ofile)):
    try:
      df = pd.read_csv(ofile)
      if len(df) == total and (not FLAGS.title or 'title' in df.columns):
        print(f'infos file {ofile} exits do nothing', file=sys.stderr)
        exit(0)
      else:
        print('num_done:', len(df), file=sys.stderr)
    except Exception:
      pass
  
  print('write to', ofile, file=sys.stderr)

  FLAGS.batch_size = FLAGS.batch_size_
  batch_size = FLAGS.batch_size

  if tf.__version__ < '2':
    tf.compat.v1.enable_eager_execution()
  
  dataset = Dataset('valid')
  print('---batch_size', dataset.batch_size, FLAGS.batch_size, melt.batch_size(), file=sys.stderr)  
  
  batches = dataset.make_batch(batch_size=batch_size, filenames=files, repeat=False)

  num_steps = -int(-total // batch_size)
  print('----num_steps', num_steps, file=sys.stderr) 
  m = defaultdict(list)
  for i, (x, _) in tqdm(enumerate(batches), total=num_steps, ascii=True, desc='loop'):
    bs = len(x['id'])
    keys = list(x.keys())
    for key in keys:
      x[key] = x[key].numpy()
      if not len(x[key]):
        del x[key]
        continue
      if x[key].shape == (bs, 1):
        x[key] = gezi.squeeze(x[key])
      if x[key].shape != (bs,):
        del x[key]
        continue
      if x[key].dtype == np.object:
        x[key] = gezi.decode(x[key])
      m[key] += [x[key]]
    if i == 0:
      if df is not None and len(df) == total and set(m.keys()) == set(list(df.columns)):
        print(f'infos file {ofile} exits do nothing', file=sys.stderr)
        exit(0)

  for key in m.keys():
    m[key] = np.concatenate(m[key], 0)
    
  df = pd.DataFrame(m)
  df.to_csv(ofile, index=False)
Пример #18
0
def inputs(
        files,
        decode_fn,
        batch_size=64,
        num_epochs=None,
        num_threads=None,
        buffer_size=15000,  #change from 1000 to 15000
        dynamic_pad=True,
        shuffle_files=True,
        batch_join=True,
        shuffle_batch=True,
        min_after_dequeue=None,
        seed=None,
        enqueue_many=False,
        fix_random=False,
        no_random=False,
        fix_sequence=False,
        allow_smaller_final_batch=True,
        num_prefetch_batches=None,
        bucket_boundaries=None,
        length_index=None,
        length_key=None,
        length_fn=None,
        bucket_batch_sizes=None,
        repeat=True,
        initializable=False,
        filter_fn=None,
        balance_pos_neg=False,
        pos_filter_fn=None,
        neg_filter_fn=None,
        count_fn=None,
        return_iterator=False,
        Dataset=None,
        batch_parse=False,  #by default will be line parse
        hvd_shard=True,
        shard_by_files=False,
        training=True,
        simple_parse=False,
        repeat_then_shuffle=False,
        name='input'):
    """Reads input data num_epochs times.
  for sparse input here will do:
  1. read serialized_example
  2. shuffle serialized_examples
  3. decdoe batch_serialized_examples
  notice read_sparse.inputs and also be used for dense inputs,but if you 
  only need to decode part from serialized_example, then read.inputs will 
  be better, less to put to suffle
  #--------decode example, can refer to libsvm-decode.py
  # def decode(batch_serialized_examples):
  #   features = tf.parse_example(
  #       batch_serialized_examples,
  #       features={
  #           'label' : tf.FixedLenFeature([], tf.int64),
  #           'index' : tf.VarLenFeature(tf.int64),
  #           'value' : tf.VarLenFeature(tf.float32),
  #       })

  #   label = features['label']
  #   index = features['index']
  #   value = features['value']

  #   return label, index, value 

  #string_input_reducer will shuffle files
  #shuffle will read file by file and shuffle withn file(in shuffle queue) 
  #shuffle_batch_join will read multiple files and shuffle in shuffle queue(from many files)

  To get fixed sequence 
  shuffle=False  so by this way the sequence is as your data input unchange
  or
  shuffle=True
  seed=1024 #set
  batch_join=False  by this way you have fixed random, so get same result
  NOTICE, shuffle=True,seed=1024,batch_join=True will not get same result
  shuffle=False,seed=1024,batch_join=True also, so batch_join seems seed only control inqueue random, can not get fixed result

  for no random -> fixed result set shuffle=False wihch will force batch_join=False then use batch
  for fixed random ->  shuffle=True, seed set or  fix_random=True
  read-records.py show above ok, but train-evaluate.py show not, only shuffle=False can get fixed result.. @FIXME strange
  for train-evaluate.py it looks you can set shuffle in string_input_producer True, but then must use batch,
  batch_join and shuffle_batch join all not fixed even with seed set, may be due to trainset two inputs read ?
  for read-records.py batch_join will be fixed, shuffle_batch_join not 

  defualt parmas will give max random...

  Args:
  decode: user defined decode 
  min_after_dequeue: set to >2w for production train, suggesed will be 0.4 * num_instances, but also NOTICE do not exceed mem
  #--default parmas will make most randomness
  shuffle_files: wehter shuffle file 
  shuffle_batch: batch or shuffle_batch
  batch_join: wether to use multiple reader or use one reader mutlitple thread
  fix_random: if True make at most random which can fix random result
  allow_smaller_final_batch: set True usefull if you want verify on small d

  great article http://d0evi1.com/tensorflow/ds_performance/
  https://www.tensorflow.org/versions/master/performance/ds_performance
  """
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ
    if use_horovod:
        if FLAGS.torch:
            import horovod.torch as hvd
        else:
            import horovod.tensorflow as hvd

    def shard(d):
        return d.shard(hvd.size(), hvd.rank())

    # Choose to use cpu outside input function like in d.py
    #with tf.device('/cpu:0'):
    if isinstance(files, str):
        files = gezi.list_files(files)
    assert len(files) > 0

    if use_horovod and not hvd_shard and training:
        assert len(files) % hvd.size() == 0, '{} {} {}'.format(
            len(files), files, hvd.size())
        files_ = []
        for i in range(len(files)):
            if i % hvd.size() == hvd.rank():
                files_.append(files[i])
        files = files_
        print('----------train-files', files)
        #exit(0)

    if not num_threads:
        try:
            import multiprocessing
            num_threads = multiprocessing.cpu_count()
            logging.info('num_threads as multiprocessing.cpu_count',
                         num_threads)
        except Exception:
            num_threads = 12
            logging.info('num_threads set by default', num_threads)

    if 'batch_size' in inspect.getargspec(decode_fn).args:
        decode_fn_ = decode_fn

        def decode_function(example):
            return decode_fn_(example, batch_size)

        decode_fn = decode_function

    if simple_parse and training:
        # for multiple gpu horovod run seem this much better, might due to repeat then shuffle better TODO
        d = Dataset(files)
        if use_horovod and hvd_shard:
            d = shard(d)
        d = d.repeat(num_epochs).shuffle(
            batch_size * 1024).batch(batch_size).map(
                decode_fn, num_parallel_calls=num_threads).prefetch(9)
        return d.make_one_shot_iterator()

    if not min_after_dequeue:
        min_after_dequeue = melt.tfrecords.read.MIN_AFTER_QUEUE

    if not num_epochs:
        num_epochs = None

    if fix_random:
        if seed is None:
            seed = 1024
        shuffle_files = True
        batch_join = False  #check can be True ?

        shuffle_batch = True
        num_threads = 1

    if fix_sequence:
        no_random = True
        allow_smaller_final_batch = True
        num_threads = 1

    if no_random:
        shuffle_files = False
        batch_join = False
        shuffle_batch = False

    drop_remainder = False if allow_smaller_final_batch else True
    #drop_remainder = True

    if not num_prefetch_batches:
        num_prefetch_batches = num_threads + 3

    if buffer_size is None:
        # ... Too small ? but 1024 will cause starup slow
        buffer_size = min_after_dequeue + num_prefetch_batches * batch_size
        #buffer_size = 1024 * batch_size

    with tf.name_scope(name):
        # https://github.com/tensorflow/tensorflow/issues/14857
        Dataset = Dataset or tf.data.TFRecordDataset
        if not shuffle_files or len(files) == 1:
            d = Dataset(files)
            if use_horovod and hvd_shard:
                d = shard(d)
        else:
            d = tf.data.Dataset.list_files(files)
            # here shard by files, not work good, especially for text line dataset with hrovod
            if use_horovod and shard_by_files:
                d = shard(d)
            if shuffle_files:
                d = d.shuffle(len(files), seed=seed)
                d = d.interleave(Dataset,
                                 cycle_length=num_threads,
                                 block_length=1)
            if use_horovod and not shard_by_files:
                d = shard(d)

            #repeat_then_shuffle = True
            ## below on tf doc, but shard by files will cause problem, especially horovod mutlitple gpu, still not fully understand
            # Be sure to shard before you use any randomizing operator (such as shuffle).
            # Generally it is best if the shard operator is used early in the d pipeline. For example, when reading from a set of TFRecord files, shard before converting the d to input samples. This avoids reading every file on every worker. The following is an example of an efficient sharding strategy within a complete pipeline:
            # d = Dataset.list_files(pattern)
            # d = d.shard(num_workers, worker_index)
            # d = d.repeat(num_epochs)
            # d = d.shuffle(shuffle_buffer_size)
            # d = d.interleave(tf.data.TFRecordDataset,
            #                  cycle_length=num_readers, block_length=1)
            # d = d.map(parser_fn, num_parallel_calls=num_map_threads)

            # # # TODO still need shuffle here ?
            # # #https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle
            # d = d.shuffle(num_files)\
            #                   .apply(tf.contrib.data.parallel_interleave(
            #                         Dataset,
            #                         cycle_length=num_threads))

        if repeat and repeat_then_shuffle:
            d = d.repeat(num_epochs)

        # must batch then map if use pyfunc which you might use batch_parse, here batch_parse means batch parse otherwise slower but simple and powerfull...
        if not batch_parse:
            d = d.map(decode_fn, num_parallel_calls=num_threads)

        try:
            #shapes = d._output_shapes
            shapes = tf.data.get_output_shapes(d)
        except Exception:
            shapes = None

        #logging.info('datast decode shapes', shapes)

        ## Has bug.. seems as least not work with bucket not sure without bucket ok or not
        if balance_pos_neg:
            # https://stackoverflow.com/questions/46938530/produce-balanced-mini-batch-with-d-api/49283371#49283371
            ds_pos = d.filter(pos_filter_fn).repeat()
            ds_neg = d.filter(neg_filter_fn)

            # def _concat(x, y):
            #   return tf.cond(tf.random_uniform(()) > 0.5, lambda: x, lambda: y)
            # d = tf.data.Dataset.zip((ds_pos, ds_neg))
            # d = d.map(_concat)

            d = tf.data.Dataset.zip((ds_pos, ds_neg))
            # Each input element will be converted into a two-element `Dataset` using
            # `Dataset.from_tensors()` and `Dataset.concatenate()`, then `Dataset.flat_map()`
            # will flatten the resulting `Dataset`s into a single `Dataset`.
            d = d.flat_map(lambda ex_pos, ex_neg: tf.data.Dataset.from_tensors(
                ex_pos).concatenate(tf.data.Dataset.from_tensors(ex_neg)))

        #https://github.com/tensorflow/tensorflow/issues/14451
        # count_fn for over sample
        if count_fn is not None:
            d = d.flat_map(lambda x, y: tf.data.Dataset.from_tensors(
                (x, y)).repeat(tf.to_int64(count_fn(x, y))))

        # filter fn for under sample
        # if under_sample_filter_fn is not None:
        #   d = d.filter(under_sample_filter_fn)

        if filter_fn is not None:
            d = d.filter(filter_fn)

        if shuffle_batch:
            logging.info('shuffle with buffer_size', buffer_size, 'seed', seed)
            d = d.shuffle(buffer_size=buffer_size, seed=seed)

        # shuffle then repeat
        if repeat and not repeat_then_shuffle:
            d = d.repeat(num_epochs)

        # d = d.interleave(Dataset,
        #                       cycle_length=num_threads, block_length=16)

        # https://stackoverflow.com/questions/46444018/meaning-of-buffer-size-in-d-map-d-prefetch-and-d-shuffle
        #d = d.prefetch(buffer_size)
        d = d.prefetch(num_prefetch_batches * batch_size)
        #d = d.prefetch(num_prefetch_batches)

        # #https://github.com/HKUST-KnowComp/R-Net/blob/master/util.py
        # #https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py
        if bucket_boundaries:
            # TODO remove support for length index, use use length key!
            assert length_key is not None or length_index is not None, 'forget to set length key  or length index ?'
            if not isinstance(bucket_boundaries, (list, tuple)):
                boundaries = [
                    int(x) for x in bucket_boundaries.split(',') if x.strip()
                ]
            else:
                boundaries = bucket_boundaries
            logging.info('bucket_boundaries', boundaries)
            with tf.name_scope("bucket_by_seq_length"):

                def example_to_bucket_id(*args, **kw):
                    """Return int64 id of the length bucket for this example."""
                    #assert length_index is not None
                    if length_key is None:
                        try:
                            x = list(args[0])[length_index]
                        except Exception:
                            x = args[length_index]
                    else:
                        try:
                            x = args[0][length_key]
                        except Exception:
                            x = args[length_key]

                    seq_length = tf.reduce_sum(
                        tf.cast(tf.cast(x, tf.bool), tf.int32))

                    buckets_min = [np.iinfo(np.int32).min] + boundaries
                    buckets_max = boundaries + [np.iinfo(np.int32).max]
                    conditions_c = tf.logical_and(
                        tf.less_equal(buckets_min, seq_length),
                        tf.less(seq_length, buckets_max))
                    bucket_id = tf.reduce_min(tf.where(conditions_c))
                    return bucket_id

                if not bucket_batch_sizes:

                    def batching_fn(bucket_id, grouped_d):
                        return grouped_d.padded_batch(batch_size,
                                                      padded_shapes=(shapes))

                    ## TODO larger window better hsku squad doing this like below, shuffle can be better ?
                    ## NOTICE!! shuffle may be slow start fill queue can remove not hurt performance ?
                    d = d.apply(
                        tf.contrib.data.group_by_window(
                            example_to_bucket_id,
                            batching_fn,
                            window_size=5 * batch_size)).shuffle(
                                (len(boundaries) + 1) * 25)

                    ## tenor2tensor doing this, no shuffle ? also it seems like window_func for different bounds
                    ## with different batch_size ?
                    # d = d.apply(
                    #   tf.contrib.data.group_by_window(example_to_bucket_id, batching_fn, batch_size)).shuffle((len(boundaries) + 1) * 25)
                else:
                    # TEST OK
                    # test ok ie buckets[400] batch_sizes[64, 32]
                    if not isinstance(bucket_batch_sizes, (list, tuple)):
                        bucket_batch_sizes = [
                            int(x) for x in bucket_batch_sizes.split(',')
                            if x.strip()
                        ]

                    logging.info('bucket_batche_sizes', bucket_batch_sizes)
                    assert len(boundaries) + 1 == len(bucket_batch_sizes)

                    def window_size_fn(bucket_id):
                        # window size = batch size
                        batch_sizes = tf.constant(bucket_batch_sizes,
                                                  dtype=tf.int64)
                        window_size = batch_sizes[bucket_id]
                        # * 5 will make reading slower
                        window_size *= 5
                        return window_size

                    def batching_fn(bucket_id, grouped_d):
                        batch_sizes = tf.constant(bucket_batch_sizes,
                                                  dtype=tf.int64)
                        batch_size = batch_sizes[bucket_id]
                        #return padded_batch(grouped_d, batch_size, padded_shapes=None)
                        return grouped_d.padded_batch(batch_size,
                                                      padded_shapes=(shapes))

                    # shuffle will make start slower might fill
                    d = d.apply(
                        tf.contrib.data.group_by_window(
                            example_to_bucket_id, batching_fn, None,
                            window_size_fn)).shuffle(
                                (len(boundaries) + 1) * 25)
        else:
            # no bucket
            if dynamic_pad and not batch_parse:
                d = d.padded_batch(batch_size,
                                   padded_shapes=(shapes),
                                   drop_remainder=drop_remainder)
            else:
                d = d.batch(batch_size, drop_remainder=drop_remainder)
                if batch_parse:
                    d = d.map(decode_fn, num_parallel_calls=num_threads)

        # if not allow_smaller_final_batch:
        #   # https://github.com/tensorflow/tensorflow/issues/13745 d.apply(tf.contrib.data.batch_and_drop_remainder(10)).
        #   d = d.filter(lambda x, *args, **kw: tf.equal(tf.shape(x)[0], batch_size))

    # TODO save iterator ?
    ## Create saveable object from iterator.
    #saveable = tf.contrib.data.make_saveable_from_iterator(iterator)

    # Save the iterator state by adding it to the saveable objects collection.
    #tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
    #try:
        if tf.executing_eagerly():
            # TODO store iterator for eager
            return d
        else:
            if repeat and not initializable:
                iterator = d.make_one_shot_iterator()
                saveable = tf.contrib.data.make_saveable_from_iterator(
                    iterator)
                tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
                if return_iterator:
                    return iterator
                ops = iterator.get_next()
                return ops
            else:
                iterator = d.make_initializable_iterator()
                saveable = tf.contrib.data.make_saveable_from_iterator(
                    iterator)
                tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
                return iterator
Пример #19
0
def inputs(files,
           decode_fn,
           batch_size=64,
           num_epochs=None,
           num_threads=12,
           shuffle_files=True,
           batch_join=True,
           shuffle_batch=True,
           min_after_dequeue=None,
           seed=None,
           enqueue_many=False,
           fix_random=False,
           no_random=False,
           fix_sequence=False,
           allow_smaller_final_batch=False,
           num_prefetch_batches=None,
           dynamic_pad=False,
           bucket_boundaries=None,
           length_index=None,
           length_fn=None,
           name='input'):
    """Reads input data num_epochs times.
  for sparse input here will do:
  1. read serialized_example
  2. shuffle serialized_examples
  3. decdoe batch_serialized_examples
  notice read_sparse.inputs and also be used for dense inputs,but if you 
  only need to decode part from serialized_example, then read.inputs will 
  be better, less to put to suffle
  #--------decode example, can refer to libsvm-decode.py
  # def decode(batch_serialized_examples):
  #   features = tf.parse_example(
  #       batch_serialized_examples,
  #       features={
  #           'label' : tf.FixedLenFeature([], tf.int64),
  #           'index' : tf.VarLenFeature(tf.int64),
  #           'value' : tf.VarLenFeature(tf.float32),
  #       })

  #   label = features['label']
  #   index = features['index']
  #   value = features['value']

  #   return label, index, value 

  #string_input_reducer will shuffle files
  #shuffle will read file by file and shuffle withn file(in shuffle queue) 
  #shuffle_batch_join will read multiple files and shuffle in shuffle queue(from many files)

  To get fixed sequence 
  shuffle=False  so by this way the sequence is as your data input unchange
  or
  shuffle=True
  seed=1024 #set
  batch_join=False  by this way you have fixed random, so get same result
  NOTICE, shuffle=True,seed=1024,batch_join=True will not get same result
  shuffle=False,seed=1024,batch_join=True also, so batch_join seems seed only control inqueue random, can not get fixed result

  for no random -> fixed result set shuffle=False wihch will force batch_join=False then use batch
  for fixed random ->  shuffle=True, seed set or  fix_random=True
  read-records.py show above ok, but train-evaluate.py show not, only shuffle=False can get fixed result.. @FIXME strange
  for train-evaluate.py it looks you can set shuffle in string_input_producer True, but then must use batch,
  batch_join and shuffle_batch join all not fixed even with seed set, may be due to trainset two inputs read ?
  for read-records.py batch_join will be fixed, shuffle_batch_join not 

  defualt parmas will give max random...

  Args:
  decode: user defined decode 
  min_after_dequeue: set to >2w for production train, suggesed will be 0.4 * num_instances, but also NOTICE do not exceed mem
  #--default parmas will make most randomness
  shuffle_files: wehter shuffle file 
  shuffle_batch: batch or shuffle_batch
  batch_join: wether to use multiple reader or use one reader mutlitple thread
  fix_random: if True make at most random which can fix random result
  allow_smaller_final_batch: set True usefull if you want verify on small dataset
  """
    if isinstance(files, str):
        files = gezi.list_files(files)

    assert len(files) > 0

    if not min_after_dequeue:
        min_after_dequeue = melt.tfrecords.read.MIN_AFTER_QUEUE
    if not num_epochs:
        num_epochs = None

    if fix_random:
        if seed is None:
            seed = 1024
        shuffle_files = True
        batch_join = False  #check can be True ?

        #to get fix_random
        #shuffle_batch = True  and num_threads = 1 ok
        #shuffle_batch = False and num_threads >= 1 ok
        #from models/iamge-text-sim/read_records shuffle_batch = True will be quick, even single thread
        #and strange num_threas = 1 will be quicker then 12

        shuffle_batch = True
        num_threads = 1

        #shuffle_batch = False

    if fix_sequence:
        no_random = True
        allow_smaller_final_batch = True
        num_threads = 1

    if no_random:
        shuffle_files = False
        batch_join = False
        shuffle_batch = False

    if dynamic_pad:
        #use tf.batch
        shuffle_batch = False

    #shuffle=True
    #batch_join = True #setting to False can get fixed result
    #seed = 1024

    with tf.name_scope(name):
        filename_queue = tf.train.string_input_producer(files,
                                                        num_epochs=num_epochs,
                                                        shuffle=shuffle_files,
                                                        seed=seed)

        # min_after_dequeue defines how big a buffer we will randomly sample
        #   from -- bigger means better shuffling but slower start up and more
        #   memory used.
        # capacity must be larger than min_after_dequeue and the amount larger
        #   determines the maximum we will prefetch.  Recommendation:
        #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
        #@TODO cifa10 always use num_prefetch_batches = 3, 3 * batch_size, check which is better
        if not num_prefetch_batches:
            num_prefetch_batches = num_threads + 3
        capacity = min_after_dequeue + num_prefetch_batches * batch_size
        #@TODO diff between tf.batch_join and tf.batch, batch_join below means shuffle_batch_join.. TODO
        if batch_join:
            batch_list = [_read(filename_queue) for _ in xrange(num_threads)]
            #print batch_list
            if shuffle_batch:
                batch_serialized_examples = tf.train.shuffle_batch_join(
                    batch_list,
                    batch_size=batch_size,
                    capacity=capacity,
                    min_after_dequeue=min_after_dequeue,
                    seed=seed,
                    enqueue_many=enqueue_many,
                    allow_smaller_final_batch=allow_smaller_final_batch,
                    name='shuffle_batch_join_queue')
            else:
                batch_serialized_examples = tf.train.batch_join(
                    batch_list,
                    batch_size=batch_size,
                    capacity=capacity,
                    enqueue_many=enqueue_many,
                    allow_smaller_final_batch=allow_smaller_final_batch,
                    dynamic_pad=dynamic_pad,
                    name='batch_join_queue')
        else:
            serialized_example = list(_read(filename_queue))
            #@FIXME... for bug now can not be more random if want fix random see D:\mine\tensorflow-exp\models\image-text-sim\train-evaluate-fixrandom.py
            if shuffle_batch:
                batch_serialized_examples = tf.train.shuffle_batch(
                    serialized_example,
                    batch_size=batch_size,
                    num_threads=num_threads,
                    capacity=capacity,
                    min_after_dequeue=min_after_dequeue,
                    seed=seed,
                    enqueue_many=enqueue_many,
                    allow_smaller_final_batch=allow_smaller_final_batch,
                    name='shuffle_batch_queue')
            else:
                batch_serialized_examples = tf.train.batch(
                    serialized_example,
                    batch_size=batch_size,
                    #@TODO to make really fxied result use num_threads=1, may be shuffle_batch will be fix random?
                    num_threads=num_threads,
                    capacity=capacity,
                    enqueue_many=enqueue_many,
                    allow_smaller_final_batch=allow_smaller_final_batch,
                    dynamic_pad=dynamic_pad,
                    name='batch_queue')

        return decode_fn(
            batch_serialized_examples
        ) if decode_fn is not None else batch_serialized_examples
Пример #20
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ==============================================================================
#          \file   read-torch-dataset.py
#        \author   chenghuige
#          \date   2019-08-03 14:08:33.314862
#   \Description
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os

from torch.utils.data import DataLoader
import gezi

from pyt.dataset import *
from text_dataset import Dataset as TD

files = gezi.list_files('../input/valid/*')
td = TD()
ds = get_dataset(files, td)
dl = DataLoader(ds, 5)
for i, d in enumerate(dl):
    print(i, d)
    if i == 3:
        break
def inputs(files,
           decode_fn,
           batch_size=64,
           num_epochs=None,
           num_threads=12,
           shuffle_files=True,
           batch_join=True,
           shuffle_batch=True,
           min_after_dequeue=None,
           seed=None,
           enqueue_many=False,
           fix_random=False,
           no_random=False,
           fix_sequence=False,
           allow_smaller_final_batch=False,
           num_prefetch_batches=None,
           dynamic_pad=False,
           bucket_boundaries=None,
           length_index=None,
           length_fn=None,
           keep_fn=None,
           name='input'):
    """Reads input data num_epochs times.
  for sparse input here will do:
  1. read decode decoded_example
  2. shuffle decoded values
  3. return batch decoded values
  Args:
  decode: user defined decode #TODO should be decode_fn
  #---decode example
  # features = tf.parse_single_example(
  #     decoded_example,
  #     features={
  #         'feature': tf.FixedLenFeature([], tf.string),
  #         'name': tf.FixedLenFeature([], tf.string),
  #         'comment_str': tf.FixedLenFeature([], tf.string),
  #         'comment': tf.FixedLenFeature([], tf.string),
  #         'num_words': tf.FixedLenFeature([], tf.int64),
  #     })
  # feature = tf.decode_raw(features['feature'], tf.float32)
  # feature.set_shape([IMAGE_FEATURE_LEN])
  # comment = tf.decode_raw(features['comment'], tf.int64)
  # comment.set_shape([COMMENT_MAX_WORDS])
  # name = features['name']
  # comment_str = features['comment_str']
  # num_words = features['num_words']
  # return name, feature, comment_str, comment, num_words
  Returns:
  list of tensors
  """
    #with tf.device('/cpu:0'):
    if isinstance(files, str):
        files = gezi.list_files(files)

    assert len(files) > 0

    if not min_after_dequeue:
        min_after_dequeue = melt.tfrecords.read.MIN_AFTER_QUEUE
    if not num_epochs:
        num_epochs = None

    if fix_random:
        if seed is None:
            seed = 1024
        shuffle_files = True
        batch_join = False  #check can be True ?

        #to get fix_random
        #shuffle_batch = True  and num_threads = 1 ok
        #shuffle_batch = False and num_threads >= 1 ok
        #from models/iamge-text-sim/read_records shuffle_batch = True will be quick, even single thread
        #and strange num_threas = 1 will be quicker then 12

        shuffle_batch = True
        num_threads = 1

        #shuffle_batch = False

    if fix_sequence:
        no_random = True
        allow_smaller_final_batch = True

    if no_random:
        shuffle_files = False
        batch_join = False
        shuffle_batch = False
        num_threads = 1

    if dynamic_pad:
        #use tf.batch
        shuffle_batch = False

    #shuffle=True
    #batch_join = True #setting to False can get fixed result
    #seed = 1024

    with tf.name_scope(name):
        filename_queue = tf.train.string_input_producer(files,
                                                        num_epochs=num_epochs,
                                                        shuffle=shuffle_files,
                                                        seed=seed)

        # min_after_dequeue defines how big a buffer we will randomly sample
        #   from -- bigger means better shuffling but slower start up and more
        #   memory used.
        # capacity must be larger than min_after_dequeue and the amount larger
        #   determines the maximum we will prefetch.  Recommendation:
        #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
        #@TODO cifa10 always use num_prefetch_batches = 3, 3 * batch_size, check which is better
        if not num_prefetch_batches:
            num_prefetch_batches = num_threads + 3

        capacity = min_after_dequeue + num_prefetch_batches * batch_size

        if batch_join:
            batch_list = [
                _read_decode(filename_queue, decode_fn, thread_id)
                for thread_id in xrange(num_threads)
            ]
            #print batch_list
            if shuffle_batch:
                batch = tf.train.shuffle_batch_join(
                    batch_list,
                    batch_size=batch_size,
                    capacity=capacity,
                    min_after_dequeue=min_after_dequeue,
                    seed=seed,
                    enqueue_many=enqueue_many,
                    allow_smaller_final_batch=allow_smaller_final_batch,
                    name='shuffle_batch_join_queue')
            else:
                batch = tf.train.batch_join(
                    batch_list,
                    batch_size=batch_size,
                    capacity=capacity,
                    enqueue_many=enqueue_many,
                    allow_smaller_final_batch=allow_smaller_final_batch,
                    dynamic_pad=dynamic_pad,
                    name='batch_join_queue')
        else:
            decoded_example = list(_read_decode(filename_queue, decode_fn))
            num_threads = 1 if fix_random else num_threads
            if bucket_boundaries:
                if not isinstance(bucket_boundaries, (list, tuple)):
                    bucket_boundaries = [
                        int(x) for x in bucket_boundaries.split(',') if x
                    ]
                if length_index is not None:
                    input_length = decoded_example[length_index]
                else:
                    assert length_fn is not None, 'you must set length_index or pass length_fn'
                    input_length = length_fn(decoded_example)
                keep_input = input_length >= 1 if keep_fn is None else keep_fn(
                    decoded_example)
                _, batch = tf.contrib.training.bucket_by_sequence_length(
                    input_length=tf.to_int32(input_length),
                    bucket_boundaries=bucket_boundaries,
                    tensors=decoded_example,
                    batch_size=batch_size,
                    keep_input=keep_input,
                    num_threads=num_threads,
                    dynamic_pad=True,
                    capacity=capacity,
                    allow_smaller_final_batch=allow_smaller_final_batch,
                    name="bucket_queue")
            else:
                if shuffle_batch:
                    batch = tf.train.shuffle_batch(
                        decoded_example,
                        batch_size=batch_size,
                        num_threads=num_threads,
                        capacity=capacity,
                        min_after_dequeue=min_after_dequeue,
                        seed=seed,
                        enqueue_many=enqueue_many,
                        allow_smaller_final_batch=allow_smaller_final_batch,
                        name='shuffle_batch_queue')
                else:
                    #http://honggang.io/2016/08/19/tensorflow-data-reading/
                    batch = tf.train.batch(
                        decoded_example,
                        batch_size=batch_size,
                        num_threads=num_threads,
                        capacity=capacity,
                        enqueue_many=enqueue_many,
                        allow_smaller_final_batch=allow_smaller_final_batch,
                        dynamic_pad=dynamic_pad,
                        name='batch_queue')

        return batch
Пример #22
0
def train(model, 
          loss_fn,
          Dataset=None,  
          dataset=None,
          valid_dataset=None,
          valid_dataset2=None,
          test_dataset=None,
          evaluate_fn=None, 
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          optimizer=None,
          param_groups=None,
          init_fn=None,
          sep=','):
  use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

  if Dataset is None:
    assert dataset
  logging.info('Dataset', Dataset, 'dataset', dataset, 'valid_dataset', valid_dataset, 'test_dataset', test_dataset, loss_fn)

  if FLAGS.torch:
    torch.manual_seed(FLAGS.seed or 0)
    if torch.cuda.device_count():
      torch.cuda.manual_seed(FLAGS.seed or 0)
    if use_horovod:
      pass
      # import horovod.torch as hvd
      # hvd.init()
      # #print('-----------------', hvd, hvd.size())
      # assert hvd.mpi_threads_supported()
      # assert hvd.size() == comm.Get_size()
      # torch.cuda.set_device(hvd.local_rank())
    # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
    else:
      if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(device)
    
  input_ =  FLAGS.train_input 
  inputs = gezi.list_files(input_)
  inputs.sort()

  all_inputs = inputs

  #batch_size = FLAGS.batch_size
  batch_size = melt.batch_size()

  num_gpus = melt.num_gpus()

  #batch_size = max(batch_size, 1)
  #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
  batch_size_ = FLAGS.eval_batch_size or batch_size

  if dataset is None:
    if FLAGS.fold is not None:
      inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold) and not x.endswith('%d.tfrecord' % FLAGS.fold)]
      # if FLAGS.valid_input:
      #   inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)]
    logging.info('inputs', len(inputs), inputs[:100])
  num_folds = FLAGS.num_folds or len(inputs) + 1

  if dataset is None:
    dataset = Dataset('train')
    assert len(inputs) > 0
    train_dataset = dataset.make_batch(batch_size, inputs, simple_parse=FLAGS.simple_parse)
    num_examples = dataset.num_examples_per_epoch('train') 
  else:
    assert FLAGS.torch_only, 'only torch only currently support input dataset not Dataset class type, because we do not have len function there'
    train_dataset = dataset
    num_examples = len(train_dataset.dataset)

  num_all_examples = num_examples

  if valid_dataset is None:
    valid_inputs = None
    if FLAGS.valid_input:
      valid_inputs = gezi.list_files(FLAGS.valid_input)
    else:
      if FLAGS.fold is not None:
        #valid_inputs = [x for x in all_inputs if x not in inputs]
        if not FLAGS.test_aug:
          valid_inputs = [x for x in all_inputs if not 'aug' in x and x not in inputs]
        else:
          valid_inputs = [x for x in all_inputs if 'aug' in x and x not in inputs]

    logging.info('valid_inputs', valid_inputs)

  num_valid_examples = None
  if valid_dataset is not None:
    num_valid_examples = len(valid_dataset.dataset)
    num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None   
    valid_dataset2_iter = iter(valid_dataset2)
  else:
    if valid_inputs:
      valid_dataset = dataset.make_batch(batch_size_, valid_inputs, subset='valid', hvd_shard=FLAGS.horovod_eval )
      valid_dataset2 = dataset.make_batch(batch_size, valid_inputs, subset='valid', repeat=True, initializable=False, hvd_shard=False)
      valid_dataset2_iter = iter(valid_dataset2)
    else:
      valid_datsset = None
      valid_dataset2 = None

  if num_examples:
    if FLAGS.fold is not None:
      num_examples = int(num_examples * (num_folds - 1) / num_folds)
    num_steps_per_epoch = -(-num_examples // batch_size)
  else:
    num_steps_per_epoch = None
  logging.info('num_train_examples:', num_examples)
  if use_horovod and num_examples:
    num_steps_per_epoch = -(-num_examples // (batch_size * hvd.size()))

  if num_valid_examples is None:
    if FLAGS.valid_input:
      num_valid_examples = dataset.num_examples_per_epoch('valid')
      num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None   
    else:
      if FLAGS.fold is not None:
        if num_examples:
          num_valid_examples = int(num_all_examples * (1 / num_folds))
          num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_)
        else:
          num_valid_steps_per_epoch = None
  if use_horovod and FLAGS.horovod_eval and num_valid_examples:
      num_valid_steps_per_epoch = -(-num_valid_examples // (batch_size_ * hvd.size()))
  logging.info('num_valid_examples:', num_valid_examples)

  if test_dataset is None:
    if FLAGS.test_input:
      test_inputs = gezi.list_files(FLAGS.test_input)
      #test_inputs = [x for x in test_inputs if not 'aug' in x]
      logging.info('test_inputs', test_inputs)
    else:
      test_inputs = None
  
  num_test_examples = None
  if test_dataset is not None:
    num_test_examples = len(test_dataset.dataset)
  else:
    if test_inputs:
      test_dataset = dataset.make_batch(batch_size_, test_inputs, subset='test') 
      num_test_examples = dataset.num_examples_per_epoch('test')
    else:
      test_dataset = None
  num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None
  if use_horovod and FLAGS.horovod_eval and num_test_examples:
      num_test_steps_per_epoch = -(-num_test_examples // (batch_size_ * hvd.size()))
  logging.info('num_test_examples:', num_test_examples)
  
  summary = tf.contrib.summary
  # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch')
  # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train')
  # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid')
  writer = summary.create_file_writer(FLAGS.log_dir)
  writer_train = summary.create_file_writer(FLAGS.log_dir)
  writer_valid = summary.create_file_writer(FLAGS.log_dir)
  global_step = tf.train.get_or_create_global_step()
  ## RuntimeError: tf.summary.FileWriter is not compatible with eager execution. Use tf.contrib.summary instead.
  #logger = gezi.SummaryWriter(FLAGS.log_dir)

  learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")
  
  tf.add_to_collection('learning_rate', learning_rate)

  learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
  try:
    learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
  except Exception:
    learning_rate_weights = None

  # ckpt dir save models one per epoch
  ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt')
  os.system('mkdir -p %s' % ckpt_dir)
  # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset
  ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2')
  os.system('mkdir -p %s' % ckpt_dir2)

  #TODO FIXME now I just changed tf code so to not by default save only latest 5
  # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
  #     checkpoint, directory=ckpt_dir, max_to_keep=5)
  # latest_checkpoint = manager.latest_checkpoint

  latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
  if latest_checkpoint:
    logging.info('Latest checkpoint:', latest_checkpoint)
  else:
    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2)
    logging.info('Latest checkpoint:', latest_checkpoint)
  
  if os.path.exists(FLAGS.model_dir + '.index'):
    latest_checkpoint = FLAGS.model_dir  

  if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode:
    #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
    latest_checkpoint = FLAGS.model_dir
    #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint)

  checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')
  checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt')

  if not FLAGS.torch:
    try:
      optimizer = optimizer or melt.get_optimizer(FLAGS.optimizer)(learning_rate)
    except Exception:
      logging.warning(f'Fail to using {FLAGS.optimizer} use adam instead')
      optimizer = melt.get_optimizer('adam')(learning_rate)
    
    # TODO...
    if  learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            model=model,
            optimizer=optimizer,
            global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            model=model,
            optimizer=optimizer,
            global_step=global_step) 

    checkpoint.restore(latest_checkpoint)
    checkpoint2 = copy.deepcopy(checkpoint)

    start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0
    start_step = 0 # TODO
  else:
    # TODO torch with learning rate adjust
      # https://github.com/horovod/horovod/blob/master/examples/pytorch_mnist.py
  # TODO full support for pytorch now not work
    
    if optimizer is None:
      import lele
      is_dynamic_opt = True
      if FLAGS.optimizer == 'noam':
        optimizer_ = torch.optim.Adamax(model.parameters(), lr=0)
        if use_horovod:
          optimizer_ = hvd.DistributedOptimizer(optimizer_)
        optimizer = lele.training.optimizers.NoamOpt(128, 2, 4000, optimzier_)
      elif FLAGS.optimizer == 'bert':
        num_train_steps = int(num_steps_per_epoch * (FLAGS.num_decay_epochs or FLAGS.num_epochs))
        if FLAGS.warmup_steps and use_horovod:
          FLAGS.warmup_steps = max(int(FLAGS.warmup_steps / hvd.size()), 1)
        num_warmup_steps = FLAGS.warmup_steps or int(num_steps_per_epoch * FLAGS.warmup_epochs) or int(num_train_steps * FLAGS.warmup_proportion) 
        logging.info('num_train_steps', num_train_steps, 'num_warmup_steps', num_warmup_steps, 'warmup_proportion', FLAGS.warmup_proportion)
        optimizer_ = torch.optim.Adamax(model.parameters(), lr=0)
        if use_horovod:
          optimizer_ = hvd.DistributedOptimizer(optimizer_)
        optimizer = lele.training.optimizers.BertOpt(
                            FLAGS.learning_rate, 
                            FLAGS.min_learning_rate,
                            num_train_steps,
                            num_warmup_steps,
                            optimizer_
                            )
      else:
        is_dynamic_opt = False
        optimizer = torch.optim.Adamax(param_groups if param_groups else model.parameters(), lr=FLAGS.learning_rate)
        if use_horovod:
          optimizer = hvd.DistributedOptimizer(optimizer)
          optimizer_ = optimizer
    else:
      if use_horovod:
        optimizer = hvd.DistributedOptimizer(optimizer)
        optimizer_ = optimizer
    
    start_epoch = 0  
    latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join(FLAGS.model_dir, 'latest.pyt')
    if not os.path.exists(latest_path):
      latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt')
    if os.path.exists(latest_path):
      logging.info('loading torch model from', latest_path)
      checkpoint = torch.load(latest_path)
      if not FLAGS.torch_finetune:
        start_epoch = checkpoint['epoch']
        step = checkpoint['step']
        global_step.assign(step + 1)
      load_torch_model(model, latest_path)
      if FLAGS.torch_load_optimizer:
        optimizer.load_state_dict(checkpoint['optimizer'])

    # TODO by this way restart can not change learning rate..
    if learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
          learning_rate=learning_rate, 
          learning_rate_weight=learning_rate_weight,
          global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            global_step=global_step)

    try:
      checkpoint.restore(latest_checkpoint)
      checkpoint2 = copy.deepcopy(checkpoint)
    except Exception:
      pass

  if FLAGS.torch and is_dynamic_opt:
    optimizer._step = global_step.numpy()
    
  #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
  #model.save('./weight3.hd5')
  logging.info('optimizer:', optimizer)

  if FLAGS.torch_lr:
    learning_rate.assign(optimizer.rate(1))
  if FLAGS.torch:
    learning_rate.assign(optimizer.param_groups[0]['lr'])
    logging.info('learning rate got from pytorch latest.py as', learning_rate.numpy())

  learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor)
  if learning_rate_weights is not None:
    learning_rate_weights.assign(learning_rate_weights * FLAGS.learning_rate_start_factor)

  # TODO currently not support 0.1 epoch.. like this
  num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024

  will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ
  if global_step.numpy() == 0 :
    will_valid = False

  if gezi.get_env('EVFIRST') == '1':
    will_valid = True
  
  if gezi.get_env('EVFIRST') == '0':
    will_valid = False

  if will_valid:
    logging.info('----------valid')
    if hasattr(model, 'eval'):
      model.eval()
    names = None 
    if evaluate_fn is not None:
      vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch)
    elif eval_fn:
      model_path = None if not write_valid else latest_checkpoint
      names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None

      logging.info('model_path:', model_path, 'model_dir:', FLAGS.model_dir)
      vals, names = evaluate(model, valid_dataset, eval_fn, model_path, 
                             names, valid_write_fn, write_streaming,
                             num_valid_steps_per_epoch, num_valid_examples,
                             suffix=valid_suffix, sep=sep)
    if names:
      logging.info2('epoch:%.2f/%d step:%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs, global_step.numpy()), 
                    ['%s:%.4f' % (name, val) for name, val in zip(names, vals)])
  
    if FLAGS.work_mode == 'valid' or gezi.get_env('METRIC') == '1':
      exit(0)

  if 'test' in FLAGS.work_mode or gezi.get_env('TEST') == '1' or gezi.get_env('INFER') == '1':
    logging.info('--------test/inference')
    if test_dataset:
      if hasattr(model, eval):
        model.eval()
      if inference_fn is None:
        # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint
        # logging.info('model_path', model_path)
        assert latest_checkpoint
        inference(model, test_dataset, latest_checkpoint, 
                  infer_names, infer_debug_names, infer_write_fn, write_streaming,
                  num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix)
      else:
        inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch)
    exit(0)
  
  if 'SHOW' in os.environ:
    num_epochs = start_epoch + 1

  class PytObj(object):
    def __init__(self, x):
      self.x = x
    def numpy(self):
      return self.x

  class PytMean(object):
    def __init__(self):
      self._val = 0. 
      self.count = 0

      self.is_call = True

    def clear(self):
      self._val = 0
      self.count = 0

    def __call__(self, val):
      if not self.is_call:
        self.clear()
        self.is_call = True
      self._val += val.item()
      self.count += 1

    def result(self):
      if self.is_call:
        self.is_call = False
      if not self.count:
        val = 0
      else:
        val = self._val / self.count
      # TODO just for compact with tf ..
      return PytObj(val)
      
  Mean =  tfe.metrics.Mean if not FLAGS.torch else PytMean
  
  num_insts = 0

  if FLAGS.learning_rate_decay_factor > 0:
    #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?'
    #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay
    #since global step / decay_steps will not be correct epoch as num_steps per epoch changed
    #so if if you change batch set you have to reset global step as fixed step
    assert FLAGS.num_steps_per_decay or (FLAGS.num_epochs_per_decay and num_steps_per_epoch), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch'
    decay_steps = FLAGS.num_steps_per_decay or int(num_steps_per_epoch * FLAGS.num_epochs_per_decay)    
    decay_start_step = FLAGS.decay_start_step or int(num_steps_per_epoch * FLAGS.decay_start_epoch)
    # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
    logging.info('learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}'.format(
        FLAGS.learning_rate_decay_factor, FLAGS.num_epochs_per_decay, decay_steps, FLAGS.decay_start_epoch, decay_start_step))

  #-------------------------start training
  if hasattr(model, 'train'):
    model.train()

  if use_horovod:
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer_, root_rank=0)

  timer = gezi.Timer()
  loss_avg = Mean()
  valid_loss_avg = Mean()

  num_epochs = num_epochs if num_epochs else 0
  loops = min(num_epochs, 1) if FLAGS.torch_only else 1
  for _ in range(loops):
    for i, (x, y) in enumerate(train_dataset):
      if FLAGS.torch:
        x, y = to_torch(x, y)
        if is_dynamic_opt:
          learning_rate.assign(optimizer.rate())

      def loss_fn_(x, y):
        if not FLAGS.torch and 'training' in inspect.getargspec(model.call).args:
          y_ = model(x, training=True)
        else:
          y_ = model(x)
        if not FLAGS.torch:
          return loss_fn(y, y_)
        else:
          return loss_fn(y_, y)
      
      if not FLAGS.torch:
        loss, grads = melt.eager.grad(model, x, y, loss_fn)
        grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients)
        #optimizer.apply_gradients(zip(grads, model.variables))
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        # https://github.com/horovod/horovod/blob/master/examples/tensorflow_mnist_eager.py
        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        # Note: broadcast should be done after the first gradient step to ensure optimizer
        # initialization.
        # TODO check eager mode
        if use_horovod and epoch == start_epoch and i == 0:
          hvd.broadcast_variables(model.variables, root_rank=0)
          hvd.broadcast_variables(optimizier.variables(), root_rank=0)
      else:
        optimizer.zero_grad()
        loss = loss_fn_(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                        FLAGS.clip_gradients)
        optimizer.step()

      global_step.assign_add(1)
      loss_avg(loss)
    
      ## https://discuss.pytorch.org/t/calling-loss-backward-reduce-memory-usage/2735
      # if FLAGS.torch:
      #   del loss

      batch_size_ = list(x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type({}) else x.shape[FLAGS.batch_size_dim]
      num_insts += int(batch_size_)
      if global_step.numpy() % FLAGS.interval_steps == 0:
        #checkpoint.save(checkpoint_prefix)
        elapsed = timer.elapsed()
        steps_per_second = FLAGS.interval_steps / elapsed
        instances_per_second = num_insts / elapsed
        num_insts = 0

        if num_steps_per_epoch is None:
          epoch_time_info = ''
        else:
          hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600
          epoch_time_info = '1epoch:[{:.2f}h]'.format(hours_per_epoch)

        if valid_dataset2:
          # try:
          #   x, y = next(iter(valid_dataset2))
          # except Exception:
          #   # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset 
          #   x, y = next(iter(valid_dataset2))
          ## valid dataset2 is repeated
          ## NOTICE will always the first batch ... as below
          #x, y = next(iter(valid_dataset2))
          x, y = next(valid_dataset2_iter)
          #print(x['id'][0])
          if FLAGS.torch:
            x, y = to_torch(x, y)
          if hasattr(model, 'eval'):  
            model.eval()
          valid_loss = loss_fn_(x, y)
          valid_loss = valid_loss.numpy() if not FLAGS.torch else valid_loss.item()
          if hasattr(model, 'train'):
            model.train()

          if not use_horovod or hvd.rank() == 0:
                        # 'train_loss:[%.4f]' % loss_avg.result().numpy(),
                        # 'valid_loss:[%.4f]' % valid_loss_avg.result().numpy()
            logging.info2('epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs), 
                        'step:%d' % global_step.numpy(), 
                        'elapsed:[%.2f]' % elapsed,
                        'batch_size:[%d]' % batch_size_,
                        'gpus:[%d]' % num_gpus, 
                        'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second,
                        '%s' % epoch_time_info,
                        'lr:[%.6f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % loss_avg.result().numpy(),
                        'valid_loss:[%.4f]' % valid_loss
                        )
            if global_step.numpy() % FLAGS.valid_interval_steps == 0:
              with writer_valid.as_default(), summary.always_record_summaries():
                summary.scalar('loss/valid', valid_loss)
                writer_valid.flush()
        else:
          if not use_horovod or hvd.rank() == 0:
            #'train_loss:[%.4f]' % loss_avg.result().numpy()
            logging.info2('epoch:%.2f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 
                        'step:%d' % global_step.numpy(), 
                        'elapsed:[%.2f]' % elapsed,
                        'batch_size:[%d]' % batch_size_,
                        'gpus:[%d]' % num_gpus, 
                        'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second,
                        '%s' % epoch_time_info,
                        'lr:[%.6f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % loss_avg.result().numpy()
                        )      

        if not use_horovod or hvd.rank() == 0:
          if global_step.numpy() % FLAGS.valid_interval_steps == 0:
            with writer_train.as_default(), summary.always_record_summaries():
              summary.scalar('loss/train_avg', loss_avg.result().numpy())
              summary.scalar('learning_rate', learning_rate.numpy())
              summary.scalar('other/batch_size', batch_size_)
              summary.scalar('other/epoch', melt.epoch())
              summary.scalar('perf/steps_per_second', steps_per_second)
              summary.scalar('perf/instances_per_second', instances_per_second)
              writer_train.flush()

      if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy() and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0:
        if hasattr(model, eval):
          model.eval()
        vals, names = None, None
        if evaluate_fn is not None:
          vals, names = evaluate_fn(model, valid_dataset, None, num_valid_steps_per_epoch)
        elif eval_fn:
          names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None
          vals, names = evaluate(model, valid_dataset, eval_fn, None, 
                                  names, valid_write_fn, write_streaming,
                                  num_valid_steps_per_epoch, num_valid_examples, sep=sep)
        if not use_horovod or hvd.rank() == 0:
          if vals and names:
            with writer_valid.as_default(), summary.always_record_summaries():
              for name, val in zip(names, vals):
                summary.scalar(f'step_eval/{name}', val)
              writer_valid.flush()
      
        if FLAGS.torch:
          if not FLAGS.torch_lr:
            # control learning rate by tensorflow learning rate
            for param_group in optimizer.param_groups:
              # important learning rate decay
              param_group['lr'] = learning_rate.numpy()
        if hasattr(model, 'train'):  
          model.train()
        if not use_horovod or hvd.rank() == 0:
          if names and vals:
            logging.info2('epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs),  
                          'valid_step:%d' % global_step.numpy(),
                          'valid_metrics',
                          ['%s:%.5f' % (name, val) for name, val in zip(names, vals)])
      
      if not use_horovod or hvd.rank() == 0:
      # TODO save ok ?
        if global_step.numpy() % FLAGS.save_interval_steps == 0:
          if FLAGS.torch:
            state = {
                    'epoch': int(global_step.numpy() / num_steps_per_epoch),
                    'step': global_step.numpy(),
                    'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                  }
            torch.save(state, os.path.join(FLAGS.model_dir, 'latest.pyt'))     

        # TODO fixme why if both checpoint2 and chekpoint used... not ok..
        if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
          checkpoint2.save(checkpoint_prefix2) 
          if FLAGS.torch:
            state = {
                    'epoch': int(global_step.numpy() / num_steps_per_epoch),
                    'step': global_step.numpy(),
                    'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                  }
            torch.save(state, tf.train.latest_checkpoint(ckpt_dir2) + '.pyt')

      if FLAGS.learning_rate_decay_factor > 0:
        if global_step.numpy() >= decay_start_step and global_step.numpy() % decay_steps == 0:
          lr = max(learning_rate.numpy() * FLAGS.learning_rate_decay_factor, FLAGS.min_learning_rate)
          if lr < learning_rate.numpy():
            learning_rate.assign(lr)
            if FLAGS.torch:
              for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate.numpy()

      if i == 0:
        try:
          if not FLAGS.torch:
            logging.info(model.summary())
            # #tf.keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='TB')
            # import keras
            # keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='LR', expand_nested=True, dpi=96)
          else:
            logging.info(model)
        except Exception:
          traceback.print_exc()
          logging.info('Fail to do model.summary() may be you have layer define in init but not used in call')
        if 'SHOW' in os.environ:
          exit(0)
      
      if valid_dataset and  global_step.numpy() % int(num_steps_per_epoch * FLAGS.valid_interval_epochs) == 0:
        if hasattr(model, 'eval'):
          model.eval()

        vals, names = None, None
        if evaluate_fn is not None:
          vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch)
        elif eval_fn:
          model_path = None if not write_valid else tf.train.latest_checkpoint(ckpt_dir)
          print('---------metric evaluate step', global_step.numpy(), 'model_path:', model_path)
          names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None

          vals, names = evaluate(model, valid_dataset, eval_fn, model_path, 
                                names, valid_write_fn, write_streaming,
                                num_valid_steps_per_epoch, num_valid_examples, suffix=valid_suffix, sep=sep)

        if not use_horovod or hvd.rank() == 0:
          if vals and names:
            logging.info2('epoch:%.2f/%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs), 
                          'step:%d' % global_step.numpy(),
                          'valid_metrics',
                          ['%s:%.5f' % (name, val) for name, val in zip(names, vals)])

        if not use_horovod or hvd.rank() == 0:
          with writer.as_default(), summary.always_record_summaries():
            temp = global_step.value()
            global_step.assign(int(global_step.numpy() / int(num_steps_per_epoch * FLAGS.valid_interval_epochs)))
            if valid_dataset:
              if hasattr(model, 'eval'):
                model.eval()
              if vals and names:
                for name, val in zip(names, vals):
                  summary.scalar(f'eval/{name}', val)
            writer.flush()
            global_step.assign(temp)

      if test_dataset and global_step.numpy() % int(num_steps_per_epoch * FLAGS.inference_interval_epochs) == 0:
        if hasattr(model, 'eval'):
          model.eval()
        if inference_fn is None:
          inference(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), 
                    infer_names, infer_debug_names, infer_write_fn, write_streaming,
                    num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix, sep=sep)
        else:
          inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch)

      if num_epochs and (global_step.numpy() % num_steps_per_epoch) == 0 and int(global_step.numpy() / num_steps_per_epoch) == num_epochs:
        logging.info(f'Finshed training of {num_epochs} epochs')
        exit(0)
Пример #23
0
def train(Dataset, 
          model, 
          loss_fn, 
          evaluate_fn=None, 
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          sep=','):
  if FLAGS.torch:
    if torch.cuda.is_available():
      model.cuda()
  
  input_ =  FLAGS.train_input 
  inputs = gezi.list_files(input_)
  inputs.sort()

  all_inputs = inputs

  batch_size = FLAGS.batch_size

  num_gpus = melt.num_gpus()
  if num_gpus > 1:
    assert False, 'Eager mode train currently not support for num gpus > 1'

  #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
  batch_size_ = batch_size

  if FLAGS.fold is not None:
    inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)]

  logging.info('inputs', inputs)

  dataset = Dataset('train')
  num_examples = dataset.num_examples_per_epoch('train') 
  num_all_examples = num_examples

  # if FLAGS.fold is not None:
  #   valid_inputs = [x for x in all_inputs if x not in inputs]
  # else:
  #   valid_inputs = gezi.list_files(FLAGS.valid_input)
  
  # logging.info('valid_inputs', valid_inputs)

  # if valid_inputs:
  #   valid_dataset_ = Dataset('valid')
  #   valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs)
  #   valid_dataset2 = valid_dataset_.make_batch(batch_size_, valid_inputs, repeat=True)
  # else:
  #   valid_datsset = None
  #   valid_dataset2 = None

  if num_examples:
    if FLAGS.fold is not None:
      num_examples = int(num_examples * (len(inputs) / (len(inputs) + 1)))
    num_steps_per_epoch = -(-num_examples // batch_size)
  else:
    num_steps_per_epoch = None

  # if FLAGS.fold is not None:
  #   if num_examples:
  #     num_valid_examples = int(num_all_examples * (1 / (len(inputs) + 1)))
  #     num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_)
  #   else:
  #     num_valid_steps_per_epoch = None
  # else:
  #   num_valid_examples = valid_dataset_.num_examples_per_epoch('valid')
  #   num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None

  # test_inputs = gezi.list_files(FLAGS.test_input)
  # logging.info('test_inputs', test_inputs)
  
  # if test_inputs:
  #   test_dataset_ = Dataset('test')
  #   test_dataset = test_dataset_.make_batch(batch_size_, test_inputs) 
  #   num_test_examples = test_dataset_.num_examples_per_epoch('test')
  #   num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None
  # else:
  #   test_dataset = None
  
  summary = tf.contrib.summary
  # writer = summary.create_file_writer(FLAGS.model_dir + '/epoch')
  # writer_train = summary.create_file_writer(FLAGS.model_dir + '/train')
  # writer_valid = summary.create_file_writer(FLAGS.model_dir + '/valid')
  writer = summary.create_file_writer(FLAGS.model_dir)
  writer_train = summary.create_file_writer(FLAGS.model_dir)
  writer_valid = summary.create_file_writer(FLAGS.model_dir)
  global_step = tf.train.get_or_create_global_step()

  learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")
  tf.add_to_collection('learning_rate', learning_rate)

  learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
  try:
    learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
  except Exception:
    learning_rate_weights = None

  ckpt_dir = FLAGS.model_dir + '/ckpt'

  #TODO FIXME now I just changed tf code so to not by default save only latest 5
  # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
  #     checkpoint, directory=ckpt_dir, max_to_keep=5)
  # latest_checkpoint = manager.latest_checkpoint

  latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
  logging.info('Latest checkpoint:', latest_checkpoint)
  checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')

  if not FLAGS.torch:
    optimizer = melt.get_optimizer(FLAGS.optimizer)(learning_rate)
    
    # TODO...
    if  learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            model=model,
            optimizer=optimizer,
            global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            model=model,
            optimizer=optimizer,
            global_step=global_step)
      
    if os.path.exists(FLAGS.model_dir + '.index'):
      latest_checkpoint = FLAGS.model_dir   

    checkpoint.restore(latest_checkpoint)

    start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint else 0
  else:
    # TODO torch with learning rate adjust
    optimizer = torch.optim.Adamax(model.parameters(), lr=FLAGS.learning_rate)

    if latest_checkpoint:
      checkpoint = torch.load(latest_checkpoint + '.pyt')
      start_epoch = checkpoint['epoch']
      model.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      model.eval()
    else:
      start_epoch = 0

    if learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
          learning_rate=learning_rate, 
          learning_rate_weight=learning_rate_weight,
          global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            global_step=global_step)

  #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
  #model.save('./weight3.hd5')

  # TODO currently not support 0.1 epoch.. like this
  num_epochs = FLAGS.num_epochs
  
 
  class PytObj(object):
    def __init__(self, x):
      self.x = x
    def numpy(self):
      return self.x

  class PytMean(object):
    def __init__(self):
      self._val = 0. 
      self.count = 0

      self.is_call = True

    def clear(self):
      self._val = 0
      self.count = 0

    def __call__(self, val):
      if not self.is_call:
        self.clear()
        self.is_call = True
      self._val += val.item()
      self.count += 1

    def result(self):
      if self.is_call:
        self.is_call = False
      if not self.count:
        val = 0
      else:
        val = self._val / self.count
      # TODO just for compact with tf ..
      return PytObj(val)
      
  # TODO consider multiple gpu for torch 

  iter = dataset.make_batch(batch_size, inputs, repeat=False, initializable=False)
  batch = iter.get_next()
  #x, y = melt.split_batch(batch, batch_size, num_gpus)
  x_, y_ = batch
  
  Mean =  tfe.metrics.Mean if not FLAGS.torch else PytMean
  epoch_loss_avg = Mean()
  epoch_valid_loss_avg = Mean()

  sess = melt.get_session(device_count={'GPU': 0})
  global_step = 0
  for epoch in range(start_epoch, num_epochs):
    melt.set_global('epoch', '%.4f' % (epoch))
    sess.run(iter.initializer)

    model.train()

    #..... still OOM... FIXME TODO
    try:
      for _ in tqdm(range(num_steps_per_epoch), total=num_steps_per_epoch, ascii=True):
        x, y = sess.run([x_, y_])
        x, y = to_torch(x, y)
        
        optimizer.zero_grad()
        loss = loss_fn(model, x, y)
        loss.backward()
        optimizer.step()

        epoch_loss_avg(loss) 

        if global_step % FLAGS.interval_steps == 0:
          print(global_step, epoch_loss_avg.result().numpy())

        global_step += 1
    except tf.errors.OutOfRangeError:
      print('epoch:%d loss:%f' % (epoch, epoch_loss_avg.result().numpy()))
Пример #24
0
def main(_):

  base = FLAGS.base
  logging.set_logging_path('./mount/tmp/')
  vocab_path = f'{base}/vocab.txt'
  ids2text.init(vocab_path)
  FLAGS.vocab = f'{base}/vocab.txt'

  # FLAGS.length_index = 2
  # FLAGS.buckets = '100,400'
  # FLAGS.batch_sizes = '64,64,32'

  input_ = FLAGS.input 
  if FLAGS.type == 'test':
    input_ = input_.replace('valid', 'test')

  inputs = gezi.list_files(input_)
  inputs.sort()
  if FLAGS.fold is not None:
    inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)]

  if FLAGS.type == 'debug':
    print('type', FLAGS.type, 'inputs', inputs, file=sys.stderr)

    dataset = Dataset('valid')
    dataset = dataset.make_batch(FLAGS.batch_size_, inputs)

    print('dataset', dataset)

    timer = gezi.Timer('read record')
    for i, (x, y) in enumerate(dataset):
      # if i % 10 == 1:
      #   print(x['id'])
      #   print(x['content'][0])
      #   print(ids2text.ids2text(x['content'][0], sep='|'))
      #   print(x['content'])
      #   print(type(x['id'].numpy()[0]) == bytes)
      #   break
      x['id'] = gezi.decode(x['id'].numpy())
      x['content_str'] = gezi.decode(x['content_str'].numpy())
      for j, id in enumerate(x['id']):
        if id == '573':
          print(id, x['content_str'][j])
  elif FLAGS.type == 'dump':
    valid_infos = {}
    test_infos = {}
    # TODO notice train and valid also share ids.. so valid only save 0 is ok...
    # 120000 doc but first 15000 train duplicate id with valid so only save valid result for those ids currently
    inputs = gezi.list_files(f'{base}/train/*record')
    dataset = Dataset('valid')
    dataset = dataset.make_batch(1, inputs)
    deal(dataset, valid_infos)
    print('after valid', len(valid_infos))

    for key in valid_infos:
      print(valid_infos[key])
      print(ids2text.ids2text(valid_infos[key]['content']))
      break

    ofile = f'{base}/info.pkl'
    with open(ofile, 'wb') as out:
      pickle.dump(valid_infos, out)  

    del valid_infos

    inputs = gezi.list_files(f'{base}/test/*record')
    dataset = Dataset('test')
    dataset = dataset.make_batch(1, inputs)
    deal(dataset, test_infos)
    print('after test', len(test_infos))

    ofile = ofile.replace('.pkl', '.test.pkl')  
    with open(ofile, 'wb') as out:
      pickle.dump(test_infos, out)
    for key in test_infos:
      print(test_infos[key])
      print(ids2text.ids2text(test_infos[key]['content']))
      break
  elif FLAGS.type == 'show_info':
    valid_infos = pickle.load(open(f'{base}/info.pkl', 'rb'))
    lens = [len(valid_infos[key]['content']) for key in valid_infos]
    unks = [list(valid_infos[key]['content']).count(FLAGS.unk_id) for key in valid_infos]
    print('num unks per doc:', sum(unks) / len(valid_infos))
    print('num doc with unk ratio:', len([x for x in unks if x != 0]) / len(unks)) 
    print('un unk tokens ratio:', sum(unks) / sum(lens))
    print('len max:', np.max(lens))
    print('len min:', np.min(lens))
    print('len mean:', np.mean(lens)) 
    print('num docs:', len(valid_infos))

    num_show = 0
    for key in valid_infos:
      if list(valid_infos[key]['content']).count(FLAGS.unk_id) > 0:
        print(valid_infos[key])
        print(ids2text.ids2text(valid_infos[key]['content']))
        num_show += 1
        if num_show > 5:
          break

    del valid_infos
    print('--------------for test info:')
    test_infos = pickle.load(open(f'{base}/info.test.pkl', 'rb'))
    lens = [len(test_infos[key]['content']) for key in test_infos]
    unks = [list(test_infos[key]['content']).count(FLAGS.unk_id) for key in test_infos]
    print('num unks per doc:', sum(unks) / len(test_infos))
    print('num doc with unk ratio:', len([x for x in unks if x != 0]) / len(test_infos)) 
    print('un unk tokens ratio:', sum(unks) / sum(lens))
    print('len max:', np.max(lens))
    print('len min:', np.min(lens))
    print('len mean:', np.mean(lens))
    print('num docs:', len(test_infos))

    num_show = 0
    for key in test_infos:
      if list(test_infos[key]['content']).count(FLAGS.unk_id) > 0:
        print(test_infos[key])
        print(ids2text.ids2text(test_infos[key]['content']))
        num_show += 1
        if num_show > 5:
          break
  else:
    raise ValueError(FLAGS.type)
Пример #25
0
        return x, y.to(device)
    if y is not None:
        y = torch_(y)

    if not isinstance(x, dict):
        x = torch_(x)
    else:
        for key in x:
            x[key] = torch_(x[key])
    if y is None:
        return x
    else:
        return x, y


files = gezi.list_files('../input/train.small/*')
td = TD()
ds = get_dataset(files, td)
dl = DataLoader(ds, 2, collate_fn=lele.DictPadCollate())
print(len(ds), len(dl), len(dl.dataset))
for i, (x, y) in enumerate(dl):
    print(i, x['id'][0], x['value'][0])
    # #print('--------------', d)
    # print(x['index'].shape)
    # print(x['field'].shape)
    # print(x['value'].shape)
    # print(x['id'].shape)
    # print(y.shape)
    #print(x)
    # for key in x:
    #   print(key, type(x[key][0]), type(x[key]), x[key][0].dtype)