Esempio n. 1
0
def generate_accuracy_plots(exp_dir, output_dir, plot, key='accuracy',
                            file_basename='accuracy', comparison_dir=None,
                            start_iter=1,
                            latex_report=None, output_name='output'):
    assert start_iter >= 1

    if plot:
        fig = plt.figure()
        plots = []

    comparison_dir = [] if comparison_dir is None else comparison_dir
    dirs = [exp_dir] + comparison_dir
    index = 0
    for dir in dirs:
        [accuracy_report, accuracy_times,
         accuracy_data] = log_parse.generate_accuracy_report(dir, key,
                                                             output_name)
        if index == 0:
            # this is the main experiment directory
            with open("{0}/{1}.log".format(output_dir,
                                           file_basename), "w") as f:
                f.write(accuracy_report)

        if plot:
            color_val = g_plot_colors[index]
            data = np.array(accuracy_data)
            if data.shape[0] == 0:
                raise Exception("Couldn't find any rows for the accuracy plot")
            data = data[data[:, 0] >= start_iter, :]
            plot_handle, = plt.plot(data[:, 0], data[:, 1], color=color_val,
                                    linestyle="--",
                                    label="train {0}".format(dir))
            plots.append(plot_handle)
            plot_handle, = plt.plot(data[:, 0], data[:, 2], color=color_val,
                                    label="valid {0}".format(dir))
            plots.append(plot_handle)
        index += 1
    if plot:
        plt.xlabel('Iteration')
        plt.ylabel(key)
        lgd = plt.legend(handles=plots, loc='lower center',
                         bbox_to_anchor=(0.5, -0.2 + len(dirs) * -0.1),
                         ncol=1, borderaxespad=0.)
        plt.grid(True)
        fig.suptitle("{0} plot for {1}".format(key, output_name))
        figfile_name = '{0}/{1}_{2}.pdf'.format(
            output_dir, file_basename,
            latex_compliant_name(output_name))
        plt.savefig(figfile_name, bbox_extra_artists=(lgd,),
                    bbox_inches='tight')
        if latex_report is not None:
            latex_report.add_figure(
                figfile_name,
                "Plot of {0} vs iterations for {1}".format(key, output_name))
