예제 #1
0
def eval_mos(_):
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    config = tf.ConfigProto(gpu_options=gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession()

    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
    audio_processor = input_data.AudioProcessor(
        FLAGS.data_url, FLAGS.data_dir,
        FLAGS.silence_percentage, FLAGS.unknown_percentage,
        FLAGS.wanted_words.split(','), FLAGS.training_percentage,
        FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings)
    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    # uid_count = audio_processor.num_uids
    uid_count = audio_processor.num_uids

    print("Label count: %d uid count: %d" % (label_count, uid_count))
    #sys.exit()
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    net = lambda scope: lambda inp: models.create_model(inp,
                                                        model_settings,
                                                        FLAGS.
                                                        model_architecture,
                                                        is_training=True,
                                                        scope=scope)

    # Define loss and optimizer
    ground_truth_label = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')
    ground_truth_style = tf.placeholder(tf.float32, [None, uid_count],
                                        name='groundtruth_style')

    with tf.variable_scope('crossgrad'):
        label_net = net('label')
        label_embedding = label_net(fingerprint_input)

    with tf.variable_scope(''):
        cross_entropy, label_logits = losses.mos(label_embedding,
                                                 ground_truth_style,
                                                 ground_truth_label,
                                                 label_count, uid_count,
                                                 fingerprint_input, FLAGS)
    tune_var = tf.get_variable("tune_var", [FLAGS.num_uids])
    c_wts = tf.sigmoid(tune_var)
    c_wts /= tf.norm(c_wts)
    #   c_wts /= tf.reduce_sum(c_wts)

    #   tune_var = tf.get_variable("tune_var", [uid_count])
    with tf.variable_scope('', reuse=True):
        emb_mat = tf.get_variable("emb_mat", [uid_count, FLAGS.num_uids])
        #     _cwts = tf.nn.softmax(tune_var)
        #     c_wts = tf.einsum("j,ju->u", _cwts, emb_mat)

        logits_for_tune = losses.mos_tune(label_embedding, c_wts,
                                          ground_truth_label, label_count,
                                          fingerprint_input, FLAGS)
        # look away, ugly code!!
        logits_dir = losses.mos_tune(label_embedding, c_wts,
                                     ground_truth_label, label_count,
                                     fingerprint_input, FLAGS)


#     logits_for_tune = tf.nn.relu(logits_for_tune)
#     logits_dir = tf.nn.relu(logits_dir)

    probs_for_tune = tf.nn.softmax(logits_for_tune, axis=1)
    probs_dir = tf.nn.softmax(logits_dir, axis=1)
    agg_prob = tf.reduce_sum(probs_for_tune, axis=0)
    agg_prob2 = tf.reduce_sum(probs_dir, axis=0)
    # the first label is silence which is not present for many ids
    np_u = np.ones([label_count], dtype=np.float32) / (label_count - 1)
    np_u[0] = 0
    U = tf.constant(np_u)

    _l = logits_for_tune - tf.expand_dims(
        tf.reduce_min(logits_for_tune, axis=1), axis=1)
    #   _l /= tf.expand_dims(tf.reduce_max(_l, 1), 1)
    loss1 = tf.reduce_sum(
        tf.abs(agg_prob -
               (U * tf.cast(tf.shape(probs_for_tune)[0], tf.float32))))
    loss2 = -tf.reduce_mean(
        tf.reduce_sum(probs_for_tune * tf.one_hot(
            tf.argmax(probs_for_tune, axis=1), depth=label_count),
                      axis=1))
    loss4 = -tf.reduce_mean(tf.reduce_sum(probs_for_tune * (_l), axis=1))
    loss_for_tune = loss4

    _l = logits_dir - tf.expand_dims(tf.reduce_min(logits_dir, axis=1), axis=1)
    _l /= tf.expand_dims(tf.reduce_max(_l, 1), 1)
    loss1 = tf.reduce_sum(
        tf.abs(agg_prob2 -
               (U * tf.cast(tf.shape(probs_for_tune)[0], tf.float32))))
    loss2 = -tf.reduce_mean(
        tf.reduce_sum(probs_dir * tf.one_hot(tf.argmax(probs_dir, axis=1),
                                             depth=label_count),
                      axis=1))
    loss4 = -tf.reduce_mean(tf.reduce_sum(probs_dir * (_l), axis=1))
    loss_dir = loss4

    predicted_tuned_indices = tf.argmax(logits_for_tune, axis=1)
    tune_acc = tf.reduce_mean(
        tf.cast(
            tf.equal(predicted_tuned_indices, tf.argmax(ground_truth_label,
                                                        1)), tf.float32))

    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    predicted_indices = tf.argmax(label_logits, 1)
    expected_indices = tf.argmax(ground_truth_label, 1)
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    global_step = tf.contrib.framework.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    lvars = [
        var for var in tf.global_variables() if var.name.find("tune_var") < 0
    ]
    saver = tf.train.Saver(lvars)

    opt = tf.train.MomentumOptimizer(0.1, 0.1, use_nesterov=True)
    #   opt = tf.train.AdamOptimizer(0.01)
    #   opt = tf.train.GradientDescentOptimizer(0.2)
    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        tune_op = opt.minimize(loss_for_tune, var_list=[tune_var])

    tf.global_variables_initializer().run()

    print(tf.trainable_variables())
    sys.stdout.flush()

    if FLAGS.model_dir:
        saver.restore(sess, FLAGS.model_dir)
        start_step = global_step.eval(session=sess)

    ph_uidx = tf.placeholder(tf.int32, [])
    tf_emb = tf.nn.embedding_lookup(emb_mat, [ph_uidx])
    tf_emb = tf.sigmoid(tf_emb)
    tf_emb /= tf.expand_dims(tf.norm(tf_emb, axis=1), axis=1)
    #   sess.graph.finalize()

    set_size = audio_processor.set_size('training')
    tf.logging.info('set_size=%d', set_size)

    # train accuracy is almost 100.

    #   total_accuracy = 0
    #   total_conf_matrix = None
    #   for i in xrange(0, set_size, FLAGS.batch_size):
    #     test_fingerprints, test_ground_truth, test_uids = audio_processor.get_data(
    #         FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'training', sess)
    #     test_accuracy, conf_matrix = sess.run(
    #         [evaluation_step, confusion_matrix],
    #         feed_dict={
    #             fingerprint_input: test_fingerprints,
    #             ground_truth_label: test_ground_truth,
    #             ground_truth_style: test_uids,
    #         })
    #     batch_size = min(FLAGS.batch_size, set_size - i)
    #     total_accuracy += (test_accuracy * batch_size) / set_size
    #     if total_conf_matrix is None:
    #       total_conf_matrix = conf_matrix
    #     else:
    #       total_conf_matrix += conf_matrix
    #   tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    #   tf.logging.info('Final train accuracy = %.1f%% (N=%d)' % (total_accuracy * 100, set_size))

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    #   for i in xrange(0, set_size, FLAGS.batch_size):
    _t1, _t2, _t3 = 0, 0, 0
    _i1, _i2, _i3 = 0, 0, 0
    agg_acc = np.zeros([uid_count])
    for test_fingerprints, test_ground_truth in audio_processor.get_data_per_domain(
            model_settings, 0.0, 0.0, 0, 'validation', sess):
        sess.run(tune_var.initializer)
        label_dist = np.sum(test_ground_truth, axis=0)
        if label_dist[0] > 0 or _i1 >= 20:
            continue

        print("Num samples: %d label dist: %s" %
              (len(test_fingerprints), label_dist))
        _ln = len(test_ground_truth)

        max_test = -1
        ruid = np.random.randint(0, uid_count)
        all_acc = []
        for uidx in tqdm.tqdm(range(uid_count)):
            uids = np.zeros([_ln, uid_count])
            uids[np.arange(_ln), uidx] = 1

            test_accuracy, conf_matrix = sess.run(
                [evaluation_step, confusion_matrix],
                feed_dict={
                    fingerprint_input: test_fingerprints,
                    ground_truth_label: test_ground_truth,
                    ground_truth_style: uids,
                })
            all_acc.append(test_accuracy)

            if test_accuracy > max_test:
                max_test = test_accuracy
                best_wt = sess.run(tf_emb, feed_dict={ph_uidx: uidx})[0]
            #print ("Test acc: %0.4f" % test_accuracy)
            if uidx == ruid:
                base = test_accuracy

        _t3 += base
        _i3 += 1

        _t1 += max_test
        _i1 += 1.

        agg_acc += np.array(sorted(all_acc))
        print("Base Test acc: %0.4f" % (base))
        print("Best Test acc: %0.4f -- wt: %s" % (max_test, best_wt))

        fd = {
            fingerprint_input: test_fingerprints,
            ground_truth_label: test_ground_truth,
        }
        #     sess.run(tf.assign(tune_var, best_wt))
        #     np_loss_for_best = sess.run(loss_dir, feed_dict=fd)
        #     print ("Loss for best wt: %f" % np_loss_for_best)
        sess.run(tune_var.initializer)
        for _it in range(100):
            _, np_l, np_wts, np_wts2 = sess.run(
                [tune_op, loss_for_tune, c_wts, tune_var], feed_dict=fd)
            if _it % 100 == 0:
                print("Loss: %f wts: %s %s" % (np_l, np_wts, np_wts2))
        np_tuned_acc, np_preds = sess.run(
            [tune_acc,
             tf.one_hot(predicted_tuned_indices, label_count)],
            feed_dict=fd)

        _t2 += np_tuned_acc
        _i2 += 1
        print("Tuned acc: %f dist: %s" %
              (100 * np_tuned_acc, np.sum(np_preds, axis=0)))
        #print (conf_matrix)
    print("Defau Avg test accuracy: %f over %d domains" % ((_t3 / _i3), _i3))
    print("Brute Avg test accuracy: %f over %d domains" % ((_t1 / _i1), _i1))
    print("Tuned Avg test accuracy: %f over %d domains" % ((_t2 / _i2), _i2))
    agg_acc /= _i1
    for pi in range(0, 110, 10):
        print(pi, np.percentile(agg_acc, pi))
예제 #2
0
파일: eval_mos2.py 프로젝트: vihari/CSD
def get_initialization(_):
    # Start a new TensorFlow session.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)
    config = tf.ConfigProto(gpu_options=gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession()

    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
    audio_processor = input_data.AudioProcessor(
        FLAGS.data_url, FLAGS.data_dir,
        FLAGS.silence_percentage, FLAGS.unknown_percentage,
        FLAGS.wanted_words.split(','), FLAGS.training_percentage,
        FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings)

    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    uid_count = audio_processor.num_uids

    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    net = lambda scope: lambda inp: models.create_model(inp,
                                                        model_settings,
                                                        FLAGS.
                                                        model_architecture,
                                                        is_training=True,
                                                        scope=scope)

    # Define loss and optimizer
    ground_truth_label = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')
    ground_truth_style = tf.placeholder(tf.float32, [None, uid_count],
                                        name='groundtruth_style')

    with tf.variable_scope('crossgrad'):
        label_net = net('label')
        label_embedding = label_net(fingerprint_input)

    with tf.variable_scope(''):
        cross_entropy, label_logits = losses.mos(label_embedding,
                                                 ground_truth_style,
                                                 ground_truth_label,
                                                 label_count, uid_count,
                                                 fingerprint_input, FLAGS)

    tune_var = tf.get_variable("tune_var", [FLAGS.num_uids])
    c_wts = tf.sigmoid(tune_var)
    c_wts /= tf.norm(c_wts)

    with tf.variable_scope('', reuse=True):
        emb_mat = tf.get_variable("emb_mat", [uid_count, FLAGS.num_uids])

        tf_reprs = losses.mos_tune_project(label_embedding)
        logits_for_tune, tf_reprs = losses.mos_tune(tf_reprs, c_wts,
                                                    ground_truth_label,
                                                    label_count, FLAGS)

        sm_w = tf.get_variable("sm_w",
                               shape=[FLAGS.num_uids, 128, label_count])
        sm_bias = tf.get_variable("sm_bias",
                                  shape=[FLAGS.num_uids, label_count])

    probs_for_tune = tf.nn.softmax(logits_for_tune, axis=1)
    agg_prob = tf.reduce_sum(probs_for_tune, axis=0)
    # the first label is silence which is not present for many ids
    loss_for_tune = tf.losses.softmax_cross_entropy(
        onehot_labels=ground_truth_label, logits=logits_for_tune)

    predicted_tuned_indices = tf.argmax(logits_for_tune, axis=1)
    tune_acc = tf.reduce_mean(
        tf.cast(
            tf.equal(predicted_tuned_indices, tf.argmax(ground_truth_label,
                                                        1)), tf.float32))

    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    predicted_indices = tf.argmax(logits_for_tune, 1)
    expected_indices = tf.argmax(ground_truth_label, 1)
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    global_step = tf.contrib.framework.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    lvars = [
        var for var in tf.global_variables() if var.name.find("tune_var") < 0
    ]
    saver = tf.train.Saver(lvars)

    opt = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.mom, use_nesterov=True)

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        tune_op = opt.minimize(loss_for_tune, var_list=[tune_var])

    tf.global_variables_initializer().run()

    print(tf.trainable_variables())
    sys.stdout.flush()

    if FLAGS.model_dir:
        saver.restore(sess, FLAGS.model_dir)
        start_step = global_step.eval(session=sess)

    ph_uidx = tf.placeholder(tf.int32, [])
    tf_emb = tf.nn.embedding_lookup(emb_mat, [ph_uidx])
    tf_emb = tf.sigmoid(tf_emb)
    tf_emb /= tf.expand_dims(tf.norm(tf_emb, axis=1), axis=1)

    set_size = audio_processor.set_size('training')
    tf.logging.info('set_size=%d', set_size)

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)

    nsteps = 1000
    batch_size = 128
    best_val = -1
    best = None
    for _iter in range(nsteps):
        train_fingerprints, train_ground_truth, train_uids = audio_processor.get_data(
            batch_size, 0, model_settings, 0, 0, 0, 'training', sess)
        fd = {
            fingerprint_input: train_fingerprints,
            ground_truth_label: train_ground_truth,
        }

        _, np_l, np_wts, np_wts2 = sess.run(
            [tune_op, loss_for_tune, c_wts, tune_var], feed_dict=fd)

        if _iter % 50 == 0:
            print("Loss: %f wts: %s %s" % (np_l, np_wts, np_wts2))
            val, test = -1, -1
            for si, split in enumerate(['validation', 'testing']):
                inps, labels, _ = test_fingerprints, test_ground_truth, test_uids = audio_processor.get_data(
                    -1, 0, model_settings, 0.0, 0.0, 0, split, sess)
                fd = {fingerprint_input: inps, ground_truth_label: labels}
                np_acc = sess.run(evaluation_step, feed_dict=fd)
                print("Split: %s -- acc: %f" % (split, np_acc))
                if si == 0:
                    val = np_acc
                else:
                    test = np_acc
            if val > best_val:
                best_val = val
                best = (val, test, np_wts, np_wts2)
    print("Best validation accuracy %f -- %s" % (best_val, best))
