def summarize(self, write_example=True): with self.g.as_default(): is_training(False, session=self.sess) iter_step = self.get_global_step() fetches = [ self.loss, self.grad_summ, self.activation_summ, self.edit_distance ] if write_example: fetches.append(self.pred) vals = self._rnn_roll(fetches) edit_distance = np.mean(vals[3]).item() summaries = [x[0] for x in vals[1:3]] for summary in summaries: self.train_writer.add_summary(summary, global_step=iter_step) self.train_writer.add_summary(tf.Summary(value=[ tf.Summary.Value(tag="train/edit_distance", simple_value=edit_distance) ]), global_step=iter_step) if write_example: out_net = list(map(lambda x: decode_sparse(x)[0], vals[-1])) target = self.decode_target(0, pad=self.block_size_y) for a, b in zip(out_net, target): print("OUTPUT:", a) print("TARGET:", b) print('----')
def train_minibatch(self, log_every=100, trace_every=10000): """ Trains minibatch and performs all required operations Args: log_every: how many time steps to log train_writer and stdout trace_every: how many time steps to perform full trace execution (slow). It always performs it at global step 25. """ with self.g.as_default(): is_training(True, session=self.sess) tt = perf_counter() iter_step, _ = self.sess.run([self.global_step, self.load_train]) self.sess.run([self.inc_gs]) self.dequeue_time = 0.8 * self.dequeue_time + 0.2 * (perf_counter() - tt) if trace_every > 0: if (iter_step > 0 and iter_step % trace_every == 0) or iter_step == 25: self.trace_level = tf.RunOptions.FULL_TRACE self.bbt = 0.8 * (self.bbt) + 0.2 * (perf_counter() - self.bbt_clock) tt = perf_counter() fetches = [self.train_op, self.loss, self.reg] vals = self._rnn_roll(add_fetch=fetches, timeline_suffix="ctc_loss") self.batch_time = 0.8 * self.batch_time + 0.2 * (perf_counter() - tt) self.bbt_clock = perf_counter() loss = np.sum(vals[1]).item() reg_loss = np.sum(vals[2]).item() self.train_writer.add_summary(tf.Summary(value=[ tf.Summary.Value(tag="train/loss", simple_value=loss), tf.Summary.Value(tag="train/reg_loss", simple_value=reg_loss), tf.Summary.Value(tag="input/batch_time", simple_value=self.batch_time), tf.Summary.Value(tag="input/dequeue_time", simple_value=self.dequeue_time), tf.Summary.Value(tag="input/between_batch_time", simple_value=self.bbt), ]), global_step=iter_step) if iter_step % log_every == 0: train_summ, y_len = self.sess.run([self.train_summ, self.Y_len]) self.train_writer.add_summary(train_summ, global_step=iter_step) self.logger.info( "%4d loss %6.3f reg_loss %6.3f bt %.3f, bbt %.3f, avg_y_len %.3f dequeue %.3fs", iter_step, loss, reg_loss, self.batch_time, self.bbt, np.mean(y_len), self.dequeue_time) self.trace_level = tf.RunOptions.NO_TRACE return loss
def run_validation(self, num_batches=5): """ Runs validation. Args: num_batches: Number of batches to run validation. Returns: A tuple (average loss, average edit distance) on validation set """ with self.g.as_default(): is_training(False, session=self.sess) losses = [] reg_losses = [] edit_distances = [] for _ in range(num_batches): _, iter_step, summ = self.sess.run( [self.load_test, self.global_step, self.test_queue_size]) self.test_writer.add_summary(summ, global_step=iter_step) loss, reg_loss, edit_distance = self._rnn_roll( add_fetch=[self.loss, self.reg, self.edit_distance], timeline_suffix="ctc_val_loss") losses.append(np.sum(loss)) reg_losses.append(np.sum(reg_loss)) edit_distances.append(edit_distance) avg_loss = np.mean(losses).item() avg_reg_loss = np.mean(reg_losses).item() avg_edit_distance = np.mean(edit_distances).item() self.test_writer.add_summary(tf.Summary(value=[ tf.Summary.Value(tag="train/loss", simple_value=avg_loss), tf.Summary.Value(tag="train/reg_loss", simple_value=avg_reg_loss), tf.Summary.Value(tag="train/edit_distance", simple_value=avg_edit_distance), ]), global_step=iter_step) self.logger.info( "%4d validation loss %6.3f edit_distance %.3f in %.3fs" % (iter_step, avg_loss, avg_edit_distance, perf_counter() - self.bbt_clock)) self.bbt_clock = perf_counter() return avg_loss, avg_edit_distance
def run_validation_full(self, frac, verbose=False, fasta_out_dir=None, ref=None): """ Runs full validation on test set with whole sequence_length Args: frac: fraction of test set to evaluate, or number of test cases if int """ test_root = os.path.join(input_readers.root_dir_default, 'test') items = np.array(glob(os.path.join(test_root, '*.fa'))) np.random.shuffle(items) if isinstance(frac, (int)): items = items[:frac] else: items = items[:int(len(items) * frac)] self.logger.info("Running validation on %d examples", len(items)) nedit = np.zeros(items.shape, dtype=np.float32) acc = np.zeros(items.shape, dtype=np.float32) n = len(acc) total_time = 1e-6 total_bases_read = 0 cigar_stat = defaultdict(int) with self.g.as_default(): is_training(False, session=self.sess) pbar = tqdm(items) for i, fast5_path in enumerate(pbar): t = perf_counter() nedit[i], acc[ i], read_len, cigar_read_stat = self.get_aligement( fast5_path, input_readers.find_ref(fast5_path), verbose=verbose, fasta_out_dir=fasta_out_dir, ref=ref) total_time += perf_counter() - t total_bases_read += read_len for k in ['=', 'X', 'I', 'D']: cigar_stat[k] += cigar_read_stat[k] mu_edit, mu_acc = np.mean(nedit[:i + 1]), np.mean(acc[:i + 1]) std_edit, std_acc = np.std(nedit[:i + 1]), np.std(acc[:i + 1]) se_edit, se_acc = std_edit / np.sqrt(i + 1), std_acc / np.sqrt(i + 1) pbar.set_postfix( stat= "avg edit %.4f s %.4f CI <%.4f, %.4f> avg_acc %.4f s %.4f CI <%.4f, %.4f> %.2f bps" % (mu_edit, std_edit, mu_edit - 2 * se_edit, mu_edit + 2 * se_edit, mu_acc, std_acc, mu_acc - 2 * se_acc, mu_acc + 2 * se_acc, total_bases_read / total_time)) mu_edit, mu_acc = np.mean(nedit), np.mean(acc) std_edit, std_acc = np.std(nedit), np.std(acc) se_edit, se_acc = std_edit / np.sqrt( len(nedit) + 1), std_acc / np.sqrt(len(nedit) + 1) self.logger.info( "step: %d [samples %d] avg edit %.4f s %.4f CI <%.4f, %.4f> avg_acc %.4f s %.4f CI <%.4f, %.4f> %.3f bps", self.get_global_step(), n, mu_edit, std_edit, mu_edit - 2 * se_edit, mu_edit + 2 * se_edit, mu_acc, std_acc, mu_acc - 2 * se_acc, mu_acc + 2 * se_acc, total_bases_read / total_time) total_cigar_elements = np.sum(list(cigar_stat.values())) for k in ['=', 'X', 'I', 'D']: self.logger.info("Step %d; %s %.2f%%", self.get_global_step(), k, 100 * cigar_stat[k] / total_cigar_elements) self.logger.info("Speed = %.3f bases per second", total_bases_read / total_time) return { 'edit': { 'mu': mu_edit.item(), 'std': std_edit.item(), 'se': se_edit.item() }, 'accuracy': { 'mu': mu_acc.item(), 'std': std_acc.item(), 'se': se_acc.item() } }
def basecall_singal(self, fast5_path, signal, start_pad, write_logits=None, block_size=100000, pad=512, debug=False): with self.g.as_default(): is_training(False, session=self.sess) t = perf_counter() original_len = signal.shape[1] original_signal = signal # print(signal.shape, original_len/block_size) signal = np.pad(signal, ((0, 0), (pad, pad), (0, 0)), 'constant') # print(signal.shape) chunks = [ signal[:, i - pad:i + block_size + pad, :] for i in range(pad, signal.shape[1] - 2 * pad, block_size) ] # print(len(chunks), (signal.shape[1] - 2*pad)/block_size) # print("and now goes the chunks") np.testing.assert_equal( np.sum([chunk.shape[1] - 2 * pad for chunk in chunks]), original_len) if debug: np.testing.assert_equal( np.concatenate([c[:, pad:-pad, :] for c in chunks], axis=1), original_signal) logits_all = [] for chunk in chunks: logits = self.sess.run(self.logits, feed_dict={ self.X_batch: chunk, self.X_batch_len: np.array(chunk.shape[1]).reshape([ 1, ]), self.block_idx: 0, self.batch_size_var: 1, self.block_size_x_tensor: chunk.shape[1] }) cut = pad // self.shrink_factor logits = logits[cut:-cut] logits_all.append(logits) logits = np.concatenate(logits_all) assert logits.shape[0] in [ original_len // self.shrink_factor, original_len // self.shrink_factor + 1, ] basecalled = self.sess.run( self.dense_pred, feed_dict={ self.logits: logits, self.X_batch_len: np.array([self.shrink_factor * logits.shape[0]]).reshape([ 1, ]) }).ravel() if write_logits: with h5py.File(fast5_path, 'a') as h5: h5_path = 'Analyses/MinCall/Logits' logits = np.squeeze(logits, (1, )) try: h5.create_dataset(h5_path, data=logits) except RuntimeError: del h5[h5_path] h5.create_dataset(h5_path, data=logits) h5[h5_path].attrs['start_pad'] = start_pad h5[h5_path].attrs['model_logdir'] = self.log_dir h5[h5_path].attrs['run_id'] = self.run_id h5[h5_path].attrs['in_data_classname'] = type( self.in_data).__name__ h5[h5_path].attrs['shrink_factor'] = self.shrink_factor fname = os.path.join(self.log_dir, 'model_hyperparams.json') with open(fname, 'r') as f: hyper = json.load(f) for k, v in hyper.items(): h5[h5_path].attrs[k] = v self.logger.debug("Basecalled %s in %.3f", fast5_path, perf_counter() - t) basecalled = "".join(util.decode(basecalled)) return basecalled