Example #1
0
class SiameseModel:
    def __init__(self, model_config, train_config, mode='train'):
        self.model_config = model_config
        self.train_config = train_config
        self.mode = mode
        assert mode in ['train', 'validation', 'inference']

        if self.mode == 'train':
            self.data_config = self.train_config['train_data_config']
        elif self.mode == 'validation':
            self.data_config = self.train_config['validation_data_config']

        self.dataloader = None
        self.exemplars = None
        self.instances = None
        self.response = None
        self.batch_loss = None
        self.total_loss = None
        self.init_fn = None
        self.global_step = None

    def is_training(self):
        """Returns true if the model is built for training mode"""
        return self.mode == 'train'

    def build_inputs(self):
        """Input fetching and batching

    Outputs:
      self.exemplars: image batch of shape [batch, hz, wz, 3]
      self.instances: image batch of shape [batch, hx, wx, 3]
    """
        if self.mode in ['train', 'validation']:
            with tf.device(
                    "/cpu:0"
            ):  # Put data loading and preprocessing in CPU is substantially faster
                self.dataloader = DataLoader(self.data_config,
                                             self.is_training())
                self.dataloader.build()
                exemplars, instances = self.dataloader.get_one_batch()

                exemplars = tf.to_float(exemplars)
                instances = tf.to_float(instances)
        else:
            self.examplar_feed = tf.placeholder(shape=[None, None, None, 3],
                                                dtype=tf.uint8,
                                                name='examplar_input')
            self.instance_feed = tf.placeholder(shape=[None, None, None, 3],
                                                dtype=tf.uint8,
                                                name='instance_input')
            exemplars = tf.to_float(self.examplar_feed)
            instances = tf.to_float(self.instance_feed)

        self.exemplars = exemplars
        self.instances = instances

    def build_image_embeddings(self, reuse=False):
        """Builds the image model subgraph and generates image embeddings

    Inputs:
      self.exemplars: A tensor of shape [batch, hz, wz, 3]
      self.instances: A tensor of shape [batch, hx, wx, 3]

    Outputs:
      self.exemplar_embeds: A Tensor of shape [batch, hz_embed, wz_embed, embed_dim]
      self.instance_embeds: A Tensor of shape [batch, hx_embed, wx_embed, embed_dim]
    """
        config = self.model_config['embed_config']
        # =============================================================================
        #     arg_scope = convolutional_alexnet_arg_scope(config,
        #                                                 trainable=config['train_embedding'],
        #                                                 is_training=self.is_training())
        # =============================================================================
        arg_scope = convolutional_alexnet_arg_scope(config,
                                                    trainable=False,
                                                    is_training=False)

        @functools.wraps(convolutional_alexnet)
        def embedding_fn(images, reuse=False):
            with slim.arg_scope(arg_scope):
                return convolutional_alexnet(images, reuse=reuse)

        self.exemplar_embeds_c5, self.exemplar_embeds_c4, self.exemplar_embeds_c3, _ = embedding_fn(
            self.exemplars, reuse=reuse)
        self.instance_embeds_c5, self.instance_embeds_c4, self.instance_embeds_c3, _ = embedding_fn(
            self.instances, reuse=True)

