Beispiel #1
0
def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
         splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any],
         to_evaluate_indexes: tuple, cover_mode: bool):

  log_dir = save_dir / 'logs'
  log_dir.mkdir(parents=True, exist_ok=True)
  logger = Logger(str(log_dir), os.getpid(), False)

  logger.log('xargs : seeds      = {:}'.format(seeds))
  logger.log('xargs : cover_mode = {:}'.format(cover_mode))
  logger.log('-' * 100)

  logger.log(
    'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes))
   +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode))
  for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
    logger.log(
      '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
  logger.log('--->>> optimization config : {:}'.format(opt_config))
  #to_evaluate_indexes = list(range(srange[0], srange[1] + 1))

  start_time, epoch_time = time.time(), AverageMeter()
  for i, index in enumerate(to_evaluate_indexes):
    channelstr = nets[index]
    logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i,
                       len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15))
    logger.log('{:} {:} {:}'.format('-' * 15, channelstr, '-' * 15))

    # test this arch on different datasets with different seeds
    has_continue = False
    for seed in seeds:
      to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
      if to_save_name.exists():
        if cover_mode:
          logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
          os.remove(str(to_save_name))
        else:
          logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
          has_continue = True
          continue
      results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger)
      torch.save(results, to_save_name)
      logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}]  ===>>> {:}'.format(time_string(), i,
                    len(to_evaluate_indexes), index, len(nets), seeds, to_save_name))
    # measure elapsed time
    if not has_continue: epoch_time.update(time.time() - start_time)
    start_time = time.time()
    need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True))
    logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True)))
    logger.log('{:}'.format('*' * 100))
    logger.log('{:}   {:74s}   {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(
      to_evaluate_indexes), index, len(nets), need_time), '*' * 10))
    logger.log('{:}'.format('*' * 100))

  logger.close()
