def main(): cudnn.benchmark = True device = 'cuda' if torch.cuda.is_available() else 'cpu' assert config.MISC.TEST_INTERVAL is not 0, 'Illegal setting: config.MISC.TEST_INTERVAL = 0!' # set random seed if config.MISC.RANDOM_SEED: random.seed(config.MISC.RANDOM_SEED) np.random.seed(config.MISC.RANDOM_SEED) torch.manual_seed(config.MISC.RANDOM_SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(config.MISC.RANDOM_SEED) train_transformer = Mytransforms.Compose([ Mytransforms.KeyAreaCrop(20), Mytransforms.RandomRotate(40), Mytransforms.TestResized(config.MODEL.IMG_SIZE), Mytransforms.RandomHorizontalFlip() ]) test_transformer = Mytransforms.Compose([ Mytransforms.KeyAreaCrop(20), Mytransforms.TestResized(config.MODEL.IMG_SIZE) ]) # train source_dset = HandKptDataset(config.DATA.SOURCE.TRAIN.DIR, config.DATA.SOURCE.TRAIN.LBL_FILE, stride=config.MODEL.HEATMAP_STRIDE, transformer=train_transformer) # target_dset = HandKptDataset(config.DATA.TARGET.TRAIN.DIR, config.DATA.TARGET.TRAIN.LBL_FILE, # stride=config.MODEL.HEATMAP_STRIDE, transformer=train_transformer) source_val_dset = HandKptDataset(config.DATA.SOURCE.VAL.DIR, config.DATA.SOURCE.VAL.LBL_FILE, stride=config.MODEL.HEATMAP_STRIDE, transformer=test_transformer) target_val_dset = HandKptDataset(config.DATA.TARGET.VAL.DIR, config.DATA.TARGET.VAL.LBL_FILE, stride=config.MODEL.HEATMAP_STRIDE, transformer=test_transformer) # source only train_loader = torch.utils.data.DataLoader( source_dset, batch_size=config.TRAIN.BATCH_SIZE, shuffle=True, num_workers=config.MISC.WORKERS, pin_memory=True) # val source_val_loader = torch.utils.data.DataLoader( source_val_dset, batch_size=config.TRAIN.BATCH_SIZE, shuffle=False, num_workers=config.MISC.WORKERS, pin_memory=True) target_val_loader = torch.utils.data.DataLoader( target_val_dset, batch_size=config.TRAIN.BATCH_SIZE, shuffle=False, num_workers=config.MISC.WORKERS, pin_memory=True) logger = Logger(ckpt_path=os.path.join(config.DATA.CKPT_PATH, config.PROJ_NAME), tsbd_path=os.path.join(config.DATA.VIZ_PATH, config.PROJ_NAME)) net = pose_resnet.get_pose_net(config).to(device) optimizer = torch.optim.Adam(net.parameters(), config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) input_shape = (config.TRAIN.BATCH_SIZE, 3, config.MODEL.IMG_SIZE, config.MODEL.IMG_SIZE) logger.add_graph(net, input_shape, device) if len(config.MODEL.RESUME) > 0: print("=> loading checkpoint '{}'".format(config.MODEL.RESUME)) resume_ckpt = torch.load(config.MODEL.RESUME) net.load_state_dict(resume_ckpt['net']) optimizer.load_state_dict(resume_ckpt['optim']) config.TRAIN.START_ITERS = resume_ckpt['iter'] logger.global_step = resume_ckpt['iter'] logger.best_metric_val = resume_ckpt['best_metric_val'] net = torch.nn.DataParallel(net) if config.EVALUATE: pck05, pck2 = evaluate(net, target_val_loader, img_size=config.MODEL.IMG_SIZE, vis=True, logger=logger, disp_interval=config.MISC.DISP_INTERVAL) print("=> validate [email protected] = {}, [email protected] = {}".format(pck05 * 100, pck2 * 100)) return criterion = nn.SmoothL1Loss(reduction='none').to(device) total_progress_bar = tqdm.tqdm(desc='Train iter', ncols=80, total=config.TRAIN.MAX_ITER, initial=config.TRAIN.START_ITERS) epoch = 0 while logger.global_step < config.TRAIN.MAX_ITER: for (stu_inputs, stu_heatmap, _) in tqdm.tqdm( train_loader, total=len(train_loader), desc='Current epoch', ncols=80, leave=False): stu_inputs = stu_inputs.to(device) stu_heatmap = stu_heatmap.to(device) stu_heats = net(stu_inputs) loss = criterion(stu_heats, stu_heatmap).sum() / stu_inputs.size(0) optimizer.zero_grad() loss.backward() optimizer.step() # val if logger.global_step % config.MISC.TEST_INTERVAL == 0: pck05, pck2 = evaluate(net, source_val_loader, img_size=config.MODEL.IMG_SIZE, vis=True, logger=logger, disp_interval=config.MISC.DISP_INTERVAL, show_gt=(logger.global_step == 0), is_target=False) logger.add_scalar('[email protected]', pck05 * 100) logger.add_scalar('[email protected]', pck2 * 100) pck05, pck2 = evaluate(net, target_val_loader, img_size=config.MODEL.IMG_SIZE, vis=True, logger=logger, disp_interval=config.MISC.DISP_INTERVAL, show_gt=(logger.global_step == 0), is_target=True) logger.add_scalar('[email protected]', pck05 * 100) logger.add_scalar('[email protected]', pck2 * 100) logger.save_ckpt(state={ 'net': net.module.state_dict(), 'optim': optimizer.state_dict(), 'iter': logger.global_step, 'best_metric_val': logger.best_metric_val, }, cur_metric_val=pck05) logger.step(1) total_progress_bar.update(1) # log logger.add_scalar('regress_loss', loss.item()) epoch += 1 total_progress_bar.close()