def train_with_data_generator(dataset_root_dir=GAN_DATA_ROOT_DIR, weights_file=None): net_name = 'o_net' batch_size = BATCH_SIZE epochs = ONET_EPOCHS learning_rate = ONET_LEARNING_RATE dataset_dir = os.path.join(dataset_root_dir, net_name) pos_dataset_path = os.path.join(dataset_dir, 'pos_shuffle.h5') neg_dataset_path = os.path.join(dataset_dir, 'neg_shuffle.h5') part_dataset_path = os.path.join(dataset_dir, 'part_shuffle.h5') landmarks_dataset_path = os.path.join(dataset_dir, 'landmarks_shuffle.h5') data_generator = DataGenerator(pos_dataset_path, neg_dataset_path, part_dataset_path, landmarks_dataset_path, batch_size, im_size=NET_SIZE['o_net']) data_gen = data_generator.generate() steps_per_epoch = data_generator.steps_per_epoch() callbacks, model_file = create_callbacks_model_file(net_name, epochs) _o_net = train_o_net_with_data_generator(data_gen, steps_per_epoch, initial_epoch=0, epochs=epochs, lr=learning_rate, callbacks=callbacks, weights_file=weights_file) _o_net.save_weights(model_file)
def train_with_data_generator(dataset_root_dir=GAN_DATA_ROOT_DIR, model_file=model_file, weights_file=None): batch_size = 64 * 7 epochs = 30 learning_rate = 0.001 pos_dataset_path = os.path.join(dataset_root_dir, 'pos_shuffle_%s.h5' % (net_name)) neg_dataset_path = os.path.join(dataset_root_dir, 'neg_shuffle_%s.h5' % (net_name)) part_dataset_path = os.path.join(dataset_root_dir, 'part_shuffle_%s.h5' % (net_name)) landmarks_dataset_path = os.path.join( dataset_root_dir, 'landmarks_shuffle_%s.h5' % (net_name)) data_generator = DataGenerator(pos_dataset_path, neg_dataset_path, part_dataset_path, landmarks_dataset_path, batch_size, im_size=12) data_gen = data_generator.generate() steps_per_epoch = data_generator.steps_per_epoch() if net_name == 'Pnet': _net = Pnet() elif net_name == 'Rnet': _net = Rnet() else: _net = Onet() _net_model = _net.model(training=True) _net_model.summary() if weights_file is not None: _net_model.load_weights(weights_file) #sgd = SGD(lr=0.005, momentum=0.8) #_p_net_model.compile(optimizer=sgd, loss=_p_net.loss_func, metrics=[_p_net.accuracy, _p_net.recall]) _net_model.compile(Adam(lr=learning_rate), loss=_net.loss_func, metrics=[_net.accuracy, _net.recall]) _net_model.fit_generator(data_gen, steps_per_epoch=steps_per_epoch, initial_epoch=0, epochs=epochs) _net_model.save_weights(model_file)