예제 #1
0
def test_train_iter2():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 1
    train_meter = ClevrerTrainMeter(5, cfg)
    train_meter.update_stats(top1_err=40.0,
                             top5_err=20.0,
                             mc_opt_err=55.0,
                             mc_q_err=78.0,
                             loss_des=6.3,
                             loss_mc=0.2,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    train_meter.update_stats(top1_err=41.0,
                             top5_err=21.0,
                             mc_opt_err=56.0,
                             mc_q_err=86.0,
                             loss_des=6.7,
                             loss_mc=1.2,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=1, cur_iter=1)
    assert stats['top1_err'] == 41.0
    assert stats['top5_err'] == 21.0
    assert stats['mc_opt_err'] == 56.0
    assert stats['mc_q_err'] == 86.0
    assert stats['loss_des'] == 6.7
    assert stats['loss_mc'] == 1.2
예제 #2
0
def load_config(args):
    """
    Given the arguemnts, load and initialize the configs.
    Args:
        args (argument): arguments includes `shard_id`, `num_shards`,
            `init_method`, `cfg_file`, and `opts`.
    """
    # Setup cfg.
    cfg = get_cfg()
    # Load config from cfg.
    if args.cfg_file is not None:
        cfg.merge_from_file(args.cfg_file)
    # Load config from command line, overwrite config from opts.
    if args.opts is not None:
        cfg.merge_from_list(args.opts)

    # Inherit parameters from args.
    if hasattr(args, "num_shards") and hasattr(args, "shard_id"):
        cfg.NUM_SHARDS = args.num_shards
        cfg.SHARD_ID = args.shard_id
    if hasattr(args, "rng_seed"):
        cfg.RNG_SEED = args.rng_seed
    if hasattr(args, "output_dir"):
        cfg.OUTPUT_DIR = args.output_dir

    cu.make_checkpoint_dir(cfg.OUTPUT_DIR)
    return cfg
예제 #3
0
def build_teacher_model(cfg, gpu_id=None):
    """
    Builds the video model.
    Args:
        cfg (configs): configs that contains the hyper-parameters to build the
        backbone. Details can be seen in slowfast/config/defaults.py.
    """
    # Load teacher cfg
    teacher_cfg = get_cfg()
    teacher_cfg.merge_from_file(cfg.KD.CONFIG)
    teacher_cfg.KD.ENABLE = True
    # Construct the model

    model = MODEL_REGISTRY.get(teacher_cfg.MODEL.MODEL_NAME)(teacher_cfg)
    if cfg.NUM_GPUS:
        if gpu_id is None:
            # Determine the GPU used by the current process
            cur_device = torch.cuda.current_device()
        else:
            cur_device = gpu_id
        # Transfer the model to the current GPU device
        model = model.cuda(device=cur_device)
    # Use multi-process data parallel model in the multi-gpu setting
    if cfg.NUM_GPUS > 1:
        # Make model replica operate on the current device
        model = torch.nn.parallel.DistributedDataParallel(
            module=model, device_ids=[cur_device], output_device=cur_device
        )
    return model
예제 #4
0
def get_init_params_cfg():
    #Init parameters
    cfg = get_cfg()
    cfg.TRAIN.ENABLE = True
    cfg.TRAIN.ONLY_DES = True
    cfg.TRAIN.DATASET = "Clevrer_des"
    cfg.TRAIN.BATCH_SIZE = 8
    cfg.TRAIN.EVAL_PERIOD = 1
    cfg.TRAIN.CHECKPOINT_PERIOD = 1
    cfg.TRAIN.AUTO_RESUME = True
    cfg.TRAIN.TRAIN_STATS_FILE = "./train_stats_hyper.txt"

    cfg.DATA.RESIZE_H = 224
    cfg.DATA.RESIZE_W = 224
    cfg.DATA.NUM_FRAMES = 25
    cfg.DATA.SAMPLING_RATE = 5
    cfg.DATA.TRAIN_JITTER_SCALES = [256, 320]
    cfg.DATA.TRAIN_CROP_SIZE = 224
    cfg.DATA.TEST_CROP_SIZE = 224
    cfg.DATA.INPUT_CHANNEL_NUM = [3]
    cfg.DATA.PATH_TO_DATA_DIR = "/datasets/clevrer"
    cfg.DATA.PATH_PREFIX = "/datasets/clevrer"
    cfg.DATA.MAX_TRAIN_LEN = None
    cfg.DATA.MAX_VAL_LEN = None

    cfg.SOLVER.BASE_LR = 0.001
    cfg.SOLVER.LR_POLICY = "cosine"
    cfg.SOLVER.COSINE_END_LR = 0.00001
    cfg.SOLVER.EPOCH_CYCLE = 10.0
    cfg.SOLVER.MAX_EPOCH = 10
    cfg.SOLVER.MOMENTUM = 0.9
    cfg.SOLVER.NESTEROV = True
    cfg.SOLVER.WEIGHT_DECAY = 0.00001
    cfg.SOLVER.WARMUP_EPOCHS = 0.0
    cfg.SOLVER.WARMUP_START_LR = 0.01
    cfg.SOLVER.OPTIMIZING_METHOD = "sgd"

    cfg.MODEL.ARCH = "CNN_SEP_LSTM"
    cfg.MODEL.MODEL_NAME = "CNN_SEP_LSTM"

    cfg.NUM_GPUS = 1
    cfg.LOG_PERIOD = 100
    cfg.OUTPUT_DIR = "./"
    cfg.RNG_SEED = 42

    cfg.WORD_EMB.USE_PRETRAINED_EMB = True
    cfg.WORD_EMB.TRAINABLE = True
    cfg.WORD_EMB.GLOVE_PATH = '/datasets/word_embs/glove.6B/glove.6B.50d.txt'
    cfg.WORD_EMB.EMB_DIM = 50

    cfg.CLEVRERMAIN.LSTM_HID_DIM = 256
    cfg.CLEVRERMAIN.T_DROPOUT = 0.1

    return cfg
예제 #5
0
    def pre_proc_config(cfg: CN, dct: Dict = None):
        """
        Add any pre processing based on cfg
        """
        def upd_sub_mdl(
            cfg: CN,
            sub_mdl_default_cfg: CN,
            sub_mdl_name_key: str,
            sub_mdl_file_key: str,
            sub_mdl_mapper: Dict,
            new_dct: Dict,
        ):
            if new_dct is not None and sub_mdl_name_key in new_dct:
                sub_mdl_name = new_dct[sub_mdl_name_key]
            else:
                sub_mdl_name = CfgProcessor.get_val_from_cfg(
                    cfg, sub_mdl_name_key)

            assert sub_mdl_name in sub_mdl_mapper
            sub_mdl_file = sub_mdl_mapper[sub_mdl_name]
            assert Path(sub_mdl_file).exists()
            CfgProcessor.update_one_full_key(cfg,
                                             {sub_mdl_file_key: sub_mdl_file},
                                             full_key=sub_mdl_file_key)

            sub_mdl_default_cfg.merge_from_file(sub_mdl_file)
            sub_mdl_cfg = yaml.safe_load(sub_mdl_default_cfg.dump())
            sub_mdl_cfg_dct_keep = {k: v for k, v in sub_mdl_cfg.items()}

            return CN(sub_mdl_cfg_dct_keep)

        sf_mdl_cfg_default = get_cfg()
        cfg.sf_mdl = upd_sub_mdl(
            cfg,
            sf_mdl_cfg_default,
            "mdl.sf_mdl_name",
            "mdl.sf_mdl_cfg_file",
            sf_mdl_to_cfg_fpath_dct,
            dct,
        )
        tx_dec_default = get_default_tx_dec_cfg()
        cfg.tx_dec = upd_sub_mdl(
            cfg,
            tx_dec_default,
            "mdl.tx_dec_mdl_name",
            "mdl.tx_dec_cfg_file",
            tx_to_cfg_fpath_dct,
            dct,
        )
        return cfg
예제 #6
0
파일: parser.py 프로젝트: cthorey/SlowFast
def load_config_from_file(cfg_file=None):
    """
    Given the arguemnts, load and initialize the configs.
    Args:
        args (argument): arguments includes `shard_id`, `num_shards`,
            `init_method`, `cfg_file`, and `opts`.
    """
    # Setup cfg.
    cfg = get_cfg()
    # Load config from cfg.
    if cfg_file is not None:
        cfg.merge_from_file(cfg_file)
    # Create the checkpoint dir.
    cu.make_checkpoint_dir(cfg.OUTPUT_DIR)
    return cfg
예제 #7
0
def main():
    args = parser_args()
    print(args)
    cfg_file = args.cfg_file
    checkpoint_file = args.checkpoint
    save_checkpoint_file = args.save
    half_flag = args.half
    cfg = get_cfg()
    cfg.merge_from_file(cfg_file)
    cfg.TEST.CHECKPOINT_FILE_PATH = checkpoint_file

    print("simplifier model!\n")
    with torch.no_grad():
        model = build_model(cfg)
        model.eval()
        cu.load_test_checkpoint(cfg, model)
        if half_flag:
            model.half()
        with open(save_checkpoint_file, 'wb') as file:
            torch.save({"model_state": model.state_dict()}, file)
예제 #8
0
def main():
    cfg = get_cfg()
    # print(cfg)
    split = "train"
    drop_last = True
    cfg.DATA.PATH_TO_DATA_DIR = ''
    cfg.TEST.NUM_ENSEMBLE_VIEWS = 2
    dataset = Kinetics(cfg, split)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=2,
        shuffle=False,
        num_workers=cfg.DATA_LOADER.NUM_WORKERS,
        pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
        drop_last=drop_last,
    )
    for cur_iter, (inputs, labels, _) in enumerate(train_loader):
        print(len(inputs))
        print(inputs[0].shape)
        print(inputs[1].shape)
        print(labels)
