Esempio n. 1
0
def model_fit(
        model: EmbeddingModel,
        training_set: tf.data.Dataset,
        validation_set: tf.data.Dataset,
        export_path: str,
        log_dir: str,
        hparams: Dict[hp.HParam, Any],
        epochs: int = 20,
        verbose: int = 1,
        worker: int = 4,
) -> None:

    # tensorboard logging for standard metrics
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        profile_batch=(300, 320)
    )

    # tensorboard logging for hyperparameters
    keras_callback = hp.KerasCallback(
        writer=log_dir,
        hparams=hparams,
        trial_id=log_dir
    )

    metrics = [
        tf.keras.metrics.AUC(name='auc'),
    ]
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        optimizer='adam',
        metrics=metrics)

    _ = model.fit(
        training_set,
        validation_data=validation_set,
        epochs=epochs,
        verbose=verbose,
        callbacks=[tensorboard_callback, keras_callback],
        workers=worker)

    model.summary()
    tf.saved_model.save(
        obj=model,
        export_dir=os.path.join(export_path, '1')
    )
def main(
    train_path: str,
    test_path: str,
    layer_dir: str,
    log_dir: str,
    export_dir: str,
    batch_size: int,
    embedding_dim_base: int,
    epochs: int,
) -> None:

    # init hparams
    hparams = hparams_init(epochs=epochs,
                           batch_size=batch_size,
                           log_dir=log_dir)
    lst_feature = CAT_COLUMNS
    lst_feature.append('numeric')

    # data preparation
    train, validate = train_test_prep(train_path=train_path,
                                      test_path=test_path,
                                      batch_size=batch_size)

    # model
    binary_classifier = EmbeddingModel(
        lst_features=lst_feature,
        layer_dir=layer_dir,
        name='binary_classifier',
        embedding_dim_base=embedding_dim_base,
    )

    # training
    model_fit(
        model=binary_classifier,
        training_set=train,
        validation_set=validate,
        export_path=export_dir,
        log_dir=log_dir,
        hparams=hparams,
        epochs=epochs,
    )
