Ejemplo n.º 1
0
    def init_networks(self):
        network_image_size = self.image_size
        network_output_size = self.output_size

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1, self.num_frames] + network_image_size),
                ('instances_merged', [None] + network_output_size),
                ('instances_bac', [None] + network_output_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('instances_merged', network_output_size + [None]),
                ('instances_bac', network_output_size + [None])
            ])

        # build val graph
        val_placeholders = create_placeholders(data_generator_entries,
                                               shape_prefix=[1])
        self.data_val = val_placeholders['image']

        with tf.variable_scope('net'):
            self.embeddings_0, self.embeddings_1 = network(
                self.data_val,
                num_outputs_embedding=self.num_embeddings,
                data_format=self.data_format,
                actual_network=HourglassNet3D,
                is_training=False)
            self.embeddings_normalized_0 = tf.nn.l2_normalize(
                self.embeddings_0, dim=self.channel_axis)
            self.embeddings_normalized_1 = tf.nn.l2_normalize(
                self.embeddings_1, dim=self.channel_axis)
            self.embeddings_cropped_val = (self.embeddings_normalized_0,
                                           self.embeddings_normalized_1)
Ejemplo n.º 2
0
    def init_networks(self):
        network_image_size = self.image_size

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [1] + network_image_size)])
        else:
            data_generator_entries = OrderedDict([('image', network_image_size + [1])])

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build val graph
        val_placeholders = create_placeholders(data_generator_entries, shape_prefix=[1])
        self.data_val = val_placeholders['image']
        self.prediction_val, self.local_prediction_val, self.spatial_prediction_val = training_net(self.data_val, num_labels=self.num_labels, is_training=False, data_format=self.data_format)
        self.prediction_softmax_val = tf.nn.softmax(self.prediction_val, axis=1 if self.data_format == 'channels_first' else 4)
Ejemplo n.º 3
0
    def initNetworks(self):
        net = tf.make_template('scn', self.network)

        if self.data_format == 'channels_first':
            if self.generate_landmarks:
                data_generator_entries = OrderedDict([
                    ('image',
                     [self.image_channels] + list(reversed(self.image_size))),
                    ('landmarks', [self.num_landmarks, 4])
                ])
            else:
                data_generator_entries = OrderedDict([
                    ('image',
                     [self.image_channels] + list(reversed(self.image_size))),
                    ('heatmaps',
                     [self.num_landmarks] + list(reversed(self.heatmap_size)))
                ])
        else:
            if self.generate_landmarks:
                data_generator_entries = OrderedDict([
                    ('image',
                     list(reversed(self.image_size)) + [self.image_channels]),
                    ('landmarks', [self.num_landmarks, 4])
                ])
            else:
                data_generator_entries = OrderedDict([
                    ('image',
                     list(reversed(self.image_size)) + [self.image_channels]),
                    ('heatmaps',
                     list(reversed(self.heatmap_size)) + [self.num_landmarks])
                ])

        sigmas = tf.get_variable('sigmas', [self.num_landmarks],
                                 initializer=tf.constant_initializer(
                                     self.heatmap_sigma))
        mean_sigmas = tf.reduce_mean(sigmas)
        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size)
        #self.train_queue = DataGeneratorDataset(self.dataset_train, data_generator_entries, batch_size=self.batch_size)
        placeholders = self.train_queue.dequeue()
        image = placeholders[0]

        if self.generate_landmarks:
            target_landmarks = placeholders[1]
            target_heatmaps = generate_heatmap_target(
                list(reversed(self.heatmap_size)),
                target_landmarks,
                sigmas,
                scale=self.sigma_scale,
                normalize=True,
                data_format=self.data_format)
            loss_sigmas = self.sigma_regularization * tf.nn.l2_loss(
                sigmas * target_landmarks[0, :, 0])
        else:
            target_heatmaps = placeholders[1]
            loss_sigmas = self.sigma_regularization * tf.nn.l2_loss(sigmas)
        heatmaps, _, _ = net(image,
                             num_heatmaps=self.num_landmarks,
                             is_training=True,
                             data_format=self.data_format)
        self.loss_net = self.loss_function(heatmaps, target_heatmaps)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            if self.reg_constant > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss_reg = self.reg_constant * tf.add_n(reg_losses)
                self.loss = self.loss_net + self.loss_reg + loss_sigmas
            else:
                self.loss = self.loss_net + loss_sigmas

        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg),
                                         ('loss_sigmas', loss_sigmas),
                                         ('mean_sigmas', mean_sigmas)])
        # self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
        self.optimizer = tf.train.MomentumOptimizer(
            learning_rate=self.learning_rate, momentum=0.99,
            use_nesterov=True).minimize(self.loss)

        # build val graph
        val_placeholders = create_placeholders(data_generator_entries,
                                               shape_prefix=[1])
        self.image_val = val_placeholders['image']
        if self.generate_landmarks:
            self.target_landmarks_val = val_placeholders['landmarks']
            self.target_heatmaps_val = generate_heatmap_target(
                list(reversed(self.heatmap_size)),
                self.target_landmarks_val,
                sigmas,
                scale=self.sigma_scale,
                normalize=True,
                data_format=self.data_format)
            loss_sigmas_val = self.sigma_regularization * tf.nn.l2_loss(
                sigmas * self.target_landmarks_val[0, :, 0])
        else:
            self.target_heatmaps_val = val_placeholders['heatmaps']
            loss_sigmas_val = self.sigma_regularization * tf.nn.l2_loss(sigmas)
        self.heatmaps_val, self.heatmals_local_val, self.heatmaps_global_val = net(
            self.image_val,
            num_heatmaps=self.num_landmarks,
            is_training=False,
            data_format=self.data_format)

        # losses
        self.loss_val = self.loss_function(self.heatmaps_val,
                                           self.target_heatmaps_val)
        self.val_losses = OrderedDict([('loss', self.loss_val),
                                       ('loss_reg', self.loss_reg),
                                       ('loss_sigmas', loss_sigmas_val),
                                       ('mean_sigmas', mean_sigmas)])
