Esempio n. 1
0
def load_train_config(args):
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    if args.log_step != -1:
        cfg.TRAIN.LOG_STEP = args.log_step
    if args.save_step != -1:
        cfg.TRAIN.SAVE_STEP = args.save_step
    if args.eval_step != -1:
        cfg.TRAIN.EVAL_STEP = args.eval_step

    if args.resume:
        cfg.TRAIN.RESUME = True
    if not args.use_tensorboard:
        cfg.TRAIN.USE_TENSORBOARD = False

    if args.gpus != -1:
        cfg.NUM_GPUS = args.gpus
    if args.nodes != -1:
        cfg.NUM_NODES = args.nodes
    if args.nr != -1:
        cfg.RANK_ID = args.nr

    num_gpus = cfg.NUM_GPUS
    if num_gpus > 1:
        cfg.OPTIMIZER.LR *= num_gpus

    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT_DIR):
        os.makedirs(cfg.OUTPUT_DIR)

    return cfg
Esempio n. 2
0
def test_slowonly():
    cfg.merge_from_file('configs/slowonly_r3d50_ucf101_rgb_224x4_dense.yaml')
    cfg.freeze()

    model = resnet3d_50_slowonly(cfg)
    print(model)

    data = torch.randn(1, 3, 4, 224, 224)
    outputs = model(data)
    print(outputs.shape)

    assert outputs.shape == (1, 2048, 4, 7, 7)
Esempio n. 3
0
def load_config(args):
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    if args.gpus > 1:
        cfg.OPTIMIZER.LR *= args.gpus
        cfg.OPTIMIZER.WEIGHT_DECAY *= args.gpus
        cfg.LR_SCHEDULER.COSINE_ANNEALING_LR.MINIMAL_LR *= args.gpus

    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT.DIR):
        os.makedirs(cfg.OUTPUT.DIR)

    return cfg
Esempio n. 4
0
def test_slowfast():
    cfg.merge_from_file('configs/slowfast_r3d50_ucf101_rgb_224x32_dense.yaml')
    cfg.freeze()

    model = resnet3d_50_slowfast(cfg)
    print(model)

    data = torch.randn(1, 3, 32, 224, 224)
    outputs = model(data)
    print(len(outputs))
    print(outputs[0].shape)
    print(outputs[1].shape)

    assert len(outputs) == 2
    assert outputs[0].shape == (1, 2048, 4, 7, 7)
    assert outputs[1].shape == (1, 256, 32, 7, 7)
Esempio n. 5
0
def load_test_config(args):
    if not os.path.isfile(args.config_file) or not os.path.isfile(args.pretrained):
        raise ValueError('需要输入配置文件和预训练模型路径')

    cfg.merge_from_file(args.config_file)
    cfg.MODEL.PRETRAINED = args.pretrained
    cfg.OUTPUT_DIR = args.output

    if args.gpus != -1:
        cfg.NUM_GPUS = args.gpus
    if args.nodes != -1:
        cfg.NODES = args.nodes
    if args.nr != -1:
        cfg.RANK = args.nr
    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT_DIR):
        os.makedirs(cfg.OUTPUT_DIR)

    return cfg
Esempio n. 6
0
def main():
    parser = argparse.ArgumentParser(description='TSN Test With PyTorch')
    parser.add_argument("config_file", default="", metavar="CONFIG_FILE",
                        help="path to config file", type=str)
    parser.add_argument('pretrained', default="", metavar='PRETRAINED_FILE',
                        help="path to pretrained model", type=str)
    parser.add_argument('--output', default="./outputs/test", type=str)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()

    if not os.path.isfile(args.config_file) or not os.path.isfile(args.pretrained):
        raise ValueError('需要输入配置文件和预训练模型路径')

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.MODEL.PRETRAINED = args.pretrained
    cfg.OUTPUT.DIR = args.output
    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT.DIR):
        os.makedirs(cfg.OUTPUT.DIR)
    logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    test(cfg)
Esempio n. 7
0
File: train.py Progetto: ZJCV/TRN
def main():
    parser = argparse.ArgumentParser(description='TSN Training With PyTorch')
    parser.add_argument("--config_file",
                        default="",
                        metavar="FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Print logs every log_step')
    parser.add_argument('--save_step',
                        default=2500,
                        type=int,
                        help='Save checkpoint every save_step')
    parser.add_argument('--stop_save', default=False, action='store_true')
    parser.add_argument(
        '--eval_step',
        default=2500,
        type=int,
        help='Evaluate dataset every eval_step, disabled when eval_step < 0')
    parser.add_argument('--stop_eval', default=False, action='store_true')
    parser.add_argument('--resume',
                        default=False,
                        action='store_true',
                        help='Resume training')
    parser.add_argument('--use_tensorboard', default=1, type=int)

    parser.add_argument('-n',
                        '--nodes',
                        default=1,
                        type=int,
                        metavar='N',
                        help='number of machines (default: 1)')
    parser.add_argument('-g',
                        '--gpus',
                        default=1,
                        type=int,
                        help='number of gpus per node')
    parser.add_argument('-nr',
                        '--nr',
                        default=0,
                        type=int,
                        help='ranking within the nodes')

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT.DIR):
        os.makedirs(cfg.OUTPUT.DIR)
    logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    args.world_size = args.gpus * args.nodes
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '14028'
    mp.spawn(train, nprocs=args.gpus, args=(args, cfg))