Esempio n. 2
0
def train(args, run_opts, background_process_handler):
    """ The main function for training.

    Args:
        args: a Namespace object with the required parameters
            obtained from the function process_args()
        run_opts: RunOpts object obtained from the process_args()
    """

    arg_string = pprint.pformat(vars(args))
    logger.info("Arguments for the experiment\n{0}".format(arg_string))

    # Set some variables.
    # num_leaves = common_lib.get_number_of_leaves_from_tree(args.ali_dir)
    num_jobs = common_lib.get_number_of_jobs(args.ali_dir)
    feat_dim = common_lib.get_feat_dim(args.feat_dir)
    ivector_dim = common_lib.get_ivector_dim(args.online_ivector_dir)

    # split the training data into parts for individual jobs
    # we will use the same number of jobs as that used for alignment
    common_lib.split_data(args.feat_dir, num_jobs)
    shutil.copy('{0}/tree'.format(args.ali_dir), args.dir)

    with open('{0}/num_jobs'.format(args.dir), 'w') as f:
        f.write(str(num_jobs))

    config_dir = '{0}/configs'.format(args.dir)
    var_file = '{0}/vars'.format(config_dir)

    variables = common_train_lib.parse_generic_config_vars_file(var_file)

    # Set some variables.
    try:
        model_left_context = variables['model_left_context']
        model_right_context = variables['model_right_context']
        # this is really the number of times we add layers to the network for
        # discriminative pretraining
        num_hidden_layers = variables['num_hidden_layers']
    except KeyError as e:
        raise Exception("KeyError {0}: Variables need to be defined in "
                        "{1}".format(str(e), '{0}/configs'.format(args.dir)))

    left_context = args.chunk_left_context + model_left_context
    right_context = args.chunk_right_context + model_right_context

    # Initialize as "raw" nnet, prior to training the LDA-like preconditioning
    # matrix.  This first config just does any initial splicing that we do;
    # we do this as it's a convenient way to get the stats for the 'lda-like'
    # transform.

    if (args.stage <= -5):
        logger.info("Initializing a basic network for estimating "
                    "preconditioning matrix")
        common_lib.run_job(
            """{command} {dir}/log/nnet_init.log \
                    nnet3-init --srand=-2 {dir}/configs/init.config \
                    {dir}/init.raw""".format(command=run_opts.command,
                                             dir=args.dir))

    default_egs_dir = '{0}/egs'.format(args.dir)
    if (args.stage <= -4) and args.egs_dir is None:
        logger.info("Generating egs")

        train_lib.acoustic_model.generate_egs(
            data=args.feat_dir, alidir=args.ali_dir, egs_dir=default_egs_dir,
            left_context=left_context, right_context=right_context,
            run_opts=run_opts,
            frames_per_eg=args.frames_per_eg,
            srand=args.srand,
            egs_opts=args.egs_opts,
            cmvn_opts=args.cmvn_opts,
            online_ivector_dir=args.online_ivector_dir,
            samples_per_iter=args.samples_per_iter,
            transform_dir=args.transform_dir,
            stage=args.egs_stage)

    if args.egs_dir is None:
        egs_dir = default_egs_dir
    else:
        egs_dir = args.egs_dir

    [egs_left_context, egs_right_context,
     frames_per_eg, num_archives] = (
        common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim,
                                        left_context, right_context))
    assert(args.frames_per_eg == frames_per_eg)

    if (args.num_jobs_final > num_archives):
        raise Exception('num_jobs_final cannot exceed the number of archives '
                        'in the egs directory')

    # copy the properties of the egs to dir for
    # use during decoding
    common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir)

    if (args.stage <= -3):
        logger.info('Computing the preconditioning matrix for input features')

        train_lib.common.compute_preconditioning_matrix(
            args.dir, egs_dir, num_archives, run_opts,
            max_lda_jobs=args.max_lda_jobs,
            rand_prune=args.rand_prune)

    if (args.stage <= -2):
        logger.info("Computing initial vector for FixedScaleComponent before"
                    " softmax, using priors^{prior_scale} and rescaling to"
                    " average 1".format(
                        prior_scale=args.presoftmax_prior_scale_power))

        common_train_lib.compute_presoftmax_prior_scale(
                args.dir, args.ali_dir, num_jobs, run_opts,
                presoftmax_prior_scale_power=args.presoftmax_prior_scale_power)

    if (args.stage <= -1):
        logger.info("Preparing the initial acoustic model.")
        train_lib.acoustic_model.prepare_initial_acoustic_model(
            args.dir, args.ali_dir, run_opts)

    # set num_iters so that as close as possible, we process the data
    # $num_epochs times, i.e. $num_iters*$avg_num_jobs) ==
    # $num_epochs*$num_archives, where
    # avg_num_jobs=(num_jobs_initial+num_jobs_final)/2.
    num_archives_expanded = num_archives * args.frames_per_eg
    num_archives_to_process = args.num_epochs * num_archives_expanded
    num_archives_processed = 0
    num_iters = ((num_archives_to_process * 2)
                 / (args.num_jobs_initial + args.num_jobs_final))

    models_to_combine = common_train_lib.verify_iterations(
        num_iters, args.num_epochs,
        num_hidden_layers, num_archives_expanded,
        args.max_models_combine, args.add_layers_period,
        args.num_jobs_final)

    def learning_rate(iter, current_num_jobs, num_archives_processed):
        return common_train_lib.get_learning_rate(iter, current_num_jobs,
                                                  num_iters,
                                                  num_archives_processed,
                                                  num_archives_to_process,
                                                  args.initial_effective_lrate,
                                                  args.final_effective_lrate)

    logger.info("Training will run for {0} epochs = "
                "{1} iterations".format(args.num_epochs, num_iters))

    for iter in range(num_iters):
        if (args.exit_stage is not None) and (iter == args.exit_stage):
            logger.info("Exiting early due to --exit-stage {0}".format(iter))
            return
        current_num_jobs = int(0.5 + args.num_jobs_initial
                               + (args.num_jobs_final - args.num_jobs_initial)
                               * float(iter) / num_iters)

        if args.stage <= iter:
            logger.info("On iteration {0}, learning rate is {1}.".format(
                iter, learning_rate(iter, current_num_jobs,
                                    num_archives_processed)))

            train_lib.common.train_one_iteration(
                dir=args.dir,
                iter=iter,
                srand=args.srand,
                egs_dir=egs_dir,
                num_jobs=current_num_jobs,
                num_archives_processed=num_archives_processed,
                num_archives=num_archives,
                learning_rate=learning_rate(iter, current_num_jobs,
                                            num_archives_processed),
                minibatch_size=args.minibatch_size,
                frames_per_eg=args.frames_per_eg,
                num_hidden_layers=num_hidden_layers,
                add_layers_period=args.add_layers_period,
                left_context=left_context,
                right_context=right_context,
                momentum=args.momentum,
                max_param_change=args.max_param_change,
                shuffle_buffer_size=args.shuffle_buffer_size,
                run_opts=run_opts,
                background_process_handler=background_process_handler)

            if args.cleanup:
                # do a clean up everythin but the last 2 models, under certain
                # conditions
                common_train_lib.remove_model(
                    args.dir, iter-2, num_iters, models_to_combine,
                    args.preserve_model_interval)

            if args.email is not None:
                reporting_iter_interval = num_iters * args.reporting_interval
                if iter % reporting_iter_interval == 0:
                    # lets do some reporting
                    [report, times, data] = (
                        nnet3_log_parse.generate_accuracy_report(args.dir))
                    message = report
                    subject = ("Update : Expt {dir} : "
                               "Iter {iter}".format(dir=args.dir, iter=iter))
                    common_lib.send_mail(message, subject, args.email)

        num_archives_processed = num_archives_processed + current_num_jobs

    if args.stage <= num_iters:
        logger.info("Doing final combination to produce final.mdl")
        train_lib.common.combine_models(
            dir=args.dir, num_iters=num_iters,
            models_to_combine=models_to_combine,
            egs_dir=egs_dir,
            left_context=left_context, right_context=right_context,
            run_opts=run_opts,
            background_process_handler=background_process_handler)

    if args.stage <= num_iters + 1:
        logger.info("Getting average posterior for purposes of "
                    "adjusting the priors.")
        avg_post_vec_file = train_lib.common.compute_average_posterior(
            dir=args.dir, iter='combined', egs_dir=egs_dir,
            num_archives=num_archives,
            left_context=left_context, right_context=right_context,
            prior_subset_size=args.prior_subset_size, run_opts=run_opts)

        logger.info("Re-adjusting priors based on computed posteriors")
        combined_model = "{dir}/combined.mdl".format(dir=args.dir)
        final_model = "{dir}/final.mdl".format(dir=args.dir)
        train_lib.common.adjust_am_priors(args.dir, combined_model,
                                          avg_post_vec_file, final_model,
                                          run_opts)

    if args.cleanup:
        logger.info("Cleaning up the experiment directory "
                    "{0}".format(args.dir))
        remove_egs = args.remove_egs
        if args.egs_dir is not None:
            # this egs_dir was not created by this experiment so we will not
            # delete it
            remove_egs = False

        common_train_lib.clean_nnet_dir(
            nnet_dir=args.dir, num_iters=num_iters, egs_dir=egs_dir,
            preserve_model_interval=args.preserve_model_interval,
            remove_egs=remove_egs)

    # do some reporting
    [report, times, data] = nnet3_log_parse.generate_accuracy_report(args.dir)
    if args.email is not None:
        common_lib.send_mail(report, "Update : Expt {0} : "
                                     "complete".format(args.dir), args.email)

    with open("{dir}/accuracy.report".format(dir=args.dir), "w") as f:
        f.write(report)

    common_lib.run_job("steps/info/nnet3_dir_info.pl "
                       "{0}".format(args.dir))
