Exemplo n.º 1
0
    def load_mca(self, model_path: str):
        """
        Restores pretrained MCA

        Arguments:
            model_path {String} -- Path to the model
        """

        # Restore model
        self.model_path = model_path
        with open(os.path.join(model_path, 'model_params.json')) as f:
            params = json.load(f)

        # Set up language and transforms
        self.smiles_language = SMILESLanguage.load(
            os.path.join(model_path, 'smiles_language.pkl')
        )
        self.transforms = self.compose_smiles_transforms(params)

        # Initialize and restore model weights
        self.model = MODEL_FACTORY['mca'](params)
        self.model.load(
            os.path.join(model_path, 'weights', 'best_ROC-AUC_mca.pt'),
            map_location=get_device()
        )
        self.model.eval()
Exemplo n.º 2
0
    def __init__(self, generator, encoder, params, model_name, logger):
        """
        Constructor for the Reinforcement object.

        Args:
            generator (nn.Module): SMILES generator object.
            encoder (nn.Module): An encoder object.
            params (dict): dict with hyperparameter.
            model_name (str): name of the model.
            logger: a logger.

        Returns:
            object of type REINFORCE used for biasing the properties
            estimated by the predictor of trajectories produced by the
            generator to maximize the custom reward function get_reward.
        """

        super(Reinforce, self).__init__()

        self.generator = generator
        self.generator.eval()

        self.encoder = encoder
        self.encoder.eval()

        self.logger = logger
        self.device = get_device()

        self.optimizer = torch.optim.Adam(
            list(self.generator.decoder.parameters()),
            lr=params.get('learning_rate', 0.0001),
            eps=params.get('eps', 0.0001),
            weight_decay=params.get('weight_decay', 0.00001))

        self.model_name = model_name
        self.model_path = os.path.join(
            params.get('model_folder', 'biased_models'), model_name)
        self.weights_path = os.path.join(self.model_path, 'weights/{}')

        self.smiles_to_tensor = ToTensor(self.device)

        # If model does not yet exist, create it.
        if not os.path.isdir(self.model_path):
            os.makedirs(os.path.join(self.model_path, 'weights'),
                        exist_ok=True)
            os.makedirs(os.path.join(self.model_path, 'results'),
                        exist_ok=True)
            # Save untrained models
            self.save('generator_epoch_0.pt', 'encoder_epoch_0.pt')

            with open(os.path.join(self.model_path, 'model_params.json'),
                      'w') as f:
                json.dump(params, f)
        else:
            self.logger.warning(
                'Model exists already. Call model.load() to restore weights.')
Exemplo n.º 3
0
    def __init__(self, params, *args, **kwargs):
        """Constructor.
        Args:
            params (dict): A dictionary containing the parameter to built the
                dense Decoder.
        Items in params:
            dense_sizes (list[int]): Number of neurons in the hidden layers.
            num_drug_features (int, optional): Number of features for molecule.
                Defaults to 512 (bits fingerprint).
            activation_fn (string, optional): Activation function used in all
                layers for specification in ACTIVATION_FN_FACTORY.
                Defaults to 'relu'.
            batch_norm (bool, optional): Whether batch normalization is
                applied. Defaults to False.
            dropout (float, optional): Dropout probability in all
                except parametric layer. Defaults to 0.0.
            *args, **kwargs: positional and keyword arguments are ignored.
        """
        super(Dense, self).__init__(*args, **kwargs)

        self.device = get_device()
        self.params = params
        self.num_drug_features = params.get('num_drug_features', 512)
        self.num_tasks = params.get('num_tasks', 12)
        self.hidden_sizes = params.get(
            'stacked_dense_hidden_sizes',
            [self.num_drug_features, 5000, 1000, 500]
        )
        self.dropout = params.get('dropout', 0.0)
        self.act_fn = ACTIVATION_FN_FACTORY[
            params.get('activation_fn', 'relu')]
        self.dense_layers = nn.ModuleList(
            [
                dense_layer(
                    self.hidden_sizes[ind],
                    self.hidden_sizes[ind + 1],
                    act_fn=self.act_fn,
                    dropout=self.dropout,
                    batch_norm=self.params.get('batch_norm', True)
                ).to(self.device) for ind in range(len(self.hidden_sizes) - 1)
            ]
        )

        self.final_dense = EnsembleLayer(
            typ=params.get('ensemble', 'score'),
            input_size=self.hidden_sizes[-1],
            output_size=self.num_tasks,
            ensemble_size=params.get('ensemble_size', 5),
            fn=ACTIVATION_FN_FACTORY['sigmoid']
        ).to(self.device)
        self.loss_fn = LOSS_FN_FACTORY[
            params.get('loss_fn', 'binary_cross_entropy_ignore_nan_and_sum')]
