def _build_graph(self, images, labels, mode):
        """Constructs the TF graph for the model.

        Args:
          images: A 4-D image Tensor
          labels: A 2-D labels Tensor.
          mode: string indicating training mode ( e.g., 'train', 'valid', 'test').
        """
        is_training = 'train' in mode
        if is_training:
            self.global_step = tf.train.get_or_create_global_step()

        logits = build_model(images, self.num_classes, is_training,
                             self.hparams)
        self.predictions, self.cost = helper_utils.setup_loss(logits, labels)

        self._calc_num_trainable_params()

        # Adds L2 weight decay to the cost
        self.cost = helper_utils.decay_weights(self.cost,
                                               self.hparams.weight_decay_rate)

        if is_training:
            self._build_train_op()

        # Setup checkpointing for this child model
        # Keep 2 or more checkpoints around during training.
        with tf.device('/cpu:0'):
            self.saver = tf.train.Saver(max_to_keep=10)

        self.init = tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer())
示例#2
0
    def _build_graph(self, images, labels, mode):
        if self.hparams.model_name in ['resnet18']:
            if self.hparams.model_name == 'resnet18':
                model = Resnet18(self.num_classes)
            is_training = 'train' in mode
            if is_training:
                self.global_step = tf.train.get_or_create_global_step()

            logits = model(images, is_training)
            self.predictions, self.cost = helper_utils.setup_loss(logits, labels)

            self._calc_num_trainable_params()

            # Adds L2 weight decay to the cost
            self.cost = helper_utils.decay_weights(self.cost,
                                                   self.hparams.weight_decay_rate)

            if is_training:
                self._build_train_op()

            # Setup checkpointing for this child model
            # Keep 2 or more checkpoints around during training.
            with tf.device('/cpu:0'):
                self.saver = tf.train.Saver(max_to_keep=10)

            self.init = tf.group(tf.global_variables_initializer(),
                                 tf.local_variables_initializer())
        else:
            super(Model, self)._build_graph(images, labels, mode)
示例#3
0
文件: model.py 项目: tobyclh/pba
    def _build_graph(self, images, labels, mode):
        """Constructs the TF graph for the model.

        Args:
          images: A 4-D image Tensor
          labels: A 2-D labels Tensor.
          mode: string indicating training mode ( e.g., 'train', 'valid', 'test').
        """
        is_training = 'train' in mode
        if is_training:
            self.global_step = 0

        logits = build_model(images, self.num_classes, is_training,
                             self.hparams)
        self.predictions, self.cost = helper_utils.setup_loss(logits, labels)

        self._calc_num_trainable_params()

        # Adds L2 weight decay to the cost
        self.cost = helper_utils.decay_weights(self.cost,
                                               self.hparams.weight_decay_rate)

        if is_training:
            self._build_train_op()