Esempio n. 3
0
def main(cfg):

    #Configuramos para utitilizar toda la memoria de los dispoisitvo GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:

                tf.config.experimental.set_memory_growth(gpu, True)

        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)


    # TODO:  Habilitar loggers
    # Definimos la ubicación donde se almacena los logs de Tensorboard
    train_summary_writer = tf.summary.create_file_writer(cfg.dirs.train_log)


    # Preparamos el DATASET
    # Cargamos los datos desde el fichero CSV
    pids, fids = load_dataset(cfg.dirs.csv_file, cfg.dirs.images)

    # Obtenemos todas las etiquetas
    unique_pids = np.unique(pids)
    print ("Etiquetas únicas: {}".format(len(unique_pids)))
    # Preparamos un dataset donde en cada época se  se cubran todos los valores de PID
    # y si se distribuyan uniformemente.
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    if len(unique_pids) < cfg.model.batch.P:
        unique_pids = np.tile(unique_pids, int(np.ceil(cfg.model.batch.P / len(unique_pids))))
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    # Cogemos valores aleatorios
    dataset = dataset.shuffle(len(unique_pids))
    # Forzamos que el tamaño del dataset sea múltiplo de batch-size
    dataset = dataset.take((len(unique_pids) // cfg.model.batch.P) * cfg.model.batch.P)
    dataset = dataset.repeat(None)

    # Para cada PID obtenemos el cfg.model.batch.K imagenes
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=cfg.model.batch.K))

    # Desagrupamos los cfg.model.batch.K para una mejor carga de las imágenes
    dataset = dataset.unbatch()

    # Ahora cargamos las imágenes transformandolas al tamaño del emb_modelo

    net_input_size = (cfg.model.input.height, cfg.model.input.width)
    pre_crop_size = (cfg.model.crop.height, cfg.model.crop.width)

    # Comprobamos si queremos hacer el crop

    image_size = pre_crop_size if cfg.model.crop else net_input_size

    # Redimensionamos las imágenes
    dataset = dataset.map(lambda fid, pid: fid_to_image (
        fid, pid, cfg.dirs.images, image_size), num_parallel_calls=cfg.loading_threads)

    # Si se ha habilitado el CROP redimensionamos al tamaño de red

    if cfg.model.crop:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_crop(im, net_input_size + (3,)),
                               fid,
                               pid))



    #Carga de y un model Predefinido y preparación para el entrenamiento
    # Definimos la fase de nuestro entorno TF 0 = test, 1= train
    tf.keras.backend.set_learning_phase(1)
    emb_model = EmbeddingModel(cfg)
    # Agrupamos el dataset en los batch_size
    # y preprocesamos la imagen para prepararlo para el emb_modelo
    batch_size = cfg.model.batch.P * cfg.model.batch.K
    dataset = dataset.map(lambda im, fid, pid: (emb_model.preprocess_input(im), fid, pid))
    dataset = dataset.batch(batch_size)

    print ('Batch-size: {}'.format(batch_size))

    # Preparación de los siguientes batch
    # Esto mejora la latencia y el rendimiento en el coste computacional de usar
    # memoria adicional  para almacenar los siguientes batch
    dataset = dataset.prefetch(2)


    # Establecemos el optimizador y la programación ratio de aprendizaje (learning-rate schedule)
    if 0 <=cfg.model.fit.decay_start_iteration  <cfg.model.fit.epochs:
           cfg.model.fit.lr = tf.optimizers.schedules.PolynomialDecay(cfg.model.fit.lr,cfg.model.fit.epochs,
                                                      end_learning_rate=1e-7)
    else:
       cfg.model.fit.lr =cfg.model.fit.lr


    if cfg.model.fit.optimizer== 'adam':
       optimizer= tf.keras.optimizers.Adam(cfg.model.fit.lr)
    elif cfg.model.fit.optimizer== 'SGD':
       optimizer= tf.keras.optimizers.SGD(cfg.model.fit.lr, momentum=0.9)
    else:
        raise NotImplementedError('Optimizador no válido {}'.format(cfg.model.fit.optimizer))


    # Iniciamos el entrenamiento

    start_step = 0
    dataset_iter = iter(dataset)

    # Definimos los checkpoint
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=emb_model)
    manager = tf.train.CheckpointManager(ckpt, cfg.dirs.checkpoint, max_to_keep=3)

    # Recuperamos el último checkpoint
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Recuperado desde {}".format(manager.latest_checkpoint))
    else:
        print("Inicializado desde inicio.")

    # Almacenamos los datos para tensorboard
    tf.summary.trace_on(graph=True, profiler=True)

    #  Función para facilmente cambiar los estados del optimizador
    @contextlib.contextmanager
    def options(options):
      old_opts = tf.config.optimizer.get_experimental_options()
      tf.config.optimizer.set_experimental_options(options)
      try:
        yield
      finally:
        tf.config.optimizer.set_experimental_options(old_opts)




    @tf.function(experimental_relax_shapes=True)
    def train_step(images, pids, iteration ):
        cfg = emb_model.cfg
        with tf.GradientTape() as tape:
            # Obtenemos de cada batch los correspondientes vectores
            # de característias
            batch_embedding = emb_model(images)
            # Realizamos la norma de orden 2 si procede
            if emb_model.l2_embedding:
                batch_embedding = tf.nn.l2_normalize(batch_embedding, -1)
            else:
                batch_embedding = batch_embedding
            # Aplicacmos una función de perdia
            if cfg.model.fit.loss == 'semi_hard_triplet':
                embedding_loss = triplet_semihard_loss(batch_embedding, pids, cfg.model.fit.margin)
            elif cfg.model.fit.loss == 'hard_triplet':
                embedding_loss = batch_hard(batch_embedding, pids, cfg.model.fit.margin, cfg.model.fit.metric)
            elif cfg.model.fit.loss == 'lifted_loss':
                embedding_loss = lifted_loss(pids, batch_embedding, margin=cfg.model.fit.margin)
            elif cfg.model.fit.loss == 'contrastive_loss':
                assert batch_size % 2 == 0
                assert cfg.model.batch.K == 4  ## Can work with other number but will need tuning

                contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7], cfg.model.batch.P // 2)
                for i in range(cfg.model.batch.P // 2):
                    contrastive_idx[i * 8:i * 8 + 8] += i * 8

                contrastive_idx = np.expand_dims(contrastive_idx, 1)
                batch_embedding_ordered = tf.gather_nd(batch_embedding, contrastive_idx)
                pids_ordered = tf.gather_nd(pids, contrastive_idx)
                # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000)
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding_ordered, [-1, 2, cfg.model.embedding_dim]), 2,
                    1)
                # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000)

                fixed_labels = np.tile([1, 0, 0, 1], cfg.model.batch.P // 2)
                # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1))
                # print(fixed_labels)
                labels = tf.constant(fixed_labels)
                # labels = tf.Print(labels,[labels],'labels ',summarize=1000)
                embedding_loss = contrastive_loss(labels, embeddings_anchor, embeddings_positive,
                                                margin=cfg.model.fit.margin)
            elif cfg.model.fit.loss == 'angular_loss':
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, cfg.model.embedding_dim]), 2,
                    1)
                # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                # pids = tf.Print(pids,[pids],'pids:: ',summarize=100)
                # Añadimos el parámetro del ángulo máximo que puede
                # forma epn y ean
                embedding_loss = angular_loss(pids, embeddings_anchor, embeddings_positive,
                                            degree= cfg.model.fit.alpha,
                                            batch_size=cfg.model.batch.P, with_l2reg=True)

            elif cfg.model.fit.loss == 'npairs_loss':
                assert cfg.model.batch.K == 2  ## Single positive pair per class
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, cfg.model.embedding_dim]), 2, 1)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                pids = tf.reshape(pids, [-1])
                embedding_loss = npairs_loss(pids, embeddings_anchor, embeddings_positive)

            else:
                raise NotImplementedError('Invalid Loss {}'.format(cfg.model.fit.loss))
            loss_mean = tf.reduce_mean(embedding_loss)


        gradients = tape.gradient(loss_mean, emb_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, emb_model.trainable_variables))
        # Almacenamos los datos para tensorboard
        with train_summary_writer.as_default():
            tf.summary.scalar('Loss mean', loss_mean, step=iteration)
            # tf.summary.image("Training data", images, step=iteration)
            # tf.summary.scalar('Learning Rate', optimizer.lr, step=iteration)

        return embedding_loss

    #print('Starting training from iteration {}.'.format(start_step))
    with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(ckpt.step.numpy(),cfg.model.fit.epochs):
                # for batch_idx, batch in enumerate():
                start_time = time.time()
                images, fids, pids = next(dataset_iter)
                # strategy = tf.distribute.MirroredStrategy()
                # strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
                # with strategy.scope():
                # TODO: Leer el documento https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops
                batch_loss = train_step(images, pids, i)
                elapsed_time = time.time() - start_time
                seconds_todo = (cfg.model.fit.epochs - i) * elapsed_time
                print('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)'.format(
                    i,
                    tf.reduce_min(batch_loss).numpy(),tf.reduce_mean(batch_loss).numpy(),tf.reduce_max(batch_loss).numpy(),
                    # cfg.model.batch.K - 1, float(b_prec_at_k),
                    timedelta(seconds=int(seconds_todo)),
                    elapsed_time))

                ckpt.step.assign_add(1)
                if (cfg.checkpoint_frequency> 0 and i % cfg.checkpoint_frequency== 0):

                    #uncomment if you want to save the emb_model weight separately
                    #emb_model.save_weights(os.path.join(cfg.dirs.checkpoint, 'emb_model_weights_{0:04d}.w'.format(i)))
                    manager.save()

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break
Esempio n. 4
0
def main(cfg):

    # TODO:  Habilitar loggers
    # Definimos la ubicación donde se almacena los logs de Tensorboard
    cbir_summary_writer = tf.summary.create_file_writer(cfg.dirs.cbir_log)

    if not os.path.exists(cfg.dirs.csv_file):
        raise IOError(' No se encuentra el fichero del dataset: {}'.format(
            cfg.dirs.csv_file))
        return

    # Cargamos los nombres de la etiquetas

    labels_id, labels_name = load_labels(cfg.dirs.labels_file)
    # print(labels_id)
    # print(labels_name)
    # Carga de los datos de nuestro catálogo

    # Cargamos el fichero con las característias y los pids de nuestro
    # catálogo (base de datos de imágenes)

    with h5py.File(cfg.dirs.embeddings_file, 'r') as db:
        db_embs = db['emb'][()]
        db_pids = db['pids'][()]
        db_fids = db['fids'][()]

    # Asignamos un idx a cada imagen

    idxs = np.array([i for i in np.arange(len(db_fids))])
    # Creamos el dataset con nuestro catálogo
    dataset = tf.data.Dataset.from_tensor_slices((db_embs, db_pids, idxs))

    # Ahora cargamos las imágenes transformandolas al tamaño del emb_modelo
    net_input_size = (cfg.model.input.height, cfg.model.input.width)
    pre_crop_size = (cfg.model.crop.height, cfg.model.crop.width)
    # Comprobamos si queremos hacer el crop
    image_size = pre_crop_size if cfg.model.crop else net_input_size
    # Redimensionamos las imágenes
    query_image = load_image(query_file, image_size)

    # Si se ha habilitado el CROP redimensionamos al tamaño de red
    if cfg.model.crop:
        query_image = tf.image.random_crop(query_image, net_input_size + (3, ))
    #Carga de y un model Predefinido y preparación para el entrenamiento
    # Definimos la fase de nuestro entorno TF 0 = test, 1= train
    tf.keras.backend.set_learning_phase(0)
    emb_model = EmbeddingModel(cfg)
    # Agrupamos el dataset en los batch_size
    # y preprocesamos la imagen para prepararlo para el emb_modelo
    query_image_preprocessed = emb_model.preprocess_input(query_image)

    # Definimos los checkpoint
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=emb_model)
    manager = tf.train.CheckpointManager(ckpt,
                                         cfg.dirs.checkpoint,
                                         max_to_keep=1)
    # Recuperamos el último checkpoint
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Recuperado desde {}".format(manager.latest_checkpoint))
    else:
        print("Inicializado desde inicio.")

    # Para cada una de las imagenes de respuesta grabamos la etiqueta y las distancias.
    matched_labels = []
    distances = []
    retrieved_idxs = []
    # Recorremos todas imagénes del db_queries y obtenemos las imágenes más cercanas
    t0 = time.time()
    # Obtenemos el vector de nuestra imagen
    query_image_preprocessed = tf.reshape(
        query_image_preprocessed,
        [1, cfg.model.input.height, cfg.model.input.width, 3])

    query = emb_model(query_image_preprocessed)

    #Calculamos las distancias
    dataset = dataset.map(lambda embs, pid, idx: calculate_distances(
        embs, pid, idx, query, type='Euclidian'))
    # Quitamos las caraterísitas para agilizar los cálculos
    dataset_without_embs = dataset.map(
        lambda embs, pid, idx, distance: remove_embs(embs, pid, idx, distance))

    # TODO: Ver una forma de ordenar con Tensorflow GPU
    # el array se quedaría pid,idx, distance, boolean_label
    distance_with_labels = np.array(
        list(dataset_without_embs.as_numpy_iterator()))
    sorted_distance_with_labels = distance_with_labels[
        distance_with_labels[:, 2].argsort()]

    # Obtenemos los indices, distancia y etiquetas
    # sólo del número que vamos a devolver (n_retrievals)
    sorted_idxs = np.array(sorted_distance_with_labels[:n_retrievals,
                                                       1]).astype('int')
    sorted_distances = np.array(
        sorted_distance_with_labels[:n_retrievals, 2]).astype('float32')
    print(sorted_distances)

    # Obtenemos los nombres de las etiquetas de las imágenes obtenidas
    sorted_labels = sorted_distance_with_labels[:n_retrievals, 0]
    sorted_labels_names = [
        get_label_name(pid, labels_id, labels_name) for pid in sorted_labels
    ]
    print("Labels images retrieval: {}".format(sorted_labels_names))

    retrievals = tf.data.Dataset.from_tensor_slices(
        (db_fids[sorted_idxs], np.asarray(db_pids, dtype=int)[sorted_idxs]))
    # Redimensionamos las imágenes
    retrievals = retrievals.map(
        lambda fid, pid: fid_to_image(fid, pid, cfg.dirs.images, image_size),
        num_parallel_calls=cfg.loading_threads)
    # print(list(retrievals.as_numpy_iterator()))
    # Si se ha habilitado el CROP redimensionamos al tamaño de red

    # if cfg.model.crop:
    #     retrievals  = retrievals .map(lambda im, fid, pid:
    #                           (tf.image.random_crop(im, net_input_size + (3,)),
    #                            fid,
    #                            pid))
    #Cogemos una imágen para hacer la prueba
    retrievals = retrievals.batch(cfg.n_retrievals)
    retrievals_iter = iter(retrievals)
    retrievals_images, retrievals_fid, retrievals_pid = next(retrievals_iter)

    query_image = tf.reshape(
        query_image, [1, cfg.model.input.height, cfg.model.input.width, 3])
    summary_retrievals = show_images(query_image, [query_file], [-1],
                                     retrievals_images, retrievals_fid,
                                     retrievals_pid, sorted_distances)

    print(retrievals_fid)
    print(retrievals_pid)
    t1 = time.time()
    print('Tiempo en recuperar las imágenes: %.4f s' % (t1 - t0))

    with cbir_summary_writer.as_default():
        tf.summary.image("Imagenes Recuperadas",
                         plot_to_image(summary_retrievals),
                         step=1)
        tf.summary.scalar('Tiempo de recuperación', t1 - t0, step=1)
