Example #1
0
    def summarize(self, write_example=True):
        with self.g.as_default():
            is_training(False, session=self.sess)
        iter_step = self.get_global_step()
        fetches = [
            self.loss, self.grad_summ, self.activation_summ, self.edit_distance
        ]
        if write_example:
            fetches.append(self.pred)
        vals = self._rnn_roll(fetches)

        edit_distance = np.mean(vals[3]).item()
        summaries = [x[0] for x in vals[1:3]]
        for summary in summaries:
            self.train_writer.add_summary(summary, global_step=iter_step)

        self.train_writer.add_summary(tf.Summary(value=[
            tf.Summary.Value(tag="train/edit_distance",
                             simple_value=edit_distance)
        ]),
                                      global_step=iter_step)

        if write_example:
            out_net = list(map(lambda x: decode_sparse(x)[0], vals[-1]))
            target = self.decode_target(0, pad=self.block_size_y)
            for a, b in zip(out_net, target):
                print("OUTPUT:", a)
                print("TARGET:", b)
                print('----')
Example #2
0
    def train_minibatch(self, log_every=100, trace_every=10000):
        """
            Trains minibatch and performs all required operations

            Args:
                log_every: how many time steps to log train_writer and stdout
                trace_every: how many time steps to perform full trace execution (slow). It always performs it at global step 25.
        """
        with self.g.as_default():
            is_training(True, session=self.sess)
        tt = perf_counter()
        iter_step, _ = self.sess.run([self.global_step, self.load_train])
        self.sess.run([self.inc_gs])
        self.dequeue_time = 0.8 * self.dequeue_time + 0.2 * (perf_counter() -
                                                             tt)

        if trace_every > 0:
            if (iter_step > 0
                    and iter_step % trace_every == 0) or iter_step == 25:
                self.trace_level = tf.RunOptions.FULL_TRACE

        self.bbt = 0.8 * (self.bbt) + 0.2 * (perf_counter() - self.bbt_clock)
        tt = perf_counter()
        fetches = [self.train_op, self.loss, self.reg]
        vals = self._rnn_roll(add_fetch=fetches, timeline_suffix="ctc_loss")
        self.batch_time = 0.8 * self.batch_time + 0.2 * (perf_counter() - tt)
        self.bbt_clock = perf_counter()

        loss = np.sum(vals[1]).item()
        reg_loss = np.sum(vals[2]).item()

        self.train_writer.add_summary(tf.Summary(value=[
            tf.Summary.Value(tag="train/loss", simple_value=loss),
            tf.Summary.Value(tag="train/reg_loss", simple_value=reg_loss),
            tf.Summary.Value(tag="input/batch_time",
                             simple_value=self.batch_time),
            tf.Summary.Value(tag="input/dequeue_time",
                             simple_value=self.dequeue_time),
            tf.Summary.Value(tag="input/between_batch_time",
                             simple_value=self.bbt),
        ]),
                                      global_step=iter_step)

        if iter_step % log_every == 0:
            train_summ, y_len = self.sess.run([self.train_summ, self.Y_len])
            self.train_writer.add_summary(train_summ, global_step=iter_step)

            self.logger.info(
                "%4d loss %6.3f reg_loss %6.3f bt %.3f, bbt %.3f, avg_y_len %.3f dequeue %.3fs",
                iter_step, loss, reg_loss, self.batch_time, self.bbt,
                np.mean(y_len), self.dequeue_time)

        self.trace_level = tf.RunOptions.NO_TRACE
        return loss
