def get_dataloader(data_dir, batch_size, num_workers, input_size, mean, std, distributed): """Get dataloader.""" def val_batch_fn(batch, device): data = batch[0].to(device) scale = batch[1] center = batch[2] score = batch[3] imgid = batch[4] return data, scale, center, score, imgid val_dataset = COCOKeyPoints(data_dir, aspect_ratio=4. / 3., splits=('person_keypoints_val2017')) meanvec = [float(i) for i in mean.split(',')] stdvec = [float(i) for i in std.split(',')] transform_val = SimplePoseDefaultValTransform( num_joints=val_dataset.num_joints, joint_pairs=val_dataset.joint_pairs, image_size=input_size, mean=meanvec, std=stdvec) val_tmp = val_dataset.transform(transform_val) sampler = make_data_sampler(val_tmp, False, distributed) batch_sampler = data.BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False) val_data = data.DataLoader(val_tmp, batch_sampler=batch_sampler, num_workers=num_workers) return val_dataset, val_data, val_batch_fn
def get_dataloader(val_dataset, batch_size, num_workers, distributed, coco=False): """Get dataloader.""" if coco: batchify_fn = Tuple(*[Append() for _ in range(3)], Empty()) else: batchify_fn = Tuple(*[Append() for _ in range(3)]) sampler = make_data_sampler(val_dataset, False, distributed) batch_sampler = data.BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False) val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, collate_fn=batchify_fn, num_workers=num_workers) return val_loader
def get_dataloader(val_dataset, batch_size, num_workers, distributed, coco=False): """Get dataloader.""" if coco: batchify_fn = Tuple(Stack(), Pad(pad_val=-1), Empty()) else: batchify_fn = Tuple(Stack(), Pad(pad_val=-1)) sampler = make_data_sampler(val_dataset, False, distributed) batch_sampler = data.BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False) val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, collate_fn=batchify_fn, num_workers=num_workers) return val_loader
def get_dataloader(batch_size, num_workers, data_root, distributed): transform_test = transforms.Compose([ transforms_cv.ToTensor(), transforms_cv.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) ]) val_dataset = CIFAR10(root=data_root, train=False, transform=transform_test, download=True) sampler = make_data_sampler(val_dataset, False, distributed) batch_sampler = data.BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False) val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, num_workers=num_workers) return val_loader
def get_dataloader(opt, distributed): input_size = opt.input_size crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 resize = int(math.ceil(input_size / crop_ratio)) transform_test = transforms_cv.Compose([ transforms_cv.Resize((resize, resize)), transforms_cv.CenterCrop(input_size), transforms_cv.ToTensor(), transforms_cv.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_dataset = ImageNet(opt.data_dir, train=False, transform=transform_test) sampler = make_data_sampler(val_dataset, False, distributed) batch_sampler = data.BatchSampler(sampler=sampler, batch_size=opt.batch_size, drop_last=False) val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, num_workers=opt.num_workers) return val_loader
def __init__(self, args): self.device = torch.device(args.device) # network net_name = '_'.join(('yolo3', args.network, args.dataset)) self.save_prefix = net_name self.net = get_model(net_name, pretrained_base=True) if args.distributed: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) if args.resume.strip(): logger.info("Resume from the model {}".format(args.resume)) self.net.load_state_dict(torch.load(args.resume.strip())) else: logger.info("Init from base net {}".format(args.network)) classes, anchors = self.net.num_class, self.net.anchors self.net.set_nms(nms_thresh=0.45, nms_topk=400) if args.label_smooth: self.net._target_generator._label_smooth = True self.net.to(self.device) if args.distributed: self.net = torch.nn.parallel.DistributedDataParallel( self.net, device_ids=[args.local_rank], output_device=args.local_rank) # dataset and dataloader train_dataset = get_train_data(args.dataset, args.mixup) width, height = args.data_shape, args.data_shape batchify_fn = Tuple(*([Stack() for _ in range(6)] + [Pad(axis=0, pad_val=-1) for _ in range(1)])) train_dataset = train_dataset.transform( YOLO3DefaultTrainTransform(width, height, classes, anchors, mixup=args.mixup)) args.per_iter = len(train_dataset) // (args.num_gpus * args.batch_size) args.max_iter = args.epochs * args.per_iter if args.distributed: sampler = data.DistributedSampler(train_dataset) else: sampler = data.RandomSampler(train_dataset) train_sampler = data.sampler.BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=False) train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=args.max_iter) if args.no_random_shape: self.train_loader = data.DataLoader(train_dataset, batch_sampler=train_sampler, pin_memory=True, collate_fn=batchify_fn, num_workers=args.num_workers) else: transform_fns = [YOLO3DefaultTrainTransform(x * 32, x * 32, classes, anchors, mixup=args.mixup) for x in range(10, 20)] self.train_loader = RandomTransformDataLoader(transform_fns, train_dataset, batch_sampler=train_sampler, collate_fn=batchify_fn, num_workers=args.num_workers) if args.eval_epoch > 0: # TODO: rewrite it val_dataset, self.metric = get_test_data(args.dataset) val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1)) val_dataset = val_dataset.transform(YOLO3DefaultValTransform(width, height)) val_sampler = make_data_sampler(val_dataset, False, args.distributed) val_batch_sampler = data.BatchSampler(val_sampler, args.test_batch_size, False) self.val_loader = data.DataLoader(val_dataset, batch_sampler=val_batch_sampler, collate_fn=val_batchify_fn, num_workers=args.num_workers) # optimizer and lr scheduling self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) if args.lr_mode == 'cos': self.scheduler = WarmupCosineLR(optimizer=self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters) elif args.lr_mode == 'step': lr_decay = float(args.lr_decay) milestones = sorted([float(ls) * args.per_iter for ls in args.lr_decay_epoch.split(',') if ls.strip()]) self.scheduler = WarmupMultiStepLR(optimizer=self.optimizer, milestones=milestones, gamma=lr_decay, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters) else: raise ValueError('illegal scheduler type') self.args = args
input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) data_kwargs = { 'base_size': args.base_size, 'crop_size': args.crop_size, 'transform': input_transform } val_dataset = get_segmentation_dataset(args.dataset, split=args.split, mode=args.mode, **data_kwargs) sampler = make_data_sampler(val_dataset, False, distributed) batch_sampler = data.BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=False) val_data = data.DataLoader(val_dataset, shuffle=False, batch_sampler=batch_sampler, num_workers=args.num_workers) if args.multi: evaluator = MultiEvalModel(model, val_dataset.num_class) else: evaluator = SegEvalModel(model) metric = SegmentationMetric(val_dataset.num_class) metric = validate(evaluator, val_data, metric, device) ptutil.synchronize()
def train(self): train_dataset = CIFAR10(root=os.path.join(self.cfg.data_root, 'cifar10'), train=True, transform=self.transform_train, download=True) train_sampler = make_data_sampler(train_dataset, True, self.distributed) train_batch_sampler = data.sampler.BatchSampler( train_sampler, self.cfg.batch_size, True) train_data = data.DataLoader(train_dataset, num_workers=self.cfg.num_workers, batch_sampler=train_batch_sampler) val_dataset = CIFAR10(root=os.path.join(self.cfg.data_root, 'cifar10'), train=False, transform=self.transform_test) val_sampler = make_data_sampler(val_dataset, False, self.distributed) val_batch_sampler = data.sampler.BatchSampler(val_sampler, self.cfg.batch_size, False) val_data = data.DataLoader(val_dataset, num_workers=self.cfg.num_workers, batch_sampler=val_batch_sampler) optimizer = optim.SGD(self.net.parameters(), nesterov=True, lr=self.cfg.lr, weight_decay=self.cfg.wd, momentum=self.cfg.momentum) metric = Accuracy() train_metric = Accuracy() loss_fn = nn.CrossEntropyLoss() if is_main_process(): train_history = TrainingHistory( ['training-error', 'validation-error']) iteration = 0 lr_decay_count = 0 best_val_score = 0 for epoch in range(self.cfg.num_epochs): tic = time.time() train_metric.reset() metric.reset() train_loss = 0 num_batch = len(train_data) if epoch == self.lr_decay_epoch[lr_decay_count]: set_learning_rate( optimizer, get_learning_rate(optimizer) * self.cfg.lr_decay) lr_decay_count += 1 for i, batch in enumerate(train_data): image = batch[0].to(self.device) label = batch[1].to(self.device) output = self.net(image) loss = loss_fn(output, label) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() train_metric.update(label, output) iteration += 1 metric = self.validate(val_data, metric) synchronize() train_loss /= num_batch train_loss = reduce_list(all_gather(train_loss)) name, acc = accumulate_metric(train_metric) name, val_acc = accumulate_metric(metric) if is_main_process(): train_history.update([1 - acc, 1 - val_acc]) train_history.plot(save_path='%s/%s_history.png' % (self.plot_path, self.cfg.model)) if val_acc > best_val_score: best_val_score = val_acc torch.save( self.net.state_dict(), '%s/%.4f-cifar-%s-%d-best.pth' % (self.save_dir, best_val_score, self.cfg.model, epoch)) logging.info( '[Epoch %d] train=%f val=%f loss=%f time: %f' % (epoch, acc, val_acc, train_loss, time.time() - tic)) if self.save_period and self.cfg.save_dir and ( epoch + 1) % self.save_period == 0: torch.save( self.net.module.state_dict() if self.distributed else self.net.state_dict(), '%s/cifar10-%s-%d.pth' % (self.save_dir, self.cfg.model, epoch)) if is_main_process() and self.save_period and self.save_dir: torch.save( self.net.module.state_dict() if self.distributed else self.net.state_dict(), '%s/cifar10-%s-%d.pth' % (self.save_dir, self.cfg.model, self.cfg.num_epochs - 1))
def __init__(self, args): self.device = torch.device(args.device) self.save_prefix = '_'.join((args.model, args.backbone, args.dataset)) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) args.per_iter = len(trainset) // (args.num_gpus * args.batch_size) args.max_iter = args.epochs * args.per_iter if args.distributed: sampler = data.DistributedSampler(trainset) else: sampler = data.RandomSampler(trainset) train_sampler = data.sampler.BatchSampler(sampler, args.batch_size, True) train_sampler = IterationBasedBatchSampler( train_sampler, num_iterations=args.max_iter) self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=True, num_workers=args.workers) if not args.skip_eval or 0 < args.eval_epochs < args.epochs: valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, args.distributed) val_batch_sampler = data.sampler.BatchSampler( val_sampler, args.test_batch_size, False) self.valid_loader = data.DataLoader( valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network if args.model_zoo is not None: self.net = get_model(args.model_zoo, pretrained=True) else: self.net = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, dilated=args.dilated, jpu=args.jpu, crop_size=args.crop_size) if args.distributed: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) self.net.to(self.device) # resume checkpoint if needed if args.resume is not None: if os.path.isfile(args.resume): self.net.load_state_dict(torch.load(args.resume)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) # create criterion if args.ohem: min_kept = args.batch_size * args.crop_size**2 // 16 self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7, min_kept=min_kept, use_weight=False) else: self.criterion = MixSoftmaxCrossEntropyLoss( args.aux, aux_weight=args.aux_weight) # optimizer and lr scheduling params_list = [{ 'params': self.net.base1.parameters(), 'lr': args.lr }, { 'params': self.net.base2.parameters(), 'lr': args.lr }, { 'params': self.net.base3.parameters(), 'lr': args.lr }] if hasattr(self.net, 'others'): for name in self.net.others: params_list.append({ 'params': getattr(self.net, name).parameters(), 'lr': args.lr * 10 }) if hasattr(self.net, 'JPU'): params_list.append({ 'params': self.net.JPU.parameters(), 'lr': args.lr * 10 }) self.optimizer = optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.scheduler = WarmupPolyLR(self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, power=0.9) if args.distributed: self.net = torch.nn.parallel.DistributedDataParallel( self.net, device_ids=[args.local_rank], output_device=args.local_rank) # evaluation metrics self.metric = SegmentationMetric(trainset.num_class) self.args = args