save_best_only=True,
    verbose=1)
cats_early_stopping = EarlyStopping(monitor='val_accuracy',
                                    patience=10,
                                    verbose=1)
cats_reduce_lr = ReduceLROnPlateau(patience=5, verbose=1)
"""FIT MODEL"""
print('============ binary model fit ============')
binary_model.fit(
    train_generator,
    epochs=100,
    steps_per_epoch=np.ceil(len(train_df) / TRAIN_BATCH_SIZE),
    validation_data=val_dataset,
    callbacks=[binary_checkpoint, binary_early_stopping, binary_reduce_lr])
# load the best (on validation) weights from .fit() phase
binary_model.load_weights(
    filepath=f'advanced_weights_binary_{model_name}.hdf5')

print('============ dogs model fit ============')
dogs_model.fit(
    dogs_train_generator,
    epochs=100,
    steps_per_epoch=np.ceil(len(dogs_train_df) / TRAIN_BATCH_SIZE),
    validation_data=dogs_val_dataset,
    callbacks=[dogs_checkpoint, dogs_early_stopping, dogs_reduce_lr])
# load the best (on validation) weights from .fit() phase
dogs_model.load_weights(filepath=f'advanced_weights_dogs_{model_name}.hdf5')

print('============ cats model fit ============')
cats_model.fit(
    cats_train_generator,
    epochs=100,
elif model_name == 'xception':
    model = Xception(weights=None, classes=num_of_classes)
else:
    raise ValueError(f"not supported model name {model_name}")

"""MODEL PARAMS AND CALLBACKS"""
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
checkpoint = ModelCheckpoint(filepath=f'flat_weights_{model_name}.hdf5', save_best_only=True, verbose=1)
early_stopping = EarlyStopping(monitor='val_accuracy', patience=10, verbose=1)
reduce_lr = ReduceLROnPlateau(patience=5, verbose=1)

"""FIT MODEL"""
print('============ fit flat model ============')
model.fit(train_generator, epochs=100, steps_per_epoch=np.ceil(len(train_df) / TRAIN_BATCH_SIZE),
          validation_data=val_dataset, callbacks=[checkpoint, early_stopping, reduce_lr])
# load the best (on validation) weights from .fit() phase
model.load_weights(filepath=f'flat_weights_{model_name}.hdf5')

print('============ predict flat model ============')
# we use the .predict() method and not .evaluate() so we can generate a confusion matrix (for example)
classes = train_generator.class_indices
inverted_classes = dict(map(reversed, classes.items()))
predictions = model.predict(test_dataset)
predictions = tf.argmax(predictions, axis=-1).numpy()
inverted_class_predictions = [inverted_classes[i] for i in predictions]

test_df['flat_prediction'] = inverted_class_predictions

accuracy = len(test_df[test_df['breed'] == test_df['flat_prediction']]) / len(test_df)
print(f'\n#RESULTS {model_name}# Flat Animal breed accuracy: {accuracy}. BATCH_SIZE: {TRAIN_BATCH_SIZE}')