def main(args): """Train/ Cross validate for data source = YogiDB.""" # Create data loader """Generic(data.Dataset)(image_set, annotations, is_train=True, inp_res=256, out_res=64, sigma=1, scale_factor=0, rot_factor=0, label_type='Gaussian', rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV).""" annotations_source = 'basic-thresholder' # Get the data from yogi db_obj = YogiDB(config.db_url) imageset = db_obj.get_filtered(ImageSet, name=args.image_set_name) annotations = db_obj.get_annotations(image_set_name=args.image_set_name, annotation_source=annotations_source) pts = torch.Tensor(annotations[0]['joint_self']) num_classes = pts.size(0) crop_size = 512 if args.crop: crop_size = args.crop crop = True else: crop = False # Using the default RGB mean and std dev as 0 RGB_MEAN = torch.as_tensor([0.0, 0.0, 0.0]) RGB_STDDEV = torch.as_tensor([0.0, 0.0, 0.0]) dataset = Generic(image_set=imageset, inp_res=args.inp_res, out_res=args.out_res, annotations=annotations, mode=args.mode, crop=crop, crop_size=crop_size, rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV) train_dataset = dataset train_dataset.is_train = True train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_dataset = dataset val_dataset.is_train = False val_loader = DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) # Select the hardware device to use for inference. if torch.cuda.is_available(): device = torch.device('cuda', torch.cuda.current_device()) torch.backends.cudnn.benchmark = True else: device = torch.device('cpu') # Disable gradient calculations by default. torch.set_grad_enabled(False) # create checkpoint dir os.makedirs(args.checkpoint, exist_ok=True) if args.arch == 'hg1': model = hg1(pretrained=False, num_classes=num_classes) elif args.arch == 'hg2': model = hg2(pretrained=False, num_classes=num_classes) elif args.arch == 'hg8': model = hg8(pretrained=False, num_classes=num_classes) else: raise Exception('unrecognised model architecture: ' + args.model) model = DataParallel(model).to(device) if args.optimizer == "Adam": optimizer = Adam(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) best_acc = 0 # optionally resume from a checkpoint title = args.data_identifier + ' ' + args.arch if args.resume: assert os.path.isfile(args.resume) print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_acc'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) else: logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) # train and eval lr = args.lr for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # train for one epoch train_loss, train_acc = do_training_epoch(train_loader, model, device, optimizer) # evaluate on validation set if args.debug == 1: valid_loss, valid_acc, predictions, validation_log = do_validation_epoch(val_loader, model, device, False, True, os.path.join(args.checkpoint, 'debug.csv'), epoch + 1) else: valid_loss, valid_acc, predictions, _ = do_validation_epoch(val_loader, model, device, False) # append logger file logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc]) # remember best acc and save checkpoint is_best = valid_acc > best_acc best_acc = max(valid_acc, best_acc) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot) logger.close() logger.plot(['Train Acc', 'Val Acc']) savefig(os.path.join(args.checkpoint, 'log.eps'))
def main(args): # Select the hardware device to use for inference. if torch.cuda.is_available(): device = torch.device('cuda', torch.cuda.current_device()) torch.backends.cudnn.benchmark = True else: device = torch.device('cpu') # Disable gradient calculations by default. torch.set_grad_enabled(False) # create checkpoint dir os.makedirs(args.checkpoint, exist_ok=True) if args.arch == 'hg1': model = hg1(pretrained=False) elif args.arch == 'hg2': model = hg2(pretrained=False) elif args.arch == 'hg8': model = hg8(pretrained=False) else: raise Exception('unrecognised model architecture: ' + args.arch) model = DataParallel(model).to(device) optimizer = RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) best_acc = 0 # optionally resume from a checkpoint if args.resume: assert os.path.isfile(args.resume) print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_acc'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True) else: logger = Logger(os.path.join(args.checkpoint, 'log.txt')) logger.set_names( ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) # create data loader train_dataset = Mpii(args.image_path, is_train=True) train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_dataset = Mpii(args.image_path, is_train=False) val_loader = DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) # train and eval lr = args.lr for epoch in trange(args.start_epoch, args.epochs, desc='Overall', ascii=True): lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) # train for one epoch train_loss, train_acc = do_training_epoch(train_loader, model, device, Mpii.DATA_INFO, optimizer, acc_joints=Mpii.ACC_JOINTS) # evaluate on validation set valid_loss, valid_acc, predictions = do_validation_epoch( val_loader, model, device, Mpii.DATA_INFO, False, acc_joints=Mpii.ACC_JOINTS) # print metrics tqdm.write( f'[{epoch + 1:3d}/{args.epochs:3d}] lr={lr:0.2e} ' f'train_loss={train_loss:0.4f} train_acc={100 * train_acc:0.2f} ' f'valid_loss={valid_loss:0.4f} valid_acc={100 * valid_acc:0.2f}') # append logger file logger.append( [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc]) logger.plot_to_file(os.path.join(args.checkpoint, 'log.svg'), ['Train Acc', 'Val Acc']) # remember best acc and save checkpoint is_best = valid_acc > best_acc best_acc = max(valid_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot) logger.close()