def mean_per_class_accuracy(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) # Replace class_id_preds with class_id_true for recall here interesting_class_id = 0 accuracy_mask = K.cast(K.equal(class_id_true, interesting_class_id), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class0_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1) interesting_class_id = 1 accuracy_mask = K.cast(K.equal(class_id_true, interesting_class_id), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class1_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1) interesting_class_id = 2 accuracy_mask = K.cast(K.equal(class_id_true, interesting_class_id), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class2_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1) return (class0_acc + class1_acc + class2_acc) / 3
def adj_loss(y_true, y_pred): Y_pred = k.argmax(y_pred) Y_true = k.argmax(y_true) adj = SimpleAdjMat(batch_size, pixel_distance) adj_pred = adj.adj_mat(Y_pred, Y_true) adj_pred = tf.norm(tensor=adj_pred, ord=1, axis=1) adj_true = adj.adj_mat(Y_true, Y_pred) adj_true = tf.norm(tensor=adj_true, ord=1, axis=1) # L2 quad = (adj_pred - adj_true) quad = quad * quad sqrt = k.sqrt(quad) global adj_loss_value adj_loss_value = lambda_loss * k.mean(sqrt) global categ_loss categ_loss = k.categorical_crossentropy(y_true, y_pred) loss = adj_loss_value + categ_loss return loss
def fractional_accuracy(y_true, y_pred): equal = K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)) X = K.mean(K.sum(K.cast(equal, tf.float32), axis=-1)) not_equal = K.not_equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)) Y = K.mean(K.sum(K.cast(not_equal, tf.float32), axis=-1)) return X / (X + Y)
def captcha_metric(y_true, y_pred): y_pred = K.reshape(y_pred, (-1, alphabet)) y_true = K.reshape(y_true, (-1, alphabet)) y_p = K.argmax(y_pred, axis=1) y_t = K.argmax(y_true, axis=1) r = K.mean(K.cast(K.equal(y_p, y_t), 'float32')) return r
def class2_accuracy(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) accuracy_mask = K.cast(K.equal(class_id_preds, 2), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1) return class_acc
def mask_acc(y_true, y_pred): y_true_class = K.argmax(y_true, axis=-1) y_pred_class = K.argmax(y_pred, axis=-1) ignore_mask = K.cast(K.not_equal(y_true_class, 0), "int32") matches = K.cast(K.equal(y_true_class, y_pred_class), "int32") * ignore_mask accuracy = K.sum(matches) / K.maximum(K.sum(ignore_mask), 1) return accuracy
def accuracy_ignore_padding(y_true, y_pred): prediction = backend.argmax(y_pred, axis=-1) target = backend.argmax(y_true, axis=-1) accuracy = backend.equal( backend.array_ops.boolean_mask(prediction, backend.not_equal(target, 0)), backend.array_ops.boolean_mask(target, backend.not_equal(target, 0))) return backend.mean(accuracy)
def fallback_metric(self, y_true, y_pred): #grab the most confident prediction predictions = K.max(y_pred, axis=-1) #fill a tensor with our threshold_value threshold_tensor = tf.fill(tf.shape(predictions), self.threshold) #Are we confident in our prediction? threshold_high = predictions > threshold_tensor threshold_high = tf.cast(threshold_high, tf.int32) #Do we have low confidence in our prediction? threshold_low = predictions <= threshold_tensor threshold_low = tf.cast(threshold_low, tf.int32) idx_true = K.argmax(y_true, -1) idx_pred = K.argmax(y_pred, -1) #For our confident predictions, compare the top prediction to the label of the true value high_correct = math_ops.equal(idx_true, idx_pred) high_correct = tf.cast(high_correct, tf.int32) #For our less confident predictions, grab the top 2 most confident predictions _, max_pred = tf.math.top_k(y_pred, k=2) #Gather the lineages of those top 2 predictions using the transpose of the hierarchy's adjaency matrix because the adjacency only points from ancestor to descendant lineages = tf.gather(K.transpose(self.hierarchy.A), max_pred) lineages = K.cast(lineages, tf.int32) #Grab the first two columns of this matrix fallback = tf.bitwise.bitwise_and(lineages[:, 0], lineages[:, 1]) #Gather the lineage of the true value actual = tf.gather(K.transpose(self.hierarchy.A), K.argmax(y_true)) actual = K.cast(actual, tf.int32) #Multiply the two together overlap_score = K.batch_dot(fallback, actual) #Are either of the top 2 predictions in the lineage of the true value? If so, overlap_score should be >1 and we count the result as correct low_correct = overlap_score > 1 low_correct = tf.cast(low_correct, tf.int32) low_correct = tf.squeeze(low_correct) #results for the high confidence predictions high_accuracy = tf.math.multiply(threshold_high, high_correct) #results for the low confidence predictions low_accuracy = tf.math.multiply(threshold_low, low_correct) # total accuracy vector correct = high_accuracy + low_accuracy #return batch accuracy value return K.mean(K.cast(correct, tf.float32))
def sparse_accuracy_ignoring_last_label(y_true, y_pred): nb_classes = K.int_shape(y_pred)[-1] y_pred = K.reshape(y_pred, (-1, nb_classes)) y_true = K.one_hot(tf.to_int32(K.flatten(y_true)), nb_classes + 1) unpacked = tf.unstack(y_true, axis=-1) legal_labels = ~tf.cast(unpacked[-1], tf.bool) y_true = tf.stack(unpacked[:-1], axis=-1) return K.sum(tf.to_float(legal_labels & K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)))) / K.sum(tf.to_float(legal_labels))
def multitask_accuracy(y_true, y_pred): """Multi-Task accuracy metric. Only computes a batch-wise average of accuracy. Computes the accuracy, a metric for multi-label classification of how many items are selected correctly. """ return K.mean( K.equal(K.argmax(y_true[:, :-1], axis=-1), K.argmax(y_pred[:, :-1], axis=-1)))
def f(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) # Replace class_id_preds with class_id_true for recall here accuracy_mask = K.cast(K.equal(class_id_preds, interested_class_id), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class_acc = K.sum(class_acc_tensor) / K.maximum( K.sum(accuracy_mask), 1) return class_acc
def custom_accuracy(y_true, y_pred): y_true_1 = tf.slice(y_true, [0, 0], [batch_size, 6]) y_true_2 = tf.slice(y_true, [0, 6], [batch_size, 6]) y_pred_1 = tf.slice(y_pred, [0, 0], [batch_size, 6]) y_pred_2 = tf.slice(y_pred, [0, 6], [batch_size, 6]) equal = backend.equal( backend.equal(backend.argmax(y_true_1, axis=-1), backend.argmax(y_pred_1, axis=-1)), backend.equal(backend.argmax(y_true_2, axis=-1), backend.argmax(y_pred_2, axis=-1))) return backend.mean(equal)
def mzz_metrics(y_true, y_pred): ''' :param y_true: a tensor of shape: (batch_size, num_class), which represents the ground truth labels :param y_pred: a tensor of shape: (batch_size, num_class), which represents the predicted labels :return: ''' pred_indices = K.argmax(y_pred, axis=1) real = K.argmax(y_true * y_pred, axis=1) result = (tf.shape(pred_indices)[0] - tf.cast(tf.count_nonzero(pred_indices - real), dtype=tf.int32)) / \ tf.shape(pred_indices)[0] return result
def adj_loss(y_true, y_pred): Y_pred = k.argmax(y_pred) Y_true = k.argmax(y_true) adj0 = SingleAdjMat(batch_size, 0, pixel_distance) adj1 = SingleAdjMat(batch_size, 1, pixel_distance) adj_pred0 = adj0.adj_mat(Y_pred, Y_true) adj_pred0 = tf.norm(tensor=adj_pred0, ord=1, axis=1) adj_pred1 = adj1.adj_mat(Y_pred, Y_true) adj_pred1 = tf.norm(tensor=adj_pred1, ord=1, axis=1) adj_true0 = adj0.adj_mat(Y_true, Y_pred) adj_true0 = tf.norm(tensor=adj_true0, ord=1, axis=1) adj_true1 = adj1.adj_mat(Y_true, Y_pred) adj_true1 = tf.norm(tensor=adj_true1, ord=1, axis=1) # L2 quad0 = (adj_pred0 - adj_true0) quad0 = quad0 * quad0 quad1 = (adj_pred1 - adj_true1) quad1 = quad1 * quad1 global adj_loss_value tmp0 = k.mean(quad0) tmp0 = k.sum(tmp0) # global adj_loss_value # tmp0 = k.mean(quad0, keepdims=True) # tmp0 = k.sum(tmp0, axis=0) # tmp0 = tmp0 * vector_weights # tmp0 = k.sum(tmp0, axis=0) tmp1 = k.mean(quad1) tmp1 = k.sum(tmp1) tmp = tmp0 + tmp1 adj_loss_value = lambda_loss * tmp global categ_loss categ_loss = k.categorical_crossentropy(y_true, y_pred) loss = adj_loss_value + categ_loss return loss
def predict_img_bgr_prob(self, img_bgr, img_size=448): img_list = [] # img_size = (224, 224) img_size = (img_size, img_size) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) img_rgb = cv2.resize(img_rgb, img_size) # resize # img_list.append(img_rgb) # for idx in [90, 180, 270]: # img_tmp = rotate_img_for_4angle(img_rgb, idx) # img_list.append(img_tmp) # imgs_arr = np.array(img_list) # print('[Info] imgs_arr: {}'.format(imgs_arr.shape)) imgs_arr = np.expand_dims(img_rgb, axis=0) imgs_arr_b = preprocess_input(imgs_arr) predictions = self.model.predict(imgs_arr_b) # angle_dict = collections.defaultdict(int) # for i in range(4): # probs = predictions[i] # angle = (int(K.argmax(probs))) * 90 % 360 # angle = (angle + 90 * i) % 360 # angle_dict[angle] += 1 # # angle_list = sort_dict_by_value(angle_dict) probs = predictions[0] angle = (int(K.argmax(probs))) * 90 % 360 return angle
def _class_weights_map_fn(*data): """Convert `class_weight` to `sample_weight`.""" x, y, sw = unpack_x_y_sample_weight(data) if nest.is_sequence(y): raise ValueError( "`class_weight` is only supported for Models with a single output.") if y.shape.rank > 2: raise ValueError("`class_weight` not supported for " "3+ dimensional targets.") y_classes = smart_cond.smart_cond( y.shape.rank == 2 and backend.shape(y)[1] > 1, lambda: backend.argmax(y, axis=1), lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64)) cw = array_ops.gather_v2(class_weight_tensor, y_classes) if sw is not None: cw = math_ops.cast(cw, sw.dtype) sw, cw = expand_1d((sw, cw)) # `class_weight` and `sample_weight` are multiplicative. sw = sw * cw else: sw = cw return x, y, sw
def filter_boxes(box_confidence, boxes, box_class_probs, threshold=.6): """ Filter YOLO boxes based on object and class confidence. Parameters ---------- box_confidence: tf.Tensor Probability estimate for whether each box contains any object. boxes: tf.Tensor Bounding boxes box_class_probs: tf.Tensor Probability distribution estimate for each box over class labels. threshold: float Threshold value Returns ------- _boxes, scores, classes: (tf.Tensor, tf.Tensor, tf.Tensor) """ box_scores = box_confidence * box_class_probs box_classes = K.argmax(box_scores, axis=-1) box_class_scores = K.max(box_scores, axis=-1) prediction_mask = box_class_scores >= threshold _boxes = tf.boolean_mask(boxes, prediction_mask) scores = tf.boolean_mask(box_class_scores, prediction_mask) classes = tf.boolean_mask(box_classes, prediction_mask) return _boxes, scores, classes
def linear_unbin_layer(tnsr): bin = K.constant((2 / 14), dtype='float32') norm = K.constant(1, dtype='float32') b = K.cast(K.argmax(tnsr), dtype='float32') a = b - norm # print('linear_unbin_layer out: {}'.format(a)) return a
def accuracy(y_true, y_pred): # reshape in case it's in shape (num_samples, 1) instead of (num_samples,) if K.ndim(y_true) == K.ndim(y_pred): y_true = K.squeeze(y_true, -1) # convert dense predictions to labels y_pred_labels = K.argmax(y_pred, axis=-1) y_pred_labels = K.cast(y_pred_labels, K.floatx()) return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
def compute_mask_loss(boxes, masks, annotations, masks_target, width, height, iou_threshold=0.5, mask_size=(28, 28)): """compute overlap of boxes with annotations""" iou = overlap(boxes, annotations) argmax_overlaps_inds = K.argmax(iou, axis=1) max_iou = K.max(iou, axis=1) # filter those with IoU > 0.5 indices = tf.where(K.greater_equal(max_iou, iou_threshold)) boxes = tf.gather_nd(boxes, indices) masks = tf.gather_nd(masks, indices) argmax_overlaps_inds = K.cast(tf.gather_nd(argmax_overlaps_inds, indices), 'int32') labels = K.cast(K.gather(annotations[:, 4], argmax_overlaps_inds), 'int32') # make normalized boxes x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] boxes = K.stack([ y1 / (K.cast(height, dtype=K.floatx()) - 1), x1 / (K.cast(width, dtype=K.floatx()) - 1), (y2 - 1) / (K.cast(height, dtype=K.floatx()) - 1), (x2 - 1) / (K.cast(width, dtype=K.floatx()) - 1), ], axis=1) # crop and resize masks_target # append a fake channel dimension masks_target = K.expand_dims(masks_target, axis=3) masks_target = tf.image.crop_and_resize( masks_target, boxes, argmax_overlaps_inds, mask_size ) masks_target = masks_target[:, :, :, 0] # remove fake channel dimension # gather the predicted masks using the annotation label masks = tf.transpose(masks, (0, 3, 1, 2)) label_indices = K.stack([tf.range(K.shape(labels)[0]), labels], axis=1) masks = tf.gather_nd(masks, label_indices) # compute mask loss mask_loss = K.binary_crossentropy(masks_target, masks) normalizer = K.shape(masks)[0] * K.shape(masks)[1] * K.shape(masks)[2] normalizer = K.maximum(K.cast(normalizer, K.floatx()), 1) mask_loss = K.sum(mask_loss) / normalizer return mask_loss
def predict_img_bgr(self, img_bgr): """ 预测角度 """ img_bgr = rotate_img_with_bound(img_bgr, 90) img_bgr = rotate_img_with_bound(img_bgr, -90) probs = self.predict_img_bgr_prob(img_bgr) angle = int(K.argmax(probs)) % 360 angle = self.format_angle(angle) return angle
def adj_loss(y_true, y_pred): Y_pred = k.argmax(y_pred) Y_true = k.argmax(y_true) adj = WeightedAdjMat(batch_size, pixel_distance) adj_pred = adj.adj_mat(Y_pred, Y_true) adj_pred = tf.norm(tensor=adj_pred, ord=1, axis=1) adj_true = adj.adj_mat(Y_true, Y_pred) adj_true = tf.norm(tensor=adj_true, ord=1, axis=1) # L1 mod = k.abs(adj_pred - adj_true) global adj_loss_value adj_loss_value = lambda_loss * k.mean(mod) global categ_loss categ_loss = k.categorical_crossentropy(y_true, y_pred) loss = adj_loss_value + categ_loss return loss
def sparse_categorical_accuracy(y_true, y_pred): """ Accuracy metric for semantic image segmentation. None of the existing Keras accuracy metrics seem to work with the tensor shapes used here. Args: y_true: float32 array with true lables, shape: (-1, img_height * img_weidth) y_pred: float32 array with probabilities from a softmax layer, shape: (-1, img_height * img_weidth, nb_classes) Return: Accuracy of prediction """ return K.cast( K.equal(y_true, K.cast(K.argmax(y_pred, axis=-1), K.floatx())), K.floatx())
def process_yolo_layer_output(feats, anchors, num_classes, input_shape, image_shape): """Process Conv layer output""" box_xy, box_wh, box_confidence, box_class_probs = yolo_head(feats, anchors, num_classes, input_shape) box_scores = box_confidence * box_class_probs highest_score_indexes = K.argmax(box_scores, axis=-1) box_classes = K.reshape(highest_score_indexes, [-1]) highest_scores = K.max(box_scores, axis=-1) highest_box_scores = K.reshape(highest_scores, [-1]) boxes = scale_boxes_to_original_image_size(box_xy, box_wh, image_shape) boxes = K.reshape(boxes, [-1, 4]) return boxes, highest_box_scores, box_classes
def adj_loss(y_true, y_pred): Y_pred = k.argmax(y_pred) Y_true = k.argmax(y_true) adj = adj_mat_func(batch_size) adj_pred = adj.adj_mat(Y_pred, Y_true) adj_pred = tf.norm(tensor=adj_pred, ord=1, axis=1) adj_true = adj.adj_mat(Y_true, Y_pred) adj_true = tf.norm(tensor=adj_true, ord=1, axis=1) # L2 quad = (adj_pred - adj_true) quad = quad * quad global adj_loss_value adj_loss_value = lambda_loss * k.mean(quad) global categ_loss categ_loss = k.categorical_crossentropy(y_true, y_pred) loss = adj_loss_value + categ_loss return loss
def predict(self, data: List[str] = None, model_path: str = None): if self.model is None and model_path is not None: print(f"Loading model from {model_path}.") self.model = load_model(model_path) self.data_sequence.tokenizer = Tokenizer.from_vocab() elif self.model is None: print(f"No model file provided. Training new model.") self.build() self.train() pred_sequence = PredictionSequence(self.data_sequence, data) predictions = self.model.predict_generator(pred_sequence, steps=len(pred_sequence)) for index, sample in enumerate(pred_sequence.samples): prediction = [K.argmax(char) for char in predictions[index]] print(f"Predicted for sample {sample}: {prediction}")
def cal_si_snr_with_pit(source, estimate_source, padding): """ :param source:[bs, spkt, time] :param estimate_source: [B, spke, time] :param source_lengths: source_length to remove pad :return: """ assert source.shape == estimate_source.shape bs, spk1, time = source.shape mask = get_mask(source, padding) source *= mask estimate_source *= mask mean_target = k.mean(source, axis=[2], keepdims=True) mean_estimate = k.mean(estimate_source, axis=[2], keepdims=True) zero_mean_target = source - mean_target zero_mean_estimate = estimate_source - mean_estimate zero_mean_target *= mask # [bs, spkt, time] zero_mean_estimate *= mask # [bs, spke, time] pair_wise_dot = tf.matmul(zero_mean_estimate, zero_mean_target, transpose_b=True) # [bs, spkt, spke] s_target_energy = k.sum(zero_mean_target**2, axis=-1, keepdims=True) + eps s_target_energy = s_target_energy[:, tf.newaxis, :, :] # [bs, spkt ,1] # s_target_energy = k.sum(s_target**2, axis=-1, keepdims=True) + eps # [bs, spkt, spke, 1] # s_target = <s', s>s / ||s||^2 # s' [] pair_wise_proj = pair_wise_dot[ ..., tf.newaxis] * zero_mean_target[:, tf.newaxis, :, :] / s_target_energy e_noise = zero_mean_estimate[:, :, tf.newaxis, :] - pair_wise_proj pair_wise_si_snr = k.sum(pair_wise_proj**2, axis=-1) / (k.sum(e_noise**2, axis=-1) + eps) pair_wise_si_snr = 10 * tf.math.log(pair_wise_si_snr + eps) / tf.math.log(10.) # permutations, [C!, C] perms = tf.cast(list(permutations(range(spk1))), tf.int64) length = perms.shape[0] perms = tf.one_hot(perms, depth=spk1) # perms [C!, C , C] snr_set = tf.einsum("bij,pij->bp", pair_wise_si_snr, perms) max_snr_idx = k.argmax(snr_set, axis=-1) # [B,] max_snr_idx = tf.one_hot(max_snr_idx, depth=length) max_snr = max_snr_idx * snr_set max_snr = k.sum(max_snr, axis=-1) / tf.cast(spk1, tf.float32) return max_snr, perms, max_snr_idx
def loss(pred): actual_labels = [[] for _ in range(len_actual_labels)] for i, actual_label in enumerate(self._y_labeled): if actual_label: actual_labels[int(actual_label) - 1].append(i) pred_labels = K.argmax(pred, axis=1) predicted_labels = [] for i in range(self._n_clusters): indices = tf.where(tf.equal(pred_labels, i)) predicted_labels.append(indices) def predicted_distance(i, j): # binary metric # 1, if examples are in the same cluster i_tens, j_tens = tf.dtypes.cast(i, tf.float32), tf.dtypes.cast(j, tf.float32) for cluster in predicted_labels: # y = tf.split(cluster, cluster.shape[1], axis=1) for tens_i in range(cluster.shape[1]): for tens_j in range(cluster.shape[1]): if tens_i != tens_j: if cluster[tens_i] == i_tens and cluster[tens_j] == j_tens: return tf.dtypes.cast(0., tf.float32) return tf.dtypes.cast(1., tf.float32) loss = 0. for labeled_cluster in actual_labels: # for each gold cluster find if it's objects are in the same predicted cluster for i in range(len(labeled_cluster)): for j in range(len(labeled_cluster) + 1): loss += predicted_distance(i, j) # ss_loss = 1. - 1e1 * loss / (self._batch_size - 1) # 1. - 1e1 * batch_size * sum / nCr(batch_size, 2) # ss_loss = loss / sum([len(labeled_cluster) * len(labeled_cluster) for labeled_cluster in actual_labels[1:]]) ss_loss = loss / tf.dtypes.cast(len(predicted_labels), tf.float32) return ss_loss
def mean_iou(y_true, y_pred): """ Args: y_true: true labels, tensor with shape (-1, num_labels) y_pred: predicted label propabilities from a softmax layer, tensor with shape (-1, num_labels, num_classes) """ iou_sum = K.variable(0.0, name='iou_sum') seen_classes = K.variable(0.0, name='seen_classes') y_pred_sparse = K.argmax(y_pred, axis=-1) for c in range(0, num_classes): true_c = K.cast(K.equal(y_true, c), K.floatx()) pred_c = K.cast(K.equal(y_pred_sparse, c), K.floatx()) true_c_sum = K.sum(true_c) pred_c_sum = K.sum(pred_c) intersect = true_c * pred_c union = true_c + pred_c - intersect intersect_sum = K.sum(intersect) union_sum = K.sum(union) iou = intersect_sum / union_sum union_sum_is_zero = K.equal(union_sum, 0) iou_sum = K.switch(union_sum_is_zero, iou_sum, iou_sum + iou) seen_classes = K.switch(union_sum_is_zero, seen_classes, seen_classes + 1) # Calculate mean IOU over all (seen) classes. Regarding this check # `seen_classes` can only be 0 if none of the true or predicted # labels in the batch contains a valid class. We do not want to # raise a DivByZero error in this case. return K.switch(K.equal(seen_classes, 0), iou_sum, iou_sum / seen_classes)
def call(self, inputs, **kwargs): if type(inputs) is list: assert len(inputs) == 2 input, mask = inputs _, hei, wid, _, _ = input.get_shape() if self.resize_masks: mask = tf.image.resize_bicubic(mask, (hei.value, wid.value)) mask = K.expand_dims(mask, -1) if input.get_shape().ndims == 3: masked = K.batch_flatten(mask * input) else: masked = mask * input else: if inputs.get_shape().ndims == 3: x = K.sqrt(K.sum(K.square(inputs), -1)) mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1]) masked = K.batch_flatten(K.expand_dims(mask, -1) * inputs) else: masked = inputs return masked