Esempio n. 3
0
def train(args, run_opts, background_process_handler):
    """ The main function for training.

    Args:
        args: a Namespace object with the required parameters
            obtained from the function process_args()
        run_opts: RunOpts object obtained from the process_args()
    """

    arg_string = pprint.pformat(vars(args))
    logger.info("Arguments for the experiment\n{0}".format(arg_string))

    # Set some variables.
    feat_dim = common_lib.get_feat_dim(args.feat_dir)
    ivector_dim = common_lib.get_ivector_dim(args.online_ivector_dir)

    config_dir = '{0}/configs'.format(args.dir)
    var_file = '{0}/vars'.format(config_dir)

    variables = common_train_lib.parse_generic_config_vars_file(var_file)

    # Set some variables.
    try:
        model_left_context = variables['model_left_context']
        model_right_context = variables['model_right_context']
        # this is really the number of times we add layers to the network for
        # discriminative pretraining
        num_hidden_layers = variables['num_hidden_layers']
        add_lda = common_lib.str_to_bool(variables['add_lda'])
        include_log_softmax = common_lib.str_to_bool(
            variables['include_log_softmax'])
    except KeyError as e:
        raise Exception("KeyError {0}: Variables need to be defined in "
                        "{1}".format(str(e), '{0}/configs'.format(args.dir)))

    left_context = args.chunk_left_context + model_left_context
    right_context = args.chunk_right_context + model_right_context

    # Initialize as "raw" nnet, prior to training the LDA-like preconditioning
    # matrix.  This first config just does any initial splicing that we do;
    # we do this as it's a convenient way to get the stats for the 'lda-like'
    # transform.

    if (args.stage <= -5):
        logger.info("Initializing a basic network")
        common_lib.run_job(
            """{command} {dir}/log/nnet_init.log \
                    nnet3-init --srand=-2 {dir}/configs/init.config \
                    {dir}/init.raw""".format(command=run_opts.command,
                                             dir=args.dir))

    default_egs_dir = '{0}/egs'.format(args.dir)
    if (args.stage <= -4) and args.egs_dir is None:
        logger.info("Generating egs")

        if args.use_dense_targets:
            target_type = "dense"
            try:
                num_targets = int(variables['num_targets'])
                if (common_lib.get_feat_dim_from_scp(args.targets_scp)
                        != num_targets):
                    raise Exception("Mismatch between num-targets provided to "
                                    "script vs configs")
            except KeyError as e:
                num_targets = -1
        else:
            target_type = "sparse"
            try:
                num_targets = int(variables['num_targets'])
            except KeyError as e:
                raise Exception("KeyError {0}: Variables need to be defined "
                                "in {1}".format(
                                    str(e), '{0}/configs'.format(args.dir)))

        train_lib.raw_model.generate_egs_using_targets(
            data=args.feat_dir, targets_scp=args.targets_scp,
            egs_dir=default_egs_dir,
            left_context=left_context, right_context=right_context,
            valid_left_context=left_context, valid_right_context=right_context,
            run_opts=run_opts,
            frames_per_eg=args.frames_per_eg,
            srand=args.srand,
            egs_opts=args.egs_opts,
            cmvn_opts=args.cmvn_opts,
            online_ivector_dir=args.online_ivector_dir,
            samples_per_iter=args.samples_per_iter,
            transform_dir=args.transform_dir,
            stage=args.egs_stage,
            target_type=target_type,
            num_targets=num_targets)

    if args.egs_dir is None:
        egs_dir = default_egs_dir
    else:
        egs_dir = args.egs_dir

    [egs_left_context, egs_right_context,
     frames_per_eg, num_archives] = (
        common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim,
                                        left_context, right_context))
    assert(args.frames_per_eg == frames_per_eg)

    if (args.num_jobs_final > num_archives):
        raise Exception('num_jobs_final cannot exceed the number of archives '
                        'in the egs directory')

    # copy the properties of the egs to dir for
    # use during decoding
    common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir)

    if (add_lda and args.stage <= -3):
        logger.info('Computing the preconditioning matrix for input features')

        train_lib.common.compute_preconditioning_matrix(
            args.dir, egs_dir, num_archives, run_opts,
            max_lda_jobs=args.max_lda_jobs,
            rand_prune=args.rand_prune)

    if (args.stage <= -1):
        logger.info("Preparing the initial network.")
        common_train_lib.prepare_initial_network(args.dir, run_opts)

    # set num_iters so that as close as possible, we process the data
    # $num_epochs times, i.e. $num_iters*$avg_num_jobs) ==
    # $num_epochs*$num_archives, where
    # avg_num_jobs=(num_jobs_initial+num_jobs_final)/2.
    num_archives_expanded = num_archives * args.frames_per_eg
    num_archives_to_process = args.num_epochs * num_archives_expanded
    num_archives_processed = 0
    num_iters = ((num_archives_to_process * 2)
                 / (args.num_jobs_initial + args.num_jobs_final))

    models_to_combine = common_train_lib.verify_iterations(
        num_iters, args.num_epochs,
        num_hidden_layers, num_archives_expanded,
        args.max_models_combine, args.add_layers_period,
        args.num_jobs_final)

    def learning_rate(iter, current_num_jobs, num_archives_processed):
        return common_train_lib.get_learning_rate(iter, current_num_jobs,
                                                  num_iters,
                                                  num_archives_processed,
                                                  num_archives_to_process,
                                                  args.initial_effective_lrate,
                                                  args.final_effective_lrate)

    logger.info("Training will run for {0} epochs = "
                "{1} iterations".format(args.num_epochs, num_iters))

    for iter in range(num_iters):
        if (args.exit_stage is not None) and (iter == args.exit_stage):
            logger.info("Exiting early due to --exit-stage {0}".format(iter))
            return
        current_num_jobs = int(0.5 + args.num_jobs_initial
                               + (args.num_jobs_final - args.num_jobs_initial)
                               * float(iter) / num_iters)

        if args.stage <= iter:
            train_lib.common.train_one_iteration(
                dir=args.dir,
                iter=iter,
                srand=args.srand,
                egs_dir=egs_dir,
                num_jobs=current_num_jobs,
                num_archives_processed=num_archives_processed,
                num_archives=num_archives,
                learning_rate=learning_rate(iter, current_num_jobs,
                                            num_archives_processed),
                dropout_edit_string=common_lib.get_dropout_edit_string(
                    args.dropout_schedule,
                    float(num_archives_processed) / num_archives_to_process,
                    iter),
                minibatch_size=args.minibatch_size,
                frames_per_eg=args.frames_per_eg,
                num_hidden_layers=num_hidden_layers,
                add_layers_period=args.add_layers_period,
                left_context=left_context,
                right_context=right_context,
                momentum=args.momentum,
                max_param_change=args.max_param_change,
                shuffle_buffer_size=args.shuffle_buffer_size,
                run_opts=run_opts,
                get_raw_nnet_from_am=False,
                background_process_handler=background_process_handler)

            if args.cleanup:
                # do a clean up everythin but the last 2 models, under certain
                # conditions
                common_train_lib.remove_model(
                    args.dir, iter-2, num_iters, models_to_combine,
                    args.preserve_model_interval,
                    get_raw_nnet_from_am=False)

            if args.email is not None:
                reporting_iter_interval = num_iters * args.reporting_interval
                if iter % reporting_iter_interval == 0:
                    # lets do some reporting
                    [report, times, data] = (
                        nnet3_log_parse.generate_accuracy_report(args.dir))
                    message = report
                    subject = ("Update : Expt {dir} : "
                               "Iter {iter}".format(dir=args.dir, iter=iter))
                    common_lib.send_mail(message, subject, args.email)

        num_archives_processed = num_archives_processed + current_num_jobs

    if args.stage <= num_iters:
        logger.info("Doing final combination to produce final.raw")
        train_lib.common.combine_models(
            dir=args.dir, num_iters=num_iters,
            models_to_combine=models_to_combine, egs_dir=egs_dir,
            left_context=left_context, right_context=right_context,
            run_opts=run_opts,
            background_process_handler=background_process_handler,
            get_raw_nnet_from_am=False)

    if include_log_softmax and args.stage <= num_iters + 1:
        logger.info("Getting average posterior for purposes of "
                    "adjusting the priors.")
        train_lib.common.compute_average_posterior(
            dir=args.dir, iter='final', egs_dir=egs_dir,
            num_archives=num_archives,
            left_context=left_context, right_context=right_context,
            prior_subset_size=args.prior_subset_size, run_opts=run_opts,
            get_raw_nnet_from_am=False)

    if args.cleanup:
        logger.info("Cleaning up the experiment directory "
                    "{0}".format(args.dir))
        remove_egs = args.remove_egs
        if args.egs_dir is not None:
            # this egs_dir was not created by this experiment so we will not
            # delete it
            remove_egs = False

        common_train_lib.clean_nnet_dir(
            nnet_dir=args.dir, num_iters=num_iters, egs_dir=egs_dir,
            preserve_model_interval=args.preserve_model_interval,
            remove_egs=remove_egs,
            get_raw_nnet_from_am=False)

    # do some reporting
    [report, times, data] = nnet3_log_parse.generate_accuracy_report(args.dir)
    if args.email is not None:
        common_lib.send_mail(report, "Update : Expt {0} : "
                                     "complete".format(args.dir), args.email)

    with open("{dir}/accuracy.report".format(dir=args.dir), "w") as f:
        f.write(report)

    common_lib.run_job("steps/info/nnet3_dir_info.pl "
                       "{0}".format(args.dir))
