Example #1
0
def main(args):
    logger.debug("Creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = hg(
        num_stacks=args.stacks,
        num_blocks=args.blocks,
        num_classes=args.num_classes,
        num_feats=args.features,
        inplanes=args.inplanes,
        init_stride=args.stride,
    )
    model = on_cuda(torch.nn.DataParallel(model))
    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           verbose=True,
                                                           patience=5)

    if args.resume:
        load_weights(model, args.resume)

    logger.debug("Total params: %.2fM" %
                 (sum(p.numel() for p in model.parameters()) / 1000000.0))

    if args.unlabeled:
        loader = DataLoader(
            DrosophilaDataset(
                data_folder=args.data_folder,
                train=False,
                sigma=args.sigma,
                session_id_train_list=None,
                folder_train_list=None,
                img_res=args.img_res,
                hm_res=args.hm_res,
                augmentation=False,
                evaluation=True,
                unlabeled=args.unlabeled,
                num_classes=args.num_classes,
                max_img_id=min(get_max_img_id(args.unlabeled),
                               args.max_img_id),
                output_folder=args.output_folder,
            ),
            batch_size=args.test_batch,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=False,
            drop_last=False,
        )
        pred, heatmap = process_folder(
            model,
            loader,
            args.unlabeled,
            args.output_folder,
            args.overwrite,
            num_classes=args.num_classes,
            acc_joints=args.acc_joints,
        )
        return pred, heatmap
    else:
        train_loader, val_loader = create_dataloader()
        lr = args.lr
        best_acc = 0
        for epoch in range(args.start_epoch, args.epochs):
            logger.debug("\nEpoch: %d | LR: %.8f" % (epoch + 1, lr))

            _, _, _, _, _ = step(
                loader=train_loader,
                model=model,
                optimizer=optimizer,
                mode=Mode.train,
                heatmap=False,
                epoch=epoch,
                num_classes=args.num_classes,
                acc_joints=args.acc_joints,
            )
            val_pred, _, val_loss, val_acc, val_mse = step(
                loader=val_loader,
                model=model,
                optimizer=optimizer,
                mode=Mode.test,
                heatmap=False,
                epoch=epoch,
                num_classes=args.num_classes,
                acc_joints=args.acc_joints,
            )
            scheduler.step(val_loss)
            is_best = val_acc > best_acc
            best_acc = max(val_acc, best_acc)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "best_acc": best_acc,
                    "optimizer": optimizer.state_dict(),
                    "image_shape": args.img_res,
                    "heatmap_shape": args.hm_res,
                },
                val_pred,
                is_best,
                checkpoint=args.checkpoint,
                snapshot=args.snapshot,
            )
