コード例 #1
0
def main(_):
    FLAGS.torch_only = True
    #FLAGS.valid_input = None
    melt.init()
    fit = melt.get_fit()

    FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier

    model_name = FLAGS.model
    model = getattr(base, model_name)()

    loss_fn = nn.BCEWithLogitsLoss()

    td = text_dataset.Dataset()
    train_files = gezi.list_files(FLAGS.train_input)
    train_ds = get_dataset(train_files, td)

    ## speed up a bit with pin_memory==True
    ## num_workers 1 is very slow especially for validation, seems 4 workers is enough, large number dangerous sometimes 12 ok sometimes hang, too much resource seems

    #kwargs = {'num_workers': 12, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
    #kwargs = {'num_workers': 6, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
    kwargs = {
        'num_workers': 8,
        'pin_memory': True,
        'collate_fn': lele.DictPadCollate()
    }
    ## for 1 gpu, set > 8 might startup very slow
    #num_workers = int(8 / hvd.size())
    # num_workers = 0
    # pin_memory = False
    #kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory, 'collate_fn': lele.DictPadCollate()}

    train_dl = DataLoader(train_ds, FLAGS.batch_size, shuffle=True, **kwargs)

    #kwargs['num_workers'] = max(1, num_workers)
    #logging.info('num train examples', len(train_ds), len(train_dl))

    if FLAGS.valid_input:
        valid_files = gezi.list_files(FLAGS.valid_input)
        valid_ds = get_dataset(valid_files, td)
        valid_dl = DataLoader(valid_ds, FLAGS.eval_batch_size, **kwargs)

        #kwargs['num_workers'] = max(1, num_workers)
        valid_dl2 = DataLoader(valid_ds, FLAGS.batch_size, **kwargs)
        #logging.info('num valid examples', len(valid_ds), len(valid_dl))

    fit(
        model,
        loss_fn,
        dataset=train_dl,
        valid_dataset=valid_dl,
        valid_dataset2=valid_dl2,
        eval_fn=ev.evaluate,
        valid_write_fn=ev.valid_write,
        #write_valid=FLAGS.write_valid)
        write_valid=False,
    )
コード例 #2
0
def main(_):
  FLAGS.torch_only = True
  melt.init()
  fit = melt.get_fit()

  FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier

  model_name = FLAGS.model
  model = getattr(base, model_name)() 

  model = model.cuda()

  loss_fn = nn.BCEWithLogitsLoss()

  td = text_dataset.Dataset()

  train_files = gezi.list_files('../input/train/*')
  train_ds = get_dataset(train_files, td)
  
  #kwargs = {'num_workers': 4, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
  #kwargs = {'num_workers': 0, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
  #kwargs = {'num_workers': 4, 'pin_memory': True, 'collate_fn': lele.DictPadCollate()}
  
  num_workers = 1
  kwargs = {'num_workers': num_workers, 'pin_memory': False, 'collate_fn': lele.DictPadCollate()}

  train_sampler = train_ds
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_ds, num_replicas=hvd.size(), rank=hvd.rank())
  
  train_dl = DataLoader(train_ds, FLAGS.batch_size, sampler=train_sampler, **kwargs)
  
  valid_files = gezi.list_files('../input/valid/*')
  valid_ds = get_dataset(valid_files, td)

  kwargs['num_workers'] = 1
  # support shuffle=False from version 1.2
  valid_sampler = torch.utils.data.distributed.DistributedSampler(
      valid_ds, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=False)

  kwargs['num_workers'] = 1
  valid_sampler2 = torch.utils.data.distributed.DistributedSampler(
      valid_ds, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=False)
  
  valid_dl = DataLoader(valid_ds, FLAGS.eval_batch_size, sampler=valid_sampler, **kwargs)
  valid_dl2 = DataLoader(valid_ds, FLAGS.batch_size, sampler=valid_sampler2, **kwargs)

  fit(model,  
      loss_fn,
      dataset=train_dl,
      valid_dataset=valid_dl,
      valid_dataset2=valid_dl2,
      eval_fn=ev.evaluate,
      valid_write_fn=ev.valid_write,
      #write_valid=FLAGS.write_valid)   
      write_valid=False,
     )
