Esempio n. 1
0
def main():

    cfg = TrainConfig().parse()
    print(cfg.name)
    result_dir = os.path.join(
        cfg.result_root,
        cfg.name + '_' + datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    utils.write_configure_to_file(cfg, result_dir)
    np.random.seed(seed=cfg.seed)

    # prepare dataset
    train_session = cfg.train_session
    train_set = prepare_dataset(cfg.feature_root, train_session, cfg.feat,
                                cfg.label_root)
    batch_per_epoch = len(train_set) // cfg.sess_per_batch

    val_session = cfg.val_session
    val_set = prepare_dataset(cfg.feature_root, val_session, cfg.feat,
                              cfg.label_root)

    # construct the graph
    with tf.Graph().as_default():
        tf.set_random_seed(cfg.seed)
        global_step = tf.Variable(0, trainable=False)
        lr_ph = tf.placeholder(tf.float32, name='learning_rate')

        # load backbone model
        if cfg.network == "tsn":
            model = networks.ConvTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "rtsn":
            model = networks.ConvRTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)

        # get the embedding
        input_ph = tf.placeholder(tf.float32,
                                  shape=[None, cfg.num_seg, None, None, None])
        label_ph = tf.placeholder(tf.float32, shape=[None])
        dropout_ph = tf.placeholder(tf.float32, shape=[])
        model.forward(input_ph, dropout_ph)
        if cfg.normalized:
            embedding = tf.nn.l2_normalize(model.hidden,
                                           axis=-1,
                                           epsilon=1e-10)
        else:
            embedding = model.hidden

        # variable for visualizing the embeddings
        emb_var = tf.Variable([0.0], name='embeddings')
        set_emb = tf.assign(emb_var, embedding, validate_shape=False)

        # calculated for monitoring all-pair embedding distance
        diffs = utils.all_diffs_tf(embedding, embedding)
        all_dist = utils.cdist_tf(diffs)
        tf.summary.histogram('embedding_dists', all_dist)

        metric_loss, num_active, diff, weights, fp, cn = networks.lifted_loss(
            all_dist, label_ph, cfg.alpha)

        regularization_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = metric_loss + regularization_loss * cfg.lambda_l2

        tf.summary.scalar('active_ratio', num_active)
        tf.summary.scalar('learning_rate', lr_ph)
        train_op = utils.optimize(total_loss, global_step, cfg.optimizer,
                                  lr_ph, tf.global_variables())

        saver = tf.train.Saver(max_to_keep=10)

        summary_op = tf.summary.merge_all()

        # session iterator for session sampling
        feat_paths_ph = tf.placeholder(tf.string,
                                       shape=[None, cfg.sess_per_batch])
        label_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        train_data = session_generator(feat_paths_ph,
                                       label_paths_ph,
                                       sess_per_batch=cfg.sess_per_batch,
                                       num_threads=2,
                                       shuffled=False,
                                       preprocess_func=model.prepare_input)
        train_sess_iterator = train_data.make_initializable_iterator()
        next_train = train_sess_iterator.get_next()

        # prepare validation data
        val_feats = []
        val_labels = []
        for session in val_set:
            eve_batch, lab_batch, _ = load_data_and_label(
                session[0], session[1], model.prepare_input_test
            )  # use prepare_input_test for testing time
            val_feats.append(eve_batch)
            val_labels.append(lab_batch)
        val_feats = np.concatenate(val_feats, axis=0)
        val_labels = np.concatenate(val_labels, axis=0)
        print("Shape of val_feats: ", val_feats.shape)

        # generate metadata.tsv for visualize embedding
        with open(os.path.join(result_dir, 'metadata_val.tsv'), 'w') as fout:
            for v in val_labels:
                fout.write('%d\n' % int(v))

        # Start running the graph
        if cfg.gpu:
            os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        summary_writer = tf.summary.FileWriter(result_dir, sess.graph)

        with sess.as_default():

            sess.run(tf.global_variables_initializer())

            # load pretrain model, if needed
            if cfg.model_path:
                print("Restoring pretrained model: %s" % cfg.model_path)
                saver.restore(sess, cfg.model_path)

            ################## Training loop ##################
            epoch = -1
            while epoch < cfg.max_epochs - 1:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // batch_per_epoch

                # learning rate schedule, reference: "In defense of Triplet Loss"
                if epoch < cfg.static_epochs:
                    learning_rate = cfg.learning_rate
                else:
                    learning_rate = cfg.learning_rate * \
                            0.001**((epoch-cfg.static_epochs)/(cfg.max_epochs-cfg.static_epochs))

                # prepare data for this epoch
                random.shuffle(train_set)

                feat_paths = [path[0] for path in train_set]
                label_paths = [path[1] for path in train_set]
                # reshape a list to list of list
                # interesting hacky code from: https://stackoverflow.com/questions/10124751/convert-a-flat-list-to-list-of-list-in-python
                feat_paths = list(zip(*[iter(feat_paths)] *
                                      cfg.sess_per_batch))
                label_paths = list(
                    zip(*[iter(label_paths)] * cfg.sess_per_batch))

                sess.run(train_sess_iterator.initializer,
                         feed_dict={
                             feat_paths_ph: feat_paths,
                             label_paths_ph: label_paths
                         })

                # for each epoch
                batch_count = 1
                while True:
                    try:
                        # First, sample sessions for a batch
                        start_time_select = time.time()
                        eve, se, lab = sess.run(next_train)
                        select_time1 = time.time() - start_time_select

                        # Second, select samples for a batch
                        batch_idx = select_batch(lab, cfg.batch_size)
                        eve = eve[batch_idx]
                        lab = lab[batch_idx]

                        # Third, perform training on a batch
                        start_time_train = time.time()
                        err, _, step, summ, diff_v, weights_v, fp_v, cn_v, dist_v = sess.run(
                            [
                                total_loss, train_op, global_step, summary_op,
                                diff, weights, fp, cn, all_dist
                            ],
                            feed_dict={
                                input_ph: eve,
                                dropout_ph: cfg.keep_prob,
                                label_ph: np.squeeze(lab),
                                lr_ph: learning_rate
                            })

                        train_time = time.time() - start_time_train
                        print ("Epoch: [%d][%d/%d]\tEvent num: %d\tSelect_time: %.3f\tTrain_time: %.3f\tLoss %.4f" % \
                                (epoch+1, batch_count, batch_per_epoch, eve.shape[0], select_time1, train_time, err))

                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag="train_loss",
                                             simple_value=err),
                            tf.Summary.Value(tag="select_time1",
                                             simple_value=select_time1)
                        ])
                        summary_writer.add_summary(summary, step)
                        summary_writer.add_summary(summ, step)

                        batch_count += 1

                    except tf.errors.OutOfRangeError:
                        print("Epoch %d done!" % (epoch + 1))
                        break

                # validation on val_set
                print("Evaluating on validation set...")
                val_embeddings, _ = sess.run([embedding, set_emb],
                                             feed_dict={
                                                 input_ph: val_feats,
                                                 dropout_ph: 1.0
                                             })
                mAP, mPrec = utils.evaluate_simple(val_embeddings, val_labels)

                summary = tf.Summary(value=[
                    tf.Summary.Value(tag="Valiation mAP", simple_value=mAP),
                    tf.Summary.Value(tag="Validation [email protected]",
                                     simple_value=mPrec)
                ])
                summary_writer.add_summary(summary, step)

                # config for embedding visualization
                config = projector.ProjectorConfig()
                visual_embedding = config.embeddings.add()
                visual_embedding.tensor_name = emb_var.name
                visual_embedding.metadata_path = os.path.join(
                    result_dir, 'metadata_val.tsv')
                projector.visualize_embeddings(summary_writer, config)

                # save model
                saver.save(sess,
                           os.path.join(result_dir, cfg.name + '.ckpt'),
                           global_step=step)