# =============================================================================
#     self.exemplar_embeds_c5, self.exemplar_embeds_c4, self.exemplar_embeds_c3 = embed_alexnet(self.exemplars)
#     self.instance_embeds_c5, self.instance_embeds_c4, self.instance_embeds_c3 = embed_alexnet(self.instances)
# =============================================================================

    def build_template(self):
        # The template is simply the feature of the exemplar image in SiamFC.
        self.templates_c5 = self.exemplar_embeds_c5
        self.templates_c4 = self.exemplar_embeds_c4
        self.templates_c3 = self.exemplar_embeds_c3

    def build_detection(self, reuse=False):
        with tf.variable_scope('detection', reuse=reuse):

            def _translation_match(
                    x, z):  # translation match for one example within a batch
                x = tf.expand_dims(x,
                                   0)  # [1, in_height, in_width, in_channels]
                z = tf.expand_dims(
                    z, -1)  # [filter_height, filter_width, in_channels, 1]
                return tf.nn.conv2d(x,
                                    z,
                                    strides=[1, 1, 1, 1],
                                    padding='VALID',
                                    name='translation_match')

            output_c5 = tf.map_fn(lambda x: _translation_match(x[0], x[1]),
                                  (self.instance_embeds_c5, self.templates_c5),
                                  dtype=self.instance_embeds_c5.dtype)
            output_c4 = tf.map_fn(lambda x: _translation_match(x[0], x[1]),
                                  (self.instance_embeds_c4, self.templates_c4),
                                  dtype=self.instance_embeds_c4.dtype)
            output_c3 = tf.map_fn(lambda x: _translation_match(x[0], x[1]),
                                  (self.instance_embeds_c3, self.templates_c3),
                                  dtype=self.instance_embeds_c3.dtype)
            output_c5 = tf.squeeze(output_c5,
                                   [1, 4])  # of shape e.g., [8, 15, 15]
            output_c4 = tf.squeeze(output_c4,
                                   [1, 4])  # of shape e.g., [8, 15, 15]
            output_c3 = tf.squeeze(output_c3,
                                   [1, 4])  # of shape e.g., [8, 15, 15]
            print_op_c5 = tf.Print(output_c5, [output_c5])
            print_op_c4 = tf.Print(output_c4, [output_c4])
            print_op_c3 = tf.Print(output_c3, [output_c3])

            #ALEX_DICT = parse_tf_model("/home/travail/dev/GitRepo/CBSiamFC/Logs/SiamFC/track_model_checkpoints/Alex_v1/model.ckpt-332499")
            # Adjust score, this is required to make training possible.
            config = self.model_config['adjust_response_config']
            # =============================================================================
            #       self.bias_c5 = tf.get_variable('biases_a_c5', [1],    #[-9.63422871]
            #                              dtype=tf.float32,
            #                              initializer=tf.constant_initializer(0.0),#=tf.constant_initializer(ALEX_DICT['detection/biases_c5']),
            #                              trainable=config['train_bias'])
            #       self.bias_c4 = tf.get_variable('biases_a_c4', [1],    #[-5.29178524]
            #                              dtype=tf.float32,
            #                              initializer=tf.constant_initializer(0.0),#tf.constant_initializer(ALEX_DICT['detection/biases_c4']),
            #                              trainable=config['train_bias'])
            #       self.bias_c3 = tf.get_variable('biases_a_c3', [1],    #[-4.51134348]
            #                              dtype=tf.float32,
            #                              initializer=tf.constant_initializer(0.0),#tf.constant_initializer(ALEX_DICT['detection/biases_c3']),
            #                              trainable=config['train_bias'])
            # =============================================================================
            self.bias_c5 = tf.get_variable(
                'biases_s_c5',
                [1],  #[-9.63422871]
                dtype=tf.float32,
                initializer=tf.constant_initializer(
                    0.0
                ),  #=tf.constant_initializer(ALEX_DICT['detection/biases_c5']),
                trainable=config['train_bias'])
            self.bias_c4 = tf.get_variable(
                'biases_s_c4',
                [1],  #[-5.29178524]
                dtype=tf.float32,
                initializer=tf.constant_initializer(
                    0.0
                ),  #tf.constant_initializer(ALEX_DICT['detection/biases_c4']),
                trainable=config['train_bias'])
            self.bias_c3 = tf.get_variable(
                'biases_s_c3',
                [1],  #[-4.51134348]
                dtype=tf.float32,
                initializer=tf.constant_initializer(
                    0.0
                ),  #tf.constant_initializer(ALEX_DICT['detection/biases_c3']),
                trainable=config['train_bias'])
            # =============================================================================
            #       with tf.control_dependencies([print_op_c3]):#455
            # =============================================================================
            # =============================================================================
            #       response_c5 = 1e-2*output_c5 + self.bias_c5
            #       response_c4 = 1e-3*output_c4 + self.bias_c4
            #       response_c3 = 1e-3*output_c3 + self.bias_c3
            # =============================================================================
            # =============================================================================
            #       ## for training alexnet se
            #       response_c5 = 1e-3*output_c5 + self.bias_c5
            #       response_c4 = 1e-4*output_c4 + self.bias_c4
            #       response_c3 = 1e-5*output_c3 + self.bias_c3
            # =============================================================================
            ## for training siamfc se
            response_c5 = 1e-3 * output_c5 + self.bias_c5
            response_c4 = 1e-3 * output_c4 + self.bias_c4
            response_c3 = 1e-3 * output_c3 + self.bias_c3
            # =============================================================================
            #       response_c5 = config['scale'] * output_c5 -4.52620411
            #       response_c4 = config['scale'] * output_c4 -0.03678114
            #       response_c3 = config['scale'] * output_c3 -0.49341503
            # =============================================================================

            # weight maps for response from each layer
            # =============================================================================
            #       response_size = response_c5.get_shape().as_list()[1:3]
            #       map_c5 = tf.get_variable('map_c5', response_size,
            #                                     dtype=tf.float32,
            #                                     initializer=tf.constant_initializer(0.5, dtype=tf.float32),
            #                                     trainable=True)
            #       map_c4 = tf.get_variable('map_c4', response_size,
            #                                     dtype=tf.float32,
            #                                     initializer=tf.constant_initializer(0.5, dtype=tf.float32),
            #                                     trainable=True)
            #       map_c3 = tf.get_variable('map_c3', response_size,
            #                                     dtype=tf.float32,
            #                                     initializer=tf.constant_initializer(0.5, dtype=tf.float32),
            #                                     trainable=True)
            # =============================================================================
            # =============================================================================
            #       self.weight_c5 = tf.get_variable('weight_a_c5', [1], #[ 0.71658146]
            #                              dtype=tf.float32,
            #                              initializer=tf.constant_initializer(0.5, dtype=tf.float32),
            #                              trainable=True)
            #       self.weight_c4 = tf.get_variable('weight_a_c4', [1], #[ 0.04511292]
            #                              dtype=tf.float32,
            #                              initializer=tf.constant_initializer(0.5, dtype=tf.float32),
            #                              trainable=True)
            #       self.weight_c3 = tf.get_variable('weight_a_c3', [1], #[ 0.0067619]
            #                              dtype=tf.float32,
            #                              initializer=tf.constant_initializer(0.5, dtype=tf.float32),
            #                              trainable=True)
            #       response_c5 = tf.multiply(response_c5, self.weight_c5)
            #       response_c4 = tf.multiply(response_c4, self.weight_c4)
            #       response_c3 = tf.multiply(response_c3, self.weight_c3)
            # =============================================================================
            # =============================================================================
            #       response_c5_max = tf.reduce_max(response_c5)
            #       response_c4_max = tf.reduce_max(response_c4)
            #       response_c3_max = tf.reduce_max(response_c3)
            #       self.response_c5 = tf.div(response_c5, response_c5_max)
            #       self.response_c4 = tf.div(response_c4, response_c4_max)
            #       self.response_c3 = tf.div(response_c3, response_c3_max)
            #       self.response = 0.3*response_c5+0.6*response_c4+0.1*response_c3
            # =============================================================================
            # =============================================================================
            #       self.response = response_c5*0.6+response_c4*0.3+response_c3*0.1
            # =============================================================================
            self.response_c5 = response_c5
            self.response_c4 = response_c4
            self.response_c3 = response_c3

    def build_loss(self):
        response_c5 = self.response_c5
        response_c4 = self.response_c4
        response_c3 = self.response_c3
        response_size = response_c5.get_shape().as_list()[1:
                                                          3]  # [height, width]

        gt = construct_gt_score_maps(
            response_size, self.data_config['batch_size'],
            self.model_config['embed_config']['stride'],
            self.train_config['gt_config'])

        with tf.name_scope('Loss'):
            loss_c5 = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=response_c5, labels=gt)
            loss_c4 = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=response_c4, labels=gt)
            loss_c3 = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=response_c3, labels=gt)

            with tf.name_scope('Balance_weights'):
                n_pos = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 1)))
                n_neg = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 0)))
                w_pos = 0.5 / n_pos
                w_neg = 0.5 / n_neg
                class_weights = tf.where(tf.equal(gt, 1),
                                         w_pos * tf.ones_like(gt),
                                         tf.ones_like(gt))
                class_weights = tf.where(tf.equal(gt, 0),
                                         w_neg * tf.ones_like(gt),
                                         class_weights)
                loss_c5 = loss_c5 * class_weights
                loss_c4 = loss_c4 * class_weights
                loss_c3 = loss_c3 * class_weights

            # Note that we use reduce_sum instead of reduce_mean since the loss has
            # already been normalized by class_weights in spatial dimension.
            loss_c5 = tf.reduce_sum(loss_c5, [1, 2])
            loss_c4 = tf.reduce_sum(loss_c4, [1, 2])
            loss_c3 = tf.reduce_sum(loss_c3, [1, 2])

            batch_loss_c5 = tf.reduce_mean(loss_c5, name='batch_loss_c5')
            batch_loss_c4 = tf.reduce_mean(loss_c4, name='batch_loss_c4')
            batch_loss_c3 = tf.reduce_mean(loss_c3, name='batch_loss_c3')
            tf.losses.add_loss(batch_loss_c5)
            tf.losses.add_loss(batch_loss_c4)
            tf.losses.add_loss(batch_loss_c3)

            total_loss = tf.losses.get_total_loss()
            self.batch_loss_c5 = batch_loss_c5
            self.batch_loss_c4 = batch_loss_c4
            self.batch_loss_c3 = batch_loss_c3
            self.total_loss = total_loss

            tf.summary.image('exemplar', self.exemplars, family=self.mode)
            tf.summary.image('instance', self.instances, family=self.mode)

            mean_batch_loss_c5, update_op1_c5 = tf.metrics.mean(batch_loss_c5)
            mean_batch_loss_c4, update_op1_c4 = tf.metrics.mean(batch_loss_c4)
            mean_batch_loss_c3, update_op1_c3 = tf.metrics.mean(batch_loss_c3)
            mean_total_loss, update_op2 = tf.metrics.mean(total_loss)
            with tf.control_dependencies(
                [update_op1_c5, update_op1_c4, update_op1_c3, update_op2]):
                tf.summary.scalar('batch_loss_c5',
                                  mean_batch_loss_c5,
                                  family=self.mode)
                tf.summary.scalar('batch_loss_c4',
                                  mean_batch_loss_c4,
                                  family=self.mode)
                tf.summary.scalar('batch_loss_c3',
                                  mean_batch_loss_c3,
                                  family=self.mode)
                tf.summary.scalar('total_loss',
                                  mean_total_loss,
                                  family=self.mode)

            if self.mode == 'train':
                tf.summary.image('GT',
                                 tf.reshape(gt[0], [1] + response_size + [1]),
                                 family='GT')
            tf.summary.image('Response_c5',
                             tf.expand_dims(tf.sigmoid(response_c5), -1),
                             family=self.mode)
            tf.summary.histogram('Response_c5',
                                 self.response_c5,
                                 family=self.mode)
            tf.summary.image('Response_c4',
                             tf.expand_dims(tf.sigmoid(response_c4), -1),
                             family=self.mode)
            tf.summary.histogram('Response_c4',
                                 self.response_c4,
                                 family=self.mode)
            tf.summary.image('Response_c3',
                             tf.expand_dims(tf.sigmoid(response_c3), -1),
                             family=self.mode)
            tf.summary.histogram('Response_c3',
                                 self.response_c3,
                                 family=self.mode)


