예제 #1
0
    def next_seqs(self):
        """ Construct array of sequences for this stream chunk. """

        # extract next sequences from generator
        seqs_1hot = []
        stream_end = self.stream_start + self.stream_size
        for si in range(self.stream_start, stream_end):
            try:
                seqs_1hot.append(self.seqs_gen.__next__())
            except StopIteration:
                continue

        # initialize ensemble
        seqs_1hot_ens = []

        # add rc/shifts
        for seq_1hot in seqs_1hot:
            for shift in self.shifts:
                seq_1hot_aug = dna_io.hot1_augment(seq_1hot, shift=shift)
                seqs_1hot_ens.append(seq_1hot_aug)
                if self.rc:
                    seq_1hot_aug = dna_io.hot1_rc(seq_1hot_aug)
                    seqs_1hot_ens.append(seq_1hot_aug)

        seqs_1hot_ens = np.array(seqs_1hot_ens, dtype='float32')
        return seqs_1hot_ens
예제 #2
0
    def next(self, fwdrc=True, shift=0):
        """ Load the next batch from the HDF5. """
        Xb = None
        Yb = None
        NAb = None
        Nb = 0

        stop = self.start + self.batch_size
        if self.start < self.num_seqs:
            # full or partial batch
            if stop <= self.num_seqs:
                Nb = self.batch_size
            else:
                Nb = self.num_seqs - self.start

            # initialize
            Xb = np.zeros((Nb, self.seq_len, self.seq_depth), dtype='float32')
            if self.Yf is not None:
                if self.Yf.dtype == np.uint8:
                    ytype = 'int32'
                else:
                    ytype = 'float32'

                Yb = np.zeros(
                    (Nb, self.seq_len // self.pool_width, self.num_targets),
                    dtype=ytype)
                NAb = np.zeros((Nb, self.seq_len // self.pool_width),
                               dtype='bool')

            # copy data
            for i in range(Nb):
                si = self.order[self.start + i]
                Xb[i] = self.Xf[si]

                # fix N positions
                Xbi_n = (Xb[i].sum(axis=1) == 0)
                Xb[i] = Xb[i] + (1 / self.seq_depth) * Xbi_n.repeat(
                    self.seq_depth).reshape(self.seq_len, self.seq_depth)

                if self.Yf is not None:
                    Yb[i] = np.nan_to_num(self.Yf[si])

                    if self.NAf is not None:
                        NAb[i] = self.NAf[si]

        # reverse complement and shift
        if Xb is not None:
            Xb = dna_io.hot1_augment(Xb, fwdrc, shift)
        if not fwdrc:
            if Yb is not None:
                Yb = Yb[:, ::-1, :]
            if NAb is not None:
                NAb = NAb[:, ::-1]

        # update start
        self.start = min(stop, self.num_seqs)

        return Xb, Yb, NAb, Nb
예제 #3
0
    def next(self, rc=False, shift=0):
        try:
            d = self.session.run(self._next_element)

            Xb = d['sequence']
            Yb = d['label']
            NAb = d['na']
            Nb = Xb.shape[0]

            # reverse complement
            if rc:
                if Xb is not None:
                    Xb = dna_io.hot1_augment(Xb, rc, shift)
                if Yb is not None:
                    Yb = Yb[:, ::-1, :]
                if NAb is not None:
                    NAb = NAb[:, ::-1]

            return Xb, Yb, NAb, Nb

        except tf.errors.OutOfRangeError:
            return None, None, None, None
예제 #4
0
    def _predict_ensemble(self,
                          sess,
                          fd,
                          Xb,
                          ensemble_fwdrc,
                          ensemble_shifts,
                          mc_n,
                          ds_indexes=None,
                          target_indexes=None,
                          return_var=False,
                          return_all=False,
                          penultimate=False):

        # determine predictions length
        preds_length = self.preds_length
        if ds_indexes is not None:
            preds_length = len(ds_indexes)

        # determine num targets
        if penultimate:
            num_targets = self.hp.cnn_params[-1].filters
        else:
            num_targets = self.hp.num_targets
            if target_indexes is not None:
                num_targets = len(target_indexes)

        # initialize batch predictions
        preds_batch = np.zeros((Xb.shape[0], preds_length, num_targets),
                               dtype='float32')

        if return_var:
            preds_batch_var = np.zeros(preds_batch.shape, dtype='float32')
        else:
            preds_batch_var = None

        if return_all:
            all_n = mc_n * len(ensemble_fwdrc)
            preds_all = np.zeros(
                (Xb.shape[0], preds_length, num_targets, all_n),
                dtype='float32')
        else:
            preds_all = None

        running_i = 0

        for ei in range(len(ensemble_fwdrc)):
            # construct sequence
            Xb_ensemble = hot1_augment(Xb, ensemble_fwdrc[ei],
                                       ensemble_shifts[ei])

            # update feed dict
            fd[self.inputs] = Xb_ensemble

            # for each monte carlo (or non-mc single) iteration
            for mi in range(mc_n):
                # print('ei=%d, mi=%d, fwdrc=%d, shifts=%d' % (ei, mi, ensemble_fwdrc[ei], ensemble_shifts[ei]), flush=True)

                # predict
                if penultimate:
                    preds_ei = sess.run(self.penultimate_op, feed_dict=fd)
                else:
                    preds_ei = sess.run(self.preds_op, feed_dict=fd)

                # reverse
                if ensemble_fwdrc[ei] is False:
                    preds_ei = preds_ei[:, ::-1, :]

                # down-sample
                if ds_indexes is not None:
                    preds_ei = preds_ei[:, ds_indexes, :]
                if target_indexes is not None:
                    preds_ei = preds_ei[:, :, target_indexes]

                # save previous mean
                preds_batch1 = preds_batch

                # update mean
                preds_batch = self.running_mean(preds_batch1, preds_ei,
                                                running_i + 1)

                # update variance sum
                if return_var:
                    preds_batch_var = self.running_varsum(
                        preds_batch_var, preds_ei, preds_batch1, preds_batch)

                # save iteration
                if return_all:
                    preds_all[:, :, :, running_i] = preds_ei[:, :, :]

                # update running index
                running_i += 1

        return preds_batch, preds_batch_var, preds_all
예제 #5
0
    def _gradients_ensemble(self,
                            sess,
                            fd,
                            Xb,
                            ensemble_fwdrc,
                            ensemble_shifts,
                            mc_n,
                            return_var=False,
                            return_all=False):
        """ Compute gradients over an ensemble of input augmentations.

      In
       sess: TensorFlow session
       fd: feed dict
       Xb: input data
       ensemble_fwdrc:
       ensemble_shifts:
       mc_n:
       return_var:
       return_all: Return all ensemble predictions.

      Out
       preds:
       layer_reprs:
       layer_grads
    """

        # initialize batch predictions
        preds = np.zeros((Xb.shape[0], self.preds_length, self.hp.num_targets),
                         dtype='float32')

        # initialize layer representations and gradients
        layer_reprs = []
        layer_grads = []
        for lii in range(len(self.grad_layers)):
            li = self.grad_layers[lii]
            layer_seq_len = self.layer_reprs[li].shape[1].value
            layer_units = self.layer_reprs[li].shape[2].value

            lr = np.zeros((Xb.shape[0], layer_seq_len, layer_units),
                          dtype='float16')
            layer_reprs.append(lr)

            lg = np.zeros(
                (self.hp.num_targets, Xb.shape[0], layer_seq_len, layer_units),
                dtype='float32')
            layer_grads.append(lg)

        # initialize variance
        if return_var:
            preds_var = np.zeros(preds.shape, dtype='float32')

            layer_reprs_var = []
            layer_grads_var = []
            for lii in range(len(self.grad_layers)):
                layer_reprs_var.append(
                    np.zeros(layer_reprs.shape, dtype='float32'))
                layer_grads_var.append(
                    np.zeros(layer_grads.shape, dtype='float32'))
        else:
            preds_var = None
            layer_grads_var = [None] * len(self.grad_layers)

        # initialize all-saving arrays
        if return_all:
            all_n = mc_n * len(ensemble_fwdrc)
            preds_all = np.zeros(
                (Xb.shape[0], self.preds_length, self.hp.num_targets, all_n),
                dtype='float32')

            layer_reprs_all = []
            layer_grads_all = []
            for lii in range(len(self.grad_layers)):
                ls = tuple(list(layer_reprs[lii].shape) + [all_n])
                layer_reprs_all.append(np.zeros(ls, dtype='float32'))

                ls = tuple(list(layer_grads[lii].shape) + [all_n])
                layer_grads_all.append(np.zeros(ls, dtype='float32'))
        else:
            preds_all = None
            layer_grads_all = [None] * len(self.grad_layers)

        running_i = 0

        for ei in range(len(ensemble_fwdrc)):
            # construct sequence
            Xb_ensemble = hot1_augment(Xb, ensemble_fwdrc[ei],
                                       ensemble_shifts[ei])

            # update feed dict
            fd[self.inputs] = Xb_ensemble

            # for each monte carlo (or non-mc single) iteration
            for mi in range(mc_n):
                # print('ei=%d, mi=%d, fwdrc=%d, shifts=%d' % \
                #       (ei, mi, ensemble_fwdrc[ei], ensemble_shifts[ei]),
                #       flush=True)

                ##################################################
                # prediction

                # predict
                preds_ei, layer_reprs_ei = sess.run(
                    [self.preds_op, self.layer_reprs], feed_dict=fd)

                # reverse
                if ensemble_fwdrc[ei] is False:
                    preds_ei = preds_ei[:, ::-1, :]

                # save previous mean
                preds1 = preds

                # update mean
                preds = self.running_mean(preds1, preds_ei, running_i + 1)

                # update variance sum
                if return_var:
                    preds_var = self.running_varsum(preds_var, preds_ei,
                                                    preds1, preds)

                # save iteration
                if return_all:
                    preds_all[:, :, :, running_i] = preds_ei[:, :, :]

                ##################################################
                # representations

                for lii in range(len(self.grad_layers)):
                    li = self.grad_layers[lii]

                    # reverse
                    if ensemble_fwdrc[ei] is False:
                        layer_reprs_ei[li] = layer_reprs_ei[li][:, ::-1, :]

                    # save previous mean
                    layer_reprs_lii1 = layer_reprs[lii]

                    # update mean
                    layer_reprs[lii] = self.running_mean(
                        layer_reprs_lii1, layer_reprs_ei[li], running_i + 1)

                    # update variance sum
                    if return_var:
                        layer_reprs_var[lii] = self.running_varsum(
                            layer_reprs_var[lii], layer_reprs_ei[li],
                            layer_reprs_lii1, layer_reprs[lii])

                    # save iteration
                    if return_all:
                        layer_reprs_all[lii][:, :, :,
                                             running_i] = layer_reprs_ei[li]

                ##################################################
                # gradients

                # compute gradients for each target individually
                for ti in range(self.hp.num_targets):
                    # compute gradients
                    layer_grads_ti_ei = sess.run(self.grad_ops[ti],
                                                 feed_dict=fd)

                    for lii in range(len(self.grad_layers)):
                        # reverse
                        if ensemble_fwdrc[ei] is False:
                            layer_grads_ti_ei[lii] = layer_grads_ti_ei[
                                lii][:, ::-1, :]

                        # save predious mean
                        layer_grads_lii_ti1 = layer_grads[lii][ti]

                        # update mean
                        layer_grads[lii][ti] = self.running_mean(
                            layer_grads_lii_ti1, layer_grads_ti_ei[lii],
                            running_i + 1)

                        # update variance sum
                        if return_var:
                            layer_grads_var[lii][ti] = self.running_varsum(
                                layer_grads_var[lii][ti],
                                layer_grads_ti_ei[lii], layer_grads_lii_ti1,
                                layer_grads[lii][ti])

                        # save iteration
                        if return_all:
                            layer_grads_all[lii][
                                ti, :, :, :,
                                running_i] = layer_grads_ti_ei[lii]

                # update running index
                running_i += 1

        if return_var:
            return (preds, preds_var), (layer_reprs,
                                        layer_reprs_var), (layer_grads,
                                                           layer_grads_var)
        elif return_all:
            return (preds, preds_all), (layer_reprs,
                                        layer_reprs_all), (layer_grads,
                                                           layer_grads_all)
        else:
            return preds, layer_reprs, layer_grads
예제 #6
0
    def test_stochastic(self):
        # get HDF5 data
        hdf5_open = h5py.File(self.data_h5)
        hdf5_seqs = hdf5_open['valid_in']
        hdf5_targets = hdf5_open['valid_out']

        # get TFR data
        tfr_pattern = '%s/tfrecords/valid-0.tfr' % self.tfr_data_dir
        next_op = make_data_op(tfr_pattern, self.seq_length,
                               self.target_length)

        # define augmentation
        augment_shifts = [-2, -1, 0, 1, 2]
        next_op = augmentation.augment_stochastic(next_op, True,
                                                  augment_shifts)

        # initialize counters
        augment_counts = {}
        for fwdrc in [True, False]:
            for shift in augment_shifts:
                augment_counts[(fwdrc, shift)] = 0

        # choose # sequences
        max_seqs = min(64, hdf5_seqs.shape[0])
        si = 0

        # iterate over data
        si = 0
        with tf.Session() as sess:
            next_datum = sess.run(next_op)
            while next_datum:
                # parse TFRecord
                seqs_tfr = next_datum['sequence'][0]
                targets_tfr = next_datum['label'][0]

                # parse HDF5
                seqs_h5 = hdf5_seqs[si].astype('float32')
                targets_h5 = hdf5_targets[si].astype('float32')

                # expand dim
                seqs1_h5 = np.reshape(seqs_h5,
                                      (1, seqs_h5.shape[0], seqs_h5.shape[1]))

                # check augmentations for matches
                matched = False
                for fwdrc in [True, False]:
                    for shift in augment_shifts:
                        # modify sequence
                        seqs_h5_aug = dna_io.hot1_augment(
                            seqs1_h5, fwdrc, shift)[0]

                        # modify targets
                        if fwdrc:
                            targets_h5_aug = targets_h5
                        else:
                            targets_h5_aug = targets_h5[::-1, :]

                        # check match
                        if np.array_equal(seqs_tfr,
                                          seqs_h5_aug) and np.allclose(
                                              targets_tfr, targets_h5_aug):
                            #  print(si, fwdrc, shift)
                            matched = True
                            augment_counts[(fwdrc, shift)] += 1

                # assert augmentation found
                self.assertTrue(matched)

                try:
                    next_datum = sess.run(next_op)
                    si += 1
                except tf.errors.OutOfRangeError:
                    next_datum = False

        hdf5_open.close()

        # verify all augmentations appear
        for fwdrc in [True, False]:
            for shift in augment_shifts:
                # print(fwdrc, shift, augment_counts[(fwdrc,shift)])
                self.assertGreater(augment_counts[(fwdrc, shift)], 0)