def _build_graph(self, images, labels, mode): """Constructs the TF graph for the cifar model. Args: images: A 4-D image Tensor labels: A 2-D labels Tensor. mode: string indicating training mode ( e.g., 'train', 'valid', 'test'). """ is_training = 'train' in mode if is_training: self.global_step = tf.train.get_or_create_global_step() logits = build_model(images, self.num_classes, is_training, self.hparams) self.predictions, self.cost = helper_utils.setup_loss(logits, labels) self.accuracy, self.eval_op = tf.metrics.accuracy( tf.argmax(labels, 1), tf.argmax(self.predictions, 1)) self._calc_num_trainable_params() # Adds L2 weight decay to the cost self.cost = helper_utils.decay_weights(self.cost, self.hparams.weight_decay_rate) if is_training: self._build_train_op() # Setup checkpointing for this child model # Keep 2 or more checkpoints around during training. with tf.device('/cpu:0'): self.saver = tf.train.Saver(max_to_keep=2) self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
def _build_graph(self, images, labels, mode): """Constructs the TF graph for the cifar model. Args: images: A 4-D image Tensor labels: A 2-D labels Tensor. mode: string indicating training mode ( e.g., 'train', 'valid', 'test'). """ is_training = 'train' in mode if is_training: self.global_step = tf.train.get_or_create_global_step() if self.hparams.use_gamma_swish: for layer in range(13): _ = tf.Variable([self.hparams.init_beta], trainable=True, dtype=tf.float32, name='swish_beta_layer_{}'.format(layer)) _ = tf.Variable([self.hparams.init_gamma], trainable=True, dtype=tf.float32, name='swish_gamma_layer_{}'.format(layer)) logits, hiddens = build_model(images, self.num_classes, is_training, self.hparams) self.predictions, self.cost_ = helper_utils.setup_loss(logits, labels) self.hiddens_norm = tf.reduce_mean(hiddens**2) self.logits = logits self.logit_norm = tf.reduce_mean(logits**2) self.accuracy, self.eval_op = tf.metrics.accuracy( tf.argmax(labels, 1), tf.argmax(self.predictions, 1)) self._calc_num_trainable_params() self.compute_flops_per_example() # Adds L2 weight decay to the cost self.cost = helper_utils.decay_weights(self.cost_, self.hparams.weight_decay_rate) if is_training: self._build_train_op() # Setup checkpointing for this child model # Keep 2 or more checkpoints around during training. with tf.device('/cpu:0'): self.saver = tf.train.Saver(max_to_keep=2) self.ckpt_saver = tf.train.Saver(max_to_keep=100) self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
def _build_graph(self, images, labels, mode): """Constructs the TF graph for the cifar model. Args: images: A 4-D image Tensor labels: A 2-D labels Tensor. mode: string indicating training mode ( e.g., 'train', 'valid', 'test'). """ is_training = 'train' in mode if is_training: self.global_step = tf.train.get_or_create_global_step() logits = build_model( images, self.num_classes, is_training, self.hparams) self.predictions, self.cost = helper_utils.setup_loss( logits, labels) self.accuracy, self.eval_op = tf.metrics.accuracy( tf.argmax(labels, 1), tf.argmax(self.predictions, 1)) self._calc_num_trainable_params() # Adds L2 weight decay to the cost self.cost = helper_utils.decay_weights(self.cost, self.hparams.weight_decay_rate) if is_training: self._build_train_op() # Setup checkpointing for this child model # Keep 2 or more checkpoints around during training. with tf.device('/cpu:0'): self.saver = tf.train.Saver(max_to_keep=2) self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())