示例#1
0
def main(options):
    print 'Load data stream'

    train_datastream = get_datastream(path=options['data_path'],
                                      which_set='train_si84',
                                      batch_size=options['batch_size'])
    valid_datastream = get_datastream(path=options['data_path'],
                                      which_set='test_dev93',
                                      batch_size=options['batch_size'])
    print 'Build and compile network'
    input_data = T.ftensor3('input_data')
    input_mask = T.fmatrix('input_mask')
    target_data = T.imatrix('target_data')
    target_mask = T.fmatrix('target_mask')

    network = build_network(
        input_data=input_data,
        input_mask=input_mask,
        num_inputs=options['num_inputs'],
        num_inner_units_list=options['num_inner_units_list'],
        num_factor_units_list=options['num_factor_units_list'],
        num_outer_units_list=options['num_outer_units_list'],
        num_outputs=options['num_outputs'],
        dropout_ratio=options['dropout_ratio'],
        use_layer_norm=options['use_layer_norm'],
        learn_init=options['learn_init'],
        grad_clipping=options['grad_clipping'])
    network_params = get_all_params(network, trainable=True)

    if options['reload_model']:
        print('Loading Parameters...')
        pretrain_network_params_val, pretrain_update_params_val, pretrain_total_batch_cnt = pickle.load(
            open(options['reload_model'], 'rb'))

        print('Applying Parameters...')
        set_model_param_value(network_params, pretrain_network_params_val)
    else:
        pretrain_update_params_val = None
        pretrain_total_batch_cnt = 0

    print 'Build network trainer'
    training_fn, trainer_params = set_network_trainer(
        input_data=input_data,
        input_mask=input_mask,
        target_data=target_data,
        target_mask=target_mask,
        network=network,
        updater=options['updater'],
        learning_rate=options['lr'],
        grad_max_norm=options['grad_norm'],
        l2_lambda=options['l2_lambda'],
        load_updater_params=pretrain_update_params_val)

    print 'Build network predictor'
    predict_fn = set_network_predictor(input_data=input_data,
                                       input_mask=input_mask,
                                       target_data=target_data,
                                       target_mask=target_mask,
                                       network=network)

    evaluation_history = [[[10.0, 10.0, 1.0], [10.0, 10.0, 1.0]]]
    check_early_stop = 0
    total_batch_cnt = 0

    print 'Start training'
    try:
        # for each epoch
        for e_idx in range(options['num_epochs']):
            # for each batch
            for b_idx, data in enumerate(
                    train_datastream.get_epoch_iterator()):
                total_batch_cnt += 1
                if pretrain_total_batch_cnt >= total_batch_cnt:
                    continue

                # get input, target data
                train_input = data

                # get output
                train_output = training_fn(*train_input)
                train_predict_cost = train_output[2]
                network_grads_norm = train_output[3]

                # show intermediate result
                if total_batch_cnt % options[
                        'train_disp_freq'] == 0 and total_batch_cnt != 0:
                    print '============================================================================================'
                    print 'Model Name: ', options['save_path'].split('/')[-1]
                    print '============================================================================================'
                    print 'Epoch: ', str(e_idx), ', Update: ', str(
                        total_batch_cnt)
                    print '--------------------------------------------------------------------------------------------'
                    print 'Prediction Cost: ', str(train_predict_cost)
                    print 'Gradient Norm: ', str(network_grads_norm)
                    print '--------------------------------------------------------------------------------------------'
                    print 'Train NLL: ', str(
                        evaluation_history[-1][0][0]), ', BPC: ', str(
                            evaluation_history[-1][0][1]), ', FER: ', str(
                                evaluation_history[-1][0][2])
                    print 'Valid NLL: ', str(
                        evaluation_history[-1][1][0]), ', BPC: ', str(
                            evaluation_history[-1][1][1]), ', FER: ', str(
                                evaluation_history[-1][1][2])

            # evaluation
            train_nll, train_bpc, train_per = network_evaluation(
                predict_fn, train_datastream)
            valid_nll, valid_bpc, valid_per = network_evaluation(
                predict_fn, valid_datastream)

            # check over-fitting
            if valid_per > evaluation_history[-1][1][2]:
                check_early_stop += 1.
            else:
                check_early_stop = 0.
                best_network_params_vals = get_model_param_values(
                    network_params)
                pickle.dump(
                    best_network_params_vals,
                    open(options['save_path'] + '_best_model.pkl', 'wb'))

            if check_early_stop > 10:
                print('Training Early Stopped')
                break

            # save results
            evaluation_history.append([[train_nll, train_bpc, train_per],
                                       [valid_nll, valid_bpc, valid_per]])
            numpy.savez(options['save_path'] + '_eval_history',
                        eval_history=evaluation_history)

            cur_network_params_val = get_model_param_values(network_params)
            cur_trainer_params_val = get_update_params_values(trainer_params)
            cur_total_batch_cnt = total_batch_cnt
            pickle.dump([
                cur_network_params_val, cur_trainer_params_val,
                cur_total_batch_cnt
            ], open(options['save_path'] + '_last_model.pkl', 'wb'))

    except KeyboardInterrupt:
        print('Training Interrupted')
        cur_network_params_val = get_model_param_values(network_params)
        cur_trainer_params_val = get_update_params_values(trainer_params)
        cur_total_batch_cnt = total_batch_cnt
        pickle.dump([
            cur_network_params_val, cur_trainer_params_val, cur_total_batch_cnt
        ], open(options['save_path'] + '_last_model.pkl', 'wb'))
