예제 #1
0
    def optimize(self):
        utils.create_log_dir(self.hp_utils.save_path)
        if not self.restart:
            utils.check_conflict(self.hp_utils.save_path, check=self.check)
        utils.create_log_dir(self.hp_utils.save_path)
        dy = {y_name: None for y_name in self.hp_utils.y_names}
        self.hp_utils.save_hp_conf(list(range(len(self.hp_utils.var_names))),
                                   dy,
                                   None,
                                   record=False)

        self.n_jobs = self.get_n_jobs()
        self.check_double_submission()

        save_time = utils.save_elapsed_time(self.hp_utils.save_path,
                                            self.hp_utils.lock, self.verbose,
                                            self.print_freq)

        if self.n_parallels <= 1:
            self._optimize_sequential(save_time)
        else:
            self._optimize_parallel(save_time)

        hps_conf, losses = self.hp_utils.load_hps_conf(do_sort=True)
        best_hp_conf, best_performance = hps_conf[0], losses[0][0]
        self.print_optimized_result(best_hp_conf, best_performance)

        return best_hp_conf, best_performance
예제 #2
0
    def optimize(self):
        utils.create_log_dir(self.hp_utils.save_path)
        if not self.restart:
            utils.check_conflict(self.hp_utils.save_path)
        utils.create_log_dir(self.hp_utils.save_path)

        self.n_jobs = self.get_n_jobs()

        if self.n_parallels <= 1:
            self._optimize_sequential()
        else:
            self._optimize_parallel()

        hps_conf, losses = self.hp_utils.load_hps_conf(do_sort=True)
        best_hp_conf, best_performance = hps_conf[0], losses[0]
        self.print_optimized_result(best_hp_conf, best_performance[0])

        return best_hp_conf, best_performance
