예제 #1
0
def get_dataloader(data_dir, batch_size, num_workers, input_size, mean, std,
                   distributed):
    """Get dataloader."""
    def val_batch_fn(batch, device):
        data = batch[0].to(device)
        scale = batch[1]
        center = batch[2]
        score = batch[3]
        imgid = batch[4]
        return data, scale, center, score, imgid

    val_dataset = COCOKeyPoints(data_dir,
                                aspect_ratio=4. / 3.,
                                splits=('person_keypoints_val2017'))

    meanvec = [float(i) for i in mean.split(',')]
    stdvec = [float(i) for i in std.split(',')]
    transform_val = SimplePoseDefaultValTransform(
        num_joints=val_dataset.num_joints,
        joint_pairs=val_dataset.joint_pairs,
        image_size=input_size,
        mean=meanvec,
        std=stdvec)
    val_tmp = val_dataset.transform(transform_val)
    sampler = make_data_sampler(val_tmp, False, distributed)
    batch_sampler = data.BatchSampler(sampler=sampler,
                                      batch_size=batch_size,
                                      drop_last=False)
    val_data = data.DataLoader(val_tmp,
                               batch_sampler=batch_sampler,
                               num_workers=num_workers)

    return val_dataset, val_data, val_batch_fn
예제 #2
0
def get_dataloader(val_dataset, batch_size, num_workers, distributed, coco=False):
    """Get dataloader."""
    if coco:
        batchify_fn = Tuple(*[Append() for _ in range(3)], Empty())
    else:
        batchify_fn = Tuple(*[Append() for _ in range(3)])
    sampler = make_data_sampler(val_dataset, False, distributed)
    batch_sampler = data.BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False)
    val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, collate_fn=batchify_fn,
                                 num_workers=num_workers)
    return val_loader
예제 #3
0
def get_dataloader(val_dataset, batch_size, num_workers, distributed, coco=False):
    """Get dataloader."""
    if coco:
        batchify_fn = Tuple(Stack(), Pad(pad_val=-1), Empty())
    else:
        batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
    sampler = make_data_sampler(val_dataset, False, distributed)
    batch_sampler = data.BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False)
    val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, collate_fn=batchify_fn,
                                 num_workers=num_workers)
    return val_loader
예제 #4
0
def get_dataloader(batch_size, num_workers, data_root, distributed):
    transform_test = transforms.Compose([
        transforms_cv.ToTensor(),
        transforms_cv.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])
    val_dataset = CIFAR10(root=data_root, train=False, transform=transform_test, download=True)

    sampler = make_data_sampler(val_dataset, False, distributed)
    batch_sampler = data.BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False)
    val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, num_workers=num_workers)
    return val_loader
