Exemple #1
0
    def train_fit(self):
        tf_record_util = TFRecordUtility(self.output_len)
        '''prepare callbacks'''
        callbacks_list = self._prepare_callback()
        ''' define optimizers'''
        optimizer = self._get_optimizer()
        '''create train, validation, test data iterator'''
        train_images, train_landmarks = tf_record_util.create_training_tensor_points(
            tfrecord_filename=self.tf_train_path, batch_size=self.BATCH_SIZE)
        validation_images, validation_landmarks = \
            tf_record_util.create_training_tensor_points(tfrecord_filename=self.tf_eval_path,
                                                         batch_size=self.BATCH_SIZE)
        '''creating model'''
        cnn = CNNModel()
        # model = tf.keras.models.load_model(self.weight)

        model = cnn.get_model(train_images=train_images,
                              arch=self.arch,
                              num_output_layers=self.num_output_layers,
                              output_len=self.output_len,
                              input_tensor=train_images,
                              inp_shape=None)
        if self.weight is not None:
            model.load_weights(self.weight)
        '''compiling model'''
        model.compile(
            loss=self._generate_loss(),
            optimizer=optimizer,
            metrics=['mse', 'mae'],
            target_tensors=self._generate_target_tensors(train_landmarks),
            loss_weights=self._generate_loss_weights())
        '''train Model '''
        print('< ========== Start Training ============= >')

        history = model.fit(train_images,
                            train_landmarks,
                            epochs=self.EPOCHS,
                            steps_per_epoch=self.STEPS_PER_EPOCH,
                            validation_data=(validation_images,
                                             validation_landmarks),
                            validation_steps=self.STEPS_PER_VALIDATION_EPOCH,
                            verbose=1,
                            callbacks=callbacks_list)
    def train(self, teachers_arch, teachers_weight_files, teachers_weight_loss,
              teachers_tf_train_paths, student_weight_file):
        """
        :param teachers_arch: an array containing architecture of teacher networks
        :param teachers_weight_files: an array containing teachers h5 files
        :param teachers_weight_loss: an array containing weight of teachers model in loss function
        :param teachers_tf_train_paths: an array containing path of train tf records
        :param student_weight_file : student h5 weight path
        :return: null
        """

        tf_record_util = TFRecordUtility(self.output_len)
        c_loss = Custom_losses()
        '''-------------------------------------'''
        '''     preparing student models        '''
        '''-------------------------------------'''
        teacher_models = []
        cnn = CNNModel()
        for i in range(len(teachers_arch)):
            student_train_images, student_train_landmarks = tf_record_util.create_training_tensor_points(
                tfrecord_filename=teachers_tf_train_paths[i],
                batch_size=self.BATCH_SIZE)
            model = cnn.get_model(train_images=student_train_images,
                                  arch=teachers_arch[i],
                                  num_output_layers=1,
                                  output_len=self.output_len,
                                  input_tensor=None)

            model.load_weights(teachers_weight_files[i])
            teacher_models.append(model)
        '''---------------------------------'''
        '''     creating student model      '''
        '''---------------------------------'''
        '''retrieve tf data'''
        train_images, train_landmarks = tf_record_util.create_training_tensor_points(
            tfrecord_filename=self.tf_train_path, batch_size=self.BATCH_SIZE)
        validation_images, validation_landmarks = tf_record_util.create_training_tensor_points(
            tfrecord_filename=self.tf_eval_path, batch_size=self.BATCH_SIZE)
        '''create model'''
        student_model = cnn.get_model(train_images=train_images,
                                      arch=self.arch,
                                      num_output_layers=1,
                                      output_len=self.output_len,
                                      input_tensor=train_images,
                                      inp_shape=None)
        if student_weight_file is not None:
            student_model.load_weights(student_weight_file)
        '''prepare callbacks'''
        callbacks_list = self._prepare_callback()
        ''' define optimizers'''
        optimizer = Adam(lr=1e-3,
                         beta_1=0.9,
                         beta_2=0.999,
                         decay=1e-5,
                         amsgrad=False)
        '''create loss'''
        # file = open("map_aug" + self.dataset_name, 'rb')
        file = open("map_orig" + self.dataset_name, 'rb')
        landmark_img_map = pickle.load(file)
        file.close()

        # loss_func = c_loss.custom_teacher_student_loss_cos(img_path=self.img_path, lnd_img_map=landmark_img_map,
        #                                                    teacher_models=teacher_models,
        #                                                    teachers_weight_loss=teachers_weight_loss,
        #                                                    bath_size=self.BATCH_SIZE,
        #                                                    num_points=self.output_len, cos_weight=cos_weight)

        loss_func = c_loss.custom_teacher_student_loss(
            img_path=self.img_path,
            lnd_img_map=landmark_img_map,
            teacher_models=teacher_models,
            teachers_weight_loss=teachers_weight_loss,
            bath_size=self.BATCH_SIZE,
            num_points=self.output_len,
            ds_name=self.dataset_name,
            loss_type=0)
        '''compiling model'''
        student_model.compile(loss=loss_func,
                              optimizer=optimizer,
                              metrics=['mse', 'mae'],
                              target_tensors=train_landmarks)

        print('< ========== Start Training Student============= >')
        history = student_model.fit(
            train_images,
            train_landmarks,
            epochs=self.EPOCHS,
            steps_per_epoch=self.STEPS_PER_EPOCH,
            validation_data=(validation_images, validation_landmarks),
            validation_steps=self.STEPS_PER_VALIDATION_EPOCH,
            verbose=1,
            callbacks=callbacks_list)