def main(): global args, best_acc1 args = parser.parse_args() print('args:', args) args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # Val data loading valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) val_dataset = datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(args.input_dim + 32), transforms.CenterCrop(args.input_dim), transforms.ToTensor(), normalize, ])) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) num_classes = len(val_dataset.classes) print('Total classes: ', num_classes) # create model print("=> creating model '{}'".format(args.arch)) if args.arch == 'peleenet': model = PeleeNet(num_classes=num_classes) else: print( "=> unsupported model '{}'. creating PeleeNet by default.".format( args.arch)) model = PeleeNet(num_classes=num_classes) if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model) else: # DataParallel will divide model = torch.nn.DataParallel(model) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) elif args.pretrained: if os.path.isfile(args.weights): checkpoint = torch.load(args.weights, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {}, acc@1 {})".format( args.pretrained, checkpoint['epoch'], checkpoint['best_acc1'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion) return if args.tune: model.eval() model.module.fuse_model() import ilit tuner = ilit.Tuner("./conf.yaml") q_model = tuner.tune(model) exit(0) # Training data loading traindir = os.path.join(args.data, 'train') train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(args.input_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set acc1 = validate(val_loader, model, criterion) # remember best Acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, is_best)
if epoch < 2: count = train_warmup(train_set, val_set, model, optimizer_conv, criteria1, epoch, writer, count) else: count, train_loss = train(train_set, val_set, model, optimizer_conv, criteria1, writer, count, epoch) print("train_loss : {0}, lr : {1}".format( train_loss, optimizer_conv.param_groups[0]['lr'])) schedule.step(train_loss) #schedule.step() val_top1, val_top3, val_top5, val_loss = validate( val_set, model, criteria1) writer = add_summery(writer, 'val', val_loss, val_top1, val_top5, count) if_best_model = (val_top1 > best_acc) best_acc = max(val_top1, best_acc) filepath = 'weights/epoch_' + str(epoch) + 'checkpoint.pth.tar' save_step( { 'epoch': epoch + 1, 'arch': 'peleenet', 'state_dict': model.state_dict(), 'acc': val_top1, }, if_best_model, 'test', filename=filepath)