# =============================================================================
#       # Two more metrics to monitor the performance of training
#       tf.summary.scalar('center_score_error', center_score_error(response), family=self.mode)
#       tf.summary.scalar('center_dist_error', center_dist_error(response), family=self.mode)
# =============================================================================

    def setup_global_step(self):
        global_step = tf.Variable(initial_value=0,
                                  name='global_step',
                                  trainable=False,
                                  collections=[
                                      tf.GraphKeys.GLOBAL_STEP,
                                      tf.GraphKeys.GLOBAL_VARIABLES
                                  ])

        self.global_step = global_step

    def setup_embedding_initializer(self):
        """Sets up the function to restore embedding variables from checkpoint."""
        embed_config = self.model_config['embed_config']
        if embed_config['embedding_checkpoint_file']:
            # Restore Siamese FC models from .mat model files
            initialize = load_mat_model(
                embed_config['embedding_checkpoint_file'],
                'convolutional_alexnet/', 'detection/')

            def restore_fn(sess):
                tf.logging.info(
                    "Restoring embedding variables from checkpoint file %s",
                    embed_config['embedding_checkpoint_file'])
                sess.run([initialize])

            self.init_fn = restore_fn

    def build(self, reuse=False):
        """Creates all ops for training and evaluation"""
        with tf.name_scope(self.mode):
            self.build_inputs()
            self.build_image_embeddings(reuse=reuse)
            self.build_template()
            self.build_detection(reuse=reuse)
            self.setup_embedding_initializer()

            if self.mode in ['train', 'validation']:
                self.build_loss()

            if self.is_training():
                self.setup_global_step()