Exemplo n.º 4
0
def main(train_affinity_filepath, test_affinity_filepath, receptor_filepath,
         ligand_filepath, model_path, params_filepath, training_name,
         smiles_language_filepath):

    logger = logging.getLogger(f'{training_name}')
    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    # Prepare the dataset
    logger.info("Start data preprocessing...")
    device = get_device()

    # Load languages
    if smiles_language_filepath == '':
        smiles_language_filepath = os.path.join(
            os.sep,
            *metadata.__file__.split(os.sep)[:-1], 'smiles_language')
    smiles_language = SMILESTokenizer.from_pretrained(smiles_language_filepath)
    smiles_language.set_encoding_transforms(
        randomize=None,
        add_start_and_stop=params.get('ligand_start_stop_token', True),
        padding=params.get('ligand_padding', True),
        padding_length=params.get('ligand_padding_length', True),
        device=device,
    )
    smiles_language.set_smiles_transforms(
        augment=params.get('augment_smiles', False),
        canonical=params.get('smiles_canonical', False),
        kekulize=params.get('smiles_kekulize', False),
        all_bonds_explicit=params.get('smiles_bonds_explicit', False),
        all_hs_explicit=params.get('smiles_all_hs_explicit', False),
        remove_bonddir=params.get('smiles_remove_bonddir', False),
        remove_chirality=params.get('smiles_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('sanitize', False))

    if params.get('receptor_embedding', 'learned') == 'predefined':
        protein_language = ProteinFeatureLanguage(
            features=params.get('predefined_embedding', 'blosum'))
    else:
        protein_language = ProteinLanguage()

    if params.get('ligand_embedding', 'learned') == 'one_hot':
        logger.warning(
            'ligand_embedding_size parameter in param file is ignored in '
            'one_hot embedding setting, ligand_vocabulary_size used instead.')
    if params.get('receptor_embedding', 'learned') == 'one_hot':
        logger.warning(
            'receptor_embedding_size parameter in param file is ignored in '
            'one_hot embedding setting, receptor_vocabulary_size used instead.'
        )

    # Assemble datasets
    train_dataset = DrugAffinityDataset(
        drug_affinity_filepath=train_affinity_filepath,
        smi_filepath=ligand_filepath,
        protein_filepath=receptor_filepath,
        protein_language=protein_language,
        smiles_language=smiles_language,
        smiles_padding=params.get('ligand_padding', True),
        smiles_padding_length=params.get('ligand_padding_length', None),
        smiles_add_start_and_stop=params.get('ligand_add_start_stop', True),
        smiles_augment=params.get('augment_smiles', False),
        smiles_canonical=params.get('smiles_canonical', False),
        smiles_kekulize=params.get('smiles_kekulize', False),
        smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False),
        smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False),
        smiles_remove_bonddir=params.get('smiles_remove_bonddir', False),
        smiles_remove_chirality=params.get('smiles_remove_chirality', False),
        smiles_selfies=params.get('selfies', False),
        protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'),
        protein_padding=params.get('receptor_padding', True),
        protein_padding_length=params.get('receptor_padding_length', None),
        protein_add_start_and_stop=params.get('receptor_add_start_stop', True),
        protein_augment_by_revert=params.get('protein_augment', False),
        device=device,
        drug_affinity_dtype=torch.float,
        backend='eager',
        iterate_dataset=params.get('iterate_dataset', False),
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0),
    )

    test_dataset = DrugAffinityDataset(
        drug_affinity_filepath=test_affinity_filepath,
        smi_filepath=ligand_filepath,
        protein_filepath=receptor_filepath,
        protein_language=protein_language,
        smiles_language=smiles_language,
        smiles_padding=params.get('ligand_padding', True),
        smiles_padding_length=params.get('ligand_padding_length', None),
        smiles_add_start_and_stop=params.get('ligand_add_start_stop', True),
        smiles_augment=False,
        smiles_canonical=params.get('smiles_test_canonical', False),
        smiles_kekulize=params.get('smiles_kekulize', False),
        smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False),
        smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False),
        smiles_remove_bonddir=params.get('smiles_remove_bonddir', False),
        smiles_remove_chirality=params.get('smiles_remove_chirality', False),
        smiles_selfies=params.get('selfies', False),
        protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'),
        protein_padding=params.get('receptor_padding', True),
        protein_padding_length=params.get('receptor_padding_length', None),
        protein_add_start_and_stop=params.get('receptor_add_start_stop', True),
        protein_augment_by_revert=False,
        device=device,
        drug_affinity_dtype=torch.float,
        backend='eager',
        iterate_dataset=params.get('iterate_dataset', False),
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0),
    )
    logger.info(
        f'Training dataset has {len(train_dataset)} samples, test set has '
        f'{len(test_dataset)}.')

    logger.info(f'Device for data loader is {train_dataset.device} and for '
                f'model is {device}')
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')
    params.update({
        'ligand_vocabulary_size':
        (train_dataset.smiles_dataset.smiles_language.number_of_tokens),
        'receptor_vocabulary_size':
        protein_language.number_of_tokens,
    })
    logger.info(
        f'Receptor vocabulary size is {protein_language.number_of_tokens} and '
        f'ligand vocabulary size is {train_dataset.smiles_dataset.smiles_language.number_of_tokens}'
    )
    model_fn = params.get('model_fn', 'bimodal_mca')
    model = MODEL_FACTORY[model_fn](params).to(device)
    model._associate_language(smiles_language)
    model._associate_language(protein_language)

    if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mca.pt')):
        logger.info('Found existing model, restoring now...')
        try:
            model.load(os.path.join(model_dir, 'weights', 'best_mca.pt'))

            with open(os.path.join(model_dir, 'results', 'mse.json'),
                      'r') as f:
                info = json.load(f)

                max_roc_auc = info['best_roc_auc']
                min_loss = info['test_loss']

        except Exception:
            min_loss, max_roc_auc = 100, 0
    else:
        min_loss, max_roc_auc = 100, 0

    # Define optimizer
    optimizer = OPTIMIZER_FACTORY[params.get('optimizer',
                                             'adam')](model.parameters(),
                                                      lr=params.get(
                                                          'lr', 0.001))
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params.update({'number_of_parameters': num_params})
    logger.info(f'Number of parameters: {num_params}')
    logger.info(f'Model: {model}')

    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()

    model.save(save_top_model.format('epoch', '0', model_fn))

    for epoch in range(params['epochs']):

        model.train()
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (ligands, receptors, y) in enumerate(train_loader):
            if ind % 100 == 0:
                logger.info(f'Batch {ind}/{len(train_loader)}')
            y_hat, pred_dict = model(ligands, receptors)
            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            # Apply gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6)
            optimizer.step()
            train_loss += loss.item()

        logger.info("\t **** TRAINING ****   "
                    f"Epoch [{epoch + 1}/{params['epochs']}], "
                    f"loss: {train_loss / len(train_loader):.5f}. "
                    f"This took {time() - t:.1f} secs.")
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (ligands, receptors, y) in enumerate(test_loader):
                y_hat, pred_dict = model(ligands.to(device),
                                         receptors.to(device))
                predictions.append(y_hat)
                labels.append(y.clone())
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy()
        labels = torch.cat(labels, dim=0).flatten().cpu().numpy()

        test_loss = test_loss / len(test_loader)
        fpr, tpr, _ = roc_curve(labels, predictions)
        test_roc_auc = auc(fpr, tpr)

        # calculations for visualization plot
        precision, recall, _ = precision_recall_curve(labels, predictions)
        avg_precision = average_precision_score(labels, predictions)

        test_loss = test_loss / len(test_loader)
        logger.info(
            f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], "
            f"loss: {test_loss:.5f}, ROC-AUC: {test_roc_auc:.3f}, "
            f"Average precision: {avg_precision:.3f}.")

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, model_fn))
            info = {
                'best_roc_auc': str(max_roc_auc),
                'test_loss': str(min_loss),
            }
            with open(os.path.join(model_dir, 'results', metric + '.json'),
                      'w') as f:
                json.dump(info, f)
            np.save(
                os.path.join(model_dir, 'results', metric + '_preds.npy'),
                np.vstack([predictions, labels]),
            )
            if typ == 'best':
                logger.info(f'\t New best performance in "{metric}"'
                            f' with value : {val:.7f} in epoch: {epoch}')

        if test_roc_auc > max_roc_auc:
            max_roc_auc = test_roc_auc
            save(save_top_model, 'ROC-AUC', 'best', max_roc_auc)
            ep_roc = epoch
            roc_auc_loss = test_loss

        if test_loss < min_loss:
            min_loss = test_loss
            save(save_top_model, 'loss', 'best', min_loss)
            ep_loss = epoch
            loss_roc_auc = test_roc_auc
        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))
    logger.info('Overall best performances are: \n \t'
                f'Loss = {min_loss:.4f} in epoch {ep_loss} '
                f'\t (ROC-AUC was {loss_roc_auc:4f}) \n \t'
                f'ROC-AUC = {max_roc_auc:.4f} in epoch {ep_roc} '
                f'\t (Loss was {roc_auc_loss:4f})')
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Exemplo n.º 5
0
def main(
    train_sensitivity_filepath, test_sensitivity_filepath, gep_filepath,
    smi_filepath, gene_filepath, smiles_language_filepath, model_path,
    params_filepath, training_name
):

    logger = logging.getLogger(f'{training_name}')
    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    # Prepare the dataset
    logger.info("Start data preprocessing...")

    # Load SMILES language
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Load the gene list
    with open(gene_filepath, 'rb') as f:
        gene_list = pickle.load(f)

    # Assemble datasets
    train_dataset = DrugSensitivityDataset(
        drug_sensitivity_filepath=train_sensitivity_filepath,
        smi_filepath=smi_filepath,
        gene_expression_filepath=gep_filepath,
        smiles_language=smiles_language,
        gene_list=gene_list,
        drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True),
        drug_sensitivity_processing_parameters=params.get(
            'drug_sensitivity_processing_parameters', {}
        ),
        augment=params.get('augment_smiles', True),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        add_start_and_stop=params.get('smiles_start_stop_token', True),
        padding_length=params.get('smiles_padding_length', None),
        gene_expression_standardize=params.get(
            'gene_expression_standardize', True
        ),
        gene_expression_min_max=params.get('gene_expression_min_max', False),
        gene_expression_processing_parameters=params.get(
            'gene_expression_processing_parameters', {}
        ),
        device=torch.device(params.get('dataset_device', 'cpu')),
        backend='eager'
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0)
    )

    test_dataset = DrugSensitivityDataset(
        drug_sensitivity_filepath=test_sensitivity_filepath,
        smi_filepath=smi_filepath,
        gene_expression_filepath=gep_filepath,
        smiles_language=smiles_language,
        gene_list=gene_list,
        drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True),
        drug_sensitivity_processing_parameters=params.get(
            'drug_sensitivity_processing_parameters',
            train_dataset.drug_sensitivity_processing_parameters
        ),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        add_start_and_stop=params.get('smiles_start_stop_token', True),
        padding_length=params.get('smiles_padding_length', None),
        gene_expression_standardize=params.get(
            'gene_expression_standardize', True
        ),
        gene_expression_min_max=params.get('gene_expression_min_max', False),
        gene_expression_processing_parameters=params.get(
            'gene_expression_processing_parameters',
            train_dataset.gene_expression_dataset.processing
        ),
        device=torch.device(params.get('dataset_device', 'cpu')),
        backend='eager'
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0)
    )
    logger.info(
        f'Training dataset has {len(train_dataset)} samples, test set has '
        f'{len(test_dataset)}.'
    )

    device = get_device()
    logger.info(
        f'Device for data loader is {train_dataset.device} and for '
        f'model is {device}'
    )
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')
    params.update({  # yapf: disable
        'number_of_genes': len(gene_list),  # yapf: disable
        'smiles_vocabulary_size': smiles_language.number_of_tokens,
        'drug_sensitivity_processing_parameters':
            train_dataset.drug_sensitivity_processing_parameters,
        'gene_expression_processing_parameters':
            train_dataset.gene_expression_dataset.processing
    })

    model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)
    model._associate_language(smiles_language)

    if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mse_mca.pt')):
        logger.info('Found existing model, restoring now...')
        model.load(os.path.join(model_dir, 'weights', 'mse_best_mca.pt'))

        with open(os.path.join(model_dir, 'results', 'mse.json'), 'r') as f:
            info = json.load(f)

            min_rmse = info['best_rmse']
            max_pearson = info['best_pearson']
            min_loss = info['test_loss']

    else:
        min_loss, min_rmse, max_pearson = 100, 1000, 0

    # Define optimizer
    optimizer = (
        OPTIMIZER_FACTORY[params.get('optimizer', 'Adam')]
        (model.parameters(), lr=params.get('lr', 0.01))
    )

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params.update({'number_of_parameters': num_params})
    logger.info(f'Number of parameters {num_params}')

    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()

    model.save(
        save_top_model.format('epoch', '0', params.get('model_fn', 'mca'))
    )

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, gep, y) in enumerate(train_loader):
            y_hat, pred_dict = model(
                torch.squeeze(smiles.to(device)), gep.to(device)
            )
            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            # Apply gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6)
            optimizer.step()
            train_loss += loss.item()

        logger.info(
            "\t **** TRAINING ****   "
            f"Epoch [{epoch + 1}/{params['epochs']}], "
            f"loss: {train_loss / len(train_loader):.5f}. "
            f"This took {time() - t:.1f} secs."
        )
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, gep, y) in enumerate(test_loader):
                y_hat, pred_dict = model(
                    torch.squeeze(smiles.to(device)), gep.to(device)
                )
                predictions.append(y_hat)
                labels.append(y)
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = np.array(
            [p.cpu() for preds in predictions for p in preds]
        )
        labels = np.array([l.cpu() for label in labels for l in label])
        test_pearson_a = pearsonr(
            torch.Tensor(predictions), torch.Tensor(labels)
        )
        test_rmse_a = np.sqrt(np.mean((predictions - labels)**2))
        test_loss_a = test_loss / len(test_loader)
        logger.info(
            f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], "
            f"loss: {test_loss_a:.5f}, "
            f"Pearson: {test_pearson_a:.3f}, "
            f"RMSE: {test_rmse_a:.3f}"
        )

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            with open(
                os.path.join(model_dir, 'results', metric + '.json'), 'w'
            ) as f:
                json.dump(info, f)
            np.save(
                os.path.join(model_dir, 'results', metric + '_preds.npy'),
                np.vstack([predictions, labels])
            )
            if typ == 'best':
                logger.info(
                    f'\t New best performance in "{metric}"'
                    f' with value : {val:.7f} in epoch: {epoch}'
                )

        def update_info():
            return {
                'best_rmse': str(min_rmse),
                'best_pearson': str(float(max_pearson)),
                'test_loss': str(min_loss),
                'predictions': [float(p) for p in predictions]
            }

        if test_loss_a < min_loss:
            min_rmse = test_rmse_a
            min_loss = test_loss_a
            min_loss_pearson = test_pearson_a
            info = update_info()
            save(save_top_model, 'mse', 'best', min_loss)
            ep_loss = epoch
        if test_pearson_a > max_pearson:
            max_pearson = test_pearson_a
            max_pearson_loss = test_loss_a
            info = update_info()
            save(save_top_model, 'pearson', 'best', max_pearson)
            ep_pearson = epoch
        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))
    logger.info(
        'Overall best performances are: \n \t'
        f'Loss = {min_loss:.4f} in epoch {ep_loss} '
        f'\t (Pearson was {min_loss_pearson:4f}) \n \t'
        f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} '
        f'\t (Loss was {max_pearson_loss:2f})'
    )
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Exemplo n.º 6
0
 def __init__(self):
     self.device = get_device()
