示例#1
0
 def setUp(self):
     super().setUp()
     import_all_modules_for_register()
     main_root = os.environ['MAIN_ROOT']
     main_root = Path(main_root)
     self.config_file = main_root.joinpath(
         'egs/mock_text_cls_data/text_cls/v1/config/han-cls.yml')
示例#2
0
def main(_):
  ''' main func '''
  FLAGS = app.flags.FLAGS  #pylint: disable=invalid-name

  logging.info("config is {}".format(FLAGS.config))
  logging.info("mode is {}".format(FLAGS.mode))
  logging.info("gpu is {}".format(FLAGS.gpu))
  assert FLAGS.config, 'give a config.yaml'
  assert FLAGS.mode, 'give mode eval, infer or eval_and_infer'

  os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu  #selects a specific device

  #create dataset
  if FLAGS.mode == 'infer':
    mode = utils.INFER
  else:
    mode = utils.EVAL

  # load config
  config = utils.load_config(FLAGS.config)

  # process config
  import_all_modules_for_register()
  solver_name = config['solver']['name']
  solver = registers.solver[solver_name](config)
  config = solver.config

  eval_obj = ASREvaluate(config, gpu_str=FLAGS.gpu, mode=mode)
  eval_obj.predict()
示例#3
0
 def setUp(self):
   ''' set up'''
   import_all_modules_for_register()
   main_root = os.environ['MAIN_ROOT']
   main_root = Path(main_root)
   self.config_file = main_root.joinpath(
       'egs/mock_text_seq_label_data/seq-label/v1/config/seq-label-mock.yml')
示例#4
0
 def setUp(self):
     """ set up"""
     main_root = os.environ['MAIN_ROOT']
     main_root = Path(main_root)
     self.config_file = main_root.joinpath(
         'egs/mock_text_match_data/text_match/v1/config/rnn-match-mock.yml')
     import_all_modules_for_register()
示例#5
0
 def setUp(self):
   ''' set up'''
   import_all_modules_for_register()
   main_root = os.environ['MAIN_ROOT']
   main_root = Path(main_root)
   self.config_file = main_root.joinpath(
       'egs/mock_text_nlu_joint_data/nlu-joint/v1/config/nlu_joint.yml')
示例#6
0
 def setUp(self):
   main_root = os.environ['MAIN_ROOT']
   main_root = Path(main_root)
   self.config_file = main_root.joinpath(
       'egs/mock_text_seq2seq_data/seq2seq/v1/config/transformer-s2s.yml')
   self.config = utils.load_config(self.config_file)
   import_all_modules_for_register()
 def setUp(self):
   super().setUp()
   package_root = Path(PACKAGE_ROOT_DIR)
   self.config_file = package_root.joinpath(
       '../egs/mock_text_match_data/text_match/v1/config/rnn-match-mock.yml')
   self.config = utils.load_config(self.config_file)
   import_all_modules_for_register()
示例#8
0
 def setUp(self):
     ''' set up'''
     import_all_modules_for_register()
     main_root = os.environ['MAIN_ROOT']
     main_root = Path(main_root)
     self.config_file = main_root.joinpath(
         'egs/mock_text_seq2seq_data/nlp1/config/transformer-s2s.yml')
 def setUp(self):
     super().setUp()
     import_all_modules_for_register()
     package_root = Path(PACKAGE_ROOT_DIR)
     self.config_file = package_root.joinpath(
         '../egs/mock_text_nlu_joint_data/nlu-joint/v1/config/nlu_joint.yml'
     )
示例#10
0
 def setUp(self):
     main_root = os.environ['MAIN_ROOT']
     main_root = Path(main_root)
     self.config_file = main_root.joinpath(
         'egs/mock_text_match_data/nlp1/config/rnn-match-mock.yml')
     self.config = utils.load_config(self.config_file)
     import_all_modules_for_register()