示例#2
0
def main(options):
    #################
    # build network #
    #################
    print 'Build and compile network'
    # input data
    input_data = T.ftensor3('input_data')
    input_mask = T.fmatrix('input_mask')

    # target data
    target_data = T.imatrix('target_data')
    target_mask = T.fmatrix('target_mask')

    # network
    network = build_network(input_data=input_data,
                            input_mask=input_mask,
                            num_inputs=options['num_inputs'],
                            num_units_list=options['num_units_list'],
                            num_outputs=options['num_outputs'],
                            dropout_ratio=options['dropout_ratio'],
                            use_layer_norm=options['use_layer_norm'],
                            weight_noise=options['weight_noise'],
                            learn_init=options['learn_init'],
                            grad_clipping=options['grad_clipping'])
    network_params = get_all_params(network, trainable=True)

    ###################
    # Load Parameters #
    ###################
    if options['reload_model']:
        print('Loading Parameters...')
        pretrain_network_params_val,  pretrain_update_params_val, pretrain_total_batch_cnt = pickle.load(open(options['reload_model'], 'rb'))

        print('Applying Parameters...')
        set_model_param_value(network_params, pretrain_network_params_val)
    else:
        pretrain_update_params_val = None
        pretrain_total_batch_cnt = 0

    #########################
    # build network trainer #
    #########################
    print 'Build network trainer'
    training_fn, trainer_params = set_network_trainer(input_data=input_data,
                                                      input_mask=input_mask,
                                                      target_data=target_data,
                                                      target_mask=target_mask,
                                                      network=network,
                                                      updater=options['updater'],
                                                      learning_rate=options['lr'],
                                                      grad_max_norm=options['grad_norm'],
                                                      l2_lambda=options['l2_lambda'],
                                                      load_updater_params=pretrain_update_params_val)

    ###########################
    # build network predictor #
    ###########################
    print 'Build network predictor'
    predict_fn = set_network_predictor(input_data=input_data,
                                       input_mask=input_mask,
                                       target_data=target_data,
                                       target_mask=target_mask,
                                       network=network)

    ################
    # load dataset #
    ################
    print 'Load data stream'
    train_dataset, train_datastream = timit_datastream(path=options['data_path'],
                                                       which_set='train',
                                                       pool_size=options['pool_size'],
                                                       maximum_frames=options['max_total_frames'],
                                                       local_copy=False)
    valid_dataset, valid_datastream = timit_datastream(path=options['data_path'],
                                                       which_set='dev',
                                                       pool_size=options['pool_size'],
                                                       maximum_frames=options['max_total_frames'],
                                                       local_copy=False)

    phone_dict = train_dataset.get_phoneme_dict()
    phoneme_dict = {k: phone_to_phoneme_dict[v] if v in phone_to_phoneme_dict else v for k, v in phone_dict.iteritems()}
    black_list = ['<START>', '<STOP>', 'q', '<END>']


    ##################
    # start training #
    ##################
    evaluation_history =[[[1000.0, 1.0], [1000.0, 1.0]]]
    check_early_stop = 0
    total_batch_cnt = 0

    print 'Start training'
    try:
        # for each epoch
        for e_idx in range(options['num_epochs']):
            # for each batch
            for b_idx, data in enumerate(train_datastream.get_epoch_iterator()):
                total_batch_cnt += 1
                if pretrain_total_batch_cnt>=total_batch_cnt:
                    continue
                # get input, target data
                train_input = data

                # get output
                train_output = training_fn(*train_input)
                train_ctc_cost = train_output[0]
                train_cost_per_char = train_output[1]
                train_regularizer_cost = train_output[2]
                network_grads_norm = train_output[3]

                # show intermediate result
                if total_batch_cnt%options['train_disp_freq'] == 0 and total_batch_cnt!=0:
                    print '============================================================================================'
                    print 'Model Name: ', options['save_path'].split('/')[-1]
                    print '============================================================================================'
                    print 'Epoch: ', str(e_idx), ', Update: ', str(total_batch_cnt)
                    print 'CTC Cost: ', str(train_ctc_cost)
                    print 'Per Char Cost: ', str(train_cost_per_char)
                    print 'Regularizer Cost: ', str(train_regularizer_cost)
                    print 'Gradient Norm: ', str(network_grads_norm)
                    print '============================================================================================'
                    print 'Train CTC Cost: ', str(evaluation_history[-1][0][0]), ', PER: ', str(evaluation_history[-1][0][-1])
                    print 'Valid CTC Cost: ', str(evaluation_history[-1][1][0]), ', PER: ', str(evaluation_history[-1][1][-1])

            # evaluation
            train_ctc_cost, train_per = network_evaluation(predict_fn=predict_fn,
                                                           data_stream=train_datastream,
                                                           phoneme_dict=phoneme_dict,
                                                           black_list=black_list)
            valid_ctc_cost, valid_per = network_evaluation(predict_fn=predict_fn,
                                                           data_stream=valid_datastream,
                                                           phoneme_dict=phoneme_dict,
                                                           black_list=black_list)

            # check over-fitting
            if valid_per>evaluation_history[-1][1][-1]:
                check_early_stop += 1.
            else:
                check_early_stop = 0.
                best_network_params_vals = get_model_param_values(network_params)
                pickle.dump(best_network_params_vals,
                            open(options['save_path'] + '_best_model.pkl', 'wb'))

            if check_early_stop>10:
                print('Training Early Stopped')
                break

            # save results
            evaluation_history.append([[train_ctc_cost, train_per],
                                       [valid_ctc_cost, valid_per]])
            numpy.savez(options['save_path'] + '_eval_history',
                        eval_history=evaluation_history)

            cur_network_params_val = get_model_param_values(network_params)
            cur_trainer_params_val = get_update_params_values(trainer_params)
            cur_total_batch_cnt = total_batch_cnt
            pickle.dump([cur_network_params_val, cur_trainer_params_val, cur_total_batch_cnt],
                        open(options['save_path'] + '_last_model.pkl', 'wb'))

    except KeyboardInterrupt:
        print('Training Interrupted')
        cur_network_params_val = get_model_param_values(network_params)
        cur_trainer_params_val = get_update_params_values(trainer_params)
        cur_total_batch_cnt = total_batch_cnt
        pickle.dump([cur_network_params_val, cur_trainer_params_val, cur_total_batch_cnt],
                    open(options['save_path'] + '_last_model.pkl', 'wb'))