Ejemplo n.º 4
0
    def init_networks(self):
        """
        Init training and validation networks.
        """
        network_image_size = list(reversed(self.image_size))
        num_instances = 1 if self.bitwise_instance_image else None
        num_instances_val = None

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1, self.num_frames] + network_image_size),
                ('instances_merged',
                 [num_instances, self.num_frames] + network_image_size),
                ('instances_bac',
                 [num_instances, self.num_frames] + network_image_size)
            ])
            data_generator_entries_test_cropped_single_frame = OrderedDict([
                ('image', [1] + network_image_size),
                ('instances_merged', [num_instances_val] + network_image_size),
                ('instances_bac', [num_instances_val] + network_image_size)
            ])
            embedding_normalization_function = lambda x: tf.nn.l2_normalize(
                x, dim=self.channel_axis)
        else:
            assert 'channels_last not supported'
        data_generator_types = {
            'image': tf.float32,
            'instances_merged': self.bitwise_instances_image_type,
            'instances_bac': self.bitwise_instances_image_type
        }

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGeneratorPadding(
            self.dataset_train,
            self.coord,
            data_generator_entries,
            batch_size=self.batch_size,
            data_types=data_generator_types,
            n_threads=4)

        # build train graph
        data, instances_tra, instances_bac = self.train_queue.dequeue()
        embeddings_tuple = training_net(
            data,
            num_outputs_embedding=self.num_embeddings,
            is_training=True,
            data_format=self.data_format,
            actual_network=self.actual_network,
            **self.network_parameters)

        if not isinstance(embeddings_tuple, tuple):
            embeddings_tuple = (embeddings_tuple, )

        loss_reg = get_reg_loss(self.reg_constant, True)

        with tf.name_scope('train_loss'):
            train_losses_dict = self.losses(
                embeddings_tuple,
                instances_tra,
                instances_bac,
                bitwise_instances=self.bitwise_instance_image)
            train_losses_dict['loss_reg'] = loss_reg
            self.loss = tf.reduce_sum(list(train_losses_dict.values()))
            self.train_losses = train_losses_dict

        # solver
        global_step = tf.Variable(self.current_iter)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        #optimizer = tf.contrib.opt.NadamOptimizer(learning_rate=learning_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        self.optimizer = optimizer.minimize(self.loss, global_step=global_step)

        # initialize variables
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())

        print('Variables')
        for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
            print(i)

        # build val graph
        val_placeholders_cropped = create_placeholders(
            data_generator_entries_test_cropped_single_frame, shape_prefix=[1])
        self.data_cropped_val = val_placeholders_cropped['image']
        self.instances_cropped_tra_val = val_placeholders_cropped[
            'instances_merged']
        self.instances_cropped_bac_val = val_placeholders_cropped[
            'instances_bac']
        with tf.variable_scope('net/rnn', reuse=True):
            output_tuple = network_single_frame_with_lstm_states(
                self.data_cropped_val,
                num_outputs_embedding=self.num_embeddings,
                data_format=self.data_format,
                actual_network=self.actual_network,
                **self.network_parameters)
            self.lstm_input_states_cropped_val = output_tuple[0]
            self.lstm_output_states_cropped_val = output_tuple[1]
            self.embeddings_cropped_val = output_tuple[2:]

        if not isinstance(self.embeddings_cropped_val, tuple):
            self.embeddings_cropped_val = (self.embeddings_cropped_val, )

        with tf.variable_scope('loss'):
            val_losses_dict = self.losses(self.embeddings_cropped_val,
                                          self.instances_cropped_tra_val,
                                          self.instances_cropped_bac_val,
                                          bitwise_instances=False)
            val_losses_dict['loss_reg'] = loss_reg
            self.loss_val = tf.reduce_sum(list(val_losses_dict.values()))
            self.val_losses = val_losses_dict

        if not self.normalized_embeddings:
            self.embeddings_cropped_val = tuple([
                embedding_normalization_function(e)
                for e in self.embeddings_cropped_val
            ])