Пример #1
0
def subset_selection_test():
    labels_path = '../../bdd100k/classification/labels/'
    bdd100k_labels_path = "../../bdd100k/labels/"
    val_labels_csv = '../../bdd100k/classification/labels/val_ground_truth.csv'
    class_map_file = labels_path + 'class_mapping.csv'
    val_json = '../../bdd100k/labels/bdd100k_labels_images_val.json'
    attr_val_file = bdd100k_labels_path + 'bdd100k_labels_images_val_attributes.csv'
    attr_tr_file = bdd100k_labels_path + 'bdd100k_labels_images_train_attributes.csv'
    train_labels = '../../bdd100k/classification/labels/train_ground_truth.csv'

    class_map_file = bu.class_mapping(input_json=val_json,
                                      output_csv=class_map_file)

    # Dataset for analysis
    tr_partition, tr_labels = bu.get_ids_labels(train_labels, class_map_file)

    w, s, tod, wst_dk2ak = bu.wst_attribute_mapping(attr_tr_file)

    d_tod = analyse.DiscreteAttribute(tod)
    d_s = analyse.DiscreteAttribute(s)

    scene_tod_distrib = np.zeros(
        (len(d_tod.get_labels()), len(d_s.get_labels())))
    print(scene_tod_distrib)
    for data_key in tr_partition:
        attr_key = wst_dk2ak(data_key)
        x = d_tod.index_of(d_tod.labelof(attr_key))
        y = d_s.index_of(d_s.labelof(attr_key))
        scene_tod_distrib[x][y] += 1

    print("        " + " / ".join(d_s.get_labels()))
    for k in xrange(len(scene_tod_distrib)):
        print(k)
        print(d_tod.get_labels()[k] + "\t" +
              " ".join([str(val) for val in scene_tod_distrib[k]]))
Пример #2
0
def retinanet_train(tiny=False):
    labels_path = bdd100k_labels_path
    val_json = labels_path + 'bdd100k_labels_images_val.json'
    train_json = labels_path + 'bdd100k_labels_images_train.json'
    val_annot = labels_path + 'val_annotations.csv'
    train_annot = labels_path + 'train_annotations.csv'

    num_data = 70000

    classes = bu.annotate(val_json, val_annot, labels_path, bdd100k_val_path)
    cl_map_path = bu.class_mapping(classes,
                                   output_csv=labels_path +
                                   'class_mapping.csv')
    bu.annotate(train_json, train_annot, bdd100k_labels_path,
                bdd100k_train_path)

    # Hyper-parameters
    batch_size = 1
    steps_per_epoch = np.ceil(num_data / batch_size)

    for m in models:
        print('Generating %s backbone...' % m)
        backbone = kr_models.backbone(m)
        weights = backbone.download_imagenet()
        print('Creating generators...')
        tr_gen, val_gen = mt.create_generators(
            train_annotations=train_annot,
            val_annotations=val_annot,
            class_mapping=cl_map_path,
            base_dir='',
            preprocess_image=backbone.preprocess_image,
            batch_size=batch_size)
        print('Creating models...')
        model, training_model, prediction_model = kr_train.create_models(
            backbone.retinanet, tr_gen.num_classes(), weights)
        print('Creating callbacks...')
        callbacks = mt.create_callbacks(model,
                                        batch_size,
                                        'test',
                                        tensorboard_dir=log_path)

        print('Training...')
        training_model.fit_generator(
            generator=tr_gen,
            steps_per_epoch=steps_per_epoch,  # 10000,
            epochs=2,
            verbose=1,
            callbacks=callbacks,
            workers=4,  # 1
            use_multiprocessing=True,  # False,
            max_queue_size=10,
            validation_data=val_gen)
