Ejemplo n.º 1
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'seed-{}-{}.log'.format(args.manualSeed,
                                             time_for_file())), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    print_log('------------ Options -------------', log)
    for k, v in sorted(vars(args).items()):
        print_log('Parameter : {:20} = {:}'.format(k, v), log)
    print_log('-------------- End ----------------', log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("Pillow version : {}".format(PIL.__version__), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.5, 0.5, 0.5]])
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(
        args.arg_flip, args.rotate_max)
    train_transform = [transforms.PreCrop(args.pre_crop_expand)]
    train_transform += [
        transforms.TrainScale2WH((args.crop_width, args.crop_height))
    ]
    train_transform += [
        transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)
    ]
    #if args.arg_flip:
    #  train_transform += [transforms.AugHorizontalFlip()]
    if args.rotate_max:
        train_transform += [transforms.AugRotate(args.rotate_max)]
    train_transform += [
        transforms.AugCrop(args.crop_width, args.crop_height,
                           args.crop_perturb_max, mean_fill)
    ]
    train_transform += [transforms.ToTensor(), normalize]
    train_transform = transforms.Compose(train_transform)

    eval_transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])
    assert (
        args.scale_min + args.scale_max
    ) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(
        args.scale_min, args.scale_max, args.scale_eval)

    args.downsample = 8  # By default
    args.sigma = args.sigma * args.scale_eval

    train_data = datasets.GeneralDataset(train_transform, args.sigma,
                                         args.downsample, args.heatmap_type,
                                         args.dataset_name)
    train_data.load_list(args.train_list, args.num_pts, True)
    if args.convert68to49:
        train_data.convert68to49()
    elif args.convert68to51:
        train_data.convert68to51()
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    eval_loaders = []
    if args.eval_lists is not None:
        for eval_list in args.eval_lists:
            eval_data = datasets.GeneralDataset(eval_transform, args.sigma,
                                                args.downsample,
                                                args.heatmap_type,
                                                args.dataset_name)
            eval_data.load_list(eval_list, args.num_pts, True)
            if args.convert68to49:
                eval_data.convert68to49()
            elif args.convert68to51:
                eval_data.convert68to51()
            eval_loader = torch.utils.data.DataLoader(
                eval_data,
                batch_size=args.eval_batch,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append(eval_loader)

    if args.convert68to49 or args.convert68to51:
        assert args.num_pts == 68, 'The format of num-pts is not right : {}'.format(
            args.num_pts)
        assert args.convert68to49 + args.convert68to51 == 1, 'Only support one convert'
        if args.convert68to49: args.num_pts = 49
        else: args.num_pts = 51

    args.modelconfig = models.ModelConfig(train_data.NUM_PTS + 1,
                                          args.cpm_stage, args.pretrain,
                                          args.argmax_size)

    if args.cycle_model_path is None:
        # define the network
        itnetwork = models.itn_model(args.modelconfig, args, log)

        cycledata = datasets.CycleDataset(train_transform, args.dataset_name)
        cycledata.set_a(args.cycle_a_lists)
        cycledata.set_b(args.cycle_b_lists)
        print_log('Cycle-data initialize done : {}'.format(cycledata), log)

        args.cycle_model_path = procedure.train_cycle_gan(
            cycledata, itnetwork, args, log)
    assert osp.isdir(
        args.cycle_model_path), '{} does not exist or is not dir.'.format(
            args.cycle_model_path)

    # start train itn-cpm model
    itn_cpm = models.__dict__[args.arch](args.modelconfig,
                                         args.cycle_model_path)
    procedure.train_san_epoch(args, itn_cpm, train_loader, eval_loaders, log)

    log.close()
def main():
    # Init logger
    if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
    log = open(
        os.path.join(
            args.save_path,
            'cluster_seed_{}_{}.txt'.format(args.manualSeed, time_for_file())),
        'w')
    print_log('save path : {}'.format(args.save_path), log)
    print_log('------------ Options -------------', log)
    for k, v in sorted(vars(args).items()):
        print_log('Parameter : {:20} = {:}'.format(k, v), log)
    print_log('-------------- End ----------------', log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("Pillow version : {}".format(PIL.__version__), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])

    args.downsample = 8  # By default
    args.sigma = args.sigma * args.scale_eval

    data = datasets.GeneralDataset(transform, args.sigma, args.downsample,
                                   args.heatmap_type, args.dataset_name)
    data.load_list(args.train_list, args.num_pts, True)
    loader = torch.utils.data.DataLoader(data,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=True)

    # Load all lists
    all_lines = {}
    for file_path in args.train_list:
        listfile = open(file_path, 'r')
        listdata = listfile.read().splitlines()
        listfile.close()
        for line in listdata:
            temp = line.split(' ')
            assert len(temp) == 6 or len(
                temp) == 7, 'This line has the wrong format : {}'.format(line)
            image_path = temp[0]
            all_lines[image_path] = line

    assert args.n_clusters >= 2, 'The cluster number must be greater than 2'
    resnet = models.resnet152(True).cuda()
    all_features = []
    for i, (inputs, target, mask, points, image_index, label_sign,
            ori_size) in enumerate(loader):
        input_vars = torch.autograd.Variable(inputs.cuda(), volatile=True)
        features, classifications = resnet(input_vars)
        features = features.cpu().data.numpy()
        all_features.append(features)
        if i % args.print_freq == 0:
            print_log(
                '{} {}/{} extract features'.format(time_string(), i,
                                                   len(loader)), log)
    all_features = np.concatenate(all_features, axis=0)
    kmeans_result = KMeans(n_clusters=args.n_clusters,
                           n_jobs=args.workers).fit(all_features)
    print_log('kmeans [{}] calculate done'.format(args.n_clusters), log)
    labels = kmeans_result.labels_.copy()

    cluster_idx = []
    for iL in range(args.n_clusters):
        indexes = np.where(labels == iL)[0]
        cluster_idx.append(len(indexes))
    cluster_idx = np.argsort(cluster_idx)

    for iL in range(args.n_clusters):
        ilabel = cluster_idx[iL]
        indexes = np.where(labels == ilabel)
        if isinstance(indexes, tuple) or isinstance(indexes, list):
            indexes = indexes[0]
        cluster_features = all_features[indexes, :].copy()
        filtered_index = filter_cluster(indexes.copy(), cluster_features, 0.8)

        print_log(
            '{:} [{:2d} / {:2d}] has {:4d} / {:4d} -> {:4d} = {:.2f} images '.
            format(time_string(), iL, args.n_clusters, indexes.size, len(data),
                   len(filtered_index), indexes.size * 1. / len(data)), log)
        indexes = filtered_index.copy()
        save_dir = osp.join(
            args.save_path,
            'cluster-{:02d}-{:02d}'.format(iL, args.n_clusters))
        save_path = save_dir + '.lst'
        #if not osp.isdir(save_path): os.makedirs( save_path )
        print_log('save into {}'.format(save_path), log)
        txtfile = open(save_path, 'w')
        for idx in indexes:
            image_path = data.datas[idx]
            assert image_path in all_lines, 'Not find {}'.format(image_path)
            txtfile.write('{}\n'.format(all_lines[image_path]))
            #basename = osp.basename( image_path )
            #os.system( 'cp {} {}'.format(image_path, save_dir) )
        txtfile.close()
Ejemplo n.º 3
0
def main(xargs):

    save_dir = Path(xargs.log_dir) / time_for_file()
    save_dir.mkdir(parents=True, exist_ok=True)

    print('save dir : {:}'.format(save_dir))
    print('xargs : {:}'.format(xargs))

    if xargs.dataset == 'cifar-10':
        train_data = reader_creator(xargs.data_path, 'data_batch', True, False)
        test__data = reader_creator(xargs.data_path, 'test_batch', False,
                                    False)
        class_num = 10
        print('create cifar-10  dataset')
    elif xargs.dataset == 'cifar-100':
        train_data = reader_creator(xargs.data_path, 'train', True, False)
        test__data = reader_creator(xargs.data_path, 'test', False, False)
        class_num = 100
        print('create cifar-100 dataset')
    else:
        raise ValueError('invalid dataset : {:}'.format(xargs.dataset))

    train_reader = paddle.batch(paddle.reader.shuffle(train_data,
                                                      buf_size=5000),
                                batch_size=xargs.batch_size)

    # Reader for testing. A separated data set for testing.
    test_reader = paddle.batch(test__data, batch_size=xargs.batch_size)

    place = fluid.CUDAPlace(0)

    main_program = fluid.default_main_program()
    star_program = fluid.default_startup_program()

    # programs
    predict = inference_program(xargs.model_name, class_num)
    [loss, accuracy] = train_program(predict)
    print('training program setup done')
    test_program = main_program.clone(for_test=True)
    print('testing  program setup done')

    #infer_writer = SummaryWriter( str(save_dir / 'infer') )
    #infer_writer.add_paddle_graph(fluid_program=fluid.default_main_program(), verbose=True)
    #infer_writer.close()
    #print(test_program.to_string(True))

    #learning_rate = fluid.layers.cosine_decay(learning_rate=xargs.lr, step_each_epoch=xargs.step_each_epoch, epochs=xargs.epochs)
    #learning_rate = fluid.layers.cosine_decay(learning_rate=0.1, step_each_epoch=196, epochs=300)
    learning_rate = cosine_decay_with_warmup(xargs.lr, xargs.step_each_epoch,
                                             xargs.epochs)
    optimizer = fluid.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=0.9,
        regularization=fluid.regularizer.L2Decay(0.0005),
        use_nesterov=True)
    optimizer.minimize(loss)

    exe = fluid.Executor(place)

    feed_var_list_loop = [
        main_program.global_block().var('pixel'),
        main_program.global_block().var('label')
    ]
    feeder = fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
    exe.run(star_program)

    start_time, epoch_time = time.time(), AverageMeter()
    for iepoch in range(xargs.epochs):
        losses, accuracies, steps = AverageMeter(), AverageMeter(), 0
        for step_id, train_data in enumerate(train_reader()):
            tloss, tacc, xlr = exe.run(
                main_program,
                feed=feeder.feed(train_data),
                fetch_list=[loss, accuracy, learning_rate])
            tloss, tacc, xlr = float(tloss), float(tacc) * 100, float(xlr)
            steps += 1
            losses.update(tloss, len(train_data))
            accuracies.update(tacc, len(train_data))
            if step_id % 100 == 0:
                print(
                    '{:} [{:03d}/{:03d}] [{:03d}] lr = {:.7f}, loss = {:.4f} ({:.4f}), accuracy = {:.2f} ({:.2f}), error={:.2f}'
                    .format(time_string(), iepoch, xargs.epochs, step_id, xlr,
                            tloss, losses.avg, tacc, accuracies.avg,
                            100 - accuracies.avg))
        test_loss, test_acc = evaluation(test_program, test_reader,
                                         [loss, accuracy], place)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.avg * (xargs.epochs - iepoch), True))
        print(
            '{:}x[{:03d}/{:03d}] {:} train-loss = {:.4f}, train-accuracy = {:.2f}, test-loss = {:.4f}, test-accuracy = {:.2f} test-error = {:.2f} [{:} steps per epoch]\n'
            .format(time_string(), iepoch, xargs.epochs, need_time, losses.avg,
                    accuracies.avg, test_loss, test_acc, 100 - test_acc,
                    steps))
        if isinstance(predict, list):
            fluid.io.save_inference_model(str(save_dir / 'inference_model'),
                                          ["pixel"], predict, exe)
        else:
            fluid.io.save_inference_model(str(save_dir / 'inference_model'),
                                          ["pixel"], [predict], exe)
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    print('finish training and evaluation with {:} epochs in {:}'.format(
        xargs.epochs, convert_secs2time(epoch_time.sum, True)))