コード例 #3
0
ファイル: torch-only-train.py プロジェクト: liurht/wenzheng
def main(_):
    FLAGS.torch_only = True
    melt.init()
    fit = melt.get_fit()

    FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier

    model_name = FLAGS.model
    model = getattr(base, model_name)()

    loss_fn = nn.BCEWithLogitsLoss()

    td = TextDataset()
    train_files = gezi.list_files('../input/train/*')
    train_ds = get_dataset(train_files, td)

    import multiprocessing
    #--easy to be Killed .. if large workers
    num_threads = int(multiprocessing.cpu_count() * 0.3)
    logging.info('num_threads as multiprocessing.cpu_count', num_threads)

    num_threads = 12
    train_dl = DataLoader(train_ds,
                          FLAGS.batch_size,
                          shuffle=True,
                          num_workers=num_threads,
                          collate_fn=lele.DictPadCollate())
    #logging.info('num train examples', len(train_ds), len(train_dl))
    valid_files = gezi.list_files('../input/valid/*')
    valid_ds = get_dataset(valid_files, td)
    valid_dl = DataLoader(valid_ds,
                          FLAGS.eval_batch_size,
                          collate_fn=lele.DictPadCollate(),
                          num_workers=num_threads)
    valid_dl2 = DataLoader(valid_ds,
                           FLAGS.batch_size,
                           collate_fn=lele.DictPadCollate(),
                           num_workers=num_threads)
    #logging.info('num valid examples', len(valid_ds), len(valid_dl))

    fit(
        model,
        loss_fn,
        dataset=train_dl,
        valid_dataset=valid_dl,
        valid_dataset2=valid_dl2,
        eval_fn=ev.evaluate,
        valid_write_fn=ev.valid_write,
        #write_valid=FLAGS.write_valid)
        write_valid=False,
    )
コード例 #4
0
def main(_):
    FLAGS.torch_only = True
    melt.init()
    fit = melt.get_fit()

    FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier

    model_name = FLAGS.model
    model = getattr(base, model_name)()

    loss_fn = nn.BCEWithLogitsLoss()

    td = TextDataset()
    train_files = gezi.list_files('../input/train/*')
    train_ds = get_dataset(train_files, td)

    train_dl = DataLoader(train_ds,
                          FLAGS.batch_size,
                          shuffle=True,
                          num_workers=12)
    logging.info('num train examples', len(train_ds), len(train_dl))
    valid_files = gezi.list_files('../input/valid/*')
    valid_ds = get_dataset(valid_files, td)
    valid_dl = DataLoader(valid_ds, FLAGS.eval_batch_size)
    valid_dl2 = DataLoader(valid_ds, FLAGS.batch_size)
    logging.info('num valid examples', len(valid_ds), len(valid_dl))
    print(dir(valid_dl))

    fit(
        model,
        loss_fn,
        dataset=train_dl,
        valid_dataset=valid_dl,
        valid_dataset2=valid_dl2,
        eval_fn=ev.evaluate,
        valid_write_fn=ev.valid_write,
        #write_valid=FLAGS.write_valid)
        write_valid=False,
    )
コード例 #5
0
ファイル: torch-train.py プロジェクト: shykoe/wenzheng
def main(_):
    FLAGS.torch = True
    melt.init()
    fit = melt.get_fit()

    FLAGS.eval_batch_size = 512 * FLAGS.valid_multiplier
    print('---------eval_batch_size', FLAGS.eval_batch_size)

    model_name = FLAGS.model
    model = getattr(base, model_name)()

    Dataset = TextDataset if not 'tfrecord' in FLAGS.train_input else TFRecordDataset

    loss_fn = nn.BCEWithLogitsLoss()

    fit(
        model,
        loss_fn,
        Dataset,
        eval_fn=ev.evaluate,
        valid_write_fn=ev.valid_write,
        #write_valid=FLAGS.write_valid)
        write_valid=False)