Ejemplo n.º 1
0
    def test_default_colate_bad_numpy_types(self):
        import numpy as np

        # Should be a no-op
        arr = np.array(['a', 'b', 'c'])
        default_collate(arr)

        arr = np.array([[['a', 'b', 'c']]])
        self.assertRaises(TypeError, lambda: default_collate(arr))

        arr = np.array([object(), object(), object()])
        self.assertRaises(TypeError, lambda: default_collate(arr))

        arr = np.array([[[object(), object(), object()]]])
        self.assertRaises(TypeError, lambda: default_collate(arr))
Ejemplo n.º 2
0
def eqaCollateSeq2seq(batch):
    transposed = list(zip(*batch))
    idx_batch = default_collate(transposed[0])
    questions_batch = default_collate(transposed[1])
    answers_batch = default_collate(transposed[2])
    images_batch = default_collate(transposed[3])
    actions_in_batch = default_collate(transposed[4])
    actions_out_batch = default_collate(transposed[5])
    action_lengths_batch = default_collate(transposed[6])
    mask_batch = default_collate(transposed[7])

    return [
        idx_batch, questions_batch, answers_batch, images_batch,
        actions_in_batch, actions_out_batch, action_lengths_batch, mask_batch
    ]
Ejemplo n.º 3
0
 def my_collate_fn(batch):
     batch = list(filter(lambda x: x is not None, batch))
     if len(batch) == 0:
         print("No valid data!!!")
         batch = [[torch.from_numpy(np.zeros([1, 1]))]]
     return default_collate(batch)
Ejemplo n.º 4
0
def collate_fn(batch):
    # remove audio from the batch
    batch = [(d[0], d[2]) for d in batch]
    return default_collate(batch)
Ejemplo n.º 5
0
def my_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0:
        return None
    else:
        return default_collate(batch)
Ejemplo n.º 6
0
def my_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return dataloader.default_collate(batch)
Ejemplo n.º 7
0
def my_collate_split(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) < 1:
        return None, None, None, None, None, None

    return default_collate(batch)
Ejemplo n.º 8
0
 def __call__(self, batch: Dict[str,
                                torch.Tensor]) -> Dict[str, torch.Tensor]:
     batch = default_collate(batch)
     batch = cutmix(batch, self.alpha)
     return batch
Ejemplo n.º 9
0
 def collate_fn(batch):
     batch = [s for s in batch if s is not None]
     return default_collate(batch)
Ejemplo n.º 10
0
def center_collate(batch):
    "Puts each data field into a tensor with outer dimension batch size"
    batchindexgetter = itemgetter(2)
    batch = list(filter(lambda x: batchindexgetter(x) % 4 == 0, batch))
    return default_collate(batch)
Ejemplo n.º 11
0
 def dense_collate(data_list):
     batch = Batch()
     for key in data_list[0].keys:
         batch[key] = default_collate([d[key] for d in data_list])
     return batch