示例#11
0
def main(_):
  ''' main func '''
  FLAGS = app.flags.FLAGS  #pylint: disable=invalid-name
  logging.info("config: {}".format(FLAGS.config))
  logging.info("mode: {}".format(FLAGS.mode))
  logging.info("gpu_visible: {}".format(FLAGS.gpu))
  assert FLAGS.config, 'pls give a config.yaml'
  assert FLAGS.mode, 'pls give mode [eval|infer|eval_and_infer]'
  os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu  #selects a specific device

  #create dataset
  mode = utils.INFER if FLAGS.mode == 'infer' else utils.EVAL

  # load config
  config = utils.load_config(FLAGS.config)

  # process config
  import_all_modules_for_register()
  solver_name = config['solver']['name']
  logging.info(f"sovler: {solver_name}")
  solver = registers.solver[solver_name](config)
  config = solver.config

  # Evaluate
  evaluate_name = config['serving']['name']
  logging.info(f"evaluate: {evaluate_name}")
  evaluate = registers.serving[evaluate_name](
      config, gpu_str=FLAGS.gpu, mode=mode)

  if FLAGS.debug:
    evaluate.debug()
  evaluate.predict()
示例#12
0
 def setUp(self):
     super().setUp()
     import_all_modules_for_register()
     package_root = Path(PACKAGE_ROOT_DIR)
     self.config_file = package_root.joinpath(
         '../egs/mock_text_seq_label_data/seq-label/v1/config/seq-label-mock.yml'
     )
 def setUp(self):
     super().setUp()
     package_root = Path(PACKAGE_ROOT_DIR)
     self.config_file = package_root.joinpath(
         '../egs/mock_text_seq2seq_data/seq2seq/v1/config/transformer-s2s.yml'
     )
     self.config = utils.load_config(self.config_file)
     import_all_modules_for_register()
示例#14
0
 def setUp(self):
     super().setUp()
     main_root = os.environ['MAIN_ROOT']
     main_root = Path(main_root)
     self.config_file = main_root.joinpath(
         'egs/mock_text_nlu_joint_data/nlu-joint/v1/config/nlu_joint.yml')
     self.config = utils.load_config(self.config_file)
     import_all_modules_for_register()
示例#15
0
文件: main.py 项目: mabaochang/delta
def main(argv):
  """
    main function
  """
  # pylint: disable=unused-argument

  if FLAGS.log_debug:
    logging.set_verbosity(logging.DEBUG)
  else:
    logging.set_verbosity(logging.INFO)

  # load config
  config = utils.load_config(FLAGS.config)
  set_seed(config)

  import_all_modules_for_register()

  solver_name = config['solver']['name']
  solver = registers.solver[solver_name](config)

  # config after process
  config = solver.config

  task_name = config['data']['task']['name']
  task_class = registers.task[task_name]

  logging.info("CMD: {}".format(FLAGS.cmd))
  if FLAGS.cmd == 'train':
    solver.train()
  elif FLAGS.cmd == 'train_and_eval':
    solver.train_and_eval()
  elif FLAGS.cmd == 'eval':
    solver.eval()
  elif FLAGS.cmd == 'infer':
    solver.infer(yield_single_examples=False)
  elif FLAGS.cmd == 'export_model':
    solver.export_model()
  elif FLAGS.cmd == 'gen_feat':
    assert config['data']['task'][
        'suffix'] == '.npy', 'wav does not need to extractor feature'
    paths = []
    for mode in [utils.TRAIN, utils.EVAL, utils.INFER]:
      paths += config['data'][mode]['paths']
    task = task_class(config, utils.INFER)
    task.generate_feat(paths, dry_run=FLAGS.dry_run)
  elif FLAGS.cmd == 'gen_cmvn':
    logging.info(
        '''using infer pipeline to compute cmvn of train_paths, and stride must be 1'''
    )
    paths = config['data'][utils.TRAIN]['paths']
    segments = config['data'][utils.TRAIN]['segments']
    config['data'][utils.INFER]['paths'] = paths
    config['data'][utils.INFER]['segments'] = segments
    task = task_class(config, utils.INFER)
    task.generate_cmvn(dry_run=FLAGS.dry_run)
  else:
    raise ValueError("Not support mode: {}".format(FLAGS.cmd))
