class Classifier(Predictor): model_name = 'classifier' def __init__(self, mark=None, net_type=Feedforward): Predictor.__init__(self, mark, net_type) # Private attributes self._probabilities = TensorSlot(self, 'Probability') # TODO: to be deprecated # self._evaluation_group = Group(self, self._metric, self._probabilities, # name='evaluation group') @with_graph def build(self, optimizer=None, loss='cross_entropy', metric='accuracy', batch_metric=None, eval_metric=None, **kwargs): Predictor.build(self, optimizer=optimizer, loss=loss, metric=metric, batch_metric=batch_metric, eval_metric=eval_metric, **kwargs) def _build(self, optimizer=None, metric=None, **kwargs): # TODO: ... do some compromise hub.block_validation = True # If last layer is not softmax layer, add it to model TODO # if not (isinstance(self.last_function, Activation) # and self.last_function.abbreviation == 'softmax'): # self.add(Activation('softmax')) # Call parent's method to build using the default loss function # -- cross entropy Predictor._build(self, optimizer=optimizer, metric=metric, **kwargs) assert self.outputs.activated # Plug tensor into probabilities slot self._probabilities.plug(self.outputs.tensor) @with_graph def classify(self, data, batch_size=None, extractor=None, return_probs=False, verbose=False): probs = self.evaluate(self._probabilities.tensor, data, batch_size, extractor, verbose=verbose) if return_probs: return probs # TODO: make clear data shape here and add comments if self.input_type is InputTypes.RNN_BATCH: preds = [np.argmax(p, axis=-1) for p in probs] else: preds = np.argmax(probs, axis=-1) return preds @with_graph def evaluate_model(self, data, batch_size=None, extractor=None, **kwargs): """This method is a mess.""" if hub.take_down_confusion_matrix: # TODO: (william) please refactor this method cm = self.evaluate_pro(data, batch_size, verbose=kwargs.get('verbose', False), show_class_detail=True, show_confusion_matrix=True) # Take down confusion matrix from tframe import context agent = context.trainer.model.agent agent.take_notes('Confusion Matrix on {}:'.format(data.name), False) agent.take_notes('\n' + cm.matrix_table().content) agent.take_notes('Evaluation Result on {}:'.format(data.name), False) agent.take_notes('\n' + cm.make_table().content) return cm.accuracy # If not necessary, use Predictor's evaluate_model method metric_is_accuracy = self.eval_metric.name.lower() == 'accuracy' if not metric_is_accuracy: result = super().evaluate_model(data, batch_size, **kwargs) if metric_is_accuracy: result *= 100 return result console.show_status('Evaluating classifier on {} ...'.format( data.name)) acc_slot = self.metrics_manager.get_slot_by_name('accuracy') assert isinstance(acc_slot, MetricSlot) acc_foreach = acc_slot.quantity_definition.quantities results = self.evaluate(acc_foreach, data, batch_size, extractor, verbose=hub.val_progress_bar) if self.input_type is InputTypes.RNN_BATCH: results = np.concatenate([y.flatten() for y in results]) accuracy = np.mean(results) * 100 # Show accuracy console.supplement('Accuracy on {} is {:.3f}%'.format( data.name, accuracy)) # Return accuracy return accuracy @with_graph def evaluate_pro(self, data_set, batch_size=None, verbose=False, **kwargs): """Evaluate model and give results calculated based on TP, TN, FP, and FN. 'extractor' is not considered cuz the developer forgot what is this used for. :param data_set: an instance of dataset contains at least features :param batch_size: set this value if your (G)RAM is not big enough to handle the whole dataset :param verbose: whether to show status or progress bar stuff :param kwargs: other options which will be recognized by PyCharm """ # Get options show_confusion_matrix = kwargs.get('show_confusion_matrix', False) show_class_detail = kwargs.get('show_class_detail', False) export_false = kwargs.get('export_false', False) top_k = kwargs.get('export_top_k', 3) # Check data_set before get model prediction assert self.input_type is InputTypes.BATCH assert isinstance(data_set, DataSet) assert data_set.features is not None and data_set.targets is not None # ------------------------------------------------------------------------- # Calculate predicted classes and corresponding probabilities # ------------------------------------------------------------------------- probs = self.classify(data_set, batch_size, return_probs=True, verbose=verbose) # This provides necessary information for image viewer presentation # i.e., the sorted probabilities for each class probs_sorted = np.fliplr(np.sort(probs, axis=-1)) class_sorted = np.fliplr(np.argsort(probs, axis=-1)) preds = class_sorted[:, 0] truths = np.ravel(data_set.dense_labels) # Produce confusion matrix cm = ConfusionMatrix(num_classes=data_set.num_classes, class_names=data_set.properties.get( pedia.classes, None)) cm.fill(preds, truths) # Print evaluation results if show_confusion_matrix: console.show_info('Confusion Matrix:') console.write_line(cm.matrix_table(kwargs.get('cell_width', None))) console.show_info('Evaluation Result:') console.write_line( cm.make_table(decimal=4, class_details=show_class_detail)) # Visualize false set if specified if export_false: indices = np.argwhere(preds != truths).flatten() false_set = data_set[indices] false_set.properties[pedia.predictions] = preds[indices] false_set.properties[pedia.top_k_label] = class_sorted[ indices, :top_k] false_set.properties[pedia.top_k_prob] = probs_sorted[ indices, :top_k] return cm, false_set else: return cm
def _build(self, loss='cross_entropy', optimizer=None, metric=None, metric_is_like_loss=True, metric_name='Metric'): Feedforward._build(self) # Check shapes of branch outputs output_shape = self._check_branch_outputs() # Initiate targets placeholder self._plug_target_in(output_shape) # Define output tensors for i, output in enumerate(self.branch_outputs): if i == 0 or not self.strict_residual: output_tensor = output else: output_tensor = output + self._boutputs[i - 1].tensor slot = TensorSlot(self, name='output_{}'.format(i + 1)) slot.plug(output_tensor) self._boutputs.append(slot) # Define loss tensors loss_function = losses.get(loss) with tf.name_scope('Loss'): for i, output in enumerate(self._boutputs): assert isinstance(output, TensorSlot) loss_tensor = loss_function(self._targets.tensor, output.tensor) slot = TensorSlot(self, name='loss_{}'.format(i + 1)) slot.plug(loss_tensor) self._losses.append(slot) # Add summary if hub.summary: name = 'loss_sum_{}'.format(i + 1) sum_slot = SummarySlot(self, name) sum_slot.plug(tf.summary.scalar(name, loss_tensor)) self._train_step_summaries.append(sum_slot) # Define metric tensors metric_function = metrics.get(metric) if metric_function is not None: with tf.name_scope('Metric'): for i, output in enumerate(self._boutputs): assert isinstance(output, TensorSlot) metric_tensor = metric_function(self._targets.tensor, output.tensor) slot = Metric(self, name='metric_{}'.format(i + 1)) slot.plug(metric_tensor, as_loss=metric_is_like_loss, symbol='{}{}'.format(metric_name, i + 1)) self._metrics.append(slot) # Add summary if hub.summary: name = 'metric_sum_{}'.format(i + 1) sum_slot = SummarySlot(self, name) sum_slot.plug(tf.summary.scalar(name, metric_tensor)) self._validation_summaries.append(sum_slot) # Define train step self._define_train_step(optimizer) # Define groups # TODO when train a single branch with summary on, error may occur # .. due to that the higher branch summary can not get its value act_summaries = [] if hub.monitor_preact: slot = SummarySlot(self, 'act_summary') slot.plug( tf.summary.merge(tf.get_collection( pedia.train_step_summaries))) act_summaries.append(slot) self._update_group = Group(self, *self._losses, *self._train_steps, *self._train_step_summaries, *act_summaries) self._validate_group = Group(self, *self._metrics, *self._validation_summaries)
class Predictor(Feedforward, Recurrent): """A feedforward or a recurrent predictor""" model_name = 'Predictor' def __init__(self, mark=None, net_type=Feedforward): """ Construct a Predictor :param mark: model mark :param net_type: \in {Feedforward, Recurrent} """ if not net_type in (Feedforward, Recurrent): raise TypeError('!! Unknown net type') self.master = net_type # Attributes self._targets = TensorSlot(self, 'targets') self._val_targets = TensorSlot(self, 'val_targets') # Call parent's constructor net_type.__init__(self, mark) # region : Properties @property def affix(self): if self.master is Feedforward: return 'forward' assert self.master is Recurrent return 'recurrent' @property def description(self): return '{}: {}'.format(self.master.__name__, self.structure_string()) @property def input_type(self): if self.master is Feedforward: return InputTypes.BATCH else: return InputTypes.RNN_BATCH # endregion : Properties # region : Build @with_graph def build_as_regressor(self, optimizer=None, loss='euclid', metric='rms_ratio', metric_name='Err %'): self.build(optimizer=optimizer, loss=loss, metric=metric, metric_name=metric_name) @with_graph def build(self, optimizer=None, loss='euclid', metric=None, batch_metric=None, eval_metric=None, **kwargs): context.metric_name = 'unknown' # TODO: to be deprecated Model.build(self, optimizer=optimizer, loss=loss, metric=metric, batch_metric=batch_metric, eval_metric=eval_metric, **kwargs) def _build(self, optimizer=None, loss='euclid', metric=None, **kwargs): # For some RNN predictors, their last step is counted as the only output # e.g. RNNs for sequence classification tasks last_only = False if 'last_only' in kwargs.keys(): last_only = kwargs.pop('last_only') if hub.use_gather_indices: # Initiate gather_indices placeholder assert context.gather_indices is None context.gather_indices = tf.placeholder( tf.int32, [None, 2], 'gather_indices') tf.add_to_collection(pedia.default_feed_dict, context.gather_indices) # Get loss quantity before building self.loss_quantity = losses.get(loss, last_only) # This is for calculating loss inside a while-loop context.loss_function = self.loss_quantity.function # Call parent's build method to link network # Usually output tensor has been plugged into Model._outputs slot self.master._build(self) assert self.outputs.activated # Initiate targets and add it to collection self._plug_target_in(self.outputs.shape_list) # Define loss. Some tensorflow apis only support calculating logits with tf.name_scope('Loss'): loss_tensor = self.loss_quantity(self._targets.tensor, self.outputs.tensor) # TODO: with or without regularization loss? if hub.summary: tf.add_to_collection( pedia.train_step_summaries, tf.summary.scalar('loss_sum', loss_tensor)) # Try to add extra loss which is calculated by the corresponding net # .. regularization loss is included if self.extra_loss is not None: loss_tensor = tf.add(loss_tensor, self.extra_loss) # Plug in self.loss.plug(loss_tensor, quantity_def=self.loss_quantity) # <monitor_grad_step_02: register loss and plug grad_ops in> if hub.monitor_weight_grads: context.monitor.register_loss(loss_tensor) self.grads_slot.plug(context.monitor.grad_ops_list) self._update_group.add(self.grads_slot) # Monitor general tensors (currently only activation is included) if hub.export_activations and context.monitor.tensor_fetches: self.general_tensor_slot.plug(context.monitor.tensor_fetches) self._update_group.add(self.general_tensor_slot) # Initialize metric if metric is not None: checker.check_type_v2(metric, (str, Quantity)) # Create placeholder for val_targets if necessary # Common targets will be plugged into val_target slot by default self._plug_val_target_in(kwargs.get('val_targets', None)) with tf.name_scope('Metric'): self._metrics_manager.initialize(metric, last_only, self._val_targets.tensor, self._outputs.tensor, **kwargs) # Merge summaries self._merge_summaries() # Define train step self._define_train_step(optimizer) def _plug_target_in(self, shape): dtype = hub.dtype if hub.target_dim != 0: shape[-1] = hub.target_dim # If target is sparse label if hub.target_dtype is not None: dtype = hub.target_dtype # if hub.target_dim == 1: dtype = tf.int32 # TODO: X # Handle recurrent situation if self._targets.tensor is not None: # targets placeholder has been plugged in Recurrent._build_while_free # method assert self.master == Recurrent return target_tensor = tf.placeholder(dtype, shape, name='targets') self._targets.plug(target_tensor, collection=pedia.default_feed_dict) def _plug_val_target_in(self, val_targets): if val_targets is None: self._val_targets = self._targets else: assert isinstance(val_targets, str) val_target_tensor = tf.placeholder(hub.dtype, self.outputs.shape_list, name=val_targets) self._val_targets.plug(val_target_tensor, collection=pedia.default_feed_dict) # endregion : Build # region : Train def update_model(self, data_batch, **kwargs): if self.master is Feedforward: return Feedforward.update_model(self, data_batch, **kwargs) # Update recurrent model feed_dict = self._get_default_feed_dict(data_batch, is_training=True) results = self._update_group.run(feed_dict) self.set_buffers(results.pop(self._state_slot), is_training=True) # TODO: BETA assert not hub.use_rtrl if hub.use_rtrl: self._gradient_buffer_array = results.pop(self._grad_buffer_slot) if hub.test_grad: delta = results.pop(self.grad_delta_slot) _ = None return results # endregion : Train # region : Public Methods @with_graph def predict(self, data, batch_size=None, extractor=None, **kwargs): return self.evaluate(self._outputs.tensor, data, batch_size, extractor, **kwargs) @with_graph def evaluate_model(self, data, batch_size=None, dynamic=False, **kwargs): """The word `evaluate` in this method name is different from that in `self.evaluate` method. Here only eval_metric will be evaluated and the result will be printed on terminal.""" # Check metric if not self.eval_metric.activated: raise AssertionError('!! Metric not defined') # Do dynamic evaluation if necessary if dynamic: from tframe.trainers.eval_tools.dynamic_eval import DynamicEvaluator as de de.dynamic_evaluate(self, data, kwargs.get('val_set', None), kwargs.get('delay', None)) return # If hub.val_progress_bar is True, this message will be showed in # model.evaluate method if not hub.val_progress_bar: console.show_status('Evaluating on {} ...'.format(data.name)) # use val_progress_bar option here temporarily result = self.validate_model( data, batch_size, allow_sum=False, verbose=hub.val_progress_bar)[self.eval_metric] console.supplement('{} = {}'.format( self.eval_metric.symbol, hub.decimal_str(result, hub.val_decimals))) return result # endregion : Public Methods # region : Private Methods def _evaluate_batch(self, fetch_list, data_batch, **kwargs): return self.master._evaluate_batch(self, fetch_list, data_batch, **kwargs) def _get_default_feed_dict(self, batch, is_training): return self.master._get_default_feed_dict(self, batch, is_training)
class Classifier(Predictor): model_name = 'classifier' def __init__(self, mark=None, net_type=Feedforward): Predictor.__init__(self, mark, net_type) # Private attributes self._probabilities = TensorSlot(self, 'Probability') # TODO: to be deprecated # self._evaluation_group = Group(self, self._metric, self._probabilities, # name='evaluation group') @with_graph def build(self, optimizer=None, loss='cross_entropy', metric='accuracy', batch_metric=None, eval_metric=None, **kwargs): Predictor.build( self, optimizer=optimizer, loss=loss, metric=metric, batch_metric=batch_metric, eval_metric=eval_metric, **kwargs) def _build(self, optimizer=None, metric=None, **kwargs): # TODO: ... do some compromise hub.block_validation = True # If last layer is not softmax layer, add it to model TODO # if not (isinstance(self.last_function, Activation) # and self.last_function.abbreviation == 'softmax'): # self.add(Activation('softmax')) # Call parent's method to build using the default loss function # -- cross entropy Predictor._build(self, optimizer=optimizer, metric=metric, **kwargs) assert self.outputs.activated # Plug tensor into probabilities slot self._probabilities.plug(self.outputs.tensor) @with_graph def evaluate_model(self, data, batch_size=None, extractor=None, export_false=False, **kwargs): # If not necessary, use Predictor's evaluate_model method metric_is_accuracy = self.eval_metric.name.lower() == 'accuracy' if not export_false or not metric_is_accuracy: result = super().evaluate_model(data, batch_size, **kwargs) if metric_is_accuracy: result *= 100 return result console.show_status('Evaluating classifier on {} ...'.format(data.name)) acc_slot = self.metrics_manager.get_slot_by_name('accuracy') assert isinstance(acc_slot, MetricSlot) acc_foreach = acc_slot.quantity_definition.quantities results = self.evaluate(acc_foreach, data, batch_size, extractor, verbose=hub.val_progress_bar) if self.input_type is InputTypes.RNN_BATCH: results = np.concatenate([y.flatten() for y in results]) accuracy = np.mean(results) * 100 # Show accuracy console.supplement('Accuracy on {} is {:.3f}%'.format(data.name, accuracy)) # export_false option is valid for images only if export_false and accuracy < 100.0: assert self.input_type is InputTypes.BATCH assert isinstance(data, DataSet) assert data.features is not None and data.targets is not None top_k = hub.export_top_k if hub.export_top_k > 0 else 3 probs = self.classify(data, batch_size, extractor, return_probs=True) probs_sorted = np.fliplr(np.sort(probs, axis=-1)) class_sorted = np.fliplr(np.argsort(probs, axis=-1)) preds = class_sorted[:, 0] false_indices = np.argwhere(results == 0).flatten() false_preds = preds[false_indices] probs_sorted = probs_sorted[false_indices, :top_k] class_sorted = class_sorted[false_indices, :top_k] false_set = data[false_indices] false_set.properties[pedia.predictions] = false_preds false_set.properties[pedia.top_k_label] = class_sorted false_set.properties[pedia.top_k_prob] = probs_sorted from tframe.data.images.image_viewer import ImageViewer vr = ImageViewer(false_set) vr.show() # Return accuracy return accuracy @with_graph def classify(self, data, batch_size=None, extractor=None, return_probs=False, verbose=False): probs = self.evaluate( self._probabilities.tensor, data, batch_size, extractor, verbose=verbose) if return_probs: return probs if self.input_type is InputTypes.RNN_BATCH: preds = [np.argmax(p, axis=-1) for p in probs] else: preds = np.argmax(probs, axis=-1) return preds
class Classifier(Predictor): def __init__(self, mark=None, net_type=Feedforward): Predictor.__init__(self, mark, net_type) # Private attributes self._probabilities = TensorSlot(self, 'Probability') self._evaluation_group = Group(self, self._metric, self._probabilities, name='evaluation group') @with_graph def build(self, optimizer=None, metric='accuracy', **kwargs): Predictor.build(self, optimizer=optimizer, loss='cross_entropy', metric=metric, metric_is_like_loss=False, metric_name='Accuracy') def _build(self, optimizer=None, metric=None, **kwargs): # TODO: ... do some compromise hub.block_validation = True # If last layer is not softmax layer, add it to model if not (isinstance(self.last_function, Activation) and self.last_function.abbreviation == 'softmax'): self.add(Activation('softmax')) # Call parent's method to build using the default loss function # -- cross entropy Predictor._build(self, optimize=optimizer, metric=metric, **kwargs) assert self.outputs.activated # Plug tensor into probabilities slot self._probabilities.plug(self.outputs.tensor) def evaluate_model(self, data, batch_size=None, extractor=None, export_false=False, **kwargs): # Feed data set into model and get results false_sample_list = [] false_label_list = [] true_label_list = [] num_samples = 0 console.show_status('Evaluating classifier ...') for batch in self.get_data_batches(data, batch_size): assert isinstance(batch, DataSet) and batch.targets is not None # Get predictions preds = self._classify_batch(batch, extractor) # Get true labels in dense format if batch.targets.shape[-1] > 1: targets = batch.targets.reshape(-1, batch.targets.shape[-1]) else: targets = batch.targets num_samples += len(targets) true_labels = misc.convert_to_dense_labels(targets) if len(true_labels) < len(preds): assert len(true_labels) == 1 true_labels = np.concatenate((true_labels, ) * len(preds)) # Select false samples false_indices = np.argwhere(preds != true_labels) if false_indices.size == 0: continue features = batch.features if self.input_type is InputTypes.RNN_BATCH: features = np.reshape(features, [-1, *features.shape[2:]]) false_indices = np.reshape(false_indices, false_indices.size) false_sample_list.append(features[false_indices]) false_label_list.append(preds[false_indices]) true_label_list.append(true_labels[false_indices]) # Concatenate if len(false_sample_list) > 0: false_sample_list = np.concatenate(false_sample_list) false_label_list = np.concatenate(false_label_list) true_label_list = np.concatenate(true_label_list) # Show accuracy accuracy = (num_samples - len(false_sample_list)) / num_samples * 100 console.supplement('Accuracy on {} is {:.2f}%'.format( data.name, accuracy)) # Try to export false samples if export_false and accuracy < 100: false_set = DataSet(features=false_sample_list, targets=true_label_list) if hasattr(data, 'properties'): false_set.properties = data.properties false_set.data_dict[pedia.predictions] = false_label_list from tframe.data.images.image_viewer import ImageViewer vr = ImageViewer(false_set) vr.show() def classify(self, data, batch_size=None, extractor=None): predictions = [] for batch in self.get_data_batches(data, batch_size): preds = self._classify_batch(batch, extractor) if isinstance(preds, int): preds = [preds] predictions.append(preds) return np.concatenate(predictions) def _classify_batch(self, batch, extractor): assert isinstance(batch, DataSet) and batch.features is not None batch = self._sanity_check_before_use(batch) feed_dict = self._get_default_feed_dict(batch, is_training=False) probs = self._probabilities.run(feed_dict) if self.input_type is InputTypes.RNN_BATCH: assert len(probs.shape) == 3 probs = np.reshape(probs, (-1, probs.shape[2])) if extractor is None: preds = misc.convert_to_dense_labels(probs) else: preds = extractor(probs) return preds
class Predictor(Feedforward, Recurrent): """A feedforward or a recurrent predictor""" model_name = 'Predictor' def __init__(self, mark=None, net_type=Feedforward): """ Construct a Predictor :param mark: model mark :param net_type: \in {Feedforward, Recurrent} """ if not net_type in (Feedforward, Recurrent): raise TypeError('!! Unknown net type') self.master = net_type # Call parent's constructor net_type.__init__(self, mark) # Attributes self._targets = TensorSlot(self, 'targets') # region : Properties @property def description(self): return '{}: {}'.format(self.master.__name__, self.structure_string()) @property def input_type(self): if self.master is Feedforward: return InputTypes.BATCH else: return InputTypes.RNN_BATCH # endregion : Properties # region : Build @with_graph def build_as_regressor(self, optimizer=None, loss='euclid', metric='rms_ratio', metric_is_like_loss=True, metric_name='Err %'): self.build(optimizer=optimizer, loss=loss, metric=metric, metric_name=metric_name, metric_is_like_loss=metric_is_like_loss) @with_graph def build(self, optimizer=None, loss='euclid', metric=None, metric_is_like_loss=True, metric_name='Metric', **kwargs): Model.build(self, optimizer=optimizer, loss=loss, metric=metric, metric_name=metric_name, metric_is_like_loss=metric_is_like_loss) def _build(self, optimizer=None, loss='euclid', metric=None, metric_is_like_loss=True, metric_name='Metric', **kwargs): # Call parent's build method # Usually output tensor has been plugged into Model._outputs slot self.master._build(self) assert self.outputs.activated # Initiate targets and add it to collection self._plug_target_in(self.outputs.shape_list) # Define loss loss_function = losses.get(loss) with tf.name_scope('Loss'): if loss == 'cross_entropy': output_tensor = self.logits_tensor assert output_tensor is not None else: output_tensor = self.outputs.tensor loss_tensor = loss_function(self._targets.tensor, output_tensor) # TODO: with or without regularization loss? if hub.summary: tf.add_to_collection( pedia.train_step_summaries, tf.summary.scalar('loss_sum', loss_tensor)) # Try to add regularization loss reg_loss = self.regularization_loss if reg_loss is not None: loss_tensor += reg_loss # Plug in self.loss.plug(loss_tensor) # Define metric if metric is not None: metric_function = metrics.get(metric) with tf.name_scope('Metric'): metric_tensor = metric_function(self._targets.tensor, self._outputs.tensor) self._metric.plug(metric_tensor, as_loss=metric_is_like_loss, symbol=metric_name) if hub.summary: tf.add_to_collection( pedia.validation_summaries, tf.summary.scalar('metric_sum', self._metric.tensor)) # Merge summaries self._merge_summaries() # Define train step self._define_train_step(optimizer) def _plug_target_in(self, shape): target_tensor = tf.placeholder(hub.dtype, shape, name='targets') self._targets.plug(target_tensor, collection=pedia.default_feed_dict) # endregion : Build # region : Train # TODO # def begin_round(self, **kwargs): # if self.master is Recurrent: # th = kwargs.get('th') # assert isinstance(th, TrainerHub) # self.reset_state(th.batch_size) def update_model(self, data_batch, **kwargs): if self.master is Feedforward: return Feedforward.update_model(self, data_batch, **kwargs) # Update recurrent model feed_dict = self._get_default_feed_dict(data_batch, is_training=True) results = self._update_group.run(feed_dict) self._state_array = results.pop(self._state) return results # endregion : Train # region : Public Methods def predict(self, data, batch_size=None, extractor=None, **kwargs): outputs = [] for batch in self.get_data_batches(data, batch_size): batch = self._sanity_check_before_use(batch) feed_dict = self._get_default_feed_dict(batch, is_training=False) output = self._outputs.run(feed_dict) if extractor is not None: output = extractor(output) outputs.append(output) return np.concatenate(outputs) def evaluate_model(self, data, batch_size=None, **kwargs): # Check metric if not self.metric.activated: raise AssertionError('!! Metric not defined') # Show status console.show_status('Evaluating {} ...'.format(data.name)) result = self.validate_model(data, batch_size, allow_sum=False)[self.metric] console.supplement('{} = {:.3f}'.format(self.metric.symbol, result)) # endregion : Public Methods # region : Private Methods def _get_default_feed_dict(self, batch, is_training): feed_dict = Feedforward._get_default_feed_dict(self, batch, is_training) if self.master is Recurrent: assert isinstance(batch, DataSet) # If a new sequence begin while training, reset state if is_training: if batch.should_reset_state: if hub.notify_when_reset: console.write_line('- ' * 40) self.reset_state(batch.size) if batch.should_partially_reset_state: if hub.notify_when_reset: if batch.reset_values is not None: info = [(i, v) for i, v in zip( batch.reset_batch_indices, batch.reset_values)] else: info = batch.reset_batch_indices console.write_line('{}'.format(info)) self.reset_part_state(batch.reset_batch_indices, batch.reset_values) batch_size = None if is_training else batch.size # If is not training, always set a zero state to model feed_dict.update(self._get_state_dict(batch_size=batch_size)) return feed_dict