示例#3
0
def main(options):
    #################
    # build network #
    #################
    print 'Build and compile network'
    # input data
    input_data = T.ftensor3('input_data')
    input_mask = T.fmatrix('input_mask')

    # target data
    target_data = T.imatrix('target_data')
    target_mask = T.fmatrix('target_mask')

    # network
    network = build_network(
        input_data=input_data,
        input_mask=input_mask,
        num_inputs=options['num_inputs'],
        num_inner_units_list=options['num_inner_units_list'],
        num_factor_units_list=options['num_factor_units_list'],
        num_outer_units_list=options['num_outer_units_list'],
        num_outputs=options['num_outputs'],
        dropout_ratio=options['dropout_ratio'],
        use_layer_norm=options['use_layer_norm'],
        weight_noise=options['weight_noise'],
        learn_init=True,
        grad_clipping=0.0)
    network_params = get_all_params(network, trainable=True)

    ###################
    # load parameters #
    ###################
    if options['load_params']:
        print 'Load parameters into network'

    #########################
    # build network trainer #
    #########################
    print 'Build network trainer'
    training_fn, trainer_params = set_network_trainer(
        input_data=input_data,
        input_mask=input_mask,
        target_data=target_data,
        target_mask=target_mask,
        network=network,
        updater=options['updater'],
        learning_rate=options['lr'],
        grad_max_norm=options['grad_norm'],
        l2_lambda=options['l2_lambda'],
        load_updater_params=options['updater_params'])

    ###########################
    # build network predictor #
    ###########################
    print 'Build network predictor'
    predict_fn = set_network_predictor(input_data=input_data,
                                       input_mask=input_mask,
                                       target_data=target_data,
                                       target_mask=target_mask,
                                       network=network)

    ################
    # load dataset #
    ################
    print 'Load data stream'
    train_datastream = framewise_timit_datastream(
        path=options['data_path'],
        which_set='train',
        batch_size=options['batch_size'],
        local_copy=False)
    valid_datastream = framewise_timit_datastream(
        path=options['data_path'],
        which_set='test',
        batch_size=options['batch_size'],
        local_copy=False)

    ##################
    # start training #
    ##################
    evaluation_history = [[[1000.0, 1000.0, 1.0], [1000.0, 1000.0, 1.0]]]
    check_early_stop = 0
    total_batch_cnt = 0

    print 'Start training'
    try:
        # for each epoch
        for e_idx in range(options['num_epochs']):
            # for each batch
            for b_idx, data in enumerate(
                    train_datastream.get_epoch_iterator()):
                # get input, target data
                train_input = data

                # get output
                train_output = training_fn(*train_input)
                train_predict_cost = train_output[2]
                train_regularizer_cost = train_output[3]
                network_grads_norm = train_output[4]

                # count batch
                total_batch_cnt += 1

                # show intermediate result
                if total_batch_cnt % options[
                        'train_disp_freq'] == 0 and total_batch_cnt != 0:
                    print '============================================================================================'
                    print 'Model Name: ', options['save_path'].split('/')[-1]
                    print '============================================================================================'
                    print 'Epoch: ', str(e_idx), ', Update: ', str(
                        total_batch_cnt)
                    print 'Prediction Cost: ', str(train_predict_cost)
                    print 'Regularizer Cost: ', str(train_regularizer_cost)
                    print 'Gradient Norm: ', str(network_grads_norm)
                    print '============================================================================================'
                    print 'Train NLL: ', str(
                        evaluation_history[-1][0][0]), ', BPC: ', str(
                            evaluation_history[-1][0][1]), ', PER: ', str(
                                evaluation_history[-1][0][2])
                    print 'Valid NLL: ', str(
                        evaluation_history[-1][1][0]), ', BPC: ', str(
                            evaluation_history[-1][1][1]), ', PER: ', str(
                                evaluation_history[-1][1][2])

            # evaluation
            train_nll, train_bpc, train_per = network_evaluation(
                predict_fn, train_datastream)
            valid_nll, valid_bpc, valid_per = network_evaluation(
                predict_fn, valid_datastream)

            # check over-fitting
            if valid_per > evaluation_history[-1][1][2]:
                check_early_stop += 1.
            else:
                check_early_stop = 0.
                best_network_params_vals = get_model_param_values(
                    network_params)
                pickle.dump(
                    best_network_params_vals,
                    open(options['save_path'] + '_best_model.pkl', 'wb'))

            if check_early_stop > 3:
                print('Training Early Stopped')
                break

            # save results
            evaluation_history.append([[train_nll, train_bpc, train_per],
                                       [valid_nll, valid_bpc, valid_per]])
            numpy.savez(options['save_path'] + '_eval_history',
                        eval_history=evaluation_history)

            cur_network_params_val = get_model_param_values(network_params)
            cur_trainer_params_val = get_update_params_values(trainer_params)
            cur_total_batch_cnt = total_batch_cnt
            pickle.dump([
                cur_network_params_val, cur_trainer_params_val,
                cur_total_batch_cnt
            ], open(options['save_path'] + '_last_model.pkl', 'wb'))

    except KeyboardInterrupt:
        print('Training Interrupted')
        cur_network_params_val = get_model_param_values(network_params)
        cur_trainer_params_val = get_update_params_values(trainer_params)
        cur_total_batch_cnt = total_batch_cnt
        pickle.dump([
            cur_network_params_val, cur_trainer_params_val, cur_total_batch_cnt
        ], open(options['save_path'] + '_last_model.pkl', 'wb'))
