Пример #1
0
def main():
    dataset = SyncableDataset()
    dataloader = JacDataLoader(dataset, batch_size=1, shuffle=False, num_workers=2, worker_recv_fn=dataset.on_recv)

    for i, value in enumerate(dataloader):
        print(i, value)
        if i == 9:
            dataloader.send_to_worker({'global_index': 10086})
Пример #2
0
    def make_dataloader(self, batch_size, shuffle, drop_last, nr_workers):
        from jactorch.data.dataloader import JacDataLoader
        from jactorch.data.collate import VarLengthCollateV2

        collate_guide = {
            "scene": "skip",
            "objects_raw": "skip",
            "objects": "concat",
            "image_index": "skip",
            "image_filename": "skip",
            "program_raw": "skip",
            "program_seq": "skip",
            "program_tree": "skip",
            "program_qsseq": "skip",
            "program_qstree": "skip",
            "question_type": "skip",
            "answer": "skip",
        }

        gdef.update_collate_guide(collate_guide)

        return JacDataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last,
            num_workers=nr_workers,
            pin_memory=True,
            collate_fn=VarLengthCollateV2(collate_guide),
        )
Пример #3
0
    def make_dataloader(self, batch_size, shuffle, drop_last, nr_workers):
        from jactorch.data.dataloader import JacDataLoader
        from jactorch.data.collate import VarLengthCollateV2

        collate_guide = {
            'scene': 'skip',
            'objects_raw': 'skip',
            'objects': 'concat',

            'image_index': 'skip',
            'image_filename': 'skip',

            'program_raw': 'skip',
            'program_seq': 'skip',
            'program_tree': 'skip',
            'program_qsseq': 'skip',
            'program_qstree': 'skip',

            'question_type': 'skip',
            'answer': 'skip',
        }

        gdef.update_collate_guide(collate_guide)

        return JacDataLoader(
            self, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
            num_workers=nr_workers, pin_memory=True,
            collate_fn=VarLengthCollateV2(collate_guide)
        )
Пример #4
0
    def _prepare_dataset(self, epoch_size, mode):
        assert mode in ['train', 'test']
        if mode == 'train':
            batch_size = args.batch_size
            number = args.train_number
        else:
            batch_size = args.test_batch_size
            number = self.test_number

        # The actual number of instances in an epoch is epoch_size * batch_size.
        dataset = make_dataset(number, epoch_size * batch_size,
                               mode == 'train')
        dataloader = JacDataLoader(dataset,
                                   shuffle=True,
                                   batch_size=batch_size,
                                   num_workers=min(epoch_size, 4))
        self.data_iterator[mode] = dataloader.__iter__()
Пример #5
0
 def test_jac_dataloader(self):
     ds = _FakeDataset()
     dl = JacDataLoader(ds,
                        num_workers=2,
                        worker_init_fn=_my_init_func,
                        worker_init_args=[('hello', ), ('world', )])
     res = list(dl)
     self.assertNotEqual(as_float(res[0]), as_float(res[1]))
Пример #6
0
def main():
    dataset = MyDataset()
    from jactorch.data.dataloader import JacDataLoader, JacDataLoaderMultiGPUWrapper
    from jactorch.data.collate import VarLengthCollateV3
    dataloader = JacDataLoader(dataset,
                               batch_size=8,
                               collate_fn=VarLengthCollateV3({
                                   'x': 'concat',
                                   'y': 'concat'
                               }),
                               shuffle=True,
                               drop_last=True,
                               num_workers=0)
    dataloader = JacDataLoaderMultiGPUWrapper(dataloader, args.gpus)

    from jactorch.parallel import JacDataParallel
    model = MyModel()
    model = JacDataParallel(model,
                            user_scattered=True,
                            dict_gather_layout={
                                'z': 'concat',
                                'devices': 'skip'
                            })
    model.cuda()
    optimizer = optim.SGD(model.parameters(), 1e-4)

    from jactorch.train.env import TrainerEnv, default_reduce_func
    env = TrainerEnv(model, optimizer)

    # the reduce func only changes the behavior of reduction on the loss function and the monitors.
    def custom_reduce_func(k, v):
        if '_max' in k:
            return v.max()
        elif '_sum' in k:
            return v.sum()
        else:
            return default_reduce_func(k, v)

    feed_dict = next(iter(dataloader))
    loss, monitors, outputs, _ = env.step(feed_dict,
                                          reduce_func=custom_reduce_func)

    # feed_dict is a List[Dict], where each dict contain 4 keys: x, y, x_length, and y_length.
    # The length of the list is the number of GPUs.
    # All x's and y's are concatenated along the first dimension (the batch dimension).
    # All {x,y}_lengths are int-typed tensors, recording the length for each item in the batch (thus of size [batch_size]).
    jacinle.stprint(feed_dict)
    # outputs is a dict, which gathers all outputs across all gpus.
    # You can specify the gathering method via dict_gather_layout.
    # For a value to "concat", it will output the concatenation of all tensors across all gpus.
    # An auxiliary tensor: z_length will be added. It is int64-typed, of size [nr_devs], which records the size of dim0
    # for all tensors.
    # If you want to have the maximal control of the outputs, specify 'skip'. In this case, it outputs List[Tuple[str]].
    jacinle.stprint(outputs)
    jacinle.stprint(monitors)
Пример #7
0
def make_dataloader(dataset,
                    batch_size=1,
                    shuffle=False,
                    num_workers=0,
                    pin_memory=False,
                    drop_last=False):
    return JacDataLoader(dataset,
                         batch_size=batch_size,
                         shuffle=shuffle,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         drop_last=drop_last,
                         collate_fn=VarLengthCollate(['sent_f', 'sent_b'],
                                                     'pad'))
Пример #8
0
    def make_dataloader(self, batch_size, shuffle, drop_last, nr_workers):
        from jactorch.data.dataloader import JacDataLoader
        from jactorch.data.collate import VarLengthCollateV2

        collate_guide = {
            'image_filename': 'skip',
        }

        return JacDataLoader(self,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             drop_last=drop_last,
                             num_workers=nr_workers,
                             pin_memory=True,
                             collate_fn=VarLengthCollateV2(collate_guide))
Пример #9
0
    def make_dataloader(self, batch_size, shuffle, drop_last, nr_workers):
        from jactorch.data.collate import VarLengthCollateV2
        collate_fn = VarLengthCollateV2({
            'image_index': 'skip',
            'image_filename': 'skip',
            'question_raw': 'skip',
            'question_raw_tokenized': 'skip',
            'program_seq': 'skip',
            'program_qsseq': 'skip',
            'answer': 'skip'
        })

        from jactorch.data.dataloader import JacDataLoader
        return JacDataLoader(self,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             drop_last=drop_last,
                             num_workers=nr_workers,
                             pin_memory=True,
                             collate_fn=collate_fn)