Beispiel #1
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            seq_length=seq_len)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate,
                                     momentum=momentum,
                                     optimizer=optimizer)

    n_epoch = begin_epoch
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0:
        module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={
                                  'clip_gradient': clip_gradient,
                                  'wd': weight_decay
                              },
                              force_init=True)

    reset_optimizer()

    while True:

        if n_epoch >= num_epoch:
            break

        eval_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            if data_batch.effective_sample_count is not None:
                lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            if (nbatch + 1) % show_every == 0:
                module.update_metric(eval_metric, data_batch.label)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        for nbatch, data_batch in enumerate(data_val):
            module.update_metric(eval_metric, data_batch.label)
        #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)

        data_train.reset()
        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args),
                                   epoch=n_epoch,
                                   save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Beispiel #2
0
    # set parameters from data section(common)
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')

    # get meta file where character to number conversions are defined
    language = args.config.get('data', 'language')
    labelUtil = LabelUtil.getInstance()
    if language == "en":
        labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv")
    else:
        raise Exception("Error: Language Type: %s" % language)
    args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    batch_size = args.config.getint('common', 'batch_size')

    # check the number of gpus is positive divisor of the batch size
    if batch_size % num_gpu != 0:
        raise Exception('num_gpu should be positive divisor of batch_size')

    if mode == "predict":
        data_train, args = load_data(args)
    elif mode == "train" or mode == "load":
        data_train, data_val, args = load_data(args)

    # log current config
    config_logger = ConfigLogger(log)
    config_logger(args.config)
Beispiel #3
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer)

    n_epoch = begin_epoch
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0:
        module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={'clip_gradient': clip_gradient,
                                                'wd': weight_decay},
                              force_init=True)

    reset_optimizer()

    while True:

        if n_epoch >= num_epoch:
            break

        eval_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            if data_batch.effective_sample_count is not None:
                lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            if (nbatch+1) % show_every == 0:
                module.update_metric(eval_metric, data_batch.label)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        for nbatch, data_batch in enumerate(data_val):
            module.update_metric(eval_metric, data_batch.label)
        #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)

        data_train.reset()
        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Beispiel #4
0
    def __init__(self, args):
        self.args = args
        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(log)
        self.config_logger(args.config)

        save_dir = 'checkpoints'
        model_name = self.args.config.get('common', 'prefix')
        max_freq = self.args.config.getint('data', 'max_freq')
        self.datagen = DataGenerator(save_dir=save_dir,
                                     model_name=model_name,
                                     max_freq=max_freq)
        self.datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

        self.buckets = json.loads(self.args.config.get('arch', 'buckets'))

        default_bucket_key = self.buckets[-1]
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(100))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)
        symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        # all_layers = symbol.get_internals()
        # s_sym = all_layers['concat36457_output']
        # sm = mx.sym.SoftmaxOutput(data=s_sym, name='softmax')

        # self.model = STTBucketingModule(
        #     sym_gen=self.model_loaded,
        #     default_bucket_key=default_bucket_key,
        #     context=self.contexts
        # )
        s_mod = mx.mod.BucketingModule(sym_gen=self.model_loaded,
                                       context=self.contexts,
                                       default_bucket_key=default_bucket_key)

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        self.init_states = prepare_data_template.prepare_data(self.args)
        self.width = self.args.config.getint('data', 'width')
        self.height = self.args.config.getint('data', 'height')
        s_mod.bind(data_shapes=[
            ('data',
             (self.batch_size, default_bucket_key, self.width * self.height))
        ] + self.init_states,
                   for_training=False)

        s_mod.set_params(self.arg_params,
                         self.aux_params,
                         allow_extra=True,
                         allow_missing=True)
        for bucket in self.buckets:
            provide_data = [
                ('data', (self.batch_size, bucket, self.width * self.height))
            ] + self.init_states
            s_mod.switch_bucket(bucket_key=bucket, data_shapes=provide_data)

        self.model = s_mod

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            log.info("language model: "
                     "is_character_based = %d," % lm_char_based +
                     " max_order = %d," % lm_max_order +
                     " dict_size = %d" % lm_dict_size)
            self.scorer = _ext_scorer
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=km.score)
            self.scorer = km.score