Example #2
0
class CompSiamModel:
	def __init__(self, model_config, train_config, mode='train'):
		self.model_config = model_config
		self.train_config = train_config
		self.mode = mode
		assert mode in ['train', 'validation', 'inference']

		if self.mode == 'train':
			self.data_config = self.train_config['train_data_config']
		elif self.mode == 'validation':
			self.data_config = self.train_config['validation_data_config']

		self.dataloader = None
		self.keyFrame = None
		self.searchFrame = None
		self.response = None
		self.batch_loss = None
		self.total_loss = None
		self.init_fn = None
		self.global_step = None

	def is_training(self):
		return self.mode == 'train'



		# code used to get the data from pickled files represents the first frame and the next fram
	def build_inputs(self):
		if self.mode in ['train', 'validation']:
			with tf.device("/cpu:0"):  # Put data loading and preprocessing in CPU is substantially faster
				self.dataloader = DataLoader(self.data_config, self.is_training())
				self.dataloader.build()
				keyFrame, searchFrame = self.dataloader.get_one_batch()
				keyFrame = tf.to_float(keyFrame)
				searchFrame = tf.to_float(searchFrame)
		else:
			self.examplar_feed = tf.placeholder(shape=[None, None, None, 3],
																					dtype=tf.uint8,
																					name='examplar_input')
			self.searchFrame_feed = tf.placeholder(shape=[None, None, None, 3],
																					dtype=tf.uint8,
																					name='searchFrame_input')
			keyFrame = tf.to_float(self.examplar_feed)
			searchFrame = tf.to_float(self.searchFrame_feed)

			# images are rescaled to solve the exploding gradient problem *NOT SUGGESTED IN PAPER*
		self.keyFrame = keyFrame/128
		self.searchFrame = searchFrame/128



		# code for creating seperate vgg networks for keyframe and search frame
	def build_image_nets(self,vgg_pretrain= None, reuse=False):
		
		config = self.model_config['embed_config']
  
		sess = tf.Session()
		
		# key frame network 
		vgg_keyFrame = vgg19.Vgg19(vgg_pretrain)
		vgg_keyFrame.build(self.keyFrame)


		# search frame network
		vgg_searchFrame = vgg19.Vgg19(vgg_pretrain)
		vgg_searchFrame.build(self.searchFrame)

		self.keyFrame_net =  vgg_keyFrame
		self.searchFrame_net =  vgg_searchFrame



		# code for creating the cross correlation map for the two images
	def build_detection(self, curr_searchFrame_embeds, curr_templates, reuse=False):
		with tf.variable_scope('detection', reuse=tf.AUTO_REUSE):
			def _translation_match(x, z):  # translation match for one example within a batch
				x = tf.expand_dims(x, 0)  # [1, in_height, in_width, in_channels]
				z = tf.expand_dims(z, -1)  # [filter_height, filter_width, in_channels, 1]
				return tf.nn.conv2d(x, z, strides=[1, 1, 1, 1], padding='VALID', name='translation_match')



			output = tf.map_fn(lambda x: _translation_match(x[0], x[1]),
												 (curr_searchFrame_embeds, curr_templates),
												 dtype=curr_searchFrame_embeds.dtype)

			output = tf.squeeze(output, [1, 4])  # of shape e.g., [8, 15, 15]


			# Adjust score, this is required to make training possible.
			config = self.model_config['adjust_response_config']
			bias = tf.get_variable('biases', [1],
														 dtype=tf.float32,
														 initializer=tf.constant_initializer(0.0, dtype=tf.float32),
														 trainable=config['train_bias'])
			response = config['scale'] * output + bias

			# response refers to the cross correlation map between two images
			return response

			

			# building 5 blocks, flops and cross-corr maps for each image as mentioned in the paper
	def build_blocks(self):

		keyFrame_net = self.keyFrame_net
		searchFrame_net = self.searchFrame_net

		# block 1
		self.block1_keyFrame_embed = keyFrame_net.pool1
		self.block1_searchFrame_embed = searchFrame_net.pool1
		block1_flops = 2 * searchFrame_net.flops1


		block1_cross_corr = self.build_detection(self.block1_searchFrame_embed, self.block1_keyFrame_embed,reuse=True)

		# block 2
		block2_keyFrame_embed = keyFrame_net.pool2
		block2_searchFrame_embed = searchFrame_net.pool2
		block2_flops = 2 * searchFrame_net.flops2

	 
		block2_cross_corr = self.build_detection(block2_searchFrame_embed, block2_keyFrame_embed, reuse=False)


		# block 3
		block3_keyFrame_embed = keyFrame_net.pool3
		block3_searchFrame_embed = searchFrame_net.pool3
		block3_flops = 2 * searchFrame_net.flops3


		block3_cross_corr = self.build_detection(block3_searchFrame_embed, block3_keyFrame_embed,reuse=False)


		# block 4
		block4_keyFrame_embed = keyFrame_net.pool4
		block4_searchFrame_embed = searchFrame_net.pool4
		block4_flops = 2 * searchFrame_net.flops4


		block4_cross_corr = self.build_detection(block4_searchFrame_embed, block4_keyFrame_embed,reuse=False)


		# block 5
		block5_keyFrame_embed = keyFrame_net.pool5
		block5_searchFrame_embed = searchFrame_net.pool5
		block5_flops = 2 * searchFrame_net.flops5


		block5_cross_corr = self.build_detection(block5_searchFrame_embed, block5_keyFrame_embed,reuse=True)


		# number of flops for each block in vgg net
		self.flops_metric =  [block1_flops,block2_flops,block3_flops,block4_flops,block5_flops]

		# cross correlation maps for each block
		self.cross_corr = [block1_cross_corr, block2_cross_corr, block3_cross_corr, block4_cross_corr, block5_cross_corr]



		# code used to create ground truth box intersection between two neightbouring image
	def block_loss(self, block_cross_corr):

		cross_corr_size = block_cross_corr.get_shape().as_list()[1:3]  # [height, width]
		print("the batch size ",self.data_config['batch_size'])

		# ground truth box
		gt = construct_gt_score_maps(cross_corr_size,self.data_config['batch_size'],
																 self.model_config['embed_config']['stride'],
																 self.train_config['gt_config'])

		
		with tf.name_scope('Loss'):
				# softmax cross entropy used to measure loss as mentioned in paper
			loss = tf.losses.softmax_cross_entropy(gt,block_cross_corr)
		return loss


		# shallow feature extractor inorder to save computation time... Non Differentiable

		# INot sure if I have implemented each part of it in the right way
	def shallow_feature_extractor(self):
		with tf.Session() as sess:
			sess.run(tf.global_variables_initializer())
			cross_corr = sess.run(self.cross_corr)
			# flattened the cross corr for calculating kurtosis and entropy

			cross_corr_1d = [np.reshape(i,[i.shape[0],-1]) for i in cross_corr]
			# no tensorflow function found for finding kurtosis and entropy so computed it
			
			kurtosis_val = [kurtosis(i,axis=1) for i in cross_corr_1d]
			entropy_val = [entropy(i.T) for i in cross_corr_1d]

 			# expand dimension
 			kurtosis_val = np.expand_dims(np.array(kurtosis_val),axis=2)
			entropy_val  =  np.expand_dims(np.array(entropy_val),axis=2)

			# first five values of cross corr
			first_five  = np.array([i[:,:5]for i in cross_corr_1d])

			# max five values
			max_five = np.array([np.sort(i,axis = 1)[:,::-1][:,:5] for i in cross_corr_1d])
			
			shallow_features = np.concatenate([kurtosis_val,entropy_val,first_five,max_five],axis=2)
			self.shallow_features = tf.Variable(shallow_features,trainable=False)
			# print(max_five[0,0,:],shallow_features.shape)
	

	# function for adaptive computation with early stopping FOR hard gating while evaluation
	def act_early_stop(self,thresh= 0.5):
		# run this function when computation is stopped
		def same(i):
			curr_cross = self.cross_corr[i-1]
			curr_shallow = self.shallow_features[i-1]
			curr_budgetedscore = self.gStarFunc(i-1)
			curr_flops = self.flops_metric[i-1]
			key = i
			return curr_cross,curr_shallow,curr_flops,curr_budgetedscore,key

		# run this function when the budgeted confidence score is below the threshold 
		def next(i):
			curr_cross = self.cross_corr[i]
			curr_shallow = self.shallow_features[i]
			curr_budgetedscore = self.gStarFunc(i)
			curr_flops = self.flops_metric[i-1]
			key = i
			key += 1
			return curr_cross,curr_shallow,curr_flops,curr_budgetedscore,key	
		key = 0	
		# run the early stopping
		for i in range(5):
			if i ==0:
				return_values =  next(key)
			else:
				# the main func for early stop
				return_values =  tf.cond(finished_batch,lambda: same(i), lambda: next(i))
			curr_cross,curr_shallow,curr_flops,curr_budgetedscore,key = return_values
			# boolean for stopping
			finished = curr_budgetedscore > thresh
			# and over booleans
			finished_batch = tf.reduce_all(finished)
		
		# final values
		final_cross_corr = curr_cross
		final_flops = curr_flops
		return final_cross_corr,final_flops,key-1	




		# the g function formula which computes sigmoid 
	def gFunction(self):
		with tf.variable_scope('gating', reuse=False):
			self.gFuncResults =  tf.layers.dense(self.shallow_features,1,activation=tf.sigmoid)

		# Budgeted Gating Function 
	def gStarFunc(self,i):
		if i < 4:
			gStarSum = 0
			for j in range(i):
				gStarSum = gStarSum + self.gStarFunc(j)
			gStarValue = (1 - gStarSum) *self.gFuncResults[i] 	
			return gStarValue
		elif i == 4:
			gStarSum = 0
			for i in range(4):
				gStarSum = gStarSum + self.gStarFunc(i)
			return gStarSum
		else:
			return 0		 

		# Gate loss formula in paper
	def gateLoss(self,lamda=0.5):
		total_gate_loss = 0 
			# table values for incremental additional cost as mentioned in paper
		p_table = [1,1.43,3.35,3.4,0.95]
		tracking_loss = 0
		computational_loss = 0
		for i in range(5):
			gStarVal = self.gStarFunc(i)
			
			# tracking loss
			tracking_loss += gStarVal* self.block_losses[i]
			
			# computation loss
			computational_loss += p_table[i]* gStarVal
		
		# lamda ratio between track loss and comp loss
		total_gate_loss = tracking_loss + lamda*computational_loss	
		self.total_gate_loss = tf.reduce_mean(total_gate_loss)
		print(self.total_gate_loss)	

		# Intermediate Supervision loss for all blocks .. Also causes Exploding gradient
	def build_block_loss(self):

		cross_corr_arr = self.cross_corr
		loss = None
		self.block_losses = [self.block_loss(i) for i in cross_corr_arr]
		# total loss
		self.total_loss = tf.losses.get_total_loss()
						

		# function for evaluation which incluedes act
	def evaluate(self,vgg_pretrain= None,thresh=0.5):
		with tf.name_scope("validate"):
			self.build_inputs()
			self.build_image_nets(reuse=True)
			self.build_blocks()
			self.shallow_feature_extractor()
			self.gFunction()
			self.final_cross_corr,self.final_flops,self.stop_index  = self.act_early_stop(thresh = thresh)

			# training function
	def build(self, reuse=False,vgg_pretrain= None):
		with tf.name_scope(self.mode):
			self.build_inputs()
			self.build_image_nets(reuse=reuse,vgg_pretrain= vgg_pretrain)
			self.build_blocks()
			self.shallow_feature_extractor()
			self.gFunction()
			self.build_block_loss()
			self.gateLoss()
			print("done")
