# action = 'evaluate'

# validation dataset
dataset_val = u.get_dataset(coco_path, 'val')
gen_val = prepare(dataset_val, epochs, batch_size, input_shape, output_shape)

#  fuu = next(gen_val)
#  import ipdb; ipdb.set_trace()

# train dataset
dataset_train = u.get_dataset(coco_path, 'train')
gen_train = prepare(
    dataset_train, epochs, batch_size, input_shape, output_shape)

callback = ModelSaveBestAvgAcc(
    filepath="model-{epoch:02d}-{acc:.2f}.hdf5",
    verbose=True
)

losses = []
for i in range(0, 4):
    losses.append(binary_focal_loss(gamma=2.))

model = get_model(input_shape)
model.compile(
    optimizer=opt.Adam(lr=1e-4),
    loss=losses,
    metrics=['accuracy']
)

model.summary()
# action = 'evaluate'

# validation dataset
dataset_val = u.get_dataset(coco_path, 'val')
gen_val = prepare(dataset_val, epochs, batch_size, input_shape, output_shape)

#  fuu = next(gen_val)
#  import ipdb; ipdb.set_trace()

#  train dataset
dataset_train = u.get_dataset(coco_path, 'train')
gen_train = prepare(dataset_train, epochs, batch_size, input_shape,
                    output_shape)

callback = ModelSaveBestAvgAcc(filepath="model-{epoch:02d}-{avgacc:.2f}.hdf5",
                               verbose=True,
                               cond=filter_val('fmeasure'))

losses = []
for i in range(0, 1):
    losses.append(binary_focal_loss(gamma=2.))

input_tensor = layers.Input(shape=input_shape)
model = get_model(input_tensor=input_tensor)

x = layers.Multiply()([input_tensor, model.output])
x = model(x)

#  import ipdb; ipdb.set_trace()
model = models.Model(inputs=input_tensor, outputs=x)