Exemplo n.º 7
0
def main(
    train_scores_filepath, test_scores_filepath, smi_filepath, model_path,
    params_filepath, training_name
):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger(f'{training_name}')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()
    device = get_device()

    # Restore pretrained model
    logger.info('Start model restoring.')
    try:
        with open(os.path.join(model_path, 'model_params.json'), 'r') as fp:
            params = json.load(fp)
        smiles_language = SMILESLanguage.load(
            os.path.join(model_path, 'smiles_language.pkl')
        )
        model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)
        # Try weight restoring
        try:
            weight_path = os.path.join(
                model_path, 'weights',
                params.get('weights_name', 'best_ROC-AUC_mca.pt')
            )
            model.load(weight_path)
        except Exception:
            try:
                wp = os.listdir(os.path.join(model_path, 'weights'))[0]
                logger.info(
                    f"Weights {weight_path} not found. Try restore {wp}"
                )
                model.load(os.path.join(model_path, 'weights', wp))
            except Exception:
                raise Exception('Error in weight loading.')

    except Exception:
        raise Exception(
            'Error in model restoring. model_path should point to the model root '
            'folder that contains a model_params.json, a smiles_language.pkl and a '
            'weights folder.'
        )
    logger.info('Model restored. Now starting to craft it for the task')

    # Process parameter file:
    model_params = {}
    with open(params_filepath) as fp:
        model_params.update(json.load(fp))

    model = update_mca_model(model, model_params)

    logger.info('Model set up.')
    for idx, (name, param) in enumerate(model.named_parameters()):
        logger.info(
            (idx, name, param.shape, f'Gradients: {param.requires_grad}')
        )

    ft_model_path = os.path.join(model_path, 'finetuned')
    os.makedirs(os.path.join(ft_model_path, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(ft_model_path, 'results'), exist_ok=True)

    logger.info('Now start data preprocessing...')

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('sanitize', True)
    )

    train_dataset = AnnotatedDataset(
        annotations_filepath=train_scores_filepath,
        dataset=smiles_dataset,
        device=get_device()
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0)
    )

    # Generally, if sanitize is True molecules are de-kekulized. Augmentation
    # preserves the "kekulization", so if it is used, test data should be
    # sanitized or canonicalized.
    smiles_test_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False)
    )

    # Dump eventually modified SMILES Language
    smiles_language.save(os.path.join(ft_model_path, 'smiles_language.pkl'))

    test_dataset = AnnotatedDataset(
        annotations_filepath=test_scores_filepath,
        dataset=smiles_test_dataset,
        device=get_device()
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=params['batch_size'],
        shuffle=False,
        drop_last=False,
        num_workers=params.get('num_workers', 0)
    )

    save_top_model = os.path.join(ft_model_path, 'weights/{}_{}_{}.pt')

    num_t_params = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    num_params = sum(p.numel() for p in model.parameters())
    params.update({'params': num_params, 'trainable_params': num_t_params})
    logger.info(
        f'Number of parameters: {num_params} (trainable {num_t_params}).'
    )

    # Define optimizer only for those layers which require gradients
    optimizer = (
        OPTIMIZER_FACTORY[params.get('optimizer', 'adam')](
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=params.get('lr', 0.00001)
        )
    )
    # Dump params.json file with updated parameters.
    with open(os.path.join(ft_model_path, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()
    min_loss, max_roc_auc = 1000000, 0
    max_precision_recall_score = 0

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, y) in enumerate(train_loader):

            smiles = torch.squeeze(smiles.to(device))
            y_hat, pred_dict = model(smiles)

            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        logger.info(
            '\t **** TRAINING ****   '
            f"Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {train_loss / len(train_loader):.5f}. '
            f'This took {time() - t:.1f} secs.'
        )
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, y) in enumerate(test_loader):

                smiles = torch.squeeze(smiles.to(device))

                y_hat, pred_dict = model(smiles)
                predictions.append(y_hat)
                # Copy y tensor since loss function applies downstream modification
                labels.append(y.clone())
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy()
        labels = torch.cat(labels, dim=0).flatten().cpu().numpy()

        # Remove NaNs from labels to compute scores
        predictions = predictions[~np.isnan(labels)]
        labels = labels[~np.isnan(labels)]
        test_loss_a = test_loss / len(test_loader)
        fpr, tpr, _ = roc_curve(labels, predictions)
        test_roc_auc_a = auc(fpr, tpr)

        # calculations for visualization plot
        precision, recall, _ = precision_recall_curve(labels, predictions)
        # score for precision vs accuracy
        test_precision_recall_score = average_precision_score(
            labels, predictions
        )

        logger.info(
            f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, '
            f'avg precision-recall score: {test_precision_recall_score:.5f}'
        )
        info = {
            'test_auc': test_roc_auc_a,
            'train_loss': train_loss / len(train_loader),
            'test_loss': test_loss_a,
            'test_auc': test_roc_auc_a,
            'best_test_auc': max_roc_auc,
            'test_precision_recall_score': test_precision_recall_score,
            'best_precision_recall_score': max_precision_recall_score,
        }

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            if typ == 'best':
                logger.info(
                    f'\t New best performance in {metric}'
                    f' with value : {val:.7f} in epoch: {epoch+1}'
                )

        if test_roc_auc_a > max_roc_auc:
            max_roc_auc = test_roc_auc_a
            info.update({'best_test_auc': max_roc_auc})
            save(save_top_model, 'ROC-AUC', 'best', max_roc_auc)
            np.save(
                os.path.join(ft_model_path, 'results', 'best_predictions.npy'),
                predictions
            )
            with open(
                os.path.join(ft_model_path, 'results', 'metrics.json'), 'w'
            ) as f:
                json.dump(info, f)

        if test_precision_recall_score > max_precision_recall_score:
            max_precision_recall_score = test_precision_recall_score
            info.update(
                {'best_precision_recall_score': max_precision_recall_score}
            )
            save(
                save_top_model, 'precision-recall score', 'best',
                max_precision_recall_score
            )

        if test_loss_a < min_loss:
            min_loss = test_loss_a
            save(save_top_model, 'loss', 'best', min_loss)
            ep_loss = epoch

        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))

    logger.info(
        'Overall best performances are: \n \t'
        f'Loss = {min_loss:.4f} in epoch {ep_loss} '
    )
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Exemplo n.º 8
0
def update_mca_model(model: MCAMultiTask, params: dict) -> MCAMultiTask:
    """
    Receives a pretrained model (instance of MCAMultiTask), modifies it and returns
    the updated object

    Args:
        model (MCAMultiTask): Pretrained model to be modified.
        params (dict): Hyperparameter file for the modifications. Needs to include:
            - number_of_tunable_layers (how many layers should not be frozen. If
                number exceeds number of existing layers, all layers are tuned.)
            - fresh_dense_sizes (a list of fresh dense layers to be plugged in at
                the end).
            - num_tasks (number of classfication tasks being performed).

    Returns:
        MCAMultiTask: Modified model for finetune
    """

    if not isinstance(model, MCAMultiTask):
        raise TypeError(
            f'Wrong model type, was {type(model)}, not MCAMultiTask.'
        )

    # Freeze the correct layers and add new ones
    # Not strictly speaking all layers, but all param matrices, gradient-req or not.
    num_layers = len(['' for p in model.parameters()])
    num_to_tune = params['number_of_tunable_layers']
    if num_to_tune > num_layers:
        logger.warning(
            f'Model has {num_layers} tunable layers. Given # is larger: {num_to_tune}.'
        )
        num_to_tune = num_layers
    fresh_sizes = params['fresh_dense_sizes']
    logger.info(
        f'Model has {num_layers} layers. {num_to_tune} will be finetuned, '
        f'{len(fresh_sizes)} fresh ones will be added (sizes: {fresh_sizes}).'
    )
    # Count the ensemble layers (will be replaced anyways)
    num_ensemble_layers = len(
        list(
            filter(lambda tpl: 'ensemble' in tpl[0], model.named_parameters())
        )
    )
    # Freeze the right layers
    for idx, (name, param) in enumerate(model.named_parameters()):
        if idx < num_layers - num_to_tune - num_ensemble_layers:
            param.requires_grad = False

    # Add more dense layers
    fresh_sizes.insert(0, model.hidden_sizes[-1])
    model.dense_layers = nn.Sequential(
        model.dense_layers,
        nn.Sequential(
            OrderedDict(
                [
                    (
                        'fresh_dense_{}'.format(ind),
                        dense_layer(
                            fresh_sizes[ind],
                            fresh_sizes[ind + 1],
                            act_fn=ACTIVATION_FN_FACTORY[
                                params.get('activation_fn', 'relu')],
                            dropout=params.get('dropout', 0.5),
                            batch_norm=params.get('batch_norm', True)
                        ).to(get_device())
                    ) for ind in range(len(fresh_sizes) - 1)
                ]
            )
        )
    )

    # Replace final layer
    model.num_tasks = params['num_tasks']
    model.final_dense = EnsembleLayer(
        typ=params.get('ensemble', 'score'),
        input_size=fresh_sizes[-1],
        output_size=model.num_tasks,
        ensemble_size=params.get('ensemble_size', 5),
        fn=ACTIVATION_FN_FACTORY['sigmoid']
    )

    return model
