Пример #1
0
    sys.path.append(str(child_dir))

from custom_metrics import ModelMetrics
from model_logging import ModelCheckpointCustom
from print_stats import PrintStats
from tensorboard_callback import TensorboardCallback
from losses import get_loss
from optimizers import get_optimizer

from models.model import get_model
from config import TRAIN_DATA_PATH
from utils.utils import Params, TrainOps, Logging, data_split
from generator_2d import DataGenerator

params = Params("params.json")
params.data_path = TRAIN_DATA_PATH

logging = Logging("./logs", params)

ids = os.listdir(os.path.join(params.data_path, "images"))
train_ids, validation_ids, test_ids = data_split(ids, params)

logging.create_model_directory()
params.model_directory = logging.model_directory

# saving model config file to model output dir
logging.save_dict_to_json(logging.model_directory + "/config.json")

# Generators
train_generator = DataGenerator(train_ids, params=params, is_training=True)
validation_generator = DataGenerator(validation_ids,
Пример #2
0
def main(flags):
    params = Params("params.json")
    params.data_path = TRAIN_DATA_PATH

    params.cv_iteration = flags.cfs_cv_iteration

    logging = Logging(flags.save_model_dir, params)

    ids = os.listdir(os.path.join(params.data_path, "images"))
    train_ids, validation_ids, test_ids = data_split(ids, params)

    test_id = [test_ids[params.cv_iteration]]

    # log test id
    params.test_id = test_id[0]

    print("Test records is: ", test_id[0])

    test_ids = [id_ for id_ in test_ids if id_ not in test_id]
    extra_ids = test_ids
    random.shuffle(extra_ids)

    train_ids = train_ids + extra_ids[0:int(len(extra_ids) * 0.75)]
    validation_ids = validation_ids + extra_ids[int(len(extra_ids) * 0.75):]

    print(f"Number of training samples: {len(train_ids)}, "
          f"number of validation samples: {len(validation_ids)}, "
          f"number of test sample: {len(test_id)}")

    logging.create_model_directory(
        model_dir=f"{flags.save_model_dir}/{test_id[0].replace('.png', '')}")
    params.model_directory = logging.model_directory

    # saving model config file to model output dir
    logging.save_dict_to_json(logging.model_directory + "/config.json")

    # Generators
    train_generator = DataGenerator(train_ids, params=params, is_training=True)
    validation_generator = DataGenerator(validation_ids,
                                         params=params,
                                         is_training=False)

    trainops = TrainOps(params, num_records=len(train_generator))

    optimizer = get_optimizer(params, trainops)
    loss_fn = get_loss(params)

    model_metrics = ModelMetrics(params)
    tb_callback = TensorboardCallback(model_dir=params.model_directory)
    model_checkpoint = ModelCheckpointCustom(monitor="val_acc",
                                             model_dir=params.model_directory,
                                             mode="max")
    print_stats = PrintStats(params=params)

    # get model
    model = get_model(params)

    for epoch in range(params.num_epochs):
        # Iterate over the batches of the dataset.
        for step, (x_batch_train,
                   y_batch_train) in tqdm(enumerate(train_generator)):
            with tf.GradientTape() as tape:
                logits = model(x_batch_train, training=True)
                loss = loss_fn(y_batch_train, logits)

            grads = tape.gradient(loss, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))
            current_lr = optimizer._decayed_lr(tf.float32).numpy()
            print(
                f"\nOpt Iteration: {optimizer.__dict__['_iterations'].numpy()} "
                f"learning rate: {current_lr} loss: {np.round(loss.numpy(), 2):.2f}"
            )

            # Update training metric.
            model_metrics.update_metric_states(y_batch_train,
                                               logits,
                                               mode="train")

        # Display metrics at the end of each epoch.
        train_result_dict = model_metrics.result_metrics(mode="train")

        tb_callback.on_epoch_end(epoch=epoch,
                                 logging_dict=train_result_dict,
                                 lr=current_lr)

        # Run a validation loop at the end of each epoch.
        for x_batch_val, y_batch_val in validation_generator:
            val_logits = model(x_batch_val, training=False)
            val_loss = loss_fn(y_batch_val, val_logits)

            # Update val metrics
            model_metrics.update_metric_states(y_batch_val,
                                               val_logits,
                                               mode="val")

        print(f"validation loss is: f'{val_loss.numpy():.2f}'")

        val_result_dict = model_metrics.result_metrics(mode="val")

        tb_callback.on_epoch_end(epoch=epoch, logging_dict=val_result_dict)
        model_checkpoint.on_epoch_end(epoch,
                                      model,
                                      logging_dict=val_result_dict)
        print_stats.on_epoch_end(epoch,
                                 train_dict=train_result_dict,
                                 validation_dict=val_result_dict,
                                 lr=current_lr)

        # Reset training metrics at the end of each epoch
        model_metrics.reset_metric_states(mode="train")
        model_metrics.reset_metric_states(mode="val")