示例#4
0
def main(options):
    print 'Build and compile network'
    input_data = T.ftensor3('input_data')
    input_mask = T.fmatrix('input_mask')
    target_data = T.imatrix('target_data')
    target_mask = T.fmatrix('target_mask')

    network_outputs = build_network(
        input_data=input_data,
        input_mask=input_mask,
        num_inputs=options['num_inputs'],
        num_outputs=options['num_outputs'],
        num_inner_units_list=options['num_inner_units_list'],
        num_outer_units_list=options['num_outer_units_list'],
        use_peepholes=options['use_peepholes'],
        use_layer_norm=options['use_layer_norm'],
        learn_init=options['learn_init'],
        grad_clipping=options['grad_clip'])

    network = network_outputs[-1]
    inner_loop_layers = network_outputs[:-1]

    network_params = get_all_params(network, trainable=True)

    print("number of parameters in model: %d" %
          count_params(network, trainable=True))

    if options['reload_model']:
        print('Loading Parameters...')
        [
            pretrain_network_params_val, pretrain_update_params_val,
            pretrain_total_batch_cnt
        ] = pickle.load(open(options['reload_model'], 'rb'))

        print('Applying Parameters...')
        set_model_param_value(network_params, pretrain_network_params_val)
    else:
        pretrain_update_params_val = None
        pretrain_total_batch_cnt = 0

    print 'Build network trainer'
    train_lr = theano.shared(convert_to_floatX(options['lr']))
    training_fn, trainer_params = set_network_trainer(
        input_data=input_data,
        input_mask=input_mask,
        target_data=target_data,
        target_mask=target_mask,
        num_outputs=options['num_outputs'],
        network=network,
        inner_loop_layers=inner_loop_layers,
        updater=options['updater'],
        learning_rate=train_lr,
        grad_max_norm=options['grad_norm'],
        l2_lambda=options['l2_lambda'],
        load_updater_params=pretrain_update_params_val)

    print 'Build network predictor'
    predict_fn = set_network_predictor(input_data=input_data,
                                       input_mask=input_mask,
                                       target_data=target_data,
                                       target_mask=target_mask,
                                       num_outputs=options['num_outputs'],
                                       network=network)

    # evaluation
    if options['reload_model']:
        train_eval_datastream = get_datastream(
            path=options['data_path'],
            norm_path=options['norm_data_path'],
            which_set='train_si84',
            batch_size=options['eval_batch_size'])
        valid_eval_datastream = get_datastream(
            path=options['data_path'],
            norm_path=options['norm_data_path'],
            which_set='test_dev93',
            batch_size=options['eval_batch_size'])
        train_nll, train_bpc, train_fer = network_evaluation(
            predict_fn, train_eval_datastream)
        valid_nll, valid_bpc, valid_fer = network_evaluation(
            predict_fn, valid_eval_datastream)
        print '======================================================='
        print 'Train NLL: ', str(train_nll), ', FER: ', str(train_fer)
        print 'Valid NLL: ', str(valid_nll), ', FER: ', str(valid_fer)
        print '======================================================='

    print 'Load data stream'
    train_datastream = get_datastream(path=options['data_path'],
                                      norm_path=options['norm_data_path'],
                                      which_set='train_si84',
                                      batch_size=options['batch_size'])

    print 'Start training'
    if os.path.exists(options['save_path'] + '_eval_history.npz'):
        evaluation_history = numpy.load(
            options['save_path'] +
            '_eval_history.npz')['eval_history'].tolist()
    else:
        evaluation_history = [[[100.0, 100.0, 1.0], [100.0, 100.0, 1.0]]]

    total_batch_cnt = 0
    start_time = time.time()
    try:
        # for each epoch
        for e_idx in range(options['num_epochs']):
            # for each batch
            for b_idx, data in enumerate(
                    train_datastream.get_epoch_iterator()):
                total_batch_cnt += 1

                if pretrain_total_batch_cnt >= total_batch_cnt:
                    continue

                # get input, target data
                input_data = data[0].astype(floatX)
                input_mask = data[1].astype(floatX)

                # get target data
                target_data = data[2]
                target_mask = data[3].astype(floatX)

                # get output
                train_output = training_fn(input_data, input_mask, target_data,
                                           target_mask)
                train_predict_cost = train_output[0]
                network_grads_norm = train_output[1]
                train_sf_cost0 = train_output[2]
                train_sf_cost1 = train_output[3]
                train_sf_cost2 = train_output[4]

                print('=====================================================')
                print(total_batch_cnt, train_predict_cost, network_grads_norm)
                print(train_sf_cost0, train_sf_cost1, train_sf_cost2)

                if numpy.isnan(train_predict_cost) or numpy.isnan(
                        network_grads_norm):
                    print('update cnt: ', total_batch_cnt)
                    print('NaN detected: ', train_predict_cost,
                          network_grads_norm)
                    raw_input()

                # show intermediate result
                if total_batch_cnt % options[
                        'train_disp_freq'] == 0 and total_batch_cnt != 0:
                    best_idx = numpy.asarray(evaluation_history)[:, 1,
                                                                 2].argmin()
                    print '============================================================================================'
                    print 'Model Name: ', options['save_path'].split('/')[-1]
                    print '============================================================================================'
                    print 'Epoch: ', str(e_idx), ', Update: ', str(
                        total_batch_cnt), ', Time: ', str(time.time() -
                                                          start_time)
                    print '--------------------------------------------------------------------------------------------'
                    print 'Prediction Cost: ', str(train_predict_cost)
                    print 'Gradient Norm: ', str(network_grads_norm)
                    print '--------------------------------------------------------------------------------------------'
                    print 'Learn Rate: ', str(train_lr.get_value())
                    print '--------------------------------------------------------------------------------------------'
                    print 'Train NLL: ', str(
                        evaluation_history[-1][0][0]), ', BPC: ', str(
                            evaluation_history[-1][0][1]), ', FER: ', str(
                                evaluation_history[-1][0][2])
                    print 'Valid NLL: ', str(
                        evaluation_history[-1][1][0]), ', BPC: ', str(
                            evaluation_history[-1][1][1]), ', FER: ', str(
                                evaluation_history[-1][1][2])
                    print '--------------------------------------------------------------------------------------------'
                    print 'Best NLL: ', str(
                        evaluation_history[best_idx][1][0]), ', BPC: ', str(
                            evaluation_history[best_idx][1]
                            [1]), ', FER: ', str(
                                evaluation_history[best_idx][1][2])
                    start_time = time.time()

                # # evaluation
                # if total_batch_cnt%options['train_eval_freq'] == 0 and total_batch_cnt!=0:
                #     train_eval_datastream = get_datastream(path=options['data_path'],
                #                                            norm_path=options['norm_data_path'],
                #                                            which_set='train_si84',
                #                                            batch_size=options['eval_batch_size'])
                #     valid_eval_datastream = get_datastream(path=options['data_path'],
                #                                            norm_path=options['norm_data_path'],
                #                                            which_set='test_dev93',
                #                                            batch_size=options['eval_batch_size'])
                #     train_nll, train_bpc, train_fer = network_evaluation(predict_fn,
                #                                                          train_eval_datastream)
                #     valid_nll, valid_bpc, valid_fer = network_evaluation(predict_fn,
                #                                                          valid_eval_datastream)
                #
                #     # check over-fitting
                #     if valid_fer<numpy.asarray(evaluation_history)[:, 1, 2].min():
                #         best_network_params_vals = get_model_param_values(network_params)
                #         pickle.dump(best_network_params_vals,
                #                     open(options['save_path'] + '_best_model.pkl', 'wb'))
                #
                #     # save results
                #     evaluation_history.append([[train_nll, train_bpc, train_fer],
                #                                [valid_nll, valid_bpc, valid_fer]])
                #     numpy.savez(options['save_path'] + '_eval_history',
                #                 eval_history=evaluation_history)

                # save network
                if total_batch_cnt % options[
                        'train_save_freq'] == 0 and total_batch_cnt != 0:
                    cur_network_params_val = get_model_param_values(
                        network_params)
                    cur_trainer_params_val = get_update_params_values(
                        trainer_params)
                    cur_total_batch_cnt = total_batch_cnt
                    pickle.dump([
                        cur_network_params_val, cur_trainer_params_val,
                        cur_total_batch_cnt
                    ],
                                open(
                                    options['save_path'] +
                                    str(total_batch_cnt).zfill(10) +
                                    '_model.pkl', 'wb'))

    except KeyboardInterrupt:
        print 'Training Interrupted'
        cur_network_params_val = get_model_param_values(network_params)
        cur_trainer_params_val = get_update_params_values(trainer_params)
        cur_total_batch_cnt = total_batch_cnt
        pickle.dump([
            cur_network_params_val, cur_trainer_params_val, cur_total_batch_cnt
        ], open(options['save_path'] + '_last_model.pkl', 'wb'))
