def main():

    cfg = TrainConfig().parse()
    print(cfg.name)
    result_dir = os.path.join(
        cfg.result_root,
        cfg.name + '_' + datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    utils.write_configure_to_file(cfg, result_dir)
    np.random.seed(seed=cfg.seed)

    # prepare dataset
    train_session = cfg.train_session
    train_set = prepare_multimodal_dataset(cfg.feature_root, train_session,
                                           cfg.feat, cfg.label_root)
    if cfg.task == "supervised":  # fully supervised task
        train_set = train_set[:cfg.label_num]
    batch_per_epoch = len(train_set) // cfg.sess_per_batch
    labeled_session = train_session[:cfg.label_num]

    val_session = cfg.val_session
    val_set = prepare_multimodal_dataset(cfg.feature_root, val_session,
                                         cfg.feat, cfg.label_root)

    # construct the graph
    with tf.Graph().as_default():
        tf.set_random_seed(cfg.seed)
        global_step = tf.Variable(0, trainable=False)
        lr_ph = tf.placeholder(tf.float32, name='learning_rate')

        ####################### Load models here ########################
        sensors_emb_dim = 32
        segment_emb_dim = 32

        with tf.variable_scope("modality_core"):
            # load backbone model
            if cfg.network == "convtsn":
                model_emb = networks.ConvTSN(n_seg=cfg.num_seg,
                                             emb_dim=cfg.emb_dim)
            elif cfg.network == "convrtsn":
                model_emb = networks.ConvRTSN(n_seg=cfg.num_seg,
                                              emb_dim=cfg.emb_dim)
            elif cfg.network == "convbirtsn":
                model_emb = networks.ConvBiRTSN(n_seg=cfg.num_seg,
                                                emb_dim=cfg.emb_dim)
            else:
                raise NotImplementedError

            input_ph = tf.placeholder(
                tf.float32, shape=[None, cfg.num_seg, None, None, None])
            dropout_ph = tf.placeholder(tf.float32, shape=[])
            model_emb.forward(input_ph,
                              dropout_ph)  # for lstm has variable scope

            with tf.variable_scope("sensors"):
                model_output_sensors = networks.OutputLayer(
                    n_input=cfg.emb_dim, n_output=sensors_emb_dim)
            with tf.variable_scope("segment"):
                model_output_segment = networks.OutputLayer(
                    n_input=cfg.emb_dim, n_output=segment_emb_dim)

        lambda_mul_ph = tf.placeholder(tf.float32, shape=[])
        with tf.variable_scope("modality_sensors"):
            model_emb_sensors = networks.RTSN(n_seg=cfg.num_seg,
                                              emb_dim=sensors_emb_dim)

            input_sensors_ph = tf.placeholder(tf.float32,
                                              shape=[None, cfg.num_seg, 8])
            model_emb_sensors.forward(input_sensors_ph, dropout_ph)

            var_list = {}
            for v in tf.global_variables():
                if v.op.name.startswith("modality_sensors"):
                    var_list[v.op.name.replace("modality_sensors/", "")] = v
            restore_saver_sensors = tf.train.Saver(var_list)

        with tf.variable_scope("modality_segment"):
            model_emb_segment = networks.RTSN(n_seg=cfg.num_seg,
                                              emb_dim=segment_emb_dim,
                                              n_input=357)

            input_segment_ph = tf.placeholder(tf.float32,
                                              shape=[None, cfg.num_seg, 357])
            model_emb_segment.forward(input_segment_ph, dropout_ph)

            var_list = {}
            for v in tf.global_variables():
                if v.op.name.startswith("modality_segment"):
                    var_list[v.op.name.replace("modality_segment/", "")] = v
            restore_saver_segment = tf.train.Saver(var_list)

        ############################# Forward Pass #############################

        if cfg.normalized:
            embedding = tf.nn.l2_normalize(model_emb.hidden,
                                           axis=-1,
                                           epsilon=1e-10)
            embedding_sensors = tf.nn.l2_normalize(model_emb_sensors.hidden,
                                                   axis=-1,
                                                   epsilon=1e-10)
            embedding_segment = tf.nn.l2_normalize(model_emb_segment.hidden,
                                                   axis=-1,
                                                   epsilon=1e-10)
        else:
            embedding = model_emb.hidden
            embedding_sensors = model_emb_sensors.hidden
            embedding_segment = model_emb_segment.hidden

        # get the number of unsupervised training
        unsup_num = tf.shape(input_sensors_ph)[0]

        # variable for visualizing the embeddings
        emb_var = tf.Variable(tf.zeros([1116, cfg.emb_dim], dtype=tf.float32),
                              name='embeddings')
        set_emb = tf.assign(emb_var, embedding, validate_shape=False)

        # calculated for monitoring all-pair embedding distance
        diffs = utils.all_diffs_tf(embedding, embedding)
        all_dist = utils.cdist_tf(diffs)
        tf.summary.histogram('embedding_dists', all_dist)

        # split embedding into anchor, positive and negative and calculate triplet loss
        anchor, positive, negative = tf.unstack(
            tf.reshape(embedding[:-unsup_num], [-1, 3, cfg.emb_dim]), 3, 1)
        metric_loss = networks.triplet_loss(anchor, positive, negative,
                                            cfg.alpha)

        model_output_sensors.forward(tf.nn.relu(embedding[-unsup_num:]),
                                     dropout_ph)
        logits_sensors = model_output_sensors.logits
        model_output_segment.forward(tf.nn.relu(embedding[-unsup_num:]),
                                     dropout_ph)
        logits_segment = model_output_segment.logits

        # MSE loss
        MSE_loss_sensors = tf.losses.mean_squared_error(
            embedding_sensors, logits_sensors) / sensors_emb_dim
        MSE_loss_segment = tf.losses.mean_squared_error(
            embedding_sensors, logits_segment) / segment_emb_dim
        MSE_loss = MSE_loss_sensors + MSE_loss_segment
        regularization_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = tf.cond(
            tf.equal(unsup_num,
                     tf.shape(embedding)[0]), lambda: MSE_loss * lambda_mul_ph
            + regularization_loss * cfg.lambda_l2, lambda: metric_loss +
            MSE_loss * lambda_mul_ph + regularization_loss * cfg.lambda_l2)

        tf.summary.scalar('learning_rate', lr_ph)
        # only train the core branch
        train_var_list = [
            v for v in tf.global_variables()
            if v.op.name.startswith("modality_core")
        ]
        train_op = utils.optimize(total_loss, global_step, cfg.optimizer,
                                  lr_ph, train_var_list)

        saver = tf.train.Saver(max_to_keep=10)

        summary_op = tf.summary.merge_all()

        #########################################################################

        # session iterator for session sampling
        feat_paths_ph = tf.placeholder(tf.string,
                                       shape=[None, cfg.sess_per_batch])
        feat2_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        feat3_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        label_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        train_data = multimodal_session_generator(
            feat_paths_ph,
            feat2_paths_ph,
            feat3_paths_ph,
            label_paths_ph,
            sess_per_batch=cfg.sess_per_batch,
            num_threads=2,
            shuffled=False,
            preprocess_func=[
                model_emb.prepare_input, model_emb_sensors.prepare_input,
                model_emb_segment.prepare_input
            ])
        train_sess_iterator = train_data.make_initializable_iterator()
        next_train = train_sess_iterator.get_next()

        # prepare validation data
        val_sess = []
        val_feats = []
        val_feats2 = []
        val_feats3 = []
        val_labels = []
        val_boundaries = []
        for session in val_set:
            session_id = os.path.basename(session[1]).split('_')[0]
            eve_batch, lab_batch, boundary = load_data_and_label(
                session[0], session[-1], model_emb.prepare_input_test
            )  # use prepare_input_test for testing time
            val_feats.append(eve_batch)
            val_labels.append(lab_batch)
            val_sess.extend([session_id] * eve_batch.shape[0])
            val_boundaries.extend(boundary)

            eve2_batch, _, _ = load_data_and_label(
                session[1], session[-1], model_emb_sensors.prepare_input_test)
            val_feats2.append(eve2_batch)

            eve3_batch, _, _ = load_data_and_label(
                session[2], session[-1], model_emb_segment.prepare_input_test)
            val_feats3.append(eve3_batch)
        val_feats = np.concatenate(val_feats, axis=0)
        val_feats2 = np.concatenate(val_feats2, axis=0)
        val_feats3 = np.concatenate(val_feats3, axis=0)
        val_labels = np.concatenate(val_labels, axis=0)
        print("Shape of val_feats: ", val_feats.shape)

        # generate metadata.tsv for visualize embedding
        with open(os.path.join(result_dir, 'metadata_val.tsv'), 'w') as fout:
            fout.write('id\tlabel\tsession_id\tstart\tend\n')
            for i in range(len(val_sess)):
                fout.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    i, val_labels[i, 0], val_sess[i], val_boundaries[i][0],
                    val_boundaries[i][1]))

        #########################################################################

        # Start running the graph
        if cfg.gpu:
            os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        summary_writer = tf.summary.FileWriter(result_dir, sess.graph)

        with sess.as_default():

            sess.run(tf.global_variables_initializer())
            print("Restoring sensors model: %s" % cfg.sensors_path)
            restore_saver_sensors.restore(sess, cfg.sensors_path)
            print("Restoring segment model: %s" % cfg.segment_path)
            restore_saver_segment.restore(sess, cfg.segment_path)

            # load pretrain model, if needed
            if cfg.model_path:
                print("Restoring pretrained model: %s" % cfg.model_path)
                saver.restore(sess, cfg.model_path)

            ################## Training loop ##################
            epoch = -1
            while epoch < cfg.max_epochs - 1:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // batch_per_epoch

                # learning rate schedule, reference: "In defense of Triplet Loss"
                if epoch < cfg.static_epochs:
                    learning_rate = cfg.learning_rate
                else:
                    learning_rate = cfg.learning_rate * \
                            0.01**((epoch-cfg.static_epochs)/(cfg.max_epochs-cfg.static_epochs))

                # prepare data for this epoch
                random.shuffle(train_set)

                paths = list(zip(*[iter(train_set)] * cfg.sess_per_batch))

                feat_paths = [[p[0] for p in path] for path in paths]
                feat2_paths = [[p[1] for p in path] for path in paths]
                feat3_paths = [[p[2] for p in path] for path in paths]
                label_paths = [[p[-1] for p in path] for path in paths]

                sess.run(train_sess_iterator.initializer,
                         feed_dict={
                             feat_paths_ph: feat_paths,
                             feat2_paths_ph: feat2_paths,
                             feat3_paths_ph: feat3_paths,
                             label_paths_ph: label_paths
                         })

                # for each epoch
                batch_count = 1
                while True:
                    try:
                        ##################### Data loading ########################
                        start_time = time.time()
                        eve, eve_sensors, eve_segment, lab, batch_sess = sess.run(
                            next_train)

                        # for memory concern, 1000 events are used in maximum
                        if eve.shape[0] > 1000:
                            idx = np.random.permutation(eve.shape[0])[:1000]
                            eve = eve[idx]
                            eve_sensors = eve_sensors[idx]
                            eve_segment = eve_segment[idx]
                            lab = lab[idx]
                            batch_sess = batch_sess[idx]
                        load_time = time.time() - start_time

                        ##################### Triplet selection #####################
                        start_time = time.time()
                        # for labeled sessions, use facenet sampling
                        eve_labeled = []
                        lab_labeled = []
                        for i in range(eve.shape[0]):
                            # FIXME: use decode again to get session_id str
                            if batch_sess[i, 0].decode() in labeled_session:
                                eve_labeled.append(eve[i])
                                lab_labeled.append(lab[i])

                        if len(eve_labeled):  # if labeled sessions exist
                            eve_labeled = np.stack(eve_labeled, axis=0)
                            lab_labeled = np.stack(lab_labeled, axis=0)

                            # Get the embeddings of all events
                            eve_embedding = np.zeros(
                                (eve_labeled.shape[0], cfg.emb_dim),
                                dtype='float32')
                            for start, end in zip(
                                    range(0, eve_labeled.shape[0],
                                          cfg.batch_size),
                                    range(
                                        cfg.batch_size,
                                        eve_labeled.shape[0] + cfg.batch_size,
                                        cfg.batch_size)):
                                end = min(end, eve_labeled.shape[0])
                                emb = sess.run(embedding,
                                               feed_dict={
                                                   input_ph:
                                                   eve_labeled[start:end],
                                                   dropout_ph: 1.0
                                               })
                                eve_embedding[start:end] = np.copy(emb)

                            # Second, sample triplets within sampled sessions
                            all_diff = utils.all_diffs(eve_embedding,
                                                       eve_embedding)
                            triplet_input_idx, active_count = utils.select_triplets_facenet(
                                lab_labeled,
                                utils.cdist(all_diff, metric=cfg.metric),
                                cfg.triplet_per_batch,
                                cfg.alpha,
                                num_negative=cfg.num_negative)

                            if len(triplet_input_idx) == 0:
                                triplet_input = eve_labeled[triplet_input_idx]

                        else:
                            active_count = -1

                        # for all sessions in the batch
                        perm_idx = np.random.permutation(eve.shape[0])
                        perm_idx = perm_idx[:min(3 * (len(perm_idx) // 3), 3 *
                                                 cfg.triplet_per_batch)]
                        mul_input = eve[perm_idx]

                        if len(eve_labeled) and triplet_input_idx is not None:
                            triplet_input = np.concatenate(
                                (triplet_input, mul_input), axis=0)
                        else:
                            triplet_input = mul_input
                        sensors_input = eve_sensors[perm_idx]
                        segment_input = eve_segment[perm_idx]

                        ##################### Start training  ########################

                        # supervised initialization
                        if epoch < cfg.multimodal_epochs:
                            if not len(eve_labeled
                                       ):  # if no labeled sessions exist
                                continue
                            err, mse_err, _, step, summ = sess.run(
                                [
                                    total_loss, MSE_loss, train_op,
                                    global_step, summary_op
                                ],
                                feed_dict={
                                    input_ph: triplet_input,
                                    input_sensors_ph: sensors_input,
                                    dropout_ph: cfg.keep_prob,
                                    lambda_mul_ph: 0.0,
                                    lr_ph: learning_rate
                                })
                        else:
                            print(triplet_input.shape)
                            err, mse_err1, mse_err2, _, step, summ = sess.run(
                                [
                                    total_loss, MSE_loss_sensors,
                                    MSE_loss_segment, train_op, global_step,
                                    summary_op
                                ],
                                feed_dict={
                                    input_ph: triplet_input,
                                    input_sensors_ph: sensors_input,
                                    input_segment_ph: segment_input,
                                    dropout_ph: cfg.keep_prob,
                                    lambda_mul_ph: cfg.lambda_multimodal,
                                    lr_ph: learning_rate
                                })
                        train_time = time.time() - start_time

                        print ("%s\tEpoch: [%d][%d/%d]\tEvent num: %d\tLoad time: %.3f\tTrain_time: %.3f\tLoss %.4f" % \
                                (cfg.name, epoch+1, batch_count, batch_per_epoch, eve.shape[0], load_time, train_time, err))

                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag="train_loss",
                                             simple_value=err),
                            tf.Summary.Value(tag="active_count",
                                             simple_value=active_count),
                            tf.Summary.Value(
                                tag="triplet_num",
                                simple_value=(triplet_input.shape[0] -
                                              sensors_input.shape[0]) // 3),
                            tf.Summary.Value(tag="MSE_loss_sensors",
                                             simple_value=mse_err1),
                            tf.Summary.Value(tag="MSE_loss_segment",
                                             simple_value=mse_err2)
                        ])

                        summary_writer.add_summary(summary, step)
                        summary_writer.add_summary(summ, step)

                        batch_count += 1

                    except tf.errors.OutOfRangeError:
                        print("Epoch %d done!" % (epoch + 1))
                        break

                # validation on val_set
                print("Evaluating on validation set...")
                val_err1, val_err2, val_embeddings, _ = sess.run(
                    [MSE_loss_sensors, MSE_loss_segment, embedding, set_emb],
                    feed_dict={
                        input_ph: val_feats,
                        input_sensors_ph: val_feats2,
                        input_segment_ph: val_feats3,
                        dropout_ph: 1.0
                    })
                mAP, mPrec = utils.evaluate_simple(val_embeddings, val_labels)

                summary = tf.Summary(value=[
                    tf.Summary.Value(tag="Valiation mAP", simple_value=mAP),
                    tf.Summary.Value(tag="Validation [email protected]",
                                     simple_value=mPrec),
                    tf.Summary.Value(tag="Validation mse loss sensors",
                                     simple_value=val_err1),
                    tf.Summary.Value(tag="Validation mse loss segment",
                                     simple_value=val_err2)
                ])
                summary_writer.add_summary(summary, step)
                print("Epoch: [%d]\tmAP: %.4f\tmPrec: %.4f" %
                      (epoch + 1, mAP, mPrec))

                # config for embedding visualization
                config = projector.ProjectorConfig()
                visual_embedding = config.embeddings.add()
                visual_embedding.tensor_name = emb_var.name
                visual_embedding.metadata_path = os.path.join(
                    result_dir, 'metadata_val.tsv')
                projector.visualize_embeddings(summary_writer, config)

                # save model
                saver.save(sess,
                           os.path.join(result_dir, cfg.name + '.ckpt'),
                           global_step=step)