Esempio n. 4
0
def train(args, run_opts, background_process_handler):
    """ The main function for training.

    Args:
        args: a Namespace object with the required parameters
            obtained from the function process_args()
        run_opts: RunOpts object obtained from the process_args()
    """

    arg_string = pprint.pformat(vars(args))
    logger.info("Arguments for the experiment\n{0}".format(arg_string))

    # Check files
    chain_lib.check_for_required_files(args.feat_dir, args.tree_dir,
                                       args.lat_dir)

    # Set some variables.
    num_jobs = common_lib.get_number_of_jobs(args.tree_dir)
    feat_dim = common_lib.get_feat_dim(args.feat_dir)
    ivector_dim = common_lib.get_ivector_dim(args.online_ivector_dir)

    # split the training data into parts for individual jobs
    # we will use the same number of jobs as that used for alignment
    common_lib.split_data(args.feat_dir, num_jobs)
    shutil.copy('{0}/tree'.format(args.tree_dir), args.dir)
    with open('{0}/num_jobs'.format(args.dir), 'w') as f:
        f.write(str(num_jobs))

    config_dir = '{0}/configs'.format(args.dir)
    var_file = '{0}/vars'.format(config_dir)

    variables = common_train_lib.parse_generic_config_vars_file(var_file)

    # Set some variables.
    try:
        model_left_context = variables['model_left_context']
        model_right_context = variables['model_right_context']
        # this is really the number of times we add layers to the network for
        # discriminative pretraining
        num_hidden_layers = variables['num_hidden_layers']
    except KeyError as e:
        raise Exception("KeyError {0}: Variables need to be defined in "
                        "{1}".format(str(e), '{0}/configs'.format(args.dir)))

    left_context = args.chunk_left_context + model_left_context
    right_context = args.chunk_right_context + model_right_context

    # Initialize as "raw" nnet, prior to training the LDA-like preconditioning
    # matrix.  This first config just does any initial splicing that we do;
    # we do this as it's a convenient way to get the stats for the 'lda-like'
    # transform.
    if (args.stage <= -6):
        logger.info("Creating phone language-model")
        chain_lib.create_phone_lm(args.dir, args.tree_dir, run_opts,
                                  lm_opts=args.lm_opts)

    if (args.stage <= -5):
        logger.info("Creating denominator FST")
        chain_lib.create_denominator_fst(args.dir, args.tree_dir, run_opts)

    if (args.stage <= -4):
        logger.info("Initializing a basic network for estimating "
                    "preconditioning matrix")
        common_lib.run_kaldi_command(
            """{command} {dir}/log/nnet_init.log \
                    nnet3-init --srand=-2 {dir}/configs/init.config \
                    {dir}/init.raw""".format(command=run_opts.command,
                                             dir=args.dir))

    egs_left_context = left_context + args.frame_subsampling_factor/2
    egs_right_context = right_context + args.frame_subsampling_factor/2

    default_egs_dir = '{0}/egs'.format(args.dir)
    if (args.stage <= -3) and args.egs_dir is None:
        logger.info("Generating egs")
        # this is where get_egs.sh is called.
        chain_lib.generate_chain_egs(
            dir=args.dir, data=args.feat_dir,
            lat_dir=args.lat_dir, egs_dir=default_egs_dir,
            left_context=egs_left_context,
            right_context=egs_right_context,
            run_opts=run_opts,
            left_tolerance=args.left_tolerance,
            right_tolerance=args.right_tolerance,
            frame_subsampling_factor=args.frame_subsampling_factor,
            alignment_subsampling_factor=args.alignment_subsampling_factor,
            frames_per_eg=args.chunk_width,
            srand=args.srand,
            egs_opts=args.egs_opts,
            cmvn_opts=args.cmvn_opts,
            online_ivector_dir=args.online_ivector_dir,
            frames_per_iter=args.frames_per_iter,
            transform_dir=args.transform_dir,
            stage=args.egs_stage)

    if args.egs_dir is None:
        egs_dir = default_egs_dir
    else:
        egs_dir = args.egs_dir

    [egs_left_context, egs_right_context,
     frames_per_eg, num_archives] = (
        common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim,
                                        egs_left_context, egs_right_context))
    assert(args.chunk_width == frames_per_eg)
    num_archives_expanded = num_archives * args.frame_subsampling_factor

    if (args.num_jobs_final > num_archives_expanded):
        raise Exception('num_jobs_final cannot exceed the '
                        'expanded number of archives')

    # copy the properties of the egs to dir for
    # use during decoding
    common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir)

    if (args.stage <= -2):
        logger.info('Computing the preconditioning matrix for input features')

        chain_lib.compute_preconditioning_matrix(
            args.dir, egs_dir, num_archives, run_opts,
            max_lda_jobs=args.max_lda_jobs,
            rand_prune=args.rand_prune)

    if (args.stage <= -1):
        logger.info("Preparing the initial acoustic model.")
        chain_lib.prepare_initial_acoustic_model(args.dir, run_opts)

    with open("{0}/frame_subsampling_factor".format(args.dir), "w") as f:
        f.write(str(args.frame_subsampling_factor))

    # set num_iters so that as close as possible, we process the data
    # $num_epochs times, i.e. $num_iters*$avg_num_jobs) ==
    # $num_epochs*$num_archives, where
    # avg_num_jobs=(num_jobs_initial+num_jobs_final)/2.
    num_archives_to_process = args.num_epochs * num_archives_expanded
    num_archives_processed = 0
    num_iters = ((num_archives_to_process * 2)
                 / (args.num_jobs_initial + args.num_jobs_final))

    models_to_combine = common_train_lib.verify_iterations(
        num_iters, args.num_epochs,
        num_hidden_layers, num_archives_expanded,
        args.max_models_combine, args.add_layers_period,
        args.num_jobs_final)

    def learning_rate(iter, current_num_jobs, num_archives_processed):
        return common_train_lib.get_learning_rate(iter, current_num_jobs,
                                                  num_iters,
                                                  num_archives_processed,
                                                  num_archives_to_process,
                                                  args.initial_effective_lrate,
                                                  args.final_effective_lrate)

    min_deriv_time = None
    max_deriv_time = None
    if args.deriv_truncate_margin is not None:
        min_deriv_time = -args.deriv_truncate_margin - model_left_context
        max_deriv_time = (args.chunk_width - 1 + args.deriv_truncate_margin
                          + model_right_context)

    logger.info("Training will run for {0} epochs = "
                "{1} iterations".format(args.num_epochs, num_iters))

    for iter in range(num_iters):
        if (args.exit_stage is not None) and (iter == args.exit_stage):
            logger.info("Exiting early due to --exit-stage {0}".format(iter))
            return
        current_num_jobs = int(0.5 + args.num_jobs_initial
                               + (args.num_jobs_final - args.num_jobs_initial)
                               * float(iter) / num_iters)

        if args.stage <= iter:
            model_file = "{dir}/{iter}.mdl".format(dir=args.dir, iter=iter)
            shrinkage_value = 1.0
            if args.shrink_value != 1.0:
                shrinkage_value = (args.shrink_value
                                   if common_train_lib.do_shrinkage(
                                        iter, model_file,
                                        args.shrink_saturation_threshold)
                                   else 1
                                   )
            logger.info("On iteration {0}, learning rate is {1} and "
                        "shrink value is {2}.".format(
                            iter, learning_rate(iter, current_num_jobs,
                                                num_archives_processed),
                            shrinkage_value))

            chain_lib.train_one_iteration(
                dir=args.dir,
                iter=iter,
                srand=args.srand,
                egs_dir=egs_dir,
                num_jobs=current_num_jobs,
                num_archives_processed=num_archives_processed,
                num_archives=num_archives,
                learning_rate=learning_rate(iter, current_num_jobs,
                                            num_archives_processed),
                shrinkage_value=shrinkage_value,
                num_chunk_per_minibatch=args.num_chunk_per_minibatch,
                num_hidden_layers=num_hidden_layers,
                add_layers_period=args.add_layers_period,
                left_context=left_context,
                right_context=right_context,
                apply_deriv_weights=args.apply_deriv_weights,
                min_deriv_time=min_deriv_time,
                max_deriv_time=max_deriv_time,
                l2_regularize=args.l2_regularize,
                xent_regularize=args.xent_regularize,
                leaky_hmm_coefficient=args.leaky_hmm_coefficient,
                momentum=args.momentum,
                max_param_change=args.max_param_change,
                shuffle_buffer_size=args.shuffle_buffer_size,
                frame_subsampling_factor=args.frame_subsampling_factor,
                truncate_deriv_weights=args.truncate_deriv_weights,
                run_opts=run_opts,
                background_process_handler=background_process_handler)

            if args.cleanup:
                # do a clean up everythin but the last 2 models, under certain
                # conditions
                common_train_lib.remove_model(
                    args.dir, iter-2, num_iters, models_to_combine,
                    args.preserve_model_interval)

            if args.email is not None:
                reporting_iter_interval = num_iters * args.reporting_interval
                if iter % reporting_iter_interval == 0:
                    # lets do some reporting
                    [report, times, data] = (
                        nnet3_log_parse.generate_accuracy_report(
                            args.dir, "log-probability"))
                    message = report
                    subject = ("Update : Expt {dir} : "
                               "Iter {iter}".format(dir=args.dir, iter=iter))
                    common_lib.send_mail(message, subject, args.email)

        num_archives_processed = num_archives_processed + current_num_jobs

    if args.stage <= num_iters:
        logger.info("Doing final combination to produce final.mdl")
        chain_lib.combine_models(
            dir=args.dir, num_iters=num_iters,
            models_to_combine=models_to_combine,
            num_chunk_per_minibatch=args.num_chunk_per_minibatch,
            egs_dir=egs_dir,
            left_context=left_context, right_context=right_context,
            leaky_hmm_coefficient=args.leaky_hmm_coefficient,
            l2_regularize=args.l2_regularize,
            xent_regularize=args.xent_regularize,
            run_opts=run_opts,
            background_process_handler=background_process_handler)

    if args.cleanup:
        logger.info("Cleaning up the experiment directory "
                    "{0}".format(args.dir))
        remove_egs = args.remove_egs
        if args.egs_dir is not None:
            # this egs_dir was not created by this experiment so we will not
            # delete it
            remove_egs = False

        common_train_lib.clean_nnet_dir(
            args.dir, num_iters, egs_dir,
            preserve_model_interval=args.preserve_model_interval,
            remove_egs=remove_egs)

    # do some reporting
    [report, times, data] = nnet3_log_parse.generate_accuracy_report(
        args.dir, "log-probability")
    if args.email is not None:
        common_lib.send_mail(report, "Update : Expt {0} : "
                                     "complete".format(args.dir), args.email)

    with open("{dir}/accuracy.report".format(dir=args.dir), "w") as f:
        f.write(report)

    common_lib.run_kaldi_command("steps/info/nnet3_dir_info.pl "
                                 "{0}".format(args.dir))