예제 #5
0
def get_dataloader(opt, distributed):
    input_size = opt.input_size
    crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
    resize = int(math.ceil(input_size / crop_ratio))
    transform_test = transforms_cv.Compose([
        transforms_cv.Resize((resize, resize)),
        transforms_cv.CenterCrop(input_size),
        transforms_cv.ToTensor(),
        transforms_cv.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_dataset = ImageNet(opt.data_dir, train=False, transform=transform_test)

    sampler = make_data_sampler(val_dataset, False, distributed)
    batch_sampler = data.BatchSampler(sampler=sampler, batch_size=opt.batch_size, drop_last=False)
    val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, num_workers=opt.num_workers)
    return val_loader
예제 #6
0
    def __init__(self, args):
        self.device = torch.device(args.device)
        # network
        net_name = '_'.join(('yolo3', args.network, args.dataset))
        self.save_prefix = net_name
        self.net = get_model(net_name, pretrained_base=True)
        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        if args.resume.strip():
            logger.info("Resume from the model {}".format(args.resume))
            self.net.load_state_dict(torch.load(args.resume.strip()))
        else:
            logger.info("Init from base net {}".format(args.network))
        classes, anchors = self.net.num_class, self.net.anchors
        self.net.set_nms(nms_thresh=0.45, nms_topk=400)
        if args.label_smooth:
            self.net._target_generator._label_smooth = True
        self.net.to(self.device)
        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net, device_ids=[args.local_rank], output_device=args.local_rank)

        # dataset and dataloader
        train_dataset = get_train_data(args.dataset, args.mixup)
        width, height = args.data_shape, args.data_shape
        batchify_fn = Tuple(*([Stack() for _ in range(6)] + [Pad(axis=0, pad_val=-1) for _ in range(1)]))
        train_dataset = train_dataset.transform(
            YOLO3DefaultTrainTransform(width, height, classes, anchors, mixup=args.mixup))
        args.per_iter = len(train_dataset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(train_dataset)
        else:
            sampler = data.RandomSampler(train_dataset)
        train_sampler = data.sampler.BatchSampler(sampler=sampler, batch_size=args.batch_size,
                                                  drop_last=False)
        train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=args.max_iter)
        if args.no_random_shape:
            self.train_loader = data.DataLoader(train_dataset, batch_sampler=train_sampler, pin_memory=True,
                                                collate_fn=batchify_fn, num_workers=args.num_workers)
        else:
            transform_fns = [YOLO3DefaultTrainTransform(x * 32, x * 32, classes, anchors, mixup=args.mixup)
                             for x in range(10, 20)]
            self.train_loader = RandomTransformDataLoader(transform_fns, train_dataset, batch_sampler=train_sampler,
                                                          collate_fn=batchify_fn, num_workers=args.num_workers)
        if args.eval_epoch > 0:
            # TODO: rewrite it
            val_dataset, self.metric = get_test_data(args.dataset)
            val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
            val_dataset = val_dataset.transform(YOLO3DefaultValTransform(width, height))
            val_sampler = make_data_sampler(val_dataset, False, args.distributed)
            val_batch_sampler = data.BatchSampler(val_sampler, args.test_batch_size, False)
            self.val_loader = data.DataLoader(val_dataset, batch_sampler=val_batch_sampler,
                                              collate_fn=val_batchify_fn, num_workers=args.num_workers)

        # optimizer and lr scheduling
        self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum,
                                   weight_decay=args.wd)
        if args.lr_mode == 'cos':
            self.scheduler = WarmupCosineLR(optimizer=self.optimizer, T_max=args.max_iter,
                                            warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters)
        elif args.lr_mode == 'step':
            lr_decay = float(args.lr_decay)
            milestones = sorted([float(ls) * args.per_iter for ls in args.lr_decay_epoch.split(',') if ls.strip()])
            self.scheduler = WarmupMultiStepLR(optimizer=self.optimizer, milestones=milestones, gamma=lr_decay,
                                               warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters)
        else:
            raise ValueError('illegal scheduler type')
        self.args = args
예제 #7
0
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
    ])

    data_kwargs = {
        'base_size': args.base_size,
        'crop_size': args.crop_size,
        'transform': input_transform
    }

    val_dataset = get_segmentation_dataset(args.dataset,
                                           split=args.split,
                                           mode=args.mode,
                                           **data_kwargs)
    sampler = make_data_sampler(val_dataset, False, distributed)
    batch_sampler = data.BatchSampler(sampler=sampler,
                                      batch_size=args.batch_size,
                                      drop_last=False)
    val_data = data.DataLoader(val_dataset,
                               shuffle=False,
                               batch_sampler=batch_sampler,
                               num_workers=args.num_workers)
    if args.multi:
        evaluator = MultiEvalModel(model, val_dataset.num_class)
    else:
        evaluator = SegEvalModel(model)
    metric = SegmentationMetric(val_dataset.num_class)

    metric = validate(evaluator, val_data, metric, device)
    ptutil.synchronize()