예제 #3
0
파일: eval_mos2.py 프로젝트: vihari/CSD
def eval_vanilla(_):
    """
  Use all the label data to fine-tune the combination weights
  """
    # Start a new TensorFlow session.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    config = tf.ConfigProto(gpu_options=gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession()

    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
    audio_processor = input_data.AudioProcessor(
        FLAGS.data_url, FLAGS.data_dir,
        FLAGS.silence_percentage, FLAGS.unknown_percentage,
        FLAGS.wanted_words.split(','), FLAGS.training_percentage,
        FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings)
    #   val_fname, test_fname = 'extra-data/all_paths.txt', 'extra-data/all_paths.txt'
    #   audio_processor = input_data.make_processor(val_fname, test_fname, FLAGS.wanted_words.split(','), FLAGS.data_dir, model_settings)
    #   audio_processor.num_uids = FLAGS.training_percentage + 2

    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    uid_count = audio_processor.num_uids

    print("Label count: %d uid count: %d" % (label_count, uid_count))
    #sys.exit()
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    net = lambda scope: lambda inp: models.create_model(inp,
                                                        model_settings,
                                                        FLAGS.
                                                        model_architecture,
                                                        is_training=True,
                                                        scope=scope)

    # Define loss and optimizer
    ground_truth_label = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')
    ground_truth_style = tf.placeholder(tf.float32, [None, uid_count],
                                        name='groundtruth_style')

    with tf.variable_scope('crossgrad'):
        label_net = net('label')
        label_embedding = label_net(fingerprint_input)

    with tf.variable_scope(''):
        cross_entropy, label_logits = losses.mos(label_embedding,
                                                 ground_truth_style,
                                                 ground_truth_label,
                                                 label_count, uid_count,
                                                 fingerprint_input, FLAGS)
    tune_var = tf.get_variable("tune_var", [FLAGS.num_uids])
    c_wts = tf.sigmoid(tune_var)
    c_wts /= tf.norm(c_wts)

    with tf.variable_scope('', reuse=True):
        emb_mat = tf.get_variable("emb_mat", [uid_count, FLAGS.num_uids])

        tf_reprs = losses.mos_tune_project(label_embedding)
        logits_for_tune, tf_reprs = losses.mos_tune(tf_reprs, c_wts,
                                                    ground_truth_label,
                                                    label_count, FLAGS)

        sm_w = tf.get_variable("sm_w",
                               shape=[FLAGS.num_uids, 128, label_count])
        sm_bias = tf.get_variable("sm_bias",
                                  shape=[FLAGS.num_uids, label_count])

    probs_for_tune = tf.nn.softmax(logits_for_tune, axis=1)
    agg_prob = tf.reduce_sum(probs_for_tune, axis=0)
    # the first label is silence which is not present for many ids
    loss_for_tune = tf.losses.softmax_cross_entropy(
        onehot_labels=ground_truth_label, logits=logits_for_tune)

    predicted_tuned_indices = tf.argmax(logits_for_tune, axis=1)
    tune_acc = tf.reduce_mean(
        tf.cast(
            tf.equal(predicted_tuned_indices, tf.argmax(ground_truth_label,
                                                        1)), tf.float32))

    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    predicted_indices = tf.argmax(label_logits, 1)
    expected_indices = tf.argmax(ground_truth_label, 1)
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    global_step = tf.contrib.framework.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    lvars = [
        var for var in tf.global_variables() if var.name.find("tune_var") < 0
    ]
    saver = tf.train.Saver(lvars)

    opt = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.mom, use_nesterov=True)

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        tune_op = opt.minimize(loss_for_tune, var_list=[tune_var])

    tf.global_variables_initializer().run()

    print(tf.trainable_variables())
    sys.stdout.flush()

    if FLAGS.model_dir:
        saver.restore(sess, FLAGS.model_dir)
        start_step = global_step.eval(session=sess)

    ph_uidx = tf.placeholder(tf.int32, [])
    tf_emb = tf.nn.embedding_lookup(emb_mat, [ph_uidx])
    tf_emb = tf.sigmoid(tf_emb)
    tf_emb /= tf.expand_dims(tf.norm(tf_emb, axis=1), axis=1)

    set_size = audio_processor.set_size('training')
    tf.logging.info('set_size=%d', set_size)

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)

    for split in ['validation', 'testing']:
        _t1, _t2, _t3 = 0, 0, 0
        _i1, _i2, _i3 = 0, 0, 0
        num_try = 10
        #   agg_acc = np.zeros(uid_count)
        for test_fingerprints, test_ground_truth in audio_processor.get_data_per_domain(
                model_settings, 0.0, 0.0, 0, split, sess):

            sess.run(tune_var.initializer)
            label_dist = np.sum(test_ground_truth, axis=0)
            if label_dist[0] > 0:
                continue

            print("Num samples: %d label dist: %s" %
                  (len(test_fingerprints), label_dist))
            _ln = len(test_ground_truth)

            max_test = -1
            ruid = np.random.randint(0, uid_count)

            fd = {
                fingerprint_input: test_fingerprints,
                ground_truth_label: test_ground_truth,
            }

            num_try = 10
            agg_acc = np.zeros([num_try + 1])
            all_acc = []
            for uidx in tqdm.tqdm(
                    np.random.choice(np.arange(uid_count), num_try).tolist() +
                [ruid]):
                uids = np.zeros([_ln, uid_count])
                uids[np.arange(_ln), uidx] = 1

                test_accuracy, conf_matrix = sess.run(
                    [evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: test_fingerprints,
                        ground_truth_label: test_ground_truth,
                        ground_truth_style: uids,
                    })
                all_acc.append(test_accuracy)

                if test_accuracy > max_test:
                    max_test = test_accuracy
                    best_wt = sess.run(tf_emb, feed_dict={ph_uidx: uidx})[0]
                #print ("Test acc: %0.4f" % test_accuracy)
                if uidx == ruid:
                    base = test_accuracy

            _t3 += base
            _i3 += 1

            _t1 += max_test
            _i1 += 1.

            agg_acc += np.array(sorted(all_acc))
            print("Base Test acc: %0.4f" % (base))
            print("Best Test acc: %0.4f -- wt: %s" % (max_test, best_wt))

            np_emb_wt = sess.run(emb_mat)
            rand_emb = np_emb_wt[ruid, :]
            sess.run(tf.assign(tune_var, rand_emb))

            for _it in range(FLAGS.iters):
                _, np_l, np_wts, np_wts2 = sess.run(
                    [tune_op, loss_for_tune, c_wts, tune_var], feed_dict=fd)

                if _it % 100 == 0:
                    print("Loss: %f wts: %s %s" % (np_l, np_wts, np_wts2))
            np_tuned_acc, np_preds = sess.run(
                [tune_acc,
                 tf.one_hot(predicted_tuned_indices, label_count)],
                feed_dict=fd)

            _t2 += np_tuned_acc
            _i2 += 1
            print("Tuned acc: %f dist: %s" %
                  (100 * np_tuned_acc, np.sum(np_preds, axis=0)))
            #print (conf_matrix)
        print("Defau Avg test accuracy: %f over %d domains" %
              ((_t3 / _i3), _i3))
        print("Brute Avg test accuracy: %f over %d domains" %
              ((_t1 / _i1), _i1))
        print("Tuned Avg test accuracy: %f over %d domains" %
              ((_t2 / _i2), _i2))
        print("")
