示例#1
0
文件: train.py 项目: arfu2016/decaNLP
def main():
    args = arguments.parse()
    if args is None:
        return
    set_seed(args)
    logger = initialize_logger(args)
    logger.info(f'Arguments:\n{pformat(vars(args))}')
    # 调用vars(args)的format函数,得到字符串?

    field, save_dict = None, None
    # tuple unpacking
    if args.load is not None:
        logger.info(f'Loading field from {os.path.join(args.save, args.load)}')
        save_dict = torch.load(os.path.join(args.save, args.load))
        field = save_dict['field']
        # field is the value in the 'field' key of the data
    field, train_sets, val_sets = prepare_data(args, field, logger)

    run_args = (field, train_sets, val_sets, save_dict)
    if len(args.gpus) > 1:
        logger.info(f'Multiprocessing')
        # 多gpu
        mp = Multiprocess(run, args)
        mp.run(run_args)
    else:
        logger.info(f'Processing')
        # f string of python 3.6
        run(args, run_args, world_size=args.world_size)
示例#2
0
def test_do_tasks_after_error_raises_exc():
  def f():
    raise ValueError('error')

  m = Multiprocess()
  m.add_tasks(f, [()])
  with pytest.raises(MultiprocessProcessException):
    m.do_tasks()
  with pytest.raises(MultiprocessClosed):
    m.do_tasks()
  m.close()
示例#3
0
def test_do_tasks_twice():
  def f(q):
    q.push(1)

  q = Queue()
  m = Multiprocess()
  m.add_tasks(f, [(q, )])
  m.do_tasks()
  assert (q.qsize() == 1)
  m.add_tasks(f, [(q, )])
  m.do_tasks()
  assert (q.qsize() == 2)
示例#4
0
def test_do_tasks_after_close_raises_exc():
  m = Multiprocess()
  f = lambda: 1
  m.add_tasks(f, [()])
  m.close()
  with pytest.raises(MultiprocessClosed):
    m.do_tasks()
示例#5
0
def test_empty_fn():
  def f():
    pass

  m = Multiprocess()
  m.add_tasks(f, [()])
  m.do_tasks()
  m.close()
示例#6
0
def test_lambda_pickling():
  f = lambda q: q.push(1)
  q = Queue()
  m = Multiprocess()
  m.add_tasks(f, [(q, )])
  m.do_tasks()
  m.close()
  assert (q.qsize() == 1)
  assert (q.pop() == 1)
  assert (q.qsize() == 0)
示例#7
0
def test_fn_raises_exc_is_caught():
  m = Multiprocess()

  def f():
    raise ValueError('unique')

  m.add_tasks(f, [()])
  with pytest.raises(MultiprocessProcessException) as excinfo:
    m.do_tasks()
  assert 'unique' in str(excinfo.value)
  m.close()
示例#8
0
def test_queue():
  q = Queue()
  m = Multiprocess()

  def f(q):
    q.push(1)

  m.add_tasks(f, [(q, )])
  m.do_tasks()
  m.close()
  assert (q.qsize() == 1)
  assert (q.pop() == 1)
  assert (q.qsize() == 0)
def test_stdout(capsys):
  m = Multiprocess()
  sys.stderr.write("\r\n")
  m.add_tasks(lambda: 1, [()])
  m.do_tasks()
  m.close()
  out, err = capsys.readouterr()
  assert (out == '')
  spl = [v.split('\r')[-1] for v in err.split('\n')]
  while len(spl) and spl[0] == '':
    spl = spl[1:]
  assert (len(spl) == 2)
  assert (spl[0].startswith('100%'))
  assert (spl[1] == '')
def test_disable_load_bar_works(capsys):
  m = Multiprocess(show_loading_bar=False)
  m.add_tasks(lambda: 1, [()])
  m.close()
  out, err = capsys.readouterr()
  assert (out == '')
  assert (err == '')
示例#11
0
def main():
    args = arguments.parse()
    if args is None:
        return
    set_seed(args)
    logger = initialize_logger(args)
    logger.info(f'Arguments:\n{pformat(vars(args))}')

    field, save_dict = None, None
    if args.load is not None:
        logger.info(f'Loading field from {os.path.join(args.save, args.load)}')
        save_dict = torch.load(os.path.join(args.save, args.load))
        field = save_dict['field']
    field, train_sets, val_sets = prepare_data(args, field, logger)

    run_args = (field, train_sets, val_sets, save_dict)
    if len(args.gpus) > 1:
        logger.info(f'Multiprocessing')
        mp = Multiprocess(run, args)
        mp.run(run_args)
    else:
        logger.info(f'Processing')
        run(args, run_args, world_size=args.world_size)
