def main(args, train_set, class_num, pre_ckpt, model_def, depth_multiplier, is_augmenter, image_size, output_size, batch_size, rand_seed, max_nrof_epochs, init_learning_rate, learning_rate_decay_factor, obj_weight, noobj_weight, wh_weight, obj_thresh, iou_thresh, vaildation_split, log_dir, is_prune, initial_sparsity, final_sparsity, end_epoch, frequency): # Build path log_dir = (Path(log_dir) / datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') ) # type: Path ckpt_weights = log_dir / 'yolo_weights.h5' ckpt = log_dir / 'yolo_model.h5' if not log_dir.exists(): log_dir.mkdir(parents=True) write_arguments_to_file(args, str(log_dir / 'args.txt')) # Build utils h = Helper(f'data/{train_set}_img_ann.npy', class_num, f'data/{train_set}_anchor.npy', np.reshape(np.array(image_size), (-1, 2)), np.reshape(np.array(output_size), (-1, 2)), vaildation_split) h.set_dataset(batch_size, rand_seed, is_training=(is_augmenter == 'True')) # Build network network = eval(model_def) # type :yolo_mobilev2 yolo_model, yolo_model_warpper = network([image_size[0], image_size[1], 3], len(h.anchors[0]), class_num, alpha=depth_multiplier) if pre_ckpt != None and pre_ckpt != 'None' and pre_ckpt != '': if 'h5' in pre_ckpt: yolo_model_warpper.load_weights(str(pre_ckpt)) print(INFO, f' Load CKPT {str(pre_ckpt)}') else: print(ERROR, ' Pre CKPT path is unvalid') # prune model pruning_params = { 'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=initial_sparsity, final_sparsity=final_sparsity, begin_step=0, end_step=h.train_epoch_step * end_epoch, frequency=frequency) } if is_prune == 'True': train_model = sparsity.prune_low_magnitude(yolo_model_warpper, **pruning_params) else: train_model = yolo_model_warpper train_model.compile( keras.optimizers.Adam(lr=init_learning_rate, decay=learning_rate_decay_factor), loss=[ create_loss_fn(h, obj_thresh, iou_thresh, obj_weight, noobj_weight, wh_weight, layer) for layer in range( len(train_model.output) if isinstance(train_model.output, list ) else 1) ], metrics=[ Yolo_Precision(obj_thresh, name='p'), Yolo_Recall(obj_thresh, name='r') ]) """ NOTE fix the dataset output shape """ shapes = (train_model.input.shape, tuple(h.output_shapes)) h.train_dataset = h.train_dataset.apply(assert_element_shape(shapes)) h.test_dataset = h.test_dataset.apply(assert_element_shape(shapes)) """ Callbacks """ if is_prune == 'True': cbs = [ sparsity.UpdatePruningStep(), sparsity.PruningSummaries(log_dir=str(log_dir), profile_batch=0) ] else: cbs = [TensorBoard(str(log_dir), update_freq='batch', profile_batch=3)] # Training try: train_model.fit(h.train_dataset, epochs=max_nrof_epochs, steps_per_epoch=h.train_epoch_step, callbacks=cbs, validation_data=h.test_dataset, validation_steps=int(h.test_epoch_step * h.validation_split)) except KeyboardInterrupt as e: pass if is_prune == 'True': final_model = sparsity.strip_pruning(train_model) prune_ckpt = log_dir / 'yolo_prune_model.h5' keras.models.save_model(yolo_model, str(prune_ckpt), include_optimizer=False) print() print(INFO, f' Save Pruned Model as {str(prune_ckpt)}') else: keras.models.save_model(yolo_model, str(ckpt)) print() print(INFO, f' Save Model as {str(ckpt)}')
def train(FLAGS): """Train yolov3 with different backbone """ prune = FLAGS['prune'] opt = FLAGS['opt'] backbone = FLAGS['backbone'] log_dir = FLAGS['log_directory'] or os.path.join( 'logs', str(backbone).split('.')[1].lower() + str(datetime.date.today())) if tf.io.gfile.exists(log_dir) is not True: tf.io.gfile.mkdir(log_dir) batch_size = FLAGS['batch_size'] train_dataset_glob = FLAGS['train_dataset'] val_dataset_glob = FLAGS['val_dataset'] test_dataset_glob = FLAGS['test_dataset'] freeze = FLAGS['freeze'] freeze_step = FLAGS['epochs'][0] train_step = FLAGS['epochs'][1] if opt == OPT.DEBUG: tf.config.experimental_run_functions_eagerly(True) tf.debugging.set_log_device_placement(True) tf.get_logger().setLevel(tf.logging.DEBUG) elif opt == OPT.XLA: config = tf.ConfigProto() config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 sess = tf.Session(config=config) tf.keras.backend.set_session(sess) class_names = get_classes(FLAGS['classes_path']) num_classes = len(class_names) anchors = get_anchors(FLAGS['anchors_path']) input_shape = FLAGS['input_size'] # multiple of 32, hw model_path = FLAGS['model'] if model_path and model_path.endswith('.h5') is not True: model_path = tf.train.latest_checkpoint(model_path) lr = FLAGS['learning_rate'] tpu_address = FLAGS['tpu_address'] if tpu_address is not None: cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=tpu_address) tf.config.experimental_connect_to_host(cluster_resolver.master()) tf.tpu.experimental.initialize_tpu_system(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) else: strategy = tf.distribute.MirroredStrategy(devices=FLAGS['gpus']) batch_size = batch_size * strategy.num_replicas_in_sync train_dataset_builder = Dataset(train_dataset_glob, batch_size, anchors, num_classes, input_shape) train_dataset, train_num = train_dataset_builder.build() val_dataset_builder = Dataset(val_dataset_glob, batch_size, anchors, num_classes, input_shape, mode=DATASET_MODE.VALIDATE) val_dataset, val_num = val_dataset_builder.build() map_callback = MAPCallback(test_dataset_glob, input_shape, anchors, class_names) logging = tf.keras.callbacks.TensorBoard(write_graph=False, log_dir=log_dir, write_images=True) checkpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join( log_dir, 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'), monitor='val_loss', save_weights_only=False, save_best_only=False, period=1) cos_lr = tf.keras.callbacks.LearningRateScheduler( lambda epoch, _: tf.keras.experimental.CosineDecay(lr[1], train_step) (epoch - freeze_step).numpy(), 1) early_stopping = tf.keras.callbacks.EarlyStopping( monitor='val_loss', min_delta=0, patience=(freeze_step + train_step) // 10, verbose=0) if tf.version.VERSION.startswith('1.'): loss = [ lambda y_true, yolo_output: YoloLoss( y_true, yolo_output, 0, anchors, print_loss=True) ] else: loss = [ YoloLoss(idx, anchors, print_loss=False) for idx in range(len(anchors) // 3) ] with strategy.scope(): #factory = ModelFactory(tf.keras.layers.Input(shape=(*input_shape, 3)), # weights_path=model_path) factory = ModelFactory(tf.keras.layers.Input(shape=(*input_shape, 3))) if backbone == BACKBONE.MOBILENETV2: model = factory.build(mobilenetv2_yolo_body, 20, len(anchors) // 1, num_classes, alpha=1.0) elif backbone == BACKBONE.DARKNET53: model = factory.build(darknet_yolo_body, 185, len(anchors) // 3, num_classes) elif backbone == BACKBONE.EFFICIENTNET: FLAGS['model_name'] = 'efficientnet-b4' model = factory.build( efficientnet_yolo_body, 20, # todo FLAGS['model_name'], len(anchors) // 2, batch_norm_momentum=0.9, batch_norm_epsilon=1e-3, num_classes=num_classes, drop_connect_rate=0.2, data_format="channels_first") if prune: from tensorflow_model_optimization.python.core.api.sparsity import keras as sparsity end_step = np.ceil(1.0 * train_num / batch_size).astype( np.int32) * train_step new_pruning_params = { 'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.5, final_sparsity=0.9, begin_step=0, end_step=end_step, frequency=1000) } pruned_model = sparsity.prune_low_magnitude(model, **new_pruning_params) pruned_model.compile(optimizer=tf.keras.optimizers.Adam(lr[0], epsilon=1e-8), loss=loss) pruned_model.fit(train_dataset, epochs=train_step, initial_epoch=0, steps_per_epoch=max(1, train_num // batch_size), callbacks=[ checkpoint, cos_lr, logging, map_callback, early_stopping ], validation_data=val_dataset, validation_steps=max(1, val_num // batch_size)) model = sparsity.strip_pruning(pruned_model) model.save_weights( os.path.join( log_dir, str(backbone).split('.')[1].lower() + '_trained_weights_pruned.h5')) with zipfile.ZipFile(os.path.join( log_dir, str(backbone).split('.')[1].lower() + '_trained_weights_pruned.h5.zip'), 'w', compression=zipfile.ZIP_DEFLATED) as f: f.write( os.path.join( log_dir, str(backbone).split('.')[1].lower() + '_trained_weights_pruned.h5')) return # Train with frozen layers first, to get a stable loss. # Adjust num epochs to your dataset. This step is enough to obtain a not bad model. if freeze is True: with strategy.scope(): model.compile(optimizer=tf.keras.optimizers.Adam(lr[0], epsilon=1e-8), loss=loss) model.fit(train_dataset, epochs=freeze_step, initial_epoch=0, steps_per_epoch=max(1, train_num // batch_size), callbacks=[logging, checkpoint], validation_data=val_dataset, validation_steps=max(1, val_num // batch_size)) model.save_weights( os.path.join( log_dir, str(backbone).split('.')[1].lower() + '_trained_weights_stage_1.h5')) # Unfreeze and continue training, to fine-tune. # Train longer if the result is not good. else: #if 1: for i in range(len(model.layers)): model.layers[i].trainable = True with strategy.scope(): model.compile(optimizer=tf.keras.optimizers.Adam(lr[1], epsilon=1e-8), loss=loss) # recompile to apply the change print('Unfreeze all of the layers.') model.fit( train_dataset, epochs=train_step + freeze_step, initial_epoch=freeze_step, steps_per_epoch=max(1, train_num // batch_size), callbacks=[ checkpoint, cos_lr, logging, early_stopping #map_callback ], validation_data=val_dataset, validation_steps=max(1, val_num // batch_size)) model.save_weights( os.path.join( log_dir, str(backbone).split('.')[1].lower() + '_trained_weights_final.h5'))