Ejemplo n.º 1
0
def build_data_loader():
    logger.info("build train dataset")
    # train_dataset
    train_dataset = TrkDataset()
    logger.info("build dataset done")

    train_sampler = None
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.TRAIN.BATCH_SIZE,
                              num_workers=cfg.TRAIN.NUM_WORKERS,
                              pin_memory=True,
                              sampler=train_sampler)
    return train_loader
Ejemplo n.º 2
0
def get_train_dataflow():
    '''
    training dataflow with data augmentation.
    '''
    ds = TrkDataset()
    train_preproc = TrainingDataPreprocessor(cfg)

    if cfg.TRAIN.NUM_WORKERS == 1:
        ds = MapData(ds, train_preproc)
    else:
        ds = MultiProcessMapDataZMQ(ds, cfg.TRAIN.NUM_WORKERS, train_preproc)
    ds = BatchData(ds, cfg.TRAIN.BATCH_SIZE)
    return TPIterableDataset(ds)
Ejemplo n.º 3
0
def build_data_loader():
    logger.info("build train dataset")
    reload(pysot.core.config)
    reload(pysot.datasets.dataset)
    from pysot.datasets.dataset import TrkDataset
    train_dataset = TrkDataset()
    logger.info("build dataset done")

    train_sampler = None
    if get_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.TRAIN.BATCH_SIZE,
                              num_workers=cfg.TRAIN.NUM_WORKERS,
                              pin_memory=True,
                              sampler=train_sampler)
    return train_loader
Ejemplo n.º 4
0
def build_data_loader():
    '''
    :return: 建立train_loader,参数在config中指定
    '''

    logger.info("build train dataset")
    # train_dataset
    train_dataset = TrkDataset(
    )  ##为feature map生成anchor的位置信息,通过json文件加载训练数据集(数据集合一视频为单位,保证每个视频至少一个跟踪目标,每个目标跟踪标注信息至少有一帧),设置数据增强的参数
    logger.info("build dataset done")

    train_sampler = None
    if get_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.TRAIN.BATCH_SIZE,
                              num_workers=cfg.TRAIN.NUM_WORKERS,
                              pin_memory=True,
                              sampler=train_sampler)
    return train_loader
Ejemplo n.º 5
0
def build_data_loader(mode='train'):
    if mode == 'train':
        logger.info("build train dataset")
        # train_dataset
        train_dataset = TrkDataset()
        logger.info("build dataset done")
    else:
        logger.info("build val dataset")
        # train_dataset
        train_dataset = ValDataset()
        logger.info("build dataset done")

    train_sampler = None
    if get_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.TRAIN.BATCH_SIZE,
                              num_workers=cfg.TRAIN.NUM_WORKERS,
                              pin_memory=True,
                              sampler=train_sampler)
    return train_loader
Ejemplo n.º 6
0
def build_data_loader():
    logger.info("build train dataset")
    # train_dataset
    train_dataset = TrkDataset()
    logger.info("build dataset done")

    train_sampler = None
    if get_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset)
    # dataset:PyTorch已有的数据读取接口或者自定义的数据接口的输出
    # batchsize:batch块的大小
    # collate_fn:用来处理不同情况下的输入dataset的封装
    # num_workers: 数据导入时需要的进程数量-0表示数据导入从主进程中进行
    # pin_memory:如果是True,dataloader会在返回之前将tensors复制到cuda的固定内存(pinned memory)中
    # sampler:采样器
    # timeout:用来设置数据读取的超时时间的,超过时间没有读取到数据就会报错


    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.TRAIN.BATCH_SIZE,
                              num_workers=cfg.TRAIN.NUM_WORKERS,
                              pin_memory=True,
                              sampler=train_sampler)
    return train_loader