コード例 #1
0
def preprocess_batch_graphs(batched_inputs, device):

    batch_size = len(batched_inputs)
    batch_list = list()
    for graph_batch in zip(*batched_inputs):
        for i in range(batch_size):
            graph_batch[i].to(device)
        loader = DataLoader(graph_batch, batch_size, False)
        batch_list.append(loader.__iter__().__next__().to(device))

    return batch_list
コード例 #2
0
def DecLoader(data, batch_size, shuffle=True, device='cuda'):
    device = torch.device(device)
    if shuffle:
        shffle(data)
    while len(data) >= batch_size:
        batch_list = list()
        batch = data[:batch_size]
        for graph_batch in zip(*batch):
            for i in range(batch_size):
                graph_batch[i].to(device)
            loader = DataLoader(graph_batch, batch_size, False)
            batch_list.append(loader.__iter__().__next__().to(device))
        data = data[batch_size:]
        yield batch_list
コード例 #3
0
                        help='Window overlay size')

    args = parser.parse_args()

    training_set = ACERTA_FP(set_split='training', split=args.split, windowed=args.no_windowed, window_size=args.window_size, window_overlay=args.window_overlay,
                            input_type=args.input_type, condition=args.condition, adj_threshold=args.adj_threshold)
    test_set = ACERTA_FP(set_split='test', split=args.split, windowed=args.no_windowed, window_size=args.window_size, window_overlay=args.window_overlay,
                             input_type=args.input_type, condition=args.condition, adj_threshold=args.adj_threshold)
    
    train_loader = DataLoader(training_set, shuffle=True, drop_last=True,
                                batch_size=args.training_batch)
    test_loader = DataLoader(test_set, shuffle=False, drop_last=False,
                                batch_size=args.test_batch)
    

    nfeat = train_loader.__iter__().__next__()['input_anchor']['x'].shape[1]
    print("NFEAT: ",nfeat)
    print("Model: ",args.model)
    print("Scheduler: On") if args.no_scheduler else print("Scheduler: Off")
    if not args.no_windowed and args.input_type=='RST': print("Window: On")


    elif args.model == 'gcn_cheby':
        model = Siamese_GeoChebyConv(nfeat=nfeat,
                                     nhid=args.hidden,
                                     nclass=1,
                                     dropout=args.dropout)
        criterion = ContrastiveLoss(args.loss_margin)

    elif args.model == 'gcn_cheby_bce':
        model = Siamese_GeoChebyConv_Read(nfeat=nfeat,
コード例 #4
0
class GraphDataset(object):
    def __init__(self,dataset, batch_size=1, shuffle=False ):
        self.shuffle = shuffle
        self._dataset = dataset
        self._batch_size = batch_size
        self._dataloader = DataLoader(dataset, batch_size=self._batch_size, shuffle=self.shuffle)
        self._dataloader_iter = self._dataloader.__iter__()
        self._index_in_epoch = 0


    def enforce_batch(self, batch_size):
        self._batch_size = batch_size
        self._dataloader = DataLoader(self._dataset, batch_size=self._batch_size, shuffle=self.shuffle)
        self._dataloader_iter = self._dataloader.__iter__()


    @property
    def num_batches(self):
        return int(len(self._dataset) / self._batch_size)

    @property
    def num_examples(self):
        return len(self._dataset)

    def __len__(self):
        return self.num_examples

    def __getitem__(self, idx):
        if not isinstance(idx, int):
            raise TypeError('dataset indices must be integers, not '+ str(type(idx)))
        if idx > self.__len__() or idx < -self.__len__():
            raise IndexError('dataset index out of range')
        if idx < 0:
            idx = self.__len__()+idx
        return self._dataset[idx]

    def __iter__(self):
        i = 0
        while i < self.__len__():
            yield self._dataset[i]
            i +=1

    def batches(self):
        for _ in range(self.num_batches):
            yield self.next_batch(self._batch_size, shuffle=self.shuffle)

    def next_batch(self, batch_size, shuffle=True):
        if not batch_size == self._batch_size:
            self.shuffle =shuffle
            self.enforce_batch(batch_size)
            self._index_in_epoch = 0

        # restarting dataloader iterable when
        start = self._index_in_epoch
        if start + self._batch_size > self.num_examples:
            self._dataloader_iter = self._dataloader.__iter__()
            self._index_in_epoch = 0
        # new sample
        try:
            data = self._dataloader_iter.__next__()
        except StopIteration:
            self._dataloader_iter = self._dataloader.__iter__()
            data = self._dataloader_iter.__next__()
        images, labels = data, data.y
        self._index_in_epoch += self._batch_size

        return images, labels