Example #3
0
class MBST:
    
    def __init__(self, configuration):
        self.model_config = configuration.MODEL_CONFIG
        self.train_config = configuration.TRAIN_CONFIG
        self.data_config = self.train_config['train_data_config']
        self.mode = "train"
        
        
    def build_inputs(self):
        self.dataloader = DataLoader(self.data_config, True)
        self.dataloader.build()
        exemplars, instances, clusters = self.dataloader.get_one_batch()
        self.exemplars = tf.to_float(exemplars)
        self.instances = tf.to_float(instances)
        self.classid = clusters[0]
        
    def build_embedding(self):

        self.templates = tf.case(
                pred_fn_pairs=[
                        (tf.equal(self.classid, '0', name="eq0"), lambda : embed_fn_0(self.exemplars)),
                        (tf.equal(self.classid, '1', name="eq1"), lambda : embed_fn_1(self.exemplars)),
                        (tf.equal(self.classid, '2', name="eq2"), lambda : embed_fn_2(self.exemplars)),
                        (tf.equal(self.classid, '3', name="eq3"), lambda : embed_fn_3(self.exemplars)),
                        (tf.equal(self.classid, '4', name="eq4"), lambda : embed_fn_4(self.exemplars)),
                        (tf.equal(self.classid, '5', name="eq5"), lambda : embed_fn_5(self.exemplars)),
                        (tf.equal(self.classid, '6', name="eq6"), lambda : embed_fn_6(self.exemplars)),
                        (tf.equal(self.classid, '7', name="eq7"), lambda : embed_fn_7(self.exemplars)),
                        (tf.equal(self.classid, '8', name="eq8"), lambda : embed_fn_8(self.exemplars)),
                        (tf.equal(self.classid, '9', name="eq9"), lambda : embed_fn_9(self.exemplars)),
                        (tf.equal(self.classid, 'ori', name="eq_ori"), lambda : embed_ori(self.exemplars))],
                        exclusive=False,
                        name="case1")
        self.instance_embeds = tf.case(
                pred_fn_pairs=[
                        (tf.equal(self.classid, '0', name="eq0"), lambda : embed_fn_0(self.instances)),
                        (tf.equal(self.classid, '1', name="eq1"), lambda : embed_fn_1(self.instances)),
                        (tf.equal(self.classid, '2', name="eq2"), lambda : embed_fn_2(self.instances)),
                        (tf.equal(self.classid, '3', name="eq3"), lambda : embed_fn_3(self.instances)),
                        (tf.equal(self.classid, '4', name="eq4"), lambda : embed_fn_4(self.instances)),
                        (tf.equal(self.classid, '5', name="eq5"), lambda : embed_fn_5(self.instances)),
                        (tf.equal(self.classid, '6', name="eq6"), lambda : embed_fn_6(self.instances)),
                        (tf.equal(self.classid, '7', name="eq7"), lambda : embed_fn_7(self.instances)),
                        (tf.equal(self.classid, '8', name="eq8"), lambda : embed_fn_8(self.instances)),
                        (tf.equal(self.classid, '9', name="eq9"), lambda : embed_fn_9(self.instances)),
                        (tf.equal(self.classid, 'ori', name="eq_ori"), lambda : embed_ori(self.instances))],
                        exclusive=False,
                        name="case1")

    def build_detection(self, reuse=False):
        with tf.variable_scope('detection', reuse=reuse):
            def _translation_match(x, z):  # translation match for one example within a batch
                x = tf.expand_dims(x, 0)  # [1, in_height, in_width, in_channels]
                z = tf.expand_dims(z, -1)  # [filter_height, filter_width, in_channels, 1]
                return tf.nn.conv2d(x, z, strides=[1, 1, 1, 1], padding='VALID', name='translation_match')

            output = tf.map_fn(lambda x: _translation_match(x[0], x[1]),
                         (self.instance_embeds, self.templates),
                         dtype=self.instance_embeds.dtype)
            output = tf.squeeze(output, [1, 4])  # of shape e.g., [8, 15, 15]

            # Adjust score, this is required to make training possible.
            config = self.model_config['adjust_response_config']
            self.bias = tf.get_variable('biases', [1],
                         dtype=tf.float32,
                         initializer=tf.constant_initializer(0.0, dtype=tf.float32),
                         trainable=config['train_bias'])
            response = config['scale'] * output + self.bias
            self.response = response

    def build_loss(self):
        response = self.response
        response_size = response.get_shape().as_list()[1:3]  # [height, width]

        self.gt = construct_gt_score_maps(response_size,
                                 self.data_config['batch_size'],
                                 self.model_config['embed_config']['stride'],
                                 self.train_config['gt_config'])

        with tf.name_scope('Loss'):
            loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=response,
                                                     labels=self.gt)
            with tf.name_scope('Balance_weights'):
                n_pos = tf.reduce_sum(tf.to_float(tf.equal(self.gt[0], 1)))
                n_neg = tf.reduce_sum(tf.to_float(tf.equal(self.gt[0], 0)))
                w_pos = 0.5 / n_pos
                w_neg = 0.5 / n_neg
                class_weights = tf.where(tf.equal(self.gt, 1),
                                         w_pos * tf.ones_like(self.gt),
                                         tf.ones_like(self.gt))
                class_weights = tf.where(tf.equal(self.gt, 0),
                                         w_neg * tf.ones_like(self.gt),
                                         class_weights)
                loss = loss * class_weights

            # Note that we use reduce_sum instead of reduce_mean since the loss has
            # already been normalized by class_weights in spatial dimension.
            loss = tf.reduce_sum(loss, [1, 2])
            
            batch_loss = tf.reduce_mean(loss, name='batch_loss')
            tf.losses.add_loss(batch_loss)
        
            total_loss = tf.losses.get_total_loss()
            self.batch_loss = batch_loss
            self.total_loss = total_loss

    def setup_global_step(self):
        global_step = tf.Variable(
                initial_value=0,
                name='global_step',
                trainable=False,
                collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])

        self.global_step = global_step

    def build(self, reuse=False):
        """Creates all ops for training and evaluation"""
        with tf.name_scope(self.mode):
            self.build_inputs()
            self.build_embedding()
            self.build_detection(reuse=reuse)
            self.build_loss()
            self.setup_global_step()
