コード例 #1
0
    def optimize(self):
        utils.create_log_dir(self.hp_utils.save_path)
        if not self.restart:
            utils.check_conflict(self.hp_utils.save_path, check=self.check)
        utils.create_log_dir(self.hp_utils.save_path)
        dy = {y_name: None for y_name in self.hp_utils.y_names}
        self.hp_utils.save_hp_conf(list(range(len(self.hp_utils.var_names))),
                                   dy,
                                   None,
                                   record=False)

        self.n_jobs = self.get_n_jobs()
        self.check_double_submission()

        save_time = utils.save_elapsed_time(self.hp_utils.save_path,
                                            self.hp_utils.lock, self.verbose,
                                            self.print_freq)

        if self.n_parallels <= 1:
            self._optimize_sequential(save_time)
        else:
            self._optimize_parallel(save_time)

        hps_conf, losses = self.hp_utils.load_hps_conf(do_sort=True)
        best_hp_conf, best_performance = hps_conf[0], losses[0][0]
        self.print_optimized_result(best_hp_conf, best_performance)

        return best_hp_conf, best_performance
コード例 #2
0
    def optimize(self):
        utils.create_log_dir(self.hp_utils.save_path)
        if not self.restart:
            utils.check_conflict(self.hp_utils.save_path)
        utils.create_log_dir(self.hp_utils.save_path)

        self.n_jobs = self.get_n_jobs()

        if self.n_parallels <= 1:
            self._optimize_sequential()
        else:
            self._optimize_parallel()

        hps_conf, losses = self.hp_utils.load_hps_conf(do_sort=True)
        best_hp_conf, best_performance = hps_conf[0], losses[0]
        self.print_optimized_result(best_hp_conf, best_performance[0])

        return best_hp_conf, best_performance
