Пример #1
0
def main():
    args = parse_train_args()
    cfg = load_config(args)

    logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    launch_job(args, cfg, train)
Пример #2
0
def main():
    args = parse_test_args()
    cfg = load_test_config(args)

    logger = logging.setup_logging(__name__, output_dir=cfg.OUTPUT_DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    launch_job(cfg=cfg, init_method=args.init_method, func=test)
Пример #3
0
def main():
    parser = argparse.ArgumentParser(description='TSN Test With PyTorch')
    parser.add_argument("rgb_config_file",
                        default="",
                        metavar="RGB_CONFIG_FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument('rgb_pretrained',
                        default="",
                        metavar='RGB_PRETRAINED_FILE',
                        help="path to pretrained model",
                        type=str)
    parser.add_argument("rgbdiff_config_file",
                        default="",
                        metavar="RGBDIFF_CONFIG_FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument('rgbdiff_pretrained',
                        default="",
                        metavar='RGBDIFF_PRETRAINED_FILE',
                        help="path to pretrained model",
                        type=str)
    parser.add_argument('--output', default="./outputs/test", type=str)
    args = parser.parse_args()

    if not os.path.isfile(args.rgb_config_file) and not os.path.isfile(
            args.rgb_pretrained):
        raise ValueError('需要输入RGB模态配置文件和预训练模型路径')
    if not os.path.isfile(args.rgbdiff_config_file) or not os.path.isfile(
            args.rgbdiff_pretrained):
        raise ValueError('需要输入RGBDIFF模态配置文件和预训练模型路径')

    if not os.path.exists(args.output):
        os.makedirs(args.output)
    logger = logging.setup_logging(output_dir=args.output)
    logger.info(args)
    logger.info("Environment info:\n" + collect_env_info())

    test(args)
Пример #4
0
def main():
    parser = argparse.ArgumentParser(description='TSN Test With PyTorch')
    parser.add_argument("config_file", default="", metavar="CONFIG_FILE",
                        help="path to config file", type=str)
    parser.add_argument('pretrained', default="", metavar='PRETRAINED_FILE',
                        help="path to pretrained model", type=str)
    parser.add_argument('--output', default="./outputs/test", type=str)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()

    if not os.path.isfile(args.config_file) or not os.path.isfile(args.pretrained):
        raise ValueError('需要输入配置文件和预训练模型路径')

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.MODEL.PRETRAINED = args.pretrained
    cfg.OUTPUT.DIR = args.output
    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT.DIR):
        os.makedirs(cfg.OUTPUT.DIR)
    logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    test(cfg)
Пример #5
0
Файл: train.py Проект: ZJCV/TRN
def main():
    parser = argparse.ArgumentParser(description='TSN Training With PyTorch')
    parser.add_argument("--config_file",
                        default="",
                        metavar="FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Print logs every log_step')
    parser.add_argument('--save_step',
                        default=2500,
                        type=int,
                        help='Save checkpoint every save_step')
    parser.add_argument('--stop_save', default=False, action='store_true')
    parser.add_argument(
        '--eval_step',
        default=2500,
        type=int,
        help='Evaluate dataset every eval_step, disabled when eval_step < 0')
    parser.add_argument('--stop_eval', default=False, action='store_true')
    parser.add_argument('--resume',
                        default=False,
                        action='store_true',
                        help='Resume training')
    parser.add_argument('--use_tensorboard', default=1, type=int)

    parser.add_argument('-n',
                        '--nodes',
                        default=1,
                        type=int,
                        metavar='N',
                        help='number of machines (default: 1)')
    parser.add_argument('-g',
                        '--gpus',
                        default=1,
                        type=int,
                        help='number of gpus per node')
    parser.add_argument('-nr',
                        '--nr',
                        default=0,
                        type=int,
                        help='ranking within the nodes')

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT.DIR):
        os.makedirs(cfg.OUTPUT.DIR)
    logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    args.world_size = args.gpus * args.nodes
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '14028'
    mp.spawn(train, nprocs=args.gpus, args=(args, cfg))