Beispiel #5
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil.getInstance().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    #seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True)
    # mxboard setting
    loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch=begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        symbol, data_names, label_names = module(1600)
        model = STTBucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))


    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        optimizer_params = {'lr_scheduler': lr_scheduler,
                            'clip_gradient': clip_gradient,
                            'wd': weight_decay}
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kvstore_option,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)
    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    #mxboard setting
    mxlog_dir = args.config.get('common', 'mxboard_log_dir')
    summary_writer = SummaryWriter(mxlog_dir)

    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):
            module.forward_backward(data_batch)
            module.update()
            # mxboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch+1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
                module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # mxboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()
        data_train.is_first_epoch = False

        # mxboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate=learning_rate/learning_rate_annealing

    log.info('FINISH')
Beispiel #6
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    #seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint(
        'common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean(
        'train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean(
        'train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_validation_metric,
                            is_epoch_end=True)
    # mxboard setting
    loss_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_train_metric,
                            is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train',
                                                   'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(
        args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch = begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        symbol, data_names, label_names = module(1600)
        model = STTBucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_path, model_num_epoch)
        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                    label_shapes=data_train.provide_label,
                    for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))

    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'clip_gradient': clip_gradient,
            'wd': weight_decay
        }
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kvstore_option,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)

    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    #mxboard setting
    mxlog_dir = args.config.get('common', 'mxboard_log_dir')
    summary_writer = SummaryWriter(mxlog_dir)

    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):
            module.forward_backward(data_batch)
            module.update()
            # mxboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch + 1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch,
                         nbatch)
                module.save_checkpoint(
                    prefix=get_checkpoint_path(args) + "n_epoch" +
                    str(n_epoch) + "n_batch",
                    epoch=(int(
                        (nbatch + 1) / save_checkpoint_every_n_batch) - 1),
                    save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # mxboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer,
                 int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()
        data_train.is_first_epoch = False

        # mxboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value(
        )
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args),
                                   epoch=n_epoch,
                                   save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate = learning_rate / learning_rate_annealing

    log.info('FINISH')
