Example #1
0
def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size,
                   num_workers, ctx):
    """Get dataloader."""
    width, height = data_shape, data_shape
    num_class = len(train_dataset.classes)
    batchify_fn = Tuple([Stack() for _ in range(6)
                         ])  # stack image, cls_targets, box_targets
    train_loader = gluon.data.DataLoader(train_dataset.transform(
        CenterNetDefaultTrainTransform(width,
                                       height,
                                       num_class=num_class,
                                       scale_factor=net.scale)),
                                         batch_size,
                                         True,
                                         batchify_fn=batchify_fn,
                                         last_batch='rollover',
                                         num_workers=num_workers)
    val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
    val_loader = gluon.data.DataLoader(val_dataset.transform(
        CenterNetDefaultValTransform(width, height)),
                                       batch_size,
                                       False,
                                       batchify_fn=val_batchify_fn,
                                       last_batch='keep',
                                       num_workers=num_workers)
    return train_loader, val_loader
Example #2
0
def get_dataloader(val_dataset, data_shape, batch_size, num_workers):
    """Get dataloader."""
    width, height = data_shape, data_shape
    batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
    val_loader = gluon.data.DataLoader(
        val_dataset.transform(CenterNetDefaultValTransform(width, height)),
        batch_size,
        False,
        last_batch='keep',
        num_workers=num_workers,
        batchify_fn=batchify_fn,
    )
    return val_loader