Ejemplo n.º 12
0
def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args)

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(
            transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(
            transforms.RandomCutmix(num_classes,
                                    p=1.0,
                                    alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch)
                                               )  # noqa: E731
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    model = torchvision.models.__dict__[args.model](weights=args.weights,
                                                    num_classes=num_classes)
    model.to(device)

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

    if args.norm_weight_decay is None:
        parameters = model.parameters()
    else:
        param_groups = torchvision.ops._utils.split_normalization_params(model)
        wd_groups = [args.norm_weight_decay, args.weight_decay]
        parameters = [{
            "params": p,
            "weight_decay": w
        } for p, w in zip(param_groups, wd_groups) if p]

    opt_name = args.opt.lower()
    if opt_name.startswith("sgd"):
        optimizer = torch.optim.SGD(
            parameters,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(parameters,
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay,
                                        eps=0.0316,
                                        alpha=0.9)
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters,
                                      lr=args.lr,
                                      weight_decay=args.weight_decay)
    else:
        raise RuntimeError(
            f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported."
        )

    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == "steplr":
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs)
    elif args.lr_scheduler == "exponentiallr":
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=args.lr_gamma)
    else:
        raise RuntimeError(
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported.")

    if args.lr_warmup_epochs > 0:
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=args.lr_warmup_decay,
                total_iters=args.lr_warmup_epochs)
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer,
                factor=args.lr_warmup_decay,
                total_iters=args.lr_warmup_epochs)
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_lr_scheduler, main_lr_scheduler],
            milestones=[args.lr_warmup_epochs])
    else:
        lr_scheduler = main_lr_scheduler

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    model_ema = None
    if args.model_ema:
        # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
        # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
        alpha = 1.0 - args.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        model_ema = utils.ExponentialMovingAverage(model_without_ddp,
                                                   device=device,
                                                   decay=1.0 - alpha)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
        if model_ema:
            model_ema.load_state_dict(checkpoint["model_ema"])
        if scaler:
            scaler.load_state_dict(checkpoint["scaler"])

    if args.test_only:
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        if model_ema:
            evaluate(model_ema,
                     criterion,
                     data_loader_test,
                     device=device,
                     log_suffix="EMA")
        else:
            evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device,
                        epoch, args, model_ema, scaler)
        lr_scheduler.step()
        evaluate(model, criterion, data_loader_test, device=device)
        if model_ema:
            evaluate(model_ema,
                     criterion,
                     data_loader_test,
                     device=device,
                     log_suffix="EMA")
        if args.output_dir:
            checkpoint = {
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            if model_ema:
                checkpoint["model_ema"] = model_ema.state_dict()
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir,
                                         f"model_{epoch}.pth"))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
Ejemplo n.º 13
0
def my_collate(batch):
    batch = [b for b in batch if b is not None]
    return default_collate(batch)
Ejemplo n.º 14
0
def custom_collate_fn(batch):
    items = list(zip(*batch))
    items[0] = default_collate(items[0])
    items[1] = default_collate(items[1])
    items[2] = default_collate(items[2])
    return items
