コード例 #1
0
def get_mol_descriptors():
    smile_batch, path_batch = combine_data([drug_dataset[0], drug_dataset[1], drug_dataset[2], drug_dataset[3]], args)
    path_input, path_mask = path_batch
    path_input = path_input.to(args.device)
    path_mask = path_mask.to(args.device)
    mol_graph = MolGraph(smile_batch, args, path_input, path_mask)
    output, _ = molnet.model.forward(mol_graph)
    print(output.size())
    print(mol_graph)
    print(mol_graph.scope)
    outputs = []
    sizes = []
    for j in range(len(drug_dataset.data)):
        smile_batch, path_batch = combine_data([drug_dataset[j]], args)
        path_input, path_mask = path_batch
        path_input = path_input.to(args.device)
        path_mask = path_mask.to(args.device)
        mol_graph = MolGraph(smile_batch, args, path_input, path_mask)
        sizes.append(mol_graph.scope[0][1])
        outputs.append(molnet.model.forward(mol_graph)[0])
    lengths = [output.size(0) for output in outputs]
    print(outputs[0].size(), outputs[1].size(), max(lengths), len(outputs))
    mol_tensor = torch.zeros(len(outputs), max(lengths), outputs[0].size(1), dtype=torch.float32)
    print(mol_tensor.size())
    print(sizes)
    for i, output in enumerate(outputs):
        mol_tensor[i,:lengths[i],:] = output
    return sizes, mol_tensor.to(args.device)
コード例 #2
0
def setup_batch_ts(i, bsz):
    inds = batch_inds_ts[i:i + bsz]
    pr_seq = proteins_ts[inds, :]
    smile_batch, path_batch = combine_data(
        [drug_dataset[j] for j in drugs_ts_ind[inds]], args)
    path_input, path_mask = path_batch
    path_input = path_input.to(args.device)
    path_mask = path_mask.to(args.device)
    mol_graph = MolGraph(smile_batch, args, path_input, path_mask)
    return torch.tensor(pr_seq, device = args.device).to(torch.int64), mol_graph,\
           torch.tensor(affinity_ts[inds], device=args.device).to(torch.float32)
コード例 #3
0
def setup_batch(i, bsz):
    inds = get_index_batch(i, bsz)
    pr_seq = proteins_tr[inds, :]
    smile_batch, path_batch = combine_data(
        [drug_dataset[j] for j in drugs_tr_ind[inds]], args)
    path_input, path_mask = path_batch
    path_input = path_input.to(args.device)
    path_mask = path_mask.to(args.device)
    #print(path_input.size(), path_mask.size())
    mol_graph = MolGraph(smile_batch, args, path_input, path_mask)

    return torch.tensor(pr_seq, device = args.device).to(torch.int64), mol_graph,\
           torch.tensor(affinity_tr[inds], device=args.device).to(torch.float32)