Example #2
0
def main(args):
    global best_acc

    # create model
    getLogger('df3d').debug("Creating model '{}', stacks={}, blocks={}".format(
            args.arch, args.stacks, args.blocks
        )
    )
    model = models.__dict__[args.arch](
        num_stacks=args.stacks,
        num_blocks=args.blocks,
        num_classes=args.num_classes,
        num_feats=args.features,
        inplanes=args.inplanes,
        init_stride=args.stride,
    )

    model = torch.nn.DataParallel(model).cuda()
    criterion = torch.nn.MSELoss(reduction='mean').cuda()  # deprecated: size_average=True
    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, patience=5
    )

    # optionally resume from a checkpoint
    title = "Drosophila-" + args.arch
    if args.resume:
        if isfile(args.resume):
            getLogger('df3d').debug("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if "mpii" in args.resume and not args.unlabeled:  # weights for sh trained on mpii dataset
                getLogger('df3d').debug("Removing input/output layers")
                ignore_weight_list_template = [
                    "module.score.{}.bias",
                    "module.score.{}.weight",
                    "module.score_.{}.weight",
                ]
                ignore_weight_list = list()
                for i in range(8):
                    for template in ignore_weight_list_template:
                        ignore_weight_list.append(template.format(i))
                for k in ignore_weight_list:
                    if k in checkpoint["state_dict"]:
                        checkpoint["state_dict"].pop(k)

                state = model.state_dict()
                state.update(checkpoint["state_dict"])
                getLogger('df3d').debug(model.state_dict())
                getLogger('df3d').debug(checkpoint["state_dict"])
                model.load_state_dict(state, strict=False)
            elif "mpii" in args.resume and args.unlabeled:
                model.load_state_dict(checkpoint['state_dict'], strict=False)
            else:
                pretrained_dict = checkpoint["state_dict"]
                model.load_state_dict(pretrained_dict, strict=True)

                args.start_epoch = checkpoint["epoch"]
                args.img_res = checkpoint["image_shape"]
                args.hm_res = checkpoint["heatmap_shape"]
                getLogger('df3d').debug("Loading the optimizer")
                getLogger('df3d').debug(
                    "Setting image resolution and heatmap resolution: {} {}".format(
                        args.img_res, args.hm_res
                    )
                )
            getLogger('df3d').debug(
                "Loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                )
            )
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            raise FileNotFoundError

    # prepare loggers
    if not args.unlabeled:
        logger = Logger(join(args.checkpoint, "log.txt"), title=title)
        logger.set_names(
            [
                "Epoch",
                "LR",
                "Train Loss",
                "Val Loss",
                "Train Acc",
                "Val Acc",
                "Val Mse",
                "Val Jump",
            ]
        )

    # cudnn.benchmark = True
    getLogger('df3d').debug("Total params: %.2fM" % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    if args.unlabeled:
        if args.unlabeled[0] == '/': #wtf why does it have a slash before it where did that come from?
            args.unlabeled = args.unlabeled[1:]
        unlabeled_folder = args.unlabeled
        print("UNLABELED FOLDER:")
        print(unlabeled_folder)
        max_img_id = get_max_img_id(unlabeled_folder)
        try:
            max_img_id = min(max_img_id, args.num_images_max-1)
        except:
            pass
        getLogger('df3d').debug('Going to process {} images'.format(max_img_id+1))
        unlabeled_loader = DataLoader(
            deepfly.pose2d.datasets.Drosophila(
                data_folder=args.data_folder,
                train=False,
                sigma=args.sigma,
                session_id_train_list=None,
                folder_train_list=None,
                img_res=args.img_res,
                hm_res=args.hm_res,
                augmentation=False,
                evaluation=True,
                unlabeled=unlabeled_folder,
                num_classes=args.num_classes,
                max_img_id=max_img_id,
            ),
            batch_size=args.test_batch,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=False,
            drop_last=False,
        )

        valid_loss, valid_acc, val_pred, val_score_maps, mse, jump_acc = validate(
            unlabeled_loader, 0, model, criterion, args, save_path=unlabeled_folder
        )
        unlabeled_folder_replace = unlabeled_folder.replace("/", "-")
        getLogger('df3d').debug(f"val_score_maps have shape: {val_score_maps.shape}")

        getLogger('df3d').debug("Saving Results, flipping heatmaps")
        cid_to_reverse = config["flip_cameras"]  # camera id to reverse predictions and heatmaps
        cidread2cid, cid2cidread = read_camera_order(os.path.join(unlabeled_folder, 'df3d'))
        cid_read_to_reverse = [cid2cidread[cid] for cid in cid_to_reverse]
        getLogger('df3d').debug(
            "Flipping heatmaps for images with cam_id: {}".format(
                cid_read_to_reverse
            )
        )
        val_pred[cid_read_to_reverse, :, :, 0] = (
            1 - val_pred[cid_read_to_reverse, :, :, 0]
        )
        for cam_id in cid_read_to_reverse:
            for img_id in range(val_score_maps.shape[1]):
                for j_id in range(val_score_maps.shape[2]):
                    val_score_maps[cam_id, img_id, j_id, :, :] = cv2.flip(
                        val_score_maps[cam_id, img_id, j_id, :, :], 1
                    )

        save_dict(
            val_pred,
            os.path.join(
                args.data_folder,
                "{}".format(unlabeled_folder),
                args.output_folder,
                "./preds_{}.pkl".format(unlabeled_folder_replace),
            ),
        )

        getLogger('df3d').debug("Finished saving results")
    else:
        train_loader, val_loader = create_dataloader()
        lr = args.lr
        for epoch in range(args.start_epoch, args.epochs):
            # lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
            getLogger('df3d').debug("\nEpoch: %d | LR: %.8f" % (epoch + 1, lr))

            # train for one epoch
            train_loss, train_acc, train_predictions, train_mse, train_mse_jump = train(
                train_loader, epoch, model, optimizer, criterion, args
            )
            # # evaluate on validation set
            valid_loss, valid_acc, val_pred, val_score_maps, mse, jump_acc = validate(
                val_loader, epoch, model, criterion, args, save_path=args.unlabeled
            )
            scheduler.step(valid_loss)
            # append logger file
            logger.append(
                [
                    epoch + 1,
                    lr,
                    train_loss,
                    valid_loss,
                    train_acc,
                    valid_acc,
                    mse,
                    jump_acc,
                ]
            )

            # remember best acc and save checkpoint
            is_best = valid_acc > best_acc
            best_acc = max(valid_acc, best_acc)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "best_acc": best_acc,
                    "optimizer": optimizer.state_dict(),
                    "multiview": args.multiview,
                    "image_shape": args.img_res,
                    "heatmap_shape": args.hm_res,
                },
                val_pred,
                is_best,
                checkpoint=args.checkpoint,
            )

            fig = plt.figure()
            logger.plot(["Train Acc", "Val Acc"])
            savefig(os.path.join(args.checkpoint, "log.eps"))
            plt.close(fig)
    return val_score_maps, val_pred
    logger.close()