コード例 #1
0
    def test_dataloader(self):
        # OPTIONAL
        test_dir = os.path.join(self.hparams.data, 'test')

        split_idx = 0
        raw_data = read_smiles_multiclass('%s/raw.csv' % self.hparams.data)
        data_splits = read_splits('%s/split_%d.txt' %
                                  (self.hparams.data, split_idx))

        test_dataset = get_loader(raw_data,
                                  data_splits['test'],
                                  self.hparams,
                                  shuffle=True)  #TODO

        if self.use_ddp:
            test_sampler = DistributedSampler(test_dataset)
        else:
            test_sampler = None

        test_loader = DataLoader(dataset=test_dataset,
                                 batch_size=self.hparams.batch_size,
                                 shuffle=(test_sampler is None),
                                 num_workers=0,
                                 sampler=test_sampler)

        return test_loader
コード例 #2
0
    def train_dataloader(self):
        # REQUIRED
        train_dir = os.path.join(self.hparams.data, 'train')
        split_idx = 0
        raw_data = read_smiles_multiclass('%s/raw.csv' % self.hparams.data)
        data_splits = read_splits('%s/split_%d.txt' %
                                  (self.hparams.data, split_idx))
        train_dataset = get_loader(raw_data,
                                   data_splits['train'],
                                   self.hparams,
                                   shuffle=True)

        if self.use_ddp:
            train_sampler = DistributedSampler(train_dataset)
        else:
            train_sampler = None
        """
        train_loader = DataLoader(dataset=train_dataset, 
                                  batch_size=self.hparams.batch_size,
                                  shuffle=(train_sampler is None),
                                  num_workers=0,
                                  sampler=train_sampler)
        """
        train_loader = get_loader(raw_data,
                                  data_splits['train'],
                                  self.hparams,
                                  shuffle=True)

        return train_loader
コード例 #3
0
    def test_dataloader(self):
        # OPTIONAL
        test_dir = os.path.join(self.hparams.data, 'test')

        split_idx = 0
        if self.hparams.multi:
            raw_data = read_smiles_multiclass('%s/raw.csv' % self.hparams.data)
            n_classes = len(raw_data[0][1])
        else:
            raw_data = read_smiles_from_file('%s/raw.csv' % self.hparams.data)
            n_classes = 1
        data_splits = read_splits('%s/split_%d.txt' %
                                  (self.hparams.data, split_idx))

        test_dataset = get_loader(raw_data,
                                  data_splits['test'],
                                  self.hparams,
                                  num_workers=0,
                                  shuffle=False)  #TODO"""

        if self.use_ddp:
            test_sampler = DistributedSampler(test_dataset)
        else:
            test_sampler = None

        test_loader = get_loader(raw_data,
                                 data_splits['test'],
                                 self.hparams,
                                 num_workers=0,
                                 shuffle=(test_sampler is None),
                                 sampler=test_sampler)

        return test_loader
コード例 #4
0
    def val_dataloader(self):
        # OPTIONAL
        val_dir = os.path.join(self.hparams.data, 'valid')

        split_idx = 0
        if self.hparams.multi:
            raw_data = read_smiles_multiclass('%s/raw.csv' % self.hparams.data)
            n_classes = len(raw_data[0][1])
        else:
            raw_data = read_smiles_from_file('%s/raw.csv' % self.hparams.data)
            n_classes = 1

        data_splits = read_splits('%s/split_%d.txt' %
                                  (self.hparams.data, split_idx))

        val_dataset = get_loader(raw_data,
                                 data_splits['valid'],
                                 self.hparams,
                                 shuffle=True,
                                 batch_size=self.hparams.batch_size,
                                 num_workers=self.hparams.n_workers)  #TODO

        if self.use_ddp:
            val_sampler = DistributedSampler(val_dataset)
        else:
            val_sampler = None

        val_loader = get_loader(raw_data,
                                data_splits['valid'],
                                self.hparams,
                                shuffle=False,
                                sampler=val_sampler)

        return val_loader
コード例 #5
0
ファイル: train_prop.py プロジェクト: giblets2570/tcvaemolgen
def main(hparams):
    # init module
    raw_data = read_smiles_multiclass('%s/raw.csv' % hparams.data)
    n_classes = len(raw_data[0][1])
    model = PropPredictor(hparams, n_classes=n_classes)
    model.version = hparams.dataset
    
    load_shortest_paths(hparams)
    #model.half()
    
    comet_logger = CometLogger(
        api_key=os.environ["COMET_KEY"],
        experiment_name=f'{hparams.dataset}-{str(time.time())}',
        log_graph=True,
        project_name="tcvaemolgen",
        workspace=os.environ["COMET_WKSP"]
    )

    # most basic trainer, uses good defaults
    trainer = Trainer(
        check_val_every_n_epoch=1,
        default_save_path=f'data/05_model_outputs/{hparams.dataset}',
        distributed_backend=hparams.distributed_backend,
        max_nb_epochs=hparams.max_nb_epochs,
        early_stop_callback=None,
        gpus=hparams.gpus,
        gradient_clip_val=10,
        nb_gpu_nodes=hparams.nodes,
        logger=comet_logger,
        log_save_interval=100,
        row_log_interval=10,    
        show_progress_bar=True,
        track_grad_norm=2
    )
    for round_idx in range(hparams.n_rounds):
        model.split_idx = round_idx
        log.info(f'Split {round_idx}')
        trainer.fit(model)
    
    trainer.test()
    return
コード例 #6
0
    def train_dataloader(self):
        split_idx = self.split_idx
        # REQUIRED
        train_dir = os.path.join(self.hparams.data, 'train')
        if self.hparams.multi:
            raw_data = read_smiles_multiclass('%s/raw.csv' % self.hparams.data)
            n_classes = len(raw_data[0][1])
            print(f'N_Classes: {n_classes}')
        else:
            raw_data = read_smiles_from_file('%s/raw.csv' % self.hparams.data)
            n_classes = 1
        data_splits = read_splits('%s/split_%d.txt' %
                                  (self.hparams.data, split_idx))
        train_dataset = get_loader(raw_data,
                                   data_splits['train'],
                                   self.hparams,
                                   shuffle=True)

        if self.use_ddp:
            train_sampler = DistributedSampler(train_dataset)
        else:
            train_sampler = None
        """
        train_loader = DataLoader(dataset=train_dataset, 
                                  batch_size=self.hparams.batch_size,
                                  shuffle=(train_sampler is None),
                                  num_workers=0,
                                  sampler=train_sampler)
        """
        train_loader = get_loader(raw_data,
                                  data_splits['train'],
                                  self.hparams,
                                  shuffle=True,
                                  sampler=train_sampler,
                                  batch_size=self.hparams.batch_size,
                                  num_workers=self.hparams.n_workers)

        return train_loader