コード例 #1
0
ファイル: plot_utils.py プロジェクト: lmc00/TFG
def multi_training_plots(timestamps, legend_loc='upper right'):
    """
    Compare the loss and accuracy metrics for a timestamped training.

    Parameters
    ----------
    timestamps : str, or list of strs
        Configuration dict
    legend_loc: str
        Legend position
    """
    if timestamps is str:
        timestamps = [timestamps]

    fig, axs = plt.subplots(2, 2, figsize=(16, 16))
    axs = axs.flatten()

    for ts in timestamps:

        # Set the timestamp
        paths.timestamp = ts

        # Load training statistics
        stats_path = os.path.join(paths.get_stats_dir(), 'stats.json')
        with open(stats_path) as f:
            stats = json.load(f)

        # Load training configuration
        conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
        with open(conf_path) as f:
            conf = json.load(f)

        # Training
        axs[0].plot(stats['epoch'], stats['loss'], label=ts)
        axs[1].plot(stats['epoch'], stats['acc'], label=ts)

        # Validation
        if conf['training']['use_validation']:
            axs[2].plot(stats['epoch'], stats['val_loss'], label=ts)
            axs[3].plot(stats['epoch'], stats['val_acc'], label=ts)

    axs[1].set_ylim([0, 1])
    axs[3].set_ylim([0, 1])

    for i in range(4):
        axs[0].set_xlabel('Epochs')

    axs[0].set_title('Training Loss')
    axs[1].set_title('Training Accuracy')
    axs[2].set_title('Validation Loss')
    axs[3].set_title('Validation Accuracy')

    axs[0].legend(loc=legend_loc)
コード例 #2
0
def train_fn(TIMESTAMP, CONF):

    paths.timestamp = TIMESTAMP
    paths.CONF = CONF

    utils.create_dir_tree()
    utils.backup_splits()

    # Load the training data
    X_train, y_train = load_data_splits(
        splits_dir=paths.get_ts_splits_dir(),
        im_dir=paths.get_images_dir(),
        use_location=CONF['training']['use_location'],
        split_name='train')

    # Load the validation data
    if (CONF['training']['use_validation']) and ('val.txt' in os.listdir(
            paths.get_ts_splits_dir())):
        X_val, y_val = load_data_splits(
            splits_dir=paths.get_ts_splits_dir(),
            im_dir=paths.get_images_dir(),
            use_location=CONF['training']['use_location'],
            split_name='val')
    else:
        print('No validation data.')
        X_val, y_val = None, None
        CONF['training']['use_validation'] = False

    # Load the class names
    class_names = load_class_names(splits_dir=paths.get_ts_splits_dir())

    # Update the configuration
    CONF['model']['preprocess_mode'] = model_utils.model_modes[CONF['model']
                                                               ['modelname']]
    CONF['training']['batch_size'] = min(CONF['training']['batch_size'],
                                         len(X_train))

    if CONF['model']['num_classes'] is None:
        CONF['model']['num_classes'] = len(class_names)

    assert CONF['model']['num_classes'] >= np.amax(
        y_train
    ), "Your train.txt file has more categories than those defined in classes.txt"
    if CONF['training']['use_validation']:
        assert CONF['model']['num_classes'] >= np.amax(
            y_val
        ), "Your val.txt file has more categories than those defined in classes.txt"

    # Compute the class weights
    if CONF['training']['use_class_weights']:
        class_weights = compute_classweights(
            y_train, max_dim=CONF['model']['num_classes'])
    else:
        class_weights = None

    # Compute the mean and std RGB values
    if CONF['dataset']['mean_RGB'] is None:
        CONF['dataset']['mean_RGB'], CONF['dataset'][
            'std_RGB'] = compute_meanRGB(X_train)

    #Create data generator for train and val sets
    train_gen = data_sequence(X_train,
                              y_train,
                              batch_size=CONF['training']['batch_size'],
                              num_classes=CONF['model']['num_classes'],
                              im_size=CONF['model']['image_size'],
                              mean_RGB=CONF['dataset']['mean_RGB'],
                              std_RGB=CONF['dataset']['std_RGB'],
                              preprocess_mode=CONF['model']['preprocess_mode'],
                              aug_params=CONF['augmentation']['train_mode'])
    train_steps = int(np.ceil(len(X_train) / CONF['training']['batch_size']))

    if CONF['training']['use_validation']:
        val_gen = data_sequence(
            X_val,
            y_val,
            batch_size=CONF['training']['batch_size'],
            num_classes=CONF['model']['num_classes'],
            im_size=CONF['model']['image_size'],
            mean_RGB=CONF['dataset']['mean_RGB'],
            std_RGB=CONF['dataset']['std_RGB'],
            preprocess_mode=CONF['model']['preprocess_mode'],
            aug_params=CONF['augmentation']['val_mode'])
        val_steps = int(np.ceil(len(X_val) / CONF['training']['batch_size']))
    else:
        val_gen = None
        val_steps = None

    # Launch the training
    t0 = time.time()

    # Create the model and compile it
    model, base_model = model_utils.create_model(CONF)

    # Get a list of the top layer variables that should not be applied a lr_multiplier
    base_vars = [var.name for var in base_model.trainable_variables]
    all_vars = [var.name for var in model.trainable_variables]
    top_vars = set(all_vars) - set(base_vars)
    top_vars = list(top_vars)

    # Set trainable layers
    if CONF['training']['mode'] == 'fast':
        for layer in base_model.layers:
            layer.trainable = False

    model.compile(optimizer=customAdam(lr=CONF['training']['initial_lr'],
                                       amsgrad=True,
                                       lr_mult=0.1,
                                       excluded_vars=top_vars),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit_generator(generator=train_gen,
                                  steps_per_epoch=train_steps,
                                  epochs=CONF['training']['epochs'],
                                  class_weight=class_weights,
                                  validation_data=val_gen,
                                  validation_steps=val_steps,
                                  callbacks=utils.get_callbacks(CONF),
                                  verbose=1,
                                  max_queue_size=5,
                                  workers=4,
                                  use_multiprocessing=True,
                                  initial_epoch=0)

    # Saving everything
    print('Saving data to {} folder.'.format(paths.get_timestamped_dir()))
    print('Saving training stats ...')
    stats = {
        'epoch': history.epoch,
        'training time (s)': round(time.time() - t0, 2),
        'timestamp': TIMESTAMP
    }
    stats.update(history.history)
    stats = json_friendly(stats)
    stats_dir = paths.get_stats_dir()
    with open(os.path.join(stats_dir, 'stats.json'), 'w') as outfile:
        json.dump(stats, outfile, sort_keys=True, indent=4)

    print('Saving the configuration ...')
    model_utils.save_conf(CONF)

    print('Saving the model to h5...')
    fpath = os.path.join(paths.get_checkpoints_dir(), 'final_model.h5')
    model.save(fpath, include_optimizer=False)

    # print('Saving the model to protobuf...')
    # fpath = os.path.join(paths.get_checkpoints_dir(), 'final_model.proto')
    # model_utils.save_to_pb(model, fpath)

    print('Finished')