def prepare_dataset(positive_rowgen,
                    negative_dsgen,
                    args=None,
                    random_state=None) -> DataSetGAN:
    while True:
        if args and args.permute_positives:
            assert random_state is not None
        pos_row = next(positive_rowgen)
        if args and args.predict_pc_identity:
            pos_row = pos_row._replace(
                target=pc_identity_to_structure_quality(pos_row.target))
        else:
            pos_row = pos_row._replace(target=1)
        if args and args.permute_positives:
            seq_length = len(pos_row.sequence.replace("-", ""))
            assert random_state is not None
            permute_amount = get_offset(seq_length, random_state)
            pos_ds = row_to_dataset(pos_row, permute_amount=permute_amount)
        else:
            pos_ds = row_to_dataset(pos_row)
        pos_dsg = dataset_to_gan(pos_ds)
        if args and not dataset_matches_spec(pos_dsg, args):
            continue
        ds = negative_dsgen.send(pos_dsg)
        if ds is None:
            continue
        return ds
def permute_and_slice_datagen(
    datagen_pos: Iterator[DataRow],
    datagen_neg: Optional[Generator[DataRow, Any, None]],
    methods: Tuple,
) -> Iterator[DataSetCollection]:
    batch_pos = []
    assert "permute" in methods
    slice_methods = [m for m in methods if m != "permute"]
    for i, row in enumerate(datagen_pos):
        dataset_pos = dataset.row_to_dataset(row, target=1)
        if len(dataset_pos.seq) < settings.MIN_SEQUENCE_LENGTH:
            continue
        batch_pos.append(dataset_pos)
        if (i + 1) % 256 == 0:
            batch_neg = dataset.get_permuted_examples(batch_pos)
            for pos, neg in zip(batch_pos, batch_neg):
                pos_list = [pos]
                neg_list = [neg]
                for method in slice_methods:
                    try:
                        other_neg = dataset.get_negative_example(
                            pos, method=method, rowgen=datagen_neg)
                        neg_list.append(other_neg)
                    except (MaxNumberOfTriesExceededError,
                            SequenceTooLongError) as e:
                        logger.error("%s: %s", type(e), e)
                        continue
                yield pos_list, neg_list
            batch_pos = []
def _get_internal_validation_dataset(
        args: Args, method: str,
        random_state: np.random.RandomState) -> List[DataSetGAN]:
    columns = {
        "qseq": "sequence",
        "residue_idx_1_corrected": "adjacency_idx_1",
        "residue_idx_2_corrected": "adjacency_idx_2",
        "distances": None,
    }
    rowgen_pos = iter_datarows_shuffled(
        sorted(args.training_data_path.glob("database_id=*/*.parquet")),
        columns=columns,
        random_state=random_state,
    )

    rowgen_neg = gen_datarows_shuffled(
        sorted(args.validation_data_path.glob("database_id=*/*.parquet")),
        columns=columns,
        random_state=random_state,
    )
    next(rowgen_neg)

    nsa = negative_sequence_adder(rowgen_neg,
                                  method,
                                  num_sequences=1,
                                  keep_pos=True,
                                  random_state=random_state)
    next(nsa)

    dataset: List[DataSetGAN] = []
    while len(dataset) < args.validation_num_sequences:
        pos_row = next(rowgen_pos)
        pos_ds = dataset_to_gan(row_to_dataset(pos_row, 1))
        # Filter out bad datasets
        n_aa = len(pos_ds.seqs[0])
        if not (args.min_seq_length <= n_aa < args.max_seq_length):
            logger.debug(f"Skipping because wrong sequence length: {n_aa}.")
            continue
        adj_nodiag = remove_eye_sparse(pos_ds.adjs[0], 3)
        n_interactions = adj_nodiag.nnz
        if n_interactions <= 0:
            logger.debug(
                f"Skipping because too few interactions: {n_interactions}.")
            continue
        #
        ds = nsa.send(pos_ds)
        if ds is None:
            logger.debug("Skipping this sequence...")
            continue
        dataset.append(ds)

    assert len(dataset) == args.validation_num_sequences
    return dataset
def slice_datagen(datagen_pos: Iterator[DataRow],
                  datagen_neg: Generator[DataRow, Any, None],
                  methods: Tuple) -> Iterator[DataSetCollection]:
    for row in datagen_pos:
        dataset_pos = dataset.row_to_dataset(row, target=1)
        if len(dataset_pos.seq) < settings.MIN_SEQUENCE_LENGTH:
            continue
        datasets_neg = []
        try:
            for method in methods:
                dataset_neg = dataset.get_negative_example(dataset_pos,
                                                           method=method,
                                                           rowgen=datagen_neg)
                datasets_neg.append(dataset_neg)
        except (MaxNumberOfTriesExceededError, SequenceTooLongError) as e:
            logger.error("%s: %s", type(e), e)
            continue
        yield [dataset_pos], datasets_neg
