def main(): args = parse_train_args() cfg = load_config(args) logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR) logger.info(args) logger.info("Environment info:\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) if args.config_file: with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) launch_job(args, cfg, train)
def main(): args = parse_test_args() cfg = load_test_config(args) logger = logging.setup_logging(__name__, output_dir=cfg.OUTPUT_DIR) logger.info(args) logger.info("Environment info:\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) if args.config_file: with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) launch_job(cfg=cfg, init_method=args.init_method, func=test)
def main(): parser = argparse.ArgumentParser(description='TSN Test With PyTorch') parser.add_argument("rgb_config_file", default="", metavar="RGB_CONFIG_FILE", help="path to config file", type=str) parser.add_argument('rgb_pretrained', default="", metavar='RGB_PRETRAINED_FILE', help="path to pretrained model", type=str) parser.add_argument("rgbdiff_config_file", default="", metavar="RGBDIFF_CONFIG_FILE", help="path to config file", type=str) parser.add_argument('rgbdiff_pretrained', default="", metavar='RGBDIFF_PRETRAINED_FILE', help="path to pretrained model", type=str) parser.add_argument('--output', default="./outputs/test", type=str) args = parser.parse_args() if not os.path.isfile(args.rgb_config_file) and not os.path.isfile( args.rgb_pretrained): raise ValueError('需要输入RGB模态配置文件和预训练模型路径') if not os.path.isfile(args.rgbdiff_config_file) or not os.path.isfile( args.rgbdiff_pretrained): raise ValueError('需要输入RGBDIFF模态配置文件和预训练模型路径') if not os.path.exists(args.output): os.makedirs(args.output) logger = logging.setup_logging(output_dir=args.output) logger.info(args) logger.info("Environment info:\n" + collect_env_info()) test(args)
def main(): parser = argparse.ArgumentParser(description='TSN Test With PyTorch') parser.add_argument("config_file", default="", metavar="CONFIG_FILE", help="path to config file", type=str) parser.add_argument('pretrained', default="", metavar='PRETRAINED_FILE', help="path to pretrained model", type=str) parser.add_argument('--output', default="./outputs/test", type=str) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() if not os.path.isfile(args.config_file) or not os.path.isfile(args.pretrained): raise ValueError('需要输入配置文件和预训练模型路径') cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.MODEL.PRETRAINED = args.pretrained cfg.OUTPUT.DIR = args.output cfg.freeze() if not os.path.exists(cfg.OUTPUT.DIR): os.makedirs(cfg.OUTPUT.DIR) logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR) logger.info(args) logger.info("Environment info:\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) if args.config_file: with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) test(cfg)
def main(): parser = argparse.ArgumentParser(description='TSN Training With PyTorch') parser.add_argument("--config_file", default="", metavar="FILE", help="path to config file", type=str) parser.add_argument('--log_step', default=10, type=int, help='Print logs every log_step') parser.add_argument('--save_step', default=2500, type=int, help='Save checkpoint every save_step') parser.add_argument('--stop_save', default=False, action='store_true') parser.add_argument( '--eval_step', default=2500, type=int, help='Evaluate dataset every eval_step, disabled when eval_step < 0') parser.add_argument('--stop_eval', default=False, action='store_true') parser.add_argument('--resume', default=False, action='store_true', help='Resume training') parser.add_argument('--use_tensorboard', default=1, type=int) parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', help='number of machines (default: 1)') parser.add_argument('-g', '--gpus', default=1, type=int, help='number of gpus per node') parser.add_argument('-nr', '--nr', default=0, type=int, help='ranking within the nodes') parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() if args.config_file: cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() if not os.path.exists(cfg.OUTPUT.DIR): os.makedirs(cfg.OUTPUT.DIR) logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR) logger.info(args) logger.info("Environment info:\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) if args.config_file: with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) args.world_size = args.gpus * args.nodes os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '14028' mp.spawn(train, nprocs=args.gpus, args=(args, cfg))