# network kwargs = {} module_list = [] if args.use_fpn: module_list.append('fpn') if args.norm_layer is not None: module_list.append(args.norm_layer) net_name = '_'.join(('faster_rcnn', *module_list, args.network, args.dataset)) args.save_prefix += net_name if args.pretrained.lower() in ['true', '1', 'yes', 't']: net = model_zoo.get_model(net_name, pretrained=True, root=args.root, **kwargs) else: net = model_zoo.get_model(net_name, pretrained=False, **kwargs) net.load_state_dict(args.pretrained.strip()) net.to(device) # testing data val_dataset, val_metric = get_dataset(net.short, net.max_size, args.dataset) val_data = get_dataloader(val_dataset, args.batch_size, args.num_workers, distributed, args.dataset == 'coco') classes = val_dataset.classes # class names # testing val_metric = validate(net, val_data, device, val_metric, args.dataset == 'coco') synchronize() names, values = accumulate_metric(val_metric) if is_main_process(): for k, v in zip(names, values): print(k, v)
def train(self): train_dataset = CIFAR10(root=os.path.join(self.cfg.data_root, 'cifar10'), train=True, transform=self.transform_train, download=True) train_sampler = make_data_sampler(train_dataset, True, self.distributed) train_batch_sampler = data.sampler.BatchSampler( train_sampler, self.cfg.batch_size, True) train_data = data.DataLoader(train_dataset, num_workers=self.cfg.num_workers, batch_sampler=train_batch_sampler) val_dataset = CIFAR10(root=os.path.join(self.cfg.data_root, 'cifar10'), train=False, transform=self.transform_test) val_sampler = make_data_sampler(val_dataset, False, self.distributed) val_batch_sampler = data.sampler.BatchSampler(val_sampler, self.cfg.batch_size, False) val_data = data.DataLoader(val_dataset, num_workers=self.cfg.num_workers, batch_sampler=val_batch_sampler) optimizer = optim.SGD(self.net.parameters(), nesterov=True, lr=self.cfg.lr, weight_decay=self.cfg.wd, momentum=self.cfg.momentum) metric = Accuracy() train_metric = Accuracy() loss_fn = nn.CrossEntropyLoss() if is_main_process(): train_history = TrainingHistory( ['training-error', 'validation-error']) iteration = 0 lr_decay_count = 0 best_val_score = 0 for epoch in range(self.cfg.num_epochs): tic = time.time() train_metric.reset() metric.reset() train_loss = 0 num_batch = len(train_data) if epoch == self.lr_decay_epoch[lr_decay_count]: set_learning_rate( optimizer, get_learning_rate(optimizer) * self.cfg.lr_decay) lr_decay_count += 1 for i, batch in enumerate(train_data): image = batch[0].to(self.device) label = batch[1].to(self.device) output = self.net(image) loss = loss_fn(output, label) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() train_metric.update(label, output) iteration += 1 metric = self.validate(val_data, metric) synchronize() train_loss /= num_batch train_loss = reduce_list(all_gather(train_loss)) name, acc = accumulate_metric(train_metric) name, val_acc = accumulate_metric(metric) if is_main_process(): train_history.update([1 - acc, 1 - val_acc]) train_history.plot(save_path='%s/%s_history.png' % (self.plot_path, self.cfg.model)) if val_acc > best_val_score: best_val_score = val_acc torch.save( self.net.state_dict(), '%s/%.4f-cifar-%s-%d-best.pth' % (self.save_dir, best_val_score, self.cfg.model, epoch)) logging.info( '[Epoch %d] train=%f val=%f loss=%f time: %f' % (epoch, acc, val_acc, train_loss, time.time() - tic)) if self.save_period and self.cfg.save_dir and ( epoch + 1) % self.save_period == 0: torch.save( self.net.module.state_dict() if self.distributed else self.net.state_dict(), '%s/cifar10-%s-%d.pth' % (self.save_dir, self.cfg.model, epoch)) if is_main_process() and self.save_period and self.save_dir: torch.save( self.net.module.state_dict() if self.distributed else self.net.state_dict(), '%s/cifar10-%s-%d.pth' % (self.save_dir, self.cfg.model, self.cfg.num_epochs - 1))