Example #3
0
    def run_validation(self, num_batches=5):
        """
        Runs validation.

        Args:
        num_batches: Number of batches to run validation.

        Returns:
            A tuple (average loss, average edit distance) on validation set
        """
        with self.g.as_default():
            is_training(False, session=self.sess)
        losses = []
        reg_losses = []
        edit_distances = []
        for _ in range(num_batches):
            _, iter_step, summ = self.sess.run(
                [self.load_test, self.global_step, self.test_queue_size])
            self.test_writer.add_summary(summ, global_step=iter_step)
            loss, reg_loss, edit_distance = self._rnn_roll(
                add_fetch=[self.loss, self.reg, self.edit_distance],
                timeline_suffix="ctc_val_loss")
            losses.append(np.sum(loss))
            reg_losses.append(np.sum(reg_loss))
            edit_distances.append(edit_distance)
        avg_loss = np.mean(losses).item()
        avg_reg_loss = np.mean(reg_losses).item()
        avg_edit_distance = np.mean(edit_distances).item()
        self.test_writer.add_summary(tf.Summary(value=[
            tf.Summary.Value(tag="train/loss", simple_value=avg_loss),
            tf.Summary.Value(tag="train/reg_loss", simple_value=avg_reg_loss),
            tf.Summary.Value(tag="train/edit_distance",
                             simple_value=avg_edit_distance),
        ]),
                                     global_step=iter_step)
        self.logger.info(
            "%4d validation loss %6.3f edit_distance %.3f in %.3fs" %
            (iter_step, avg_loss, avg_edit_distance,
             perf_counter() - self.bbt_clock))
        self.bbt_clock = perf_counter()
        return avg_loss, avg_edit_distance
Example #4
0
    def run_validation_full(self,
                            frac,
                            verbose=False,
                            fasta_out_dir=None,
                            ref=None):
        """
            Runs full validation on test set with whole sequence_length
            Args:
                frac: fraction of test set to evaluate, or number of test cases if int
        """

        test_root = os.path.join(input_readers.root_dir_default, 'test')
        items = np.array(glob(os.path.join(test_root, '*.fa')))
        np.random.shuffle(items)

        if isinstance(frac, (int)):
            items = items[:frac]
        else:
            items = items[:int(len(items) * frac)]
        self.logger.info("Running validation on %d examples", len(items))

        nedit = np.zeros(items.shape, dtype=np.float32)
        acc = np.zeros(items.shape, dtype=np.float32)
        n = len(acc)
        total_time = 1e-6
        total_bases_read = 0
        cigar_stat = defaultdict(int)

        with self.g.as_default():
            is_training(False, session=self.sess)
            pbar = tqdm(items)
            for i, fast5_path in enumerate(pbar):
                t = perf_counter()
                nedit[i], acc[
                    i], read_len, cigar_read_stat = self.get_aligement(
                        fast5_path,
                        input_readers.find_ref(fast5_path),
                        verbose=verbose,
                        fasta_out_dir=fasta_out_dir,
                        ref=ref)
                total_time += perf_counter() - t
                total_bases_read += read_len
                for k in ['=', 'X', 'I', 'D']:
                    cigar_stat[k] += cigar_read_stat[k]

                mu_edit, mu_acc = np.mean(nedit[:i + 1]), np.mean(acc[:i + 1])
                std_edit, std_acc = np.std(nedit[:i + 1]), np.std(acc[:i + 1])
                se_edit, se_acc = std_edit / np.sqrt(i +
                                                     1), std_acc / np.sqrt(i +
                                                                           1)
                pbar.set_postfix(
                    stat=
                    "avg edit %.4f s %.4f CI <%.4f, %.4f> avg_acc %.4f s %.4f CI <%.4f, %.4f> %.2f bps"
                    % (mu_edit, std_edit, mu_edit - 2 * se_edit, mu_edit +
                       2 * se_edit, mu_acc, std_acc, mu_acc - 2 * se_acc,
                       mu_acc + 2 * se_acc, total_bases_read / total_time))

        mu_edit, mu_acc = np.mean(nedit), np.mean(acc)
        std_edit, std_acc = np.std(nedit), np.std(acc)
        se_edit, se_acc = std_edit / np.sqrt(
            len(nedit) + 1), std_acc / np.sqrt(len(nedit) + 1)

        self.logger.info(
            "step: %d [samples %d] avg edit %.4f s %.4f CI <%.4f, %.4f> avg_acc %.4f s %.4f CI <%.4f, %.4f> %.3f bps",
            self.get_global_step(), n, mu_edit, std_edit,
            mu_edit - 2 * se_edit, mu_edit + 2 * se_edit, mu_acc, std_acc,
            mu_acc - 2 * se_acc, mu_acc + 2 * se_acc,
            total_bases_read / total_time)

        total_cigar_elements = np.sum(list(cigar_stat.values()))
        for k in ['=', 'X', 'I', 'D']:
            self.logger.info("Step %d; %s %.2f%%", self.get_global_step(), k,
                             100 * cigar_stat[k] / total_cigar_elements)
        self.logger.info("Speed = %.3f bases per second",
                         total_bases_read / total_time)

        return {
            'edit': {
                'mu': mu_edit.item(),
                'std': std_edit.item(),
                'se': se_edit.item()
            },
            'accuracy': {
                'mu': mu_acc.item(),
                'std': std_acc.item(),
                'se': se_acc.item()
            }
        }