예제 #3
0
def main(args):
    # I/O
    config_file = args.config_file
    config = utils.import_file(config_file, 'config')

    #trainset = utils.Dataset(config.train_dataset_path)
    testset = utils.Dataset(config.test_dataset_path)

    network = BaseNetwork()
    network.initialize(config, 0 ) #trainset.num_classes


    # Initalization for running
    log_dir = utils.create_log_dir(config, config_file)
    summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model is not None:
        network.restore_model(config.restore_model, config.restore_scopes)

    # Set up LFW test protocol and load images
    print('Loading images...')
    lfwtest = LFWTest(testset.images)
    lfwtest.init_standard_proto(config.lfw_pairs_file)
    lfwtest.images = utils.preprocess(lfwtest.image_paths, config, is_training=False)


    #trainset.start_batch_queue(config, True)


    #
    # Main Loop
    #
    print('\nStart Training\nname: %s\n# epochs: %d\nepoch_size: %d\nbatch_size: %d\n'\
        % (config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0

    # Testing on LFW
    print('Testing on Neetis LFW protocol...')
    embeddings = network.extract_feature(lfwtest.images, config.batch_size)
    print(type(embeddings))

    accuracy_embeddings, threshold_embeddings = lfwtest.test_standard_proto(embeddings)
    print('Embeddings Accuracy: %2.4f Threshold %2.3f' % (accuracy_embeddings, threshold_embeddings))

    with open(os.path.join(log_dir,'lfw_result.txt'),'at') as f:
        f.write('%d\t%.5f\n' % (global_step,accuracy_embeddings))
    summary = tf.Summary()
    summary.value.add(tag='lfw/accuracy', simple_value=accuracy_embeddings)
    summary_writer.add_summary(summary, global_step)
예제 #4
0
def run():
    args = argparser()

    path = utils.create_log_dir(sys.argv)
    utils.start(args.http_port)

    env = Env(args)
    agents = [Agent(args) for _ in range(args.n_agent)]
    master = Master(args)

    for agent in agents:
        master.add_agent(agent)
    master.add_env(env)

    success_list = []
    time_list = []

    for idx in range(args.n_episode):
        print('=' * 80)
        print("Episode {}".format(idx + 1))
        # 서버의 stack, timer 초기화
        print("서버를 초기화하는중...")
        master.reset(path)

        # 에피소드 시작
        master.start()
        # 에이전트 학습
        master.train()
        print('=' * 80)
        success_list.append(master.infos["is_success"])
        time_list.append(master.infos["end_time"] - master.infos["start_time"])

        if (idx + 1) % args.print_interval == 0:
            print("=" * 80)
            print("EPISODE {}: Avg. Success Rate / Time: {:.2} / {:.2}".format(
                idx + 1, np.mean(success_list), np.mean(time_list)))
            success_list.clear()
            time_list.clear()
            print("=" * 80)

        if (idx + 1) % args.checkpoint_interval == 0:
            utils.save_checkpoints(path, agents, idx + 1)

    if args.visual:
        visualize(path, args)
    print("끝")
    utils.close()
예제 #5
0
def main(args):
    # I/O
    config_file = args.config_file
    config = utils.import_file(config_file, 'config')

    trainset = utils.Dataset(config.train_dataset_path)
    testset = utils.Dataset(config.test_dataset_path)

    network = BaseNetwork()
    network.initialize(config, trainset.num_classes)

    # Initalization for running
    log_dir = utils.create_log_dir(config, config_file)
    summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model is not None:
        network.restore_model(config.restore_model, config.restore_scopes)

    # Set up LFW test protocol and load images
    print('Loading images...')
    lfwtest = LFWTest(testset.images)
    lfwtest.init_standard_proto(config.lfw_pairs_file)
    lfwtest.images = utils.preprocess(lfwtest.image_paths,
                                      config,
                                      is_training=False)

    trainset.start_batch_queue(config, True)

    #
    # Main Loop
    #
    print('\nStart Training\nname: %s\n# epochs: %d\nepoch_size: %d\nbatch_size: %d\n'\
        % (config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0
    start_time = time.time()

    for epoch in range(config.num_epochs):

        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(
                global_step, config)
            batch = trainset.pop_batch_queue()

            wl, sm, global_step = network.train(batch['images'],
                                                batch['labels'], learning_rate,
                                                config.keep_prob)

            # Display
            if step % config.summary_interval == 0:
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                summary_writer.add_summary(sm, global_step=global_step)

        # Testing on LFW
        print('Testing on standard LFW protocol...')
        embeddings = network.extract_feature(lfwtest.images, config.batch_size)
        accuracy_embeddings, threshold_embeddings = lfwtest.test_standard_proto(
            embeddings)
        print('Embeddings Accuracy: %2.4f Threshold %2.3f' %
              (accuracy_embeddings, threshold_embeddings))

        with open(os.path.join(log_dir, 'lfw_result.txt'), 'at') as f:
            f.write('%d\t%.5f\n' % (global_step, accuracy_embeddings))
        summary = tf.Summary()
        summary.value.add(tag='lfw/accuracy', simple_value=accuracy_embeddings)
        summary_writer.add_summary(summary, global_step)

        # Save the model
        network.save_model(log_dir, global_step)
예제 #6
0
                        help='Number of projections to test')
    parser.add_argument('-e',
                        '--eps',
                        nargs='?',
                        default=0.001,
                        type=float,
                        help='Epsilon for entropic gw')

    args = parser.parse_args()
    all_samples = args.all_samples
    numItermax = args.nitermax
    projs = args.proj
    epsilon = args.eps
    niterentro = args.it_entro

    log_dir = utils.create_log_dir(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    results = {}

    results['all_samples'] = all_samples
    results['projs'] = projs
    results['epsilon'] = epsilon

    results['all_gw'] = []
    results['t_all_gw'] = []
    results['all_converged'] = []

    results['all_entro_gw'] = []
    results['t_all_entro_gw'] = []
예제 #7
0
from os.path import join
import utils
import builders
import args
from core.base import Trainer
from core.plugins.logging import TrainLogger, MetricFileLogger, HeaderPrinter
from core.plugins.storage import BestModelSaver, ModelLoader
from metrics import KNNF1ScoreMetric, Evaluator

# Script arguments
args = args.get_args()

# Create directory to save plots, models, results, etc
log_path = utils.create_log_dir(args.exp_path, args)
print(f"Logging to {log_path}")

# Dump all script arguments
utils.dump_params(join(log_path, 'config.cfg'), args)

# Set custom seed
utils.set_custom_seed(args.seed)

# Load dataset and create model
print(f"[Model: {args.model.upper()}]")
print(f"[Loss: {args.loss.upper()}]")
print('[Loading Dataset and Model...]')

# Embedding dim is 768 based on BERT, we use the same for LSTM to be fair
# Classes are 6 because we include a 'non-misogyny class'
nfeat, nclass = 768, 6
config = builders.build_config(args, nfeat, nclass)
예제 #8
0
def main(args):
    # I/O
    config_file = args.config_file
    config = utils.import_file(config_file, 'config')

    trainset = utils.Dataset(config.train_dataset_path)
    testset = utils.Dataset(config.test_dataset_path)

    network = SiblingNetwork()
    network.initialize(config, trainset.num_classes)


    # Initalization for running
    log_dir = utils.create_log_dir(config, config_file)
    summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model:
        network.restore_model(config.restore_model, config.restore_scopes)

    # Set up test protocol and load images
    print('Loading images...')
    testset.separate_template_and_probes()
    testset.images = utils.preprocess(testset.images, config, is_training=False)


    trainset.start_batch_queue(config, True)


    #
    # Main Loop
    #
    print('\nStart Training\nname: %s\n# epochs: %d\nepoch_size: %d\nbatch_size: %d\n'\
        % (config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0
    start_time = time.time()

    for epoch in range(config.num_epochs):

        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(global_step, config)
            image_batch, label_batch = trainset.pop_batch_queue()
        
            switch_batch = utils.zero_one_switch(len(image_batch))
            wl, sm, global_step = network.train(image_batch, label_batch, switch_batch, learning_rate, config.keep_prob)

            # Display
            if step % config.summary_interval == 0:
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                summary_writer.add_summary(sm, global_step=global_step)

        # Testing
        print('Testing...')
        switch = utils.zero_one_switch(len(testset.images))
        embeddings = network.extract_feature(testset.images, switch, config.batch_size)
        tars, fars, _ = utils.test_roc(embeddings, FARs=[1e-4, 1e-3, 1e-2])
        with open(os.path.join(log_dir,'result.txt'),'at') as f:
            for i in range(len(tars)):
                print('[%d] TAR: %2.4f FAR %2.3f' % (epoch+1, tars[i], fars[i]))
                f.write('[%d] TAR: %2.4f FAR %2.3f\n' % (epoch+1, tars[i], fars[i]))
                summary = tf.Summary()
                summary.value.add(tag='test/tar_%d'%i, simple_value=tars[i])
                summary_writer.add_summary(summary, global_step)

        # Save the model
        network.save_model(log_dir, global_step)
예제 #9
0
parser.add_argument('--config_file',
                    help='Path to training configuration file',
                    type=str)
config_file = parser.parse_args().config_file
# I/O
config = utils.import_file(config_file, 'config')

trainset = utils.Dataset(config.splits_path + '/train_' +
                         str(config.fold_number) + '.txt')
trainset.images = utils.preprocess(trainset.images, config, True)

network = Network()
network.initialize(config, trainset.num_classes)

# Initalization for running
log_dir = utils.create_log_dir(config, config_file)
summary_writer = tf.summary.FileWriter(log_dir, network.graph)
if config.restore_model:
    network.restore_model(config.restore_model, config.restore_scopes)

# Load gallery and probe file_list
print('Loading images...')
probes = []
gal = []
with open(
        config.splits_path + '/fold_' + str(config.fold_number) +
        '/probe_1.txt', 'r') as f:
    for line in f:
        probes.append(line.strip())

probe_set = evaluate.ImageSet(probes, config)
예제 #10
0
    if args.alpha==-8000 and not args.automatic_cv_alpha:
        raise AlphaMustBeDefinedError('You must set alpha via -a or use automatic grid via -cva')
    
    
    name='fgw'+'_'+args.dataset_name+'_feature_metric_'+args.feature_metric+'_structure_metric_'+args.structure_metric
    if args.wl_feature>0:
        name=name+'_wl_'+str(args.wl_feature)
    name=name+args.optionnal_name
    
    try:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
    except OSError:
        raise

    log_dir=create_log_dir(args)
  
    if args.Csvm !=-1:
        Clist=[args.Csvm]
    else:
        Clist=list(np.logspace(-4,4,15))
        
    if args.automatic_cv_alpha :
        nb=15
        N=int(15/3)
        a=np.logspace(-6,-1,N)
        c=1-a
        b=np.array(list(set(np.linspace(a[0],c[0],N)).difference(set((a[0],c[0])))))
        alphalist=np.concatenate((a,b,c))
        alpha_list=list(np.sort(np.append([0,1],alphalist)))
    else:
예제 #11
0
def main(_):
    # We want to see all the logging messages for this tutorial.
    # 记录日志
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    # 启动会话
    sess = tf.InteractiveSession()

    # 将日志写入指定文件
    # get TF logger
    import utils
    utils.create_log_dir(FLAGS.train_dir, FLAGS.log_path)

    # Begin by making sure we have the training data we need. If you already have
    # training data of your own, use `--data_url= ` on the command line to avoid
    # downloading.
    # 模型设置:单词切割、采样率、时长、窗长、帧移、mfcc系数
    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)

    # 数据网址、本地文件夹、静音比重、干扰词比重、验证集占比、测试集占比
    audio_processor = input_data.AudioProcessor(FLAGS.data_url, FLAGS.data_dir,
                                                FLAGS.silence_percentage,
                                                FLAGS.unknown_percentage,
                                                FLAGS.wanted_words.split(','),
                                                FLAGS.validation_percentage,
                                                FLAGS.testing_percentage,
                                                model_settings)

    # 数据输入格式
    fingerprint_size = model_settings['fingerprint_size']
    # 标签数量
    label_count = model_settings['label_count']
    # 采样率
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)

    # Figure out the learning rates for each training phase. Since it's often
    # effective to have high learning rates at the start of training, followed by
    # lower levels towards the end, the number of steps and learning rates can be
    # specified as comma-separated lists to define the rate at each stage. For
    # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
    # will run 13,000 training loops in total, with a rate of 0.001 for the first
    # 10,000, and 0.0001 for the final 3,000.
    # 实验发现:在训练的开始,能很快到达很高的准确率,即训练速度快;但在接近最高点的时候,学习效率就会降低;
    # 因此为了适应前后训练的不同,需要根据时间调整前后的学习率:前面大,后面小;
    # 比如整个循环有13000次,前10000次使用0.001的学习率;后3000次使用0.0001的学习率
    # training_steps_list=10000,10000,10000
    # learning_rates_list=0.0005, 0.0001, 0.00002
    training_steps_list = list(
        map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))

    # 数据输入占位符
    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    # 创建模型
    # logits:预测标签结果;dropout_prob:丢弃率
    logits, dropout_prob = models.create_model(fingerprint_input,
                                               model_settings,
                                               FLAGS.model_architecture,
                                               FLAGS.model_size_info,
                                               is_training=True)

    # Define loss and optimizer
    # 输出维度占位符---真实标签
    ground_truth_input = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # Create the back propagation and training evaluation machinery in the graph.
    # 交叉验证,即优化函数
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(labels=ground_truth_input,
                                                    logits=logits))
    # summary概要:主要用于数据可视化,scalar将数据转化为标量
    tf.summary.scalar('cross_entropy', cross_entropy_mean)

    # 使用BN,即批量更新
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.name_scope('train'), tf.control_dependencies(
            update_ops), tf.control_dependencies(control_dependencies):
        # 学习率
        learning_rate_input = tf.placeholder(tf.float32, [],
                                             name='learning_rate_input')
        # 优化函数
        train_op = tf.train.AdamOptimizer(learning_rate_input)
        # 交叉验证函数+优化函数 进行训练
        train_step = slim.learning.create_train_op(cross_entropy_mean,
                                                   train_op)

    #    train_step = tf.train.GradientDescentOptimizer(
    #        learning_rate_input).minimize(cross_entropy_mean)

    # 返回概率最大的标签下标:分别为预测值-真实值
    predicted_indices = tf.argmax(logits, 1)
    expected_indices = tf.argmax(ground_truth_input, 1)
    # 正确与否,0-1值矩阵
    correct_prediction = tf.equal(predicted_indices, expected_indices)

    # 计算混淆矩阵:直接显示模型的效果,越集中于对角线,表示效果越好
    # https: // www.zhihu.com / question / 36883196
    confusion_matrix = tf.confusion_matrix(expected_indices,
                                           predicted_indices,
                                           num_classes=label_count)

    # 正确率:计算平均值,因为由0-1组成,所以平均值即正确率
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', evaluation_step)

    # 全局步数记录
    global_step = tf.train.get_or_create_global_step()
    # 全局步数+1
    increment_global_step = tf.assign(global_step, global_step + 1)

    # 模型存储
    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /Users/zoutai/ML_KWS/retrain_logs (by default)
    # 将所有的数据概要保存到文件
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                              '/validation')

    # 初始化全局变量
    tf.global_variables_initializer().run()

    # Parameter counts
    # 记录参数量
    params = tf.trainable_variables()
    num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()),
                         params))
    # print('Total number of Parameters: ', num_params)
    tf.logging.info('Total number of Parameters: %d ', num_params)

    start_step = 1

    # 是否有初始模型,如果有从初始模型开始;否则,重新开始训练
    if FLAGS.start_checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)

    start_step = start_step // 2
    tf.logging.info('Training from step: %d ', start_step)

    # Save graph.pbtxt.
    # 存储模型
    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.model_architecture + '.pbtxt')

    # Save list of words.
    # 记录标签
    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(audio_processor.words_list))

    # 开启训练
    # Training loop.
    best_accuracy = 0
    # 总训练步数
    training_steps_max = np.sum(training_steps_list)
    for training_step in range(start_step, training_steps_max + 1):
        # Figure out what the current learning rate is.
        # 每一阶段的总步数
        training_steps_sum = 0

        # 三个阶段
        for i in range(len(training_steps_list)):
            training_steps_sum += training_steps_list[i]
            if training_step <= training_steps_sum:
                learning_rate_value = learning_rates_list[i]
                break

        # Pull the audio samples we'll use for training.
        # 读取输入数据
        train_fingerprints, train_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
            FLAGS.background_volume, time_shift_samples, 'training', sess)

        # 1.训练部分
        # Run the graph with this batch of training data.
        # 开启训练会话
        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries, evaluation_step, cross_entropy_mean,
                train_step, increment_global_step
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_input: train_ground_truth,
                learning_rate_input: learning_rate_value,
                dropout_prob: 1.0
            })

        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.2f%%, cross entropy %f' %
            (training_step, learning_rate_value, train_accuracy * 100,
             cross_entropy_value))

        # 2.验证部分
        # 最后一步,训练完成
        is_last_step = (training_step == training_steps_max)
        # 达到一定步数进行一次记录,一般为500次;或者到达最终步数记录一次
        if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
            # 读取验证集,进行验证
            set_size = audio_processor.set_size('validation')
            total_accuracy = 0
            total_conf_matrix = None
            for i in range(0, set_size, FLAGS.batch_size):
                validation_fingerprints, validation_ground_truth = (
                    audio_processor.get_data(FLAGS.batch_size, i,
                                             model_settings, 0.0, 0.0, 0,
                                             'validation', sess))

                # Run a validation step and capture training summaries for TensorBoard
                # with the `merged` op.
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_fingerprints,
                        ground_truth_input: validation_ground_truth,
                        dropout_prob: 1.0
                    })
                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(FLAGS.batch_size, set_size - i)
                total_accuracy += (validation_accuracy * batch_size) / set_size
                if total_conf_matrix is None:
                    total_conf_matrix = conf_matrix
                else:
                    total_conf_matrix += conf_matrix
            tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
            tf.logging.info('Step %d: Validation accuracy = %.2f%% (N=%d)' %
                            (training_step, total_accuracy * 100, set_size))

            # Save the model checkpoint when validation accuracy improves
            # 存储当前模型
            if total_accuracy > best_accuracy:
                best_accuracy = total_accuracy
                checkpoint_path = os.path.join(
                    FLAGS.train_dir, 'best', FLAGS.model_architecture + '_' +
                    str(int(best_accuracy * 10000)) + '.ckpt')
                tf.logging.info('Saving best model to "%s-%d"',
                                checkpoint_path, training_step)
                saver.save(sess, checkpoint_path, global_step=training_step)
            tf.logging.info('So far the best validation accuracy is %.2f%%' %
                            (best_accuracy * 100))

    # 3.测试部分
    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    total_conf_matrix = None
    for i in range(0, set_size, FLAGS.batch_size):
        test_fingerprints, test_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_input: test_ground_truth,
                dropout_prob: 1.0
            })
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix
    tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.logging.info('Final test accuracy = %.2f%% (N=%d)' %
                    (total_accuracy * 100, set_size))
