def _define_train_step(self, optimizer=None, var_list=None): """ TODO: should be modified for tframe.optimizer self._train_step will be plugged only here """ if not self._loss.activated: raise AssertionError('!! loss has not been activated yet') with tf.name_scope('Optimizer'): if optimizer is None: console.show_status( 'Optimizer defined in trainer hub initialized.', '++') optimizer = hub.get_optimizer(optimizer) self._optimizer = optimizer self.set_train_step(var_list)
def build(self, **kwargs): # Smooth out flags before important actions hub.smooth_out_conflicts() # Initialize pruner if necessary if any([ hub.prune_on, hub.weights_mask_on, hub.etch_on, hub.force_to_use_pruner ]): # import here to prevent circular import (temporarily) from tframe.operators.prune.pruner import Pruner tfr.context.pruner = Pruner(self) # If optimizer if not provided here, try hub.get_optimizer() # this requires that th.optimizer and th.learning_rate have been provided if 'optimizer' not in kwargs: kwargs['optimizer'] = hub.get_optimizer() # Call successor's _build method self._build(**kwargs) # Initialize monitor self._init_monitor() # Set built flag self._built = True # Show build info console.show_status('Model built successfully:') self.agent.take_notes('Model built successfully') self.agent.take_notes('Structure:', date_time=False) # Description may be a model structure description = self.description if not isinstance(description, (tuple, list)): description = [description] for line in description: assert isinstance(line, str) console.supplement(line) self.agent.take_notes(line, date_time=False) # Add metric slot to update group batch_metric = kwargs.get('batch_metric', []) if batch_metric: if not isinstance(batch_metric, (tuple, list)): batch_metric = [batch_metric] for metric_str in batch_metric: assert isinstance(metric_str, str) metric_slot = self.metrics_manager.get_slot_by_name(metric_str) self._update_group.add(metric_slot) # Register eval_metric if provided eval_metric = kwargs.get('eval_metric', None) if eval_metric is not None: assert isinstance(eval_metric, str) self.metrics_manager.register_eval_slot(eval_metric)
def _define_train_step(self, optimizer=None, var_list=None): """ TODO: should be modified for tframe.optimizer self._train_step will be plugged only here """ if not self._loss.activated: raise AssertionError('!! loss has not been activated yet') with tf.name_scope('Optimizer'): if optimizer is None: optimizer = hub.get_optimizer() console.show_status( 'Optimizer defined in trainer hub initialized.', '++') # TODO: BETA if hub.use_rtrl: raise AssertionError('use_rtrl option has been deprecated') from tframe.optimizers.rtrl_opt import RealTimeOptimizer optimizer = RealTimeOptimizer(self, optimizer) self._optimizer = optimizer self.set_train_step(var_list)