def main(): logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir) summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir) print = logger.info print(cfg) num_gpus = torch.cuda.device_count() if cfg.dist: device = torch.device('cuda:%d' % cfg.local_rank) if cfg.dist else torch.device('cuda') torch.cuda.set_device(cfg.local_rank) dist.init_process_group(backend='nccl', init_method='env://', world_size=num_gpus, rank=cfg.local_rank) else: device = torch.device('cuda') print('==> Preparing data..') train_dataset = ImgNet_split(root=os.path.join(cfg.data_dir, 'train'), transform=imgnet_transform(is_training=True)) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=num_gpus, rank=cfg.local_rank) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size // num_gpus if cfg.dist else cfg.batch_size, shuffle=not cfg.dist, num_workers=cfg.num_workers, sampler=train_sampler if cfg.dist else None) val_dataset = ImgNet_split(root=os.path.join(cfg.data_dir, 'val'), transform=imgnet_transform(is_training=False)) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers) print('==> Building model..') genotype = torch.load(os.path.join(cfg.ckpt_dir, 'genotype.pickle'))['genotype'] model = NetworkImageNet(genotype, cfg.init_ch, cfg.num_cells, cfg.auxiliary, num_classes=1000) if not cfg.dist: model = nn.DataParallel(model).to(device) else: # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = model.to(device) model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.local_rank, ], output_device=cfg.local_rank) optimizer = torch.optim.SGD(model.parameters(), cfg.lr, momentum=0.9, weight_decay=cfg.wd) criterion = CrossEntropyLabelSmooth(num_classes=1000, epsilon=cfg.label_smooth).to(device) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97) warmup = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=2) # Training def train(epoch): model.train() start_time = time.time() for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device, non_blocking=True) outputs, outputs_aux = model(inputs) loss = criterion(outputs, targets) loss_aux = criterion(outputs_aux, targets) loss += cfg.auxiliary * loss_aux optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() if batch_idx % cfg.log_interval == 0: step = len(train_loader) * epoch + batch_idx duration = time.time() - start_time print('[%d/%d - %d/%d] cls_loss= %.5f (%d samples/sec)' % (epoch, cfg.max_epochs, batch_idx, len(train_loader), loss.item(), cfg.batch_size * cfg.log_interval / duration)) start_time = time.time() summary_writer.add_scalar('cls_loss', loss.item(), step) summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) def val(epoch): # switch to evaluate mode model.eval() top1 = 0 top5 = 0 with torch.no_grad(): for i, (inputs, targets) in enumerate(val_loader): inputs, targets = inputs.to(device), targets.to(device, non_blocking=True) output, _ = model(inputs) # measure accuracy and record loss _, pred = output.data.topk(5, dim=1, largest=True, sorted=True) pred = pred.t() correct = pred.eq(targets.view(1, -1).expand_as(pred)) top1 += correct[:1].view(-1).float().sum(0, keepdim=True).item() top5 += correct[:5].view(-1).float().sum(0, keepdim=True).item() top1 *= 100 / len(val_dataset) top5 *= 100 / len(val_dataset) print(' Precision@1 ==> %.2f%% Precision@1: %.2f%%\n' % (top1, top5)) summary_writer.add_scalar('Precision@1', top1, epoch) summary_writer.add_scalar('Precision@5', top5, epoch) return for epoch in range(cfg.max_epochs): print('\nEpoch: %d lr: %.5f drop_path_prob: %.3f' % (epoch, scheduler.get_lr()[0], cfg.drop_path_prob * epoch / cfg.max_epochs)) model.module.drop_path_prob = cfg.drop_path_prob * epoch / cfg.max_epochs train_sampler.set_epoch(epoch) train(epoch) val(epoch) if epoch < 5: warmup.step(epoch) else: scheduler.step(epoch) # move to here after pytorch1.1.0 print(model.module.genotype()) if cfg.local_rank == 0: torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) summary_writer.close() count_parameters(model) count_flops(model, input_size=224)
get_hourglass = \ {'large_hourglass': exkp(n=5, nstack=2, dims=[256, 256, 384, 384, 384, 512], modules=[2, 2, 2, 2, 2, 4]), 'small_hourglass': exkp(n=5, nstack=1, dims=[256, 256, 384, 384, 384, 512], modules=[2, 2, 2, 2, 2, 4])} if __name__ == '__main__': from collections import OrderedDict from utils.utils import count_parameters, count_flops, load_model def hook(self, input, output): print(output.data.cpu().numpy().shape) # pass net = get_hourglass['large_hourglass'] load_model(net, '../ckpt/pretrain/checkpoint.t7') count_parameters(net) count_flops(net, input_size=512) for m in net.modules(): if isinstance(m, nn.Conv2d): m.register_forward_hook(hook) with torch.no_grad(): y = net(torch.randn(2, 3, 512, 512).cuda()) # print(y.size())
def main(): logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir) summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir) print = logger.info print(cfg) num_gpus = torch.cuda.device_count() if cfg.dist: device = torch.device( 'cuda:%d' % cfg.local_rank) if cfg.dist else torch.device('cuda') torch.cuda.set_device(cfg.local_rank) dist.init_process_group(backend='nccl', init_method='env://', world_size=num_gpus, rank=cfg.local_rank) else: device = torch.device('cuda') print('==> Preparing data..') cifar = 100 if 'cifar100' in cfg.log_name else 10 train_dataset = CIFAR_split(cifar=cifar, root=cfg.data_dir, split='train', ratio=1.0, transform=cifar_search_transform( is_training=True, cutout=cfg.cutout)) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=num_gpus, rank=cfg.local_rank) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg.batch_size // num_gpus if cfg.dist else cfg.batch_size, shuffle=not cfg.dist, num_workers=cfg.num_workers, sampler=train_sampler if cfg.dist else None) test_dataset = CIFAR_split( cifar=cifar, root=cfg.data_dir, split='test', transform=cifar_search_transform(is_training=False)) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers) print('==> Building model..') print(os.path.join(cfg.ckpt_dir, 'seed-14880-best-genotype.pth')) #genotype = torch.load(os.path.join(cfg.ckpt_dir, 'seed-14880-best-genotype.pth')) genotype = seed14880 model = NetworkCIFAR(genotype, cfg.init_ch, cfg.num_cells, cfg.auxiliary, num_classes=cifar) if not cfg.dist: model = nn.DataParallel(model).to(device) else: # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = model.to(device) model = nn.parallel.DistributedDataParallel( model, device_ids=[ cfg.local_rank, ], output_device=cfg.local_rank) optimizer = torch.optim.SGD(model.parameters(), cfg.lr, momentum=0.9, weight_decay=cfg.wd) criterion = nn.CrossEntropyLoss().to(device) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, cfg.max_epochs) # Training def train(epoch): model.train() start_time = time.time() for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device, non_blocking=True) # very important outputs, outputs_aux = model(inputs) loss = criterion(outputs, targets) loss_aux = criterion(outputs_aux, targets) loss += cfg.auxiliary * loss_aux optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() if batch_idx % cfg.log_interval == 0: step = len(train_loader) * epoch + batch_idx duration = time.time() - start_time print('[%d/%d - %d/%d] cls_loss= %.5f (%d samples/sec)' % (epoch, cfg.max_epochs, batch_idx, len(train_loader), loss.item(), cfg.batch_size * cfg.log_interval / duration)) start_time = time.time() summary_writer.add_scalar('cls_loss', loss.item(), step) summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) def test(epoch): model.eval() correct = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(test_loader): inputs, targets = inputs.to(device), targets.to( device, non_blocking=True) outputs, _ = model(inputs) _, predicted = torch.max(outputs.data, 1) correct += predicted.eq(targets.data).cpu().sum().item() acc = 100. * correct / len(test_loader.dataset) print(' Precision@1 ==> %.2f%% \n' % acc) summary_writer.add_scalar('Precision@1', acc, global_step=epoch) return for epoch in range(cfg.max_epochs): print('\nEpoch: %d lr: %.5f drop_path_prob: %.3f' % (epoch, scheduler.get_lr()[0], cfg.drop_path_prob * epoch / cfg.max_epochs)) model._modules[ 'module'].drop_path_prob = cfg.drop_path_prob * epoch / cfg.max_epochs train_sampler.set_epoch(epoch) train(epoch) test(epoch) scheduler.step(epoch) # move to here after pytorch1.1.0 #print(model.module.genotype()) if cfg.local_rank == 0: torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) summary_writer.close() count_parameters(model) count_flops(model, input_size=32)