e_dim = sample_item[0].edata["feat"].shape[1] g_dim = len(sample_item[1]) loader_train = DataLoader( data_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_pair, num_workers=NUM_WORKERS, ) model = MPNNPredictor( node_in_feats=a_dim, edge_in_feats=e_dim, global_feats=g_dim, n_tasks=values.shape[1], num_step_message_passing=N_MESSPASS, output_f=None, ).to(DEVICE) opt = Adam(model.parameters(), lr=INITIAL_LR) LOGGER.info("Now training...") train_losses = [] for epoch_no in range(N_EPOCHS): print("Train epoch {}/{}...".format(epoch_no + 1, N_EPOCHS)) t_l = train_loop(loader_train, model, loss_fn, opt) train_losses.extend(t_l)
imp_rf = [] model_pt = os.path.join(MODELS_PATH, f"{data}_noHs.pt") model_rf = load(os.path.join(BASELINE_MODELS_PATH, f"rf_{data}.pt")) with open( os.path.join(DATA_PATH, f"{data}", f"data_{data}.pt"), "rb" ) as handle: inchis, _ = pickle.load(handle) output_f = torch.sigmoid if data == "cyp" else None model = MPNNPredictor( node_in_feats=49, edge_in_feats=10, global_feats=4, n_tasks=1, output_f=output_f, ).to(DEVICE) model.load_state_dict(torch.load(model_pt, map_location=DEVICE)) for inchi in tqdm(inchis): mol = MolFromInchi(inchi) for version in range(1, N_VERSIONS + 1): _, _, atom_importance, _, global_importance = molecule_importance( mol, model, version=version ) imp[version - 1].append(atom_importance) g_imp[version - 1].append(global_importance)
def predict( inchis, w_path, n_tasks=1, batch_size=32, output_f=None, add_hs=False, progress=True ): """Predicts values for a list of `inchis` given model weights `w_path`. Parameters ---------- inchis : list A list of inchis that we wish to predict values for w_path : pickle file path A path to model weights, pickled. n_tasks : int, optional number of tasks, by default 1 batch_size : int, optional output_f : [type], optional Activation function to apply on the output layer if necessary, by default None progress : bool, optional Show progress bar, by default True Returns ------- np.ndarray Predictions. """ data = GraphData(inchis, train=False, add_hs=add_hs) sample_item = data[0] a_dim = sample_item[0].ndata["feat"].shape[1] e_dim = sample_item[0].edata["feat"].shape[1] g_dim = len(sample_item[1]) loader = DataLoader( data, batch_size=batch_size, shuffle=False, collate_fn=collate_pair_prod, num_workers=NUM_WORKERS, ) if progress: loader = tqdm(loader) model = MPNNPredictor( node_in_feats=a_dim, edge_in_feats=e_dim, global_feats=g_dim, n_tasks=n_tasks, num_step_message_passing=N_MESSPASS, output_f=output_f, ).to(DEVICE) model.load_state_dict(torch.load(w_path, map_location=DEVICE)) yhats = [] for g, g_feat in loader: with torch.no_grad(): g = g.to(DEVICE) g_feat = g_feat.to(DEVICE) out = model(g, g_feat) yhats.append(out.cpu()) return torch.cat(yhats)
args = parser.parse_args() if torch.cuda.is_available(): LOGGER.info( f"Using device {torch.cuda.get_device_name()} for prediction and feature attribution." ) else: LOGGER.warning( f"A CUDA-capable device was not found. Using CPU for prediction and feature attribution, which can take considerably longer." ) LOGGER.info("Loading model...") model = MPNNPredictor( node_in_feats=args.node_in_feats, edge_in_feats=args.edge_in_feats, global_feats=args.global_feats, n_tasks=1, ).to(DEVICE) model.load_state_dict(torch.load(args.model_path, map_location=DEVICE)) LOGGER.info(f"Model {args.model_path} successfully loaded!") # Procesing ligands if args.smi.endswith(".smi"): with open(args.smi, "r+") as handle: ligands = handle.readlines() ligands = [sm.strip("\n") for sm in ligands] if ligands[-1] == "": ligands.pop(-1)