コード例 #4
0
def run_epoch(data_loader,
              model,
              optimizer,
              stat_names,
              args,
              mode,
              write_path=None):
    training = mode == 'train'
    prop_predictor = model
    prop_predictor.train() if training else prop_predictor.eval()

    if write_path is not None:
        write_file = open(write_path, 'w+')
    stats_tracker = data_utils.stats_tracker(stat_names)

    batch_split_idx = 0
    all_pred_logits, all_labels = [], []  # Used to compute Acc, AUC
    for batch_idx, batch_data in enumerate(
            tqdm.tqdm(data_loader, dynamic_ncols=True)):
        if training and batch_split_idx % args.batch_splits == 0:
            optimizer.zero_grad()
        batch_split_idx += 1

        smiles_list, labels_list, path_tuple = batch_data
        path_input, path_mask = path_tuple
        if args.use_paths:
            path_input = path_input.to(args.device)
            path_mask = path_mask.to(args.device)

        n_data = len(smiles_list)
        mol_graph = MolGraph(smiles_list, args, path_input, path_mask)

        pred_logits = prop_predictor(mol_graph, stats_tracker).squeeze(1)
        labels = torch.tensor(labels_list, device=args.device)

        if args.loss_type == 'ce':  # memory issues
            all_pred_logits.append(pred_logits)
            all_labels.append(labels)

        if args.loss_type == 'mse':
            loss = nn.MSELoss()(input=pred_logits, target=labels)
        elif args.loss_type == 'mae':
            loss = nn.L1Loss()(input=pred_logits, target=labels)
        elif args.loss_type == 'ce':
            pred_probs = nn.Sigmoid()(pred_logits)
            loss = nn.BCELoss()(pred_probs, labels)
        else:
            assert (False)
        stats_tracker.add_stat('loss', loss.item() * n_data, n_data)
        loss = loss / args.batch_splits

        if args.loss_type == 'mae':
            mae = torch.mean(torch.abs(pred_logits - labels))
            stats_tracker.add_stat('mae', mae.item() * n_data, n_data)

        if training:
            loss.backward()

            if batch_split_idx % args.batch_splits == 0:
                train_utils.backprop_grads(prop_predictor, optimizer,
                                           stats_tracker, args)
                batch_split_idx = 0

        if write_path is not None:
            write_utils.write_props(write_file, smiles_list, labels_list,
                                    pred_logits.cpu().numpy())

    if training and batch_split_idx != 0:
        train_utils.backprop_grads(model, optimizer, stats_tracker,
                                   args)  # Any remaining

    if args.loss_type == 'ce':
        all_pred_logits = torch.cat(all_pred_logits, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        pred_probs = nn.Sigmoid()(all_pred_logits).detach().cpu().numpy()
        all_labels = all_labels.detach().cpu().numpy()
        acc = train_utils.compute_acc(pred_probs, all_labels)
        auc = train_utils.compute_auc(pred_probs, all_labels)
        stats_tracker.add_stat('acc', acc, 1)
        stats_tracker.add_stat('auc', auc, 1)

    if write_path is not None:
        write_file.close()
    return stats_tracker.get_stats()
コード例 #5
0
def run_epoch(data_loader,
              model,
              optimizer,
              stat_names,
              args,
              mode,
              write_path=None):
    training = mode == 'train'
    atom_predictor = model
    atom_predictor.train() if training else atom_predictor.eval()

    if write_path is not None:
        write_file = open(write_path, 'w+')
    stats_tracker = data_utils.stats_tracker(stat_names)

    batch_split_idx = 0
    all_pred_logits, all_labels = [], []  # Used to compute Acc, AUC
    for batch_idx, batch_data in enumerate(
            tqdm.tqdm(data_loader, dynamic_ncols=True)):
        if training and batch_split_idx % args.batch_splits == 0:
            optimizer.zero_grad()
        batch_split_idx += 1

        smiles_list, labels_list, path_tuple = batch_data
        path_input, path_mask = path_tuple
        if args.use_paths:
            path_input = path_input.to(args.device)
            path_mask = path_mask.to(args.device)

        n_data = len(smiles_list)
        mol_graph = MolGraph(smiles_list, args, path_input, path_mask)
        atom_pairs_idx, labels = zip(*labels_list)

        pred_logits = atom_predictor(mol_graph, atom_pairs_idx,
                                     stats_tracker).squeeze(1)
        labels = [torch.tensor(x, device=args.device) for x in labels]
        labels = torch.cat(labels, dim=0)

        all_pred_logits.append(pred_logits)
        all_labels.append(labels)

        if args.n_classes > 1:
            pred_probs = nn.Softmax(dim=1)(pred_logits)
            loss = F.cross_entropy(input=pred_logits, target=labels)
        else:
            pred_probs = nn.Sigmoid()(pred_logits)
            loss = nn.BCELoss()(pred_probs, labels.float())

        stats_tracker.add_stat('loss', loss.item() * n_data, n_data)
        loss = loss / args.batch_splits

        if write_path is not None:
            write_ring_output(write_file, smiles_list, atom_pairs_idx, labels,
                              pred_probs, args.n_classes)

        if training:
            loss.backward()

            if batch_split_idx % args.batch_splits == 0:
                train_utils.backprop_grads(atom_predictor, optimizer,
                                           stats_tracker, args)
                batch_split_idx = 0

    if training and batch_split_idx != 0:
        train_utils.backprop_grads(model, optimizer, stats_tracker,
                                   args)  # Any remaining

    all_pred_logits = torch.cat(all_pred_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    if args.n_classes > 1:
        pred_probs = nn.Softmax(dim=1)(all_pred_logits).detach().cpu().numpy()
    else:
        pred_probs = nn.Sigmoid()(all_pred_logits).detach().cpu().numpy()
    all_labels = all_labels.detach().cpu().numpy()
    acc = train_utils.compute_acc(pred_probs, all_labels, args.n_classes)

    if args.n_classes > 1:
        auc = 0
    else:
        auc = train_utils.compute_auc(pred_probs, all_labels)
    stats_tracker.add_stat('acc', acc, 1)
    stats_tracker.add_stat('auc', auc, 1)

    if write_path is not None:
        write_file.close()
    return stats_tracker.get_stats()