Esempio n. 2
0
def main():

    # Load configurations and write to config.txt
    cfg = TrainConfig().parse()
    print(cfg.name)
    result_dir = os.path.join(
        cfg.result_root,
        cfg.name + '_' + datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    utils.write_configure_to_file(cfg, result_dir)
    np.random.seed(seed=cfg.seed)

    # prepare dataset
    train_session = cfg.train_session
    train_set = prepare_dataset(cfg.feature_root, train_session, cfg.feat,
                                cfg.label_root)
    train_set = train_set[:cfg.label_num]
    batch_per_epoch = len(train_set) // cfg.sess_per_batch

    val_session = cfg.val_session
    val_set = prepare_dataset(cfg.feature_root, val_session, cfg.feat,
                              cfg.label_root)

    # construct the graph
    with tf.Graph().as_default():
        tf.set_random_seed(cfg.seed)
        global_step = tf.Variable(0, trainable=False)
        label_ph = tf.placeholder(tf.int32, shape=[None], name="label")
        lr_ph = tf.placeholder(tf.float32, name='learning_rate')

        ####################### Define model here ########################

        # Load embedding model
        if cfg.network == "tsn":
            model_emb = networks.TSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "rtsn":
            model_emb = networks.RTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "convtsn":
            model_emb = networks.ConvTSN(n_seg=cfg.num_seg,
                                         emb_dim=cfg.emb_dim)
        elif cfg.network == "convrtsn":
            model_emb = networks.ConvRTSN(n_seg=cfg.num_seg,
                                          emb_dim=cfg.emb_dim,
                                          n_h=cfg.n_h,
                                          n_w=cfg.n_w,
                                          n_C=cfg.n_C,
                                          n_input=cfg.n_input)
        elif cfg.network == "convbirtsn":
            model_emb = networks.ConvBiRTSN(n_seg=cfg.num_seg,
                                            emb_dim=cfg.emb_dim)
        else:
            raise NotImplementedError

        # get the embedding
        if cfg.feat == "sensors" or cfg.feat == "segment":
            input_ph = tf.placeholder(tf.float32,
                                      shape=[None, cfg.num_seg, None])
        elif cfg.feat == "resnet" or cfg.feat == "segment_down":
            input_ph = tf.placeholder(
                tf.float32, shape=[None, cfg.num_seg, None, None, None])
        dropout_ph = tf.placeholder(tf.float32, shape=[])
        model_emb.forward(input_ph, dropout_ph)
        if cfg.normalized:
            embedding = tf.nn.l2_normalize(model_emb.hidden,
                                           axis=-1,
                                           epsilon=1e-10)
        else:
            embedding = model_emb.hidden

        # Use tensorflow implementation for loss functions
        if cfg.loss == 'triplet':
            metric_loss, active_count = loss_tf.triplet_semihard_loss(
                labels=label_ph, embeddings=embedding, margin=cfg.alpha)
        elif cfg.loss == 'lifted':
            metric_loss, active_count = loss_tf.lifted_struct_loss(
                labels=label_ph, embeddings=embedding, margin=cfg.alpha)
        else:
            raise NotImplementedError

        regularization_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = metric_loss + regularization_loss * cfg.lambda_l2

        tf.summary.scalar('learning_rate', lr_ph)
        train_op = utils.optimize(total_loss, global_step, cfg.optimizer,
                                  lr_ph, tf.global_variables())

        ####################### Define data loader ############################

        # session iterator for session sampling
        feat_paths_ph = tf.placeholder(tf.string,
                                       shape=[None, cfg.sess_per_batch])
        label_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        train_data = session_generator(feat_paths_ph,
                                       label_paths_ph,
                                       sess_per_batch=cfg.sess_per_batch,
                                       num_threads=2,
                                       shuffled=False,
                                       preprocess_func=model_emb.prepare_input)
        train_sess_iterator = train_data.make_initializable_iterator()
        next_train = train_sess_iterator.get_next()

        # Prepare validation data
        val_sess = []
        val_feats = []
        val_labels = []
        val_boundaries = []
        for session in val_set:
            session_id = os.path.basename(session[1]).split('_')[0]
            eve_batch, lab_batch, boundary = load_data_and_label(
                session[0], session[-1], model_emb.prepare_input_test
            )  # use prepare_input_test for testing time
            val_feats.append(eve_batch)
            val_labels.append(lab_batch)
            val_sess.extend([session_id] * eve_batch.shape[0])
            val_boundaries.extend(boundary)
        val_feats = np.concatenate(val_feats, axis=0)
        val_labels = np.concatenate(val_labels, axis=0)
        print("Shape of val_feats: ", val_feats.shape)

        # generate metadata.tsv for visualize embedding
        with open(os.path.join(result_dir, 'metadata_val.tsv'), 'w') as fout:
            fout.write('id\tlabel\tsession_id\tstart\tend\n')
            for i in range(len(val_sess)):
                fout.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    i, val_labels[i, 0], val_sess[i], val_boundaries[i][0],
                    val_boundaries[i][1]))

        # Variable for visualizing the embeddings
        emb_var = tf.Variable(tf.zeros([val_feats.shape[0], cfg.emb_dim]),
                              name='embeddings')
        set_emb = tf.assign(emb_var, embedding, validate_shape=False)

        # calculated for monitoring all-pair embedding distance
        diffs = utils.all_diffs_tf(embedding, embedding)
        all_dist = utils.cdist_tf(diffs)
        tf.summary.histogram('embedding_dists', all_dist)

        summary_op = tf.summary.merge_all()
        saver = tf.train.Saver(max_to_keep=10)

        #########################################################################

        # Start running the graph
        if cfg.gpu:
            os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        summary_writer = tf.summary.FileWriter(result_dir, sess.graph)

        with sess.as_default():

            sess.run(tf.global_variables_initializer())

            # load pretrain model, if needed
            if cfg.model_path:
                print("Restoring pretrained model: %s" % cfg.model_path)
                saver.restore(sess, cfg.model_path)

            ################## Training loop ##################
            epoch = -1
            while epoch < cfg.max_epochs - 1:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // batch_per_epoch

                # learning rate schedule, reference: "In defense of Triplet Loss"
                if epoch < cfg.static_epochs:
                    learning_rate = cfg.learning_rate
                else:
                    learning_rate = cfg.learning_rate * \
                            0.01**((epoch-cfg.static_epochs)/(cfg.max_epochs-cfg.static_epochs))

                # prepare data for this epoch
                random.shuffle(train_set)

                feat_paths = [path[0] for path in train_set]
                label_paths = [path[1] for path in train_set]
                # reshape a list to list of list
                # interesting hacky code from: https://stackoverflow.com/questions/10124751/convert-a-flat-list-to-list-of-list-in-python
                feat_paths = list(zip(*[iter(feat_paths)] *
                                      cfg.sess_per_batch))
                label_paths = list(
                    zip(*[iter(label_paths)] * cfg.sess_per_batch))

                sess.run(train_sess_iterator.initializer,
                         feed_dict={
                             feat_paths_ph: feat_paths,
                             label_paths_ph: label_paths
                         })

                # for each epoch
                batch_count = 1
                while True:
                    try:
                        # Get a batch
                        start_time_select = time.time()

                        eve, se, lab = sess.run(next_train)
                        # for memory concern, cfg.event_per_batch events are used in maximum
                        if eve.shape[0] > cfg.event_per_batch:
                            idx = np.random.permutation(
                                eve.shape[0])[:cfg.event_per_batch]
                            eve = eve[idx]
                            se = se[idx]
                            lab = lab[idx]

                        select_time = time.time() - start_time_select

                        start_time_train = time.time()

                        # perform training on the batch
                        err, _, step, summ = sess.run(
                            [total_loss, train_op, global_step, summary_op],
                            feed_dict={
                                input_ph: eve,
                                label_ph: np.squeeze(lab),
                                dropout_ph: cfg.keep_prob,
                                lr_ph: learning_rate
                            })

                        train_time = time.time() - start_time_train

                        print ("%s\tEpoch: [%d][%d/%d]\tEvent num: %d\tSelect_time: %.3f\tTrain_time: %.3f\tLoss %.4f" % \
                                (cfg.name, epoch+1, batch_count, batch_per_epoch, eve.shape[0], select_time, train_time, err))

                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag="train_loss",
                                             simple_value=err),
                        ])
                        summary_writer.add_summary(summary, step)
                        summary_writer.add_summary(summ, step)

                        batch_count += 1

                    except tf.errors.OutOfRangeError:
                        print("Epoch %d done!" % (epoch + 1))
                        break

                # validation on val_set
                print("Evaluating on validation set...")
                val_embeddings, _ = sess.run([embedding, set_emb],
                                             feed_dict={
                                                 input_ph: val_feats,
                                                 dropout_ph: 1.0
                                             })
                mAP, mPrec, recall = utils.evaluate_simple(
                    val_embeddings, val_labels)
                summary = tf.Summary(value=[
                    tf.Summary.Value(tag="Valiation mAP", simple_value=mAP),
                    tf.Summary.Value(tag="Validation Recall@1",
                                     simple_value=recall),
                    tf.Summary.Value(tag="Validation [email protected]",
                                     simple_value=mPrec)
                ])
                summary_writer.add_summary(summary, step)
                print("Epoch: [%d]\tmAP: %.4f\tmPrec: %.4f" %
                      (epoch + 1, mAP, mPrec))

                # config for embedding visualization
                config = projector.ProjectorConfig()
                visual_embedding = config.embeddings.add()
                visual_embedding.tensor_name = emb_var.name
                visual_embedding.metadata_path = os.path.join(
                    result_dir, 'metadata_val.tsv')
                projector.visualize_embeddings(summary_writer, config)

                # save model
                saver.save(sess,
                           os.path.join(result_dir, cfg.name + '.ckpt'),
                           global_step=step)
