def main(): global args, best_prec1 args = parser.parse_args() if args.dataset == 'ucf101': num_class = 101 elif args.dataset == 'hmdb51': num_class = 51 elif args.dataset == 'kinetics': num_class = 400 else: raise ValueError('Unknown dataset ' + args.dataset) if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 else: data_length = 5 # generate 5 displacement map, using 6 RGB images model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, new_length=data_length) model = model.to(device) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std train_augmentation = model.get_augmentation() if device.type == 'cuda': model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() if args.resume: if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict'], strict=True) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() train_loader = torch.utils.data.DataLoader(TSNDataSet( "", args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff", "CV"] else args.flow_prefix + "{}_{:05d}.jpg", transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(TSNDataSet( "", args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff", "CV"] else args.flow_prefix + "{}_{:05d}.jpg", random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_steps, gamma=0.1) if args.evaluate: validate(val_loader, model, criterion, 0) return for epoch in range(0, args.epochs): scheduler.step() if epoch < args.start_epoch: continue # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, epoch) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best) writer.close()
def create_model(args): """Create a pytorch model based on the model architecture and dataset Args: pretrained [boolean]: True is you wish to load a pretrained model. Some models do not have a pretrained version. dataset: dataset name (only 'imagenet' and 'cifar10' are supported) arch: architecture name parallel [boolean]: if set, use torch.nn.DataParallel device_ids: Devices on which model should be created - None - GPU if available, otherwise CPU -1 - CPU >=0 - GPU device IDs """ pretrained = args.pretrained dataset = args.dataset num_classes = args.num_classes parallel = not args.load_serialized device_ids = args.gpus arch = args.arch dataset = dataset.lower() if dataset not in SUPPORTED_DATASETS: raise ValueError('Dataset {} is not supported'.format(dataset)) model = None try: # model = models.__dict__[arch](pretrained, num_classes) model = TSN(num_classes, args.num_segments, args.modality, base_model=arch, consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn) except ValueError: raise ValueError( 'Could not recognize dataset {} and arch {} pair'.format( dataset, arch)) msglogger.info("=> created a %s%s model with the %s dataset" % ('pretrained ' if pretrained else '', arch, dataset)) if torch.cuda.is_available() and device_ids != -1: device = 'cuda' if parallel: if arch.startswith('alexnet') or ('vgg' in arch): model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) else: model = torch.nn.DataParallel(model, device_ids=device_ids) model.is_parallel = parallel else: device = 'cpu' model.is_parallel = False # Cache some attributes which describe the model # _set_model_input_shape_attr(model, arch, dataset, pretrained, cadene) model.arch = arch model.dataset = dataset return model.to(device)