示例#5
0
def main(options):
    print 'Build and compile network'
    input_data = T.ftensor3('input_data')
    input_mask = T.fmatrix('input_mask')
    target_data = T.imatrix('target_data')
    target_mask = T.fmatrix('target_mask')

    skip_scale = theano.shared(convert_to_floatX(options['skip_scale']))

    network, rand_layer_list = build_network(
        input_data=input_data,
        input_mask=input_mask,
        num_inputs=options['num_inputs'],
        num_units_list=options['num_units_list'],
        num_outputs=options['num_outputs'],
        skip_scale=skip_scale,
        dropout_ratio=options['dropout_ratio'],
        weight_noise=options['weight_noise'],
        use_layer_norm=options['use_layer_norm'],
        peepholes=options['peepholes'],
        learn_init=options['learn_init'],
        grad_clipping=options['grad_clipping'],
        gradient_steps=options['gradient_steps'],
        use_projection=options['use_projection'])

    network_params = get_all_params(network, trainable=True)

    print("number of parameters in model: %d" %
          count_params(network, trainable=True))

    if options['reload_model']:
        print('Loading Parameters...')
        pretrain_network_params_val, pretrain_update_params_val, pretrain_total_batch_cnt = pickle.load(
            open(options['reload_model'], 'rb'))

        print('Applying Parameters...')
        set_model_param_value(network_params, pretrain_network_params_val)
    else:
        pretrain_update_params_val = None
        pretrain_total_batch_cnt = 0

    print 'Build network trainer'
    training_fn, trainer_params = set_network_trainer(
        input_data=input_data,
        input_mask=input_mask,
        target_data=target_data,
        target_mask=target_mask,
        num_outputs=options['num_outputs'],
        network=network,
        rand_layer_list=rand_layer_list,
        updater=options['updater'],
        learning_rate=options['lr'],
        grad_max_norm=options['grad_norm'],
        l2_lambda=options['l2_lambda'],
        load_updater_params=pretrain_update_params_val)

    print 'Build network predictor'
    predict_fn = set_network_predictor(input_data=input_data,
                                       input_mask=input_mask,
                                       target_data=target_data,
                                       target_mask=target_mask,
                                       num_outputs=options['num_outputs'],
                                       network=network)

    print 'Load data stream'
    train_datastream = get_datastream(path=options['data_path'],
                                      norm_path=options['norm_data_path'],
                                      which_set='train_si84',
                                      batch_size=options['batch_size'])

    print 'Start training'
    if os.path.exists(options['save_path'] + '_eval_history.npz'):
        evaluation_history = numpy.load(
            options['save_path'] +
            '_eval_history.npz')['eval_history'].tolist()
    else:
        evaluation_history = [[[10.0, 10.0, 1.0], [10.0, 10.0, 1.0]]]
    early_stop_flag = False
    early_stop_cnt = 0
    total_batch_cnt = 0

    try:
        # for each epoch
        for e_idx in range(options['num_epochs']):
            # for each batch
            for b_idx, data in enumerate(
                    train_datastream.get_epoch_iterator()):
                total_batch_cnt += 1
                if pretrain_total_batch_cnt >= total_batch_cnt:
                    continue

                # get input, target data
                input_data = data[0].astype(floatX)
                input_mask = data[1].astype(floatX)

                # get target data
                target_data = data[2]
                target_mask = data[3].astype(floatX)

                # get output
                train_output = training_fn(input_data, input_mask, target_data,
                                           target_mask)
                train_predict_cost = train_output[0]
                network_grads_norm = train_output[1]
                skip_means = train_output[2:]

                # show intermediate result
                if total_batch_cnt % options[
                        'train_disp_freq'] == 0 and total_batch_cnt != 0:
                    # pdb.set_trace()
                    best_idx = numpy.asarray(evaluation_history)[:, 1,
                                                                 2].argmin()
                    print '============================================================================================'
                    print 'Model Name: ', options['save_path'].split('/')[-1]
                    print '============================================================================================'
                    print 'Epoch: ', str(e_idx), ', Update: ', str(
                        total_batch_cnt)
                    print '--------------------------------------------------------------------------------------------'
                    print 'Prediction Cost: ', str(train_predict_cost)
                    print 'Gradient Norm: ', str(network_grads_norm)
                    print '--------------------------------------------------------------------------------------------'
                    print 'Skip Ratio: ', skip_means
                    print 'Skip Scale: ', str(skip_scale.get_value())
                    print '--------------------------------------------------------------------------------------------'
                    print 'Train NLL: ', str(
                        evaluation_history[-1][0][0]), ', BPC: ', str(
                            evaluation_history[-1][0][1]), ', FER: ', str(
                                evaluation_history[-1][0][2])
                    print 'Valid NLL: ', str(
                        evaluation_history[-1][1][0]), ', BPC: ', str(
                            evaluation_history[-1][1][1]), ', FER: ', str(
                                evaluation_history[-1][1][2])
                    print '--------------------------------------------------------------------------------------------'
                    print 'Best NLL: ', str(
                        evaluation_history[best_idx][1][0]), ', BPC: ', str(
                            evaluation_history[best_idx][1]
                            [1]), ', FER: ', str(
                                evaluation_history[best_idx][1][2])

                # evaluation
                if total_batch_cnt % options[
                        'train_eval_freq'] == 0 and total_batch_cnt != 0:
                    train_eval_datastream = get_datastream(
                        path=options['data_path'],
                        norm_path=options['norm_data_path'],
                        which_set='train_si84',
                        batch_size=options['eval_batch_size'])
                    valid_eval_datastream = get_datastream(
                        path=options['data_path'],
                        norm_path=options['norm_data_path'],
                        which_set='test_dev93',
                        batch_size=options['eval_batch_size'])
                    train_nll, train_bpc, train_fer = network_evaluation(
                        predict_fn, train_eval_datastream)
                    valid_nll, valid_bpc, valid_fer = network_evaluation(
                        predict_fn, valid_eval_datastream)

                    # check over-fitting
                    if valid_fer > numpy.asarray(evaluation_history)[:, 1,
                                                                     2].min():
                        early_stop_cnt += 1.
                    else:
                        early_stop_cnt = 0.
                        best_network_params_vals = get_model_param_values(
                            network_params)
                        pickle.dump(
                            best_network_params_vals,
                            open(options['save_path'] + '_best_model.pkl',
                                 'wb'))

                    if early_stop_cnt > 10:
                        early_stop_flag = True
                        break

                    # save results
                    evaluation_history.append(
                        [[train_nll, train_bpc, train_fer],
                         [valid_nll, valid_bpc, valid_fer]])
                    numpy.savez(options['save_path'] + '_eval_history',
                                eval_history=evaluation_history)

                # save network
                if total_batch_cnt % options[
                        'train_save_freq'] == 0 and total_batch_cnt != 0:
                    cur_network_params_val = get_model_param_values(
                        network_params)
                    cur_trainer_params_val = get_update_params_values(
                        trainer_params)
                    cur_total_batch_cnt = total_batch_cnt
                    pickle.dump([
                        cur_network_params_val, cur_trainer_params_val,
                        cur_total_batch_cnt
                    ], open(options['save_path'] + '_last_model.pkl', 'wb'))

                if total_batch_cnt % 1000 == 0 and total_batch_cnt != 0:
                    skip_scale.set_value(
                        convert_to_floatX(skip_scale.get_value() * 1.01))

            if early_stop_flag:
                break

    except KeyboardInterrupt:
        print 'Training Interrupted'
        cur_network_params_val = get_model_param_values(network_params)
        cur_trainer_params_val = get_update_params_values(trainer_params)
        cur_total_batch_cnt = total_batch_cnt
        pickle.dump([
            cur_network_params_val, cur_trainer_params_val, cur_total_batch_cnt
        ], open(options['save_path'] + '_last_model.pkl', 'wb'))