示例#16
0
    def test_all(self):
        # train and eval
        import_all_modules_for_register()
        solver = RawMatchSolver(self.config)
        solver.train_and_eval()
        model_path = solver.get_generated_model_path()
        self.assertNotEqual(model_path, None)

        # infer
        solver.first_eval = True
        solver.infer()
        res_file = self.config["solver"]["postproc"].get("res_file", "")
        self.assertTrue(os.path.exists(res_file))

        # export model
        solver.export_model()

        export_path_base = self.config["solver"]["service"]["model_path"]
        model_version = self.config["solver"]["service"]["model_version"]
        export_path = os.path.join(tf.compat.as_bytes(export_path_base),
                                   tf.compat.as_bytes(model_version))
        export_path = os.path.abspath(export_path)
        logging.info("Load exported model from: {}".format(export_path))

        # load the model and run
        graph = tf.Graph()
        with graph.as_default():  # pylint: disable=not-context-manager
            with self.cached_session(use_gpu=False, force_gpu=False) as sess:
                tf.saved_model.loader.load(
                    sess, [tf.saved_model.tag_constants.SERVING], export_path)

                input_sentence_tensor_left = graph.get_operation_by_name(
                    "input_sent_left").outputs[0]
                input_sentence_tensor_right = graph.get_operation_by_name(
                    "input_sent_right").outputs[0]
                score_tensor = graph.get_operation_by_name("score").outputs[0]

                score = sess.run(score_tensor,
                                 feed_dict={
                                     input_sentence_tensor_left:
                                     ["I love china"],
                                     input_sentence_tensor_right:
                                     ["I am lovely"]
                                 })
                logging.info("score: {}".format(score))