Exemplo n.º 9
0
def main(train_sensitivity_filepath, test_sensitivity_filepath, gep_filepath,
         smi_filepath, gene_filepath, smiles_language_filepath, model_path,
         params_filepath, training_name):

    logger = logging.getLogger(f'{training_name}')
    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    # Prepare the dataset
    logger.info("Start data preprocessing...")

    # Load SMILES language
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Load the gene list
    with open(gene_filepath, 'rb') as f:
        gene_list = pickle.load(f)

    # Assemble datasets
    train_dataset = DrugSensitivityDataset(
        drug_sensitivity_filepath=train_sensitivity_filepath,
        smi_filepath=smi_filepath,
        gene_expression_filepath=gep_filepath,
        smiles_language=smiles_language,
        gene_list=gene_list,
        drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True),
        augment=params.get('augment_smiles', True),
        add_start_and_stop=params.get('smiles_start_stop_token', True),
        padding_length=params.get('smiles_padding_length', None),
        device=torch.device(params.get('dataset_device', 'cpu')),
        backend='eager')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=params['batch_size'],
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=params.get(
                                                   'num_workers', 0))

    test_dataset = DrugSensitivityDataset(
        drug_sensitivity_filepath=test_sensitivity_filepath,
        smi_filepath=smi_filepath,
        gene_expression_filepath=gep_filepath,
        smiles_language=smiles_language,
        gene_list=gene_list,
        drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True),
        augment=params.get('augment_smiles', True),
        add_start_and_stop=params.get('smiles_start_stop_token', True),
        padding_length=params.get('smiles_padding_length', None),
        device=torch.device(params.get('dataset_device', 'cpu')),
        backend='eager')
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=params['batch_size'],
                                              shuffle=True,
                                              drop_last=True,
                                              num_workers=params.get(
                                                  'num_workers', 0))

    params.update({
        'number_of_genes': len(gene_list),
        'smiles_vocabulary_size': smiles_language.number_of_tokens
    })

    # Set the tensorboard logger
    tb_logger = Logger(os.path.join(model_dir, 'tb_logs'))
    device = get_device()
    logger.info(f'Device for data loader is {train_dataset.device} and for '
                f'model is {device}')
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')

    model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)

    # Define optimizer
    optimizer = (OPTIMIZER_FACTORY[params.get('optimizer',
                                              'Adam')](model.parameters(),
                                                       lr=params.get(
                                                           'lr', 0.01)))
    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()
    min_loss, max_pearson = 100, 0

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, gep, y) in enumerate(train_loader):
            y_hat, pred_dict = model(torch.squeeze(smiles.to(device)),
                                     gep.to(device))
            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            # Apply gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6)
            optimizer.step()
            train_loss += loss.item()

        logger.info("\t **** TRAINING ****   "
                    f"Epoch [{epoch + 1}/{params['epochs']}], "
                    f"loss: {train_loss / len(train_loader):.5f}. "
                    f"This took {time() - t:.1f} secs.")
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, gep, y) in enumerate(test_loader):
                y_hat, pred_dict = model(torch.squeeze(smiles.to(device)),
                                         gep.to(device))
                predictions.append(y_hat)
                labels.append(y)
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = np.array(
            [p.cpu() for preds in predictions for p in preds])
        labels = np.array([l.cpu() for label in labels for l in label])
        test_pearson_a = pearsonr(predictions, labels)
        test_rmse_a = np.sqrt(np.mean((predictions - labels)**2))
        np.save(os.path.join(model_dir, 'results', 'preds.npy'),
                np.hstack([predictions, labels]))
        test_loss_a = test_loss / len(test_loader)
        logger.info(
            f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], "
            f"loss: {test_loss_a:.5f}, "
            f"Pearson: {test_pearson_a:.3f}, "
            f"RMSE: {test_rmse_a:.3f}")

        # TensorBoard logging of scalars.
        info = {
            'train_loss': train_loss / len(train_loader),
            'test_loss': test_loss_a,
        }
        for tag, value in info.items():
            tb_logger.scalar_summary(tag, value, epoch + 1)

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            if typ == 'best':
                logger.info(f'\t New best performance in "{metric}"'
                            f' with value : {val:.7f} in epoch: {epoch}')

        if test_loss_a < min_loss:
            min_loss = test_loss_a
            min_loss_pearson = test_pearson_a
            save(save_top_model, 'mse', 'best', min_loss)
            ep_loss = epoch
        if test_pearson_a > max_pearson:
            max_pearson = test_pearson_a
            max_pearson_loss = test_loss_a
            save(save_top_model, 'pearson', 'best', max_pearson)
            ep_pearson = epoch
        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))
    logger.info('Overall best performances are: \n \t'
                f'Loss = {min_loss:.4f} in epoch {ep_loss} '
                f'\t (Pearson was {min_loss_pearson:4f}) \n \t'
                f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} '
                f'\t (Loss was {max_pearson_loss:2f})')
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Exemplo n.º 10
0
    def __init__(self, params, *args, **kwargs):
        """Constructor.

        Args:
            params (dict): A dictionary containing the parameter to built the
                dense Decoder.
                TODO params should become actual arguments (use **params).
        Items in params:
            filters (list[int], optional): Numbers of filters to learn per
                convolutional layer. 
            kernel_sizes (list[list[int]], optional): Sizes of kernels per
                convolutional layer. Defaults to  [
                    [3, params['smiles_embedding_size']],
                    [5, params['smiles_embedding_size']],
                    [11, params['smiles_embedding_size']]
                ]
            activation_fn (string, optional): Activation function used in all
                layers for specification in ACTIVATION_FN_FACTORY.
                Defaults to 'relu'.
            batch_norm (bool, optional): Whether batch normalization is
                applied. Defaults to False.
            dropout (float, optional): Dropout probability in all
                except parametric layer. Defaults to 0.0.
            *args, **kwargs: positional and keyword arguments are ignored.
        """

        super(CNN, self).__init__(*args, **kwargs)

        # Model Parameter
        self.device = get_device()
        self.params = params
        self.loss_fn = LOSS_FN_FACTORY[params.get('loss_fn', 'cnn')]

        self.kernel_sizes = params.get('kernel_sizes',
                                       [[3, params['smiles_embedding_size']],
                                        [5, params['smiles_embedding_size']],
                                        [11, params['smiles_embedding_size']]])

        self.num_filters = [1] + params.get('num_filters', [10, 20, 50])

        if len(self.filters) != len(self.kernel_sizes):
            raise ValueError(
                'Length of filter and kernel size lists do not match.')

        self.smiles_embedding = nn.Embedding(
            self.params['smiles_vocabulary_size'],
            self.params['smiles_embedding_size'],
            scale_grad_by_freq=params.get('embed_scale_grad', False))

        self.dropout = params.get('dropout', 0.0)
        self.act_fn = ACTIVATION_FN_FACTORY[params.get('activation_fn',
                                                       'relu')]

        self.conv_layers = [
            convolutional_layer(self.num_filters[layer],
                                self.num_filters[layer + 1],
                                self.kernel_sizes[layer])
            for layer in range(len(self.channel_inputs) - 1)
        ]

        self.final_dense = nn.Linear(
            (self.num_filters[-1] * self.params['smiles_embedding_size']),
            self.num_tasks)
        self.final_act_fn = ACTIVATION_FN_FACTORY['sigmoid']