Пример #3
0
def retinanet_tiny_train():
    labels_path = bdd100k_labels_path
    val_json = labels_path + 'bdd100k_labels_images_val.json'

    num_data = 7000
    batch_size = 1
    steps_per_epoch = np.ceil(num_data / batch_size)

    train_annot, val_annot = bu.annotate_tiny(val_json,
                                              labels_path,
                                              bdd100k_val_path,
                                              overwrite=True)
    cl_map_path = bu.class_mapping(input_json=val_json,
                                   output_csv=labels_path +
                                   'class_mapping.csv')

    for m in models:
        print('Generating %s backbone...' % m)
        backbone = kr_models.backbone(m)
        weights = backbone.download_imagenet()
        print('Creating generators...')
        tr_gen, val_gen = bu.create_generators(
            train_annotations=train_annot,
            val_annotations=val_annot,
            class_mapping=cl_map_path,
            base_dir='',
            preprocess_image=backbone.preprocess_image,
            batch_size=batch_size)
        print('Creating models...')
        model, training_model, prediction_model = kr_train.create_models(
            backbone.retinanet, tr_gen.num_classes(), weights)
        print('Creating callbacks...')
        callbacks = bu.create_callbacks(model,
                                        batch_size,
                                        snapshots_path=retinanet_h5_path,
                                        tensorboard_dir=log_path,
                                        backbone=m,
                                        dataset_type='bdd10k')

        print('Training...')
        training_model.fit_generator(
            generator=tr_gen,
            steps_per_epoch=steps_per_epoch,  # 10000,
            epochs=50,
            verbose=1,
            callbacks=callbacks,
            workers=1,  # 1
            use_multiprocessing=False,  # False,
            max_queue_size=10,
            validation_data=val_gen)
Пример #4
0
def bdd100k_analysis(model_file, do_plot_boxes=False):
    class_map_file = bu.class_mapping(input_json=val_json, output_csv=labels_path + 'class_mapping.csv')

    # Dataset for analysis
    val_partition, val_labels = bu.get_ids_labels(val_labels_csv, class_map_file)

    # Attribute mapping and data_key to attr_key function (dk2ak)
    weather, scene, timeofday, wst_dk2ak = bu.wst_attribute_mapping(attr_val_file)
    box_size, box_size_dk2ak = bu.box_size_attribute_mapping(box_val_file, box_val_json)

    # show_worst(model_file, 10, 'daytime', timeofday, dk2ak=wst_dk2ak)

    # for attr in attributes.values():
    #     print(attr['d_attribute'].get_labels())
    #     print(attr['d_attribute'].get_distribution())

    attributes = {'weather': {'name': 'weather',
                              'map': weather,
                              'dk2ak': wst_dk2ak,
                              'd_attribute': None,
                              'metrics': ['score', 'acc'],
                              'weaks': [[], []]},
                  'scene': {'name': 'scene',
                            'map': scene,
                            'dk2ak': wst_dk2ak,
                            'd_attribute': None,
                            'metrics': ['score', 'acc'],
                            'weaks': [[], []]},
                  'timeofday': {'name': 'timeofday',
                                'map': timeofday,
                                'dk2ak': wst_dk2ak,
                                'd_attribute': None,
                                'metrics': ['score', 'acc'],
                                'weaks': [[], []]},
                  'box_size': {'name': 'box_size',
                               'map': box_size,
                               'dk2ak': box_size_dk2ak,
                               'd_attribute': None,
                               'metrics': ['score', 'acc'],
                               'weaks': [[], []]},
                  }

    bdd100k_model_analysis(model_file, attributes, val_labels)

    if do_plot_boxes:
        plotting.plot_discrete_attribute_scores(attributes, 'score', model_file)
        # plotting.plot_discrete_attribute_scores(attributes, 'acc', model_file)

    return attributes
Пример #5
0
def bdd100k_compare(model_file, ref_file, attribute, metric):
    class_map_file = bu.class_mapping(input_json=val_json, output_csv=labels_path + 'class_mapping.csv')

    # Dataset for analysis
    val_partition, val_labels = bu.get_ids_labels(val_labels_csv, class_map_file)

    # Attribute mapping and data_key to attr_key function (dk2ak)
    weather, scene, timeofday, wst_dk2ak = bu.wst_attribute_mapping(attr_val_file)
    box_size, box_size_dk2ak = bu.box_size_attribute_mapping(box_val_file, box_val_json)

    attributes = {'weather': {'name': 'weather',
                              'map': weather,
                              'dk2ak': wst_dk2ak,
                              'd_attribute': None,
                              'metrics': ['score', 'acc'],
                              'weaks': [[], []]},
                  'scene': {'name': 'scene',
                            'map': scene,
                            'dk2ak': wst_dk2ak,
                            'd_attribute': None,
                            'metrics': ['score', 'acc'],
                            'weaks': [[], []]},
                  'timeofday': {'name': 'timeofday',
                                'map': timeofday,
                                'dk2ak': wst_dk2ak,
                                'd_attribute': None,
                                'metrics': ['score', 'acc'],
                                'weaks': [[], []]},
                  'box_size': {'name': 'box_size',
                               'map': box_size,
                               'dk2ak': box_size_dk2ak,
                               'd_attribute': None,
                               'metrics': ['score', 'acc'],
                               'weaks': [[], []]},
                  }

    bdd100k_model_analysis(model_file, attributes, val_labels)
    labels = attributes[attribute]['d_attribute'].get_labels()
    distrib = [int(v) for v in attributes[attribute]['d_attribute'].get_distribution()]
    series = [attributes[attribute]['d_attribute'].get_metric_value_list(metric, label) for label in labels]
    series_names = ["%s (%1.2f)" % (labels[k], distrib[k] / sum(distrib)) for k in xrange(len(labels))]
    attributes[attribute]['d_attribute'].flush()

    bdd100k_model_analysis(ref_file, attributes, val_labels)
    series_ref = [attributes[attribute]['d_attribute'].get_metric_value_list(metric, label) for label in labels]

    plotting.n_box_plot_compare(series, series_ref, series_names, metric)
