Example #1
0
    label_mean_args = {
        "gpu_chunk": 32768,
        "lst_file": file_label_mean,
        "file_format": file_format,
        "separate_lines": True,
        "has_labels": False
    }

    test_sets = DataReadStream(test_data_args, feat_dim)
    label_mean_sets = DataReadStream(label_mean_args, label_dim)
    return (init_states, test_sets, label_mean_sets)


if __name__ == '__main__':
    args = parse_args()
    args.config.write(sys.stderr)

    decoding_method = args.config.get('train', 'method')
    contexts = parse_contexts(args)

    init_states, test_sets, label_mean_sets = prepare_data(args)
    state_names = [x[0] for x in init_states]

    batch_size = args.config.getint('train', 'batch_size')
    num_hidden = args.config.getint('arch', 'num_hidden')
    num_hidden_proj = args.config.getint('arch', 'num_hidden_proj')
    num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')
    feat_dim = args.config.getint('data', 'xdim')
    label_dim = args.config.getint('data', 'ydim')
    out_file = args.config.get('data', 'out_file')
Example #2
0
                                                 data_names=data_names, label_names=label_names,
                                                 load_optimizer_states=True)
        else:
            model_loaded = mx.module.Module.load(prefix=model_path, epoch=model_num_epoch, context=contexts,
                                                 data_names=data_names, label_names=label_names,
                                                 load_optimizer_states=False)

    return model_loaded, model_num_epoch


if __name__ == '__main__':
    if len(sys.argv) <= 1:
        raise Exception('cfg file path must be provided. ex)python main.py --configfile examplecfg.cfg')
    mx.random.seed(hash(datetime.now()))
    # set parameters from cfg file
    args = parse_args(sys.argv[1])

    log_filename = args.config.get('common', 'log_filename')
    log = LogUtil(filename=log_filename).getlogger()

    # set parameters from data section(common)
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')

    # get meta file where character to number conversions are defined
    language = args.config.get('data', 'language')
    labelUtil = LabelUtil.getInstance()
    if language == "en":
        labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv")
Example #3
0
            reset_optimizer()
            module.set_params(*last_params)
        else:
            last_params = module.get_params()
            last_acc = curr_acc
            n_epoch += 1

            # save checkpoints
            mx.model.save_checkpoint(get_checkpoint_path(args), n_epoch,
                                     module.symbol, *last_params)

        if n_epoch == num_epoch:
            break

if __name__ == '__main__':
    args = parse_args()
    args.config.write(sys.stdout)

    training_method = args.config.get('train', 'method')
    contexts = parse_contexts(args)

    init_states, train_sets, dev_sets = prepare_data(args)
    state_names = [x[0] for x in init_states]

    batch_size = args.config.getint('train', 'batch_size')
    num_hidden = args.config.getint('arch', 'num_hidden')
    num_hidden_proj = args.config.getint('arch', 'num_hidden_proj')
    num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')
    feat_dim = args.config.getint('data', 'xdim')
    label_dim = args.config.getint('data', 'ydim')
Example #4
0
import config_util
args = config_util.parse_args()
config = config_util.load_config(args.config)
print(config)
Example #5
0
#!/usr/bin/python
import os, sys
import codecs

import numpy as np
import mxnet as mx
from mxnet import gluon, nd

from config_util import parse_args, get_checkpoint_path, parse_contexts
from model import Generator, Discriminator
from data_iter import SentenceIter
import ttspacker
packer = ttspacker.ttspacker()
if __name__ == '__main__':
    args = parse_args('default.cfg')
    source_scp = args.config.get('test', 'test_source_scp')
    target_scp = args.config.get('test', 'test_target_scp')
    feat_dim = args.config.getint('data', 'feat_dim')
    segment_length = args.config.getint('train', 'segment_length')
    num_iteration = args.config.getint('train', 'num_iteration')
    G_learning_rate = args.config.getfloat('train', 'G_learning_rate')
    D_learning_rate = args.config.getfloat('train', 'D_learning_rate')
    momentum = args.config.getfloat('train', 'momentum')
    source_speaker = args.config.get('data', 'source_speaker')
    target_speaker = args.config.get('data', 'target_speaker')
    lambda_cyc = args.config.getfloat('train', 'lambda_cyc')
    lambda_id = args.config.getfloat('train', 'lambda_id')
    input_gv = args.config.get('data', 'source_gv')
    output_gv = args.config.get('data', 'target_gv')
    G_A_check_iter = args.config.getint('test', 'G_A_check_iter')
    G_B_check_iter = args.config.getint('test', 'G_B_check_iter')
Example #6
0
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(100))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        self.model = STTBucketingModule(sym_gen=self.model_loaded,
                                        default_bucket_key=default_bucket_key,
                                        context=self.contexts)

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        self.model.bind(data_shapes=[
            ('data', (self.batch_size, default_bucket_key, width * height))
        ] + init_states,
                        label_shapes=[
                            ('label',
                             (self.batch_size,
                              self.args.config.getint('arch',
                                                      'max_label_length')))
                        ],
                        for_training=True)

        _, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        self.model.set_params(self.arg_params,
                              self.aux_params,
                              allow_extra=True,
                              allow_missing=True)

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.scorer = _ext_scorer
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.scorer = km.score
Example #7
0
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        save_dir = 'checkpoints'
        model_name = self.args.config.get('common', 'prefix')
        max_freq = self.args.config.getint('data', 'max_freq')
        self.datagen = DataGenerator(save_dir=save_dir,
                                     model_name=model_name,
                                     max_freq=max_freq)
        self.datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

        self.buckets = json.loads(self.args.config.get('arch', 'buckets'))

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(95))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        # self.model = STTBucketingModule(
        #     sym_gen=self.model_loaded,
        #     default_bucket_key=default_bucket_key,
        #     context=self.contexts
        # )

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        for bucket in self.buckets:
            net, init_state_names, ll = self.model_loaded(bucket)
            net.save('checkpoints/%s-symbol.json' % bucket)
        input_shapes = dict([('data',
                              (self.batch_size, default_bucket_key,
                               width * height))] + init_states + [('label',
                                                                   (1, 18))])
        # self.executor = net.simple_bind(ctx=mx.cpu(), **input_shapes)

        # self.model.bind(data_shapes=[('data', (self.batch_size, default_bucket_key, width * height))] + init_states,
        #                 label_shapes=[
        #                     ('label', (self.batch_size, self.args.config.getint('arch', 'max_label_length')))],
        #                 for_training=True)

        symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        all_layers = symbol.get_internals()
        concat = all_layers['concat36457_output']
        sm = mx.sym.SoftmaxOutput(data=concat, name='softmax')
        self.executor = sm.simple_bind(ctx=mx.cpu(), **input_shapes)
        # self.model.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True)

        for key in self.executor.arg_dict.keys():
            if key in self.arg_params:
                self.arg_params[key].copyto(self.executor.arg_dict[key])
        init_state_names.remove('data')
        init_state_names.sort()
        self.states_dict = dict(
            zip(init_state_names, self.executor.outputs[1:]))
        self.input_arr = mx.nd.zeros(
            (self.batch_size, default_bucket_key, width * height))

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.eval_metric = EvalSTTMetric(batch_size=self.batch_size,
                                             num_gpu=self.num_gpu,
                                             is_logging=True,
                                             scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.eval_metric = EvalSTTMetric(batch_size=self.batch_size,
                                             num_gpu=self.num_gpu,
                                             is_logging=True,
                                             scorer=km.score)