def train_dataloader(self):
     # REQUIRED
     train_dir = os.path.join(self.hparams.data, 'train')
     split_idx=0
     raw_data = read_smiles_ring_data('%s/raw.csv' % self.hparams.data)
     data_splits = read_splits('%s/split_%d.txt' % (self.hparams.data, split_idx))
     train_dataset = get_loader(raw_data, 
                                data_splits['train'], 
                                self.hparams, 
                                shuffle=True)
     
     if self.use_ddp:
         train_sampler = DistributedSampler(train_dataset)
     else:
         train_sampler = None
     """
     train_loader = DataLoader(dataset=train_dataset, 
                               batch_size=self.hparams.batch_size,
                               shuffle=(train_sampler is None),
                               num_workers=0,
                               sampler=train_sampler)
     """
     train_loader = get_loader(raw_data, 
                                data_splits['train'], 
                                self.hparams, 
                                shuffle=True)
     
     return train_loader
Exemple #2
0
    def val_dataloader(self):
        # OPTIONAL
        val_dir = os.path.join(self.hparams.data, 'valid')

        split_idx = 0
        if self.hparams.multi:
            raw_data = read_smiles_multiclass('%s/raw.csv' % self.hparams.data)
            n_classes = len(raw_data[0][1])
        else:
            raw_data = read_smiles_from_file('%s/raw.csv' % self.hparams.data)
            n_classes = 1

        data_splits = read_splits('%s/split_%d.txt' %
                                  (self.hparams.data, split_idx))

        val_dataset = get_loader(raw_data,
                                 data_splits['valid'],
                                 self.hparams,
                                 shuffle=True,
                                 batch_size=self.hparams.batch_size,
                                 num_workers=self.hparams.n_workers)  #TODO

        if self.use_ddp:
            val_sampler = DistributedSampler(val_dataset)
        else:
            val_sampler = None

        val_loader = get_loader(raw_data,
                                data_splits['valid'],
                                self.hparams,
                                shuffle=False,
                                sampler=val_sampler)

        return val_loader
Exemple #3
0
    def test_dataloader(self):
        # OPTIONAL
        test_dir = os.path.join(self.hparams.data, 'test')

        split_idx = 0
        if self.hparams.multi:
            raw_data = read_smiles_multiclass('%s/raw.csv' % self.hparams.data)
            n_classes = len(raw_data[0][1])
        else:
            raw_data = read_smiles_from_file('%s/raw.csv' % self.hparams.data)
            n_classes = 1
        data_splits = read_splits('%s/split_%d.txt' %
                                  (self.hparams.data, split_idx))

        test_dataset = get_loader(raw_data,
                                  data_splits['test'],
                                  self.hparams,
                                  num_workers=0,
                                  shuffle=False)  #TODO"""

        if self.use_ddp:
            test_sampler = DistributedSampler(test_dataset)
        else:
            test_sampler = None

        test_loader = get_loader(raw_data,
                                 data_splits['test'],
                                 self.hparams,
                                 num_workers=0,
                                 shuffle=(test_sampler is None),
                                 sampler=test_sampler)

        return test_loader
    def test_dataloader(self):
        # OPTIONAL
        test_dir = os.path.join(self.hparams.data, 'test')
        
        split_idx=0
        raw_data = read_smiles_ring_data('%s/raw.csv' % self.hparams.data)
        data_splits = read_splits('%s/split_%d.txt' % (self.hparams.data, split_idx))
        
        test_dataset = get_loader(raw_data, 
                                   data_splits['test'], 
                                   self.hparams, 
                                   shuffle=True) #TODO
        
        if self.use_ddp:
            test_sampler = DistributedSampler(test_dataset)
        else:
            test_sampler = None

        test_loader = DataLoader(dataset=test_dataset, 
                                  batch_size=self.hparams.batch_size,
                                  shuffle=(test_sampler is None),
                                  num_workers=0,
                                  sampler=test_sampler)
        
        return test_loader
Exemple #5
0
    def train_dataloader(self):
        split_idx = self.split_idx
        # REQUIRED
        train_dir = os.path.join(self.hparams.data, 'train')
        if self.hparams.multi:
            raw_data = read_smiles_multiclass('%s/raw.csv' % self.hparams.data)
            n_classes = len(raw_data[0][1])
            print(f'N_Classes: {n_classes}')
        else:
            raw_data = read_smiles_from_file('%s/raw.csv' % self.hparams.data)
            n_classes = 1
        data_splits = read_splits('%s/split_%d.txt' %
                                  (self.hparams.data, split_idx))
        train_dataset = get_loader(raw_data,
                                   data_splits['train'],
                                   self.hparams,
                                   shuffle=True)

        if self.use_ddp:
            train_sampler = DistributedSampler(train_dataset)
        else:
            train_sampler = None
        """
        train_loader = DataLoader(dataset=train_dataset, 
                                  batch_size=self.hparams.batch_size,
                                  shuffle=(train_sampler is None),
                                  num_workers=0,
                                  sampler=train_sampler)
        """
        train_loader = get_loader(raw_data,
                                  data_splits['train'],
                                  self.hparams,
                                  shuffle=True,
                                  sampler=train_sampler,
                                  batch_size=self.hparams.batch_size,
                                  num_workers=self.hparams.n_workers)

        return train_loader