예제 #1
0
파일: train.py 프로젝트: wwwht/AtrousPose
def get_dataloader(train=True):
    dataloader = get_loader(json_path, data_dir, # 368 input size, 8 feat_stride
                            mask_dir, 384, 8,  # params_transform is a dictionary
                            'atrous_pose', BTACH_SIZE, params_transform = params_transform,
                            shuffle=True, training=train, num_workers=1, coco=False)

    print('train dataset len1: {}'.format(len(dataloader)))

    return dataloader, len(dataloader)
예제 #2
0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


print("Loading dataset...")
# load data
train_data = get_loader(args.json_path,
                        args.data_dir,
                        args.mask_dir,
                        256,
                        4,
                        'rtpose',
                        args.batch_size,
                        shuffle=True,
                        params_transform=params_transform,
                        training=True,
                        num_workers=16)
print('train dataset len: {}'.format(len(train_data.dataset)))

# validation data
valid_data = get_loader(args.json_path,
                        args.data_dir,
                        args.mask_dir,
                        256,
                        4,
                        preprocess='rtpose',
                        params_transform=params_transform,
예제 #3
0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


print("Loading dataset...")
# load data
# TODO update data loaders
train_data = get_loader(args.json_path,
                        args.data_dir,
                        args.mask_dir,
                        368,
                        8,
                        'vgg',
                        args.batch_size,
                        params_transform=params_transform,
                        shuffle=True,
                        training=True,
                        num_workers=8)
print('train dataset len: {}'.format(len(train_data.dataset)))

# validation data
valid_data = get_loader(args.json_path,
                        args.data_dir,
                        args.mask_dir,
                        368,
                        8,
                        preprocess='vgg',
                        training=False,
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


print("Loading dataset...")
# load data
train_data = get_loader(args.json_path,
                        args.data_dir,
                        args.mask_dir,
                        368,
                        8,
                        'vgg',
                        args.batch_size,
                        shuffle=True,
                        training=True)
print('train dataset len: {}'.format(len(train_data.dataset)))

# validation data
valid_data = get_loader(args.json_path,
                        args.data_dir,
                        args.mask_dir,
                        368,
                        8,
                        preprocess='vgg',
                        training=False,
                        batch_size=args.batch_size,
예제 #5
0

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1**(epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


print("Loading dataset...")
# load data
train_data = get_loader(json_path,
                        data_dir,
                        mask_dir,
                        inp_size,
                        feat_stride,
                        'vgg',
                        batch_size,
                        shuffle=True,
                        training=True)
print('train dataset len: {}'.format(len(train_data.dataset)))

# validation data
valid_data = get_loader(json_path,
                        data_dir,
                        mask_dir,
                        inp_size,
                        feat_stride,
                        preprocess='vgg',
                        training=False,
                        batch_size=batch_size,
예제 #6
0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


print("Loading dataset...")
# load data
train_data = get_loader(args.json_path,
                        args.data_dir,
                        args.mask_dir,
                        368,
                        8,
                        'rtpose',
                        args.batch_size,
                        shuffle=True,
                        training=True,
                        num_workers=16)
print('train dataset len: {}'.format(len(train_data.dataset)))

# validation data
valid_data = get_loader(args.json_path,
                        args.data_dir,
                        args.mask_dir,
                        368,
                        8,
                        preprocess='rtpose',
                        training=False,
                        batch_size=args.batch_size,
예제 #7
0
def load_data_(dataset, arch, json_path, data_dir, mask_dir, batch_size,
               workers):
    if dataset not in DATASETS_NAMES:
        raise ValueError('load_data does not support dataset %s" % dataset')

    input_size = {
        'shufflenetv2': 368,
        'vgg19': 368,
        'hourglass': 256,
    }[arch]

    input_shape = (1, 3, input_size, input_size)

    params_transform = {
        'shufflenetv2': {
            'mode': 5,
            'scale_min': 0.5,
            'scale_max': 1.1,
            'scale_prob': 1,
            'target_dist': 0.6,
            'max_rotate_degree': 40,
            'center_perterb_max': 40,
            'flip_prob': 0.5,
            'np': 56,
            'sigma': 7.0,
        },
        'hourglass': {
            'mode': 5,
            'scale_min': 0.5,
            'scale_max': 1.1,
            'scale_prob': 1,
            'target_dist': 0.6,
            'max_rotate_degree': 40,
            'center_perterb_max': 40,
            'flip_prob': 0.5,
            'np': 56,
            'sigma': 4.416,
            'limb_width': 1.289,
        },
        'vgg19': {
            'mode': 5,
            'scale_min': 0.5,
            'scale_max': 1.1,
            'scale_prob': 1,
            'target_dist': 0.6,
            'max_rotate_degree': 40,
            'center_perterb_max': 40,
            'flip_prob': 0.5,
            'np': 56,
            'sigma': 7.0,
            'limb_width': 1.,
        }
    }[arch]

    if arch == 'shufflenetv2':
        train_loader = get_loader(json_path,
                                  data_dir,
                                  mask_dir,
                                  input_size,
                                  8,
                                  preprocess='rtpose',
                                  batch_size=batch_size,
                                  params_transform=params_transform,
                                  shuffle=True,
                                  training=True,
                                  num_workers=workers)

        valid_loader = get_loader(json_path,
                                  data_dir,
                                  mask_dir,
                                  input_size,
                                  8,
                                  preprocess='rtpose',
                                  batch_size=batch_size,
                                  params_transform=params_transform,
                                  shuffle=False,
                                  training=False,
                                  num_workers=workers)
    elif arch == 'hourglass':
        train_loader = get_loader(json_path,
                                  data_dir,
                                  mask_dir,
                                  input_size,
                                  4,
                                  preprocess='rtpose',
                                  batch_size=batch_size,
                                  params_transform=params_transform,
                                  shuffle=True,
                                  training=True,
                                  num_workers=workers)

        valid_loader = get_loader(json_path,
                                  data_dir,
                                  mask_dir,
                                  input_size,
                                  4,
                                  preprocess='rtpose',
                                  batch_size=batch_size,
                                  params_transform=params_transform,
                                  shuffle=False,
                                  training=False,
                                  num_workers=workers)
    elif arch == 'vgg19':
        train_loader = get_loader(json_path,
                                  data_dir,
                                  mask_dir,
                                  input_size,
                                  8,
                                  preprocess='vgg',
                                  batch_size=batch_size,
                                  params_transform=params_transform,
                                  shuffle=True,
                                  training=True,
                                  num_workers=workers)

        valid_loader = get_loader(json_path,
                                  data_dir,
                                  mask_dir,
                                  input_size,
                                  8,
                                  preprocess='vgg',
                                  batch_size=batch_size,
                                  params_transform=params_transform,
                                  shuffle=False,
                                  training=False,
                                  num_workers=workers)
    else:
        raise ValueError

    return train_loader, valid_loader, valid_loader, input_shape