예제 #4
0
파일: eval_mos2.py 프로젝트: vihari/CSD
def eval_mos(_):
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    config = tf.ConfigProto(gpu_options=gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession()

    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)

    #   audio_processor = input_data.AudioProcessor(
    #       FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
    #       FLAGS.unknown_percentage,
    #       FLAGS.wanted_words.split(','), FLAGS.training_percentage,
    #       FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings)

    val_fname, test_fname = 'extra-data/all_paths.txt', 'extra-data/all_paths.txt'
    audio_processor = input_data.make_processor(val_fname, test_fname,
                                                FLAGS.wanted_words.split(','),
                                                FLAGS.data_dir, model_settings)
    audio_processor.num_uids = FLAGS.training_percentage + 2

    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    uid_count = audio_processor.num_uids

    print("Label count: %d uid count: %d" % (label_count, uid_count))
    #sys.exit()

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    net = lambda scope: lambda inp: models.create_model(inp,
                                                        model_settings,
                                                        FLAGS.
                                                        model_architecture,
                                                        is_training=True,
                                                        scope=scope)

    # Define loss and optimizer
    ground_truth_label = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')
    ground_truth_style = tf.placeholder(tf.float32, [None, uid_count],
                                        name='groundtruth_style')
    ph_reprs = tf.placeholder(tf.float32, [None, 128],
                              name='representations_input')

    with tf.variable_scope('crossgrad'):
        label_net = net('label')
        label_embedding = label_net(fingerprint_input)

    with tf.variable_scope(''):
        cross_entropy, label_logits = losses.mos(label_embedding,
                                                 ground_truth_style,
                                                 ground_truth_label,
                                                 label_count, uid_count,
                                                 fingerprint_input, FLAGS)
    tune_var = tf.get_variable("tune_var", [FLAGS.num_uids],
                               initializer=tf.zeros_initializer)

    c_wts = tf.sigmoid(tune_var)
    c_wts /= tf.norm(c_wts)

    with tf.variable_scope('', reuse=True):
        emb_mat = tf.get_variable("emb_mat", [uid_count, FLAGS.num_uids])

        tf_reprs = losses.mos_tune_project(label_embedding)
        logits_for_tune, _ = losses.mos_tune(tf_reprs, c_wts,
                                             ground_truth_label, label_count,
                                             FLAGS)

        logits_for_tune2, _ = losses.mos_tune(ph_reprs, c_wts,
                                              ground_truth_label, label_count,
                                              FLAGS)

        sm_w = tf.get_variable("sm_w",
                               shape=[FLAGS.num_uids, 128, label_count])

    probs_for_tune = tf.nn.softmax(logits_for_tune, axis=1)

    loss_for_tune = tf.losses.softmax_cross_entropy(
        onehot_labels=ground_truth_label, logits=logits_for_tune)
    loss_for_tune2 = tf.losses.softmax_cross_entropy(
        onehot_labels=ground_truth_label, logits=logits_for_tune2)

    predicted_tuned_indices = tf.argmax(logits_for_tune, axis=1)
    tune_acc = tf.reduce_mean(
        tf.cast(
            tf.equal(predicted_tuned_indices, tf.argmax(ground_truth_label,
                                                        1)), tf.float32))

    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    predicted_indices = tf.argmax(label_logits, 1)
    expected_indices = tf.argmax(ground_truth_label, 1)
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    global_step = tf.contrib.framework.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    common_var = [
        var for var in tf.all_variables() if var.name.find('common_var') >= 0
    ][0]

    lvars = [
        var for var in tf.global_variables() if var.name.find("tune_var") < 0
    ]
    saver = tf.train.Saver(lvars)

    opt = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.mom, use_nesterov=True)

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        tune_op = opt.minimize(loss_for_tune, var_list=[tune_var])
        tune_op2 = opt.minimize(loss_for_tune2, var_list=[tune_var])

    tf.global_variables_initializer().run()

    print(tf.trainable_variables())
    sys.stdout.flush()

    if FLAGS.model_dir:
        saver.restore(sess, FLAGS.model_dir)
        start_step = global_step.eval(session=sess)

    np_common_var = sess.run(common_var)

    ph_uidx = tf.placeholder(tf.int32, [])
    tf_emb = tf.nn.embedding_lookup(emb_mat, [ph_uidx])
    tf_emb = tf.sigmoid(tf_emb)
    tf_emb /= tf.expand_dims(tf.norm(tf_emb, axis=1), axis=1)

    set_size = audio_processor.set_size('training')
    tf.logging.info('set_size=%d', set_size)

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    wfname = "results.txt"
    wf = open(wfname, "a")
    params = "|".join([
        str(_) for _ in [
            "mos", FLAGS.lr, FLAGS.mom, FLAGS.iters, FLAGS.inneriters,
            FLAGS.nseeds
        ]
    ])
    wf.write("|" + params)

    #   with open("cache/smws.pkl", "wb") as f1, open("cache/smbs,pkl", "wb") as f2:
    #     pickle.dump(sess.run(sm_w), f1)
    #     pickle.dump(sess.run(sm_bias), f2)

    # select centers for each label cluster
    def select_centers(np_probs, np_reprs, N=20):
        preds = np.argmax(np_probs, axis=1)
        mean_vecs = {}
        count = {}
        for _p in preds:
            count[_p] = count.get(_p, 0) + 1
        for _p in np.unique(preds):
            # tested
            # print ("Shape: ", np.shape(np_reprs[np.where(np.equal(preds, _p))]))
            # print ("Count: ", count[_p])
            mean_vecs[_p] = np.mean(np_reprs[np.where(np.equal(preds, _p))],
                                    axis=0)
        label_avg = np.array([mean_vecs[_p] for _p in preds])

        dists = np.linalg.norm(np_reprs - label_avg, axis=1)
        # select centers
        return np.argpartition(dists, kth=N)[:N]

    def get_centering_data(np_probs, np_reprs):
        preds = np.argmax(np_probs, axis=1)
        mean_vecs = {}
        count = {}
        for _p in preds:
            count[_p] = count.get(_p, 0) + 1
        for _p in np.unique(preds):
            mean_vecs[_p] = np.mean(np_reprs[np.where(np.equal(preds, _p))],
                                    axis=0)
        label_avg = np.array([mean_vecs[_p] for _p in preds])

        x, y = [], []
        for _p in np.unique(preds):
            if count[_p] >= 0:
                _l = np.zeros(np.shape(np_probs)[-1])
                _l[_p] = 1
                x.append(mean_vecs[_p])
                y.append(_l)
        x, y = np.array(x), np.array(y)

        return np.array(x), np.array(y)

    for split in ['validation', 'testing']:
        _t1, _t2, _t3, _t4 = 0, 0, 0, 0
        _i1, _i2, _i3, _i4 = 0, 0, 0, 0
        num_try = 10
        agg_acc = np.zeros([num_try + 1])
        #   agg_acc = np.zeros(uid_count)
        for test_fingerprints, test_ground_truth in audio_processor.get_data_per_domain(
                model_settings, 0.0, 0.0, 0, split, sess):
            sess.run(tune_var.initializer)
            label_dist = np.sum(test_ground_truth, axis=0)
            # ignoring the one domain which has only silence examples
            if label_dist[0] > 0:
                continue

            print("Num samples: %d label dist: %s" %
                  (len(test_fingerprints), label_dist))
            _ln = len(test_ground_truth)

            max_test = -1
            ruid = np.random.randint(0, uid_count)
            all_acc = []
            # range(uid_count)
            # np.random.choice(np.arange(uid_count), num_try).tolist() + [ruid]
            for uidx in tqdm.tqdm(
                    np.random.choice(np.arange(uid_count), num_try).tolist() +
                [ruid]):
                uids = np.zeros([_ln, uid_count])
                uids[np.arange(_ln), uidx] = 1

                test_accuracy, conf_matrix = sess.run(
                    [evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: test_fingerprints,
                        ground_truth_label: test_ground_truth,
                        ground_truth_style: uids,
                    })
                all_acc.append(test_accuracy)

                if test_accuracy > max_test:
                    max_test = test_accuracy
                    best_wt = sess.run(tf_emb, feed_dict={ph_uidx: uidx})[0]
                #print ("Test acc: %0.4f" % test_accuracy)
                if uidx == ruid:
                    base = test_accuracy

            _t3 += base
            _i3 += 1

            _t1 += max_test
            _i1 += 1.

            agg_acc += np.array(sorted(all_acc))

            fd = {
                fingerprint_input: test_fingerprints,
                ground_truth_label: test_ground_truth,
            }
            np_emb_wt = sess.run(emb_mat)
            rand_emb = np_emb_wt[ruid, :]
            #       sess.run(tf.assign(tune_var, rand_emb))
            sess.run(tf.assign(tune_var, np_common_var))
            np_tuned_acc = sess.run(tune_acc, feed_dict=fd)
            _t4 += np_tuned_acc
            _i4 += 1

            print("Base Test acc: %0.4f" % (base))
            print("Common Test acc: %0.4f" % (np_tuned_acc))
            print("Best Test acc: %0.4f -- wt: %s" % (max_test, best_wt))

            #       sess.run(tune_var.initializer)
            for _it in range(FLAGS.iters):
                np_probs, np_reprs = sess.run([probs_for_tune, tf_reprs],
                                              feed_dict=fd)
                if True:
                    #           _mxprobs = np.max(np_probs, axis=1)
                    #           c_idxs = np.argpartition(-_mxprobs, kth=FLAGS.nseeds)[:FLAGS.nseeds]
                    c_idxs = select_centers(np_probs, np_reprs, FLAGS.nseeds)

                    all_input = fd[fingerprint_input]
                    all_labels = np_probs
                    new_inp_pl, new_label_pl = [], []
                    for ci in c_idxs:
                        new_inp_pl.append(all_input[ci])
                        _pred = np.argmax(all_labels[ci])
                        z = np.zeros([label_count], np.int32)
                        z[_pred] = 1
                        new_label_pl.append(z)
                    select_insts = {}
                    select_insts[fingerprint_input] = new_inp_pl
                    select_insts[ground_truth_label] = new_label_pl

                    for _ in range(FLAGS.inneriters):
                        _, np_l, np_wts, np_wts2 = sess.run(
                            [tune_op, loss_for_tune, c_wts, tune_var],
                            feed_dict=select_insts)
                else:
                    new_inp_pl, new_label_pl = get_centering_data(
                        np_probs, np_reprs)
                    select_insts = {}
                    select_insts[ph_reprs] = new_inp_pl
                    select_insts[ground_truth_label] = new_label_pl

                    for _ in range(FLAGS.inneriters):
                        _, np_l, np_wts, np_wts2 = sess.run(
                            [tune_op2, loss_for_tune2, c_wts, tune_var],
                            feed_dict=select_insts)

                if _it % 50 == 0:
                    print("Loss: %f wts: %s %s" % (np_l, np_wts, np_wts2))
            np_tuned_acc, np_preds = sess.run(
                [tune_acc,
                 tf.one_hot(predicted_tuned_indices, label_count)],
                feed_dict=fd)

            _t2 += np_tuned_acc
            _i2 += 1
            print("Tuned acc: %f dist: %s" %
                  (100 * np_tuned_acc, np.sum(np_preds, axis=0)))
            #print (conf_matrix)
        print("Defau Avg test accuracy: %f over %d domains" %
              ((_t3 / _i3), _i3))
        print("Brute Avg test accuracy: %f over %d domains" %
              ((_t1 / _i1), _i1))
        print("Tuned Avg test accuracy: %f over %d domains" %
              ((_t2 / _i2), _i2))
        print("Common Avg test accuracy: %f over %d domains" %
              ((_t4 / _i4), _i4))
        fields = [
            "%0.2f" % (100 * _)
            for _ in [(_t3 / _i3), (_t1 / _i1), (_t2 / _i2), (_t4 / _i4)]
        ]
        wf.write("|" + "|".join(fields))

        agg_acc /= _i1
        for pi in range(0, 110, 10):
            print(pi, np.percentile(agg_acc, pi))
    wf.write("|\n")
    wf.close()