class SiameseModel:
  def __init__(self, model_config, train_config, mode='train'):
    self.model_config = model_config
    self.train_config = train_config
    self.mode = mode
    assert mode in ['train', 'validation', 'inference']

    if self.mode == 'train':
      self.data_config = self.train_config['train_data_config']
    elif self.mode == 'validation':
      self.data_config = self.train_config['validation_data_config']

    self.dataloader = None
    self.exemplars = None
    self.instances = None
    self.response = None
    self.batch_loss = None
    self.total_loss = None
    self.init_fn = None
    self.global_step = None

  def is_training(self):
    """Returns true if the model is built for training mode"""
    return self.mode == 'train'

  def build_inputs(self):
    """Input fetching and batching

    Outputs:
      self.exemplars: image batch of shape [batch, hz, wz, 3]
      self.instances: image batch of shape [batch, hx, wx, 3]
    """
    if self.mode in ['train', 'validation']:
      with tf.device("/cpu:0"):  # Put data loading and preprocessing in CPU is substantially faster
        self.dataloader = DataLoader(self.data_config, self.is_training())
        self.dataloader.build()
        exemplars, instances = self.dataloader.get_one_batch()

        exemplars = tf.to_float(exemplars)
        instances = tf.to_float(instances)
    else:
      self.examplar_feed = tf.placeholder(shape=[None, None, None, 3],
                                          dtype=tf.uint8,
                                          name='examplar_input')
      self.instance_feed = tf.placeholder(shape=[None, None, None, 3],
                                          dtype=tf.uint8,
                                          name='instance_input')
      exemplars = tf.to_float(self.examplar_feed)
      instances = tf.to_float(self.instance_feed)

    self.exemplars = exemplars
    self.instances = instances

  def build_image_embeddings(self, reuse=False):
    """Builds the image model subgraph and generates image embeddings

    Inputs:
      self.exemplars: A tensor of shape [batch, hz, wz, 3]
      self.instances: A tensor of shape [batch, hx, wx, 3]

    Outputs:
      self.exemplar_embeds: A Tensor of shape [batch, hz_embed, wz_embed, embed_dim]
      self.instance_embeds: A Tensor of shape [batch, hx_embed, wx_embed, embed_dim]
    """
    config = self.model_config['embed_config']
    arg_scope = convolutional_alexnet_arg_scope(config,
                                                trainable=config['train_embedding'],
                                                is_training=self.is_training())

    @functools.wraps(convolutional_alexnet)
    def embedding_fn(images, reuse=False):
      with slim.arg_scope(arg_scope):
        return convolutional_alexnet(images, reuse=reuse)

    self.exemplar_embeds, _ = embedding_fn(self.exemplars, reuse=reuse)
    self.instance_embeds, _ = embedding_fn(self.instances, reuse=True)

  def build_template(self):
    # The template is simply the feature of the exemplar image in SiamFC.
    self.templates = self.exemplar_embeds

  def build_detection(self, reuse=False):
    with tf.variable_scope('detection', reuse=reuse):
      def _translation_match(x, z):  # translation match for one example within a batch
        x = tf.expand_dims(x, 0)  # [1, in_height, in_width, in_channels]
        z = tf.expand_dims(z, -1)  # [filter_height, filter_width, in_channels, 1]
        return tf.nn.conv2d(x, z, strides=[1, 1, 1, 1], padding='VALID', name='translation_match')

      output = tf.map_fn(lambda x: _translation_match(x[0], x[1]),
                         (self.instance_embeds, self.templates),
                         dtype=self.instance_embeds.dtype)
      output = tf.squeeze(output, [1, 4])  # of shape e.g., [8, 15, 15]

      # Adjust score, this is required to make training possible.
      config = self.model_config['adjust_response_config']
      bias = tf.get_variable('biases', [1],
                             dtype=tf.float32,
                             initializer=tf.constant_initializer(0.0, dtype=tf.float32),
                             trainable=config['train_bias'])
      response = config['scale'] * output + bias
      self.response = response

  def build_loss(self):
    response = self.response
    response_size = response.get_shape().as_list()[1:3]  # [height, width]

    gt = construct_gt_score_maps(response_size,
                                 self.data_config['batch_size'],
                                 self.model_config['embed_config']['stride'],
                                 self.train_config['gt_config'])

    with tf.name_scope('Loss'):
      loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=response,
                                                     labels=gt)

      with tf.name_scope('Balance_weights'):
        n_pos = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 1)))
        n_neg = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 0)))
        w_pos = 0.5 / n_pos
        w_neg = 0.5 / n_neg
        class_weights = tf.where(tf.equal(gt, 1),
                                 w_pos * tf.ones_like(gt),
                                 tf.ones_like(gt))
        class_weights = tf.where(tf.equal(gt, 0),
                                 w_neg * tf.ones_like(gt),
                                 class_weights)
        loss = loss * class_weights

      # Note that we use reduce_sum instead of reduce_mean since the loss has
      # already been normalized by class_weights in spatial dimension.
      loss = tf.reduce_sum(loss, [1, 2])

      batch_loss = tf.reduce_mean(loss, name='batch_loss')
      tf.losses.add_loss(batch_loss)

      total_loss = tf.losses.get_total_loss()
      self.batch_loss = batch_loss
      self.total_loss = total_loss

      tf.summary.image('exemplar', self.exemplars, family=self.mode)
      tf.summary.image('instance', self.instances, family=self.mode)

      mean_batch_loss, update_op1 = tf.metrics.mean(batch_loss)
      mean_total_loss, update_op2 = tf.metrics.mean(total_loss)
      with tf.control_dependencies([update_op1, update_op2]):
        tf.summary.scalar('batch_loss', mean_batch_loss, family=self.mode)
        tf.summary.scalar('total_loss', mean_total_loss, family=self.mode)

      if self.mode == 'train':
        tf.summary.image('GT', tf.reshape(gt[0], [1] + response_size + [1]), family='GT')
      tf.summary.image('Response', tf.expand_dims(tf.sigmoid(response), -1), family=self.mode)
      tf.summary.histogram('Response', self.response, family=self.mode)

      # Two more metrics to monitor the performance of training
      tf.summary.scalar('center_score_error', center_score_error(response), family=self.mode)
      tf.summary.scalar('center_dist_error', center_dist_error(response), family=self.mode)

  def setup_global_step(self):
    global_step = tf.Variable(
      initial_value=0,
      name='global_step',
      trainable=False,
      collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])

    self.global_step = global_step

  def setup_embedding_initializer(self):
    """Sets up the function to restore embedding variables from checkpoint."""
    embed_config = self.model_config['embed_config']
    if embed_config['embedding_checkpoint_file']:
      # Restore Siamese FC models from .mat model files
      initialize = load_mat_model(embed_config['embedding_checkpoint_file'],
                                  'convolutional_alexnet/', 'detection/')

      def restore_fn(sess):
        tf.logging.info("Restoring embedding variables from checkpoint file %s",
                        embed_config['embedding_checkpoint_file'])
        sess.run([initialize])

      self.init_fn = restore_fn

  def build(self, reuse=False):
    """Creates all ops for training and evaluation"""
    with tf.name_scope(self.mode):
      self.build_inputs()
      self.build_image_embeddings(reuse=reuse)
      self.build_template()
      self.build_detection(reuse=reuse)
      self.setup_embedding_initializer()

      if self.mode in ['train', 'validation']:
        self.build_loss()

      if self.is_training():
        self.setup_global_step()