예제 #8
0
    def train(self):
        train_dataset = CIFAR10(root=os.path.join(self.cfg.data_root,
                                                  'cifar10'),
                                train=True,
                                transform=self.transform_train,
                                download=True)
        train_sampler = make_data_sampler(train_dataset, True,
                                          self.distributed)
        train_batch_sampler = data.sampler.BatchSampler(
            train_sampler, self.cfg.batch_size, True)
        train_data = data.DataLoader(train_dataset,
                                     num_workers=self.cfg.num_workers,
                                     batch_sampler=train_batch_sampler)

        val_dataset = CIFAR10(root=os.path.join(self.cfg.data_root, 'cifar10'),
                              train=False,
                              transform=self.transform_test)
        val_sampler = make_data_sampler(val_dataset, False, self.distributed)
        val_batch_sampler = data.sampler.BatchSampler(val_sampler,
                                                      self.cfg.batch_size,
                                                      False)
        val_data = data.DataLoader(val_dataset,
                                   num_workers=self.cfg.num_workers,
                                   batch_sampler=val_batch_sampler)

        optimizer = optim.SGD(self.net.parameters(),
                              nesterov=True,
                              lr=self.cfg.lr,
                              weight_decay=self.cfg.wd,
                              momentum=self.cfg.momentum)
        metric = Accuracy()
        train_metric = Accuracy()
        loss_fn = nn.CrossEntropyLoss()
        if is_main_process():
            train_history = TrainingHistory(
                ['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0
        best_val_score = 0

        for epoch in range(self.cfg.num_epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            if epoch == self.lr_decay_epoch[lr_decay_count]:
                set_learning_rate(
                    optimizer,
                    get_learning_rate(optimizer) * self.cfg.lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                image = batch[0].to(self.device)
                label = batch[1].to(self.device)

                output = self.net(image)
                loss = loss_fn(output, label)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                train_metric.update(label, output)
                iteration += 1

            metric = self.validate(val_data, metric)
            synchronize()
            train_loss /= num_batch
            train_loss = reduce_list(all_gather(train_loss))
            name, acc = accumulate_metric(train_metric)
            name, val_acc = accumulate_metric(metric)
            if is_main_process():
                train_history.update([1 - acc, 1 - val_acc])
                train_history.plot(save_path='%s/%s_history.png' %
                                   (self.plot_path, self.cfg.model))
                if val_acc > best_val_score:
                    best_val_score = val_acc
                    torch.save(
                        self.net.state_dict(), '%s/%.4f-cifar-%s-%d-best.pth' %
                        (self.save_dir, best_val_score, self.cfg.model, epoch))
                logging.info(
                    '[Epoch %d] train=%f val=%f loss=%f time: %f' %
                    (epoch, acc, val_acc, train_loss, time.time() - tic))

                if self.save_period and self.cfg.save_dir and (
                        epoch + 1) % self.save_period == 0:
                    torch.save(
                        self.net.module.state_dict() if self.distributed else
                        self.net.state_dict(), '%s/cifar10-%s-%d.pth' %
                        (self.save_dir, self.cfg.model, epoch))

        if is_main_process() and self.save_period and self.save_dir:
            torch.save(
                self.net.module.state_dict() if self.distributed else
                self.net.state_dict(), '%s/cifar10-%s-%d.pth' %
                (self.save_dir, self.cfg.model, self.cfg.num_epochs - 1))
예제 #9
0
    def __init__(self, args):
        self.device = torch.device(args.device)
        self.save_prefix = '_'.join((args.model, args.backbone, args.dataset))
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        args.per_iter = len(trainset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, args.batch_size,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=args.max_iter)
        self.train_loader = data.DataLoader(trainset,
                                            batch_sampler=train_sampler,
                                            pin_memory=True,
                                            num_workers=args.workers)
        if not args.skip_eval or 0 < args.eval_epochs < args.epochs:
            valset = get_segmentation_dataset(args.dataset,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, args.test_batch_size, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=args.workers,
                pin_memory=True)

        # create network
        if args.model_zoo is not None:
            self.net = get_model(args.model_zoo, pretrained=True)
        else:
            self.net = get_segmentation_model(model=args.model,
                                              dataset=args.dataset,
                                              backbone=args.backbone,
                                              aux=args.aux,
                                              dilated=args.dilated,
                                              jpu=args.jpu,
                                              crop_size=args.crop_size)
        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        self.net.to(self.device)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                self.net.load_state_dict(torch.load(args.resume))
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if args.ohem:
            min_kept = args.batch_size * args.crop_size**2 // 16
            self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7,
                                                         min_kept=min_kept,
                                                         use_weight=False)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss(
                args.aux, aux_weight=args.aux_weight)

        # optimizer and lr scheduling
        params_list = [{
            'params': self.net.base1.parameters(),
            'lr': args.lr
        }, {
            'params': self.net.base2.parameters(),
            'lr': args.lr
        }, {
            'params': self.net.base3.parameters(),
            'lr': args.lr
        }]
        if hasattr(self.net, 'others'):
            for name in self.net.others:
                params_list.append({
                    'params':
                    getattr(self.net, name).parameters(),
                    'lr':
                    args.lr * 10
                })
        if hasattr(self.net, 'JPU'):
            params_list.append({
                'params': self.net.JPU.parameters(),
                'lr': args.lr * 10
            })
        self.optimizer = optim.SGD(params_list,
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=args.max_iter,
                                      warmup_factor=args.warmup_factor,
                                      warmup_iters=args.warmup_iters,
                                      power=0.9)

        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)
        self.args = args