Esempio n. 5
0
        train_query_paper_ids = train_paper_ids[start_split:end_split]

        start_split = args.local_rank * val_split_size
        end_split = (args.local_rank + 1) * val_split_size

        val_query_paper_ids = val_paper_ids[start_split:end_split]

    else:
        train_query_paper_ids = train_paper_ids[:]
        val_query_paper_ids = val_paper_ids[:]

    train_triplet_dataset = TripletIterableDataset(
        dataset, train_paper_ids, set(train_paper_ids),
        args.train_triplets_per_epoch, args)
    test_triplet_dataset = TripletIterableDataset(
        dataset, val_paper_ids, set(train_paper_ids + val_paper_ids),
        args.eval_triplets_per_epoch, args)

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)
    collater = TripletCollater(tokenizer, args.max_seq_len)

    model = EmbeddingModel(args)

    trainer = Trainer(model,
                      train_triplet_dataset,
                      test_triplet_dataset,
                      args,
                      data_collater=collater)

    trainer.train()
Esempio n. 6
0
def main(argv):
    # Verify that parameters are set correctly.
    args = parser.parse_args(argv)

    if not os.path.exists(args.dataset):
        return

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'

    os_utils.touch_dir(os.path.join(args.experiment_root, args.foldername))

    log_file = os.path.join(args.experiment_root, args.foldername, "embed")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('embed')

    args.filename = os.path.join(args.experiment_root, args.foldername,
                                 args.filename)
    var_filepath = os.path.join(args.experiment_root, args.foldername,
                                args.filename[:-3] + '_var.txt')
    # Load the args from the original experiment.
    args_file = os.path.join(args.experiment_root, 'args.json')

    if os.path.isfile(args_file):
        if not args.quiet:
            print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)

        # Add arguments from training.
        for key, value in args_resumed.items():
            args.__dict__.setdefault(key, value)

        # A couple special-cases and sanity checks
        if (args_resumed['crop_augment']) == (args.crop_augment is None):
            print('WARNING: crop augmentation differs between training and '
                  'evaluation.')
        args.image_root = args.image_root or args_resumed['image_root']
    else:
        raise IOError(
            '`args.json` could not be found in: {}'.format(args_file))

    # Check a proper aggregator is provided if augmentation is used.
    if args.flip_augment or args.crop_augment == 'five':
        if args.aggregator is None:
            print(
                'ERROR: Test time augmentation is performed but no aggregator'
                'was specified.')
            exit(1)
    else:
        if args.aggregator is not None:
            print('ERROR: No test time augmentation that needs aggregating is '
                  'performed but an aggregator was specified.')
            exit(1)

    if not args.quiet:
        print('Evaluating using the following parameters:')
        for key, value in sorted(vars(args).items()):
            print('{}: {}'.format(key, value))

    # Load the data from the CSV file.
    _, data_fids = common.load_dataset(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    dataset = dataset.map(lambda fid: common.fid_to_image(
        fid,
        tf.constant('dummy'),
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    modifiers = ['original']
    if args.flip_augment:
        dataset = dataset.map(flip_augment)
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [o + m for m in ['', '_flip'] for o in modifiers]

    if args.crop_augment == 'center':
        dataset = dataset.map(lambda im, fid, pid:
                              (five_crops(im, net_input_size)[0], fid, pid))
        modifiers = [o + '_center' for o in modifiers]
    elif args.crop_augment == 'five':
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.stack(five_crops(im, net_input_size)),
                               tf.stack([fid] * 5), tf.stack([pid] * 5)))
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [
            o + m for o in modifiers for m in [
                '_center', '_top_left', '_top_right', '_bottom_left',
                '_bottom_right'
            ]
        ]
    elif args.crop_augment == 'avgpool':
        modifiers = [o + '_avgpool' for o in modifiers]
    else:
        modifiers = [o + '_resize' for o in modifiers]

    emb_model = EmbeddingModel(args)

    # Group it back into PK batches.
    dataset = dataset.batch(args.batch_size)
    dataset = dataset.map(lambda im, fid, pid:
                          (emb_model.preprocess_input(im), fid, pid))
    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)
    tf.keras.backend.set_learning_phase(0)

    with h5py.File(args.filename, 'w') as f_out:

        ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=emb_model)
        manager = tf.train.CheckpointManager(ckpt,
                                             osp.join(args.experiment_root,
                                                      'tf_ckpts'),
                                             max_to_keep=1)
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print("Restored from {}".format(manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

        emb_storage = np.zeros(
            (len(data_fids) * len(modifiers), args.embedding_dim), np.float32)

        # for batch_idx,batch in enumerate(dataset):
        dataset_iter = iter(dataset)
        for start_idx in count(step=args.batch_size):

            try:
                images, _, _ = next(dataset_iter)
                emb = emb_model(images)
                emb_storage[start_idx:start_idx + len(emb)] += emb
                print('\rEmbedded batch {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
                      flush=True,
                      end='')
            except StopIteration:
                break  # This just indicates the end of the dataset.

        if not args.quiet:
            print("Done with embedding, aggregating augmentations...",
                  flush=True)

        if len(modifiers) > 1:
            # Pull out the augmentations into a separate first dimension.
            emb_storage = emb_storage.reshape(len(data_fids), len(modifiers),
                                              -1)
            emb_storage = emb_storage.transpose((1, 0, 2))  # (Aug,FID,128D)

            # Store the embedding of all individual variants too.
            emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage)

            # Aggregate according to the specified parameter.
            emb_storage = AGGREGATORS[args.aggregator](emb_storage)

        # Store the final embeddings.
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)

        # Store information about the produced augmentation and in case no crop
        # augmentation was used, if the images are resized or avg pooled.
        f_out.create_dataset('augmentation_types',
                             data=np.asarray(modifiers, dtype='|S'))