Example #5
0
class SiameseModel:
  def __init__(self, model_config, train_config, mode='train'):
    self.model_config = model_config
    self.train_config = train_config
    self.mode = mode
    assert mode in ['train', 'validation', 'inference']

    if self.mode == 'train':
      self.data_config = self.train_config['train_data_config']
    elif self.mode == 'validation':
      self.data_config = self.train_config['validation_data_config']

    self.dataloader = None
    self.exemplars = None
    self.instances = None
    self.response = None
    self.batch_loss = None
    self.total_loss = None
    self.init_fn = None
    self.global_step = None

  def is_training(self):
    """Returns true if the model is built for training mode"""
    return self.mode == 'train'

  def build_inputs(self):
    """Input fetching and batching

    Outputs:
      self.exemplars: image batch of shape [batch, hz, wz, 3]
      self.instances: image batch of shape [batch, hx, wx, 3]
    """
    if self.mode in ['train', 'validation']:
      with tf.device("/cpu:0"):  # Put data loading and preprocessing in CPU is substantially faster
        self.dataloader = DataLoader(self.data_config, self.is_training())
        self.dataloader.build()
        exemplars, instances = self.dataloader.get_one_batch()

        exemplars = tf.to_float(exemplars)
        instances = tf.to_float(instances)
    else:
      self.examplar_feed = tf.placeholder(shape=[None, None, None, 3],
                                          dtype=tf.uint8,
                                          name='examplar_input')
      self.instance_feed = tf.placeholder(shape=[None, None, None, 3],
                                          dtype=tf.uint8,
                                          name='instance_input')
      exemplars = tf.to_float(self.examplar_feed)
      instances = tf.to_float(self.instance_feed)

    self.exemplars = exemplars
    self.instances = instances

  def build_image_embeddings(self, reuse=False):
    """Builds the image model subgraph and generates image embeddings

    Inputs:
      self.exemplars: A tensor of shape [batch, hz, wz, 3]
      self.instances: A tensor of shape [batch, hx, wx, 3]

    Outputs:
      self.exemplar_embeds: A Tensor of shape [batch, hz_embed, wz_embed, embed_dim]
      self.instance_embeds: A Tensor of shape [batch, hx_embed, wx_embed, embed_dim]
    """
    config = self.model_config['embed_config']
    arg_scope = sa_siam_arg_scope(config,
                                  trainable=config['train_embedding'],
                                  is_training=self.is_training())
    with slim.arg_scope(arg_scope):
      self.exemplar_embeds, _ = sa_siam(inputs=self.exemplars, is_example=True, reuse=reuse, sa_siam_config=self.model_config['sa_siam_config'])
      self.instance_embeds, _ = sa_siam(inputs=self.instances, is_example=False, reuse=True, sa_siam_config=self.model_config['sa_siam_config'])

  def build_template(self):
    # The template is simply the feature of the exemplar image in SiamFC.
    self.templates = self.exemplar_embeds

  def build_detection(self, reuse=False):
    with tf.variable_scope('detection', reuse=reuse):
      def _translation_match(x, z):  # translation match for one example within a batch
        x = tf.expand_dims(x, 0)  # [1, in_height, in_width, in_channels]
        z = tf.expand_dims(z, -1)  # [filter_height, filter_width, in_channels, 1]
        return tf.nn.conv2d(x, z, strides=[1, 1, 1, 1], padding='VALID', name='translation_match')

      output = tf.map_fn(lambda x: _translation_match(x[0], x[1]),
                         (self.instance_embeds, self.templates),
                         dtype=self.instance_embeds.dtype)
      output = tf.squeeze(output, [1, 4])  # of shape e.g., [8, 15, 15]

      # Adjust score, this is required to make training possible.
      config = self.model_config['adjust_response_config']
      bias = tf.get_variable('biases', [1],
                             dtype=tf.float32,
                             initializer=tf.constant_initializer(0.0, dtype=tf.float32),
                             trainable=config['train_bias'])
      response = config['scale'] * output + bias
      self.response = response

  def build_loss(self):
    response = self.response
    response_size = response.get_shape().as_list()[1:3]  # [height, width]

    gt = construct_gt_score_maps(response_size,
                                 self.data_config['batch_size'],
                                 self.model_config['embed_config']['stride'],
                                 self.train_config['gt_config'])

    with tf.name_scope('Loss'):
      loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=response,
                                                     labels=gt)

      with tf.name_scope('Balance_weights'):
        n_pos = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 1)))
        n_neg = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 0)))
        w_pos = 0.5 / n_pos
        w_neg = 0.5 / n_neg
        class_weights = tf.where(tf.equal(gt, 1),
                                 w_pos * tf.ones_like(gt),
                                 tf.ones_like(gt))
        class_weights = tf.where(tf.equal(gt, 0),
                                 w_neg * tf.ones_like(gt),
                                 class_weights)
        loss = loss * class_weights

      # Note that we use reduce_sum instead of reduce_mean since the loss has
      # already been normalized by class_weights in spatial dimension.
      loss = tf.reduce_sum(loss, [1, 2])

      batch_loss = tf.reduce_mean(loss, name='batch_loss')
      tf.losses.add_loss(batch_loss)

      total_loss = tf.losses.get_total_loss()
      self.batch_loss = batch_loss
      self.total_loss = total_loss

      tf.summary.image('exemplar', self.exemplars, family=self.mode)
      tf.summary.image('instance', self.instances, family=self.mode)

      mean_batch_loss, update_op1 = tf.metrics.mean(batch_loss)
      mean_total_loss, update_op2 = tf.metrics.mean(total_loss)
      with tf.control_dependencies([update_op1, update_op2]):
        tf.summary.scalar('batch_loss', mean_batch_loss, family=self.mode)
        tf.summary.scalar('total_loss', mean_total_loss, family=self.mode)

      if self.mode == 'train':
        tf.summary.image('GT', tf.reshape(gt[0], [1] + response_size + [1]), family='GT')
      tf.summary.image('Response', tf.expand_dims(tf.sigmoid(response), -1), family=self.mode)
      tf.summary.histogram('Response', self.response, family=self.mode)

      # Two more metrics to monitor the performance of training
      tf.summary.scalar('center_score_error', center_score_error(response), family=self.mode)
      tf.summary.scalar('center_dist_error', center_dist_error(response), family=self.mode)

  def setup_global_step(self):
    global_step = tf.Variable(
      initial_value=0,
      name='global_step',
      trainable=False,
      collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])

    self.global_step = global_step

  def setup_embedding_initializer(self):
    """Sets up the function to restore embedding variables from checkpoint."""
    embed_config = self.model_config['embed_config']
    if embed_config['embedding_checkpoint_file']:
      # Restore Siamese FC models from .mat model files
      initialize = load_mat_model(embed_config['embedding_checkpoint_file'],
                                  'sa_siam/appearance_net/', 'detection/')

      def restore_fn(sess):
        tf.logging.info("Restoring embedding variables from checkpoint file %s",
                        embed_config['embedding_checkpoint_file'])
        sess.run([initialize])

      self.init_fn = restore_fn

  def build(self, reuse=False):
    """Creates all ops for training and evaluation"""
    with tf.name_scope(self.mode):
      self.build_inputs()
      self.build_image_embeddings(reuse=reuse)
      self.build_template()
      self.build_detection(reuse=reuse)
      self.setup_embedding_initializer()

      if self.mode in ['train', 'validation']:
        self.build_loss()

      if self.is_training():
        self.setup_global_step()