Ejemplo n.º 15
0
    def __getitem__(self,
                    cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
        """
        Return a new batch, with the batch size automatically determined using the constraints
        of max_frames and max_cuts.
        """
        validate_for_asr(cuts)

        # Sort the cuts by duration so that the first one determines the batch time dimensions.
        cuts = cuts.sort_by_duration(ascending=False)

        # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
        # the supervision boundaries.
        for tnfm in self.cut_transforms:
            cuts = tnfm(cuts)

        # Get a tensor with batched feature matrices, shape (B, T, F)
        # Collation performs auto-padding, if necessary.
        inputs, _ = self.input_strategy(cuts)

        # Get a dict of tensors that encode the positional information about supervisions
        # in the batch of feature matrices. The tensors are named "sequence_idx",
        # "start_frame/sample" and "num_frames/samples".
        supervision_intervals = self.input_strategy.supervision_intervals(cuts)

        # Apply all available transforms on the inputs, i.e. either audio or features.
        # This could be feature extraction, global MVN, SpecAugment, etc.
        segments = torch.stack(list(supervision_intervals.values()), dim=1)
        for tnfm in self.input_transforms:
            inputs = tnfm(inputs, supervision_segments=segments)

        batch = {
            "inputs":
            inputs,
            "supervisions":
            default_collate([{
                "text": supervision.text,
            } for sequence_idx, cut in enumerate(cuts)
                             for supervision in cut.supervisions]),
        }
        # Update the 'supervisions' field with sequence_idx and start/num frames/samples
        batch["supervisions"].update(supervision_intervals)
        if self.return_cuts:
            batch["supervisions"]["cut"] = [
                cut for cut in cuts for sup in cut.supervisions
            ]

        has_word_alignments = all(
            s.alignment is not None and "word" in s.alignment for c in cuts
            for s in c.supervisions)
        if has_word_alignments:
            # TODO: might need to refactor BatchIO API to move the following conditional logic
            #       into these objects (e.g. use like: self.input_strategy.convert_timestamp(),
            #       that returns either num_frames or num_samples depending on the strategy).
            words, starts, ends = [], [], []
            frame_shift = cuts[0].frame_shift
            sampling_rate = cuts[0].sampling_rate
            if frame_shift is None:
                try:
                    frame_shift = self.input_strategy.extractor.frame_shift
                except AttributeError:
                    raise ValueError(
                        "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. "
                    )
            for c in cuts:
                for s in c.supervisions:
                    words.append(
                        [aliword.symbol for aliword in s.alignment["word"]])
                    starts.append([
                        compute_num_frames(
                            aliword.start,
                            frame_shift=frame_shift,
                            sampling_rate=sampling_rate,
                        ) for aliword in s.alignment["word"]
                    ])
                    ends.append([
                        compute_num_frames(
                            aliword.end,
                            frame_shift=frame_shift,
                            sampling_rate=sampling_rate,
                        ) for aliword in s.alignment["word"]
                    ])
            batch["supervisions"]["word"] = words
            batch["supervisions"]["word_start"] = starts
            batch["supervisions"]["word_end"] = ends

        return batch
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return default_collate(batch)
Ejemplo n.º 17
0
 def custom_collate(self, batch):
     return default_collate(batch)
Ejemplo n.º 18
0
def run_conditional(model, dsets):
    if len(dsets.datasets) > 1:
        split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
        dset = dsets.datasets[split]
    else:
        dset = next(iter(dsets.datasets.values()))
    batch_size = 1
    start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
                                          min_value=0,
                                          max_value=len(dset)-batch_size)
    indices = list(range(start_index, start_index+batch_size))

    example = default_collate([dset[i] for i in indices])

    x = model.get_input("image", example).to(model.device)

    cond_key = model.cond_stage_key
    c = model.get_input(cond_key, example).to(model.device)

    scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
    if scale_factor != 1.0:
        x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
        c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")

    quant_z, z_indices = model.encode_to_z(x)
    quant_c, c_indices = model.encode_to_c(c)

    cshape = quant_z.shape

    xrec = model.first_stage_model.decode(quant_z)
    st.write("image: {}".format(x.shape))
    st.image(bchw_to_st(x), clamp=True, output_format="PNG")
    st.write("image reconstruction: {}".format(xrec.shape))
    st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")

    if cond_key == "segmentation":
        # get image from segmentation mask
        num_classes = c.shape[1]
        c = torch.argmax(c, dim=1, keepdim=True)
        c = torch.nn.functional.one_hot(c, num_classes=num_classes)
        c = c.squeeze(1).permute(0, 3, 1, 2).float()
        c = model.cond_stage_model.to_rgb(c)

    st.write(f"{cond_key}: {tuple(c.shape)}")
    st.image(bchw_to_st(c), clamp=True, output_format="PNG")

    idx = z_indices

    half_sample = st.sidebar.checkbox("Image Completion", value=False)
    if half_sample:
        start = idx.shape[1]//2
    else:
        start = 0

    idx[:,start:] = 0
    idx = idx.reshape(cshape[0],cshape[2],cshape[3])
    start_i = start//cshape[3]
    start_j = start %cshape[3]

    if not half_sample and quant_z.shape == quant_c.shape:
        st.info("Setting idx to c_indices")
        idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])

    cidx = c_indices
    cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])

    xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
    st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")

    temperature = st.number_input("Temperature", value=1.0)
    top_k = st.number_input("Top k", value=100)
    sample = st.checkbox("Sample", value=True)
    update_every = st.number_input("Update every", value=75)

    st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")

    animate = st.checkbox("animate")
    if animate:
        import imageio
        outvid = "sampling.mp4"
        writer = imageio.get_writer(outvid, fps=25)
    elapsed_t = st.empty()
    info = st.empty()
    st.text("Sampled")
    if st.button("Sample"):
        output = st.empty()
        start_t = time.time()
        for i in range(start_i,cshape[2]-0):
            if i <= 8:
                local_i = i
            elif cshape[2]-i < 8:
                local_i = 16-(cshape[2]-i)
            else:
                local_i = 8
            for j in range(start_j,cshape[3]-0):
                if j <= 8:
                    local_j = j
                elif cshape[3]-j < 8:
                    local_j = 16-(cshape[3]-j)
                else:
                    local_j = 8

                i_start = i-local_i
                i_end = i_start+16
                j_start = j-local_j
                j_end = j_start+16
                elapsed_t.text(f"Time: {time.time() - start_t} seconds")
                info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
                patch = idx[:,i_start:i_end,j_start:j_end]
                patch = patch.reshape(patch.shape[0],-1)
                cpatch = cidx[:, i_start:i_end, j_start:j_end]
                cpatch = cpatch.reshape(cpatch.shape[0], -1)
                patch = torch.cat((cpatch, patch), dim=1)
                logits,_ = model.transformer(patch[:,:-1])
                logits = logits[:, -256:, :]
                logits = logits.reshape(cshape[0],16,16,-1)
                logits = logits[:,local_i,local_j,:]

                logits = logits/temperature

                if top_k is not None:
                    logits = model.top_k_logits(logits, top_k)
                # apply softmax to convert to probabilities
                probs = torch.nn.functional.softmax(logits, dim=-1)
                # sample from the distribution or take the most likely
                if sample:
                    ix = torch.multinomial(probs, num_samples=1)
                else:
                    _, ix = torch.topk(probs, k=1, dim=-1)
                idx[:,i,j] = ix

                if (i*cshape[3]+j)%update_every==0:
                    xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)

                    xstart = bchw_to_st(xstart)
                    output.image(xstart, clamp=True, output_format="PNG")

                    if animate:
                        writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))

        xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
        xstart = bchw_to_st(xstart)
        output.image(xstart, clamp=True, output_format="PNG")
        #save_img(xstart, "full_res_sample.png")
        if animate:
            writer.close()
            st.video(outvid)