Esempio n. 7
0
def main(cfg):

    # TODO:  Habilitar loggers
    # Definimos la ubicación donde se almacena los logs de Tensorboard
    eval_summary_writer = tf.summary.create_file_writer(cfg.dirs.eval_log)

    if not os.path.exists(cfg.dirs.csv_file):
        raise IOError(' No se encuentra el fichero del dataset: {}'.format(
            cfg.dirs.csv_file))
        return

    # Carga de los datos de nuestro catálogo

    # Cargamos el fichero con las característias y los pids de nuestro
    # catálogo (base de datos de imágenes)

    with h5py.File(cfg.dirs.embeddings_file, 'r') as db:
        db_embs = db['emb'][()]
        db_pids = db['pids'][()]
        db_fids = db['fids'][()]

    # Asignamos un idx a cada imagen

    idxs = np.array([i for i in np.arange(len(db_fids))])
    # Creamos el dataset con nuestro catálogo
    dataset = tf.data.Dataset.from_tensor_slices((db_embs, db_pids, idxs))

    # Cargamos los datos desde el fichero de Test
    test_pids, test_fids = load_dataset(cfg.dirs.test_file, cfg.dirs.images)
    # Preparamos un dataset para pasar el modelo y obtener los
    # vectores de características de toda nuestras imaǵenes de test
    test_dataset = tf.data.Dataset.from_tensor_slices(
        (test_fids, np.asarray(test_pids, dtype=int)))
    # Cogemos valores aleatorios
    test_dataset = test_dataset.shuffle(len(test_fids))
    # Ahora cargamos las imágenes transformandolas al tamaño del emb_modelo
    net_input_size = (cfg.model.input.height, cfg.model.input.width)
    pre_crop_size = (cfg.model.crop.height, cfg.model.crop.width)
    # Comprobamos si queremos hacer el crop
    image_size = pre_crop_size if cfg.model.crop else net_input_size
    # Redimensionamos las imágenes
    test_dataset = test_dataset.map(
        lambda fid, pid: fid_to_image(fid, pid, cfg.dirs.images, image_size),
        num_parallel_calls=cfg.loading_threads)
    # Si se ha habilitado el CROP redimensionamos al tamaño de red
    if cfg.model.crop:
        test_dataset = test_dataset.map(lambda im, fid, pid: (
            tf.image.random_crop(im, net_input_size + (3, )), fid, pid))
    #Carga de y un model Predefinido y preparación para el entrenamiento
    # Definimos la fase de nuestro entorno TF 0 = test, 1= train
    tf.keras.backend.set_learning_phase(0)
    emb_model = EmbeddingModel(cfg)
    # Agrupamos el dataset en los batch_size
    # y preprocesamos la imagen para prepararlo para el emb_modelo
    test_dataset = test_dataset.map(
        lambda im, fid, pid: (emb_model.preprocess_input(im), fid, pid, im))

    # Para cada un de los test tenemos que coger una imágen de test
    # que será nuestra imagen Query
    # TODO: Podemos ver otra opción de coger el número de purebas
    test_dataset = test_dataset.batch(1)
    test_dataset_iter = iter(test_dataset)

    # Preparación de los siguientes batch
    # Esto mejora la latencia y el rendimiento en el coste computacional de usar
    # memoria adicional  para almacenar los siguientes batch
    test_dataset = test_dataset.prefetch(1)

    # Definimos los checkpoint
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=emb_model)
    manager = tf.train.CheckpointManager(ckpt,
                                         cfg.dirs.checkpoint,
                                         max_to_keep=1)
    # Recuperamos el último checkpoint
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Recuperado desde {}".format(manager.latest_checkpoint))
    else:
        print("Inicializado desde inicio.")

    # Para cada una de las imagenes de respuesta grabamos la etiqueta y las distancias.
    matched_labels = []
    distances = []
    retrieved_idxs = []
    # Recorremos todas imagénes del db_queries y obtenemos las imágenes más cercanas
    total_t0 = time.time()
    # Repetimos este proceso por n_test_samples
    maps = []
    for i in range(cfg.n_test_samples):
        t0 = time.time()
        # obtenemos el vector de nuestra imagen
        test_image, test_fid, test_pid, test_original = next(test_dataset_iter)
        query = emb_model(test_image)
        #Calculamos las distancias
        dataset_distances = dataset.map(
            lambda embs, pid, idx: calculate_distances(
                embs, pid, idx, query, type='Euclidian'))
        # Quitamos las caraterísitas
        dataset_without_embs = dataset_distances.map(
            lambda embs, pid, idx, distance: remove_embs(
                embs, pid, idx, distance))
        dataset_without_embs = dataset_without_embs.map(
            lambda pid, idx, distance: boolean_label(pid, idx, distance,
                                                     test_pid))
        # TODO: Ver una forma de ordenar con Tensorflow
        # el array se quedaría pid,idx, distance, boolean_label
        distance_with_labels = np.array(
            list(dataset_without_embs.as_numpy_iterator()))
        sorted_distance_with_labels = distance_with_labels[
            distance_with_labels[:, 2].argsort()]
        # Obtenemos los indices, distancia y etiquetas
        # sólo del número que vamos a devolver (n_retrievals)
        sorted_idxs = np.array(sorted_distance_with_labels[:n_retrievals,
                                                           1]).astype('int')
        sorted_labels = np.array(sorted_distance_with_labels[:n_retrievals,
                                                             3]).astype('int')
        sorted_distances = np.array(
            sorted_distance_with_labels[:n_retrievals, 2]).astype('float32')

        # Calculamos nuesto AP@ke
        # k es el número de imageners recuperas (n_retrievals)
        print("Label query: {}".format(test_pid))
        print("Labels images retrieval: {}".format(
            sorted_distance_with_labels[:n_retrievals, 0]))

        ap, aps = APatk(sorted_labels)
        maps.append(ap)

        # retrieved_idxs.append(sorted_idxs)
        # distances.append(sorted_distances)
        # matched_labels.append(sorted_labels)

        # Obtenemos las imágenes recuperadas
        retrievals = tf.data.Dataset.from_tensor_slices(
            (db_fids[sorted_idxs], np.asarray(db_pids,
                                              dtype=int)[sorted_idxs]))
        # Redimensionamos las imágenes
        retrievals = retrievals.map(lambda fid, pid: fid_to_image(
            fid, pid, cfg.dirs.images, image_size),
                                    num_parallel_calls=cfg.loading_threads)

        # Si se ha habilitado el CROP redimensionamos al tamaño de red

        # if cfg.model.crop:
        #     retrievals  = retrievals .map(lambda im, fid, pid:
        #                           (tf.image.random_crop(im, net_input_size + (3,)),
        #                            fid,
        #                            pid))
        #Cogemos una imágen para hacer la prueba
        retrievals = retrievals.batch(cfg.n_retrievals)
        retrievals_iter = iter(retrievals)
        retrievals_images, retrievals_fid, retrievals_pid = next(
            retrievals_iter)

        summary_retrievals = show_images(test_original,
                                         test_fid,
                                         test_pid,
                                         retrievals_images,
                                         retrievals_fid,
                                         retrievals_pid,
                                         sorted_distances,
                                         columns=4,
                                         ap=ap,
                                         aps=aps)
        t1 = time.time()
        with eval_summary_writer.as_default():
            tf.summary.image("Imagenes Recuperadas",
                             plot_to_image(summary_retrievals),
                             step=i)
            tf.summary.scalar("ap", ap, step=i)
            tf.summary.scalar('Tiempo de recuperación', t1 - t0, step=i)

    total_t1 = time.time()
    print('Tiempo en recuperar las imágenes: %.4f s' % (t1 - t0))
    # output=np.stack((distances, matched_labels, retrieved_idxs), axis=-1)
    # score = label_ranking_average_precision_score(matched_labels, distances)
    # print('Model score: %.2f %%' % (score*100))
    map = np.average(np.array(maps))
    print('MAP: {:.2f}%'.format(map * 100))
    with eval_summary_writer.as_default():
        tf.summary.scalar('Tiempo Total', total_t1 - total_t0, step=1)
        tf.summary.scalar('Score mAP', map, step=1)
