Exemplo n.º 1
0
def main(args):
    logging.basicConfig(level=logging.INFO)

    config = configparser.ConfigParser()
    config_path = os.path.join(args.trained, "result", "config.ini")
    if not os.path.exists(config_path):
        raise Exception("config_path {} does not found".format(config_path))
    logger.info("read {}".format(config_path))
    config.read(config_path, 'UTF-8')

    logger.info("setup devices")
    chainer.global_config.autotune = True
    chainer.config.cudnn_fast_batch_normalization = True

    logger.info("> get dataset {}".format(args.mode))
    mode_dict = {
        "train": "train_set",
        "val": "val_set",
        "test": "test_set",
    }
    return_type = mode_dict[args.mode]

    dataset, hand_param = select_dataset(config, [return_type, "hand_param"])

    logger.info("> hand_param = {}".format(hand_param))
    model = select_model(config, hand_param)

    logger.info("> size of dataset is {}".format(len(dataset)))
    model_path = os.path.expanduser(
        os.path.join(args.trained, "result", "bestmodel.npz"))

    logger.info("> restore model")
    logger.info("> model.device = {}".format(model.device))
    chainer.serializers.load_npz(model_path, model)
    evaluate_ppn(args.trained, model, dataset, hand_param)
def main(args):
    logging.basicConfig(level=logging.INFO)

    config = configparser.ConfigParser()
    config_path = os.path.join(args.trained, "pose", "config.ini")
    if not os.path.exists(config_path):
        raise Exception("config_path {} does not found".format(config_path))
    logger.info("read {}".format(config_path))
    config.read(config_path, 'UTF-8')

    logger.info("setup devices")
    chainer.global_config.autotune = True
    chainer.config.cudnn_fast_batch_normalization = True

    logger.info("> get dataset {}".format(args.mode))
    mode_dict = {
        "train": "train_set",
        "val": "val_set",
        "test": "test_set",
    }
    return_type = mode_dict[args.mode]

    dataset, hand_param = select_dataset(config, [return_type, "hand_param"])

    logger.info("> hand_param = {}".format(hand_param))
    model = select_model(config, hand_param)
    transformed_dataset = TransformDataset(dataset, model.encode)

    logger.info("> size of dataset is {}".format(len(dataset)))
    model_path = os.path.expanduser(
        os.path.join(args.trained, "pose", "bestmodel.npz"))

    logger.info("> restore model")
    logger.info("> model.device = {}".format(model.device))
    chainer.serializers.load_npz(model_path, model)

    if config["model"]["name"] in ["ppn", "ppn_edge"]:
        if args.evaluate:
            evaluate_ppn(model, dataset, hand_param)
        else:
            predict_ppn(model, dataset, hand_param)
    elif config["model"]["name"] in ["rhd", "hm", "orinet"]:
        predict_heatmap(model, dataset, hand_param)
    elif config["model"]["name"] == "ganerated":
        predict_ganerated(model, dataset, hand_param)
    else:
        predict_sample(model, dataset, hand_param)
# # visualize transformed dataset

# +
from collections import defaultdict

from chainer.datasets import TransformDataset
from pose.models.selector import select_model
from pose.hand_dataset import common_dataset

config = defaultdict(dict)
config["model"]["name"] = "ganerated"
hand_param["inH"] = 224
hand_param["inW"] = 224
hand_param["inC"] = 3
hand_param["n_joints"] = common_dataset.NUM_KEYPOINTS
hand_param["edges"] = common_dataset.EDGES
model = select_model(config, hand_param)
transform_dataset = TransformDataset(dataset, model.encode)

# +
print(current_idx)