Ejemplo n.º 19
0
def collate_data(batch):
    if isinstance(batch[0], str):
        return batch
    else:
        return default_collate(batch)
def my_collate(batch):
    batch = [dict for dict in batch if dict['4_lanes'] is True]
    return default_collate(batch)
Ejemplo n.º 21
0
 def collate_fn(batch):
     return (default_collate([b[0] for b in batch]),
             [b[1] for b in batch])
Ejemplo n.º 22
0
 def my_collate(batch):
     batch = [i for i in filter(lambda x:x is not None, batch)]
     return default_collate(batch)
Ejemplo n.º 23
0
 def collater(self, samples):
     # For now only supports datasets with same underlying collater implementations
     if hasattr(self.datasets[0], 'collater'):
         return self.datasets[0].collater(samples)
     else:
         return default_collate(samples)
Ejemplo n.º 24
0
def safe_collate(batch):
    """Return batch without any None values"""
    batch = list(filter(lambda x: x is not None, batch))
    return default_collate(batch)
 def collate_fn(self, x: List[Tuple]):
     # if self.config.experiment.batch_size > 1:
     s2s_data, cmax_data = [item[0] for item in x], [item[1] for item in x]
     tensors, dates = [item[:-2] for item in s2s_data], [item[-2:] for item in s2s_data]
     return [[*default_collate(tensors), *list(zip(*dates))], default_collate(cmax_data)]