Esempio n. 8
0
def main(cfg):

    if not os.path.exists(cfg.dirs.csv_file):
        raise IOError(' No se encuentra el fichero del dataset: {}'.format(
            cfg.dirs.csv_file))
        return
    # Preparamos el DATASET
    # Cargamos los datos desde el fichero CSV
    pids, fids = load_dataset(cfg.dirs.csv_file, cfg.dirs.images)

    # Preparamos un dataset para pasar el modelo y obtener los
    # vectores de características de toda nuestra base de datos
    dataset = tf.data.Dataset.from_tensor_slices(fids)
    # Asignamos a cada imagen su pid
    dataset = dataset.map(lambda fid: find_pid(fid, fids, pids))

    # Ahora cargamos las imágenes transformandolas al tamaño del emb_modelo
    net_input_size = (cfg.model.input.height, cfg.model.input.width)
    pre_crop_size = (cfg.model.crop.height, cfg.model.crop.width)

    # Comprobamos si queremos hacer el crop

    image_size = pre_crop_size if cfg.model.crop else net_input_size
    # Redimensionamos las imágenes
    dataset = dataset.map(
        lambda fid, pid: fid_to_image(fid, pid, cfg.dirs.images, image_size),
        num_parallel_calls=cfg.loading_threads)

    # Si se ha habilitado el CROP redimensionamos al tamaño de red

    if cfg.model.crop:
        dataset = dataset.map(lambda im, fid, pid: (tf.image.random_crop(
            im, net_input_size + (3, )), fid, pid))
    #Carga de y un model Predefinido y preparación para el entrenamiento
    # Definimos la fase de nuestro entorno TF 0 = test, 1= train
    tf.keras.backend.set_learning_phase(0)
    emb_model = EmbeddingModel(cfg)

    # Agrupamos el dataset en los batch_size
    # y preprocesamos la imagen para prepararlo para el emb_modelo
    # Añadimos los thumbnails
    thumb_size = (28, 28)
    dataset = dataset.map(lambda im, fid, pid:
                          (im, fid, pid, tf.image.resize(im, thumb_size)))

    dataset = dataset.map(lambda im, fid, pid, thumb_size: (
        emb_model.preprocess_input(im), fid, pid, thumb_size))

    # dataset = dataset.map( lambda im, fid, pid, thumb: (im, fid, pid,tf.image.convert_image_dtype(thumb, dtype=tf.uint8, saturate=False)))

    dataset = dataset.batch(cfg.batch_size)
    print('Batch-size: {}'.format(cfg.batch_size))

    # Preparación de los siguientes batch
    # Esto mejora la latencia y el rendimiento en el coste computacional de usar
    # memoria adicional  para almacenar los siguientes batch
    dataset = dataset.prefetch(2)

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    modifiers = ['original']

    # TODO: Analizar si es necesario
    # if args.flip_augment:
    #     dataset = dataset.map(flip_augment)
    #     dataset = dataset.apply(tf.contrib.data.unbatch())
    #     modifiers = [o + m for m in ['', '_flip'] for o in modifiers]
    #
    # if args.crop_augment == 'center':
    #     dataset = dataset.map(lambda im, fid, pid:
    #         (five_crops(im, net_input_size)[0], fid, pid))
    #     modifiers = [o + '_center' for o in modifiers]
    # elif args.crop_augment == 'five':
    #     dataset = dataset.map(lambda im, fid, pid: (
    #         tf.stack(five_crops(im, net_input_size)),
    #         tf.stack([fid]*5),
    #         tf.stack([pid]*5)))
    #     dataset = dataset.apply(tf.contrib.data.unbatch())
    #     modifiers = [o + m for o in modifiers for m in [
    #         '_center', '_top_left', '_top_right', '_bottom_left', '_bottom_right']]
    # elif args.crop_augment == 'avgpool':
    #     modifiers = [o + '_avgpool' for o in modifiers]
    # else:
    #     modifiers = [o + '_resize' for o in modifiers]

    # Empezamos a sacar todas las característias de las imágenes del dataset
    with h5py.File(cfg.dirs.embeddings_file, 'w') as f_out:
        # Definimos los checkpoint
        ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=emb_model)
        manager = tf.train.CheckpointManager(ckpt,
                                             cfg.dirs.checkpoint,
                                             max_to_keep=1)
        # Recuperamos el último checkpoint
        ckpt.restore(manager.latest_checkpoint)

        if manager.latest_checkpoint:
            print("Recuperado desde {}".format(manager.latest_checkpoint))
        else:
            print("Inicializado desde inicio.")

        # Inicializamos la base de datos con ceros
        emb_storage = np.zeros(
            ((len(fids) * len(modifiers)), cfg.model.embedding_dim),
            np.float32)
        thumbnails = np.zeros(
            ((len(fids) * len(modifiers)), thumb_size[0], thumb_size[1], 3),
            np.float32)
        # Recorremos todo el dataset para ir almacenando
        # los vectores de características
        dataset_iter = iter(dataset)
        for start_idx in count(step=cfg.batch_size):
            try:
                images, _, _, thumbs = next(dataset_iter)
                emb = emb_model(images)
                emb_storage[start_idx:start_idx + len(emb)] += emb
                thumbnails[start_idx:start_idx + len(thumbs)] += thumbs
                print('\rCreando vectores de características {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
                      flush=True,
                      end='')
            except StopIteration:
                break  # This just indicates the end of the dataset
        # Almacenamos el vector con identificadores de las clases
        pids_dataset = f_out.create_dataset('pids',
                                            data=np.array(pids, np.int))
        # Almacenamos los nombres de las imágenes de las clases
        #Primero convertimos a ascii los nombres por un problem ane h5py
        #para gestionar los utf-8
        temp = []
        for item in fids:
            temp.append(item.encode('ascii'))
        fids = np.array(temp)

        fids_dataset = f_out.create_dataset('fids', data=fids)
        # Almacenamos el vector de características
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)
    f_out.close()

    sprite = create_sprite(thumbnails)

    sprite = Image.fromarray(np.uint8(sprite))
    sprite_file = os.path.join(cfg.dirs.emb_log, "sprites.png")
    sprite.save(sprite_file)
    # Almacenamos el checkpoint para el projector

    emb_storage = tf.Variable(emb_storage, name="embeddings")

    ckpt_emb = tf.train.Checkpoint(embeddings=emb_storage)
    ckpt_emb.save(os.path.join(cfg.dirs.emb_log, "embedding.ckpt"))
    metadata = os.path.join(cfg.dirs.emb_log, 'metadata.tsv')
    with open(metadata, 'w') as metadata_file:
        for row in pids:
            metadata_file.write('%d\n' % int(row))

    # Generar una visualización del dataset
    # Cogemos las imagenes del datas set.
    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    # The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`
    embedding.tensor_name = 'embeddings/.ATTRIBUTES/VARIABLE_VALUE'
    embedding.metadata_path = metadata
    embedding.sprite.image_path = sprite_file
    embedding.sprite.single_image_dim.extend(thumb_size)
    # Definimos la ubicación donde se almacena los logs de Tensorboard
    # emb_summary_writer = tf.summary.create_file_writer(cfg.dirs.emb_log)
    projector.visualize_embeddings(cfg.dirs.emb_log, config)