Пример #6
0
def bdd100k_sel_partition_test():
    labels_path = '../../bdd100k/classification/labels/'
    train_labels = '../../bdd100k/classification/labels/train_ground_truth.csv'
    val_json = '../../bdd100k/labels/bdd100k_labels_images_val.json'

    class_map_file = bu.class_mapping(input_json=val_json,
                                      output_csv=labels_path +
                                      'class_mapping.csv')

    # Datasets
    tr_partition, tr_labels = bu.get_ids_labels(train_labels, class_map_file)
    ft_partition = tr_partition[500000:1000000]
    sel_partition = analyse.select_ft_data(
        'densenet121_bdd100k_cl0-500k_20ep_woda_ep20_vl0.22.hdf5',
        ft_partition)

    print('selection res=', len(sel_partition))
Пример #7
0
def bdd100k_cc_analysis():
    model_files = ['densenet121_bdd100k_cl0-500k_20ep_woda_ep20_vl0.22.hdf5',
                   ]
                   # 'resnet50_bdd100k_cl0-500k_20ep_woda_ep13_vl0.27.hdf5',
                   # 'mobilenet_bdd100k_cl0-500k_20ep_woda_ep15_vl0.24.hdf5',
                   # 'mobilenetv2_bdd100k_cl0-500k_20ep_woda_ep17_vl0.22.hdf5',
                   # 'nasnet_bdd100k_cl0-500k_20ep_woda_ep17_vl0.24.hdf5']

    class_map_file = bu.class_mapping(input_json=val_json, output_csv=labels_path + 'class_mapping.csv')

    # Dataset for analysis
    val_partition, val_labels = bu.get_ids_labels(val_labels_csv, class_map_file)

    for m in model_files:

        # test_subset creation
        pr_file = '.'.join(m.split('.')[:-1]) + '_predictions.csv'
        predictions, y_scores, img_ids = dt.get_scores_from_file(csv_path + pr_file, val_labels)
        top_n_args, bot_n_args = dt.get_topbot_n_args(20000, y_scores)

        cc_high = metrics_color.ColorDensityCube(resolution=4)
        for arg in top_n_args:
            cc_high.feed(cv2.imread(img_ids[arg]))
        print('high sum', np.sum(cc_high.get_cube().flatten()))
        cc_high.normalize()
        cc_high.plot_cube()

        cc_low = metrics_color.ColorDensityCube(resolution=4)
        for arg in bot_n_args:
            cc_low.feed(cv2.imread(img_ids[arg]))
        print('low sum', np.sum(cc_low.get_cube().flatten()))
        cc_low.normalize()
        cc_low.plot_cube()

        cc_diff = cc_high.substract(cc_low, 'value')
        print('diff mean', np.sum(cc_diff.get_cube().flatten()))
        print('diff mean', np.mean(cc_diff.get_cube().flatten()))
        cc_diff.normalize()
        # cc_diff.plot_cube()

        # cc_diff.normalize()
        cc_diff.plot_cube(title='Color cube analysis difference (' + str(20000) + ' images/series)', normalize=True,
                          save=True)