Beispiel #7
0
def do_train(args, dataA_iter, dataB_iter, G_A, G_B, D_A, D_B, G_A_trainer,
             G_B_trainer, D_A_trainer, D_B_trainer, loss1, loss2):

    num_iteration = args.config.getint('train', 'num_iteration')
    lambda_cyc = args.config.getfloat('train', 'lambda_cyc')
    lambda_id = args.config.getfloat('train', 'lambda_id')
    feat_dim = args.config.getint('data', 'feat_dim')
    segment_length = args.config.getint('train', 'segment_length')
    show_loss_every = args.config.getint('train', 'show_loss_every')
    G_learning_rate = args.config.getfloat('train', 'G_learning_rate')
    D_learning_rate = args.config.getfloat('train', 'D_learning_rate')
    source_speaker = args.config.get('data', 'source_speaker')
    target_speaker = args.config.get('data', 'target_speaker')
    contexts = parse_contexts(args)
    G_lr_decay = G_learning_rate / 200000
    D_lr_decay = D_learning_rate / 200000

    dataA_iter.reset()
    dataB_iter.reset()

    label = nd.zeros((1, 1), ctx=contexts[0])

    loss_cyc_A = 0
    loss_cyc_B = 0
    loss_D_A_fake = 0
    loss_D_B_fake = 0
    loss_D_A = 0
    loss_D_B = 0

    for p in G_A.collect_params():
        G_A.collect_params()[p].grad_req = 'add'
    for p in G_B.collect_params():
        G_B.collect_params()[p].grad_req = 'add'
    for p in D_A.collect_params():
        D_A.collect_params()[p].grad_req = 'add'
    for p in D_B.collect_params():
        D_B.collect_params()[p].grad_req = 'add'

    for iter in range(num_iteration):

        if iter == 10000:
            lambda_id = 0

        if iter >= 200000:
            G_A_trainer.set_learning_rate(G_A_trainer.learning_rate -
                                          G_lr_decay)
            G_B_trainer.set_learning_rate(G_B_trainer.learning_rate -
                                          G_lr_decay)
            D_A_trainer.set_learning_rate(D_A_trainer.learning_rate -
                                          D_lr_decay)
            D_B_trainer.set_learning_rate(D_B_trainer.learning_rate -
                                          D_lr_decay)

        G_A.collect_params().zero_grad()
        G_B.collect_params().zero_grad()
        D_A.collect_params().zero_grad()
        D_B.collect_params().zero_grad()

        inputA = dataA_iter.next()
        inputB = dataB_iter.next()
        inputA = inputA.as_in_context(contexts[0])
        inputB = inputB.as_in_context(contexts[0])

        ##############################################################################
        # Train Generator
        ##############################################################################

        # calculate loss for inputA
        inputA.attach_grad()
        with autograd.record():
            fakeB_tmp = G_A(inputA)
        fakeB = fakeB_tmp.copy()
        fakeB.attach_grad()
        with autograd.record():
            cycleA_tmp = G_B(fakeB)
        cycleA = cycleA_tmp.copy()
        cycleA.attach_grad()
        with autograd.record():
            print(cycleA, inputA)
            L_cycleA = loss1(cycleA, inputA)
        L_cycleA.backward()
        cycleA_grad = cycleA.grad * lambda_cyc
        cycleA_tmp.backward(cycleA_grad)

        label[:] = 1

        fakeB.attach_grad()
        with autograd.record():
            fakeB_D = nd.reshape(fakeB, (1, 1, feat_dim, segment_length))
            pred = D_B(fakeB_D)
            DlossB = loss2(pred, label)
        DlossB.backward()
        fakeB_grad_D = fakeB.grad
        fakeB_tmp.backward(fakeB_grad_D)

        # calculate loss for inputB
        inputB.attach_grad()
        with autograd.record():
            fakeA_tmp = G_B(inputB)
        fakeA = fakeA_tmp.copy()
        fakeA.attach_grad()
        with autograd.record():
            cycleB_tmp = G_A(fakeA)
        cycleB = cycleB_tmp.copy()
        cycleB.attach_grad()
        with autograd.record():
            L_cycleB = loss1(cycleB, inputB)
        L_cycleB.backward()
        cycleB_grad = cycleB.grad * lambda_cyc
        cycleB_tmp.backward(cycleB_grad)

        label[:] = 1
        fakeA.attach_grad()
        with autograd.record():
            fakeA_D = nd.reshape(fakeA, (1, 1, feat_dim, segment_length))
            pred = D_A(fakeA_D)
            DlossA = loss2(pred, label)
        DlossA.backward()
        fakeA_grad_D = fakeA.grad
        fakeA_tmp.backward(fakeA_grad_D)

        # identity loss
        inputB.attach_grad()
        with autograd.record():
            indenB_tmp = G_A(inputB)
        indenB = indenB_tmp.copy()
        indenB.attach_grad()
        with autograd.record():
            L = loss1(indenB, inputB)
        L.backward()
        indenB_grad = indenB.grad * lambda_id
        indenB_tmp.backward(indenB_grad)

        inputA.attach_grad()
        with autograd.record():
            indenA_tmp = G_B(inputA)
        indenA = indenA_tmp.copy()
        indenA.attach_grad()
        with autograd.record():
            L = loss1(indenA, inputA)
        L.backward()
        indenA_grad = indenA.grad * lambda_id
        indenA_tmp.backward(indenA_grad)

        ##############################################################################
        # Train Discriminator and Update
        ##############################################################################

        fakeB = G_A(inputA)
        fakeA = G_B(inputB)

        def train_discriminator(modD, modD_trainer, real, fake):
            label[:] = 1
            real.attach_grad()
            with autograd.record():
                real = nd.reshape(real, (1, 1, feat_dim, segment_length))
                pred = modD(real)
                L_true = loss2(pred, label)
            L_true.backward()

            label[:] = 0
            fake.attach_grad()
            with autograd.record():
                fake = nd.reshape(fake, (1, 1, feat_dim, segment_length))
                pred = modD(fake)
                L_fake = loss2(pred, label)
            L_fake.backward()

            L = L_fake + L_true

            modD_trainer.step(1)
            return L / 2

        lossD_A = train_discriminator(D_A, D_A_trainer, inputA, fakeA)
        lossD_B = train_discriminator(D_B, D_B_trainer, inputB, fakeB)

        ##############################################################################
        # Update Generator
        ##############################################################################

        G_A_trainer.step(1)
        G_B_trainer.step(1)

        loss_cyc_A += L_cycleA.asnumpy()[0]
        loss_cyc_B += L_cycleB.asnumpy()[0]
        loss_D_B_fake += DlossB.asnumpy()[0]
        loss_D_A_fake += DlossA.asnumpy()[0]
        loss_D_A += lossD_A.asnumpy()[0]
        loss_D_B += lossD_B.asnumpy()[0]

        if iter % show_loss_every == 0 and iter != 0:
            loss_cyc_A /= show_loss_every
            loss_cyc_B /= show_loss_every
            loss_D_B_fake /= show_loss_every
            loss_D_A_fake /= show_loss_every
            loss_D_A /= show_loss_every
            loss_D_B /= show_loss_every
            logging.info(
                '[%s] | iter[%d] | loss_cyc_A:%f | loss_cyc_B:%f | loss_D_B_fake:%f | loss_D_A_fake:%f | loss_D_A:%f | loss_D_B:%f',
                time.ctime(), iter, loss_cyc_A, loss_cyc_B, loss_D_B_fake,
                loss_D_A_fake, loss_D_A, loss_D_B)
            loss_cyc_A = 0
            loss_cyc_B = 0
            loss_D_B_fake = 0
            loss_D_A_fake = 0
            loss_D_A = 0
            loss_D_B = 0

            G_A.collect_params().save('checkpoints/G_A/' + 'G_A_' +
                                      source_speaker + '-' + target_speaker +
                                      '_mgc-' + str(feat_dim) + '_iteration-' +
                                      str(iter) + '_seglen-' +
                                      str(segment_length) + '_lambda-' +
                                      str(lambda_cyc) + '-' + str(lambda_id) +
                                      '_lr-' + str(G_learning_rate) + '-' +
                                      str(D_learning_rate) + '.params')
            G_B.collect_params().save('checkpoints/G_B/' + 'G_B_' +
                                      source_speaker + '-' + target_speaker +
                                      '_mgc-' + str(feat_dim) + '_iteration-' +
                                      str(iter) + '_seglen-' +
                                      str(segment_length) + '_lambda-' +
                                      str(lambda_cyc) + '-' + str(lambda_id) +
                                      '_lr-' + str(G_learning_rate) + '-' +
                                      str(D_learning_rate) + '.params')
            D_A.collect_params().save('checkpoints/D_A/' + 'D_A_' +
                                      source_speaker + '-' + target_speaker +
                                      '_mgc-' + str(feat_dim) + '_iteration-' +
                                      str(iter) + '_seglen-' +
                                      str(segment_length) + '_lambda-' +
                                      str(lambda_cyc) + '-' + str(lambda_id) +
                                      '_lr-' + str(G_learning_rate) + '-' +
                                      str(D_learning_rate) + '.params')
            D_B.collect_params().save('checkpoints/D_B/' + 'D_B_' +
                                      source_speaker + '-' + target_speaker +
                                      '_mgc-' + str(feat_dim) + '_iteration-' +
                                      str(iter) + '_seglen-' +
                                      str(segment_length) + '_lambda-' +
                                      str(lambda_cyc) + '-' + str(lambda_id) +
                                      '_lr-' + str(G_learning_rate) + '-' +
                                      str(D_learning_rate) + '.params')
