parser.add_argument('--opt-level', type=str, default='O1')
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)

args = parser.parse_args()

# ###################################  Setup for some configurations ###########################################
opt = TrainingOpt()
config = GetConfig(opt.config_name)

limbSeq = config.limbs_conn
dt_gt_mapping = config.dt_gt_mapping
flip_heat_ord = config.flip_heat_ord
flip_paf_ord = config.flip_paf_ord
# ###############################################################################################################
soureconfig = COCOSourceConfig(opt.hdf5_train_data)
train_data = MyDataset(config, soureconfig, shuffle=False,
                       augment=True)  # shuffle in data loader
train_loader = DataLoader(
    train_data,
    batch_size=opt.batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=16,
    pin_memory=True
)  # num_workers is tuned according to project, too big or small is not good.


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
                    metavar='N',
                    help='print frequency (default: 10)')

args = parser.parse_args()
# ##############################################################################################################
# ###################################  Setup for some configurations ###########################################
# ##############################################################################################################
# 如果我们每次训练的输入数据的size不变,那么开启这个就会加快我们的训练速度
torch.backends.cudnn.benchmark = True
use_cuda = torch.cuda.is_available()
checkpoint_path = args.checkpoint_path

# > TOCHECK: training configs
opt = TrainingOpt()
config = GetConfig(opt.config_name)
soureconfig = COCOSourceConfig(opt.hdf5_train_data)  # > 512.h5
train_data = MyDataset(config, soureconfig, shuffle=False,
                       augment=True)  # shuffle in data loader

soureconfig_val = COCOSourceConfig(opt.hdf5_val_data)
val_data = MyDataset(config, soureconfig_val, shuffle=False,
                     augment=False)  # shuffle in data loader

best_loss = float('inf')
start_epoch = 0  # 从0开始或者从上一个epoch开始

args.distributed = False
if 'WORLD_SIZE' in os.environ:
    args.distributed = int(os.environ['WORLD_SIZE']) > 1

args.gpu = 0
            # show the generated ground truth
            if show_image:
                show_labels = cv2.resize(labels.transpose((1, 2, 0)),
                                         image.shape[:2],
                                         interpolation=cv2.INTER_CUBIC)
                # offsets = cv2.resize(offsets.transpose((1, 2, 0)), image.shape[:2], interpolation=cv2.INTER_NEAREST)
                mask_miss = np.repeat(mask_miss.transpose((1, 2, 0)),
                                      3,
                                      axis=2)
                # mask_miss = cv2.resize(mask_miss, image.shape[:2], interpolation=cv2.INTER_NEAREST)
                image = cv2.resize(image,
                                   mask_miss.shape[:2],
                                   interpolation=cv2.INTER_NEAREST)
                plt.imshow(image[:, :, [2, 1, 0]])  # Opencv image format: BGR
                plt.imshow(labels.transpose((1, 2, 0))[:, :, 20],
                           alpha=0.5)  # mask_all
                # plt.imshow(show_labels[:, :, 3], alpha=0.5)  # mask_all
                plt.show()
                t = 2
        print("produce %d samples per second: " % (batch / (time() - start)))

    config = GetConfig("Canonical")
    soureconfig = COCOSourceConfig(
        "../data/dataset/coco/link2coco2017/coco_val_dataset384.h5")

    val_client = MyDataset(config, soureconfig, shuffle=True,
                           augment=True)  # shuffle in data loader
    # test the data generator
    test_augmentation_speed(val_client, True)