コード例 #3
0
def main(args):
    # I/O
    config_file = args.config_file
    config = utils.import_file(config_file, 'config')

    #trainset = utils.Dataset(config.train_dataset_path)
    testset = utils.Dataset(config.test_dataset_path)

    network = BaseNetwork()
    network.initialize(config, 0 ) #trainset.num_classes


    # Initalization for running
    log_dir = utils.create_log_dir(config, config_file)
    summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model is not None:
        network.restore_model(config.restore_model, config.restore_scopes)

    # Set up LFW test protocol and load images
    print('Loading images...')
    lfwtest = LFWTest(testset.images)
    lfwtest.init_standard_proto(config.lfw_pairs_file)
    lfwtest.images = utils.preprocess(lfwtest.image_paths, config, is_training=False)


    #trainset.start_batch_queue(config, True)


    #
    # Main Loop
    #
    print('\nStart Training\nname: %s\n# epochs: %d\nepoch_size: %d\nbatch_size: %d\n'\
        % (config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0

    # Testing on LFW
    print('Testing on Neetis LFW protocol...')
    embeddings = network.extract_feature(lfwtest.images, config.batch_size)
    print(type(embeddings))

    accuracy_embeddings, threshold_embeddings = lfwtest.test_standard_proto(embeddings)
    print('Embeddings Accuracy: %2.4f Threshold %2.3f' % (accuracy_embeddings, threshold_embeddings))

    with open(os.path.join(log_dir,'lfw_result.txt'),'at') as f:
        f.write('%d\t%.5f\n' % (global_step,accuracy_embeddings))
    summary = tf.Summary()
    summary.value.add(tag='lfw/accuracy', simple_value=accuracy_embeddings)
    summary_writer.add_summary(summary, global_step)
コード例 #4
0
def run():
    args = argparser()

    path = utils.create_log_dir(sys.argv)
    utils.start(args.http_port)

    env = Env(args)
    agents = [Agent(args) for _ in range(args.n_agent)]
    master = Master(args)

    for agent in agents:
        master.add_agent(agent)
    master.add_env(env)

    success_list = []
    time_list = []

    for idx in range(args.n_episode):
        print('=' * 80)
        print("Episode {}".format(idx + 1))
        # 서버의 stack, timer 초기화
        print("서버를 초기화하는중...")
        master.reset(path)

        # 에피소드 시작
        master.start()
        # 에이전트 학습
        master.train()
        print('=' * 80)
        success_list.append(master.infos["is_success"])
        time_list.append(master.infos["end_time"] - master.infos["start_time"])

        if (idx + 1) % args.print_interval == 0:
            print("=" * 80)
            print("EPISODE {}: Avg. Success Rate / Time: {:.2} / {:.2}".format(
                idx + 1, np.mean(success_list), np.mean(time_list)))
            success_list.clear()
            time_list.clear()
            print("=" * 80)

        if (idx + 1) % args.checkpoint_interval == 0:
            utils.save_checkpoints(path, agents, idx + 1)

    if args.visual:
        visualize(path, args)
    print("끝")
    utils.close()
コード例 #5
0
def main(args):
    # I/O
    config_file = args.config_file
    config = utils.import_file(config_file, 'config')

    trainset = utils.Dataset(config.train_dataset_path)
    testset = utils.Dataset(config.test_dataset_path)

    network = BaseNetwork()
    network.initialize(config, trainset.num_classes)

    # Initalization for running
    log_dir = utils.create_log_dir(config, config_file)
    summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model is not None:
        network.restore_model(config.restore_model, config.restore_scopes)

    # Set up LFW test protocol and load images
    print('Loading images...')
    lfwtest = LFWTest(testset.images)
    lfwtest.init_standard_proto(config.lfw_pairs_file)
    lfwtest.images = utils.preprocess(lfwtest.image_paths,
                                      config,
                                      is_training=False)

    trainset.start_batch_queue(config, True)

    #
    # Main Loop
    #
    print('\nStart Training\nname: %s\n# epochs: %d\nepoch_size: %d\nbatch_size: %d\n'\
        % (config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0
    start_time = time.time()

    for epoch in range(config.num_epochs):

        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(
                global_step, config)
            batch = trainset.pop_batch_queue()

            wl, sm, global_step = network.train(batch['images'],
                                                batch['labels'], learning_rate,
                                                config.keep_prob)

            # Display
            if step % config.summary_interval == 0:
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                summary_writer.add_summary(sm, global_step=global_step)

        # Testing on LFW
        print('Testing on standard LFW protocol...')
        embeddings = network.extract_feature(lfwtest.images, config.batch_size)
        accuracy_embeddings, threshold_embeddings = lfwtest.test_standard_proto(
            embeddings)
        print('Embeddings Accuracy: %2.4f Threshold %2.3f' %
              (accuracy_embeddings, threshold_embeddings))

        with open(os.path.join(log_dir, 'lfw_result.txt'), 'at') as f:
            f.write('%d\t%.5f\n' % (global_step, accuracy_embeddings))
        summary = tf.Summary()
        summary.value.add(tag='lfw/accuracy', simple_value=accuracy_embeddings)
        summary_writer.add_summary(summary, global_step)

        # Save the model
        network.save_model(log_dir, global_step)
コード例 #6
0
                        help='Number of projections to test')
    parser.add_argument('-e',
                        '--eps',
                        nargs='?',
                        default=0.001,
                        type=float,
                        help='Epsilon for entropic gw')

    args = parser.parse_args()
    all_samples = args.all_samples
    numItermax = args.nitermax
    projs = args.proj
    epsilon = args.eps
    niterentro = args.it_entro

    log_dir = utils.create_log_dir(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    results = {}

    results['all_samples'] = all_samples
    results['projs'] = projs
    results['epsilon'] = epsilon

    results['all_gw'] = []
    results['t_all_gw'] = []
    results['all_converged'] = []

    results['all_entro_gw'] = []
    results['t_all_entro_gw'] = []
コード例 #7
0
ファイル: train.py プロジェクト: juanmc2005/MetricAMI
from os.path import join
import utils
import builders
import args
from core.base import Trainer
from core.plugins.logging import TrainLogger, MetricFileLogger, HeaderPrinter
from core.plugins.storage import BestModelSaver, ModelLoader
from metrics import KNNF1ScoreMetric, Evaluator

# Script arguments
args = args.get_args()

# Create directory to save plots, models, results, etc
log_path = utils.create_log_dir(args.exp_path, args)
print(f"Logging to {log_path}")

# Dump all script arguments
utils.dump_params(join(log_path, 'config.cfg'), args)

# Set custom seed
utils.set_custom_seed(args.seed)

# Load dataset and create model
print(f"[Model: {args.model.upper()}]")
print(f"[Loss: {args.loss.upper()}]")
print('[Loading Dataset and Model...]')

# Embedding dim is 768 based on BERT, we use the same for LSTM to be fair
# Classes are 6 because we include a 'non-misogyny class'
nfeat, nclass = 768, 6
config = builders.build_config(args, nfeat, nclass)
コード例 #8
0
ファイル: train_sibling.py プロジェクト: smilejx/DocFace
def main(args):
    # I/O
    config_file = args.config_file
    config = utils.import_file(config_file, 'config')

    trainset = utils.Dataset(config.train_dataset_path)
    testset = utils.Dataset(config.test_dataset_path)

    network = SiblingNetwork()
    network.initialize(config, trainset.num_classes)


    # Initalization for running
    log_dir = utils.create_log_dir(config, config_file)
    summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model:
        network.restore_model(config.restore_model, config.restore_scopes)

    # Set up test protocol and load images
    print('Loading images...')
    testset.separate_template_and_probes()
    testset.images = utils.preprocess(testset.images, config, is_training=False)


    trainset.start_batch_queue(config, True)


    #
    # Main Loop
    #
    print('\nStart Training\nname: %s\n# epochs: %d\nepoch_size: %d\nbatch_size: %d\n'\
        % (config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0
    start_time = time.time()

    for epoch in range(config.num_epochs):

        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(global_step, config)
            image_batch, label_batch = trainset.pop_batch_queue()
        
            switch_batch = utils.zero_one_switch(len(image_batch))
            wl, sm, global_step = network.train(image_batch, label_batch, switch_batch, learning_rate, config.keep_prob)

            # Display
            if step % config.summary_interval == 0:
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                summary_writer.add_summary(sm, global_step=global_step)

        # Testing
        print('Testing...')
        switch = utils.zero_one_switch(len(testset.images))
        embeddings = network.extract_feature(testset.images, switch, config.batch_size)
        tars, fars, _ = utils.test_roc(embeddings, FARs=[1e-4, 1e-3, 1e-2])
        with open(os.path.join(log_dir,'result.txt'),'at') as f:
            for i in range(len(tars)):
                print('[%d] TAR: %2.4f FAR %2.3f' % (epoch+1, tars[i], fars[i]))
                f.write('[%d] TAR: %2.4f FAR %2.3f\n' % (epoch+1, tars[i], fars[i]))
                summary = tf.Summary()
                summary.value.add(tag='test/tar_%d'%i, simple_value=tars[i])
                summary_writer.add_summary(summary, global_step)

        # Save the model
        network.save_model(log_dir, global_step)
コード例 #9
0
parser.add_argument('--config_file',
                    help='Path to training configuration file',
                    type=str)
config_file = parser.parse_args().config_file
# I/O
config = utils.import_file(config_file, 'config')

trainset = utils.Dataset(config.splits_path + '/train_' +
                         str(config.fold_number) + '.txt')
trainset.images = utils.preprocess(trainset.images, config, True)

network = Network()
network.initialize(config, trainset.num_classes)

# Initalization for running
log_dir = utils.create_log_dir(config, config_file)
summary_writer = tf.summary.FileWriter(log_dir, network.graph)
if config.restore_model:
    network.restore_model(config.restore_model, config.restore_scopes)

# Load gallery and probe file_list
print('Loading images...')
probes = []
gal = []
with open(
        config.splits_path + '/fold_' + str(config.fold_number) +
        '/probe_1.txt', 'r') as f:
    for line in f:
        probes.append(line.strip())

probe_set = evaluate.ImageSet(probes, config)
コード例 #10
0
ファイル: nested_cv_fgw.py プロジェクト: stjordanis/FGW
    if args.alpha==-8000 and not args.automatic_cv_alpha:
        raise AlphaMustBeDefinedError('You must set alpha via -a or use automatic grid via -cva')
    
    
    name='fgw'+'_'+args.dataset_name+'_feature_metric_'+args.feature_metric+'_structure_metric_'+args.structure_metric
    if args.wl_feature>0:
        name=name+'_wl_'+str(args.wl_feature)
    name=name+args.optionnal_name
    
    try:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
    except OSError:
        raise

    log_dir=create_log_dir(args)
  
    if args.Csvm !=-1:
        Clist=[args.Csvm]
    else:
        Clist=list(np.logspace(-4,4,15))
        
    if args.automatic_cv_alpha :
        nb=15
        N=int(15/3)
        a=np.logspace(-6,-1,N)
        c=1-a
        b=np.array(list(set(np.linspace(a[0],c[0],N)).difference(set((a[0],c[0])))))
        alphalist=np.concatenate((a,b,c))
        alpha_list=list(np.sort(np.append([0,1],alphalist)))
    else:
コード例 #11
0
def main(_):
    # We want to see all the logging messages for this tutorial.
    # 记录日志
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    # 启动会话
    sess = tf.InteractiveSession()

    # 将日志写入指定文件
    # get TF logger
    import utils
    utils.create_log_dir(FLAGS.train_dir, FLAGS.log_path)

    # Begin by making sure we have the training data we need. If you already have
    # training data of your own, use `--data_url= ` on the command line to avoid
    # downloading.
    # 模型设置:单词切割、采样率、时长、窗长、帧移、mfcc系数
    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)

    # 数据网址、本地文件夹、静音比重、干扰词比重、验证集占比、测试集占比
    audio_processor = input_data.AudioProcessor(FLAGS.data_url, FLAGS.data_dir,
                                                FLAGS.silence_percentage,
                                                FLAGS.unknown_percentage,
                                                FLAGS.wanted_words.split(','),
                                                FLAGS.validation_percentage,
                                                FLAGS.testing_percentage,
                                                model_settings)

    # 数据输入格式
    fingerprint_size = model_settings['fingerprint_size']
    # 标签数量
    label_count = model_settings['label_count']
    # 采样率
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)

    # Figure out the learning rates for each training phase. Since it's often
    # effective to have high learning rates at the start of training, followed by
    # lower levels towards the end, the number of steps and learning rates can be
    # specified as comma-separated lists to define the rate at each stage. For
    # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
    # will run 13,000 training loops in total, with a rate of 0.001 for the first
    # 10,000, and 0.0001 for the final 3,000.
    # 实验发现:在训练的开始,能很快到达很高的准确率,即训练速度快;但在接近最高点的时候,学习效率就会降低;
    # 因此为了适应前后训练的不同,需要根据时间调整前后的学习率:前面大,后面小;
    # 比如整个循环有13000次,前10000次使用0.001的学习率;后3000次使用0.0001的学习率
    # training_steps_list=10000,10000,10000
    # learning_rates_list=0.0005, 0.0001, 0.00002
    training_steps_list = list(
        map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))

    # 数据输入占位符
    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    # 创建模型
    # logits:预测标签结果;dropout_prob:丢弃率
    logits, dropout_prob = models.create_model(fingerprint_input,
                                               model_settings,
                                               FLAGS.model_architecture,
                                               FLAGS.model_size_info,
                                               is_training=True)

    # Define loss and optimizer
    # 输出维度占位符---真实标签
    ground_truth_input = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # Create the back propagation and training evaluation machinery in the graph.
    # 交叉验证,即优化函数
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(labels=ground_truth_input,
                                                    logits=logits))
    # summary概要:主要用于数据可视化,scalar将数据转化为标量
    tf.summary.scalar('cross_entropy', cross_entropy_mean)

    # 使用BN,即批量更新
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.name_scope('train'), tf.control_dependencies(
            update_ops), tf.control_dependencies(control_dependencies):
        # 学习率
        learning_rate_input = tf.placeholder(tf.float32, [],
                                             name='learning_rate_input')
        # 优化函数
        train_op = tf.train.AdamOptimizer(learning_rate_input)
        # 交叉验证函数+优化函数 进行训练
        train_step = slim.learning.create_train_op(cross_entropy_mean,
                                                   train_op)

    #    train_step = tf.train.GradientDescentOptimizer(
    #        learning_rate_input).minimize(cross_entropy_mean)

    # 返回概率最大的标签下标:分别为预测值-真实值
    predicted_indices = tf.argmax(logits, 1)
    expected_indices = tf.argmax(ground_truth_input, 1)
    # 正确与否,0-1值矩阵
    correct_prediction = tf.equal(predicted_indices, expected_indices)

    # 计算混淆矩阵:直接显示模型的效果,越集中于对角线,表示效果越好
    # https: // www.zhihu.com / question / 36883196
    confusion_matrix = tf.confusion_matrix(expected_indices,
                                           predicted_indices,
                                           num_classes=label_count)

    # 正确率:计算平均值,因为由0-1组成,所以平均值即正确率
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', evaluation_step)

    # 全局步数记录
    global_step = tf.train.get_or_create_global_step()
    # 全局步数+1
    increment_global_step = tf.assign(global_step, global_step + 1)

    # 模型存储
    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /Users/zoutai/ML_KWS/retrain_logs (by default)
    # 将所有的数据概要保存到文件
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                              '/validation')

    # 初始化全局变量
    tf.global_variables_initializer().run()

    # Parameter counts
    # 记录参数量
    params = tf.trainable_variables()
    num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()),
                         params))
    # print('Total number of Parameters: ', num_params)
    tf.logging.info('Total number of Parameters: %d ', num_params)

    start_step = 1

    # 是否有初始模型,如果有从初始模型开始;否则,重新开始训练
    if FLAGS.start_checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)

    start_step = start_step // 2
    tf.logging.info('Training from step: %d ', start_step)

    # Save graph.pbtxt.
    # 存储模型
    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.model_architecture + '.pbtxt')

    # Save list of words.
    # 记录标签
    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(audio_processor.words_list))

    # 开启训练
    # Training loop.
    best_accuracy = 0
    # 总训练步数
    training_steps_max = np.sum(training_steps_list)
    for training_step in range(start_step, training_steps_max + 1):
        # Figure out what the current learning rate is.
        # 每一阶段的总步数
        training_steps_sum = 0

        # 三个阶段
        for i in range(len(training_steps_list)):
            training_steps_sum += training_steps_list[i]
            if training_step <= training_steps_sum:
                learning_rate_value = learning_rates_list[i]
                break

        # Pull the audio samples we'll use for training.
        # 读取输入数据
        train_fingerprints, train_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
            FLAGS.background_volume, time_shift_samples, 'training', sess)

        # 1.训练部分
        # Run the graph with this batch of training data.
        # 开启训练会话
        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries, evaluation_step, cross_entropy_mean,
                train_step, increment_global_step
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_input: train_ground_truth,
                learning_rate_input: learning_rate_value,
                dropout_prob: 1.0
            })

        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.2f%%, cross entropy %f' %
            (training_step, learning_rate_value, train_accuracy * 100,
             cross_entropy_value))

        # 2.验证部分
        # 最后一步,训练完成
        is_last_step = (training_step == training_steps_max)
        # 达到一定步数进行一次记录,一般为500次;或者到达最终步数记录一次
        if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
            # 读取验证集,进行验证
            set_size = audio_processor.set_size('validation')
            total_accuracy = 0
            total_conf_matrix = None
            for i in range(0, set_size, FLAGS.batch_size):
                validation_fingerprints, validation_ground_truth = (
                    audio_processor.get_data(FLAGS.batch_size, i,
                                             model_settings, 0.0, 0.0, 0,
                                             'validation', sess))

                # Run a validation step and capture training summaries for TensorBoard
                # with the `merged` op.
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_fingerprints,
                        ground_truth_input: validation_ground_truth,
                        dropout_prob: 1.0
                    })
                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(FLAGS.batch_size, set_size - i)
                total_accuracy += (validation_accuracy * batch_size) / set_size
                if total_conf_matrix is None:
                    total_conf_matrix = conf_matrix
                else:
                    total_conf_matrix += conf_matrix
            tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
            tf.logging.info('Step %d: Validation accuracy = %.2f%% (N=%d)' %
                            (training_step, total_accuracy * 100, set_size))

            # Save the model checkpoint when validation accuracy improves
            # 存储当前模型
            if total_accuracy > best_accuracy:
                best_accuracy = total_accuracy
                checkpoint_path = os.path.join(
                    FLAGS.train_dir, 'best', FLAGS.model_architecture + '_' +
                    str(int(best_accuracy * 10000)) + '.ckpt')
                tf.logging.info('Saving best model to "%s-%d"',
                                checkpoint_path, training_step)
                saver.save(sess, checkpoint_path, global_step=training_step)
            tf.logging.info('So far the best validation accuracy is %.2f%%' %
                            (best_accuracy * 100))

    # 3.测试部分
    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    total_conf_matrix = None
    for i in range(0, set_size, FLAGS.batch_size):
        test_fingerprints, test_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_input: test_ground_truth,
                dropout_prob: 1.0
            })
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix
    tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.logging.info('Final test accuracy = %.2f%% (N=%d)' %
                    (total_accuracy * 100, set_size))