Ejemplo n.º 4
0
def main():

  # Init logger
  if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'cluster_seed_{}_{}.txt'.format(args.manualSeed, time_for_file())), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  print_log('------------ Options -------------', log)
  for k, v in sorted(vars(args).items()):
    print_log('Parameter : {:20} = {:}'.format(k, v), log)
  print_log('-------------- End ----------------', log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("Pillow version : {}".format(PIL.__version__), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # finetune resnet-152 to train style-discriminative features
  resnet = models.resnet152(True, num_classes=4)
  resnet = torch.nn.DataParallel(resnet)
  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(resnet.parameters(), args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.decay)
  cls_train_dir = args.style_train_root
  cls_eval_dir = args.style_eval_root
  assert osp.isdir(cls_train_dir), 'Does not know : {}'.format(cls_train_dir)
  # train data loader
  vision_normalize = visiontransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  print_log('Training dir : {}'.format(cls_train_dir), log)
  print_log('Evaluate dir : {}'.format(cls_eval_dir), log)
  cls_train_dataset = visiondatasets.ImageFolder(
        cls_train_dir,
        visiontransforms.Compose([
            visiontransforms.RandomResizedCrop(224),
            visiontransforms.RandomHorizontalFlip(),
            visiontransforms.ToTensor(),
            vision_normalize,
        ]))

  cls_train_loader = torch.utils.data.DataLoader(
        cls_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)

  if cls_eval_dir is not None:
    assert osp.isdir(cls_eval_dir), 'Does not know : {}'.format(cls_eval_dir)
    val_loader = torch.utils.data.DataLoader(
        visiondatasets.ImageFolder(cls_eval_dir, visiontransforms.Compose([
            visiontransforms.Resize(256),
            visiontransforms.CenterCrop(224),
            visiontransforms.ToTensor(),
            vision_normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
  else: val_loader = None

  for epoch in range(args.epochs):
    learning_rate = adjust_learning_rate(optimizer, epoch, args)
    print_log('epoch : [{}/{}] lr={}'.format(epoch, args.epochs, learning_rate), log)
    top1, losses = AverageMeter(), AverageMeter()
    resnet.train()
    for i, (inputs, target) in enumerate(cls_train_loader):
      #target = target.cuda(async=True)
      # compute output
      _, output = resnet(inputs)
      loss = criterion(output, target)

      # measure accuracy and record loss
      prec1 = accuracy(output.data, target, topk=[1])
      top1.update(prec1.item(), inputs.size(0))
      losses.update(loss.item(), inputs.size(0))
      # compute gradient and do SGD step
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      if i % args.print_freq == 0 or i+1 == len(cls_train_loader):
        print_log(' [Train={:03d}] [{:}] [{:3d}/{:3d}] accuracy : {:.1f}, loss : {:.4f}, input:{:}, output:{:}'.format(epoch, time_string(), i, len(cls_train_loader), top1.avg, losses.avg, inputs.size(), output.size()), log)

    if val_loader is None: continue

    # evaluate the model
    resnet.eval()
    top1, losses = AverageMeter(), AverageMeter()
    for i, (inputs, target) in enumerate(val_loader):
      #target = target.cuda(async=True)
      # compute output
      with torch.no_grad():
        _, output = resnet(inputs)
        loss = criterion(output, target)
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target, topk=[1])
      top1.update(prec1.item(), inputs.size(0))
      losses.update(loss.item(), inputs.size(0))
      if i % args.print_freq_eval == 0 or i+1 == len(val_loader):
        print_log(' [Evalu={:03d}] [{:}] [{:3d}/{:3d}] accuracy : {:.1f}, loss : {:.4f}, input:{:}, output:{:}'.format(epoch, time_string(), i, len(val_loader), top1.avg, losses.avg, inputs.size(), output.size()), log)
    

  # extract features
  resnet.eval()
  # General Data Argumentation
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
  transform  = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)),  transforms.ToTensor(), normalize])

  args.downsample = 8 # By default
  args.sigma = args.sigma * args.scale_eval
  data = datasets.GeneralDataset(transform, args.sigma, args.downsample, args.heatmap_type, args.dataset_name)
  data.load_list(args.train_list, args.num_pts, True)
  loader = torch.utils.data.DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

  # Load all lists
  all_lines = {}
  for file_path in args.train_list:
    listfile = open(file_path, 'r')
    listdata = listfile.read().splitlines()
    listfile.close()
    for line in listdata:
      temp = line.split(' ')
      assert len(temp) == 6  or len(temp) == 7, 'This line has the wrong format : {}'.format(line)
      image_path = temp[0]
      all_lines[ image_path ] = line

  assert args.n_clusters >= 2, 'The cluster number must be greater than 2'
  all_features = []
  for i, (inputs, target, mask, points, image_index, label_sign, ori_size) in enumerate(loader):
    with torch.no_grad():
      features, _ = resnet(inputs)
      features = features.cpu().numpy()
    all_features.append( features )
    if i % args.print_freq == 0:
      print_log('{} {}/{} extract features'.format(time_string(), i, len(loader)), log)
  all_features = np.concatenate(all_features, axis=0)
  kmeans_result = KMeans(n_clusters=args.n_clusters, n_jobs=args.workers).fit( all_features )
  print_log('kmeans [{}] calculate done'.format(args.n_clusters), log)
  labels = kmeans_result.labels_.copy()

  cluster_idx = []
  for iL in range(args.n_clusters):
    indexes = np.where( labels == iL )[0]
    cluster_idx.append( len(indexes) )
  cluster_idx = np.argsort(cluster_idx)
    
  for iL in range(args.n_clusters):
    ilabel = cluster_idx[iL]
    indexes = np.where( labels == ilabel )
    if isinstance(indexes, tuple) or isinstance(indexes, list): indexes = indexes[0]
    #cluster_features = all_features[indexes,:].copy()
    #filtered_index = filter_cluster(indexes.copy(), cluster_features, 0.8)
    filtered_index = indexes.copy()

    print_log('{:} [{:2d} / {:2d}] has {:4d} / {:4d} -> {:4d} = {:.2f} images'.format(time_string(), iL, args.n_clusters, indexes.size, len(data), len(filtered_index), indexes.size*1./len(data)), log)
    indexes = filtered_index.copy()
    save_dir = osp.join(args.save_path, 'cluster-{:02d}-{:02d}'.format(iL, args.n_clusters))
    save_path = save_dir + '.lst'
    #if not osp.isdir(save_path): os.makedirs( save_path )
    print_log('save into {}'.format(save_path), log)
    txtfile = open( save_path , 'w')
    for idx in indexes:
      image_path = data.datas[idx]
      assert image_path in all_lines, 'Not find {}'.format(image_path)
      txtfile.write('{}\n'.format(all_lines[image_path]))
      #basename = osp.basename( image_path )
      #os.system( 'cp {} {}'.format(image_path, save_dir) )
    txtfile.close()
Ejemplo n.º 5
0
def main(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    prepare_seed(args.rand_seed)
    logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
    logger = Logger(args.save_path, logstr)
    logger.log('Main Function with logger : {:}'.format(logger))
    logger.log('Arguments : -------------------------------')
    for name, value in args._get_kwargs():
        logger.log('{:16} : {:}'.format(name, value))
    logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
    logger.log("Pillow  version : {}".format(PIL.__version__))
    logger.log("PyTorch version : {}".format(torch.__version__))
    logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max)
    train_transform = [transforms.PreCrop(args.pre_crop_expand)]
    train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))]
    train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
    # if args.arg_flip:
    #  train_transform += [transforms.AugHorizontalFlip()]
    if args.rotate_max:
        train_transform += [transforms.AugRotate(args.rotate_max)]
    train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
    train_transform += [transforms.ToTensor(), normalize]
    train_transform = transforms.Compose(train_transform)

    eval_transform = transforms.Compose(
        [transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)),
         transforms.ToTensor(), normalize])
    assert (args.scale_min + args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(
        args.scale_min, args.scale_max, args.scale_eval)

    # Model Configure Load
    model_config = load_configure(args.model_config, logger)
    args.sigma = args.sigma * args.scale_eval
    logger.log('Real Sigma : {:}'.format(args.sigma))

    # Training Dataset
    train_data = GeneralDataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
    train_data.load_list(args.train_lists, args.num_pts, True)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
                                               shuffle=True,num_workers=args.workers,
                                               pin_memory=True)
    # Evaluation Dataloader
    eval_loaders = []

    if args.eval_ilists is not None:
        for eval_ilist in args.eval_ilists:
            eval_idata = GeneralDataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type,
                                 args.data_indicator)
            eval_idata.load_list(eval_ilist, args.num_pts, True)
            eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False,
                                                       num_workers=args.workers, pin_memory=True)
            eval_loaders.append((eval_iloader, False))

    # Define network
    logger.log('configure : {:}'.format(model_config))
    net = obtain_model(model_config, args.num_pts + 1)

    assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(
        model_config.downsample, net.downsample)
    logger.log("=> network :\n {}".format(net))

    logger.log('Training-data : {:}'.format(train_data))
    for i, eval_loader in enumerate(eval_loaders):
        eval_loader, is_video = eval_loader
        logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders),
                                                                          'video' if is_video else 'image',
                                                                          eval_loader.dataset))
    logger.log('arguments : {:}'.format(args))
    opt_config = load_configure(args.opt_config, logger)

    if hasattr(net, 'specify_parameter'):
        net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
    else:
        net_param_dict = net.parameters()

    optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger)
    logger.log('criterion : {:}'.format(criterion))
    net, criterion = net.cuda(), criterion.cuda()
    net = torch.nn.DataParallel(net)

    last_info = logger.last_info()
    if last_info.exists():
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
        last_info = torch.load(str(last_info))
        start_epoch = last_info['epoch'] + 1
        checkpoint = torch.load(last_info['last_checkpoint'])
        assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info,
                                                                                                     checkpoint[
                                                                                                         'epoch'])
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(logger.last_info(), checkpoint['epoch']))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch = 0

    if args.eval_once:
        logger.log("=> only evaluate the model once")
        eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config)
        logger.close()
        return

        # Main Training and Evaluation Loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(start_epoch, opt_config.epochs):
        scheduler.step()
        need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs - epoch), True)
        epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
        LRs = scheduler.get_lr()
        logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str,
                                                                                            need_time, min(LRs),
                                                                                            max(LRs), opt_config))

        # train for one epoch
        train_loss, train_nme = train(args, train_loader, net, criterion,
                                      optimizer, epoch_str, logger, opt_config)
        # log the results
        logger.log(
            '==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss,
                                                                              train_nme * 100))

        # remember best prec@1 and save checkpoint
        save_path = save_checkpoint({
            'epoch': epoch,
            'args': deepcopy(args),
            'arch': model_config.arch,
            'state_dict': net.state_dict(),
            'scheduler': scheduler.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, str(logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str)), logger)

        last_info = save_checkpoint({
            'epoch': epoch,
            'last_checkpoint': save_path,
        }, str(logger.last_info()), logger)

        eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.close()