def main(*args):
    flags = tf.flags.FLAGS
    opt = Config()
    for key in flags:
        opt.setdefault(key, flags.get_flag_value(key, None))
    check_args(opt)
    data_config_file = Path(opt.data_config)
    if not data_config_file.exists():
        raise RuntimeError("dataset config file doesn't exist!")
    for _suffix in ('json', 'yaml'):  # for compatibility
        # apply a 2-stage (or master-slave) configuration, master can be
        # override by slave
        model_config_root = Path(f'parameters/root.{_suffix}')
        if opt.p:
            model_config_file = Path(opt.p)
        else:
            model_config_file = Path(f'parameters/{opt.model}.{_suffix}')
        if model_config_root.exists():
            opt.update(Config(str(model_config_root)))
        if model_config_file.exists():
            opt.update(Config(str(model_config_file)))

    model_params = opt.get(opt.model)
    opt.update(model_params)
    model = get_model(opt.model)(**model_params)
    root = '{}/{}'.format(opt.save_dir, model.name)
    if opt.comment:
        root += '_' + opt.comment
    opt.root = root
    verbosity = tf.logging.DEBUG if opt.v else tf.logging.INFO
    # map model to trainer, ~~manually~~ automatically, by setting `_trainer`
    # attribute in models
    trainer = model.trainer
    train_data, test_data, infer_data = fetch_datasets(data_config_file, opt)
    train_config, test_config, infer_config = init_loader_config(opt)
    test_config.subdir = test_data.name
    infer_config.subdir = 'infer'
    # start fitting!
    dump(opt)
    with trainer(model, root, verbosity) as t:
        # prepare loader
        loader = partial(QuickLoader, n_threads=opt.threads)
        train_loader = loader(train_data, 'train', train_config,
                              augmentation=True)
        val_loader = loader(train_data, 'val', train_config, crop='center',
                            steps_per_epoch=1)
        test_loader = loader(test_data, 'test', test_config)
        infer_loader = loader(infer_data, 'infer', infer_config)
        # fit
        t.fit([train_loader, val_loader], train_config)
        # validate
        t.benchmark(test_loader, test_config)
        # do inference
        t.infer(infer_loader, infer_config)
        if opt.export:
            t.export(opt.root + '/exported', opt.freeze)
Exemple #2
0
def main(*args):
    flags = tf.flags.FLAGS
    flags.mark_as_parsed()
    opt = Config()
    for key in flags:
        opt.setdefault(key, flags.get_flag_value(key, None))
    check_args(opt)
    data_config_file = Path(opt.data_config)
    if not data_config_file.exists():
        raise RuntimeError("dataset config file doesn't exist!")
    for _suffix in ('json', 'yaml'):
        # apply a 2-stage (or master-slave) configuration, master can be override by slave
        model_config_root = Path('parameters/{}.{}'.format('root', _suffix))
        model_config_file = Path('parameters/{}.{}'.format(opt.model, _suffix))
        if model_config_root.exists():
            opt.update(Config(str(model_config_root)))
        if model_config_file.exists():
            opt.update(Config(str(model_config_file)))

    model_params = opt.get(opt.model)
    opt.update(model_params)
    model = get_model(opt.model)(**model_params)
    root = '{}/{}_sc{}_c{}'.format(opt.save_dir, model.name, opt.scale, opt.channel)
    if opt.comment:
        root += '_' + opt.comment
    opt.root = root
    verbosity = tf.logging.DEBUG if opt.v else tf.logging.INFO
    # map model to trainer, manually
    if opt.model == 'zssr':
        trainer = ZSSR
    elif opt.model == 'frvsr':
        trainer = FRVSR
    else:
        trainer = VSR
    train_data, test_data, infer_data = fetch_datasets(data_config_file, opt)
    train_config, test_config, infer_config = init_loader_config(opt)
    test_config.subdir = test_data.name
    # start fitting!
    with trainer(model, root, verbosity) as t:
        # prepare loader
        loader = partial(QuickLoader, n_threads=opt.threads)
        train_loader = loader(train_data, 'train', train_config, augmentation=True)
        val_loader = loader(train_data, 'val', train_config, augmentation=True, crop='center', steps_per_epoch=1)
        test_loader = loader(test_data, 'test', test_config)
        infer_loader = loader(infer_data, 'infer', infer_config)
        # fit
        t.fit([train_loader, val_loader], train_config)
        # validate
        t.benchmark(test_loader, test_config)
        # do inference
        t.infer(infer_loader, infer_config)
        if opt.export:
            t.export(opt.root)