示例#6
0
                    print '--------------------------------------------------------------------------------------------'
                    print 'Train NLL: ', str(evaluation_history[-1][0][0]), ', BPC: ', str(evaluation_history[-1][0][1]), ', FER: ', str(evaluation_history[-1][0][2])
                    print 'Valid NLL: ', str(evaluation_history[-1][1][0]), ', BPC: ', str(evaluation_history[-1][1][1]), ', FER: ', str(evaluation_history[-1][1][2])

            # evaluation
            train_nll, train_bpc, train_per = network_evaluation(predict_fn,
                                                                 train_datastream)
            valid_nll, valid_bpc, valid_per = network_evaluation(predict_fn,
                                                                 valid_datastream)

            # check over-fitting
            if valid_per>evaluation_history[-1][1][2]:
                check_early_stop += 1.
            else:
                check_early_stop = 0.
                best_network_params_vals = get_model_param_values(network_params)
                pickle.dump(best_network_params_vals,
                            open(options['save_path'] + '_best_model.pkl', 'wb'))

            if check_early_stop>10:
                print('Training Early Stopped')
                break

            # save results
            evaluation_history.append([[train_nll, train_bpc, train_per],
                                       [valid_nll, valid_bpc, valid_per]])
            numpy.savez(options['save_path'] + '_eval_history',
                        eval_history=evaluation_history)

            cur_network_params_val = get_model_param_values(network_params)
            cur_trainer_params_val = get_update_params_values(trainer_params)