def main():

    cfg = TrainConfig().parse()
    print(cfg.name)
    result_dir = os.path.join(
        cfg.result_root,
        cfg.name + '_' + datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    utils.write_configure_to_file(cfg, result_dir)
    np.random.seed(seed=cfg.seed)

    # prepare dataset
    train_session = cfg.train_session
    train_set = prepare_dataset(cfg.feature_root, train_session, cfg.feat,
                                cfg.label_root)
    batch_per_epoch = len(train_set) // cfg.sess_per_batch

    val_session = cfg.val_session[:3]
    val_set = prepare_dataset(cfg.feature_root, val_session, cfg.feat,
                              cfg.label_root)

    # construct the graph
    with tf.Graph().as_default():
        tf.set_random_seed(cfg.seed)
        global_step = tf.Variable(0, trainable=False)

        # subtract global_step by 1 if needed (for hard negative mining, keep global_step unchanged)
        subtract_global_step_op = tf.assign(global_step, global_step - 1)

        lr_ph = tf.placeholder(tf.float32, name='learning_rate')

        # load backbone model
        if cfg.network == "tsn":
            model_emb = networks.TSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "rtsn":
            model_emb = networks.RTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "convtsn":
            model_emb = networks.ConvTSN(n_seg=cfg.num_seg,
                                         emb_dim=cfg.emb_dim)
        elif cfg.network == "convrtsn":
            model_emb = networks.ConvRTSN(n_seg=cfg.num_seg,
                                          emb_dim=cfg.emb_dim)
        else:
            raise NotImplementedError

        model_ver = networks.PairSim(n_input=cfg.emb_dim)

        # get the embedding
        if cfg.feat == "sensors":
            input_ph = tf.placeholder(tf.float32,
                                      shape=[None, cfg.num_seg, None])
        elif cfg.feat == "resnet":
            input_ph = tf.placeholder(
                tf.float32, shape=[None, cfg.num_seg, None, None, None])
        dropout_ph = tf.placeholder(tf.float32, shape=[])
        label_ph = tf.placeholder(tf.int32, shape=[None])
        model_emb.forward(input_ph, dropout_ph)
        embedding = model_emb.hidden

        # split embedding into A and B
        emb_A, emb_B = tf.unstack(tf.reshape(embedding, [-1, 2, cfg.emb_dim]),
                                  2, 1)
        pairs = tf.stack([emb_A, emb_B], axis=1)

        model_ver.forward(pairs, dropout_ph)
        logits = model_ver.logits
        prob = model_ver.prob
        pred = tf.argmax(logits, -1)

        ver_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_ph,
                                                           logits=logits))

        regularization_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = ver_loss + regularization_loss * cfg.lambda_l2

        tf.summary.scalar('learning_rate', lr_ph)
        train_op = utils.optimize(total_loss, global_step, cfg.optimizer,
                                  lr_ph, tf.global_variables())

        saver = tf.train.Saver(max_to_keep=10)

        summary_op = tf.summary.merge_all()

        # session iterator for session sampling
        feat_paths_ph = tf.placeholder(tf.string,
                                       shape=[None, cfg.sess_per_batch])
        label_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        train_data = session_generator(feat_paths_ph,
                                       label_paths_ph,
                                       sess_per_batch=cfg.sess_per_batch,
                                       num_threads=2,
                                       shuffled=False,
                                       preprocess_func=model_emb.prepare_input)
        train_sess_iterator = train_data.make_initializable_iterator()
        next_train = train_sess_iterator.get_next()

        # prepare validation data
        val_sess = []
        val_feats = []
        val_labels = []
        val_boundaries = []
        for session in val_set:
            session_id = os.path.basename(session[1]).split('_')[0]
            eve_batch, lab_batch, boundary = load_data_and_label(
                session[0], session[1], model_emb.prepare_input_test
            )  # use prepare_input_test for testing time
            val_feats.append(eve_batch)
            val_labels.append(lab_batch)
            val_sess.extend([session_id] * eve_batch.shape[0])
            val_boundaries.extend(boundary)
        val_feats = np.concatenate(val_feats, axis=0)
        val_labels = np.concatenate(val_labels, axis=0)

        # generate metadata.tsv for visualize embedding
        with open(os.path.join(result_dir, 'metadata_val.tsv'), 'w') as fout:
            fout.write('id\tlabel\tsession_id\tstart\tend\n')
            for i in range(len(val_sess)):
                fout.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    i, val_labels[i, 0], val_sess[i], val_boundaries[i][0],
                    val_boundaries[i][1]))

        val_idx, val_labels = random_pairs(val_labels, 1000000, test=True)
        val_feats = val_feats[val_idx]
        val_labels = np.asarray(val_labels, dtype='int32')
        print("Shape of val_feats: ", val_feats.shape)

        # Start running the graph
        if cfg.gpu:
            os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        summary_writer = tf.summary.FileWriter(result_dir, sess.graph)

        with sess.as_default():

            sess.run(tf.global_variables_initializer())

            # load pretrain model, if needed
            if cfg.model_path:
                print("Restoring pretrained model: %s" % cfg.model_path)
                saver.restore(sess, cfg.model_path)

            ################## Training loop ##################
            epoch = -1
            while epoch < cfg.max_epochs - 1:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // batch_per_epoch

                # learning rate schedule, reference: "In defense of Triplet Loss"
                if epoch < cfg.static_epochs:
                    learning_rate = cfg.learning_rate
                else:
                    learning_rate = cfg.learning_rate * \
                            0.001**((epoch-cfg.static_epochs)/(cfg.max_epochs-cfg.static_epochs))

                # prepare data for this epoch
                random.shuffle(train_set)

                feat_paths = [path[0] for path in train_set]
                label_paths = [path[1] for path in train_set]
                # reshape a list to list of list
                # interesting hacky code from: https://stackoverflow.com/questions/10124751/convert-a-flat-list-to-list-of-list-in-python
                feat_paths = list(zip(*[iter(feat_paths)] *
                                      cfg.sess_per_batch))
                label_paths = list(
                    zip(*[iter(label_paths)] * cfg.sess_per_batch))

                sess.run(train_sess_iterator.initializer,
                         feed_dict={
                             feat_paths_ph: feat_paths,
                             label_paths_ph: label_paths
                         })

                # for each epoch
                batch_count = 1
                while True:
                    try:
                        # Hierarchical sampling (same as fast rcnn)
                        start_time_select = time.time()

                        # First, sample sessions for a batch
                        eve, se, lab = sess.run(next_train)

                        select_time1 = time.time() - start_time_select

                        # select pairs for training
                        pair_idx, train_labels = random_pairs(
                            lab, cfg.batch_size, cfg.num_negative)

                        train_input = eve[pair_idx]
                        train_labels = np.asarray(train_labels, dtype='int32')
                        select_time2 = time.time(
                        ) - start_time_select - select_time1

                        start_time_train = time.time()
                        # perform training on the selected pairs
                        err, y_pred, y_prob, _, step, summ = sess.run(
                            [
                                total_loss, pred, prob, train_op, global_step,
                                summary_op
                            ],
                            feed_dict={
                                input_ph: train_input,
                                label_ph: train_labels,
                                dropout_ph: cfg.keep_prob,
                                lr_ph: learning_rate
                            })
                        acc = accuracy_score(train_labels, y_pred)

                        negative_count = 0
                        if epoch >= cfg.negative_epochs:
                            hard_idx, hard_labels, negative_count = hard_pairs(
                                train_labels, y_prob, 0.5)
                            if negative_count > 0:
                                hard_input = train_input[hard_idx]
                                hard_labels = np.asarray(hard_labels,
                                                         dtype='int32')

                                step = sess.run(subtract_global_step_op)
                                hard_err, y_pred, _, step = sess.run(
                                    [total_loss, pred, train_op, global_step],
                                    feed_dict={
                                        input_ph: hard_input,
                                        label_ph: hard_labels,
                                        dropout_ph: cfg.keep_prob,
                                        lr_ph: learning_rate
                                    })

                        train_time = time.time() - start_time_train

                        print ("%s\tEpoch: [%d][%d/%d]\tEvent num: %d\tSelect_time1: %.3f\tSelect_time2: %.3f\tTrain_time: %.3f\tLoss: %.4f" % \
                                (cfg.name, epoch+1, batch_count, batch_per_epoch, eve.shape[0], select_time1, select_time2, train_time, err))

                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag="train_loss",
                                             simple_value=err),
                            tf.Summary.Value(tag="acc", simple_value=acc),
                            tf.Summary.Value(tag="negative_count",
                                             simple_value=negative_count)
                        ])
                        summary_writer.add_summary(summary, step)
                        summary_writer.add_summary(summ, step)

                        batch_count += 1

                    except tf.errors.OutOfRangeError:
                        print("Epoch %d done!" % (epoch + 1))
                        break

                # validation on val_set
                print("Evaluating on validation set...")
                val_err, val_pred, val_prob = sess.run(
                    [total_loss, pred, prob],
                    feed_dict={
                        input_ph: val_feats,
                        label_ph: val_labels,
                        dropout_ph: 1.0
                    })
                val_acc = accuracy_score(val_labels, val_pred)

                summary = tf.Summary(value=[
                    tf.Summary.Value(tag="Valiation acc",
                                     simple_value=val_acc),
                    tf.Summary.Value(tag="Validation loss",
                                     simple_value=val_err)
                ])
                summary_writer.add_summary(summary, step)

                # save model
                saver.save(sess,
                           os.path.join(result_dir, cfg.name + '.ckpt'),
                           global_step=step)

        # print log for analysis
        with open(os.path.join(result_dir, 'val_results.txt'), 'w') as fout:
            fout.write("acc = %.4f\n" % val_acc)
            fout.write("label\tprob_0\tprob_1\tA_idx\tB_idx\n")
            for i in range(val_prob.shape[0]):
                fout.write("%d\t%.4f\t%.4f\t%d\t%d\n" %
                           (val_labels[i], val_prob[i, 0], val_prob[i, 1],
                            val_idx[2 * i], val_idx[2 * i + 1]))