예제 #5
0
파일: eval_mos2.py 프로젝트: vihari/CSD
def eval_common(_):
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    config = tf.ConfigProto(gpu_options=gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession()

    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)

    audio_processor = input_data.AudioProcessor(
        FLAGS.data_url, FLAGS.data_dir,
        FLAGS.silence_percentage, FLAGS.unknown_percentage,
        FLAGS.wanted_words.split(','), FLAGS.training_percentage,
        FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings)

    #   val_fname, test_fname = 'extra-data/all_paths.txt', 'extra-data/all_paths.txt'
    #   audio_processor = input_data.make_processor(val_fname, test_fname, FLAGS.wanted_words.split(','), FLAGS.data_dir, model_settings)
    #   audio_processor.num_uids = FLAGS.training_percentage + 2

    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    uid_count = audio_processor.num_uids

    print("Label count: %d uid count: %d" % (label_count, uid_count))
    #sys.exit()

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    net = lambda scope: lambda inp: models.create_model(inp,
                                                        model_settings,
                                                        FLAGS.
                                                        model_architecture,
                                                        is_training=True,
                                                        scope=scope)

    # Define loss and optimizer
    ground_truth_label = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')
    ground_truth_style = tf.placeholder(tf.float32, [None, uid_count],
                                        name='groundtruth_style')
    ph_reprs = tf.placeholder(tf.float32, [None, 128],
                              name='representations_input')

    with tf.variable_scope('crossgrad'):
        label_net = net('label')
        label_embedding = label_net(fingerprint_input)

    with tf.variable_scope(''):
        cross_entropy, _, label_logits = losses.mos(label_embedding,
                                                    ground_truth_style,
                                                    ground_truth_label,
                                                    label_count, uid_count,
                                                    fingerprint_input, FLAGS)

    predicted_indices = tf.argmax(label_logits, 1)
    expected_indices = tf.argmax(ground_truth_label, 1)
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    global_step = tf.contrib.framework.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    common_var = [
        var for var in tf.all_variables() if var.name.find('common_var') >= 0
    ][0]

    lvars = [
        var for var in tf.global_variables() if var.name.find("tune_var") < 0
    ]
    saver = tf.train.Saver(lvars)

    tf.global_variables_initializer().run()

    sys.stdout.flush()

    if FLAGS.model_dir:
        saver.restore(sess, FLAGS.model_dir)
        start_step = global_step.eval(session=sess)

    np_common_var = sess.run(common_var)

    set_size = audio_processor.set_size('training')
    tf.logging.info('set_size=%d', set_size)

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    wfname = "results.org"
    wf = open(wfname, "a")
    #   wf.write(FLAGS.model_dir + "\n")
    params = "|".join([str(_) for _ in ["mos"]])
    #   wf.write("|" + params)

    for split in ['validation', 'testing']:
        _t, _i = 0, 0
        #   agg_acc = np.zeros(uid_count)
        num_correct, total = 0, 0
        test_fingerprints, test_ground_truth, _ = audio_processor.get_data(
            -1, 0, model_settings, 0.0, 0.0, 0, split, sess)
        np_acc = sess.run(evaluation_step,
                          feed_dict={
                              fingerprint_input: test_fingerprints,
                              ground_truth_label: test_ground_truth
                          })
        print("Accuracy: %f" % np_acc)