Пример #8
0
def bdd100k_local_finetune_test(model_files):
    labels_path = '../../bdd100k/classification/labels/'
    train_labels_csv = '../../bdd100k/classification/labels/train_ground_truth.csv'
    val_labels_csv = '../../bdd100k/classification/labels/val_ground_truth.csv'
    val_json = '../../bdd100k/labels/bdd100k_labels_images_val.json'

    # Parameters
    params = {
        'dim': (64, 64, 3),
        'batch_size': 32,
        'n_classes': 10,
        'shuffle': True
    }
    epochs = 30
    class_map_file = bu.class_mapping(input_json=val_json,
                                      output_csv=labels_path +
                                      'class_mapping.csv')

    # Datasets
    tr_partition, tr_labels = bu.get_ids_labels(train_labels_csv,
                                                class_map_file)
    val_partition, val_labels = bu.get_ids_labels(val_labels_csv,
                                                  class_map_file)

    for model_file in model_files:
        ft_partition = tr_partition[500000:1000000]

        if 'densenet121' in model_file:
            ref_generator = mt.DataGenerator(tr_partition[:300000], tr_labels,
                                             **params)
            val_generator = mt.DataGenerator(val_partition[:100000],
                                             val_labels, **params)

            mt.ft(h5_path + model_file,
                  ref_generator,
                  val_generator,
                  epochs,
                  save_history=True,
                  tag='ref2')
            mt.ft(h5_path + model_file,
                  ref_generator,
                  val_generator,
                  epochs,
                  save_history=True,
                  tag='ref3')

        else:

            ref_generator = mt.DataGenerator(tr_partition[:300000], tr_labels,
                                             **params)
            val_generator = mt.DataGenerator(val_partition[:100000],
                                             val_labels, **params)

            mt.ft(h5_path + model_file,
                  ref_generator,
                  val_generator,
                  epochs,
                  save_history=True,
                  tag='ref2')
            mt.ft(h5_path + model_file,
                  ref_generator,
                  val_generator,
                  epochs,
                  save_history=True,
                  tag='ref3')

            epochs = 15
            # daytime timeofday finetuning
            # Selected data partition
            day_sel_partition = analyse.select_ft_data(model_file,
                                                       ft_partition,
                                                       'timeofday',
                                                       'daytime',
                                                       do_plot_boxes=False)
            # Generators
            sp = 4 * len(day_sel_partition) // 5  # split point
            day_ft_generator = mt.DataGenerator(day_sel_partition[:sp],
                                                tr_labels, **params)
            day_val_generator = mt.DataGenerator(day_sel_partition[sp:],
                                                 tr_labels, **params)

            mt.ft(h5_path + model_file,
                  day_ft_generator,
                  day_val_generator,
                  epochs,
                  save_history=True,
                  tag='daytime2')
            mt.ft(h5_path + model_file,
                  day_ft_generator,
                  day_val_generator,
                  epochs,
                  save_history=True,
                  tag='daytime3')

            # Night timeofday finetuning
            night_sel_partition = analyse.select_ft_data(
                model_file, ft_partition, 'timeofday', 'night')
            sp = 4 * len(night_sel_partition) // 5  # split point
            night_ft_generator = mt.DataGenerator(night_sel_partition[:sp],
                                                  tr_labels, **params)
            night_val_generator = mt.DataGenerator(night_sel_partition[sp:],
                                                   tr_labels, **params)

            mt.ft(h5_path + model_file,
                  night_ft_generator,
                  night_val_generator,
                  epochs,
                  save_history=True,
                  tag='night2')
            mt.ft(h5_path + model_file,
                  night_ft_generator,
                  night_val_generator,
                  epochs,
                  save_history=True,
                  tag='night3')

            # Highway scene finetuning
            highway_sel_partition = analyse.select_ft_data(
                model_file, ft_partition, 'scene', 'highway')
            sp = 4 * len(highway_sel_partition) // 5  # split point
            highway_ft_generator = mt.DataGenerator(highway_sel_partition[:sp],
                                                    tr_labels, **params)
            highway_val_generator = mt.DataGenerator(
                highway_sel_partition[sp:], tr_labels, **params)

            mt.ft(h5_path + model_file,
                  highway_ft_generator,
                  highway_val_generator,
                  epochs,
                  save_history=True,
                  tag='highway2')
            mt.ft(h5_path + model_file,
                  highway_ft_generator,
                  highway_val_generator,
                  epochs,
                  save_history=True,
                  tag='highway3')

            # City street scene finetuning
            city_street_sel_partition = analyse.select_ft_data(
                model_file, ft_partition, 'scene', 'city street')
            sp = 4 * len(city_street_sel_partition) // 5  # split point
            city_street_ft_generator = mt.DataGenerator(
                city_street_sel_partition[:sp], tr_labels, **params)
            city_street_val_generator = mt.DataGenerator(
                city_street_sel_partition[sp:], tr_labels, **params)

            mt.ft(h5_path + model_file,
                  city_street_ft_generator,
                  city_street_val_generator,
                  epochs,
                  save_history=True,
                  tag='city_street2')
            mt.ft(h5_path + model_file,
                  city_street_ft_generator,
                  city_street_val_generator,
                  epochs,
                  save_history=True,
                  tag='city_street3')