예제 #9
0
def test_train_iter4():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 2
    train_meter = ClevrerTrainMeter(5, cfg)
    train_meter.update_stats(top1_err=40.0,
                             top5_err=20.0,
                             mc_opt_err=55.0,
                             mc_q_err=78.0,
                             loss_des=6.3,
                             loss_mc=0.2,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    train_meter.update_stats(top1_err=10.0,
                             top5_err=10.0,
                             mc_opt_err=10.0,
                             mc_q_err=10.0,
                             loss_des=10.0,
                             loss_mc=10.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=1, cur_iter=2)
    assert stats is None
    train_meter.update_stats(top1_err=20.0,
                             top5_err=20.0,
                             mc_opt_err=20.0,
                             mc_q_err=20.0,
                             loss_des=20.0,
                             loss_mc=20.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=0, cur_iter=3)
    assert stats['top1_err'] == 15.0
    assert stats['top5_err'] == 15.0
    assert stats['mc_opt_err'] == 15.0
    assert stats['mc_q_err'] == 15.0
    assert stats['loss_des'] == 15.0
    assert stats['loss_mc'] == 15.0
예제 #10
0
def test_train_iter5():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 3
    train_meter = ClevrerTrainMeter(5, cfg)
    train_meter.update_stats(top1_err=30.0,
                             top5_err=30.0,
                             mc_opt_err=30.0,
                             mc_q_err=30.0,
                             loss_des=30.0,
                             loss_mc=30.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    train_meter.update_stats(top1_err=10.0,
                             top5_err=10.0,
                             mc_opt_err=10.0,
                             mc_q_err=10.0,
                             loss_des=10.0,
                             loss_mc=10.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=1, cur_iter=1)
    assert stats is None
    train_meter.update_stats(top1_err=20.0,
                             top5_err=20.0,
                             mc_opt_err=20.0,
                             mc_q_err=20.0,
                             loss_des=20.0,
                             loss_mc=20.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=0, cur_iter=5)
    assert stats['top1_err'] == 20.0, print(stats['top1_err'])
    assert stats['top5_err'] == 20.0
    assert stats['mc_opt_err'] == 20.0
    assert stats['mc_q_err'] == 20.0
    assert stats['loss_des'] == 20.0
    assert stats['loss_mc'] == 20.0
예제 #11
0
def test_train_epoch_only_des():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 3
    train_meter = ClevrerTrainMeter(5, cfg)
    train_meter.update_stats(top1_err=30.0,
                             top5_err=30.0,
                             mc_opt_err=30.0,
                             mc_q_err=30.0,
                             loss_des=30.0,
                             loss_mc=30.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=0)
    train_meter.update_stats(top1_err=10.0,
                             top5_err=10.0,
                             mc_opt_err=10.0,
                             mc_q_err=10.0,
                             loss_des=10.0,
                             loss_mc=10.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=0)
    train_meter.update_stats(top1_err=20.0,
                             top5_err=20.0,
                             mc_opt_err=20.0,
                             mc_q_err=20.0,
                             loss_des=20.0,
                             loss_mc=20.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=0)
    stats = train_meter.log_epoch_stats(cur_epoch=1)
    assert stats['top1_err'] == 20.0, print(stats['top1_err'])
    assert stats['top5_err'] == 20.0
    assert not 'mc_opt_err' in stats
    assert not 'mc_q_err' in stats
    assert stats['loss_des'] == 20.0
    assert not 'loss_mc' in stats
예제 #12
0
    )
    parser.add_argument(
        "--tag",
        help="tag",
        default="slowfast_r101",
        type=str,
    )
    parser.add_argument(
        "--test_tag",
        help="test_tag",
        default="1",
        type=str,
    )

    args = parser.parse_args()
    cfg = get_cfg()
    # Merge train configs
    cfg.merge_from_file(os.path.join('logdir', args.tag, 'config.yaml'))
    # Merge test configs
    if args.cfg_file is not None:
        cfg.merge_from_file(args.cfg_file)

    print("Test using", args.tag)
    cfg.OUTPUT_DIR = os.path.join('logdir', args.tag)
    cfg.TEST_OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR, "test_" + args.test_tag)
    # Make dir
    if not os.path.exists(cfg.TEST_OUTPUT_DIR):
        os.makedirs(cfg.TEST_OUTPUT_DIR)
    # Save test cfg
    copyfile(args.cfg_file,
             os.path.join(cfg.TEST_OUTPUT_DIR, 'test_config.yaml'))