Beispiel #8
0
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(100))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        self.model = STTBucketingModule(sym_gen=self.model_loaded,
                                        default_bucket_key=default_bucket_key,
                                        context=self.contexts)

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        self.model.bind(data_shapes=[
            ('data', (self.batch_size, default_bucket_key, width * height))
        ] + init_states,
                        label_shapes=[
                            ('label',
                             (self.batch_size,
                              self.args.config.getint('arch',
                                                      'max_label_length')))
                        ],
                        for_training=True)

        _, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        self.model.set_params(self.arg_params,
                              self.aux_params,
                              allow_extra=True,
                              allow_missing=True)

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.scorer = _ext_scorer
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.scorer = km.score
Beispiel #9
0
def do_training(args, module, data_train, data_val, begin_epoch=0, kv=None):
    from distutils.dir_util import mkpath

    host_name = socket.gethostname()
    log = LogUtil().getlogger()
    mkpath(os.path.dirname(config_util.get_checkpoint_path(args)))

    # seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    val_batch_size = args.config.getint('common', 'val_batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint(
        'common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean(
        'train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean(
        'train', 'enable_logging_validation_metric')

    contexts = config_util.parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=val_batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_validation_metric,
                            is_epoch_end=True)
    # tensorboard setting
    loss_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_train_metric,
                            is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_start = args.config.getfloat('train', 'learning_rate_start')
    learning_rate_annealing = args.config.getfloat('train',
                                                   'learning_rate_annealing')
    lr_factor = args.config.getfloat('train', 'lr_factor')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(
        args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch = begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    # kv = mx.kv.create(kvstore_option)
    # data = mx.io.ImageRecordIter(num_parts=kv.num_workers, part_index=kv.rank)
    # # a.set_optimizer(optimizer)
    # updater = mx.optimizer.get_updater(optimizer)
    # a._set_updater(updater=updater)

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        prefix = args.config.get('common', 'prefix')
        if os.path.isabs(prefix):
            model_path = config_util.get_checkpoint_path(args).rsplit(
                "/", 1)[0] + "/" + str(model_name[:-5])
        # symbol, data_names, label_names = module(1600)
        model = mx.mod.BucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_path, model_num_epoch)

        # arg_params2 = {}
        # for item in arg_params.keys():
        #     if not item.startswith("forward") and not item.startswith("backward") and not item.startswith("rear"):
        #         arg_params2[item] = arg_params[item]
        # model.set_params(arg_params2, aux_params, allow_missing=True, allow_extra=True)

        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                    label_shapes=data_train.provide_label,
                    for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))
        # model_file = args.config.get('common', 'model_file')
        # model_name = os.path.splitext(model_file)[0]
        # model_num_epoch = int(model_name[-4:])
        # model_path = 'checkpoints/' + str(model_name[:-5])
        # _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
        # arg_params2 = {}
        # for item in arg_params.keys():
        #     if not item.startswith("forward") and not item.startswith("backward") and not item.startswith("rear"):
        #         arg_params2[item] = arg_params[item]
        # module.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)

    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    # lr, lr_scheduler = _get_lr_scheduler(args, kv)

    def reset_optimizer(force_init=False):
        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'clip_gradient': clip_gradient,
            'wd': weight_decay
        }
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kv,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)

    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    # tensorboard setting
    tblog_dir = args.config.get('common', 'tensorboard_log_dir')
    summary_writer = SummaryWriter(tblog_dir)
    learning_rate_pre = 0
    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info(host_name + '---------train---------')

        step_epochs = [
            int(l)
            for l in args.config.get('train', 'lr_step_epochs').split(',')
        ]
        # warm up to step_epochs[0] if step_epochs[0] > 0
        if n_epoch < step_epochs[0]:
            learning_rate_cur = learning_rate_start + n_epoch * (
                learning_rate - learning_rate_start) / step_epochs[0]
        else:
            # scaling lr every epoch
            if len(step_epochs) == 1:
                learning_rate_cur = learning_rate
                for s in range(n_epoch):
                    learning_rate_cur /= learning_rate_annealing
            # scaling lr by step_epochs[1:]
            else:
                learning_rate_cur = learning_rate
                for s in step_epochs[1:]:
                    if n_epoch > s:
                        learning_rate_cur *= lr_factor

        if learning_rate_pre and args.config.getboolean(
                'train', 'momentum_correction'):
            lr_scheduler.learning_rate = learning_rate_cur * learning_rate_cur / learning_rate_pre
        else:
            lr_scheduler.learning_rate = learning_rate_cur
        learning_rate_pre = learning_rate_cur
        log.info("n_epoch %d's lr is %.7f" %
                 (n_epoch, lr_scheduler.learning_rate))
        summary_writer.add_scalar('lr', lr_scheduler.learning_rate, n_epoch)
        for nbatch, data_batch in enumerate(data_train):

            module.forward_backward(data_batch)
            module.update()
            # tensorboard setting
            if (nbatch + 1) % show_every == 0:
                # loss_metric.set_audio_paths(data_batch.index)
                module.update_metric(loss_metric, data_batch.label)
                # print("loss=========== %.2f" % loss_metric.get_batch_loss())
            # summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch + 1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch,
                         nbatch)
                save_checkpoint(
                    module,
                    prefix=config_util.get_checkpoint_path(args) + "n_epoch" +
                    str(n_epoch) + "n_batch",
                    epoch=(int(
                        (nbatch + 1) / save_checkpoint_every_n_batch) - 1),
                    save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info(host_name + '---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            eval_metric.set_audio_paths(data_batch.index)
            module.update_metric(eval_metric, data_batch.label)

        # tensorboard setting
        val_cer, val_n_label, val_l_dist, val_ctc_loss = eval_metric.get_name_value(
        )
        log.info("Epoch[%d] val cer=%f (%d / %d), ctc_loss=%f", n_epoch,
                 val_cer, int(val_n_label - val_l_dist), val_n_label,
                 val_ctc_loss)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        summary_writer.add_scalar('loss validation', val_ctc_loss, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'
        np.random.seed(n_epoch)
        data_train.reset()
        data_train.is_first_epoch = False

        # tensorboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value(
        )
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            save_checkpoint(module,
                            prefix=config_util.get_checkpoint_path(args),
                            epoch=n_epoch,
                            save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Beispiel #10
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_validation_metric,is_epoch_end=True)
    # tensorboard setting
    loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_train_metric,is_epoch_end=False)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    n_epoch=begin_epoch

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))


    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        if optimizer == "sgd":
            module.init_optimizer(kvstore='device',
                                  optimizer=optimizer,
                                  optimizer_params={'lr_scheduler': lr_scheduler,
                                                    'momentum': momentum,
                                                    'clip_gradient': clip_gradient,
                                                    'wd': weight_decay},
                                  force_init=force_init)
        elif optimizer == "adam":
            module.init_optimizer(kvstore='device',
                                  optimizer=optimizer,
                                  optimizer_params={'lr_scheduler': lr_scheduler,
                                                    #'momentum': momentum,
                                                    'clip_gradient': clip_gradient,
                                                    'wd': weight_decay},
                                  force_init=force_init)
        else:
            raise Exception('Supported optimizers are sgd and adam. If you want to implement others define them in train.py')
    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)

    #tensorboard setting
    tblog_dir = args.config.get('common', 'tensorboard_log_dir')
    summary_writer = SummaryWriter(tblog_dir)
    while True:

        if n_epoch >= num_epoch:
            break

        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            module.forward_backward(data_batch)
            module.update()
            # tensorboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch+1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
                module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # tensorboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()

        # tensorboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate=learning_rate/learning_rate_annealing

    log.info('FINISH')