Example #6
0
def main():
    # Select gpu to run
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = train_config['gpu_select']
    gpu_list = [int(i) for i in train_config['gpu_select'].split(',')]

    # Create training directory
    os.makedirs(train_config['train_dir'], exist_ok=True)
    os.makedirs(train_config['checkpoint_dir'], exist_ok=True)
    os.makedirs(train_config['log_dir'], exist_ok=True)
    os.makedirs(train_config['config_saver_dir'], exist_ok=True)

    # Save configurations .json in train_dir
    save_cfgs(train_config['config_saver_dir'], model_config, train_config)

    g = tf.Graph()
    with g.as_default():
        global_step = tf.Variable(initial_value=0,
                                  name='global_step',
                                  trainable=False,
                                  collections=[
                                      tf.GraphKeys.GLOBAL_STEP,
                                      tf.GraphKeys.GLOBAL_VARIABLES
                                  ])
        learning_rate = _configure_learning_rate(train_config,
                                                 global_step)  # set lr
        tf.summary.scalar(
            'learning_rate',
            learning_rate)  #see learning rate in tensorboard-scalars
        opt = _configure_optimizer(train_config,
                                   learning_rate)  # set optimizer
        tower_grads = []  #gradient list of each gpu

        # Build dataloader
        ## train dataloader
        train_data_config = train_config['train_data_config']
        with tf.device("/cpu:0"):
            train_dataloader = DataLoader(train_data_config, is_training=True)
            train_dataloader.build()
        ## validate dataloader
        validate_data_config = train_config['validation_data_config']
        with tf.device("/cpu:0"):
            validate_dataloader = DataLoader(validate_data_config,
                                             is_training=False)
            validate_dataloader.build()

        #Build network on multi-gpu with the same graph
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(len(gpu_list)):
                #build same graph on each gpu
                with tf.device('/gpu:%d' % i):
                    # to distinguish variable in different gpu
                    with tf.name_scope('model_%d' % i):
                        if i == 0:
                            train_image_instances, train_label = train_dataloader.get_one_batch(
                            )
                            train_image_instances = tf.to_float(
                                train_image_instances)
                            train_label = tf.to_int32(train_label)
                            model = ModelConstruct(model_config,
                                                   train_config,
                                                   train_image_instances,
                                                   train_label,
                                                   mode='train')
                            model.build()
                            with tf.device("/cpu:0"):
                                tf.summary.scalar('total_loss',
                                                  model.total_loss,
                                                  family='train')
                                tf.summary.scalar('acc',
                                                  model.acc,
                                                  family='train')
                            #validate
                            validate_image_instances, validate_label = validate_dataloader.get_one_batch(
                            )
                            validate_image_instances = tf.to_float(
                                validate_image_instances)
                            validate_label = tf.to_int32(validate_label)
                            model_va = ModelConstruct(model_config,
                                                      train_config,
                                                      validate_image_instances,
                                                      validate_label,
                                                      mode='validation')
                            model_va.build(reuse=True)
                            with tf.device("/cpu:0"):
                                tf.summary.scalar('total_loss',
                                                  model_va.total_loss,
                                                  family='validation')
                                tf.summary.scalar('acc',
                                                  model_va.acc,
                                                  family='validation')
                        else:
                            train_image_instances, train_label = train_dataloader.get_one_batch(
                            )
                            train_image_instances = tf.to_float(
                                train_image_instances)
                            train_label = tf.to_int32(train_label)
                            model = ModelConstruct(model_config,
                                                   train_config,
                                                   train_image_instances,
                                                   train_label,
                                                   mode='train')
                            model.build(reuse=True)

                        tf.get_variable_scope().reuse_variables()
                        grad = opt.compute_gradients(model.total_loss)
                        tower_grads.append(grad)

        mean_grads = average_gradients(tower_grads)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = opt.apply_gradients(mean_grads, global_step=global_step)

        # save checkpoint
        saver = tf.train.Saver(
            tf.global_variables(),
            max_to_keep=train_config['max_checkpoints_to_keep'])
        # save the graph
        summary_writer = tf.summary.FileWriter(train_config['log_dir'], g)
        summary_op = tf.summary.merge_all()

        global_variables_init_op = tf.global_variables_initializer()
        local_variables_init_op = tf.local_variables_initializer()
        g.finalize()  # Finalize graph to avoid adding ops by mistake

        # Dynamically allocate GPU memory
        gpu_options = tf.GPUOptions(allow_growth=True)
        sess_config = tf.ConfigProto(gpu_options=gpu_options)

        sess = tf.Session(config=sess_config)
        model_path = tf.train.latest_checkpoint(train_config['checkpoint_dir'])
        # re-train or start new train
        if not model_path:
            sess.run(global_variables_init_op)
            sess.run(local_variables_init_op)
            start_step = 0
        else:
            logging.info('Restore from last checkpoint: {}'.format(model_path))
            sess.run(local_variables_init_op)
            saver.restore(sess, model_path)
            start_step = tf.train.global_step(sess, global_step.name) + 1

        # Training loop
        start_time = time.time()
        data_config = train_config['train_data_config']
        total_steps = int(data_config['epoch'] *
                          data_config['num_examples_per_epoch'] /
                          (data_config['batch_size'] * configuration.gpu_num))
        logging.info('Train for {} steps'.format(total_steps))
        for step in range(start_step, total_steps):
            _, loss = sess.run([
                train_op,
                model.total_loss,
            ])
            if step % train_config['log_every_n_steps'] == 0:
                logging.info(
                    '{}-->step {:d} - ({:.2f}%), total loss = {:.2f} '.format(
                        datetime.now(), step,
                        float(step) / total_steps * 100, loss))
            # each 100 steps update tensorboard-summay
            if step % train_config['tensorboard_summary_every_n_steps'] == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)
            # save model each epoch
            if step % train_config['save_model_every_n_step'] == 0 or (
                    step + 1) == total_steps:
                saver.save(sess,
                           os.path.join(train_config['checkpoint_dir'],
                                        'model.ckpt'),
                           global_step=step)
        duration = time.time() - start_time
        m, s = divmod(duration, 60)
        h, m = divmod(m, 60)
        print('The total training loop finished after {:d}h:{:02d}m:{:02d}s'.
              format(int(h), int(m), int(s)))
        sess.close()