def main(args): logger.debug("Creating model '{}', stacks={}, blocks={}".format( args.arch, args.stacks, args.blocks)) model = hg( num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes, num_feats=args.features, inplanes=args.inplanes, init_stride=args.stride, ) model = on_cuda(torch.nn.DataParallel(model)) optimizer = torch.optim.RMSprop( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=5) if args.resume: load_weights(model, args.resume) logger.debug("Total params: %.2fM" % (sum(p.numel() for p in model.parameters()) / 1000000.0)) if args.unlabeled: loader = DataLoader( DrosophilaDataset( data_folder=args.data_folder, train=False, sigma=args.sigma, session_id_train_list=None, folder_train_list=None, img_res=args.img_res, hm_res=args.hm_res, augmentation=False, evaluation=True, unlabeled=args.unlabeled, num_classes=args.num_classes, max_img_id=min(get_max_img_id(args.unlabeled), args.max_img_id), output_folder=args.output_folder, ), batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=False, drop_last=False, ) pred, heatmap = process_folder( model, loader, args.unlabeled, args.output_folder, args.overwrite, num_classes=args.num_classes, acc_joints=args.acc_joints, ) return pred, heatmap else: train_loader, val_loader = create_dataloader() lr = args.lr best_acc = 0 for epoch in range(args.start_epoch, args.epochs): logger.debug("\nEpoch: %d | LR: %.8f" % (epoch + 1, lr)) _, _, _, _, _ = step( loader=train_loader, model=model, optimizer=optimizer, mode=Mode.train, heatmap=False, epoch=epoch, num_classes=args.num_classes, acc_joints=args.acc_joints, ) val_pred, _, val_loss, val_acc, val_mse = step( loader=val_loader, model=model, optimizer=optimizer, mode=Mode.test, heatmap=False, epoch=epoch, num_classes=args.num_classes, acc_joints=args.acc_joints, ) scheduler.step(val_loss) is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) save_checkpoint( { "epoch": epoch + 1, "arch": args.arch, "state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "best_acc": best_acc, "optimizer": optimizer.state_dict(), "image_shape": args.img_res, "heatmap_shape": args.hm_res, }, val_pred, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot, )
def main(args): global best_acc # create model getLogger('df3d').debug("Creating model '{}', stacks={}, blocks={}".format( args.arch, args.stacks, args.blocks ) ) model = models.__dict__[args.arch]( num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes, num_feats=args.features, inplanes=args.inplanes, init_stride=args.stride, ) model = torch.nn.DataParallel(model).cuda() criterion = torch.nn.MSELoss(reduction='mean').cuda() # deprecated: size_average=True optimizer = torch.optim.RMSprop( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, verbose=True, patience=5 ) # optionally resume from a checkpoint title = "Drosophila-" + args.arch if args.resume: if isfile(args.resume): getLogger('df3d').debug("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) if "mpii" in args.resume and not args.unlabeled: # weights for sh trained on mpii dataset getLogger('df3d').debug("Removing input/output layers") ignore_weight_list_template = [ "module.score.{}.bias", "module.score.{}.weight", "module.score_.{}.weight", ] ignore_weight_list = list() for i in range(8): for template in ignore_weight_list_template: ignore_weight_list.append(template.format(i)) for k in ignore_weight_list: if k in checkpoint["state_dict"]: checkpoint["state_dict"].pop(k) state = model.state_dict() state.update(checkpoint["state_dict"]) getLogger('df3d').debug(model.state_dict()) getLogger('df3d').debug(checkpoint["state_dict"]) model.load_state_dict(state, strict=False) elif "mpii" in args.resume and args.unlabeled: model.load_state_dict(checkpoint['state_dict'], strict=False) else: pretrained_dict = checkpoint["state_dict"] model.load_state_dict(pretrained_dict, strict=True) args.start_epoch = checkpoint["epoch"] args.img_res = checkpoint["image_shape"] args.hm_res = checkpoint["heatmap_shape"] getLogger('df3d').debug("Loading the optimizer") getLogger('df3d').debug( "Setting image resolution and heatmap resolution: {} {}".format( args.img_res, args.hm_res ) ) getLogger('df3d').debug( "Loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint["epoch"] ) ) else: print("=> no checkpoint found at '{}'".format(args.resume)) raise FileNotFoundError # prepare loggers if not args.unlabeled: logger = Logger(join(args.checkpoint, "log.txt"), title=title) logger.set_names( [ "Epoch", "LR", "Train Loss", "Val Loss", "Train Acc", "Val Acc", "Val Mse", "Val Jump", ] ) # cudnn.benchmark = True getLogger('df3d').debug("Total params: %.2fM" % (sum(p.numel() for p in model.parameters()) / 1000000.0)) if args.unlabeled: if args.unlabeled[0] == '/': #wtf why does it have a slash before it where did that come from? args.unlabeled = args.unlabeled[1:] unlabeled_folder = args.unlabeled print("UNLABELED FOLDER:") print(unlabeled_folder) max_img_id = get_max_img_id(unlabeled_folder) try: max_img_id = min(max_img_id, args.num_images_max-1) except: pass getLogger('df3d').debug('Going to process {} images'.format(max_img_id+1)) unlabeled_loader = DataLoader( deepfly.pose2d.datasets.Drosophila( data_folder=args.data_folder, train=False, sigma=args.sigma, session_id_train_list=None, folder_train_list=None, img_res=args.img_res, hm_res=args.hm_res, augmentation=False, evaluation=True, unlabeled=unlabeled_folder, num_classes=args.num_classes, max_img_id=max_img_id, ), batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=False, drop_last=False, ) valid_loss, valid_acc, val_pred, val_score_maps, mse, jump_acc = validate( unlabeled_loader, 0, model, criterion, args, save_path=unlabeled_folder ) unlabeled_folder_replace = unlabeled_folder.replace("/", "-") getLogger('df3d').debug(f"val_score_maps have shape: {val_score_maps.shape}") getLogger('df3d').debug("Saving Results, flipping heatmaps") cid_to_reverse = config["flip_cameras"] # camera id to reverse predictions and heatmaps cidread2cid, cid2cidread = read_camera_order(os.path.join(unlabeled_folder, 'df3d')) cid_read_to_reverse = [cid2cidread[cid] for cid in cid_to_reverse] getLogger('df3d').debug( "Flipping heatmaps for images with cam_id: {}".format( cid_read_to_reverse ) ) val_pred[cid_read_to_reverse, :, :, 0] = ( 1 - val_pred[cid_read_to_reverse, :, :, 0] ) for cam_id in cid_read_to_reverse: for img_id in range(val_score_maps.shape[1]): for j_id in range(val_score_maps.shape[2]): val_score_maps[cam_id, img_id, j_id, :, :] = cv2.flip( val_score_maps[cam_id, img_id, j_id, :, :], 1 ) save_dict( val_pred, os.path.join( args.data_folder, "{}".format(unlabeled_folder), args.output_folder, "./preds_{}.pkl".format(unlabeled_folder_replace), ), ) getLogger('df3d').debug("Finished saving results") else: train_loader, val_loader = create_dataloader() lr = args.lr for epoch in range(args.start_epoch, args.epochs): # lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) getLogger('df3d').debug("\nEpoch: %d | LR: %.8f" % (epoch + 1, lr)) # train for one epoch train_loss, train_acc, train_predictions, train_mse, train_mse_jump = train( train_loader, epoch, model, optimizer, criterion, args ) # # evaluate on validation set valid_loss, valid_acc, val_pred, val_score_maps, mse, jump_acc = validate( val_loader, epoch, model, criterion, args, save_path=args.unlabeled ) scheduler.step(valid_loss) # append logger file logger.append( [ epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc, mse, jump_acc, ] ) # remember best acc and save checkpoint is_best = valid_acc > best_acc best_acc = max(valid_acc, best_acc) save_checkpoint( { "epoch": epoch + 1, "arch": args.arch, "state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "best_acc": best_acc, "optimizer": optimizer.state_dict(), "multiview": args.multiview, "image_shape": args.img_res, "heatmap_shape": args.hm_res, }, val_pred, is_best, checkpoint=args.checkpoint, ) fig = plt.figure() logger.plot(["Train Acc", "Val Acc"]) savefig(os.path.join(args.checkpoint, "log.eps")) plt.close(fig) return val_score_maps, val_pred logger.close()