Exemplo n.º 11
0
def main(train_scores_filepath,
         test_scores_filepath,
         smi_filepath,
         smiles_language_filepath,
         model_path,
         params_filepath,
         training_name,
         embedding_path=None):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger(f'{training_name}')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()

    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    if embedding_path:
        params['embedding_path'] = embedding_path

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    logger.info('Start data preprocessing...')
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Prepare FP processing
    if params.get('model_fn', 'mca') == 'dense':
        morgan_transform = Compose([
            SMILESToMorganFingerprints(radius=params.get('fp_radius', 2),
                                       bits=params.get('num_drug_features',
                                                       512),
                                       chirality=params.get(
                                           'fp_chirality', True)),
            ToTensor(get_device())
        ])

        def smiles_tensor_batch_to_fp(smiles):
            """ To abuse SMILES dataset for FP usage"""
            out = torch.Tensor(smiles.shape[0],
                               params.get('num_drug_features', 256))
            for ind, tensor in enumerate(smiles):
                smiles = smiles_language.token_indexes_to_smiles(
                    tensor.tolist())
                out[ind, :] = torch.squeeze(morgan_transform(smiles))
            return out

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('sanitize', True))

    # include arg label_columns if data file has any unwanted columns (such as index) to be ignored.
    train_dataset = AnnotatedDataset(
        annotations_filepath=train_scores_filepath,
        dataset=smiles_dataset,
        device=get_device())
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=params['batch_size'],
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=params.get(
                                                   'num_workers', 0))

    if (params.get('uncertainty', True)
            and params.get('augment_test_data', False)):
        raise ValueError(
            'Epistemic uncertainty evaluation not supported if augmentation '
            'is not enabled for test data.')

    # Generally, if sanitize is True molecules are de-kekulized. Augmentation
    # preserves the "kekulization", so if it is used, test data should be
    # sanitized or canonicalized.
    smiles_test_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False))

    # Dump eventually modified SMILES Language
    smiles_language.save(os.path.join(model_dir, 'smiles_language.pkl'))

    logger.info(smiles_dataset._dataset.transform)
    logger.info(smiles_test_dataset._dataset.transform)

    # include arg label_columns if data file has any unwanted columns (such as index) to be ignored.
    test_dataset = AnnotatedDataset(annotations_filepath=test_scores_filepath,
                                    dataset=smiles_test_dataset,
                                    device=get_device())

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=params['batch_size'],
                                              shuffle=False,
                                              drop_last=False,
                                              num_workers=params.get(
                                                  'num_workers', 0))

    if params.get('confidence', False):
        smiles_conf_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=True,  # Natively true for epistemic uncertainity estimate
            canonical=params.get('canonical', False),
            kekulize=params.get('kekulize', False),
            all_bonds_explicit=params.get('all_bonds_explicit', False),
            all_hs_explicit=params.get('all_hs_explicit', False),
            randomize=params.get('randomize', False),
            remove_bonddir=params.get('remove_bonddir', False),
            remove_chirality=params.get('remove_chirality', False),
            selfies=params.get('selfies', False))
        conf_dataset = AnnotatedDataset(
            annotations_filepath=test_scores_filepath,
            dataset=smiles_conf_dataset,
            device=get_device())
        conf_loader = torch.utils.data.DataLoader(
            dataset=conf_dataset,
            batch_size=params['batch_size'],
            shuffle=False,
            drop_last=False,
            num_workers=params.get('num_workers', 0))

    if not params.get('embedding', 'learned') == 'pretrained':
        params.update(
            {'smiles_vocabulary_size': smiles_language.number_of_tokens})

    device = get_device()
    logger.info(f'Device for data loader is {train_dataset.device} and for '
                f'model is {device}')
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')

    model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)
    logger.info(model)
    logger.info(model.loss_fn.class_weights)

    logger.info('Parameters follow')
    for name, param in model.named_parameters():
        logger.info((name, param.shape))

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params.update({'number_of_parameters': num_params})
    logger.info(f'Number of parameters {num_params}')

    # Define optimizer
    optimizer = (OPTIMIZER_FACTORY[params.get('optimizer',
                                              'adam')](model.parameters(),
                                                       lr=params.get(
                                                           'lr', 0.00001)))
    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()
    min_loss, max_roc_auc = 1000000, 0
    max_precision_recall_score = 0

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, y) in enumerate(train_loader):

            smiles = torch.squeeze(smiles.to(device))
            # Transform smiles to FP if needed
            if params.get('model_fn', 'mca') == 'dense':
                smiles = smiles_tensor_batch_to_fp(smiles).to(device)

            y_hat, pred_dict = model(smiles)

            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        logger.info('\t **** TRAINING ****   '
                    f"Epoch [{epoch + 1}/{params['epochs']}], "
                    f'loss: {train_loss / len(train_loader):.5f}. '
                    f'This took {time() - t:.1f} secs.')
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, y) in enumerate(test_loader):

                smiles = torch.squeeze(smiles.to(device))
                # Transform smiles to FP if needed
                if params.get('model_fn', 'mca') == 'dense':
                    smiles = smiles_tensor_batch_to_fp(smiles).to(device)

                y_hat, pred_dict = model(smiles)
                predictions.append(y_hat)
                # Copy y tensor since loss function applies downstream
                #   modification
                labels.append(y.clone())
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy()
        labels = torch.cat(labels, dim=0).flatten().cpu().numpy()

        # Remove NaNs from labels to compute scores
        predictions = predictions[~np.isnan(labels)]
        labels = labels[~np.isnan(labels)]
        test_loss_a = test_loss / len(test_loader)
        fpr, tpr, _ = roc_curve(labels, predictions)
        test_roc_auc_a = auc(fpr, tpr)

        # calculations for visualization plot
        precision, recall, _ = precision_recall_curve(labels, predictions)
        # score for precision vs accuracy
        test_precision_recall_score = average_precision_score(
            labels, predictions)

        logger.info(
            f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, '
            f'avg precision-recall score: {test_precision_recall_score:.5f}')
        info = {
            'test_auc': test_roc_auc_a,
            'train_loss': train_loss / len(train_loader),
            'test_loss': test_loss_a,
            'test_auc': test_roc_auc_a,
            'best_test_auc': max_roc_auc,
            'test_precision_recall_score': test_precision_recall_score,
            'best_precision_recall_score': max_precision_recall_score,
        }

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            if typ == 'best':
                logger.info(f'\t New best performance in {metric}'
                            f' with value : {val:.7f} in epoch: {epoch+1}')

        if test_roc_auc_a > max_roc_auc:
            max_roc_auc = test_roc_auc_a
            info.update({'best_test_auc': max_roc_auc})
            save(save_top_model, 'ROC-AUC', 'best', max_roc_auc)
            np.save(os.path.join(model_dir, 'results', 'best_predictions.npy'),
                    predictions)
            with open(os.path.join(model_dir, 'results', 'metrics.json'),
                      'w') as f:
                json.dump(info, f)
            if params.get('confidence', False):
                # Compute uncertainity estimates and save them
                epistemic_conf = monte_carlo_dropout(model,
                                                     regime='loader',
                                                     loader=conf_loader)
                aleatoric_conf = test_time_augmentation(model,
                                                        regime='loader',
                                                        loader=conf_loader)

                np.save(
                    os.path.join(model_dir, 'results', 'epistemic_conf.npy'),
                    epistemic_conf)
                np.save(
                    os.path.join(model_dir, 'results', 'aleatoric_conf.npy'),
                    aleatoric_conf)

        if test_precision_recall_score > max_precision_recall_score:
            max_precision_recall_score = test_precision_recall_score
            info.update(
                {'best_precision_recall_score': max_precision_recall_score})
            save(save_top_model, 'precision-recall score', 'best',
                 max_precision_recall_score)

        if test_loss_a < min_loss:
            min_loss = test_loss_a
            save(save_top_model, 'loss', 'best', min_loss)
            ep_loss = epoch

        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))

    logger.info('Overall best performances are: \n \t'
                f'Loss = {min_loss:.4f} in epoch {ep_loss} ')
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Exemplo n.º 12
0
def main(model_path, smi_filepath, labels_filepath, output_folder,
         smiles_language_filepath, model_id, confidence):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger('eval_toxicity')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()

    # Process parameter file:
    params = {}
    with open(os.path.join(model_path, 'model_params.json'), 'r') as fp:
        params.update(json.load(fp))

    # Create model directory
    os.makedirs(output_folder, exist_ok=True)

    if model_id not in MODEL_FACTORY.keys():
        raise KeyError(
            f'Model ID: Pass one of {MODEL_FACTORY.keys()}, not {model_id}')

    device = get_device()
    weights_path = os.path.join(model_path, 'weights',
                                f'best_ROC-AUC_{model_id}.pt')

    # Restore model
    model = MODEL_FACTORY[model_id](params).to(device)
    if os.path.isfile(weights_path):
        try:
            model.load(weights_path, map_location=device)
        except Exception:
            logger.error(f'Error in model restoring from {weights_path}')
    else:
        logger.info(f'Did not find weights at {weights_path}, '
                    f'name weights: "best_ROC-AUC_{model_id}.pt".')
    model.eval()

    logger.info('Model restored. Model specs & parameters follow')
    for name, param in model.named_parameters():
        logger.info((name, param.shape))

    # Load language
    if smiles_language_filepath == '.':
        smiles_language_filepath = os.path.join(model_path,
                                                'smiles_language.pkl')
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False),
    )
    logger.info(
        f'SMILES Padding length is {smiles_dataset._dataset.padding_length}.'
        'Consider setting manually if this looks wrong.')
    dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                               dataset=smiles_dataset,
                               device=device)
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=10,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=0)
    if confidence:
        smiles_aleatoric_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=True,  # Natively true for aleatoric uncertainity estimate
            canonical=params.get('test_canonical', False),
            kekulize=params.get('test_kekulize', False),
            all_bonds_explicit=params.get('test_all_bonds_explicit', False),
            all_hs_explicit=params.get('test_all_hs_explicit', False),
            remove_bonddir=params.get('test_remove_bonddir', False),
            remove_chirality=params.get('test_remove_chirality', False),
            selfies=params.get('selfies', False),
            sanitize=params.get('test_sanitize', False),
        )
        ale_dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                                       dataset=smiles_aleatoric_dataset,
                                       device=device)
        ale_loader = torch.utils.data.DataLoader(dataset=ale_dataset,
                                                 batch_size=10,
                                                 shuffle=False,
                                                 drop_last=False,
                                                 num_workers=0)
        smiles_epi_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=False,  # Natively false for epistemic uncertainity estimate
            canonical=params.get('test_canonical', False),
            kekulize=params.get('test_kekulize', False),
            all_bonds_explicit=params.get('test_all_bonds_explicit', False),
            all_hs_explicit=params.get('test_all_hs_explicit', False),
            remove_bonddir=params.get('test_remove_bonddir', False),
            remove_chirality=params.get('test_remove_chirality', False),
            selfies=params.get('selfies', False),
            sanitize=params.get('test_sanitize', False),
        )
        epi_dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                                       dataset=smiles_epi_dataset,
                                       device=device)
        epi_loader = torch.utils.data.DataLoader(dataset=epi_dataset,
                                                 batch_size=10,
                                                 shuffle=False,
                                                 drop_last=False,
                                                 num_workers=0)

    logger.info(f'Device for data loader is {dataset.device} and for '
                f'model is {device}')
    # Start evaluation
    logger.info('Evaluation about to start...\n')

    preds, labels, attention_scores, smiles = [], [], [], []
    for idx, (smiles_batch, labels_batch) in enumerate(loader):
        pred, pred_dict = model(smiles_batch.to(device))
        preds.extend(pred.detach().squeeze().tolist())
        attention_scores.extend(
            torch.stack(pred_dict['smiles_attention'], dim=1).detach())
        smiles.extend([
            smiles_language.token_indexes_to_smiles(s.tolist())
            for s in smiles_batch
        ])
        labels.extend(labels_batch.squeeze().tolist())
    # Scores are now 3D: num_samples x num_att_layers x padding_length
    attention = torch.stack(attention_scores, dim=0).numpy()

    if confidence:
        # Compute uncertainity estimates and save them
        epistemic_conf, epistemic_pred = monte_carlo_dropout(model,
                                                             regime='loader',
                                                             loader=epi_loader)
        aleatoric_conf, aleatoric_pred = test_time_augmentation(
            model, regime='loader', loader=ale_loader)
        epi_conf_df = pd.DataFrame(
            data=epistemic_conf.numpy(),
            columns=[
                f'epistemic_conf_{i}' for i in range(epistemic_conf.shape[1])
            ],
        )
        ale_conf_df = pd.DataFrame(
            data=aleatoric_conf.numpy(),
            columns=[
                f'aleatoric_conf_{i}' for i in range(aleatoric_conf.shape[1])
            ],
        )
        epi_pred_df = pd.DataFrame(
            data=epistemic_pred.numpy(),
            columns=[
                f'epistemic_pred_{i}' for i in range(epistemic_pred.shape[1])
            ],
        )
        ale_pred_df = pd.DataFrame(
            data=aleatoric_pred.numpy(),
            columns=[
                f'aleatoric_pred_{i}' for i in range(aleatoric_pred.shape[1])
            ],
        )

    logger.info(f'Shape of attention scores {attention.shape}.')
    np.save(os.path.join(output_folder, 'attention_raw.npy'), attention)
    attention_avg = np.mean(attention, axis=1)
    att_df = pd.DataFrame(
        data=attention_avg,
        columns=[f'att_idx_{i}' for i in range(attention_avg.shape[1])],
    )
    pred_df = pd.DataFrame(
        data=preds,
        columns=[f'pred_{i}' for i in range(len(preds[0]))],
    )
    lab_df = pd.DataFrame(
        data=labels,
        columns=[f'label_{i}' for i in range(len(labels[0]))],
    )
    df = pd.concat([pred_df, lab_df], axis=1)
    if confidence:
        df = pd.concat(
            [df, epi_conf_df, ale_conf_df, epi_pred_df, ale_pred_df], axis=1)
    df = pd.concat([df, att_df], axis=1)
    df.insert(0, 'SMILES', smiles)
    df.to_csv(os.path.join(output_folder, 'results.csv'), index=False)

    logger.info('Done, shutting down.')