예제 #13
0
def basic_sf_cfg(rel_yml_path):
    sf_cfg = get_cfg()
    yml_path = (Path(slowfast.__file__)
            .parents[1]/f'configs/{rel_yml_path}')
    sf_cfg.merge_from_file(yml_path)
    return sf_cfg
예제 #14
0
def test_val_epoch():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 4
    val_meter = ClevrerValMeter(5, cfg)
    val_meter.update_stats(top1_err=30.0,
                           top5_err=30.0,
                           mc_opt_err=60.0,
                           mc_q_err=60.0,
                           loss_des=30.0,
                           loss_mc=60.0,
                           mb_size_des=10,
                           mb_size_mc=5)
    val_meter.update_stats(top1_err=10.0,
                           top5_err=10.0,
                           mc_opt_err=10.0,
                           mc_q_err=10.0,
                           loss_des=10.0,
                           loss_mc=10.0,
                           mb_size_des=10,
                           mb_size_mc=5)
    val_meter.update_stats(top1_err=20.0,
                           top5_err=20.0,
                           mc_opt_err=20.0,
                           mc_q_err=20.0,
                           loss_des=20.0,
                           loss_mc=20.0,
                           mb_size_des=10,
                           mb_size_mc=5)
    stats = val_meter.log_epoch_stats(cur_epoch=1)
    assert stats['top1_err'] == 20.0, print(stats['top1_err'])
    assert stats['top5_err'] == 20.0
    assert stats['mc_opt_err'] == 30.0
    assert stats['mc_q_err'] == 30.0
    assert stats['loss_des'] == 20.0
    assert stats['loss_mc'] == 30.0

    val_meter.reset()
    val_meter.update_stats(top1_err=60.0,
                           top5_err=30.0,
                           mc_opt_err=30.0,
                           mc_q_err=30.0,
                           loss_des=30.0,
                           loss_mc=30.0,
                           mb_size_des=10,
                           mb_size_mc=5)
    val_meter.update_stats(top1_err=10.0,
                           top5_err=10.0,
                           mc_opt_err=10.0,
                           mc_q_err=10.0,
                           loss_des=10.0,
                           loss_mc=10.0,
                           mb_size_des=10,
                           mb_size_mc=5)
    val_meter.update_stats(top1_err=20.0,
                           top5_err=20.0,
                           mc_opt_err=20.0,
                           mc_q_err=20.0,
                           loss_des=20.0,
                           loss_mc=20.0,
                           mb_size_des=10,
                           mb_size_mc=5)
    stats = val_meter.log_epoch_stats(cur_epoch=2)
    assert stats['top1_err'] == 30.0, print(stats['top1_err'])
    assert stats['top5_err'] == 20.0
    assert stats['mc_opt_err'] == 20.0
    assert stats['mc_q_err'] == 20.0
    assert stats['loss_des'] == 20.0
    assert stats['loss_mc'] == 20.0
예제 #15
0
def load_config(args):
    with open(args.cfg_file, "r") as f:
        raw_cfg = yaml.load(f, Loader=yaml.CLoader)
    cfg = get_cfg()
    cfg.update(raw_cfg)
    return cfg