Esempio n. 4
0
def main():

    cfg = TrainConfig().parse()
    print(cfg.name)
    result_dir = os.path.join(
        cfg.result_root,
        cfg.name + '_' + datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    utils.write_configure_to_file(cfg, result_dir)
    np.random.seed(seed=cfg.seed)

    # prepare dataset
    train_session = cfg.train_session
    train_set = prepare_dataset(cfg.feature_root, train_session, cfg.feat,
                                cfg.label_root)
    batch_per_epoch = len(train_set) // cfg.sess_per_batch

    val_session = cfg.val_session
    val_set = prepare_dataset(cfg.feature_root, val_session, cfg.feat,
                              cfg.label_root)

    # construct the graph
    with tf.Graph().as_default():
        tf.set_random_seed(cfg.seed)
        global_step = tf.Variable(0, trainable=False)
        lr_ph = tf.placeholder(tf.float32, name='learning_rate')

        # load backbone model
        if cfg.network == "tsn":
            model_emb = networks.TSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "rtsn":
            model_emb = networks.RTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "convtsn":
            model_emb = networks.ConvTSN(n_seg=cfg.num_seg,
                                         emb_dim=cfg.emb_dim)
        elif cfg.network == "convrtsn":
            model_emb = networks.ConvRTSN(n_seg=cfg.num_seg,
                                          emb_dim=cfg.emb_dim)
        else:
            raise NotImplementedError

        # multitask loss (verification)
        model_ver = networks.PairSim2(n_input=cfg.emb_dim)
        #model_ver = networks.PairSim(n_input=cfg.emb_dim)

        # get the embedding
        if cfg.feat == "sensors":
            input_ph = tf.placeholder(tf.float32,
                                      shape=[None, cfg.num_seg, None])
        elif cfg.feat == "resnet":
            input_ph = tf.placeholder(
                tf.float32, shape=[None, cfg.num_seg, None, None, None])
        dropout_ph = tf.placeholder(tf.float32, shape=[])
        model_emb.forward(input_ph, dropout_ph)
        if cfg.normalized:
            embedding = tf.nn.l2_normalize(model_emb.hidden,
                                           axis=-1,
                                           epsilon=1e-10)
        else:
            embedding = model_emb.hidden

        # variable for visualizing the embeddings
        emb_var = tf.Variable([0.0], name='embeddings')
        set_emb = tf.assign(emb_var, embedding, validate_shape=False)

        # calculated for monitoring all-pair embedding distance
        diffs = utils.all_diffs_tf(embedding, embedding)
        all_dist = utils.cdist_tf(diffs)
        tf.summary.histogram('embedding_dists', all_dist)

        # split embedding into anchor, positive and negative and calculate triplet loss
        anchor, positive, negative = tf.unstack(
            tf.reshape(embedding, [-1, 3, cfg.emb_dim]), 3, 1)
        metric_loss = networks.triplet_loss(anchor, positive, negative,
                                            cfg.alpha)

        # verification loss
        pos_pairs = tf.concat(
            [tf.expand_dims(anchor, axis=1),
             tf.expand_dims(positive, axis=1)],
            axis=1)
        pos_label = tf.ones((tf.shape(pos_pairs)[0], ), tf.int32)
        neg_pairs = tf.concat(
            [tf.expand_dims(anchor, axis=1),
             tf.expand_dims(negative, axis=1)],
            axis=1)
        neg_label = tf.zeros((tf.shape(neg_pairs)[0], ), tf.int32)

        ver_pairs = tf.concat([pos_pairs, neg_pairs], axis=0)
        ver_label = tf.concat([pos_label, neg_label], axis=0)

        model_ver.forward(ver_pairs, dropout_ph)
        logits = model_ver.logits
        pred = tf.argmax(logits, -1)

        ver_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=ver_label,
                                                           logits=logits))

        regularization_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = metric_loss + cfg.lambda_ver * ver_loss + regularization_loss * cfg.lambda_l2

        tf.summary.scalar('learning_rate', lr_ph)
        train_op = utils.optimize(total_loss, global_step, cfg.optimizer,
                                  lr_ph, tf.global_variables())

        saver = tf.train.Saver(max_to_keep=10)

        summary_op = tf.summary.merge_all()

        # session iterator for session sampling
        feat_paths_ph = tf.placeholder(tf.string,
                                       shape=[None, cfg.sess_per_batch])
        label_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        train_data = session_generator(feat_paths_ph,
                                       label_paths_ph,
                                       sess_per_batch=cfg.sess_per_batch,
                                       num_threads=2,
                                       shuffled=False,
                                       preprocess_func=model_emb.prepare_input)
        train_sess_iterator = train_data.make_initializable_iterator()
        next_train = train_sess_iterator.get_next()

        # prepare validation data
        val_sess = []
        val_feats = []
        val_labels = []
        val_boundaries = []
        for session in val_set:
            session_id = os.path.basename(session[1]).split('_')[0]
            eve_batch, lab_batch, boundary = load_data_and_label(
                session[0], session[1], model_emb.prepare_input_test
            )  # use prepare_input_test for testing time
            val_feats.append(eve_batch)
            val_labels.append(lab_batch)
            val_sess.extend([session_id] * eve_batch.shape[0])
            val_boundaries.extend(boundary)
        val_feats = np.concatenate(val_feats, axis=0)
        val_labels = np.concatenate(val_labels, axis=0)
        print("Shape of val_feats: ", val_feats.shape)

        # generate metadata.tsv for visualize embedding
        with open(os.path.join(result_dir, 'metadata_val.tsv'), 'w') as fout:
            fout.write('id\tlabel\tsession_id\tstart\tend\n')
            for i in range(len(val_sess)):
                fout.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    i, val_labels[i, 0], val_sess[i], val_boundaries[i][0],
                    val_boundaries[i][1]))

        # Start running the graph
        if cfg.gpu:
            os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        summary_writer = tf.summary.FileWriter(result_dir, sess.graph)

        with sess.as_default():

            sess.run(tf.global_variables_initializer())

            # load pretrain model, if needed
            if cfg.model_path:
                print("Restoring pretrained model: %s" % cfg.model_path)
                saver.restore(sess, cfg.model_path)

            ################## Training loop ##################
            epoch = -1
            while epoch < cfg.max_epochs - 1:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // batch_per_epoch

                # learning rate schedule, reference: "In defense of Triplet Loss"
                if epoch < cfg.static_epochs:
                    learning_rate = cfg.learning_rate
                else:
                    learning_rate = cfg.learning_rate * \
                            0.001**((epoch-cfg.static_epochs)/(cfg.max_epochs-cfg.static_epochs))

                # prepare data for this epoch
                random.shuffle(train_set)

                feat_paths = [path[0] for path in train_set]
                label_paths = [path[1] for path in train_set]
                # reshape a list to list of list
                # interesting hacky code from: https://stackoverflow.com/questions/10124751/convert-a-flat-list-to-list-of-list-in-python
                feat_paths = list(zip(*[iter(feat_paths)] *
                                      cfg.sess_per_batch))
                label_paths = list(
                    zip(*[iter(label_paths)] * cfg.sess_per_batch))

                sess.run(train_sess_iterator.initializer,
                         feed_dict={
                             feat_paths_ph: feat_paths,
                             label_paths_ph: label_paths
                         })

                # for each epoch
                batch_count = 1
                while True:
                    try:
                        # Hierarchical sampling (same as fast rcnn)
                        start_time_select = time.time()

                        # First, sample sessions for a batch
                        eve, se, lab = sess.run(next_train)

                        select_time1 = time.time() - start_time_select

                        # Get the embeddings of all events
                        eve_embedding = np.zeros((eve.shape[0], cfg.emb_dim),
                                                 dtype='float32')
                        for start, end in zip(
                                range(0, eve.shape[0], cfg.batch_size),
                                range(cfg.batch_size,
                                      eve.shape[0] + cfg.batch_size,
                                      cfg.batch_size)):
                            end = min(end, eve.shape[0])
                            emb = sess.run(embedding,
                                           feed_dict={
                                               input_ph: eve[start:end],
                                               dropout_ph: 1.0
                                           })
                            eve_embedding[start:end] = emb

                        # Second, sample triplets within sampled sessions
                        triplet_input, negative_count = select_triplets_facenet(
                            eve,
                            lab,
                            eve_embedding,
                            cfg.triplet_per_batch,
                            cfg.alpha,
                            metric=cfg.metric)

                        select_time2 = time.time(
                        ) - start_time_select - select_time1

                        if triplet_input is not None:
                            start_time_train = time.time()
                            # perform training on the selected triplets
                            err, metric_err, ver_err, y_pred, _, step, summ = sess.run(
                                [
                                    total_loss, metric_loss, ver_loss, pred,
                                    train_op, global_step, summary_op
                                ],
                                feed_dict={
                                    input_ph: triplet_input,
                                    dropout_ph: cfg.keep_prob,
                                    lr_ph: learning_rate
                                })

                            train_time = time.time() - start_time_train

                            # calculate accuracy
                            batch_label = np.hstack(
                                (np.ones((triplet_input.shape[0] // 3, ),
                                         dtype='int32'),
                                 np.zeros((triplet_input.shape[0] // 3, ),
                                          dtype='int32')))
                            acc = accuracy_score(batch_label, y_pred)
                            print ("%s\tEpoch: [%d][%d/%d]\tEvent num: %d\tTriplet num: %d\tSelect_time1: %.3f\tSelect_time2: %.3f\tTrain_time: %.3f\tLoss %.4f" % \
                                    (cfg.name, epoch+1, batch_count, batch_per_epoch, eve.shape[0], triplet_input.shape[0], select_time1, select_time2, train_time, err))

                            summary = tf.Summary(value=[
                                tf.Summary.Value(tag="train_loss",
                                                 simple_value=err),
                                tf.Summary.Value(tag="metric_loss",
                                                 simple_value=metric_err),
                                tf.Summary.Value(tag="ver_loss",
                                                 simple_value=ver_err),
                                tf.Summary.Value(tag="acc", simple_value=acc),
                                tf.Summary.Value(tag="negative_count",
                                                 simple_value=negative_count)
                            ])
                            summary_writer.add_summary(summary, step)
                            summary_writer.add_summary(summ, step)

                        batch_count += 1

                    except tf.errors.OutOfRangeError:
                        print("Epoch %d done!" % (epoch + 1))
                        break

                # validation on val_set
                print("Evaluating on validation set...")
                val_embeddings, _ = sess.run([embedding, set_emb],
                                             feed_dict={
                                                 input_ph: val_feats,
                                                 dropout_ph: 1.0
                                             })
                mAP, mPrec = utils.evaluate_simple(val_embeddings, val_labels)

                summary = tf.Summary(value=[
                    tf.Summary.Value(tag="Valiation mAP", simple_value=mAP),
                    tf.Summary.Value(tag="Validation [email protected]",
                                     simple_value=mPrec)
                ])
                summary_writer.add_summary(summary, step)

                # config for embedding visualization
                config = projector.ProjectorConfig()
                visual_embedding = config.embeddings.add()
                visual_embedding.tensor_name = emb_var.name
                visual_embedding.metadata_path = os.path.join(
                    result_dir, 'metadata_val.tsv')
                projector.visualize_embeddings(summary_writer, config)

                # save model
                saver.save(sess,
                           os.path.join(result_dir, cfg.name + '.ckpt'),
                           global_step=step)
Esempio n. 5
0
def main():

    cfg = TrainConfig().parse()
    print(cfg.name)
    result_dir = os.path.join(
        cfg.result_root,
        cfg.name + '_' + datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    utils.write_configure_to_file(cfg, result_dir)
    np.random.seed(seed=cfg.seed)

    # prepare dataset
    train_session = cfg.train_session
    train_set = prepare_dataset(cfg.feature_root, train_session, cfg.feat,
                                cfg.label_root)
    train_set = train_set[:cfg.label_num]
    batch_per_epoch = len(train_set) // cfg.sess_per_batch

    val_session = cfg.val_session
    val_set = prepare_dataset(cfg.feature_root, val_session, cfg.feat,
                              cfg.label_root)

    # construct the graph
    with tf.Graph().as_default():
        tf.set_random_seed(cfg.seed)
        global_step = tf.Variable(0, trainable=False)
        lr_ph = tf.placeholder(tf.float32, name='learning_rate')

        # load backbone model
        if cfg.network == "tsn":
            model_emb = networks.TSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "rtsn":
            model_emb = networks.RTSN(n_seg=cfg.num_seg,
                                      emb_dim=cfg.emb_dim,
                                      n_input=cfg.n_input)
        elif cfg.network == "convtsn":
            model_emb = networks.ConvTSN(n_seg=cfg.num_seg,
                                         emb_dim=cfg.emb_dim)
        elif cfg.network == "convrtsn":
            model_emb = networks.ConvRTSN(n_seg=cfg.num_seg,
                                          emb_dim=cfg.emb_dim,
                                          n_h=cfg.n_h,
                                          n_w=cfg.n_w,
                                          n_C=cfg.n_C,
                                          n_input=cfg.n_input)
        elif cfg.network == "convbirtsn":
            model_emb = networks.ConvBiRTSN(n_seg=cfg.num_seg,
                                            emb_dim=cfg.emb_dim)
        else:
            raise NotImplementedError
        model_ver = networks.PDDM(n_input=cfg.emb_dim)

        # get the embedding
        if cfg.feat == "sensors" or cfg.feat == "segment":
            input_ph = tf.placeholder(tf.float32,
                                      shape=[None, cfg.num_seg, None])
        elif cfg.feat == "resnet" or cfg.feat == "segment_down":
            input_ph = tf.placeholder(
                tf.float32, shape=[None, cfg.num_seg, None, None, None])
        dropout_ph = tf.placeholder(tf.float32, shape=[])
        model_emb.forward(input_ph, dropout_ph)
        if cfg.normalized:
            embedding = tf.nn.l2_normalize(model_emb.hidden,
                                           axis=-1,
                                           epsilon=1e-10)
        else:
            embedding = model_emb.hidden

        # variable for visualizing the embeddings
        emb_var = tf.Variable([0.0], name='embeddings')
        set_emb = tf.assign(emb_var, embedding, validate_shape=False)

        # calculated for monitoring all-pair embedding distance
        diffs = utils.all_diffs_tf(embedding, embedding)
        all_dist = utils.cdist_tf(diffs)
        tf.summary.histogram('embedding_dists', all_dist)

        # split embedding into anchor, positive and negative and calculate triplet loss
        anchor, positive, negative = tf.unstack(
            tf.reshape(embedding, [-1, 3, cfg.emb_dim]), 3, 1)
        metric_loss = networks.triplet_loss(anchor, positive, negative,
                                            cfg.alpha)

        model_ver.forward(tf.stack((anchor, positive), axis=1))
        pddm_ap = model_ver.prob[:, 0]
        model_ver.forward(tf.stack((anchor, negative), axis=1))
        pddm_an = model_ver.prob[:, 0]
        pddm_loss = tf.reduce_mean(
            tf.maximum(tf.add(tf.subtract(pddm_ap, pddm_an), 0.6), 0.0), 0)

        regularization_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = pddm_loss + 0.5 * metric_loss + regularization_loss * cfg.lambda_l2

        tf.summary.scalar('learning_rate', lr_ph)
        train_op = utils.optimize(total_loss, global_step, cfg.optimizer,
                                  lr_ph, tf.global_variables())

        saver = tf.train.Saver(max_to_keep=10)

        summary_op = tf.summary.merge_all()

        # session iterator for session sampling
        feat_paths_ph = tf.placeholder(tf.string,
                                       shape=[None, cfg.sess_per_batch])
        label_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        train_data = session_generator(feat_paths_ph,
                                       label_paths_ph,
                                       sess_per_batch=cfg.sess_per_batch,
                                       num_threads=2,
                                       shuffled=False,
                                       preprocess_func=model_emb.prepare_input)
        train_sess_iterator = train_data.make_initializable_iterator()
        next_train = train_sess_iterator.get_next()

        # prepare validation data
        val_feats = []
        val_labels = []
        for session in val_set:
            eve_batch, lab_batch, _ = load_data_and_label(
                session[0], session[1], model_emb.prepare_input_test
            )  # use prepare_input_test for testing time
            val_feats.append(eve_batch)
            val_labels.append(lab_batch)
        val_feats = np.concatenate(val_feats, axis=0)
        val_labels = np.concatenate(val_labels, axis=0)
        print("Shape of val_feats: ", val_feats.shape)

        # generate metadata.tsv for visualize embedding
        with open(os.path.join(result_dir, 'metadata_val.tsv'), 'w') as fout:
            for v in val_labels:
                fout.write('%d\n' % int(v))

        # Start running the graph
        if cfg.gpu:
            os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        summary_writer = tf.summary.FileWriter(result_dir, sess.graph)

        with sess.as_default():

            sess.run(tf.global_variables_initializer())

            # load pretrain model, if needed
            if cfg.model_path:
                print("Restoring pretrained model: %s" % cfg.model_path)
                saver.restore(sess, cfg.model_path)

            ################## Training loop ##################
            epoch = -1
            while epoch < cfg.max_epochs - 1:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // batch_per_epoch

                # learning rate schedule, reference: "In defense of Triplet Loss"
                if epoch < cfg.static_epochs:
                    learning_rate = cfg.learning_rate
                else:
                    learning_rate = cfg.learning_rate * \
                            0.001**((epoch-cfg.static_epochs)/(cfg.max_epochs-cfg.static_epochs))

                # prepare data for this epoch
                random.shuffle(train_set)

                feat_paths = [path[0] for path in train_set]
                label_paths = [path[1] for path in train_set]
                # reshape a list to list of list
                # interesting hacky code from: https://stackoverflow.com/questions/10124751/convert-a-flat-list-to-list-of-list-in-python
                feat_paths = list(zip(*[iter(feat_paths)] *
                                      cfg.sess_per_batch))
                label_paths = list(
                    zip(*[iter(label_paths)] * cfg.sess_per_batch))

                sess.run(train_sess_iterator.initializer,
                         feed_dict={
                             feat_paths_ph: feat_paths,
                             label_paths_ph: label_paths
                         })

                # for each epoch
                batch_count = 1
                while True:
                    try:
                        # Hierarchical sampling (same as fast rcnn)
                        start_time_select = time.time()

                        # First, sample sessions for a batch
                        eve, se, lab = sess.run(next_train)

                        select_time1 = time.time() - start_time_select

                        # Get the similarity of all events
                        sim_prob = np.zeros((eve.shape[0], eve.shape[0]),
                                            dtype='float32') * np.nan
                        comb = list(
                            itertools.combinations(range(eve.shape[0]), 2))
                        for start, end in zip(
                                range(0, len(comb), cfg.batch_size),
                                range(cfg.batch_size,
                                      len(comb) + cfg.batch_size,
                                      cfg.batch_size)):
                            end = min(end, len(comb))
                            comb_idx = []
                            for c in comb[start:end]:
                                comb_idx.extend([c[0], c[1], c[1]])
                            emb = sess.run(pddm_ap,
                                           feed_dict={
                                               input_ph: eve[comb_idx],
                                               dropout_ph: 1.0
                                           })
                            for i in range(emb.shape[0]):
                                sim_prob[comb[start + i][0],
                                         comb[start + i][1]] = emb[i]
                                sim_prob[comb[start + i][1],
                                         comb[start + i][0]] = emb[i]

                        # Second, sample triplets within sampled sessions
                        triplet_selected, active_count = utils.select_triplets_facenet(
                            lab, sim_prob, cfg.triplet_per_batch, cfg.alpha)

                        select_time2 = time.time(
                        ) - start_time_select - select_time1

                        start_time_train = time.time()
                        triplet_input_idx = [
                            idx for triplet in triplet_selected
                            for idx in triplet
                        ]
                        triplet_input = eve[triplet_input_idx]
                        # perform training on the selected triplets
                        err, _, step, summ = sess.run(
                            [total_loss, train_op, global_step, summary_op],
                            feed_dict={
                                input_ph: triplet_input,
                                dropout_ph: cfg.keep_prob,
                                lr_ph: learning_rate
                            })

                        train_time = time.time() - start_time_train
                        print ("%s\tEpoch: [%d][%d/%d]\tEvent num: %d\tTriplet num: %d\tSelect_time1: %.3f\tSelect_time2: %.3f\tTrain_time: %.3f\tLoss %.4f" % \
                                (cfg.name, epoch+1, batch_count, batch_per_epoch, eve.shape[0], triplet_input.shape[0]//3, select_time1, select_time2, train_time, err))

                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag="train_loss",
                                             simple_value=err),
                            tf.Summary.Value(tag="active_count",
                                             simple_value=active_count),
                            tf.Summary.Value(
                                tag="triplet_num",
                                simple_value=triplet_input.shape[0] // 3)
                        ])
                        summary_writer.add_summary(summary, step)
                        summary_writer.add_summary(summ, step)

                        batch_count += 1

                    except tf.errors.OutOfRangeError:
                        print("Epoch %d done!" % (epoch + 1))
                        break

                # validation on val_set
                print("Evaluating on validation set...")
                val_embeddings, _ = sess.run([embedding, set_emb],
                                             feed_dict={
                                                 input_ph: val_feats,
                                                 dropout_ph: 1.0
                                             })
                mAP, mPrec = utils.evaluate_simple(val_embeddings, val_labels)

                val_sim_prob = np.zeros(
                    (val_feats.shape[0], val_feats.shape[0]),
                    dtype='float32') * np.nan
                val_comb = list(
                    itertools.combinations(range(val_feats.shape[0]), 2))
                for start, end in zip(
                        range(0, len(val_comb), cfg.batch_size),
                        range(cfg.batch_size,
                              len(val_comb) + cfg.batch_size, cfg.batch_size)):
                    end = min(end, len(val_comb))
                    comb_idx = []
                    for c in val_comb[start:end]:
                        comb_idx.extend([c[0], c[1], c[1]])
                    emb = sess.run(pddm_ap,
                                   feed_dict={
                                       input_ph: val_feats[comb_idx],
                                       dropout_ph: 1.0
                                   })
                    for i in range(emb.shape[0]):
                        val_sim_prob[val_comb[start + i][0],
                                     val_comb[start + i][1]] = emb[i]
                        val_sim_prob[val_comb[start + i][1],
                                     val_comb[start + i][0]] = emb[i]

                mAP_PDDM = 0.0
                count = 0
                for i in range(val_labels.shape[0]):
                    if val_labels[i] > 0:
                        temp_labels = np.delete(val_labels, i, 0)
                        temp = np.delete(val_sim_prob, i, 1)
                        mAP_PDDM += average_precision_score(
                            np.squeeze(temp_labels == val_labels[i, 0]),
                            np.squeeze(1 - temp[i]))
                        count += 1
                mAP_PDDM /= count

                summary = tf.Summary(value=[
                    tf.Summary.Value(tag="Validation mAP", simple_value=mAP),
                    tf.Summary.Value(tag="Validation mAP_PDDM",
                                     simple_value=mAP_PDDM),
                    tf.Summary.Value(tag="Validation [email protected]",
                                     simple_value=mPrec)
                ])
                summary_writer.add_summary(summary, step)
                print("Epoch: [%d]\tmAP: %.4f\tmPrec: %.4f\tmAP_PDDM: %.4f" %
                      (epoch + 1, mAP, mPrec, mAP_PDDM))

                # config for embedding visualization
                config = projector.ProjectorConfig()
                visual_embedding = config.embeddings.add()
                visual_embedding.tensor_name = emb_var.name
                visual_embedding.metadata_path = os.path.join(
                    result_dir, 'metadata_val.tsv')
                projector.visualize_embeddings(summary_writer, config)

                # save model
                saver.save(sess,
                           os.path.join(result_dir, cfg.name + '.ckpt'),
                           global_step=step)