Esempio n. 5
0
def train(args, run_opts, background_process_handler):
    """ The main function for training.

    Args:
        args: a Namespace object with the required parameters
            obtained from the function process_args()
        run_opts: RunOpts object obtained from the process_args()
    """

    arg_string = pprint.pformat(vars(args))
    logger.info("Arguments for the experiment\n{0}".format(arg_string))

    # Set some variables.
    feat_dim = common_lib.get_feat_dim(args.feat_dir)
    ivector_dim = common_lib.get_ivector_dim(args.online_ivector_dir)

    config_dir = '{0}/configs'.format(args.dir)
    var_file = '{0}/vars'.format(config_dir)

    variables = common_train_lib.parse_generic_config_vars_file(var_file)

    # Set some variables.
    try:
        model_left_context = variables['model_left_context']
        model_right_context = variables['model_right_context']
        # this is really the number of times we add layers to the network for
        # discriminative pretraining
        num_hidden_layers = variables['num_hidden_layers']
        add_lda = common_lib.str_to_bool(variables['add_lda'])
        include_log_softmax = common_lib.str_to_bool(
            variables['include_log_softmax'])
    except KeyError as e:
        raise Exception("KeyError {0}: Variables need to be defined in "
                        "{1}".format(str(e), '{0}/configs'.format(args.dir)))

    left_context = args.chunk_left_context + model_left_context
    right_context = args.chunk_right_context + model_right_context

    # Initialize as "raw" nnet, prior to training the LDA-like preconditioning
    # matrix.  This first config just does any initial splicing that we do;
    # we do this as it's a convenient way to get the stats for the 'lda-like'
    # transform.

    if (args.stage <= -4):
        logger.info("Initializing a basic network")
        common_lib.run_job("""{command} {dir}/log/nnet_init.log \
                    nnet3-init --srand=-2 {dir}/configs/init.config \
                    {dir}/init.raw""".format(command=run_opts.command,
                                             dir=args.dir))

    default_egs_dir = '{0}/egs'.format(args.dir)
    if (args.stage <= -3) and args.egs_dir is None:
        logger.info("Generating egs")

        if args.use_dense_targets:
            target_type = "dense"
            try:
                num_targets = int(variables['num_targets'])
                if (common_lib.get_feat_dim_from_scp(args.targets_scp) !=
                        num_targets):
                    raise Exception("Mismatch between num-targets provided to "
                                    "script vs configs")
            except KeyError as e:
                num_targets = -1
        else:
            target_type = "sparse"
            try:
                num_targets = int(variables['num_targets'])
            except KeyError as e:
                raise Exception("KeyError {0}: Variables need to be defined "
                                "in {1}".format(str(e), '{0}/configs'.format(
                                    args.dir)))

        train_lib.raw_model.generate_egs_using_targets(
            data=args.feat_dir,
            targets_scp=args.targets_scp,
            egs_dir=default_egs_dir,
            left_context=left_context,
            right_context=right_context,
            valid_left_context=left_context + args.chunk_width,
            valid_right_context=right_context + args.chunk_width,
            run_opts=run_opts,
            frames_per_eg=args.chunk_width,
            srand=args.srand,
            egs_opts=args.egs_opts,
            cmvn_opts=args.cmvn_opts,
            online_ivector_dir=args.online_ivector_dir,
            samples_per_iter=args.samples_per_iter,
            transform_dir=args.transform_dir,
            stage=args.egs_stage,
            target_type=target_type,
            num_targets=num_targets)

    if args.egs_dir is None:
        egs_dir = default_egs_dir
    else:
        egs_dir = args.egs_dir

    [egs_left_context, egs_right_context, frames_per_eg, num_archives
     ] = (common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim,
                                          left_context, right_context))
    assert (args.chunk_width == frames_per_eg)

    if (args.num_jobs_final > num_archives):
        raise Exception('num_jobs_final cannot exceed the number of archives '
                        'in the egs directory')

    # copy the properties of the egs to dir for
    # use during decoding
    common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir)

    if (add_lda and args.stage <= -2):
        logger.info('Computing the preconditioning matrix for input features')

        train_lib.common.compute_preconditioning_matrix(
            args.dir,
            egs_dir,
            num_archives,
            run_opts,
            max_lda_jobs=args.max_lda_jobs,
            rand_prune=args.rand_prune)

    if (args.stage <= -1):
        logger.info("Preparing the initial network.")
        common_train_lib.prepare_initial_network(args.dir, run_opts)

    # set num_iters so that as close as possible, we process the data
    # $num_epochs times, i.e. $num_iters*$avg_num_jobs) ==
    # $num_epochs*$num_archives, where
    # avg_num_jobs=(num_jobs_initial+num_jobs_final)/2.
    num_archives_to_process = args.num_epochs * num_archives
    num_archives_processed = 0
    num_iters = ((num_archives_to_process * 2) /
                 (args.num_jobs_initial + args.num_jobs_final))

    models_to_combine = common_train_lib.verify_iterations(
        num_iters, args.num_epochs, num_hidden_layers, num_archives,
        args.max_models_combine, args.add_layers_period, args.num_jobs_final)

    def learning_rate(iter, current_num_jobs, num_archives_processed):
        return common_train_lib.get_learning_rate(iter, current_num_jobs,
                                                  num_iters,
                                                  num_archives_processed,
                                                  num_archives_to_process,
                                                  args.initial_effective_lrate,
                                                  args.final_effective_lrate)

    min_deriv_time = None
    max_deriv_time = None
    if args.deriv_truncate_margin is not None:
        min_deriv_time = -args.deriv_truncate_margin - model_left_context
        max_deriv_time = (args.chunk_width - 1 + args.deriv_truncate_margin +
                          model_right_context)

    logger.info("Training will run for {0} epochs = "
                "{1} iterations".format(args.num_epochs, num_iters))

    for iter in range(num_iters):
        if (args.exit_stage is not None) and (iter == args.exit_stage):
            logger.info("Exiting early due to --exit-stage {0}".format(iter))
            return
        current_num_jobs = int(0.5 + args.num_jobs_initial +
                               (args.num_jobs_final - args.num_jobs_initial) *
                               float(iter) / num_iters)

        if args.stage <= iter:
            model_file = "{dir}/{iter}.raw".format(dir=args.dir, iter=iter)

            shrinkage_value = 1.0
            if args.shrink_value != 1.0:
                shrinkage_value = (args.shrink_value
                                   if common_train_lib.do_shrinkage(
                                       iter,
                                       model_file,
                                       args.shrink_saturation_threshold,
                                       get_raw_nnet_from_am=False) else 1)
            logger.info("On iteration {0}, learning rate is {1} and "
                        "shrink value is {2}.".format(
                            iter,
                            learning_rate(iter, current_num_jobs,
                                          num_archives_processed),
                            shrinkage_value))

            train_lib.common.train_one_iteration(
                dir=args.dir,
                iter=iter,
                srand=args.srand,
                egs_dir=egs_dir,
                num_jobs=current_num_jobs,
                num_archives_processed=num_archives_processed,
                num_archives=num_archives,
                learning_rate=learning_rate(iter, current_num_jobs,
                                            num_archives_processed),
                shrinkage_value=shrinkage_value,
                minibatch_size=args.num_chunk_per_minibatch,
                num_hidden_layers=num_hidden_layers,
                add_layers_period=args.add_layers_period,
                left_context=left_context,
                right_context=right_context,
                min_deriv_time=min_deriv_time,
                max_deriv_time=max_deriv_time,
                momentum=args.momentum,
                max_param_change=args.max_param_change,
                shuffle_buffer_size=args.shuffle_buffer_size,
                cv_minibatch_size=args.cv_minibatch_size,
                run_opts=run_opts,
                get_raw_nnet_from_am=False,
                background_process_handler=background_process_handler)

            if args.cleanup:
                # do a clean up everythin but the last 2 models, under certain
                # conditions
                common_train_lib.remove_model(args.dir,
                                              iter - 2,
                                              num_iters,
                                              models_to_combine,
                                              args.preserve_model_interval,
                                              get_raw_nnet_from_am=False)

            if args.email is not None:
                reporting_iter_interval = num_iters * args.reporting_interval
                if iter % reporting_iter_interval == 0:
                    # lets do some reporting
                    [report, times, data
                     ] = (nnet3_log_parse.generate_accuracy_report(args.dir))
                    message = report
                    subject = ("Update : Expt {dir} : "
                               "Iter {iter}".format(dir=args.dir, iter=iter))
                    common_lib.send_mail(message, subject, args.email)

        num_archives_processed = num_archives_processed + current_num_jobs

    if args.stage <= num_iters:
        logger.info("Doing final combination to produce final.raw")
        train_lib.common.combine_models(
            dir=args.dir,
            num_iters=num_iters,
            models_to_combine=models_to_combine,
            egs_dir=egs_dir,
            left_context=left_context,
            right_context=right_context,
            run_opts=run_opts,
            chunk_width=args.chunk_width,
            background_process_handler=background_process_handler,
            get_raw_nnet_from_am=False)

    if include_log_softmax and args.stage <= num_iters + 1:
        logger.info("Getting average posterior for purposes of "
                    "adjusting the priors.")
        train_lib.common.compute_average_posterior(
            dir=args.dir,
            iter='final',
            egs_dir=egs_dir,
            num_archives=num_archives,
            left_context=left_context,
            right_context=right_context,
            prior_subset_size=args.prior_subset_size,
            run_opts=run_opts,
            get_raw_nnet_from_am=False)

    if args.cleanup:
        logger.info("Cleaning up the experiment directory "
                    "{0}".format(args.dir))
        remove_egs = args.remove_egs
        if args.egs_dir is not None:
            # this egs_dir was not created by this experiment so we will not
            # delete it
            remove_egs = False

        common_train_lib.clean_nnet_dir(
            nnet_dir=args.dir,
            num_iters=num_iters,
            egs_dir=egs_dir,
            preserve_model_interval=args.preserve_model_interval,
            remove_egs=remove_egs,
            get_raw_nnet_from_am=False)

    # do some reporting
    [report, times, data] = nnet3_log_parse.generate_accuracy_report(args.dir)
    if args.email is not None:
        common_lib.send_mail(
            report, "Update : Expt {0} : "
            "complete".format(args.dir), args.email)

    with open("{dir}/accuracy.report".format(dir=args.dir), "w") as f:
        f.write(report)

    common_lib.run_job("steps/info/nnet3_dir_info.pl " "{0}".format(args.dir))