Пример #2
0
def main():

    cfg = TrainConfig().parse()
    print(cfg.name)
    result_dir = os.path.join(
        cfg.result_root,
        cfg.name + '_' + datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    utils.write_configure_to_file(cfg, result_dir)
    np.random.seed(seed=cfg.seed)

    # prepare dataset
    att_train = np.load('/mnt/work/CUB_200_2011/data/att_train.npy')
    val_att = np.load('/mnt/work/CUB_200_2011/data/att_test.npy')
    label_train = np.load('/mnt/work/CUB_200_2011/data/label_train.npy')
    label_train -= 1  # make labels start from 0
    val_labels = np.load('/mnt/work/CUB_200_2011/data/label_test.npy')
    pdb.set_trace()

    class_idx_dict = {}
    for i, l in enumerate(label_train):
        l = int(l)
        if l not in class_idx_dict:
            class_idx_dict[l] = [i]
        else:
            class_idx_dict[l].append(i)
    C = len(list(class_idx_dict.keys()))

    val_triplet_idx = select_triplets_random(val_labels, 1000)

    # generate metadata.tsv for visualize embedding
    with open(os.path.join(result_dir, 'metadata_val.tsv'), 'w') as fout:
        for l in val_labels:
            fout.write('{}\n'.format(int(l)))

    # construct the graph
    with tf.Graph().as_default():
        tf.set_random_seed(cfg.seed)
        global_step = tf.Variable(0, trainable=False)
        lr_ph = tf.placeholder(tf.float32, name='learning_rate')

        # load backbone model
        #model_emb = networks.CUBLayer(n_input=312, n_output=cfg.emb_dim)
        model_emb = networks.OutputLayer(n_input=312, n_output=cfg.emb_dim)
        model_ver = networks.PDDM(n_input=cfg.emb_dim)

        # get the embedding
        input_ph = tf.placeholder(tf.float32, shape=[None, 312])
        dropout_ph = tf.placeholder(tf.float32, shape=[])
        model_emb.forward(input_ph, dropout_ph)
        if cfg.normalized:
            embedding = tf.nn.l2_normalize(model_emb.logits,
                                           axis=-1,
                                           epsilon=1e-10)
        else:
            embedding = model_emb.logits

        # variable for visualizing the embeddings
        emb_var = tf.Variable([0.0], name='embeddings')
        set_emb = tf.assign(emb_var, embedding, validate_shape=False)

        # calculated for monitoring all-pair embedding distance
        #        diffs = utils.all_diffs_tf(embedding, embedding)
        #        all_dist = utils.cdist_tf(diffs)
        #        tf.summary.histogram('embedding_dists', all_dist)

        # split embedding into anchor, positive and negative and calculate triplet loss
        anchor, positive, negative = tf.unstack(
            tf.reshape(embedding, [-1, 3, cfg.emb_dim]), 3, 1)
        metric_loss = networks.triplet_loss(anchor, positive, negative,
                                            cfg.alpha)

        model_ver.forward(tf.stack((anchor, positive), axis=1))
        pddm_ap = model_ver.prob[:, 0]
        model_ver.forward(tf.stack((anchor, negative), axis=1))
        pddm_an = model_ver.prob[:, 0]
        pddm_loss = tf.reduce_mean(
            tf.maximum(tf.add(tf.subtract(pddm_ap, pddm_an), 0.6), 0.0), 0)

        regularization_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = pddm_loss + 0.5 * metric_loss + regularization_loss * cfg.lambda_l2

        tf.summary.scalar('learning_rate', lr_ph)
        train_op = utils.optimize(total_loss, global_step, cfg.optimizer,
                                  lr_ph, tf.global_variables())

        saver = tf.train.Saver(max_to_keep=10)

        summary_op = tf.summary.merge_all()

        # Start running the graph
        if cfg.gpu:
            os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        summary_writer = tf.summary.FileWriter(result_dir, sess.graph)

        with sess.as_default():

            sess.run(tf.global_variables_initializer())

            ################## Training loop ##################
            for epoch in range(cfg.max_epochs):

                # learning rate schedule, reference: "In defense of Triplet Loss"
                if epoch < cfg.static_epochs:
                    learning_rate = cfg.learning_rate
                else:
                    learning_rate = cfg.learning_rate * \
                            0.001**((epoch-cfg.static_epochs)/(cfg.max_epochs-cfg.static_epochs))

                # sample images
                class_in_batch = set()
                idx_batch = np.array([], dtype=np.int32)
                while len(idx_batch) < cfg.batch_size:
                    sampled_class = np.random.choice(
                        list(class_idx_dict.keys()))
                    if not sampled_class in class_in_batch:
                        class_in_batch.add(sampled_class)
                        subsample_size = np.random.choice(range(5, 11))
                        subsample = np.random.permutation(
                            class_idx_dict[sampled_class])[:subsample_size]
                        idx_batch = np.append(idx_batch, subsample)
                idx_batch = idx_batch[:cfg.batch_size]

                feat_batch = att_train[idx_batch]
                lab_batch = label_train[idx_batch]

                # Get the similarity of all events
                sim_prob = np.zeros((feat_batch.shape[0], feat_batch.shape[0]),
                                    dtype='float32') * np.nan
                comb = list(
                    itertools.combinations(range(feat_batch.shape[0]), 2))
                for start, end in zip(
                        range(0, len(comb), cfg.batch_size),
                        range(cfg.batch_size,
                              len(comb) + cfg.batch_size, cfg.batch_size)):
                    end = min(end, len(comb))
                    comb_idx = []
                    for c in comb[start:end]:
                        comb_idx.extend([c[0], c[1], c[1]])
                    emb = sess.run(pddm_ap,
                                   feed_dict={
                                       input_ph: feat_batch[comb_idx],
                                       dropout_ph: 1.0
                                   })
                    for i in range(emb.shape[0]):
                        sim_prob[comb[start + i][0],
                                 comb[start + i][1]] = emb[i]
                        sim_prob[comb[start + i][1],
                                 comb[start + i][0]] = emb[i]

                triplet_input_idx, active_count = select_triplets_facenet(
                    lab_batch,
                    sim_prob,
                    cfg.triplet_per_batch,
                    cfg.alpha,
                    num_negative=cfg.num_negative)

                if triplet_input_idx is not None:
                    triplet_input = feat_batch[triplet_input_idx]

                    # perform training on the selected triplets
                    err, _, step, summ = sess.run(
                        [total_loss, train_op, global_step, summary_op],
                        feed_dict={
                            input_ph: triplet_input,
                            dropout_ph: cfg.keep_prob,
                            lr_ph: learning_rate
                        })

                    print ("%s\tEpoch: %d\tImages num: %d\tTriplet num: %d\tLoss %.4f" % \
                            (cfg.name, epoch+1, feat_batch.shape[0], triplet_input.shape[0]//3, err))

                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag="train_loss", simple_value=err),
                        tf.Summary.Value(tag="active_count",
                                         simple_value=active_count),
                        tf.Summary.Value(tag="images_num",
                                         simple_value=feat_batch.shape[0]),
                        tf.Summary.Value(tag="triplet_num",
                                         simple_value=triplet_input.shape[0] //
                                         3)
                    ])
                    summary_writer.add_summary(summary, step)
                    summary_writer.add_summary(summ, step)

                # validation on val_set
                if (epoch + 1) % 100 == 0:
                    print("Evaluating on validation set...")
                    val_err = sess.run(total_loss,
                                       feed_dict={
                                           input_ph: val_att[val_triplet_idx],
                                           dropout_ph: 1.0
                                       })

                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag="Valiation loss",
                                         simple_value=val_err),
                    ])
                    print("Epoch: [%d]\tloss: %.4f" % (epoch + 1, val_err))

                    if (epoch + 1) % 1000 == 0:
                        val_embeddings, _ = sess.run([embedding, set_emb],
                                                     feed_dict={
                                                         input_ph: val_att,
                                                         dropout_ph: 1.0
                                                     })
                        mAP, mPrec, recall = utils.evaluate_simple(
                            val_embeddings, val_labels)

                        val_sim_prob = np.zeros(
                            (val_att.shape[0], val_att.shape[0]),
                            dtype='float32') * np.nan
                        val_comb = list(
                            itertools.combinations(range(val_att.shape[0]), 2))
                        for start, end in zip(
                                range(0, len(val_comb), cfg.batch_size),
                                range(cfg.batch_size,
                                      len(val_comb) + cfg.batch_size,
                                      cfg.batch_size)):
                            end = min(end, len(val_comb))
                            comb_idx = []
                            for c in val_comb[start:end]:
                                comb_idx.extend([c[0], c[1], c[1]])
                            emb = sess.run(pddm_ap,
                                           feed_dict={
                                               input_ph: val_att[comb_idx],
                                               dropout_ph: 1.0
                                           })
                            for i in range(emb.shape[0]):
                                val_sim_prob[val_comb[start + i][0],
                                             val_comb[start + i][1]] = emb[i]
                                val_sim_prob[val_comb[start + i][1],
                                             val_comb[start + i][0]] = emb[i]

                        mAP_PDDM = 0.0
                        count = 0
                        for i in range(val_labels.shape[0]):
                            if val_labels[i] > 0:
                                temp_labels = np.delete(val_labels, i, 0)
                                temp = np.delete(val_sim_prob, i, 1)
                                mAP_PDDM += average_precision_score(
                                    np.squeeze(temp_labels == val_labels[i]),
                                    np.squeeze(1 - temp[i]))
                                count += 1
                        mAP_PDDM /= count

                        print(
                            "Epoch: [%d]\tmAP: %.4f\trecall: %.4f\tmAP_PDDM: %.4f"
                            % (epoch + 1, mAP, recall, mAP_PDDM))
                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag="Valiation mAP",
                                             simple_value=mAP),
                            tf.Summary.Value(tag="Validation Recall@1",
                                             simple_value=recall),
                            tf.Summary.Value(tag="Validation PDDM_mAP",
                                             simple_value=mAP_PDDM),
                            tf.Summary.Value(tag="Validation [email protected]",
                                             simple_value=mPrec)
                        ])

                        # config for embedding visualization
                        config = projector.ProjectorConfig()
                        visual_embedding = config.embeddings.add()
                        visual_embedding.tensor_name = emb_var.name
                        visual_embedding.metadata_path = os.path.join(
                            result_dir, 'metadata_val.tsv')
                        projector.visualize_embeddings(summary_writer, config)

                    summary_writer.add_summary(summary, step)

                    # save model
                    saver.save(sess,
                               os.path.join(result_dir, cfg.name + '.ckpt'),
                               global_step=step)