rgb, hm, intermediate3d, rgb_joint = transform_dataset.get_example(current_idx)
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(121)
ax.imshow(np.max(hm, axis=0))
ax2 = fig.add_subplot(122, projection="3d")
ax2.scatter(*rgb_joint[:, ::-1].transpose())
def main():
    args = parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    config = configparser.ConfigParser()

    logger.info("read {}".format(args.config_path))
    config.read(args.config_path, "UTF-8")
    logger.info("setup devices")
    if chainer.backends.cuda.available:
        devices = setup_devices(config["training_param"]["gpus"])
    else:
        # cpu run
        devices = {"main": -1}
    seed = config.getint("training_param", "seed")
    logger.info("set random seed {}".format(seed))
    set_random_seed(devices, seed)

    result = os.path.expanduser(config["result"]["dir"])
    destination = os.path.join(result, "pose")
    logger.info("> copy code to {}".format(os.path.join(result, "src")))
    save_files(result)
    logger.info("> copy config file to {}".format(destination))
    if not os.path.exists(destination):
        os.makedirs(destination)
    shutil.copy(args.config_path, os.path.join(destination, "config.ini"))

    logger.info("{} chainer debug".format("enable" if args.debug else "disable"))
    chainer.set_debug(args.debug)
    chainer.global_config.autotune = True
    chainer.cuda.set_max_workspace_size(11388608)
    chainer.config.cudnn_fast_batch_normalization = True

    logger.info("> get dataset")
    train_set, val_set, hand_param = select_dataset(config, return_data=["train_set", "val_set", "hand_param"])
    model = select_model(config, hand_param)

    logger.info("> transform dataset")
    train_set = TransformDataset(train_set, model.encode)
    val_set = TransformDataset(val_set, model.encode)
    logger.info("> size of train_set is {}".format(len(train_set)))
    logger.info("> size of val_set is {}".format(len(val_set)))
    logger.info("> create iterators")
    batch_size = config.getint("training_param", "batch_size")
    n_processes = config.getint("training_param", "n_processes")

    train_iter = chainer.iterators.MultiprocessIterator(
        train_set, batch_size,
        n_processes=n_processes
    )
    test_iter = chainer.iterators.MultiprocessIterator(
        val_set, batch_size,
        repeat=False, shuffle=False,
        n_processes=n_processes,
    )

    logger.info("> setup optimizer")
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))

    logger.info("> setup parallel updater devices={}".format(devices))
    updater = training.updaters.ParallelUpdater(train_iter, optimizer, devices=devices)

    logger.info("> setup trainer")
    trainer = training.Trainer(
        updater,
        (config.getint("training_param", "train_iter"), "iteration"),
        destination,
    )

    logger.info("> setup extensions")
    trainer.extend(
        extensions.LinearShift("lr",
                               value_range=(config.getfloat("training_param", "learning_rate"), 0),
                               time_range=(0, config.getint("training_param", "train_iter"))
                               ),
        trigger=(1, "iteration")
    )

    trainer.extend(extensions.Evaluator(test_iter, model, device=devices["main"]))
    if extensions.PlotReport.available():
        trainer.extend(extensions.PlotReport([
            "main/loss", "validation/main/loss",
        ], "epoch", file_name="loss.png"))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(extensions.PrintReport([
        "epoch", "elapsed_time", "lr",
        "main/loss", "validation/main/loss",
        "main/loss_resp", "validation/main/loss_resp",
        "main/loss_iou", "validation/main/loss_iou",
        "main/loss_coor", "validation/main/loss_coor",
        "main/loss_size", "validation/main/loss_size",
        "main/loss_limb", "validation/main/loss_limb",
        "main/loss_vect_cos", "validation/main/loss_vect_cos",
        "main/loss_vect_norm", "validation/main/loss_vect_cos",
        "main/loss_vect_square", "validation/main/loss_vect_square",
    ]))
    trainer.extend(extensions.ProgressBar())

    trainer.extend(extensions.snapshot(filename="best_snapshot"),
                   trigger=training.triggers.MinValueTrigger("validation/main/loss"))
    trainer.extend(extensions.snapshot_object(model, filename="bestmodel.npz"),
                   trigger=training.triggers.MinValueTrigger("validation/main/loss"))

    logger.info("> start training")
    trainer.run()