def get_mol_descriptors(): smile_batch, path_batch = combine_data([drug_dataset[0], drug_dataset[1], drug_dataset[2], drug_dataset[3]], args) path_input, path_mask = path_batch path_input = path_input.to(args.device) path_mask = path_mask.to(args.device) mol_graph = MolGraph(smile_batch, args, path_input, path_mask) output, _ = molnet.model.forward(mol_graph) print(output.size()) print(mol_graph) print(mol_graph.scope) outputs = [] sizes = [] for j in range(len(drug_dataset.data)): smile_batch, path_batch = combine_data([drug_dataset[j]], args) path_input, path_mask = path_batch path_input = path_input.to(args.device) path_mask = path_mask.to(args.device) mol_graph = MolGraph(smile_batch, args, path_input, path_mask) sizes.append(mol_graph.scope[0][1]) outputs.append(molnet.model.forward(mol_graph)[0]) lengths = [output.size(0) for output in outputs] print(outputs[0].size(), outputs[1].size(), max(lengths), len(outputs)) mol_tensor = torch.zeros(len(outputs), max(lengths), outputs[0].size(1), dtype=torch.float32) print(mol_tensor.size()) print(sizes) for i, output in enumerate(outputs): mol_tensor[i,:lengths[i],:] = output return sizes, mol_tensor.to(args.device)
def setup_batch_ts(i, bsz): inds = batch_inds_ts[i:i + bsz] pr_seq = proteins_ts[inds, :] smile_batch, path_batch = combine_data( [drug_dataset[j] for j in drugs_ts_ind[inds]], args) path_input, path_mask = path_batch path_input = path_input.to(args.device) path_mask = path_mask.to(args.device) mol_graph = MolGraph(smile_batch, args, path_input, path_mask) return torch.tensor(pr_seq, device = args.device).to(torch.int64), mol_graph,\ torch.tensor(affinity_ts[inds], device=args.device).to(torch.float32)
def setup_batch(i, bsz): inds = get_index_batch(i, bsz) pr_seq = proteins_tr[inds, :] smile_batch, path_batch = combine_data( [drug_dataset[j] for j in drugs_tr_ind[inds]], args) path_input, path_mask = path_batch path_input = path_input.to(args.device) path_mask = path_mask.to(args.device) #print(path_input.size(), path_mask.size()) mol_graph = MolGraph(smile_batch, args, path_input, path_mask) return torch.tensor(pr_seq, device = args.device).to(torch.int64), mol_graph,\ torch.tensor(affinity_tr[inds], device=args.device).to(torch.float32)
def run_epoch(data_loader, model, optimizer, stat_names, args, mode, write_path=None): training = mode == 'train' prop_predictor = model prop_predictor.train() if training else prop_predictor.eval() if write_path is not None: write_file = open(write_path, 'w+') stats_tracker = data_utils.stats_tracker(stat_names) batch_split_idx = 0 all_pred_logits, all_labels = [], [] # Used to compute Acc, AUC for batch_idx, batch_data in enumerate( tqdm.tqdm(data_loader, dynamic_ncols=True)): if training and batch_split_idx % args.batch_splits == 0: optimizer.zero_grad() batch_split_idx += 1 smiles_list, labels_list, path_tuple = batch_data path_input, path_mask = path_tuple if args.use_paths: path_input = path_input.to(args.device) path_mask = path_mask.to(args.device) n_data = len(smiles_list) mol_graph = MolGraph(smiles_list, args, path_input, path_mask) pred_logits = prop_predictor(mol_graph, stats_tracker).squeeze(1) labels = torch.tensor(labels_list, device=args.device) if args.loss_type == 'ce': # memory issues all_pred_logits.append(pred_logits) all_labels.append(labels) if args.loss_type == 'mse': loss = nn.MSELoss()(input=pred_logits, target=labels) elif args.loss_type == 'mae': loss = nn.L1Loss()(input=pred_logits, target=labels) elif args.loss_type == 'ce': pred_probs = nn.Sigmoid()(pred_logits) loss = nn.BCELoss()(pred_probs, labels) else: assert (False) stats_tracker.add_stat('loss', loss.item() * n_data, n_data) loss = loss / args.batch_splits if args.loss_type == 'mae': mae = torch.mean(torch.abs(pred_logits - labels)) stats_tracker.add_stat('mae', mae.item() * n_data, n_data) if training: loss.backward() if batch_split_idx % args.batch_splits == 0: train_utils.backprop_grads(prop_predictor, optimizer, stats_tracker, args) batch_split_idx = 0 if write_path is not None: write_utils.write_props(write_file, smiles_list, labels_list, pred_logits.cpu().numpy()) if training and batch_split_idx != 0: train_utils.backprop_grads(model, optimizer, stats_tracker, args) # Any remaining if args.loss_type == 'ce': all_pred_logits = torch.cat(all_pred_logits, dim=0) all_labels = torch.cat(all_labels, dim=0) pred_probs = nn.Sigmoid()(all_pred_logits).detach().cpu().numpy() all_labels = all_labels.detach().cpu().numpy() acc = train_utils.compute_acc(pred_probs, all_labels) auc = train_utils.compute_auc(pred_probs, all_labels) stats_tracker.add_stat('acc', acc, 1) stats_tracker.add_stat('auc', auc, 1) if write_path is not None: write_file.close() return stats_tracker.get_stats()
def run_epoch(data_loader, model, optimizer, stat_names, args, mode, write_path=None): training = mode == 'train' atom_predictor = model atom_predictor.train() if training else atom_predictor.eval() if write_path is not None: write_file = open(write_path, 'w+') stats_tracker = data_utils.stats_tracker(stat_names) batch_split_idx = 0 all_pred_logits, all_labels = [], [] # Used to compute Acc, AUC for batch_idx, batch_data in enumerate( tqdm.tqdm(data_loader, dynamic_ncols=True)): if training and batch_split_idx % args.batch_splits == 0: optimizer.zero_grad() batch_split_idx += 1 smiles_list, labels_list, path_tuple = batch_data path_input, path_mask = path_tuple if args.use_paths: path_input = path_input.to(args.device) path_mask = path_mask.to(args.device) n_data = len(smiles_list) mol_graph = MolGraph(smiles_list, args, path_input, path_mask) atom_pairs_idx, labels = zip(*labels_list) pred_logits = atom_predictor(mol_graph, atom_pairs_idx, stats_tracker).squeeze(1) labels = [torch.tensor(x, device=args.device) for x in labels] labels = torch.cat(labels, dim=0) all_pred_logits.append(pred_logits) all_labels.append(labels) if args.n_classes > 1: pred_probs = nn.Softmax(dim=1)(pred_logits) loss = F.cross_entropy(input=pred_logits, target=labels) else: pred_probs = nn.Sigmoid()(pred_logits) loss = nn.BCELoss()(pred_probs, labels.float()) stats_tracker.add_stat('loss', loss.item() * n_data, n_data) loss = loss / args.batch_splits if write_path is not None: write_ring_output(write_file, smiles_list, atom_pairs_idx, labels, pred_probs, args.n_classes) if training: loss.backward() if batch_split_idx % args.batch_splits == 0: train_utils.backprop_grads(atom_predictor, optimizer, stats_tracker, args) batch_split_idx = 0 if training and batch_split_idx != 0: train_utils.backprop_grads(model, optimizer, stats_tracker, args) # Any remaining all_pred_logits = torch.cat(all_pred_logits, dim=0) all_labels = torch.cat(all_labels, dim=0) if args.n_classes > 1: pred_probs = nn.Softmax(dim=1)(all_pred_logits).detach().cpu().numpy() else: pred_probs = nn.Sigmoid()(all_pred_logits).detach().cpu().numpy() all_labels = all_labels.detach().cpu().numpy() acc = train_utils.compute_acc(pred_probs, all_labels, args.n_classes) if args.n_classes > 1: auc = 0 else: auc = train_utils.compute_auc(pred_probs, all_labels) stats_tracker.add_stat('acc', acc, 1) stats_tracker.add_stat('auc', auc, 1) if write_path is not None: write_file.close() return stats_tracker.get_stats()