def main():

    cfg = EvalConfig().parse()
    np.random.seed(seed=cfg.seed)

    test_session = cfg.test_session
    test_set = prepare_dataset(cfg.feature_root, test_session, cfg.feat, cfg.label_root)

    ####################### Load models here ########################

    input_ph = tf.placeholder(tf.float32, shape=[None, cfg.num_seg, None, None, None])
    dropout_ph = tf.placeholder(tf.float32, shape=[])

    with tf.variable_scope("modality_core"):
        # load backbone model
        if cfg.network == "convtsn":
            model_emb = networks.ConvTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "convrtsn":
            model_emb = networks.ConvRTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        else:
            raise NotImplementedError

        model_emb.forward(input_ph, dropout_ph)    # for lstm has variable scope

        var_list = {}
        for v in tf.global_variables():
            if v.op.name.startswith("modality_core"):
                var_list[v.op.name.replace("modality_core/","")] = v
        restore_saver = tf.train.Saver(var_list)

    with tf.variable_scope("modality_sensors"):
        sensors_emb_dim = 128
        if cfg.network == "convtsn":
            model_emb_sensors = networks.ConvTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        elif cfg.network == "convrtsn":
            model_emb_sensors = networks.ConvRTSN(n_seg=cfg.num_seg, emb_dim=cfg.emb_dim)
        else:
            raise NotImplementedError
        model_output_sensors = networks.OutputLayer(n_input=sensors_emb_dim, n_output=8)

        model_emb_sensors.forward(input_ph, dropout_ph)
        model_output_sensors.forward(tf.nn.relu(model_emb_sensors.hidden), dropout_ph)

        var_list = {}
        for v in tf.global_variables():
            if v.op.name.startswith("modality_sensors"):
                var_list[v.op.name.replace("modality_sensors/","")] = v
        restore_saver_sensors = tf.train.Saver(var_list)

    ############################# Forward Pass #############################

    # get embeddings
    embedding = tf.nn.l2_normalize(model_emb.hidden, axis=-1, epsilon=1e-10)
    if cfg.use_output:
        if cfg.normalized:
            embedding_sensors = tf.nn.l2_normalize(model_output_sensors.logits)
        else:
            embedding_sensors = model_output_sensors.logits
    else:
        embedding_sensors = tf.nn.l2_normalize(model_emb_sensors.hidden, axis=-1, epsilon=1e-10)

    #########################################################################

    # Testing
    if cfg.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    saver = tf.train.Saver()
    with sess.as_default():
        sess.run(tf.global_variables_initializer())

        # load the model (note that model_path already contains snapshot number
        restore_saver.restore(sess, cfg.model_path)
        print ("Restoring the model: {}".format(os.path.basename(cfg.model_path)))
        restore_saver_sensors.restore(sess, cfg.sensors_path)
        print ("Restoring the model: {}".format(os.path.basename(cfg.sensors_path)))

        eve_embeddings = []
        sensors_embeddings = []
        labels = []
        for i, session in enumerate(test_set):
            session_id = os.path.basename(session[1]).split('_')[0]
            print ("{0} / {1}: {2}".format(i, len(test_set), session_id))

            eve_batch, lab_batch, _ = load_data_and_label(session[0], session[1], model_emb.prepare_input_test)    # use prepare_input_test for testing time

            emb, emb_s = sess.run([embedding, embedding_sensors], 
                    feed_dict={input_ph: eve_batch, dropout_ph: 1.0})

            eve_embeddings.append(emb)
            sensors_embeddings.append(emb_s)
            labels.append(lab_batch)

        eve_embeddings = np.concatenate(eve_embeddings, axis=0)
        sensors_embeddings = np.concatenate(sensors_embeddings, axis=0)
        labels = np.concatenate(labels, axis=0)

    # evaluate the results
    fused_embeddings = np.concatenate((eve_embeddings, sensors_embeddings), axis=1)
    mAP, mAP_event, mPrec, confusion, count, recall = evaluate(fused_embeddings, np.squeeze(labels))

    mAP_macro = 0.0
    for key in mAP_event:
        mAP_macro += mAP_event[key]
    mAP_macro /= len(list(mAP_event.keys()))

    print ("%d events with dim %d for evaluation." % (labels.shape[0], fused_embeddings.shape[1]))
    print ("mAP = {}".format(mAP))
    print ("mAP_macro = {}".format(mAP_macro))
    print ("[email protected] = {}".format(mPrec))
    print ("Recall@1 = {}, Recall@10 = {}, Recall@100 = {}".format(recall[0], recall[1], recall[2]))

    keys = confusion['labels']
    for i, key in enumerate(keys):
        if key not in mAP_event:
            continue
        print ("Event {0}: {1}, ratio = {2}, mAP = {3}, [email protected] = {4}".format(
            key,
            honda_num2labels[key],
            float(count[i]) / np.sum(count),
            mAP_event[key],
            confusion['confusion_matrix'][i, i]))

    # store results
    pkl.dump({"mAP": mAP,
              "mAP_macro": mAP_macro,
              "mAP_event": mAP_event,
              "mPrec": mPrec,
              "confusion": confusion,
              "count": count,
              "recall": recall},
              open(os.path.join(os.path.dirname(cfg.model_path), "results.pkl"), 'wb'))
