def focal_loss(logits, labels, alpha, gamma=2, name='focal_loss'): """ Focal loss for multi classification :param logits: A float32 tensor of shape [batch_size num_class]. :param labels: A int32 tensor of shape [batch_size, num_class] or [batch_size]. :param alpha: A 1D float32 tensor for focal loss alpha hyper-parameter :param gamma: A scalar for focal loss gamma hyper-parameter. Returns: A tensor of the same shape as `lables` """ if len(labels.shape) == 1: labels = tf.one_hot(labels, logits.shape[-1]) else: labels = labels labels = tf.to_float(labels) y_pred = tf.nn.softmax(logits, dim=-1) L = -labels * tf.log(y_pred) L *= alpha * ((1 - y_pred)**gamma) loss = tf.reduce_sum(L) if tf.executing_eagerly(): tf.contrib.summary.scalar(name, loss) else: tf.summary.scalar(name, loss) return loss
def clip_gradients(self, grads_and_vars, clip_ratio, multitask=False): """Clip the gradients.""" is_zip_obj = False if isinstance(grads_and_vars, zip): grads_and_vars = list(grads_and_vars) is_zip_obj = True with tf.variable_scope('grad'): for grad, var in grads_and_vars: if grad is not None: if tf.executing_eagerly(): tf.contrib.summary.histogram(var.name[:-2], grad) else: tf.summary.histogram(var.name[:-2], grad) else: logging.debug('%s gradient is None' % (var.name)) # not clip if not clip_ratio: if is_zip_obj: grads, variables = zip(*grads_and_vars) grads_and_vars = zip(grads, variables) return grads_and_vars if multitask: grad_and_var_clipped, global_norm = tf.contrib.opt.clip_gradients_by_global_norm( grads_and_vars, clip_ratio) else: gradients, variables = zip(*grads_and_vars) clipped, global_norm = tf.clip_by_global_norm(gradients, clip_ratio) grad_and_var_clipped = zip(clipped, variables) if tf.executing_eagerly(): tf.contrib.summary.scalar('gradient/global_norm', global_norm) else: tf.summary.scalar('gradient/global_norm', global_norm) return grad_and_var_clipped
def summary_writer(logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None, filename_suffix=None, session=None, name=None): """Summary writer.""" if tf.executing_eagerly(): return tf.contrib.summary.create_file_writer( logdir, max_queue=max_queue, flush_millis=flush_secs * 1000, filename_suffix=filename_suffix, name=name) return tf.summary.FileWriter(logdir, graph, max_queue, flush_secs, graph_def, filename_suffix, session)
def test_fbank(self): wav_path = str(Path(PACKAGE_OPS_DIR).joinpath('data/sm1_cln.wav')) with self.cached_session(use_gpu=False, force_gpu=False): read_wav = ReadWav.params().instantiate() input_data, sample_rate = read_wav(wav_path) config = { 'window_length': 0.025, 'output_type': 1, 'frame_length': 0.010, 'snip_edges': True } fbank = Fbank.params(config).instantiate() fbank_test = fbank(input_data, sample_rate) self.assertEqual(tf.rank(fbank_test).eval(), 3) if tf.executing_eagerly(): print(fbank_test.numpy()[0:2, 0:6, 0]) else: print(fbank_test.eval()[0:2, 0:6, 0])
def testSomeTFSymbols(self): self.assertFalse(tf.executing_eagerly()) self.assertIsNotNone(tf.logging) self.assertIsNotNone(tf.flags) self.assertIs(tf.Defun, function.Defun)
def image(name, tensor, max_images=3): "Image" if tf.executing_eagerly(): tf.contrib.summary.image(name, tensor, max_images=max_images) else: tf.summary.image(name, tensor, max_outputs=max_images)
def audio(name, tensor, sample_rate, max_outputs=3): "Audio" if tf.executing_eagerly(): tf.contrib.summary.audio(name, tensor, sample_rate, max_outputs) else: tf.summary.audio(name, tensor, sample_rate, max_outputs)
def text(name, tensor): "Text" if tf.executing_eagerly(): tf.contrib.summary.text(name, tensor) else: tf.summary.text(name, tensor)
def histogram(name, values): "Histogram" if tf.executing_eagerly(): tf.contrib.summary.histogram(name, values) else: tf.summary.histogram(name, values)
def scalar(name, value): # pylint: redefined-outer-name "Scalar" if tf.executing_eagerly(): tf.contrib.summary.scalar(name, value) else: tf.summary.scalar(name, value)
def flush(writer=None, name=None): """Flush""" if tf.executing_eagerly(): tf.contrib.summary.flush(writer, name) else: tf.summary.flush(writer, name) # pylint: disable=no-member