Esempio n. 6
0
def train(args, run_opts, background_process_handler):
    """ The main function for training.

    Args:
        args: a Namespace object with the required parameters
            obtained from the function process_args()
        run_opts: RunOpts object obtained from the process_args()
    """

    arg_string = pprint.pformat(vars(args))
    logger.info("Arguments for the experiment\n{0}".format(arg_string))

    # Check files
    chain_lib.check_for_required_files(args.feat_dir, args.tree_dir,
                                       args.lat_dir)

    # Set some variables.
    num_jobs = common_lib.get_number_of_jobs(args.tree_dir)
    feat_dim = common_lib.get_feat_dim(args.feat_dir)
    ivector_dim = common_lib.get_ivector_dim(args.online_ivector_dir)

    # split the training data into parts for individual jobs
    # we will use the same number of jobs as that used for alignment
    common_lib.split_data(args.feat_dir, num_jobs)
    shutil.copy('{0}/tree'.format(args.tree_dir), args.dir)
    with open('{0}/num_jobs'.format(args.dir), 'w') as f:
        f.write(str(num_jobs))

    config_dir = '{0}/configs'.format(args.dir)
    var_file = '{0}/vars'.format(config_dir)

    variables = common_train_lib.parse_generic_config_vars_file(var_file)

    # Set some variables.
    try:
        model_left_context = variables['model_left_context']
        model_right_context = variables['model_right_context']
        # this is really the number of times we add layers to the network for
        # discriminative pretraining
        num_hidden_layers = variables['num_hidden_layers']
    except KeyError as e:
        raise Exception("KeyError {0}: Variables need to be defined in "
                        "{1}".format(str(e), '{0}/configs'.format(args.dir)))

    left_context = args.chunk_left_context + model_left_context
    right_context = args.chunk_right_context + model_right_context

    # Initialize as "raw" nnet, prior to training the LDA-like preconditioning
    # matrix.  This first config just does any initial splicing that we do;
    # we do this as it's a convenient way to get the stats for the 'lda-like'
    # transform.
    if (args.stage <= -6):
        logger.info("Creating phone language-model")
        chain_lib.create_phone_lm(args.dir, args.tree_dir, run_opts,
                                  lm_opts=args.lm_opts)

    if (args.stage <= -5):
        logger.info("Creating denominator FST")
        chain_lib.create_denominator_fst(args.dir, args.tree_dir, run_opts)

    if (args.stage <= -4):
        logger.info("Initializing a basic network for estimating "
                    "preconditioning matrix")
        common_lib.run_kaldi_command(
            """{command} {dir}/log/nnet_init.log \
                    nnet3-init --srand=-2 {dir}/configs/init.config \
                    {dir}/init.raw""".format(command=run_opts.command,
                                             dir=args.dir))

    egs_left_context = left_context + args.frame_subsampling_factor/2
    egs_right_context = right_context + args.frame_subsampling_factor/2

    default_egs_dir = '{0}/egs'.format(args.dir)
    if (args.stage <= -3) and args.egs_dir is None:
        logger.info("Generating egs")
        # this is where get_egs.sh is called.
        chain_lib.generate_chain_egs(
            dir=args.dir, data=args.feat_dir,
            lat_dir=args.lat_dir, egs_dir=default_egs_dir,
            left_context=egs_left_context,
            right_context=egs_right_context,
            run_opts=run_opts,
            left_tolerance=args.left_tolerance,
            right_tolerance=args.right_tolerance,
            frame_subsampling_factor=args.frame_subsampling_factor,
            alignment_subsampling_factor=args.alignment_subsampling_factor,
            frames_per_eg=args.chunk_width,
            srand=args.srand,
            egs_opts=args.egs_opts,
            cmvn_opts=args.cmvn_opts,
            online_ivector_dir=args.online_ivector_dir,
            frames_per_iter=args.frames_per_iter,
            transform_dir=args.transform_dir,
            stage=args.egs_stage)

    if args.egs_dir is None:
        egs_dir = default_egs_dir
    else:
        egs_dir = args.egs_dir

    [egs_left_context, egs_right_context,
     frames_per_eg, num_archives] = (
        common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim,
                                        egs_left_context, egs_right_context))
    assert(args.chunk_width == frames_per_eg)
    num_archives_expanded = num_archives * args.frame_subsampling_factor

    if (args.num_jobs_final > num_archives_expanded):
        raise Exception('num_jobs_final cannot exceed the '
                        'expanded number of archives')

    # copy the properties of the egs to dir for
    # use during decoding
    common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir)

    if (args.stage <= -2):
        logger.info('Computing the preconditioning matrix for input features')

        chain_lib.compute_preconditioning_matrix(
            args.dir, egs_dir, num_archives, run_opts,
            max_lda_jobs=args.max_lda_jobs,
            rand_prune=args.rand_prune)

    if (args.stage <= -1):
        logger.info("Preparing the initial acoustic model.")
        chain_lib.prepare_initial_acoustic_model(args.dir, run_opts)

    with open("{0}/frame_subsampling_factor".format(args.dir), "w") as f:
        f.write(str(args.frame_subsampling_factor))

    # set num_iters so that as close as possible, we process the data
    # $num_epochs times, i.e. $num_iters*$avg_num_jobs) ==
    # $num_epochs*$num_archives, where
    # avg_num_jobs=(num_jobs_initial+num_jobs_final)/2.
    num_archives_to_process = args.num_epochs * num_archives_expanded
    num_archives_processed = 0
    num_iters = ((num_archives_to_process * 2)
                 / (args.num_jobs_initial + args.num_jobs_final))

    models_to_combine = common_train_lib.verify_iterations(
        num_iters, args.num_epochs,
        num_hidden_layers, num_archives_expanded,
        args.max_models_combine, args.add_layers_period,
        args.num_jobs_final)

    def learning_rate(iter, current_num_jobs, num_archives_processed):
        return common_train_lib.get_learning_rate(iter, current_num_jobs,
                                                  num_iters,
                                                  num_archives_processed,
                                                  num_archives_to_process,
                                                  args.initial_effective_lrate,
                                                  args.final_effective_lrate)

    min_deriv_time = None
    max_deriv_time = None
    if args.deriv_truncate_margin is not None:
        min_deriv_time = -args.deriv_truncate_margin - model_left_context
        max_deriv_time = (args.chunk_width - 1 + args.deriv_truncate_margin
                          + model_right_context)

    logger.info("Training will run for {0} epochs = "
                "{1} iterations".format(args.num_epochs, num_iters))

    for iter in range(num_iters):
        if (args.exit_stage is not None) and (iter == args.exit_stage):
            logger.info("Exiting early due to --exit-stage {0}".format(iter))
            return
        current_num_jobs = int(0.5 + args.num_jobs_initial
                               + (args.num_jobs_final - args.num_jobs_initial)
                               * float(iter) / num_iters)

        if args.stage <= iter:
            model_file = "{dir}/{iter}.mdl".format(dir=args.dir, iter=iter)
            shrinkage_value = 1.0
            if args.shrink_value != 1.0:
                shrinkage_value = (args.shrink_value
                                   if common_train_lib.do_shrinkage(
                                        iter, model_file,
                                        args.shrink_saturation_threshold)
                                   else 1
                                   )

            chain_lib.train_one_iteration(
                dir=args.dir,
                iter=iter,
                srand=args.srand,
                egs_dir=egs_dir,
                num_jobs=current_num_jobs,
                num_archives_processed=num_archives_processed,
                num_archives=num_archives,
                learning_rate=learning_rate(iter, current_num_jobs,
                                            num_archives_processed),
                dropout_edit_string=common_lib.get_dropout_edit_string(
                    args.dropout_schedule,
                    float(num_archives_processed) / num_archives_to_process,
                    iter),
                shrinkage_value=shrinkage_value,
                num_chunk_per_minibatch=args.num_chunk_per_minibatch,
                num_hidden_layers=num_hidden_layers,
                add_layers_period=args.add_layers_period,
                left_context=left_context,
                right_context=right_context,
                apply_deriv_weights=args.apply_deriv_weights,
                min_deriv_time=min_deriv_time,
                max_deriv_time=max_deriv_time,
                l2_regularize=args.l2_regularize,
                xent_regularize=args.xent_regularize,
                leaky_hmm_coefficient=args.leaky_hmm_coefficient,
                momentum=args.momentum,
                max_param_change=args.max_param_change,
                shuffle_buffer_size=args.shuffle_buffer_size,
                frame_subsampling_factor=args.frame_subsampling_factor,
                truncate_deriv_weights=args.truncate_deriv_weights,
                run_opts=run_opts,
                background_process_handler=background_process_handler)

            if args.cleanup:
                # do a clean up everythin but the last 2 models, under certain
                # conditions
                common_train_lib.remove_model(
                    args.dir, iter-2, num_iters, models_to_combine,
                    args.preserve_model_interval)

            if args.email is not None:
                reporting_iter_interval = num_iters * args.reporting_interval
                if iter % reporting_iter_interval == 0:
                    # lets do some reporting
                    [report, times, data] = (
                        nnet3_log_parse.generate_accuracy_report(
                            args.dir, "log-probability"))
                    message = report
                    subject = ("Update : Expt {dir} : "
                               "Iter {iter}".format(dir=args.dir, iter=iter))
                    common_lib.send_mail(message, subject, args.email)

        num_archives_processed = num_archives_processed + current_num_jobs

    if args.stage <= num_iters:
        logger.info("Doing final combination to produce final.mdl")
        chain_lib.combine_models(
            dir=args.dir, num_iters=num_iters,
            models_to_combine=models_to_combine,
            num_chunk_per_minibatch=args.num_chunk_per_minibatch,
            egs_dir=egs_dir,
            left_context=left_context, right_context=right_context,
            leaky_hmm_coefficient=args.leaky_hmm_coefficient,
            l2_regularize=args.l2_regularize,
            xent_regularize=args.xent_regularize,
            run_opts=run_opts,
            background_process_handler=background_process_handler)

    if args.cleanup:
        logger.info("Cleaning up the experiment directory "
                    "{0}".format(args.dir))
        remove_egs = args.remove_egs
        if args.egs_dir is not None:
            # this egs_dir was not created by this experiment so we will not
            # delete it
            remove_egs = False

        common_train_lib.clean_nnet_dir(
            args.dir, num_iters, egs_dir,
            preserve_model_interval=args.preserve_model_interval,
            remove_egs=remove_egs)

    # do some reporting
    [report, times, data] = nnet3_log_parse.generate_accuracy_report(
        args.dir, "log-probability")
    if args.email is not None:
        common_lib.send_mail(report, "Update : Expt {0} : "
                                     "complete".format(args.dir), args.email)

    with open("{dir}/accuracy.report".format(dir=args.dir), "w") as f:
        f.write(report)

    common_lib.run_kaldi_command("steps/info/nnet3_dir_info.pl "
                                 "{0}".format(args.dir))