コード例 #1
0
def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)

    args, rest = parser.parse_known_args()
    # update config
    update_config(args.cfg)

    # training
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=config.PRINT_FREQ,
                        type=int)
    parser.add_argument('--gpus', help='gpus', type=str)
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int)

    args = parser.parse_args()

    return args
コード例 #2
0
ファイル: train.py プロジェクト: 1061136002/2D-TAN
def parse_args():
    parser = argparse.ArgumentParser(description='Train localization network')

    # general
    #parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str)
    parser.add_argument(
        '--cfg', default="experiments/charades/2D-TAN-16x16-K5L8-pool.yaml")
    args, rest = parser.parse_known_args()

    # update config
    update_config(args.cfg)

    # training
    parser.add_argument('--gpus', help='gpus', type=str)
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int)
    parser.add_argument('--dataDir', help='data path', type=str)
    parser.add_argument('--modelDir', help='model path', type=str)
    parser.add_argument('--logDir', help='log path', type=str)
    parser.add_argument('--verbose',
                        default=False,
                        action="store_true",
                        help='print progress bar')
    parser.add_argument('--tag', help='tags shown in log', type=str)
    args = parser.parse_args()

    return args
コード例 #3
0
def parse_args():
    """
    args for training.
    """
    parser = argparse.ArgumentParser(description='Train Ocean')
    # general
    parser.add_argument('--cfg',
                        type=str,
                        default='experiments/train/Ocean.yaml',
                        help='yaml configure file name')

    args, rest = parser.parse_known_args()
    # update config
    update_config(args.cfg)

    parser.add_argument('--gpus', type=str, help='gpus')
    parser.add_argument('--workers',
                        type=int,
                        help='num of dataloader workers')
    parser.add_argument('--WORKDIR', type=str, default='')
    parser.add_argument('--CHECKPOINT_DIR', type=str, default='')
    parser.add_argument('--OUTPUT_DIR', type=str, default='')
    parser.add_argument('--SUPERNET_PATH', type=str, default='')
    parser.add_argument('--DP', type=int, default=0)
    args = parser.parse_args()

    return args
コード例 #4
0
def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    # It can be mobilenetv2, shufflenetv2, mnasnet, resnet18, darts, nasnet, pairnas
    parser.add_argument('--net',
                        help='experiment configure file name',
                        default='pairnas',
                        type=str)
    parser.add_argument(
        '--cfg', default='experiments/coco_256x192_d256x3_adam_lr1e-3.ymal')
    parser.add_argument('--dataset_path',
                        default='/raid/huangsh/datasets/MSCOCO2017/')
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--flip', help='use flip test', default=True)
    parser.add_argument('--post_process',
                        help='use post process',
                        action='store_true')
    args, rest = parser.parse_known_args()
    # update config
    update_config(args.cfg)
    # training
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int,
                        default=16)
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=config.PRINT_FREQ,
                        type=int)
    parser.add_argument('--use-detect-bbox',
                        help='use detect bbox',
                        action='store_true')
    parser.add_argument('--shift-heatmap',
                        help='shift heatmap',
                        action='store_true')
    parser.add_argument('--coco-bbox-file',
                        help='coco detection bbox file',
                        type=str)

    args = parser.parse_args()

    return args
コード例 #5
0
def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)

    args, rest = parser.parse_known_args()
    # update config
    update_config(args.cfg)

    # training
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=config.PRINT_FREQ,
                        type=int)
    parser.add_argument('--gpus', help='gpus', type=str)
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int)
    parser.add_argument('--model-file', help='model state file', type=str)
    parser.add_argument('--use-detect-bbox',
                        help='use detect bbox',
                        action='store_true')
    parser.add_argument('--flip-test',
                        help='use flip test',
                        action='store_true')
    parser.add_argument('--post-process',
                        help='use post process',
                        action='store_true')
    parser.add_argument('--shift-heatmap',
                        help='shift heatmap',
                        action='store_true')
    parser.add_argument('--coco-bbox-file',
                        help='coco detection bbox file',
                        type=str)

    args = parser.parse_args()

    return args
コード例 #6
0
def pose_inference_on_image(cfg_path, image):

    # Load config file
    update_config(cfg_path)
    torch.backends.cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED
    image_size = config.MODEL.IMAGE_SIZE[0]

    # Create Model
    model = models.pose3d_resnet.get_pose_net(config, is_train=False)
    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus)
    print('Created model...')

    checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
    model.load_state_dict(checkpoint)
    model.eval()
    print('Loaded pre-trained weights...')

    image = cv2.resize(image, (image_size, image_size))

    img_height, img_width, img_channels = image.shape
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_patch = convert_cvimg_to_tensor(image)

    mean = np.array([123.675, 116.280, 103.530])
    std = np.array([58.395, 57.120, 57.375])

    # apply normalization
    for n_c in range(img_channels):
        if mean is not None and std is not None:
            img_patch[n_c, :, :] = (img_patch[n_c, :, :] -
                                    mean[n_c]) / std[n_c]
    img_patch = torch.from_numpy(img_patch)
    preds = model(img_patch[None, ...])
    preds = get_joint_location_result(image_size, image_size, preds)[0, :, :3]

    return image, preds