Beispiel #2
0
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.deterministic = True
  #torch.backends.cudnn.benchmark = True
  torch.set_num_threads( workers )
  
  save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}-{:}'.format('LESS' if use_less else 'FULL', model_str, arch_config['channel'], arch_config['num_cells'])
  logger   = Logger(str(save_dir), 0, False)
  if model_str in CellArchitectures:
    arch   = CellArchitectures[model_str]
    logger.log('The model string is found in pre-defined architecture dict : {:}'.format(model_str))
  else:
    try:
      arch = CellStructure.str2structure(model_str)
    except:
      raise ValueError('Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
  assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
  logger.log('Start train-evaluate {:}'.format(arch.tostr()))
  logger.log('arch_config : {:}'.format(arch_config))

  start_time, seed_time = time.time(), AverageMeter()
  for _is, seed in enumerate(seeds):
    logger.log('\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds), seed))
    to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
    if to_save_name.exists():
      logger.log('Find the existing file {:}, directly load!'.format(to_save_name))
      checkpoint = torch.load(to_save_name)
    else:
      logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
      checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger)
      torch.save(checkpoint, to_save_name)
    # log information
    logger.log('{:}'.format(checkpoint['info']))
    all_dataset_keys = checkpoint['all_dataset_keys']
    for dataset_key in all_dataset_keys:
      logger.log('\n{:} dataset : {:} {:}'.format('-'*15, dataset_key, '-'*15))
      dataset_info = checkpoint[dataset_key]
      #logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
      logger.log('Flops = {:} MB, Params = {:} MB'.format(dataset_info['flop'], dataset_info['param']))
      logger.log('config : {:}'.format(dataset_info['config']))
      logger.log('Training State (finish) = {:}'.format(dataset_info['finish-train']))
      last_epoch = dataset_info['total_epoch'] - 1
      train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
      valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
      logger.log('Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%'.format(train_acc1es[last_epoch], train_acc5es[last_epoch], 100-train_acc1es[last_epoch], valid_acc1es[last_epoch], valid_acc5es[last_epoch], 100-valid_acc1es[last_epoch]))
    # measure elapsed time
    seed_time.update(time.time() - start_time)
    start_time = time.time()
    need_time = 'Time Left: {:}'.format( convert_secs2time(seed_time.avg * (len(seeds)-_is-1), True) )
    logger.log('\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed, need_time))
  logger.close()
Beispiel #3
0
def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  #torch.backends.cudnn.benchmark = True
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( workers )

  assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange)
  
  if use_less:
    sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
  else:
    sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
  logger  = Logger(str(sub_dir), 0, False)

  all_archs = meta_info['archs']
  assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total'])
  assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1])
  if arch_index == -1:
    to_evaluate_indexes = list(range(srange[0], srange[1]+1))
  else:
    to_evaluate_indexes = [arch_index]
  logger.log('xargs : seeds      = {:}'.format(seeds))
  logger.log('xargs : arch_index = {:}'.format(arch_index))
  logger.log('xargs : cover_mode = {:}'.format(cover_mode))
  logger.log('-'*100)

  logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode))
  for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
    logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
  logger.log('--->>> architecture config : {:}'.format(arch_config))
  

  start_time, epoch_time = time.time(), AverageMeter()
  for i, index in enumerate(to_evaluate_indexes):
    arch = all_archs[index]
    logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15))
    #logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15))
    logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15))
  
    # test this arch on different datasets with different seeds
    has_continue = False
    for seed in seeds:
      to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
      if to_save_name.exists():
        if cover_mode:
          logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
          os.remove(str(to_save_name))
        else         :
          logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
          has_continue = True
          continue
      results = evaluate_all_datasets(CellStructure.str2structure(arch), \
                                        datasets, xpaths, splits, use_less, seed, \
                                        arch_config, workers, logger)
      torch.save(results, to_save_name)
      logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name))
    # measure elapsed time
    if not has_continue: epoch_time.update(time.time() - start_time)
    start_time = time.time()
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) )
    logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) ))
    logger.log('{:}'.format('*'*100))
    logger.log('{:}   {:74s}   {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10))
    logger.log('{:}'.format('*'*100))

  logger.close()
def main(args):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True
  prepare_seed(args.rand_seed)

  logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
  logger = Logger(args.save_path, logstr)
  logger.log('Main Function with logger : {:}'.format(logger))
  logger.log('Arguments : -------------------------------')
  for name, value in args._get_kwargs():
    logger.log('{:16} : {:}'.format(name, value))
  logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
  logger.log("Pillow  version : {}".format(PIL.__version__))
  logger.log("PyTorch version : {}".format(torch.__version__))
  logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

  # General Data Argumentation
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
  assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max)
  train_transform  = [transforms.PreCrop(args.pre_crop_expand)]
  train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))]
  train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
  #if args.arg_flip:
  #  train_transform += [transforms.AugHorizontalFlip()]
  if args.rotate_max:
    train_transform += [transforms.AugRotate(args.rotate_max)]
  train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
  train_transform += [transforms.ToTensor(), normalize]
  train_transform  = transforms.Compose( train_transform )

  eval_transform  = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)),  transforms.ToTensor(), normalize])
  assert (args.scale_min+args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(args.scale_min, args.scale_max, args.scale_eval)
  
  # Model Configure Load
  model_config = load_configure(args.model_config, logger)
  args.sigma   = args.sigma * args.scale_eval
  logger.log('Real Sigma : {:}'.format(args.sigma))

  # Training Dataset
  train_data   = Dataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
  train_data.load_list(args.train_lists, args.num_pts, True)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)


  # Evaluation Dataloader
  eval_loaders = []
  if args.eval_vlists is not None:
    for eval_vlist in args.eval_vlists:
      eval_vdata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_vdata.load_list(eval_vlist, args.num_pts, True)
      eval_vloader = torch.utils.data.DataLoader(eval_vdata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_vloader, True))

  if args.eval_ilists is not None:
    for eval_ilist in args.eval_ilists:
      eval_idata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_idata.load_list(eval_ilist, args.num_pts, True)
      eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_iloader, False))

  # Define network
  logger.log('configure : {:}'.format(model_config))
  net = obtain_model(model_config, args.num_pts + 1)
  assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(model_config.downsample, net.downsample)
  logger.log("=> network :\n {}".format(net))

  logger.log('Training-data : {:}'.format(train_data))
  for i, eval_loader in enumerate(eval_loaders):
    eval_loader, is_video = eval_loader
    logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders), 'video' if is_video else 'image', eval_loader.dataset))
    
  logger.log('arguments : {:}'.format(args))

  opt_config = load_configure(args.opt_config, logger)

  if hasattr(net, 'specify_parameter'):
    net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
  else:
    net_param_dict = net.parameters()

  optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger)
  logger.log('criterion : {:}'.format(criterion))
  net, criterion = net.cuda(), criterion.cuda()
  net = torch.nn.DataParallel(net)

  last_info = logger.last_info()
  if last_info.exists():
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info = torch.load(last_info)
    start_epoch = last_info['epoch'] + 1
    checkpoint  = torch.load(last_info['last_checkpoint'])
    assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info, checkpoint['epoch'])
    net.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done" .format(logger.last_info(), checkpoint['epoch']))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch = 0


  if args.eval_once:
    logger.log("=> only evaluate the model once")
    eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config)
    logger.close() ; return


  # Main Training and Evaluation Loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(start_epoch, opt_config.epochs):

    scheduler.step()
    need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs-epoch), True)
    epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
    LRs       = scheduler.get_lr()
    logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), opt_config))

    # train for one epoch
    train_loss, train_nme = train(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config)
    # log the results    
    logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss, train_nme*100))

    # remember best prec@1 and save checkpoint
    save_path = save_checkpoint({
          'epoch': epoch,
          'args' : deepcopy(args),
          'arch' : model_config.arch,
          'state_dict': net.state_dict(),
          'scheduler' : scheduler.state_dict(),
          'optimizer' : optimizer.state_dict(),
          }, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

    last_info = save_checkpoint({
          'epoch': epoch,
          'last_checkpoint': save_path,
          }, logger.last_info(), logger)

    eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)
    
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.close()
Beispiel #5
0
def train_single_model(
    save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config
):
    assert torch.cuda.is_available(), "CUDA is not available."
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True
    torch.set_num_threads(workers)

    save_dir = (
        Path(save_dir)
        / "specifics"
        / "{:}-{:}-{:}-{:}".format(
            "LESS" if use_less else "FULL",
            model_str,
            arch_config["channel"],
            arch_config["num_cells"],
        )
    )
    logger = Logger(str(save_dir), 0, False)
    if model_str in CellArchitectures:
        arch = CellArchitectures[model_str]
        logger.log(
            "The model string is found in pre-defined architecture dict : {:}".format(
                model_str
            )
        )
    else:
        try:
            arch = CellStructure.str2structure(model_str)
        except:
            raise ValueError(
                "Invalid model string : {:}. It can not be found or parsed.".format(
                    model_str
                )
            )
    assert arch.check_valid_op(
        get_search_spaces("cell", "full")
    ), "{:} has the invalid op.".format(arch)
    logger.log("Start train-evaluate {:}".format(arch.tostr()))
    logger.log("arch_config : {:}".format(arch_config))

    start_time, seed_time = time.time(), AverageMeter()
    for _is, seed in enumerate(seeds):
        logger.log(
            "\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format(
                _is, len(seeds), seed
            )
        )
        to_save_name = save_dir / "seed-{:04d}.pth".format(seed)
        if to_save_name.exists():
            logger.log(
                "Find the existing file {:}, directly load!".format(to_save_name)
            )
            checkpoint = torch.load(to_save_name)
        else:
            logger.log(
                "Does not find the existing file {:}, train and evaluate!".format(
                    to_save_name
                )
            )
            checkpoint = evaluate_all_datasets(
                arch,
                datasets,
                xpaths,
                splits,
                use_less,
                seed,
                arch_config,
                workers,
                logger,
            )
            torch.save(checkpoint, to_save_name)
        # log information
        logger.log("{:}".format(checkpoint["info"]))
        all_dataset_keys = checkpoint["all_dataset_keys"]
        for dataset_key in all_dataset_keys:
            logger.log(
                "\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)
            )
            dataset_info = checkpoint[dataset_key]
            # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
            logger.log(
                "Flops = {:} MB, Params = {:} MB".format(
                    dataset_info["flop"], dataset_info["param"]
                )
            )
            logger.log("config : {:}".format(dataset_info["config"]))
            logger.log(
                "Training State (finish) = {:}".format(dataset_info["finish-train"])
            )
            last_epoch = dataset_info["total_epoch"] - 1
            train_acc1es, train_acc5es = (
                dataset_info["train_acc1es"],
                dataset_info["train_acc5es"],
            )
            valid_acc1es, valid_acc5es = (
                dataset_info["valid_acc1es"],
                dataset_info["valid_acc5es"],
            )
            logger.log(
                "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format(
                    train_acc1es[last_epoch],
                    train_acc5es[last_epoch],
                    100 - train_acc1es[last_epoch],
                    valid_acc1es[last_epoch],
                    valid_acc5es[last_epoch],
                    100 - valid_acc1es[last_epoch],
                )
            )
        # measure elapsed time
        seed_time.update(time.time() - start_time)
        start_time = time.time()
        need_time = "Time Left: {:}".format(
            convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)
        )
        logger.log(
            "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format(
                _is, len(seeds), seed, need_time
            )
        )
    logger.close()