コード例 #12
0
def train(config_file, counter):
    # I/O
    config = utils.import_file(config_file, 'config')
    splits_path = config.splits_path + '/split{}'.format(counter)

    trainset = utils.Dataset(splits_path + '/train_' + str(config.fold_number) + '.txt')
    trainset.images = utils.preprocess(trainset.images, config, True)

    network = Network()
    network.initialize(config, trainset.num_classes)

    # Initalization for running
    log_dir = utils.create_log_dir(config, config_file)
    summary_writer = tf.compat.v1.summary.FileWriter(log_dir, network.graph)
    if config.restore_model:
        network.restore_model(config.restore_model, config.restore_scopes)

    # Load gallery and probe file_list
    print('Loading images...')
    probes = []
    gal = []
    with open(splits_path + '/fold_' + str(config.fold_number) + '/probe_1.txt' ,'r') as f:
        for line in f:
            probes.append(line.strip())

    probe_set = evaluate.ImageSet(probes, config)
    #probe_set.extract_features(network, len(probes))
    #
    with open(splits_path + '/fold_'+ str(config.fold_number) + '/gal_1.txt', 'r') as f:
        for line in f:
            gal.append(line.strip())
    gal_set = evaluate.ImageSet(gal, config)
    #gal_set.extract_features(network, len(gal))

    trainset.start_batch_queue(config, True)

    #
    # Main Loop
    #
    print('\nStart Training\n# epochs: {}\nepoch_size: {}\nbatch_size: {}\n'.\
        format(config.num_epochs, config.epoch_size, config.batch_size))

    global_step = 0
    start_time = time.time()
    for epoch in range(config.num_epochs):
        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(global_step, config)
            image_batch, label_batch = trainset.pop_batch_queue()

            wl, sm, global_step = network.train(image_batch, label_batch, learning_rate, config.keep_prob)

            # Display
            if step % config.summary_interval == 0:
                # visualize.scatter2D(_prelogits[:,:2], _label_batch, _pgrads[0][:,:2])
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                summary_writer.add_summary(sm, global_step=global_step)

        # Testing
        print('Testing...')
        probe_set.extract_features(network, len(probes))
        gal_set.extract_features(network, len(gal))

        rank1, rank5 = evaluate.identify(log_dir, probe_set, gal_set)
        print('rank-1: {:.3f}, rank-5: {:.3f}'.format(rank1[0], rank5[0]))
        
        # Output test result
        summary = tf.Summary()
        summary.value.add(tag='identification/rank1', simple_value=rank1[0])
        summary.value.add(tag='identification/rank5', simple_value=rank5[0])
        summary_writer.add_summary(summary, global_step)

        # Save the model
        network.save_model(log_dir, global_step)
    results_copy = os.path.join('log/result_{}_{}.txt'.format(config.model_version, counter))
    shutil.copyfile(os.path.join(log_dir,'result.txt'), results_copy)