def _get_mutation_dataset(mutation_class: str,
                          data_path: Path) -> List[DataSetGAN]:

    mutation_datarows = get_rowgen_mut(mutation_class, data_path)
    mutation_datasets = (dataset_to_gan(row_to_dataset(row, target=1))
                         for row in mutation_datarows)

    mutation_dsg = []
    for pos_ds in mutation_datasets:
        assert pos_ds.meta is not None
        neg_seq = bytearray(pos_ds.seqs[0])
        mutation = pos_ds.meta["mutation"]
        mutation_idx = int(mutation[1:-1]) - 1
        assert neg_seq[mutation_idx] == ord(mutation[0]), (chr(
            neg_seq[mutation_idx]), mutation[0])
        neg_seq[mutation_idx] = ord(mutation[-1])
        ds = pos_ds._replace(seqs=pos_ds.seqs + [neg_seq],
                             targets=pos_ds.targets + [pos_ds.meta["score"]])
        mutation_dsg.append(ds)

    return mutation_dsg
def get_mutation_datagen(mutation_class: str, data_path: Path) -> DataGen:

    mutation_datarows = get_rowgen_mut(mutation_class, data_path)
    mutation_datasets = (dataset.row_to_dataset(row, target=1)
                         for row in mutation_datarows)

    mutation_dsc = []
    for pos_ds in mutation_datasets:
        neg_seq = bytearray(pos_ds.seq)
        mutation = pos_ds.meta["mutation"]  # type: ignore
        mutation_idx = int(mutation[1:-1]) - 1
        assert neg_seq[mutation_idx] == ord(mutation[0]), (chr(
            neg_seq[mutation_idx]), mutation[0])
        neg_seq[mutation_idx] = ord(mutation[-1])
        neg_ds = DataSet(neg_seq, pos_ds.adj,
                         pos_ds.meta["score"])  # type: ignore
        mutation_dsc.append(([pos_ds], [neg_ds]))

    def datagen():
        for dvc in mutation_dsc:
            yield dvc

    return datagen
Exemple #7
0
def buffered_permuted_sequence_adder(
    rowgen: RowGen,
    num_sequences: int,
    keep_pos: bool = False,
    random_state: Optional[np.random.RandomState] = None,
) -> Generator[Optional[DataSetGAN], DataSetGAN, None]:
    """

    Args:
        rowgen: Used for **pre-populating** the generator only!
        num_sequences: Number of sequences to generate in each iteration.
    """
    raise NotImplementedError

    if random_state is None:
        random_state = np.random.RandomState()

    seq_buffer = [row_to_dataset(r, 0).seq for r in itertools.islice(rowgen, 512)]
    negative_dsg = None
    while True:
        dsg = yield negative_dsg
        seq = dsg.seqs[0]
        negative_seq_big = b"".join(seq_buffer)
        negative_seqs = []
        for _ in range(num_sequences):
            offset = random_state.randint(0, len(negative_seq_big) - len(seq))
            negative_seq = (negative_seq_big[offset:] + negative_seq_big[:offset])[: len(seq)]
            negative_seqs.append(negative_seq)
        negative_dsg = dsg._replace(
            seqs=(dsg.seqs if keep_pos else []) + negative_seqs,
            targets=(dsg.targets if keep_pos else []) + [0] * num_sequences,
        )
        # Reshuffle negative sequences
        seq_buffer.append(seq)
        random_state.shuffle(seq_buffer)
        random_state.pop()
def generate_batch(
    args: Args,
    net: nn.Module,
    positive_rowgen: RowGen,
    negative_ds_gen: Optional[DataSetGenM] = None,
):
    """Generate a positive and a negative dataset batch."""
    pos_seq_list = []
    neg_seq_list = []
    adjs = []
    seq_len = 0
    num_seqs = 0
    # TODO: 128 comes from the fact that we tested with sequences 64-256 AA in length
    # while seq_len < (args.batch_size * 128):
    while num_seqs < args.batch_size:
        pos_row = next(positive_rowgen)
        pos_ds = dataset_to_gan(row_to_dataset(pos_row, 1))
        if not dataset_matches_spec(pos_ds, args):
            continue
        pos_dv = net.dataset_to_datavar(pos_ds)
        pos_seq_list.append(pos_dv.seqs)
        adjs.append(pos_dv.adjs)
        if negative_ds_gen is not None:
            neg_ds = negative_ds_gen.send(pos_ds)
            neg_dv = net.dataset_to_datavar(neg_ds)
            neg_seq_list.append(neg_dv.seqs)
        seq_len += pos_dv.seqs.shape[2]
        num_seqs += 1
    pos_seq = torch.cat([s.data for s in pos_seq_list], 2)
    assert pos_seq.shape[2] == sum(adj[0].shape[1] for adj in adjs)
    if negative_ds_gen is not None:
        neg_seq = torch.cat([s.data for s in neg_seq_list], 2)
        assert neg_seq.shape[2] == sum(adj[0].shape[1] for adj in adjs)
    else:
        neg_seq = None
    return pos_seq, neg_seq, adjs
Exemple #9
0
 def datagen():
     for row in input_df.itertuples():
         dataset = row_to_dataset(row, 0)
         yield dataset