def main():

    cfg = TrainConfig().parse()
    print(cfg.name)
    result_dir = os.path.join(
        cfg.result_root,
        cfg.name + '_' + datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    utils.write_configure_to_file(cfg, result_dir)
    np.random.seed(seed=cfg.seed)

    # prepare dataset
    train_session = cfg.train_session
    train_set = prepare_multimodal_dataset(cfg.feature_root, train_session,
                                           cfg.feat, cfg.label_root)
    batch_per_epoch = len(train_set) // cfg.sess_per_batch

    val_session = cfg.val_session
    val_set = prepare_multimodal_dataset(
        cfg.feature_root, val_session, cfg.feat,
        cfg.label_root)  # only have one modality in testing time

    # construct the graph
    with tf.Graph().as_default():
        tf.set_random_seed(cfg.seed)
        global_step = tf.Variable(0, trainable=False)
        lr_ph = tf.placeholder(tf.float32, name='learning_rate')

        # load backbone model
        if cfg.network == "convtsn":
            model_emb = networks.ConvTSN(n_seg=cfg.num_seg,
                                         emb_dim=cfg.emb_dim)
        elif cfg.network == "convrtsn":
            model_emb = networks.ConvRTSN(n_seg=cfg.num_seg,
                                          emb_dim=cfg.emb_dim)
        else:
            raise NotImplementedError

        input_ph = tf.placeholder(tf.float32,
                                  shape=[None, cfg.num_seg, None, None, None])
        output_ph = tf.placeholder(tf.float32,
                                   shape=(None, ) + cfg.feat_dim[cfg.feat[1]])
        dropout_ph = tf.placeholder(tf.float32, shape=[])
        model_emb.forward(input_ph, dropout_ph)
        hidden = model_emb.hidden
        embedding = tf.nn.l2_normalize(model_emb.hidden,
                                       axis=-1,
                                       epsilon=1e-10)

        # variable for visualizing the embeddings
        emb_var = tf.Variable([0.0], name='embeddings')
        set_emb = tf.assign(emb_var, embedding, validate_shape=False)

        model_output = networks.OutputLayer(
            n_input=cfg.emb_dim, n_output=cfg.feat_dim[cfg.feat[1]][0])
        model_output.forward(tf.nn.relu(hidden), dropout_ph)
        logits = model_output.logits

        # MSE loss
        MSE_loss = tf.losses.mean_squared_error(output_ph, logits)
        regularization_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = MSE_loss + regularization_loss * cfg.lambda_l2

        tf.summary.scalar('learning_rate', lr_ph)
        train_op = utils.optimize(total_loss, global_step, cfg.optimizer,
                                  lr_ph, tf.global_variables())

        saver = tf.train.Saver(max_to_keep=10)

        summary_op = tf.summary.merge_all()

        #########################################################################

        # session iterator for session sampling
        feat_paths_ph = tf.placeholder(tf.string,
                                       shape=[None, cfg.sess_per_batch])
        feat2_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        label_paths_ph = tf.placeholder(tf.string,
                                        shape=[None, cfg.sess_per_batch])
        train_data = multimodal_session_generator(
            feat_paths_ph,
            feat2_paths_ph,
            label_paths_ph,
            sess_per_batch=cfg.sess_per_batch,
            num_threads=2,
            shuffled=False,
            preprocess_func=[model_emb.prepare_input, utils.mean_pool_input])
        train_sess_iterator = train_data.make_initializable_iterator()
        next_train = train_sess_iterator.get_next()

        # prepare validation data
        val_sess = []
        val_feats = []
        val_feats2 = []
        val_labels = []
        val_boundaries = []
        for session in val_set:
            session_id = os.path.basename(session[1]).split('_')[0]
            eve_batch, lab_batch, boundary = load_data_and_label(
                session[0], session[-1], model_emb.prepare_input_test
            )  # use prepare_input_test for testing time
            val_feats.append(eve_batch)
            val_labels.append(lab_batch)
            val_sess.extend([session_id] * eve_batch.shape[0])
            val_boundaries.extend(boundary)

            eve2_batch, _, _ = load_data_and_label(session[1], session[-1],
                                                   utils.mean_pool_input)
            val_feats2.append(eve2_batch)
        val_feats = np.concatenate(val_feats, axis=0)
        val_feats2 = np.concatenate(val_feats2, axis=0)
        val_labels = np.concatenate(val_labels, axis=0)
        print("Shape of val_feats: ", val_feats.shape)

        # generate metadata.tsv for visualize embedding
        with open(os.path.join(result_dir, 'metadata_val.tsv'), 'w') as fout:
            fout.write('id\tlabel\tsession_id\tstart\tend\n')
            for i in range(len(val_sess)):
                fout.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    i, val_labels[i, 0], val_sess[i], val_boundaries[i][0],
                    val_boundaries[i][1]))

        #########################################################################

        # Start running the graph
        if cfg.gpu:
            os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        summary_writer = tf.summary.FileWriter(result_dir, sess.graph)

        with sess.as_default():

            sess.run(tf.global_variables_initializer())

            # load pretrain model, if needed
            if cfg.model_path:
                print("Restoring pretrained model: %s" % cfg.model_path)
                saver.restore(sess, cfg.model_path)

            ################## Training loop ##################
            epoch = -1
            while epoch < cfg.max_epochs - 1:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // batch_per_epoch

                # learning rate schedule, reference: "In defense of Triplet Loss"
                if epoch < cfg.static_epochs:
                    learning_rate = cfg.learning_rate
                else:
                    learning_rate = cfg.learning_rate * \
                            0.001**((epoch-cfg.static_epochs)/(cfg.max_epochs-cfg.static_epochs))

                # prepare data for this epoch
                random.shuffle(train_set)

                paths = list(zip(*[iter(train_set)] * cfg.sess_per_batch))

                feat_paths = [[p[0] for p in path] for path in paths]
                feat2_paths = [[p[1] for p in path] for path in paths]
                label_paths = [[p[-1] for p in path] for path in paths]

                sess.run(train_sess_iterator.initializer,
                         feed_dict={
                             feat_paths_ph: feat_paths,
                             feat2_paths_ph: feat2_paths,
                             label_paths_ph: label_paths
                         })

                # for each epoch
                batch_count = 1
                while True:
                    try:
                        ##################### Data loading ########################
                        start_time = time.time()
                        start_time = time.time()
                        eve, eve2, lab = sess.run(next_train)
                        load_time = time.time() - start_time

                        ##################### Start training  ########################
                        start_time = time.time()

                        err, _, step, summ = sess.run(
                            [total_loss, train_op, global_step, summary_op],
                            feed_dict={
                                input_ph: eve,
                                output_ph: eve2,
                                dropout_ph: cfg.keep_prob,
                                lr_ph: learning_rate
                            })
                        train_time = time.time() - start_time

                        print ("%s\tEpoch: [%d][%d/%d]\tEvent num: %d\tLoad time: %.3f\tTrain_time: %.3f\tLoss %.4f" % \
                                (cfg.name, epoch+1, batch_count, batch_per_epoch, eve.shape[0], load_time, train_time, err))

                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag="train_loss",
                                             simple_value=err),
                        ])
                        summary_writer.add_summary(summary, step)
                        summary_writer.add_summary(summ, step)

                        batch_count += 1

                    except tf.errors.OutOfRangeError:
                        print("Epoch %d done!" % (epoch + 1))
                        break

                # validation on val_set
                print("Evaluating on validation set...")
                val_err, val_embeddings, val_pred, _ = sess.run(
                    [total_loss, embedding, logits, set_emb],
                    feed_dict={
                        input_ph: val_feats,
                        output_ph: val_feats2,
                        dropout_ph: 1.0
                    })
                mAP, mPrec = utils.evaluate_simple(val_embeddings, val_labels)
                mAP2, mPrec2 = utils.evaluate_simple(
                    val_pred, val_labels)  # use prediction for retrieval

                summary = tf.Summary(value=[
                    tf.Summary.Value(tag="Valiation mAP", simple_value=mAP),
                    tf.Summary.Value(tag="Validation [email protected]",
                                     simple_value=mPrec),
                    tf.Summary.Value(tag="Validation mAP 2",
                                     simple_value=mAP2),
                    tf.Summary.Value(tag="Validation [email protected] 2",
                                     simple_value=mPrec2),
                    tf.Summary.Value(tag="Validation loss",
                                     simple_value=val_err)
                ])
                summary_writer.add_summary(summary, step)

                # config for embedding visualization
                config = projector.ProjectorConfig()
                visual_embedding = config.embeddings.add()
                visual_embedding.tensor_name = emb_var.name
                visual_embedding.metadata_path = os.path.join(
                    result_dir, 'metadata_val.tsv')
                projector.visualize_embeddings(summary_writer, config)

                # save model
                saver.save(sess,
                           os.path.join(result_dir, cfg.name + '.ckpt'),
                           global_step=step)