示例#17
0
  def setUp(self):
    ''' set up'''
    self.conf_str = '''
    data:
      train:
        paths:
        - null 
        segments: null
      eval:
        paths:
        - null
        segments: null
      infer:
        paths:
        - null 
        segments: null
      task:
        name: SpeechClsTask
        suffix: .npy # file suffix
        audio:
          dry_run: false # not save feat
          # params
          clip_size: 30 # clip len in seconds
          stride: 0.5 # stride in ratio of clip_size
          sr: 8000 # sample rate
          winlen: 0.025 # window len
          winstep: 0.01 # window stride
          nfft: 512 # fft bins, default: 512
          lowfreq: 0
          highfreq: null # default: null, 200 points for 800 nfft, 400 points for 1600 nfft
          preemph: 0.97 # default: 0.97
          # extractor
          feature_extractor: tffeat # `tffeat` to use TF feature_extraction .so library, 'pyfeat' to python_speech_feature
          feature_name: fbank # fbank or spec
          save_feat_path: null  # null for dump feat with same dir of wavs
          feature_size: 40 # extract feature size
          add_delta_deltas: true # delta deltas
          # log pwoer
          log_powspec: false # true, save log power spec; otherwise save power spec
          # cmvn
          cmvn: true # apply cmvn or generate cmvn
          cmvn_path: ./cmvn_conflict.npy # cmvn file
        text:
          enable: False
          vocab_path: /vocab/chars5004_attention.txt
          vocab_size: 5004 # vocab size
          max_text_len: 100 # max length for text
        classes:
          num: 2
          vocab:
            normal: 0
            conflict: 1
        num_parallel_calls: 12
        num_prefetch_batch: 2
        shuffle_buffer_size: 200000
        need_shuffle: true
    solver:
      name: EmotionSolver
      optimizer:
        name: adam
        epochs: 5 # maximum epochs
        batch_size: 32 # number of elements in a training batch
        loss: CrossEntropyLoss
        label_smoothing: 0.0 # label smoothing rate
        learning_rate:
          rate: 0.0001 # learning rate of Adam optimizer
          type:  exp_decay # learning rate type
          decay_rate: 0.99  # the lr decay rate
          decay_steps: 100  # the lr decay_step for optimizer
        clip_global_norm: 3.0 # clip global norm
        multitask: False # whether is multi-task
      metrics:
        pos_label: 1 # int, same to sklearn
        cals:
        - name: AccuracyCal
          arguments: null 
        - name: ConfusionMatrixCal
          arguments: null
        - name: PrecisionCal
          arguments:
            average: 'binary'
        - name: RecallCal
          arguments:
            average: 'binary'
        - name: F1ScoreCal
          arguments:
            average: 'binary'
      saver:
        model_path: "ckpt/emotion-speech-cls/test"
        max_to_keep: 10
        save_checkpoints_steps: 100
        keep_checkpoint_every_n_hours: 10000
        checkpoint_every: 100 # the step to save checkpoint
        summary: false
        save_summary_steps: 100
        eval_on_dev_every_secs: 1
        print_every: 10
        resume_model_path: ""
    '''
    import_all_modules_for_register()
    #tempdir = tempfile.mkdtemp()
    tempdir = self.get_temp_dir()

    config_path = str(Path(tempdir).joinpath("speech_task.yaml"))
    logging.info("config path: {}".format(config_path))
    with open(config_path, 'w', encoding='utf-8') as f:  #pylint: disable=invalid-name
      f.write(self.conf_str)

    dataset_path = Path(tempdir).joinpath("data")
    if not dataset_path.exists():
      dataset_path.mkdir()
    postive_path = dataset_path.joinpath("conflict")
    if not postive_path.exists():
      postive_path.mkdir()
    negtive_path = dataset_path.joinpath("normal")
    if not negtive_path.exists():
      negtive_path.mkdir()

    wav_path = Path(os.environ['MAIN_ROOT']).joinpath(
        'delta/data/feat/python_speech_features/english.wav')
    for i in range(10):
      pos_file = postive_path.joinpath("{}.wav".format(i))
      neg_file = negtive_path.joinpath("{}.wav".format(i))
      shutil.copyfile(str(wav_path), str(pos_file))
      shutil.copyfile(str(wav_path), str(neg_file))

    config = utils.load_config(config_path)
    config['data']['train']['paths'] = [str(dataset_path)]
    config['data']['eval']['paths'] = [str(dataset_path)]
    config['data']['infer']['paths'] = [str(dataset_path)]
    logging.info("config: {}".format(config))

    solver_name = config['solver']['name']
    self.solver = registers.solver[solver_name](config)

    # config after process
    self.config = self.solver.config

    task_name = self.config['data']['task']['name']
    self.task_class = registers.task[task_name]
 def setUp(self):
     ''' set up '''
     import_all_modules_for_register()
     '''
示例#19
0
  def setUp(self):
    super().setUp()
    self.conf_str = '''
    data:
      train:
        paths: null
        segments: null
      eval:
        paths: null
        segments: null
      infer:
        paths: null
        segments: null
      task:
        dummy: true # dummy inputs 
        name: AsrSeqTask
        type: asr # asr, tts
        audio:
          dry_run: false # not save feat
        src:
          max_len: 3000 # max length for frames
          subsampling_factor: 1
          preprocess_conf: null
        tgt:
          max_len: 100 # max length for target tokens
        vocab:
          type: char # char, bpe, wpm, word
          size: 3653 # vocab size in vocab_file
          path: '/nfs/cold_project/dataset/opensource/librispeech/espnet/egs/hkust/asr1/data/lang_1char/train_nodup_sp_units.txt' # path to vocab(default: 'vocab
        batch:
          batch_size: 32 # number of elements in a training batch
          batch_bins: 0 # maximum number of bins (frames x dim) in a trainin batch
          batch_frames_in: 0 # maximum number of input frames in a training batch
          batch_frames_out: 0 # maximum number of output frames in a training batch
          batch_frames_inout: 0 # maximum number of input+output frames in a training batch
          batch_strategy: auto # strategy to count maximum size of batch(support 4 values: "auto", "seq", "frame", "bin")
        batch_mode: false # ture, user control batch; false, `generate` will yeild one example 
        num_parallel_calls: 12
        num_prefetch_batch: 2
        shuffle_buffer_size: 200000
        need_shuffle: true
        sortagrad: true
        batch_sort_key: 'input' # shuffle, input, output for asr and tts, and sortagrad for asr
        num_batches: 0 # for debugging

    model:
      name: CTCAsrModel
      type: keras # raw, keras or eager model
      net:
        structure:
          encoder:
            name:
            filters: # equal number of cnn layers
            - 128
            - 512
            - 512
            filter_size: # equal number of cnn layers
            - [5, 3]
            - [5, 3]
            - [5, 3]
            filter_stride: # equal number of cnn layers
            - [1, 1]
            - [1, 1]
            - [1, 1]
            pool_size: # equal number of cnn layers
            - [4, 4]
            - [1, 2]
            - [1, 2]
            num_filters: 128
            linear_num: 786 # hidden number of linear layer
            cell_num: 128 # cell units of the lstm
            hidden1: 64 # number of hidden units of fully connected layer
            attention: false # whether to use attention, false mean use max-pooling
            attention_size: 128 # attention_size
            use_lstm_layer: false # whether to use lstm layer, false mean no lstm layer
            use_dropout: true # whether to use bn, dropout layer
            dropout_rate: 0.2
            use_bn: true # whether to use bn, dropout layer
          decoder:
            name: 
          attention:
            name:
    solver:
      name: AsrSolver
      quantization:
        enable: false # whether to quantization model
        quant_delay: 0 # Number of steps after which weights and activations are quantized during training
      adversarial:
        enable: false # whether to using adversiral training
        adv_alpha: 0.5 # adviseral alpha of loss
        adv_epslion: 0.1 # adviseral example epslion
      model_average:
        enable: false # use average model
        var_avg_decay: 0.99 # the decay rate of varaibles
      distilling:
        enable: false 
        name : Teacher
        loss : DistillationLoss
        temperature: 5
        alpha: 0.5
        teacher_model: null # fronzen_graph.pb 
      optimizer:
        name: adam
        epochs: 5 # maximum epochs
        loss: CTCLoss 
        label_smoothing: 0.0 # label smoothing rate
        learning_rate:
          rate: 0.0001 # learning rate of Adam optimizer
          type:  exp_decay # learning rate type
          decay_rate: 0.99  # the lr decay rate
          decay_steps: 100  # the lr decay_step for optimizer
        clip_global_norm: 3.0 # clip global norm
        multitask: False # whether is multi-task
        early_stopping: # keras early stopping
          enable: true
          monitor: val_loss
          min_delta: 0
          patience: 5
      metrics:
        pos_label: 1 # int, same to sklearn
        cals:
        - name: AccuracyCal
          arguments: null 
        - name: ConfusionMatrixCal
          arguments: null
        - name: PrecisionCal
          arguments:
            average: 'binary'
        - name: RecallCal
          arguments:
            average: 'binary'
        - name: F1ScoreCal
          arguments:
            average: 'binary'
      postproc:
          enbale: false
          name: EmoPostProc
          log_verbose: false 
          eval: true # compute metrics
          infer: true  # get predict results
          pred_path: null # None for `model_path`/infer, dumps infer output to this dir
          thresholds:
              - 0.5
          smoothing:
              enable: true
              count: 2
      saver:
        model_path: "ckpt/asr-seq/test"
        max_to_keep: 10
        save_checkpoints_steps: 100
        keep_checkpoint_every_n_hours: 10000
        checkpoint_every: 100 # the step to save checkpoint
        summary: false
        save_summary_steps: 100
        eval_on_dev_every_secs: 1
        print_every: 10
        resume_model_path: ""
      run_config:
        debug: false # use tfdbug
        tf_random_seed: null # 0-2**32; null is None, try to read data from /dev/urandom if available or seed from the clock otherwise
        allow_soft_placement: true
        log_device_placement: false
        intra_op_parallelism_threads: 10
        inter_op_parallelism_threads: 10
        allow_growth: true
        log_step_count_steps: 100 #The frequency, in number of global steps, that the global step/sec and the loss will be logged during training.
      run_options:
        trace_level: 3 # 0: no trace, 1: sotware trace, 2: hardware_trace, 3: full trace
        inter_op_thread_pool: -1
        report_tensor_allocations_upon_oom: true
    
    serving:
      enable: false 
      name : Evaluate
      model: null # saved model dir, ckpt dir, or frozen_model.pb
      inputs: 'inputs:0'
      outpus: 'softmax_output:0'   
    '''
    import_all_modules_for_register()
    tempdir = self.get_temp_dir()

    config_path = str(Path(tempdir).joinpath("asr_seq.yaml"))
    logging.info("config path: {}".format(config_path))
    with open(config_path, 'w', encoding='utf-8') as f:  #pylint: disable=invalid-name
      f.write(self.conf_str)

    self.config = utils.load_config(config_path)
    self.mode = utils.TRAIN
    self.batch_size = 4
    self.config['solver']['optimizer']['batch_size'] = self.batch_size

    #generate dummpy data
    nexamples = 10
    generate_json_data(self.config, self.mode, nexamples)
示例#20
0
  def setUp(self):
    ''' set up'''
    import_all_modules_for_register()
    self.conf_str = '''
    data:
      train:
        paths:
        - ''
      eval:
        paths:
        - ''
      infer:
        paths:
        - ''
      task:
        name: SpeakerClsTask
        data_type: KaldiDataDirectory
        suffix: .npy # file suffix
        audio:
          dry_run: false # not save feat
          # params
          clip_size: 3 # clip len in seconds
          stride: 0.5 # stride in ratio of clip_size
          sr: 8000 # sample rate
          winlen: 0.025 # window len
          winstep: 0.01 # window stride
          nfft: 512 # fft bins, default: 512
          lowfreq: 0
          highfreq: null # default: null, 200 points for 800 nfft, 400 points for 1600 nfft
          preemph: 0.97 # default: 0.97
          # extractor
          feature_extractor: tffeat # `tffeat` to use TF feature_extraction .so library, 'pyfeat' to python_speech_feature
          save_feat_path: null  # null for dump feat with same dir of wavs
          # fbank
          save_fbank: true # save fbank or power spec
          feature_size: 23 # extract feature size
          add_delta_deltas: false # delta deltas
          # log pwoer
          log_powspec: false # true, save log power spec; otherwise save power spec
          # cmvn
          cmvn: true # apply cmvn or generate cmvn
          cmvn_path: ./cmvn_speaker.npy # cmvn file
        classes:
          num: 2 
          vocab: null
        num_parallel_calls: 12
        num_prefetch_batch: 2
        shuffle_buffer_size: 200000
        need_shuffle: true

    model:
      name: SpeakerCRNNRawModel
      type: raw # raw, keras or eager model
      net:
        structure:
          embedding_size: 2
          filters: # equal number of cnn layers
          - 2
          filter_size: # equal number of cnn layers
          - [1, 1]
          filter_stride: # equal number of cnn layers
          - [1, 1]
          pool_size: # equal number of cnn layers
          - [8, 8]
          tdnn_contexts:
          - 3
          - 3
          tdnn_dims:
          - 128
          - 128
          num_filters: 2
          linear_num: 2 # hidden number of linear layer
          cell_num: 2 # cell units of the lstm
          hidden1: 2 # number of hidden units of fully connected layer
          attention: false # whether to use attention, false mean use max-pooling
          attention_size: 64 # attention_size
          use_lstm_layer: false # whether to use lstm layer, false mean no lstm layer
          use_dropout: true # whether to use bn, dropout layer
          dropout_rate: 0.2
          use_bn: true # whether to use bn, dropout layer

          score_threshold: 0.5 # threshold to predict POS example
          threshold: 3 # threshold to predict POS example

    solver:
      name: SpeakerSolver
      quantization:
        enable: false # whether to quantization model
        quant_delay: 0 # Number of steps after which weights and activations are quantized during training
      adversarial:
        enable: false # whether to using adversiral training
        adv_alpha: 0.5 # adviseral alpha of loss
        adv_epslion: 0.1 # adviseral example epslion
      model_average:
        enable: false # use average model
        var_avg_decay: 0.99 # the decay rate of varaibles
      optimizer:
        name: adam
        epochs: 5 # maximum epochs
        batch_size: 4 # number of elements in a training batch
        loss: CrossEntropyLoss
        label_smoothing: 0.0 # label smoothing rate
        learning_rate:
          rate: 0.0001 # learning rate of Adam optimizer
          type:  exp_decay # learning rate type
          decay_rate: 0.99  # the lr decay rate
          decay_steps: 100  # the lr decay_step for optimizer
        clip_global_norm: 3.0 # clip global norm
        multitask: False # whether is multi-task
      metrics:
        pos_label: 1 # int, same to sklearn
        cals:
        - name: AccuracyCal
          arguments: null
        - name: ConfusionMatrixCal
          arguments: null
        - name: PrecisionCal
          arguments:
            average: 'binary'
        - name: RecallCal
          arguments:
            average: 'binary'
        - name: F1ScoreCal
          arguments:
            average: 'binary'
      postproc:
          name: SpeakerPostProc
          log_verbose: false
          eval: true # compute metrics
          infer: true  # get predict results
          pred_path: null # None for `model_path`/infer, dumps infer output to this dir
          thresholds:
              - 0.5
          smoothing:
              enable: true
              count: 2
      saver:
        model_path: "ckpt/emotion-speech-cls/test"
        max_to_keep: 10
        save_checkpoints_steps: 10
        keep_checkpoint_every_n_hours: 10000
        checkpoint_every: 10 # the step to save checkpoint
        summary: false
        save_summary_steps: 5
        eval_on_dev_every_secs: 1
        print_every: 10
        resume_model_path: ""
      run_config:
        debug: false # use tfdbug
        tf_random_seed: null # 0-2**32; null is None, try to read data from /dev/urandom if available or seed from the clock otherwise
        allow_soft_placement: true
        log_device_placement: false
        intra_op_parallelism_threads: 1
        inter_op_parallelism_threads: 1
        allow_growth: true
        log_step_count_steps: 1 #The frequency, in number of global steps, that the global step/sec and the loss will be logged during training.
      distilling:
        enable: false
        name : Teacher
        loss : DistillationLoss
        temperature: 5
        alpha: 0.5
        teacher_model: ''

    serving:
      enable: true
      name : Evaluate
      model: '' # saved model dir, ckpt dir, or frozen_model.pb
      inputs: 'inputs:0'
      outpus: 'softmax_output:0'
    '''

    # write config to file
    tempdir = self.get_temp_dir()
    #tempdir = 'bar'
    os.makedirs(tempdir, exist_ok=True)

    config_path = str(Path(tempdir).joinpath('speaker_task.yaml'))
    logging.info("config path: {}".format(config_path))
    with open(config_path, 'w', encoding='utf-8') as f:  #pylint: disable=invalid-name
      f.write(self.conf_str)

    # load config
    config = utils.load_config(config_path)
    logging.info("config: {}".format(config))

    # edit path in config
    dataset_path = Path(tempdir).joinpath('data')
    if not dataset_path.exists():
      dataset_path.mkdir()
    dataset_path_str = str(dataset_path)
    config['data']['train']['paths'] = [dataset_path_str]
    config['data']['eval']['paths'] = [dataset_path_str]
    config['data']['infer']['paths'] = [dataset_path_str]

    # generate dummy data
    feat_dim = config['data']['task']['audio']['feature_size']
    kaldi_dir_utils.gen_dummy_data_dir(
        dataset_path_str, 2, 2, feat_dim=feat_dim)

    solver_name = config['solver']['name']
    self.solver = registers.solver[solver_name](config)

    # config after process
    self.config = self.solver.config
示例#21
0
 def setUp(self):
   super().setUp()
   import_all_modules_for_register()
   '''