def train_fit(self): tf_record_util = TFRecordUtility(self.output_len) '''prepare callbacks''' callbacks_list = self._prepare_callback() ''' define optimizers''' optimizer = self._get_optimizer() '''create train, validation, test data iterator''' train_images, train_landmarks = tf_record_util.create_training_tensor_points( tfrecord_filename=self.tf_train_path, batch_size=self.BATCH_SIZE) validation_images, validation_landmarks = \ tf_record_util.create_training_tensor_points(tfrecord_filename=self.tf_eval_path, batch_size=self.BATCH_SIZE) '''creating model''' cnn = CNNModel() # model = tf.keras.models.load_model(self.weight) model = cnn.get_model(train_images=train_images, arch=self.arch, num_output_layers=self.num_output_layers, output_len=self.output_len, input_tensor=train_images, inp_shape=None) if self.weight is not None: model.load_weights(self.weight) '''compiling model''' model.compile( loss=self._generate_loss(), optimizer=optimizer, metrics=['mse', 'mae'], target_tensors=self._generate_target_tensors(train_landmarks), loss_weights=self._generate_loss_weights()) '''train Model ''' print('< ========== Start Training ============= >') history = model.fit(train_images, train_landmarks, epochs=self.EPOCHS, steps_per_epoch=self.STEPS_PER_EPOCH, validation_data=(validation_images, validation_landmarks), validation_steps=self.STEPS_PER_VALIDATION_EPOCH, verbose=1, callbacks=callbacks_list)
def train(self, teachers_arch, teachers_weight_files, teachers_weight_loss, teachers_tf_train_paths, student_weight_file): """ :param teachers_arch: an array containing architecture of teacher networks :param teachers_weight_files: an array containing teachers h5 files :param teachers_weight_loss: an array containing weight of teachers model in loss function :param teachers_tf_train_paths: an array containing path of train tf records :param student_weight_file : student h5 weight path :return: null """ tf_record_util = TFRecordUtility(self.output_len) c_loss = Custom_losses() '''-------------------------------------''' ''' preparing student models ''' '''-------------------------------------''' teacher_models = [] cnn = CNNModel() for i in range(len(teachers_arch)): student_train_images, student_train_landmarks = tf_record_util.create_training_tensor_points( tfrecord_filename=teachers_tf_train_paths[i], batch_size=self.BATCH_SIZE) model = cnn.get_model(train_images=student_train_images, arch=teachers_arch[i], num_output_layers=1, output_len=self.output_len, input_tensor=None) model.load_weights(teachers_weight_files[i]) teacher_models.append(model) '''---------------------------------''' ''' creating student model ''' '''---------------------------------''' '''retrieve tf data''' train_images, train_landmarks = tf_record_util.create_training_tensor_points( tfrecord_filename=self.tf_train_path, batch_size=self.BATCH_SIZE) validation_images, validation_landmarks = tf_record_util.create_training_tensor_points( tfrecord_filename=self.tf_eval_path, batch_size=self.BATCH_SIZE) '''create model''' student_model = cnn.get_model(train_images=train_images, arch=self.arch, num_output_layers=1, output_len=self.output_len, input_tensor=train_images, inp_shape=None) if student_weight_file is not None: student_model.load_weights(student_weight_file) '''prepare callbacks''' callbacks_list = self._prepare_callback() ''' define optimizers''' optimizer = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999, decay=1e-5, amsgrad=False) '''create loss''' # file = open("map_aug" + self.dataset_name, 'rb') file = open("map_orig" + self.dataset_name, 'rb') landmark_img_map = pickle.load(file) file.close() # loss_func = c_loss.custom_teacher_student_loss_cos(img_path=self.img_path, lnd_img_map=landmark_img_map, # teacher_models=teacher_models, # teachers_weight_loss=teachers_weight_loss, # bath_size=self.BATCH_SIZE, # num_points=self.output_len, cos_weight=cos_weight) loss_func = c_loss.custom_teacher_student_loss( img_path=self.img_path, lnd_img_map=landmark_img_map, teacher_models=teacher_models, teachers_weight_loss=teachers_weight_loss, bath_size=self.BATCH_SIZE, num_points=self.output_len, ds_name=self.dataset_name, loss_type=0) '''compiling model''' student_model.compile(loss=loss_func, optimizer=optimizer, metrics=['mse', 'mae'], target_tensors=train_landmarks) print('< ========== Start Training Student============= >') history = student_model.fit( train_images, train_landmarks, epochs=self.EPOCHS, steps_per_epoch=self.STEPS_PER_EPOCH, validation_data=(validation_images, validation_landmarks), validation_steps=self.STEPS_PER_VALIDATION_EPOCH, verbose=1, callbacks=callbacks_list)