コード例 #7
0
def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # It can be mobilenetv2, shufflenetv2, mnasnet, darts, pairnas, nasnet
    parser.add_argument('--net', default='pairnas', type=str)
    parser.add_argument('--dataset_path',
                        default='/raid/huangsh/datasets/MSCOCO2017/')
    parser.add_argument(
        '--cfg', default='experiments/coco_256x192_d256x3_adam_lr1e-3.ymal')
    parser.add_argument('--model', default='posenet')
    parser.add_argument('--gpu', type=str, default='0')
    args, rest = parser.parse_known_args()
    # update config
    update_config(args.cfg)
    # training
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=config.PRINT_FREQ,
                        type=int)
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int,
                        default=16)
    args = parser.parse_args()
    return args
コード例 #8
0
import torch
from torch.utils.data import DataLoader

import lib.models.models as models
from lib.utils import print_speed, load_pretrain, save_model
from lib.dataset import SiamFCDataset
from lib.core.config import config, update_config
from lib.core.function import siamfc_train

parser = argparse.ArgumentParser()
parser.add_argument('--cfg',
                    required=True,
                    type=str,
                    help='yaml configure file name')
args = parser.parse_args()
update_config(args.cfg)

print('Config:')
print(pprint.pformat(config))
print()

model = models.__dict__[config.SIAMFC.TRAIN.MODEL]()
model = load_pretrain(model, config.SIAMFC.TRAIN.PRETRAIN)

trainable_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.SGD(trainable_params,
                            config.SIAMFC.TRAIN.LR,
                            momentum=config.SIAMFC.TRAIN.MOMENTUM,
                            weight_decay=config.SIAMFC.TRAIN.WEIGHT_DECAY)
コード例 #9
0
def parse_args():
    parser = argparse.ArgumentParser(
        "Train the unsupervised human pose estimation network")
    parser.add_argument(
        '--cfg',
        help="Specify the path of the path of the config(*.yaml)",
        default='../cfg/default.yaml')
    parser.add_argument(
        '--use_gt',
        action='store_true',
        help='Specify whether to use 2d gt / predictions as inputs')
    parser.add_argument('--model_dir',
                        help='Specify the directory of pretrained model',
                        default='')
    parser.add_argument('--data_dir',
                        help="Specify the directory of data",
                        default=config.DATA_DIR)
    parser.add_argument('--log_dir',
                        help='Specify the directory of output',
                        default=config.LOG_DIR)

    parser.add_argument('--dataset_name',
                        help="Specify which dataset to use",
                        choices=["h36m", "mpi"],
                        default="h36m")
    parser.add_argument(
        '--workers',
        help="Specify the number of workers for data loadering",
        default=config.NUM_WORKERS)
    parser.add_argument('--gpu',
                        help="Specify the gpu to use for training",
                        default='')
    parser.add_argument('--debug',
                        action='store_true',
                        help="Turn on the debug mode")
    parser.add_argument(
        '--print_info',
        action='store_true',
        help="Whether to print detailed information in tqdm processing")
    parser.add_argument(
        '--eval',
        action='store_true',
        help="Evaluate the model on the dataset(i.e. generate: joint_3d_pre)")
    parser.add_argument(
        '--eval_suffix',
        default=None,
        help="Specify the suffix to save predictions on 3D in evaluation mode")
    parser.add_argument('--pretrain',
                        default='',
                        help="Whether to use pretrain model")
    parser.add_argument('--finetune_rotater',
                        action='store_true',
                        help="Load pretrained model and finetune rotater")
    parser.add_argument('--print_interval', type=int, default=50)
    args = parser.parse_args()
    if args.cfg:
        update_config(args.cfg)
    else:
        print("Using default config...")
    update_dir(args.model_dir, args.log_dir, args.data_dir, args.debug)
    return args
コード例 #10
0
import torch
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn

import cv2
import numpy as np
import matplotlib.pyplot as plt

import argparse

cudnn.benchmark = config.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = config.CUDNN.ENABLED

config.TEST.FLIP_TEST = True
update_config('pretrained/384x288_d256x3_adam_lr1e-3.yaml')
model = get_pose_net(config, is_train=False)
model.load_state_dict(torch.load('pretrained/pose_resnet_50_384x288.pth.tar'))

gpus = [int(i) for i in config.GPUS.split(',')]
model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
toTensor = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean, std)])

def get_keypoints(input_image):
    '''
    Calculates keypoints based on resnet
    Input: Image