# np.random.seed(123) import tensorflow as tf # tf.random.set_random_seed(123) import deepchem as dc from deepchem.molnet import load_tox21 from deepchem.models.graph_models import GraphConvModel, WeaveModel model_dir = "/home/minys/weave" os.system(f'rm -r {model_dir}/tox21-featurized') featurizer = 'Weave' # 'GraphConv' # modelcls = WeaveModel # GraphConvModel # # Load Tox21 dataset tox21_tasks, tox21_datasets, transformers = load_tox21(featurizer=featurizer, data_dir=model_dir, save_dir=model_dir) train_dataset, valid_dataset, test_dataset = tox21_datasets print(train_dataset.data_dir) print(valid_dataset.data_dir) # Fit models metric = dc.metrics.Metric( dc.metrics.roc_auc_score, np.mean, mode="classification") # Batch size of models batch_size = 64 model = modelcls( len(tox21_tasks), batch_size=batch_size, mode='classification')
multiConvMol = ConvMol.agglomerate_mols(X_b) d[atom_features] = multiConvMol.get_atom_features() d[degree_slice] = multiConvMol.deg_slice d[membership] = multiConvMol.membership for i in range(1, len(multiConvMol.get_deg_adjacency_lists())): d[deg_adjs[i - 1]] = multiConvMol.get_deg_adjacency_lists()[i] yield d return model, feed_dict_generator, labels, task_weights model_dir = "tmp/graphconv" # Load Tox21 dataset tox21_tasks, tox21_datasets, transformers = load_tox21(featurizer='GraphConv') train_dataset, valid_dataset, test_dataset = tox21_datasets print(train_dataset.data_dir) print(valid_dataset.data_dir) # Fit models metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean, mode="classification") # Batch size of models batch_size = 100 num_epochs = 10 model, generator, labels, task_weights = sluice_model(batch_size, tox21_tasks)
def main(): args = parse_arguments() # fix seed seed_everything(args.seed) # load tox21 dataset tox21_tasks, tox21_datasets, _ = load_tox21(featurizer='AdjacencyConv', reload=True) train_dataset, valid_dataset, test_dataset = tox21_datasets # define hyperparams rng = jrandom.PRNGKey(args.seed) # model params hidden_feats = [64, 64, 64] activation, batchnorm, dropout = None, None, None # use default predicator_hidden_feats = 32 pooling_method = 'mean' predicator_dropout = None # use default n_out = 1 # binary classification # training params lr = args.lr num_epochs = args.epochs batch_size = args.batch_size task = args.task early_stop_patience = args.early_stop # setup model init_fun, predict_fun = \ GCNPredicator(hidden_feats=hidden_feats, activation=activation, batchnorm=batchnorm, dropout=dropout, pooling_method=pooling_method, predicator_hidden_feats=predicator_hidden_feats, predicator_dropout=predicator_dropout, n_out=n_out) # init params rng, init_key = jrandom.split(rng) sample_node_feat = train_dataset.X[0][1] input_shape = sample_node_feat.shape _, init_params = init_fun(init_key, input_shape) opt_init, opt_update, get_params = optimizers.adam(step_size=lr) opt_state = opt_init(init_params) @jit def predict(params, inputs): """Predict the logits""" preds = predict_fun(params, *inputs) logits = clipped_sigmoid(preds) return logits # define training loss @jit def loss(params, batch): """Compute the loss (binary cross entropy) """ inputs, targets = batch[:-1], batch[-1] logits = predict(params, inputs) loss = -jnp.mean(targets * jnp.log(logits) + (1 - targets) * jnp.log(1 - logits)) return loss # define training update @jit def update(i, opt_state, batch): """Update the params""" params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) print("Starting training...") task_index = tox21_tasks.index(task) itercount = itertools.count() early_stop = EarlyStopping(patience=early_stop_patience) for epoch in range(num_epochs): # train start_time = time.time() for original_batch in train_dataset.iterbatches(batch_size=batch_size): rng, key = jrandom.split(rng) batch = collate_fn(original_batch, task_index, key, True) opt_state = update(next(itercount), opt_state, batch) epoch_time = time.time() - start_time # valid params = get_params(opt_state) y_score, y_true, valid_loss = [], [], [] for original_batch in valid_dataset.iterbatches(batch_size=batch_size): rng, key = jrandom.split(rng) batch = collate_fn(original_batch, task_index, key, False) y_score.extend(predict(params, batch[:-1])) y_true.extend(batch[-1]) valid_loss.append(loss(params, batch)) score = roc_auc_score(y_true, y_score) # log print( f"Iter {epoch}/{num_epochs} ({epoch_time:.4f} s) valid loss: {np.mean(valid_loss):.4f} \ valid roc_auc score: {score:.4f}") # check early stopping early_stop.update(score, params) if early_stop.is_train_stop: print("Early stopping...") break # test y_score, y_true = [], [] best_params = early_stop.best_params for original_batch in test_dataset.iterbatches(batch_size=batch_size): rng, key = jrandom.split(rng) batch = collate_fn(original_batch, task_index, key, False) y_score.extend(predict(best_params, batch[:-1])) y_true.extend(batch[-1]) score = roc_auc_score(y_true, y_score) print(f'Test roc_auc score: {score:.4f}') # save best params with open('./best_params.pkl', 'wb') as f: pickle.dump(best_params, f)
from __future__ import division from __future__ import unicode_literals import os import shutil import numpy as np import deepchem as dc from deepchem.molnet import load_tox21 from sklearn.linear_model import LogisticRegression # Only for debug! np.random.seed(123) # Load Tox21 dataset n_features = 1024 tox21_tasks, tox21_datasets, transformers = load_tox21() train_dataset, valid_dataset, test_dataset = tox21_datasets # Fit models metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean) def model_builder(model_dir_logreg): sklearn_model = LogisticRegression( penalty="l2", C=1. / 0.05, class_weight="balanced", n_jobs=-1) return dc.models.sklearn_models.SklearnModel(sklearn_model, model_dir_logreg) model = dc.models.multitask.SingletaskToMultitask(tox21_tasks, model_builder) # Fit trained model
import numpy as np import json np.random.seed(123) import tensorflow as tf tf.set_random_seed(123) import deepchem as dc from deepchem.molnet import load_tox21 from deepchem.models.tensorgraph.models.graph_models import PetroskiSuchModel model_dir = "/tmp/graph_conv" # Load Tox21 dataset tox21_tasks, tox21_datasets, transformers = load_tox21( featurizer='AdjacencyConv') train_dataset, valid_dataset, test_dataset = tox21_datasets print(train_dataset.data_dir) print(valid_dataset.data_dir) # Fit models metric = dc.metrics.Metric( dc.metrics.roc_auc_score, np.mean, mode="classification") # Batch size of models batch_size = 128 model = PetroskiSuchModel( len(tox21_tasks), batch_size=batch_size, mode='classification') model.fit(train_dataset, nb_epoch=10)
import numpy as np import json np.random.seed(123) import tensorflow as tf tf.set_random_seed(123) import deepchem as dc from deepchem.molnet import load_tox21 from deepchem.models.tensorgraph.models.graph_models import PetroskiSuchTensorGraph model_dir = "/tmp/graph_conv" # Load Tox21 dataset tox21_tasks, tox21_datasets, transformers = load_tox21(featurizer='AdjMatrix') train_dataset, valid_dataset, test_dataset = tox21_datasets print(train_dataset.data_dir) print(valid_dataset.data_dir) # Fit models metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean, mode="classification") # Batch size of models batch_size = 128 model = PetroskiSuchTensorGraph(len(tox21_tasks), batch_size=batch_size, mode='classification')
import numpy as np import json np.random.seed(123) import tensorflow as tf tf.set_random_seed(123) import deepchem as dc from deepchem.molnet import load_tox21 from deepchem.models.tensorgraph.models.graph_models import PetroskiSuchModel model_dir = "/tmp/graph_conv" # Load Tox21 dataset tox21_tasks, tox21_datasets, transformers = load_tox21( featurizer='AdjacencyConv') train_dataset, valid_dataset, test_dataset = tox21_datasets print(train_dataset.data_dir) print(valid_dataset.data_dir) # Fit models metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean, mode="classification") # Batch size of models batch_size = 128 model = PetroskiSuchModel(len(tox21_tasks), batch_size=batch_size, mode='classification')
d[task_weights] = w_b multiConvMol = ConvMol.agglomerate_mols(X_b) d[atom_features] = multiConvMol.get_atom_features() d[degree_slice] = multiConvMol.deg_slice d[membership] = multiConvMol.membership for i in range(1, len(multiConvMol.get_deg_adjacency_lists())): d[deg_adjs[i - 1]] = multiConvMol.get_deg_adjacency_lists()[i] yield d return model, feed_dict_generator, labels, task_weights model_dir = "tmp/graphconv" # Load Tox21 dataset tox21_tasks, tox21_datasets, transformers = load_tox21(featurizer='GraphConv') train_dataset, valid_dataset, test_dataset = tox21_datasets print(train_dataset.data_dir) print(valid_dataset.data_dir) # Fit models metric = dc.metrics.Metric( dc.metrics.roc_auc_score, np.mean, mode="classification") # Batch size of models batch_size = 100 num_epochs = 10 model, generator, labels, task_weights = sluice_model(batch_size, tox21_tasks)
def main(): args = parse_arguments() # fix seed seed_everything(args.seed) # load tox21 dataset tox21_tasks, tox21_datasets, _ = load_tox21(featurizer='AdjacencyConv', reload=True) train_dataset, valid_dataset, test_dataset = tox21_datasets # define hyperparams rng_seq = hk.PRNGSequence(args.seed) # model params in_feats = train_dataset.X[0][1].shape[1] hidden_feats = [64, 64, 32] activation, batch_norm, dropout = None, None, None # use default predicator_hidden_feats = 32 pooling_method = 'mean' predicator_dropout = 0.2 n_out = len(tox21_tasks) # training params lr = args.lr num_epochs = args.epochs batch_size = args.batch_size early_stop_patience = args.early_stop # setup model def forward(node_feats: np.ndarray, adj: np.ndarray, is_training: bool) -> jnp.ndarray: """Forward application of the GCN.""" model = GCNPredicator(in_feats=in_feats, hidden_feats=hidden_feats, activation=activation, batch_norm=batch_norm, dropout=dropout, pooling_method=pooling_method, predicator_hidden_feats=predicator_hidden_feats, predicator_dropout=predicator_dropout, n_out=n_out) preds = model(node_feats, adj, is_training) return preds model = hk.transform_with_state(forward) optimizer = optix.adam(learning_rate=lr) # define training loss def train_loss(params: hk.Params, state: State, batch: Batch) -> Tuple[float, State]: """Compute the loss.""" inputs, targets = batch preds, new_state = model.apply(params, state, next(rng_seq), *inputs, True) loss = bce_with_logits(preds, targets) return loss, new_state # define training update @jax.jit def update(params: hk.Params, state: State, opt_state: OptState, batch: Batch) -> Tuple[hk.Params, State, OptState]: """Update the params.""" (_, new_state), grads = jax.value_and_grad(train_loss, has_aux=True)(params, state, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optix.apply_updates(params, updates) return new_params, new_state, new_opt_state # define evaluate metrics @jax.jit def evaluate(params: hk.Params, state: State, batch: Batch) -> Tuple[jnp.ndarray, float, np.ndarray]: """Compute evaluate metrics.""" inputs, targets = batch preds, _ = model.apply(params, state, next(rng_seq), *inputs, False) loss = bce_with_logits(preds, targets) return preds, loss, targets print("Starting training...") early_stop = EarlyStopping(patience=early_stop_patience) batch_init_data = (jnp.zeros((batch_size, *train_dataset.X[0][1].shape)), jnp.zeros( (batch_size, *train_dataset.X[0][0].shape)), True) params, state = model.init(next(rng_seq), *batch_init_data) opt_state = optimizer.init(params) for epoch in range(num_epochs): # train start_time = time.time() for original_batch in train_dataset.iterbatches(batch_size=batch_size): batch = collate_fn(original_batch) params, state, opt_state = update(params, state, opt_state, batch) epoch_time = time.time() - start_time # valid y_score, y_true, valid_loss = [], [], [] for original_batch in valid_dataset.iterbatches(batch_size=batch_size): batch = collate_fn(original_batch) preds, loss, targets = evaluate(params, state, batch) y_score.extend(preds), valid_loss.append(loss), y_true.extend( targets) score, _ = multi_task_roc_auc_score(np.array(y_true), np.array(y_score)) # log print(f"Iter {epoch}/{num_epochs} ({epoch_time:.4f} s) \ valid loss: {np.mean(valid_loss):.4f} \ valid roc_auc score: {score:.4f}") # check early stopping early_stop.update(score, (params, state)) if early_stop.is_train_stop: print("Early stopping...") break # test y_score, y_true = [], [] best_checkpoints = early_stop.best_checkpoints for original_batch in test_dataset.iterbatches(batch_size=batch_size): batch = collate_fn(original_batch) logits, _, targets = evaluate(*best_checkpoints, batch) y_score.extend(logits), y_true.extend(targets) score, scores = multi_task_roc_auc_score(np.array(y_true), np.array(y_score)) print(f'Test mean roc_auc score: {score:.4f}') print(f'Test all roc_auc score: {str(scores)}') # save best checkpoints with open('./best_checkpoints.pkl', 'wb') as f: pickle.dump(best_checkpoints, f)