def main(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    prepare_seed(args.rand_seed)

    logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
    logger = Logger(args.save_path, logstr)
    logger.log('Main Function with logger : {:}'.format(logger))
    logger.log('Arguments : -------------------------------')
    for name, value in args._get_kwargs():
        logger.log('{:16} : {:}'.format(name, value))
    logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
    logger.log("Pillow  version : {}".format(PIL.__version__))
    logger.log("PyTorch version : {}".format(torch.__version__))
    logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(
        args.arg_flip, args.rotate_max)
    train_transform = [transforms.PreCrop(args.pre_crop_expand)]
    train_transform += [
        transforms.TrainScale2WH((args.crop_width, args.crop_height))
    ]
    train_transform += [
        transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)
    ]
    #if args.arg_flip:
    #  train_transform += [transforms.AugHorizontalFlip()]
    if args.rotate_max:
        train_transform += [transforms.AugRotate(args.rotate_max)]
    train_transform += [
        transforms.AugCrop(args.crop_width, args.crop_height,
                           args.crop_perturb_max, mean_fill)
    ]
    train_transform += [transforms.ToTensor(), normalize]
    train_transform = transforms.Compose(train_transform)

    eval_transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])
    assert (
        args.scale_min + args.scale_max
    ) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(
        args.scale_min, args.scale_max, args.scale_eval)

    # Model Configure Load
    model_config = load_configure(args.model_config, logger)
    args.sigma = args.sigma * args.scale_eval
    logger.log('Real Sigma : {:}'.format(args.sigma))

    # Training Dataset
    train_data = VDataset(train_transform, args.sigma, model_config.downsample,
                          args.heatmap_type, args.data_indicator,
                          args.video_parser)
    train_data.load_list(args.train_lists, args.num_pts, True)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # Evaluation Dataloader
    eval_loaders = []
    if args.eval_vlists is not None:
        for eval_vlist in args.eval_vlists:
            eval_vdata = IDataset(eval_transform, args.sigma,
                                  model_config.downsample, args.heatmap_type,
                                  args.data_indicator)
            eval_vdata.load_list(eval_vlist, args.num_pts, True)
            eval_vloader = torch.utils.data.DataLoader(
                eval_vdata,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append((eval_vloader, True))

    if args.eval_ilists is not None:
        for eval_ilist in args.eval_ilists:
            eval_idata = IDataset(eval_transform, args.sigma,
                                  model_config.downsample, args.heatmap_type,
                                  args.data_indicator)
            eval_idata.load_list(eval_ilist, args.num_pts, True)
            eval_iloader = torch.utils.data.DataLoader(
                eval_idata,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append((eval_iloader, False))

    # Define network
    lk_config = load_configure(args.lk_config, logger)
    logger.log('model configure : {:}'.format(model_config))
    logger.log('LK configure : {:}'.format(lk_config))
    net = obtain_model(model_config, lk_config, args.num_pts + 1)
    assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(
        model_config.downsample, net.downsample)
    logger.log("=> network :\n {}".format(net))

    logger.log('Training-data : {:}'.format(train_data))
    for i, eval_loader in enumerate(eval_loaders):
        eval_loader, is_video = eval_loader
        logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(
            i, len(eval_loaders), 'video' if is_video else 'image',
            eval_loader.dataset))

    logger.log('arguments : {:}'.format(args))

    opt_config = load_configure(args.opt_config, logger)

    if hasattr(net, 'specify_parameter'):
        net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
    else:
        net_param_dict = net.parameters()

    optimizer, scheduler, criterion = obtain_optimizer(net_param_dict,
                                                       opt_config, logger)
    logger.log('criterion : {:}'.format(criterion))
    net, criterion = net.cuda(), criterion.cuda()
    net = torch.nn.DataParallel(net)

    last_info = logger.last_info()
    if last_info.exists():
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch'] + 1
        checkpoint = torch.load(last_info['last_checkpoint'])
        assert last_info['epoch'] == checkpoint[
            'epoch'], 'Last-Info is not right {:} vs {:}'.format(
                last_info, checkpoint['epoch'])
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(
            logger.last_info(), checkpoint['epoch']))
    elif args.init_model is not None:
        init_model = Path(args.init_model)
        assert init_model.exists(), 'init-model {:} does not exist'.format(
            init_model)
        checkpoint = torch.load(init_model)
        checkpoint = remove_module_dict(checkpoint['state_dict'], True)
        net.module.detector.load_state_dict(checkpoint)
        logger.log("=> initialize the detector : {:}".format(init_model))
        start_epoch = 0
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch = 0

    detector = torch.nn.DataParallel(net.module.detector)

    eval_results = eval_all(args, eval_loaders, detector, criterion,
                            'start-eval', logger, opt_config)
    if args.eval_once:
        logger.log("=> only evaluate the model once")
        logger.close()
        return

    # Main Training and Evaluation Loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(start_epoch, opt_config.epochs):

        scheduler.step()
        need_time = convert_secs2time(
            epoch_time.avg * (opt_config.epochs - epoch), True)
        epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
        LRs = scheduler.get_lr()
        logger.log(
            '\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.
            format(time_string(), epoch_str, need_time, min(LRs), max(LRs),
                   opt_config))

        # train for one epoch
        train_loss = train(args, train_loader, net, criterion, optimizer,
                           epoch_str, logger, opt_config, lk_config,
                           epoch >= lk_config.start)
        # log the results
        logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}'.format(
            time_string(), epoch_str, train_loss))

        # remember best prec@1 and save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch,
                'args': deepcopy(args),
                'arch': model_config.arch,
                'state_dict': net.state_dict(),
                'detector': detector.state_dict(),
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            logger.path('model') /
            '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

        last_info = save_checkpoint(
            {
                'epoch': epoch,
                'last_checkpoint': save_path,
            }, logger.last_info(), logger)

        eval_results = eval_all(args, eval_loaders, detector, criterion,
                                epoch_str, logger, opt_config)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.close()