Ejemplo n.º 26
0
def collate(batch, samples_per_gpu=1):
    """Puts each data field into a tensor/DataContainer with outer dimension
    batch size.

    Extend default_collate to add support for
    :type:`~torchie.parallel.DataContainer`. There are 3 cases.

    1. cpu_only = True, e.g., meta data
    2. cpu_only = False, stack = True, e.g., images tensors
    3. cpu_only = False, stack = False, e.g., gt bboxes
    """

    if not isinstance(batch, collections.Sequence):
        raise TypeError("{} is not supported.".format(batch.dtype))

    if isinstance(batch[0], DataContainer):
        assert len(batch) % samples_per_gpu == 0
        stacked = []
        if batch[0].cpu_only:
            for i in range(0, len(batch), samples_per_gpu):
                stacked.append(
                    [sample.data for sample in batch[i:i + samples_per_gpu]])
            return DataContainer(stacked,
                                 batch[0].stack,
                                 batch[0].padding_value,
                                 cpu_only=True)
        elif batch[0].stack:
            for i in range(0, len(batch), samples_per_gpu):
                assert isinstance(batch[i].data, torch.Tensor)

                if batch[i].pad_dims is not None:
                    ndim = batch[i].dim()
                    assert ndim > batch[i].pad_dims
                    max_shape = [0 for _ in range(batch[i].pad_dims)]
                    for dim in range(1, batch[i].pad_dims + 1):
                        max_shape[dim - 1] = batch[i].size(-dim)
                    for sample in batch[i:i + samples_per_gpu]:
                        for dim in range(0, ndim - batch[i].pad_dims):
                            assert batch[i].size(dim) == sample.size(dim)
                        for dim in range(1, batch[i].pad_dims + 1):
                            max_shape[dim - 1] = max(max_shape[dim - 1],
                                                     sample.size(-dim))
                    padded_samples = []
                    for sample in batch[i:i + samples_per_gpu]:
                        pad = [0 for _ in range(batch[i].pad_dims * 2)]
                        for dim in range(1, batch[i].pad_dims + 1):
                            pad[2 * dim -
                                1] = max_shape[dim - 1] - sample.size(-dim)
                        padded_samples.append(
                            F.pad(sample.data, pad,
                                  value=sample.padding_value))
                    stacked.append(default_collate(padded_samples))
                elif batch[i].pad_dims is None:
                    stacked.append(
                        default_collate([
                            sample.data
                            for sample in batch[i:i + samples_per_gpu]
                        ]))
                else:
                    raise ValueError(
                        "pad_dims should be either None or integers (1-3)")

        else:
            for i in range(0, len(batch), samples_per_gpu):
                stacked.append(
                    [sample.data for sample in batch[i:i + samples_per_gpu]])
        return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [collate(samples, samples_per_gpu) for samples in transposed]
    elif isinstance(batch[0], collections.Mapping):
        return {
            key: collate([d[key] for d in batch], samples_per_gpu)
            for key in batch[0]
        }
    else:
        return default_collate(batch)
Ejemplo n.º 27
0
def filter_unk_collate(batch):
    batch = list(filter(lambda x: np.sum(x['ans_scores']) > 0, batch))
    return default_collate(batch)
Ejemplo n.º 28
0
def label_squeezing_collate_fn(batch):
    x, y = default_collate(batch)
    return x, y.long().squeeze()
Ejemplo n.º 29
0
def ignore_exceptions_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return default_collate(batch)
Ejemplo n.º 30
0
def ignore_exceptions_collate(batch):
    batch = list(
        filter(lambda x: x is not None and type(x) is not torch.double, batch))
    return default_collate(batch)
Ejemplo n.º 31
0
def my_collate(batch):
    from torch.utils.data.dataloader import default_collate
    batch = filter(lambda x: x is not None, batch)
    return default_collate(batch)
Ejemplo n.º 32
0
def collate_fn(batch):
    return default_collate(batch)
 def my_collate(self, batch):
     batch = filter(lambda x: x is not None, batch)
     return default_collate(batch)