示例#12
0
文件: train.py 项目: AhlamMD/decaNLP
def main():
    args = arguments.parse()
    if args is None:
        return
    set_seed(args)
    logger = initialize_logger(args)
    logger.info(f'Arguments:\n{pformat(vars(args))}')

    field, save_dict = None, None
    if args.load is not None:
        logger.info(f'Loading field from {os.path.join(args.save, args.load)}')
        save_dict = torch.load(os.path.join(args.save, args.load))
        field = save_dict['field']
    field, train_sets, val_sets = prepare_data(args, field, logger)

    run_args = (field, train_sets, val_sets, save_dict)
    if len(args.gpus) > 1:
        logger.info(f'Multiprocessing')
        mp = Multiprocess(run, args)
        mp.run(run_args)
    else:
        logger.info(f'Processing')
        run(args, run_args, world_size=args.world_size)
示例#13
0
def test_multiple_tasks():
  def f(q, num):
    q.push(num)

  q = Queue()
  m = Multiprocess()
  arr = range(100)
  m.add_tasks(f, [(
      q,
      num,
  ) for num in arr])
  m.do_tasks()
  m.close()
  assert (q.qsize() == len(arr))
  l = [q.pop() for v in range(q.qsize())]
  assert (set(l) == set(arr))
示例#14
0
def test_close_does_calls_do_task():
  def f(q):
    q.push(1)

  q = Queue()
  m = Multiprocess()
  m.add_tasks(f, [(q, )])
  m.close()
  assert (q.qsize() == 1)
  assert (q.pop() == 1)
  assert (q.qsize() == 0)
def multiprocess(fn, arr_of_args, **kwargs):
    """Execute several tasks in parallel. Requires a function `fn`
  and an array of argument tuples `arr_of_args`, each representing a call to the function.

  Additionally, you can provide arguments the same as you would with `Multiprocess`

  Example

    >>> # exec f(x) and f(y) in parallel
    >>> multiprocess(f, [(x,), (y,)])

  If you don't want a loading bar

    >>> multiprocess(f, [(x,), (y,)], show_loading_bar=False)
  """
    m = Multiprocess(**kwargs)
    m.add_tasks(fn, arr_of_args)
    m.do_tasks()
    m.close()
示例#16
0
def init_opt(args, model):
    opt = None
    if args.transformer_lr:
        opt = torch.optim.Adam(model.params, betas=(0.9, 0.98), eps=1e-9)
    else:
        opt = torch.optim.Adam(model.params, betas=(args.beta0, 0.999))
    return opt


if __name__ == '__main__':
    args = arguments.parse()
    set_seed(args)
    logger = initialize_logger(args)
    logger.info(f'Arguments:\n{pformat(vars(args))}')

    field, save_dict = None, None
    if args.load is not None:
        logger.info(f'Loading field from {os.path.join(args.save, args.load)}')
        save_dict = torch.load(os.path.join(args.save, args.load))
        field = save_dict['field']
    field, train_sets, val_sets = prepare_data(args, field, logger)

    run_args = (field, train_sets, val_sets, save_dict)
    if len(args.gpus) > 1:
        logger.info(f'Multiprocessing')
        mp = Multiprocess(run, args)
        mp.run(run_args)
    else:
        logger.info(f'Processing')
        run(args, run_args, world_size=args.world_size)
示例#17
0
def test_close_twice_doesnt_raise_exc():
  m = Multiprocess()
  m.close()
  m.close()
示例#18
0
def test_no_loading_bar():
  m = Multiprocess(show_loading_bar=False)
  f = lambda: 1
  m.add_tasks(f, [()])
  m.close()
示例#19
0
def init_opt(args, model):
    opt = None
    if args.transformer_lr:
        opt = torch.optim.Adam(model.params, betas=(0.9, 0.98), eps=1e-9)
    else:
        opt = torch.optim.Adam(model.params, betas=(args.beta0, 0.999))
    return opt


if __name__ == '__main__':
    args = arguments.parse()
    set_seed(args)
    logger = initialize_logger(args)
    logger.info(f'Arguments:\n{pformat(vars(args))}')

    field, save_dict = None, None
    if args.load is not None:
        logger.info(f'Loading field from {os.path.join(args.save, args.load)}')
        save_dict = torch.load(os.path.join(args.save, args.load))
        field = save_dict['field']
    field, train_sets, val_sets = prepare_data(args, field, logger)

    run_args = (field, train_sets, val_sets, save_dict)
    if len(args.gpus) > 1:
        logger.info(f'Multiprocessing')
        mp = Multiprocess(run, args)
        mp.run(run_args)
    else:
        logger.info(f'Processing')
        run(args, run_args, world_size=args.world_size)