Exemple #3
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)

    data_config_file = Path(flags.data_config)
    if not data_config_file.exists():
        raise RuntimeError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        # apply a 2-stage (or master-slave) configuration, master can be
        # override by slave
        model_config_root = Path('Parameters/root.{}'.format(_ext))
        if opt.p:
            model_config_file = Path(opt.p)
        else:
            model_config_file = Path('Parameters/{}.{}'.format(
                opt.model, _ext))
        if model_config_root.exists():
            opt.update(Config(str(model_config_root)))
        if model_config_file.exists():
            opt.update(Config(str(model_config_file)))

    model_params = opt.get(opt.model, {})
    opt.update(model_params)
    suppress_opt_by_args(model_params, *args)
    model = get_model(flags.model)(**model_params)
    if flags.cuda:
        model.cuda()
    root = f'{flags.save_dir}/{flags.model}'
    if flags.comment:
        root += '_' + flags.comment
    verbosity = logging.DEBUG if flags.verbose else logging.INFO
    trainer = model.trainer

    datasets = load_datasets(data_config_file)
    dataset = datasets[flags.dataset.upper()]

    train_config = Config(crop=opt.train_data_crop,
                          feature_callbacks=[],
                          label_callbacks=[],
                          convert_to='rgb',
                          **opt)
    if opt.channel == 1:
        train_config.convert_to = 'gray'
    if opt.lr_decay:
        train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay)
    train_config.random_val = not opt.traced_val
    train_config.cuda = flags.cuda

    if opt.verbose:
        dump(opt)
    with trainer(model, root, verbosity, opt.pth) as t:
        if opt.seed is not None:
            t.set_seed(opt.seed)
        tloader = QuickLoader(dataset, 'train', train_config, True,
                              flags.thread)
        vloader = QuickLoader(dataset,
                              'val',
                              train_config,
                              False,
                              flags.thread,
                              batch=1,
                              crop=opt.val_data_crop,
                              steps_per_epoch=opt.val_num)
        t.fit([tloader, vloader], train_config)
        if opt.export:
            t.export(opt.export)
Exemple #4
0
def main():
  flags, args = parser.parse_known_args()
  opt = Config()
  for pair in flags._get_kwargs():
    opt.setdefault(*pair)
  data_config_file = Path(flags.data_config)
  if not data_config_file.exists():
    raise RuntimeError("dataset config file doesn't exist!")
  for _ext in ('json', 'yaml', 'yml'):  # for compat
    # apply a 2-stage (or master-slave) configuration, master can be
    # override by slave
    model_config_root = Path('Parameters/root.{}'.format(_ext))
    if opt.p:
      model_config_file = Path(opt.p)
    else:
      model_config_file = Path('Parameters/{}.{}'.format(opt.model, _ext))
    if model_config_root.exists():
      opt.update(Config(str(model_config_root)))
    if model_config_file.exists():
      opt.update(Config(str(model_config_file)))

  model_params = opt.get(opt.model, {})
  suppress_opt_by_args(model_params, *args)
  opt.update(model_params)
  model = get_model(flags.model)(**model_params)
  if flags.cuda:
    model.cuda()
  root = f'{flags.save_dir}/{flags.model}'
  if flags.comment:
    root += '_' + flags.comment
  verbosity = logging.DEBUG if flags.verbose else logging.INFO
  trainer = model.trainer

  datasets = load_datasets(data_config_file)
  try:
    test_datas = [datasets[t.upper()] for t in flags.test]
    run_benchmark = True
  except KeyError:
    test_datas = []
    for pattern in flags.test:
      test_data = Dataset(test=_glob_absolute_pattern(pattern),
                          mode='pil-image1', modcrop=False)
      father = Path(flags.test)
      while not father.is_dir():
        if father.parent == father:
          break
        father = father.parent
      test_data.name = father.stem
      test_datas.append(test_data)
    run_benchmark = False

  if opt.verbose:
    dump(opt)
  for test_data in test_datas:
    loader_config = Config(convert_to='rgb',
                           feature_callbacks=[], label_callbacks=[],
                           output_callbacks=[], **opt)
    loader_config.batch = 1
    loader_config.subdir = test_data.name
    loader_config.output_callbacks += [
      save_image(root, flags.output_index, flags.auto_rename)]
    if opt.channel == 1:
      loader_config.convert_to = 'gray'

    with trainer(model, root, verbosity, flags.pth) as t:
      if flags.seed is not None:
        t.set_seed(flags.seed)
      loader = QuickLoader(test_data, 'test', loader_config,
                           n_threads=flags.thread)
      loader_config.epoch = flags.epoch
      if run_benchmark:
        t.benchmark(loader, loader_config)
      else:
        t.infer(loader, loader_config)