Esempio n. 9
0
def main(argv):

    args = parser.parse_args(argv)

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # tf.compat.v1.disable_eager_execution()

    # physical_devices = tf.config.experimental.list_physical_devices('GPU')
    # tf.config.experimental.set_memory_growth(physical_devices[0], True)

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    if len(unique_pids) < args.batch_p:
        unique_pids = np.tile(unique_pids,
                              int(np.ceil(args.batch_p / len(unique_pids))))
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.unbatch()

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.

    dataset = dataset.map(
        lambda im, fid, pid: common.fid_to_image(fid,
                                                 pid,
                                                 image_root=args.image_root,
                                                 image_size=pre_crop_size
                                                 if args.crop_augment else
                                                 net_input_size),  # Ergys
        num_parallel_calls=args.loading_threads)

    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.image.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Create the model and an embedding head.
    tf.keras.backend.set_learning_phase(1)
    emb_model = EmbeddingModel(args)

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.map(lambda im, fid, pid:
                          (emb_model.preprocess_input(im), fid, pid))
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    # all_trainable_variables = embedding_head.trainable_variables+base_model.trainable_variables

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.optimizers.schedules.PolynomialDecay(
            args.learning_rate, args.train_iterations, end_learning_rate=1e-7)
    else:
        learning_rate = args.learning_rate

    if args.optimizer == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate)
    elif args.optimizer == 'momentum':
        optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9)
    else:
        raise NotImplementedError('Invalid optimizer {}'.format(
            args.optimizer))

    @tf.function
    def train_step(images, pids):

        with tf.GradientTape() as tape:
            batch_embedding = emb_model(images)
            if args.loss == 'semi_hard_triplet':
                embedding_loss = triplet_semihard_loss(batch_embedding, pids,
                                                       args.margin)
            elif args.loss == 'hard_triplet':
                embedding_loss = batch_hard(batch_embedding, pids, args.margin,
                                            args.metric)
            elif args.loss == 'lifted_loss':
                embedding_loss = lifted_loss(pids,
                                             batch_embedding,
                                             margin=args.margin)
            elif args.loss == 'contrastive_loss':
                assert batch_size % 2 == 0
                assert args.batch_k == 4  ## Can work with other number but will need tuning

                contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7],
                                          args.batch_p // 2)
                for i in range(args.batch_p // 2):
                    contrastive_idx[i * 8:i * 8 + 8] += i * 8

                contrastive_idx = np.expand_dims(contrastive_idx, 1)
                batch_embedding_ordered = tf.gather_nd(batch_embedding,
                                                       contrastive_idx)
                pids_ordered = tf.gather_nd(pids, contrastive_idx)
                # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000)
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding_ordered,
                               [-1, 2, args.embedding_dim]), 2, 1)
                # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000)

                fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2)
                # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1))
                # print(fixed_labels)
                labels = tf.constant(fixed_labels)
                # labels = tf.Print(labels,[labels],'labels ',summarize=1000)
                embedding_loss = contrastive_loss(labels,
                                                  embeddings_anchor,
                                                  embeddings_positive,
                                                  margin=args.margin)
            elif args.loss == 'angular_loss':
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]),
                    2, 1)
                # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                # pids = tf.Print(pids,[pids],'pids:: ',summarize=100)
                embedding_loss = angular_loss(pids,
                                              embeddings_anchor,
                                              embeddings_positive,
                                              batch_size=args.batch_p,
                                              with_l2reg=True)

            elif args.loss == 'npairs_loss':
                assert args.batch_k == 2  ## Single positive pair per class
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]),
                    2, 1)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                pids = tf.reshape(pids, [-1])
                embedding_loss = npairs_loss(pids, embeddings_anchor,
                                             embeddings_positive)

            else:
                raise NotImplementedError('Invalid Loss {}'.format(args.loss))
            loss_mean = tf.reduce_mean(embedding_loss)

        gradients = tape.gradient(loss_mean, emb_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients,
                                      emb_model.trainable_variables))

        return embedding_loss

    # sess = tf.compat.v1.Session()
    # start_step = sess.run(global_step)
    # checkpoint_saver = tf.train.Saver(max_to_keep=2)
    start_step = 0
    log.info('Starting training from iteration {}.'.format(start_step))
    dataset_iter = iter(dataset)

    ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                               optimizer=optimizer,
                               net=emb_model)
    manager = tf.train.CheckpointManager(ckpt,
                                         osp.join(args.experiment_root,
                                                  'tf_ckpts'),
                                         max_to_keep=3)

    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
        for i in range(ckpt.step.numpy(), args.train_iterations):
            # for batch_idx, batch in enumerate():
            start_time = time.time()
            images, fids, pids = next(dataset_iter)
            batch_loss = train_step(images, pids)
            elapsed_time = time.time() - start_time
            seconds_todo = (args.train_iterations - i) * elapsed_time
            # print(tf.reduce_min(batch_loss).numpy(),tf.reduce_mean(batch_loss).numpy(),tf.reduce_max(batch_loss).numpy())
            log.info(
                'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)'
                .format(
                    i,
                    tf.reduce_min(batch_loss).numpy(),
                    tf.reduce_mean(batch_loss).numpy(),
                    tf.reduce_max(batch_loss).numpy(),
                    # args.batch_k - 1, float(b_prec_at_k),
                    timedelta(seconds=int(seconds_todo)),
                    elapsed_time))

            ckpt.step.assign_add(1)
            if (args.checkpoint_frequency > 0
                    and i % args.checkpoint_frequency == 0):

                # uncomment if you want to save the model weight separately
                # emb_model.save_weights(os.path.join(args.experiment_root, 'model_weights_{0:04d}.w'.format(i)))

                manager.save()

            # Stop the main-loop at the end of the step, if requested.
            if u.interrupted:
                log.info("Interrupted on request!")
                break