Пример #9
0
def bdd100k_global_finetune_test(model_files):
    labels_path = '../../bdd100k/classification/labels/'
    train_labels = '../../bdd100k/classification/labels/train_ground_truth.csv'
    val_labels_csv = '../../bdd100k/classification/labels/val_ground_truth.csv'
    # class_map_file = labels_path + 'class_mapping.csv'
    val_json = '../../bdd100k/labels/bdd100k_labels_images_val.json'

    # Parameters
    params = {
        'dim': (64, 64, 3),
        'batch_size': 32,
        'n_classes': 10,
        'shuffle': False
    }

    n_test_data = 100000
    epochs = 30

    class_map_file = bu.class_mapping(input_json=val_json,
                                      output_csv=labels_path +
                                      'class_mapping.csv')

    # Datasets
    tr_partition, tr_labels = bu.get_ids_labels(train_labels, class_map_file)
    val_partition, val_labels = bu.get_ids_labels(val_labels_csv,
                                                  class_map_file)

    for model_file in model_files:
        ft_partition = tr_partition[500000:1000000]
        n_sel_data = 300000
        sel_partition = analyse.select_ft_data(
            model_file, ft_partition)  # Selected data partition

        # Generators
        finetune_generator = mt.DataGenerator(sel_partition[:n_sel_data],
                                              tr_labels, **params)
        reference_generator = mt.DataGenerator(
            tr_partition[500000:500000 + len(ft_partition)], tr_labels,
            **params)
        validation_generator = mt.DataGenerator(val_partition[:n_test_data],
                                                val_labels, **params)

        # finetune
        model = load_model(h5_path + model_file)
        model.compile('adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])

        checkpoint = ModelCheckpoint(model_file.rstrip('.hdf5') +
                                     '_ftep{epoch:02d}_vl{val_loss:.2f}.hdf5',
                                     monitor='val_acc',
                                     verbose=0,
                                     save_best_only=True,
                                     save_weights_only=False,
                                     mode='auto')

        # Train model on selected dataset
        ft_history = model.fit_generator(generator=finetune_generator,
                                         validation_data=validation_generator,
                                         verbose=1,
                                         epochs=epochs,
                                         use_multiprocessing=True,
                                         workers=6,
                                         callbacks=[checkpoint])

        with open(model_file.rstrip('.hdf5') + '_ft_hist.pkl', 'w') as fd:
            pickle.dump(ft_history, fd)

        # reference
        model = load_model(h5_path + model_file)
        model.compile('adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])

        checkpoint = ModelCheckpoint(model_file.rstrip('.h5') +
                                     '_refep{epoch:02d}_vl{val_loss:.2f}.hdf5',
                                     monitor='val_acc',
                                     verbose=0,
                                     save_best_only=True,
                                     save_weights_only=False,
                                     mode='auto')

        # Train model on ref dataset
        ref_history = model.fit_generator(generator=reference_generator,
                                          validation_data=validation_generator,
                                          verbose=1,
                                          epochs=epochs,
                                          use_multiprocessing=True,
                                          workers=6,
                                          callbacks=[checkpoint])

        with open(model_file.rstrip('.hdf5') + '_ref_hist.pkl', 'w') as fd:
            pickle.dump(ref_history, fd)
Пример #10
0
def load_model_test(model_files, overwrite=False):
    for model_file in model_files:
        m = load_model(h5_path + model_file)
        m.summary()
    return
    labels_path = '../../bdd100k/classification/labels/'
    val_labels_csv = '../../bdd100k/classification/labels/val_ground_truth.csv'
    val_json = '../../bdd100k/labels/bdd100k_labels_images_val.json'

    # Parameters
    params = {
        'dim': (64, 64, 3),
        'batch_size': 32,
        'n_classes': 10,
        'shuffle': False
    }

    n_test_data = 100000

    class_map_file = bu.class_mapping(input_json=val_json,
                                      output_csv=labels_path +
                                      'class_mapping.csv')

    # Datasets
    val_partition, val_labels = bu.get_ids_labels(val_labels_csv,
                                                  class_map_file)

    # Generators
    validation_generator = mt.DataGenerator(val_partition[:n_test_data],
                                            val_labels, **params)

    label_distrib = [
        val_labels.values()[:n_test_data].count(k) / n_test_data
        for k in xrange(params['n_classes'])
    ]
    print(label_distrib)

    for model_file in model_files:
        predictions_file = '.'.join(
            model_file.split('.')[:-1]) + '_predictions.csv'

        if os.path.isfile(predictions_file) and not overwrite:
            print('File ' + predictions_file + ' already exists. Not written.')
            return

        start_time = datetime.now()
        m = load_model(h5_path + model_file)
        print('File successfully loaded', model_file, 'in',
              str(datetime.now() - start_time))

        # print("Validation ")
        # start_time = datetime.now()
        # print(m.metrics_names)
        # print(m.evaluate_generator(validation_generator))
        # print('Model successfully evaluated', model_file, 'in (s)', str(datetime.now() - start_time))

        print('Writing predictions')
        out_pr = open(csv_path + predictions_file, 'w')

        start_time = datetime.now()
        y_predicted = m.predict_generator(validation_generator)

        # prediction
        for i in xrange(len(y_predicted)):
            out_pr.write(val_partition[i] + ',' +
                         str(y_predicted[i].tolist()) + '\n')
        out_pr.close()

        predicted_classes = np.argmax(y_predicted, axis=1)

        print('Predictions successfully written', model_file, 'in',
              str(datetime.now() - start_time))
        true_classes = [
            val_labels[id] for id in val_partition[:len(y_predicted)]
        ]
        acc = metrics.accuracy(predicted_classes, true_classes)
        print('acc=', acc)
        print(sk_metrics.confusion_matrix(true_classes, predicted_classes))
Пример #11
0
def train_bdd100k_cl():
    labels_path = '../../bdd100k/classification/labels/'
    train_labels = '../../bdd100k/classification/labels/train_ground_truth.csv'
    val_labels = '../../bdd100k/classification/labels/val_ground_truth.csv'
    # class_map_file = labels_path + 'class_mapping.csv'
    val_json = '../../bdd100k/labels/bdd100k_labels_images_val.json'

    epochs = 20

    # Parameters
    params = {
        'dim': (64, 64, 3),
        'batch_size': 32,
        'n_classes': 10,
        'shuffle': True
    }

    class_map_file = bu.class_mapping(input_json=val_json,
                                      output_csv=labels_path +
                                      'class_mapping.csv')

    # Datasets
    val_partition, val_labels = bu.get_ids_labels(val_labels, class_map_file)
    tr_partition, tr_labels = bu.get_ids_labels(train_labels, class_map_file)

    # Generators
    training_generator = mt.DataGenerator(tr_partition[:500000], tr_labels,
                                          **params)
    validation_generator = mt.DataGenerator(val_partition[:100000], val_labels,
                                            **params)
    print(len(training_generator))

    for m in models:

        weight_file = mt.weight_file_name(m,
                                          'bdd100k_cl0-500k',
                                          epochs,
                                          data_augmentation=False)
        weight_file = h5_path + weight_file
        print("Building: " + weight_file)
        if m in ('mobilenet', 'mobilenetv2', 'nasnet'):
            ###
            model = mt.model_struct(m, (224, 224, 3),
                                    params['n_classes'],
                                    weights='imagenet',
                                    include_top=False)
            new_model = mt.model_struct(m,
                                        params['dim'],
                                        params['n_classes'],
                                        weights=None,
                                        include_top=False)
            print("Loading weights...")

            for new_layer, layer in zip(new_model.layers[1:],
                                        model.layers[1:]):
                new_layer.set_weights(layer.get_weights())
            base_model = new_model
            ###
        else:
            base_model = mt.model_struct(m,
                                         params['dim'],
                                         params['n_classes'],
                                         weights='imagenet',
                                         include_top=False)

        print("Configuring top layers")
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dense(1024, activation='relu')(x)
        predictions = Dense(10, activation='softmax')(x)
        model = Model(inputs=base_model.input, outputs=predictions)
        model.summary()
        # for layer in base_model.layers:
        #     layer.trainable = False

        model.compile('adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])

        checkpoint = ModelCheckpoint(weight_file.rstrip('.h5') +
                                     '_ep{epoch:02d}_vl{val_loss:.2f}.hdf5',
                                     monitor='val_acc',
                                     verbose=0,
                                     save_best_only=True,
                                     save_weights_only=False,
                                     mode='auto')

        # Train model on dataset
        model.fit_generator(generator=training_generator,
                            validation_data=validation_generator,
                            verbose=1,
                            epochs=epochs,
                            use_multiprocessing=True,
                            workers=6,
                            callbacks=[checkpoint])