Exemplo n.º 13
0
import logging
from typing import Iterable

import torch
import torch.nn as nn
from paccmann_predictor.utils.utils import get_device

logger = logging.getLogger(__name__)

DEVICE = get_device()


class BCEIgnoreNaN(nn.Module):
    """Wrapper for BCE function that ignores NaNs"""

    def __init__(self, reduction: str, class_weights: tuple = (1, 1)) -> None:
        """

        Args:
            reduction (str): Reduction applied in loss function. Either sum or mean.
            class_weights (tuple, optional): Class weights for loss function.
                Defaults to (1, 1), i.e. equal class weighhts.
        """
        super(BCEIgnoreNaN, self).__init__()
        self.loss = nn.BCELoss(reduction='none')

        if reduction != 'sum' and reduction != 'mean':
            raise ValueError(
                f'Chose reduction type as mean or sum, not {reduction}'
            )
        self.reduction = reduction
Exemplo n.º 14
0
    def __init__(self, params: dict, *args, **kwargs):
        """Constructor.

        Args:
            params (dict): A dictionary containing the parameter to built the
                dense encoder.

        Items in params:
            smiles_embedding_size (int): dimension of tokens' embedding.
            smiles_vocabulary_size (int): size of the tokens vocabulary.
            activation_fn (string, optional): Activation function used in all
                layers for specification in ACTIVATION_FN_FACTORY.
                Defaults to 'relu'.
            batch_norm (bool, optional): Whether batch normalization is
                applied. Defaults to True.
            dropout (float, optional): Dropout probability in all
                except parametric layer. Defaults to 0.5.
            filters (list[int], optional): Numbers of filters to learn per
                convolutional layer. Defaults to [64, 64, 64].
            kernel_sizes (list[list[int]], optional): Sizes of kernels per
                convolutional layer. Defaults to  [
                    [3, params['smiles_embedding_size']],
                    [5, params['smiles_embedding_size']],
                    [11, params['smiles_embedding_size']]
                ]
                NOTE: The kernel sizes should match the dimensionality of the
                smiles_embedding_size, so if the latter is 8, the images are
                t x 8, then treat the 8 embedding dimensions like channels
                in an RGB image.
            multiheads (list[int], optional): Amount of attentive multiheads
                per SMILES embedding. Should have len(filters)+1.
                Defaults to [4, 4, 4, 4].
            stacked_dense_hidden_sizes (list[int], optional): Sizes of the
                hidden dense layers. Defaults to [1024, 512].
            smiles_attention_size (int, optional): size of the attentive layer
                for the smiles sequence. Defaults to 64.

        Example params:
        ```
        {
            "smiles_attention_size": 8,
            "smiles_vocabulary_size": 28,
            "smiles_embedding_size": 8,
            "filters": [128, 128],
            "kernel_sizes": [[3, 8], [5, 8]],
            "multiheads":[4, 4, 4]
            "stacked_dense_hidden_sizes": [1024, 512]
        }
        ```
    """
        super(MCAMultiTask, self).__init__(*args, **kwargs)

        # Model Parameter
        self.device = get_device()
        self.params = params
        self.num_tasks = params.get('num_tasks', 12)
        self.smiles_attention_size = params.get('smiles_attention_size', 64)

        # Model architecture (hyperparameter)
        self.multiheads = params.get('multiheads', [4, 4, 4, 4])
        self.filters = params.get('filters', [64, 64, 64])

        self.dropout = params.get('dropout', 0.5)
        self.use_batch_norm = self.params.get('batch_norm', True)
        self.act_fn = ACTIVATION_FN_FACTORY[params.get('activation_fn',
                                                       'relu')]

        # Build the model. First the embeddings
        if params.get('embedding', 'learned') == 'learned':

            self.smiles_embedding = nn.Embedding(
                self.params['smiles_vocabulary_size'],
                self.params['smiles_embedding_size'],
                scale_grad_by_freq=params.get('embed_scale_grad', False))
        elif params.get('embedding', 'learned') == 'one_hot':
            self.smiles_embedding = nn.Embedding(
                self.params['smiles_vocabulary_size'],
                self.params['smiles_vocabulary_size'])
            # Plug in one hot-vectors and freeze weights
            self.smiles_embedding.load_state_dict({
                'weight':
                torch.nn.functional.one_hot(
                    torch.arange(self.params['smiles_vocabulary_size']))
            })
            self.smiles_embedding.weight.requires_grad = False

        elif params.get('embedding', 'learned') == 'pretrained':
            # Load the pretrained embeddings
            try:
                with open(params['embedding_path'], 'rb') as f:
                    embeddings = pickle.load(f)
            except KeyError:
                raise KeyError('Path for embeddings is missing in params.')

            # Plug into layer
            self.smiles_embedding = nn.Embedding(embeddings.shape[0],
                                                 embeddings.shape[1])
            self.smiles_embedding.load_state_dict(
                {'weight': torch.Tensor(embeddings)})
            if params.get('fix_embeddings', True):
                self.smiles_embedding.weight.requires_grad = False

        else:
            raise ValueError(f"Unknown embedding type: {params['embedding']}")

        self.kernel_sizes = params.get(
            'kernel_sizes', [[3, self.smiles_embedding.weight.shape[1]],
                             [5, self.smiles_embedding.weight.shape[1]],
                             [11, self.smiles_embedding.weight.shape[1]]])

        self.hidden_sizes = ([
            self.multiheads[0] * self.smiles_embedding.weight.shape[1] +
            sum([h * f for h, f in zip(self.multiheads[1:], self.filters)])
        ] + params.get('stacked_hidden_sizes', [1024, 512]))

        if len(self.filters) != len(self.kernel_sizes):
            raise ValueError(
                'Length of filter and kernel size lists do not match.')
        if len(self.filters) + 1 != len(self.multiheads):
            raise ValueError(
                'Length of filter and multihead lists do not match')

        self.convolutional_layers = nn.Sequential(
            OrderedDict([(f'convolutional_{index}',
                          convolutional_layer(
                              num_kernel,
                              kernel_size,
                              act_fn=self.act_fn,
                              dropout=self.dropout,
                              batch_norm=self.use_batch_norm).to(self.device))
                         for index, (num_kernel, kernel_size) in enumerate(
                             zip(self.filters, self.kernel_sizes))]))

        smiles_hidden_sizes = [self.smiles_embedding.weight.shape[1]
                               ] + self.filters
        self.smiles_projections = nn.Sequential(
            OrderedDict([
                (f'smiles_projection_{self.multiheads[0]*layer+index}',
                 smiles_projection(smiles_hidden_sizes[layer],
                                   self.smiles_attention_size))
                for layer in range(len(self.multiheads))
                for index in range(self.multiheads[layer])
            ]))
        self.alpha_projections = nn.Sequential(
            OrderedDict([(f'alpha_projection_{self.multiheads[0]*layer+index}',
                          alpha_projection(self.smiles_attention_size))
                         for layer in range(len(self.multiheads))
                         for index in range(self.multiheads[layer])]))

        if self.use_batch_norm:
            self.batch_norm = nn.BatchNorm1d(self.hidden_sizes[0])

        self.dense_layers = nn.Sequential(
            OrderedDict([
                ('dense_{}'.format(ind),
                 dense_layer(self.hidden_sizes[ind],
                             self.hidden_sizes[ind + 1],
                             act_fn=self.act_fn,
                             dropout=self.dropout,
                             batch_norm=self.use_batch_norm).to(self.device))
                for ind in range(len(self.hidden_sizes) - 1)
            ]))

        if params.get('ensemble', 'None') not in ['score', 'prob', 'None']:
            raise NotImplementedError(
                "Choose ensemble type from ['score', 'prob', 'None']")
        if params.get('ensemble', 'None') == 'None':
            params['ensemble_size'] = 1

        self.final_dense = EnsembleLayer(typ=params.get('ensemble', 'score'),
                                         input_size=self.hidden_sizes[-1],
                                         output_size=self.num_tasks,
                                         ensemble_size=params.get(
                                             'ensemble_size', 5),
                                         fn=ACTIVATION_FN_FACTORY['sigmoid'])

        self.loss_fn = LOSS_FN_FACTORY[
            params.get('loss_fn', 'binary_cross_entropy_ignore_nan_and_sum')
        ]   # yapf: disable
        # Set class weights manually
        if 'binary_cross_entropy_ignore_nan' in params.get(
                'loss_fn', 'binary_cross_entropy_ignore_nan_and_sum'):
            self.loss_fn.class_weights = params.get('class_weights', [1, 1])