Example #5
0
    def basecall_singal(self,
                        fast5_path,
                        signal,
                        start_pad,
                        write_logits=None,
                        block_size=100000,
                        pad=512,
                        debug=False):
        with self.g.as_default():
            is_training(False, session=self.sess)
        t = perf_counter()

        original_len = signal.shape[1]
        original_signal = signal
        # print(signal.shape, original_len/block_size)
        signal = np.pad(signal, ((0, 0), (pad, pad), (0, 0)), 'constant')
        # print(signal.shape)
        chunks = [
            signal[:, i - pad:i + block_size + pad, :]
            for i in range(pad, signal.shape[1] - 2 * pad, block_size)
        ]

        # print(len(chunks), (signal.shape[1] - 2*pad)/block_size)
        # print("and now goes the chunks")
        np.testing.assert_equal(
            np.sum([chunk.shape[1] - 2 * pad for chunk in chunks]),
            original_len)

        if debug:
            np.testing.assert_equal(
                np.concatenate([c[:, pad:-pad, :] for c in chunks], axis=1),
                original_signal)

        logits_all = []
        for chunk in chunks:
            logits = self.sess.run(self.logits,
                                   feed_dict={
                                       self.X_batch:
                                       chunk,
                                       self.X_batch_len:
                                       np.array(chunk.shape[1]).reshape([
                                           1,
                                       ]),
                                       self.block_idx:
                                       0,
                                       self.batch_size_var:
                                       1,
                                       self.block_size_x_tensor:
                                       chunk.shape[1]
                                   })
            cut = pad // self.shrink_factor
            logits = logits[cut:-cut]
            logits_all.append(logits)

        logits = np.concatenate(logits_all)
        assert logits.shape[0] in [
            original_len // self.shrink_factor,
            original_len // self.shrink_factor + 1,
        ]
        basecalled = self.sess.run(
            self.dense_pred,
            feed_dict={
                self.logits:
                logits,
                self.X_batch_len:
                np.array([self.shrink_factor * logits.shape[0]]).reshape([
                    1,
                ])
            }).ravel()

        if write_logits:
            with h5py.File(fast5_path, 'a') as h5:
                h5_path = 'Analyses/MinCall/Logits'
                logits = np.squeeze(logits, (1, ))
                try:
                    h5.create_dataset(h5_path, data=logits)
                except RuntimeError:
                    del h5[h5_path]
                    h5.create_dataset(h5_path, data=logits)

                h5[h5_path].attrs['start_pad'] = start_pad
                h5[h5_path].attrs['model_logdir'] = self.log_dir
                h5[h5_path].attrs['run_id'] = self.run_id
                h5[h5_path].attrs['in_data_classname'] = type(
                    self.in_data).__name__
                h5[h5_path].attrs['shrink_factor'] = self.shrink_factor

                fname = os.path.join(self.log_dir, 'model_hyperparams.json')
                with open(fname, 'r') as f:
                    hyper = json.load(f)
                    for k, v in hyper.items():
                        h5[h5_path].attrs[k] = v

        self.logger.debug("Basecalled %s in %.3f", fast5_path,
                          perf_counter() - t)

        basecalled = "".join(util.decode(basecalled))
        return basecalled