示例#1
0
    def build_train_op(self, global_step):
        '''
            Builds two train ops, one for the Generator and one for the Discriminator. These can 
            be run independently any number of times, and each time will increase the global_step.

            Args:
                global_step: A Tensor to be incremented
            Returns:
                [ g_train_op, d_train_op ]
        '''
        if not self.model_built or not self.losses_built:
            raise RuntimeError(
                "Cannot build optimizers until 'build_model' ({0}) and 'get_losses' {1} are run"
                .format(self.model_built, self.losses_built))
        self.global_step = global_step
        self.global_step_copy = tf.identity(global_step,
                                            name='global_step_copy')

        t_vars = tf.trainable_variables()

        # Create the optimizer train_op for the generator
        self.g_optimizer = optimize.build_optimizer(
            global_step=self.global_step, cfg=self.cfg)
        self.g_vars = slim.get_variables(
            scope='encoder', collection=tf.GraphKeys.TRAINABLE_VARIABLES)
        self.g_vars += slim.get_variables(
            scope='decoder', collection=tf.GraphKeys.TRAINABLE_VARIABLES)
        self.g_train_op = optimize.create_train_op(
            self.loss_g_total,
            self.g_optimizer,
            variables_to_train=self.g_vars,
            update_global_step=True)

        self.g_lnorm_op = optimize.create_train_op(
            self.l1_loss,
            self.g_optimizer,
            variables_to_train=self.g_vars,
            update_global_step=True)

        # Create a train_op for the discriminator
        if 'discriminator_learning_args' in self.cfg:  # use these
            discriminator_learning_args = self.cfg[
                'discriminator_learning_args']
        else:
            discriminator_learning_args = self.cfg
        self.d_optimizer = optimize.build_optimizer(
            global_step=self.global_step, cfg=discriminator_learning_args)
        self.d_vars = slim.get_variables(
            scope='discriminator', collection=tf.GraphKeys.TRAINABLE_VARIABLES)
        self.d_vars += slim.get_variables(
            scope='discriminator_1',
            collection=tf.GraphKeys.TRAINABLE_VARIABLES)
        self.d_train_op = slim.learning.create_train_op(
            self.loss_d_total,
            self.d_optimizer,
            variables_to_train=self.d_vars)

        self.train_op = [self.g_train_op, self.d_train_op, self.g_lnorm_op]
        self.train_op_built = True
        return self.train_op
示例#2
0
    def build_train_op( self, global_step ):
        '''
            Builds train ops for discriminative task
            
            Args:
                global_step: A Tensor to be incremented
            Returns:
                [ loss_op, accuracy ]
        '''
        if not self.model_built or self.total_loss is None :
            raise RuntimeError( "Cannot build optimizers until 'build_model' ({0}) and 'get_losses' {1} are run".format(
                    self.model_built, self.losses_built ) )
        self.global_step = global_step

        t_vars = tf.trainable_variables()

        # Create the optimizer train_op for the generator

        self.optimizer = optimize.build_optimizer( global_step=self.global_step, cfg=self.cfg )
        if 'clip_norm' in self.cfg:
            self.loss_op = optimize.create_train_op( self.total_loss, self.optimizer, update_global_step=True, clip_gradient_norm=self.cfg['clip_norm'])
        else:
            if self.is_training:
                self.loss_op = optimize.create_train_op( self.total_loss, self.optimizer, update_global_step=True )
            else:
                self.loss_op = optimize.create_train_op( self.total_loss, self.optimizer, is_training=False, update_global_step=True )

        # Create a train_op for the discriminator

        self.train_op = [ self.loss_op, self.accuracy]
        self.train_op_built = True
        return self.train_op
示例#3
0
 def build_train_op(self, global_step):
     if not self.model_built or self.total_loss is None:
         raise RuntimeError(
             "Cannot build optimizers until 'build_model' ({0}) and 'get_losses' {1} are run"
             .format(self.model_built, self.total_loss is not None))
     self.global_step = global_step
     self.optimizer = optimizers.build_optimizer(global_step=global_step,
                                                 cfg=self.cfg)
     self.train_op = slim.learning.create_train_op(self.total_loss,
                                                   self.optimizer)
    def build_train_op( self, global_step ):
        '''
            Builds train ops for discriminative task
            
            Args:
                global_step: A Tensor to be incremented
            Returns:
                [ loss_op, accuracy ]
        '''
        if not self.model_built or self.total_loss is None :
            raise RuntimeError( "Cannot build optimizers until 'build_model' ({0}) and 'get_losses' {1} are run".format(
                    self.model_built, self.losses_built ) )
        self.global_step = global_step

        vars_to_train = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.encoder_scope)
        
        # Create the optimizer train_op for the generator

        self.optimizer = optimize.build_optimizer( global_step=self.global_step, cfg=self.cfg )
        if 'clip_norm' in self.cfg:
            self.loss_op = optimize.create_train_op( 
                self.total_loss, self.optimizer, 
                update_global_step=True, clip_gradient_norm=self.cfg['clip_norm'],
                variables_to_train=vars_to_train )
        else:
            if self.is_training:
                self.loss_op = optimize.create_train_op( 
                    self.total_loss, self.optimizer, 
                    update_global_step=True, variables_to_train=vars_to_train )
            else:
                self.loss_op = optimize.create_train_op( 
                    self.total_loss, self.optimizer, 
                    is_training=False, update_global_step=True,
                    variables_to_train=vars_to_train )

        self.train_op = [ self.loss_op, 0 ]
        self.train_op_built = True
        return self.train_op