예제 #6
0
def main(_):
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
    config = tf.ConfigProto(gpu_options=gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession()
    # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

    # Begin by making sure we have the training data we need. If you already have
    # training data of your own, use `--data_url= ` on the command line to avoid
    # downloading.
    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
    audio_processor = input_data.AudioProcessor(
        FLAGS.data_url, FLAGS.data_dir,
        FLAGS.silence_percentage, FLAGS.unknown_percentage,
        FLAGS.wanted_words.split(','), FLAGS.training_percentage,
        FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings)
    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    # uid_count = audio_processor.num_uids
    uid_count = audio_processor.num_uids

    print("Label count: %d uid count: %d" % (label_count, uid_count))
    #sys.exit()
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
    # Figure out the learning rates for each training phase. Since it's often
    # effective to have high learning rates at the start of training, followed by
    # lower levels towards the end, the number of steps and learning rates can be
    # specified as comma-separated lists to define the rate at each stage. For
    # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
    # will run 13,000 training loops in total, with a rate of 0.001 for the first
    # 10,000, and 0.0001 for the final 3,000.
    set_size = audio_processor.set_size('training')
    tf.logging.info('Train set_size=%d', set_size)

    lr = float(FLAGS.learning_rate)
    FLAGS.how_many_training_steps = (FLAGS.how_many_epochs *
                                     set_size) // FLAGS.batch_size
    print("Running for %d training steps with lr: %f" %
          (FLAGS.how_many_training_steps, lr))

    # training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
    # learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    training_steps_list = [FLAGS.how_many_training_steps, 1200]
    learning_rates_list = [lr, 1e-4]
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    net = lambda scope: lambda inp: models.create_model(inp,
                                                        model_settings,
                                                        FLAGS.
                                                        model_architecture,
                                                        is_training=True,
                                                        scope=scope)

    # Define loss and optimizer
    ground_truth_label = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')
    ground_truth_style = tf.placeholder(tf.float32, [None, uid_count],
                                        name='groundtruth_style')

    # the supervised bunch
    if FLAGS.model in [
            "cg", "ucgv3", "ucgv6", "ucgv7", "ucgv8", "ucgv8_wosx", "ucgv11",
            "ucgv12", "ucgv13"
    ]:
        with tf.variable_scope('crossgrad'):
            label_net_fn = net('label')
            style_net_fn = net('style')

            label_embedding = label_net_fn(fingerprint_input)
            style_embedding = style_net_fn(fingerprint_input)

        scales = {"epsilon": FLAGS.epsilon, "alpha": FLAGS.alpha}

        if FLAGS.model == "cg":
            cross_entropy, label_logits = losses.cg_losses(
                label_embedding,
                style_embedding,
                ground_truth_style,
                ground_truth_label,
                uid_count,
                label_count,
                fingerprint_input,
                scales,
                label_net_fn=label_net_fn,
                style_net_fn=style_net_fn)
    elif FLAGS.model == "mos" or FLAGS.model == 'mos2' or FLAGS.model == "simple":
        with tf.variable_scope('crossgrad'):
            label_net = net('label')
            label_embedding = label_net(fingerprint_input)

        if FLAGS.model == "mos":
            cross_entropy, _, label_logits = losses.mos(
                label_embedding, ground_truth_style, ground_truth_label,
                label_count, uid_count, fingerprint_input, FLAGS)
        elif FLAGS.model == "mos2":
            cross_entropy, _, label_logits = losses.mos2(
                label_embedding, ground_truth_style, ground_truth_label,
                label_count, uid_count, fingerprint_input, FLAGS)
        elif FLAGS.model == "simple":
            cross_entropy, label_logits = losses.simple(
                label_embedding, ground_truth_style, ground_truth_label,
                label_count, uid_count, fingerprint_input, FLAGS)

    else:
        raise NotImplementedError('Unknown model: %s' % FLAGS.model)
    #cross_entropy = losses.dan_loss(label_embedding, ground_truth_style, uid_count, label_logits, ground_truth_label)

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES):
        tf.summary.scalar('losses/%s' % loss.op.name, loss)

    total_loss = tf.losses.get_total_loss()
    tf.summary.scalar('total_loss', total_loss)

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        learning_rate_input = tf.placeholder(tf.float32, [],
                                             name='learning_rate_input')
        train_step = tf.train.MomentumOptimizer(
            learning_rate_input, momentum=0.9,
            use_nesterov=True).minimize(total_loss)
        #     train_step = tf.train.AdamOptimizer(learning_rate_input).minimize(total_loss)
        #     train_step = tf.train.AdadeltaOptimizer(learning_rate_input).minimize(total_loss)
        emb_mat = [
            var for var in tf.trainable_variables()
            if var.name.find('emb_mat') >= 0
        ]

    predicted_indices = tf.argmax(label_logits, 1)
    expected_indices = tf.argmax(ground_truth_label, 1)
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', evaluation_step)

    global_step = tf.contrib.framework.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                              '/validation')

    tf.global_variables_initializer().run()

    print(tf.trainable_variables())
    sys.stdout.flush()

    start_step = 1

    if FLAGS.start_checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)

    tf.logging.info('Training from step: %d ', start_step)
    # Save graph.pbtxt.
    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.model_architecture + '.pbtxt')

    # Save list of words.
    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(audio_processor.words_list))

    # Training loop.
    training_steps_max = np.sum(training_steps_list)
    sess.graph.finalize()
    best_ind_val_acc, total_accuracy = -1, -1
    for training_step in xrange(start_step, training_steps_max + 1):
        # Figure out what the current learning rate is.
        training_steps_sum = 0
        for i in range(len(training_steps_list)):
            training_steps_sum += training_steps_list[i]
            if training_step <= training_steps_sum:
                learning_rate_value = learning_rates_list[i]
                break
        # Pull the audio samples we'll use for training.
        train_fingerprints, train_ground_truth, train_uids = audio_processor.get_data(
            FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
            FLAGS.background_volume, time_shift_samples, 'training', sess)

        # Run the graph with this batch of training data.
        train_summary, train_accuracy, cross_entropy_value, _, _, _ = sess.run(
            [
                merged_summaries, evaluation_step, cross_entropy, total_loss,
                train_step, increment_global_step
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_label: train_ground_truth,
                ground_truth_style: train_uids,
                learning_rate_input: learning_rate_value
            })
        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
            (training_step, learning_rate_value, train_accuracy * 100,
             cross_entropy_value))
        is_last_step = (training_step == training_steps_max)
        if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
            set_size = audio_processor.set_size('ind-validation')
            total_accuracy = 0
            for i in xrange(0, set_size, FLAGS.batch_size):
                validation_fingerprints, validation_ground_truth, validation_uids = (
                    audio_processor.get_data(FLAGS.batch_size, i,
                                             model_settings, 0.0, 0.0, 0,
                                             'ind-validation', sess))
                # Run a validation step and capture training summaries for TensorBoard
                # with the `merged` op.
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_fingerprints,
                        ground_truth_label: validation_ground_truth,
                        ground_truth_style: validation_uids,
                    })
                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(FLAGS.batch_size, set_size - i)
                total_accuracy += (validation_accuracy * batch_size) / set_size
            tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                            (training_step, total_accuracy * 100, set_size))

        # Save the model checkpoint periodically.


#     if (training_step % FLAGS.save_step_interval == 0 or
#         training_step == training_steps_max):
        if total_accuracy > best_ind_val_acc:
            best_ind_val_acc = total_accuracy
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           FLAGS.model_architecture + '.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    for i in xrange(0, set_size, FLAGS.batch_size):
        test_fingerprints, test_ground_truth, test_uids = audio_processor.get_data(
            FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)

        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_label: test_ground_truth,
                ground_truth_style: test_uids,
            })
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size
    tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
                    (total_accuracy * 100, set_size))