def _make_torch_data_loaders(args): """ Helper function to produce data loaders for training :param args: Command line arguments :return: Training dataset, training data loader, validation data loader. """ train_dataset = datasets.Mpii( 'stacked_hourglass/data/mpii/mpii_annotations.json', 'stacked_hourglass/data/mpii/images', sigma=args.sigma, label_type=args.label_type, augment_data=args.augment_training_data, args=args) val_dataset = datasets.Mpii( 'stacked_hourglass/data/mpii/mpii_annotations.json', 'stacked_hourglass/data/mpii/images', sigma=args.sigma, label_type=args.label_type, train=False, augment_data=False, args=args) if args.use_horovod: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, # shuffle=True,#sampler=train_sampler, num_workers=args.workers, pin_memory=True) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, num_replicas=hvd.size(), rank=hvd.rank()) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.test_batch_size, sampler=val_sampler, # shuffle=False, #sampler=val_sampler, num_workers=args.workers, pin_memory=True) else: train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) return train_dataset, train_loader, val_loader
def main(args): global best_acc # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) # create model print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks)) model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes) model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = torch.nn.MSELoss(size_average=True).cuda() optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint title = 'mpii-' + args.arch if args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_acc'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) # Data loading code train_loader = torch.utils.data.DataLoader( datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images', sigma=args.sigma, label_type=args.label_type), batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images', sigma=args.sigma, label_type=args.label_type, train=False), batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: print('\nEvaluation only') loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip) save_pred(predictions, checkpoint=args.checkpoint) return lr = args.lr for epoch in range(args.start_epoch, args.epochs): from time import sleep sleep(2) lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # decay sigma if args.sigma_decay > 0: train_loader.dataset.sigma *= args.sigma_decay val_loader.dataset.sigma *= args.sigma_decay # train for one epoch train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip) # evaluate on validation set valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip) # append logger file logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc]) # remember best acc and save checkpoint is_best = valid_acc > best_acc best_acc = max(valid_acc, best_acc) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer' : optimizer.state_dict(), }, predictions, is_best, checkpoint=args.checkpoint) logger.close() logger.plot(['Train Acc', 'Val Acc']) savefig(os.path.join(args.checkpoint, 'log.eps'))
def main(args): global best_acc # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) _logger = log.get_logger(__name__, args) _logger.info(print_args(args)) # create model print("==> creating model '{}', stacks={}, blocks={}".format( args.arch, args.stacks, args.blocks)) model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=len(args.index_classes)) model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = models.loss.UniLoss(a_points=args.a_points) optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint title = 'mpii-' + args.arch if args.resume: if isfile(args.resume): _logger.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_acc'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) for param_group in optimizer.param_groups: param_group['lr'] = args.lr print(param_group['lr']) _logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=False) logger.set_names([ 'Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc' ]) else: _logger.info("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names( ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) cudnn.benchmark = True _logger.info(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # Data loading code train_loader = torch.utils.data.DataLoader( datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images', sigma=args.sigma, label_type=args.label_type, _idx=args.index_classes, direct=True, n_points=args.n_points), batch_size=args.train_batch, shuffle=True, collate_fn=datasets.mpii.mycollate, num_workers=args.workers, pin_memory=False) val_loader = torch.utils.data.DataLoader( datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images', sigma=args.sigma, label_type=args.label_type, _idx=args.index_classes, train=False, direct=True), batch_size=args.test_batch, shuffle=False, collate_fn=datasets.mpii.mycollate, num_workers=args.workers, pin_memory=False) if args.evaluate: _logger.warning('\nEvaluation only') loss, acc, predictions = validate(val_loader, model, criterion, len(args.index_classes), False, args.flip, _logger, evaluate_only=True) save_pred(predictions, checkpoint=args.checkpoint) return # multi-thread inqueues = [] outqueues = [] valid_accs = [] lr = args.lr for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) _logger.warning('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # decay sigma if args.sigma_decay > 0: train_loader.dataset.sigma *= args.sigma_decay val_loader.dataset.sigma *= args.sigma_decay # train for one epoch train_loss, train_acc = train(inqueues, outqueues, train_loader, model, criterion, optimizer, args.debug, args.flip, args.clip, _logger) # evaluate on validation set with torch.no_grad(): valid_loss, valid_acc, predictions = validate( val_loader, model, criterion, len(args.index_classes), args.debug, args.flip, _logger) # append logger file logger.append( [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc]) valid_accs.append(valid_acc) if args.schedule[0] == -1: if len(valid_accs) > 8: if sum(valid_accs[-4:]) / 4 * 0.99 < sum( valid_accs[-8:-4]) / 4: args.schedule.append(epoch + 1) valid_accs = [] # remember best acc and save checkpoint is_best = valid_acc > best_acc best_acc = max(valid_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), }, predictions, is_best, checkpoint=args.checkpoint, snapshot=1) logger.close() logger.plot(['Train Acc', 'Val Acc']) savefig(os.path.join(args.checkpoint, 'log.eps'))