Beispiel #11
0
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        save_dir = 'checkpoints'
        model_name = self.args.config.get('common', 'prefix')
        max_freq = self.args.config.getint('data', 'max_freq')
        self.datagen = DataGenerator(save_dir=save_dir,
                                     model_name=model_name,
                                     max_freq=max_freq)
        self.datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

        self.buckets = json.loads(self.args.config.get('arch', 'buckets'))

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(95))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        # self.model = STTBucketingModule(
        #     sym_gen=self.model_loaded,
        #     default_bucket_key=default_bucket_key,
        #     context=self.contexts
        # )

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        for bucket in self.buckets:
            net, init_state_names, ll = self.model_loaded(bucket)
            net.save('checkpoints/%s-symbol.json' % bucket)
        input_shapes = dict([('data',
                              (self.batch_size, default_bucket_key,
                               width * height))] + init_states + [('label',
                                                                   (1, 18))])
        # self.executor = net.simple_bind(ctx=mx.cpu(), **input_shapes)

        # self.model.bind(data_shapes=[('data', (self.batch_size, default_bucket_key, width * height))] + init_states,
        #                 label_shapes=[
        #                     ('label', (self.batch_size, self.args.config.getint('arch', 'max_label_length')))],
        #                 for_training=True)

        symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        all_layers = symbol.get_internals()
        concat = all_layers['concat36457_output']
        sm = mx.sym.SoftmaxOutput(data=concat, name='softmax')
        self.executor = sm.simple_bind(ctx=mx.cpu(), **input_shapes)
        # self.model.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True)

        for key in self.executor.arg_dict.keys():
            if key in self.arg_params:
                self.arg_params[key].copyto(self.executor.arg_dict[key])
        init_state_names.remove('data')
        init_state_names.sort()
        self.states_dict = dict(
            zip(init_state_names, self.executor.outputs[1:]))
        self.input_arr = mx.nd.zeros(
            (self.batch_size, default_bucket_key, width * height))

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.eval_metric = EvalSTTMetric(batch_size=self.batch_size,
                                             num_gpu=self.num_gpu,
                                             is_logging=True,
                                             scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.eval_metric = EvalSTTMetric(batch_size=self.batch_size,
                                             num_gpu=self.num_gpu,
                                             is_logging=True,
                                             scorer=km.score)