예제 #12
0
def train(config_file, counter):
    # I/O
    config = utils.import_file(config_file, 'config')
    splits_path = config.splits_path + '/split{}'.format(counter)

    trainset = utils.Dataset(splits_path + '/train_' + str(config.fold_number) + '.txt')
    trainset.images = utils.preprocess(trainset.images, config, True)

    network = Network()
    network.initialize(config, trainset.num_classes)

    # Initalization for running
    log_dir = utils.create_log_dir(config, config_file)
    summary_writer = tf.compat.v1.summary.FileWriter(log_dir, network.graph)
    if config.restore_model:
        network.restore_model(config.restore_model, config.restore_scopes)

    # Load gallery and probe file_list
    print('Loading images...')
    probes = []
    gal = []
    with open(splits_path + '/fold_' + str(config.fold_number) + '/probe_1.txt' ,'r') as f:
        for line in f:
            probes.append(line.strip())

    probe_set = evaluate.ImageSet(probes, config)
    #probe_set.extract_features(network, len(probes))
    #
    with open(splits_path + '/fold_'+ str(config.fold_number) + '/gal_1.txt', 'r') as f:
        for line in f:
            gal.append(line.strip())
    gal_set = evaluate.ImageSet(gal, config)
    #gal_set.extract_features(network, len(gal))

    trainset.start_batch_queue(config, True)

    #
    # Main Loop
    #
    print('\nStart Training\n# epochs: {}\nepoch_size: {}\nbatch_size: {}\n'.\
        format(config.num_epochs, config.epoch_size, config.batch_size))

    global_step = 0
    start_time = time.time()
    for epoch in range(config.num_epochs):
        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(global_step, config)
            image_batch, label_batch = trainset.pop_batch_queue()

            wl, sm, global_step = network.train(image_batch, label_batch, learning_rate, config.keep_prob)

            # Display
            if step % config.summary_interval == 0:
                # visualize.scatter2D(_prelogits[:,:2], _label_batch, _pgrads[0][:,:2])
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                summary_writer.add_summary(sm, global_step=global_step)

        # Testing
        print('Testing...')
        probe_set.extract_features(network, len(probes))
        gal_set.extract_features(network, len(gal))

        rank1, rank5 = evaluate.identify(log_dir, probe_set, gal_set)
        print('rank-1: {:.3f}, rank-5: {:.3f}'.format(rank1[0], rank5[0]))
        
        # Output test result
        summary = tf.Summary()
        summary.value.add(tag='identification/rank1', simple_value=rank1[0])
        summary.value.add(tag='identification/rank5', simple_value=rank5[0])
        summary_writer.add_summary(summary, global_step)

        # Save the model
        network.save_model(log_dir, global_step)
    results_copy = os.path.join('log/result_{}_{}.txt'.format(config.model_version, counter))
    shutil.copyfile(os.path.join(log_dir,'result.txt'), results_copy)