Beispiel #1
0
def train_DNN(train_xy_file_list, valid_xy_file_list, \
              nnets_file_name, n_ins, n_outs, ms_outs, hyper_params, buffer_size, plot=False):
    # get loggers for this function
    # this one writes to both console and file
    logger = logging.getLogger("main.train_DNN")
    logger.debug('Starting train_DNN')

    if plot:
        # this one takes care of plotting duties
        plotlogger = logging.getLogger("plotting")
        # create an (empty) plot of training convergence, ready to receive data points
        logger.create_plot('training convergence', MultipleSeriesPlot)

    try:
        assert numpy.sum(ms_outs) == n_outs
    except AssertionError:
        logger.critical('the summation of multi-stream outputs does not equal to %d' % (n_outs))
        raise

    ####parameters#####
    finetune_lr = numpy.asarray(hyper_params['learning_rate'], dtype='float32')
    training_epochs = int(hyper_params['training_epochs'])
    batch_size = int(hyper_params['batch_size'])
    l1_reg = float(hyper_params['l1_reg'])
    l2_reg = float(hyper_params['l2_reg'])
    #     private_l2_reg  = float(hyper_params['private_l2_reg'])
    warmup_epoch = int(hyper_params['warmup_epoch'])
    momentum = float(hyper_params['momentum'])
    warmup_momentum = float(hyper_params['warmup_momentum'])

    use_rprop = int(hyper_params['use_rprop'])

    use_rprop = int(hyper_params['use_rprop'])

    hidden_layers_sizes = hyper_params['hidden_layer_size']

    #     stream_weights       = hyper_params['stream_weights']
    #     private_hidden_sizes = hyper_params['private_hidden_sizes']

    buffer_utt_size = buffer_size
    early_stop_epoch = int(hyper_params['early_stop_epochs'])

    hidden_activation = hyper_params['hidden_activation']
    output_activation = hyper_params['output_activation']

    #     stream_lr_weights = hyper_params['stream_lr_weights']
    #     use_private_hidden = hyper_params['use_private_hidden']

    model_type = hyper_params['model_type']

    ## use a switch to turn on pretraining
    ## pretraining may not help too much, if this case, we turn it off to save time
    do_pretraining = hyper_params['do_pretraining']
    pretraining_epochs = int(hyper_params['pretraining_epochs'])
    pretraining_lr = float(hyper_params['pretraining_lr'])

    buffer_size = int(buffer_size / batch_size) * batch_size

    ###################
    (train_x_file_list, train_y_file_list) = train_xy_file_list
    (valid_x_file_list, valid_y_file_list) = valid_xy_file_list

    logger.debug('Creating training   data provider')
    train_data_reader = ListDataProvider(x_file_list=train_x_file_list, y_file_list=train_y_file_list, n_ins=n_ins,
                                         n_outs=n_outs, buffer_size=buffer_size, shuffle=True)

    logger.debug('Creating validation data provider')
    valid_data_reader = ListDataProvider(x_file_list=valid_x_file_list, y_file_list=valid_y_file_list, n_ins=n_ins,
                                         n_outs=n_outs, buffer_size=buffer_size, shuffle=False)

    shared_train_set_xy, temp_train_set_x, temp_train_set_y = train_data_reader.load_next_partition()
    train_set_x, train_set_y = shared_train_set_xy
    shared_valid_set_xy, temp_valid_set_x, temp_valid_set_y = valid_data_reader.load_next_partition()
    valid_set_x, valid_set_y = shared_valid_set_xy
    train_data_reader.reset()
    valid_data_reader.reset()

    ##temporally we use the training set as pretrain_set_x.
    ##we need to support any data for pretraining
    pretrain_set_x = train_set_x

    # numpy random generator
    numpy_rng = numpy.random.RandomState(123)
    logger.info('building the model')

    dnn_model = None
    pretrain_fn = None  ## not all the model support pretraining right now
    train_fn = None
    valid_fn = None
    valid_model = None  ## valid_fn and valid_model are the same. reserve to computer multi-stream distortion
    if model_type == 'DNN':
        dnn_model = DNN(numpy_rng=numpy_rng, n_ins=n_ins, n_outs=n_outs,
                        l1_reg=l1_reg, l2_reg=l2_reg,
                        hidden_layers_sizes=hidden_layers_sizes,
                        hidden_activation=hidden_activation,
                        output_activation=output_activation,
                        use_rprop=use_rprop, rprop_init_update=finetune_lr)
        train_fn, valid_fn = dnn_model.build_finetune_functions(
            (train_set_x, train_set_y), (valid_set_x, valid_set_y), batch_size=batch_size)

    else:
        logger.critical('%s type NN model is not supported!' % (model_type))
        raise

    logger.info('fine-tuning the %s model' % (model_type))

    start_time = time.clock()

    best_dnn_model = dnn_model
    best_validation_loss = sys.float_info.max
    previous_loss = sys.float_info.max

    early_stop = 0
    epoch = 0
    previous_finetune_lr = finetune_lr

    while (epoch < training_epochs):
        epoch = epoch + 1

        current_momentum = momentum
        current_finetune_lr = finetune_lr
        if epoch <= warmup_epoch:
            current_finetune_lr = finetune_lr
            current_momentum = warmup_momentum
        else:
            current_finetune_lr = previous_finetune_lr * 0.5

        previous_finetune_lr = current_finetune_lr

        train_error = []
        sub_start_time = time.clock()

        while (not train_data_reader.is_finish()):
            shared_train_set_xy, temp_train_set_x, temp_train_set_y = train_data_reader.load_next_partition()
            train_set_x.set_value(numpy.asarray(temp_train_set_x, dtype=theano.config.floatX), borrow=True)
            train_set_y.set_value(numpy.asarray(temp_train_set_y, dtype=theano.config.floatX), borrow=True)

            n_train_batches = train_set_x.get_value().shape[0] / batch_size

            logger.debug('this partition: %d frames (divided into %d batches of size %d)' % (
            train_set_x.get_value(borrow=True).shape[0], n_train_batches, batch_size))

            for minibatch_index in range(n_train_batches):
                this_train_error = train_fn(minibatch_index, current_finetune_lr, current_momentum)
                train_error.append(this_train_error)

                if numpy.isnan(this_train_error):
                    logger.warning('training error over minibatch %d of %d was %s' % (
                    minibatch_index + 1, n_train_batches, this_train_error))

        train_data_reader.reset()

        logger.debug('calculating validation loss')
        validation_losses = valid_fn()
        this_validation_loss = numpy.mean(validation_losses)

        # this has a possible bias if the minibatches were not all of identical size
        # but it should not be siginficant if minibatches are small
        this_train_valid_loss = numpy.mean(train_error)

        sub_end_time = time.clock()

        loss_difference = this_validation_loss - previous_loss

        logger.info('epoch %i, validation error %f, train error %f  time spent %.2f' % (
        epoch, this_validation_loss, this_train_valid_loss, (sub_end_time - sub_start_time)))
        if plot:
            plotlogger.add_plot_point('training convergence', 'validation set', (epoch, this_validation_loss))
            plotlogger.add_plot_point('training convergence', 'training set', (epoch, this_train_valid_loss))
            plotlogger.save_plot('training convergence', title='Progress of training and validation error',
                                 xlabel='epochs', ylabel='error')

        if this_validation_loss < best_validation_loss:
            best_dnn_model = dnn_model
            best_validation_loss = this_validation_loss
            logger.debug('validation loss decreased, so saving model')
            early_stop = 0
        else:
            logger.debug('validation loss did not improve')
            dbn = best_dnn_model
            early_stop += 1

        if early_stop >= early_stop_epoch:
            # too many consecutive epochs without surpassing the best model
            logger.debug('stopping early')
            break

        if math.isnan(this_validation_loss):
            break

        previous_loss = this_validation_loss

    end_time = time.clock()
    pickle.dump(best_dnn_model, open(nnets_file_name, 'wb'))

    logger.info(
        'overall  training time: %.2fm validation error %f' % ((end_time - start_time) / 60., best_validation_loss))

    if plot:
        plotlogger.save_plot('training convergence', title='Final training and validation error', xlabel='epochs',
                             ylabel='error')

    return best_validation_loss
Beispiel #2
0
def train_DNN(train_xy_file_list, valid_xy_file_list, \
              nnets_file_name, n_ins, n_outs, ms_outs, hyper_params, buffer_size, plot=False):

    # get loggers for this function
    # this one writes to both console and file
    logger = logging.getLogger("main.train_DNN")
    logger.debug('Starting train_DNN')

    if plot:
        # this one takes care of plotting duties
        plotlogger = logging.getLogger("plotting")
        # create an (empty) plot of training convergence, ready to receive data points
        logger.create_plot('training convergence',MultipleSeriesPlot)

    try:
        assert numpy.sum(ms_outs) == n_outs
    except AssertionError:
        logger.critical('the summation of multi-stream outputs does not equal to %d' %(n_outs))
        raise

    ####parameters#####
    finetune_lr     = numpy.asarray(hyper_params['learning_rate'],  dtype='float32')
    training_epochs = int(hyper_params['training_epochs'])
    batch_size      = int(hyper_params['batch_size'])
    l1_reg          = float(hyper_params['l1_reg'])
    l2_reg          = float(hyper_params['l2_reg'])
#     private_l2_reg  = float(hyper_params['private_l2_reg'])
    warmup_epoch    = int(hyper_params['warmup_epoch'])
    momentum        = float(hyper_params['momentum'])
    warmup_momentum = float(hyper_params['warmup_momentum'])

    use_rprop = int(hyper_params['use_rprop'])

    use_rprop = int(hyper_params['use_rprop'])

    hidden_layers_sizes = hyper_params['hidden_layer_size']

#     stream_weights       = hyper_params['stream_weights']
#     private_hidden_sizes = hyper_params['private_hidden_sizes']

    buffer_utt_size = buffer_size
    early_stop_epoch = int(hyper_params['early_stop_epochs'])

    hidden_activation = hyper_params['hidden_activation']
    output_activation = hyper_params['output_activation']

#     stream_lr_weights = hyper_params['stream_lr_weights']
#     use_private_hidden = hyper_params['use_private_hidden']

    model_type = hyper_params['model_type']

    ## use a switch to turn on pretraining
    ## pretraining may not help too much, if this case, we turn it off to save time
    do_pretraining = hyper_params['do_pretraining']
    pretraining_epochs = int(hyper_params['pretraining_epochs'])
    pretraining_lr = float(hyper_params['pretraining_lr'])


    buffer_size = int(buffer_size / batch_size) * batch_size

    ###################
    (train_x_file_list, train_y_file_list) = train_xy_file_list
    (valid_x_file_list, valid_y_file_list) = valid_xy_file_list

    logger.debug('Creating training   data provider')
    train_data_reader = ListDataProvider(x_file_list = train_x_file_list, y_file_list = train_y_file_list, n_ins = n_ins, n_outs = n_outs, buffer_size = buffer_size, shuffle = True)

    logger.debug('Creating validation data provider')
    valid_data_reader = ListDataProvider(x_file_list = valid_x_file_list, y_file_list = valid_y_file_list, n_ins = n_ins, n_outs = n_outs, buffer_size = buffer_size, shuffle = False)

    shared_train_set_xy, temp_train_set_x, temp_train_set_y = train_data_reader.load_next_partition()
    train_set_x, train_set_y = shared_train_set_xy
    shared_valid_set_xy, temp_valid_set_x, temp_valid_set_y = valid_data_reader.load_next_partition()
    valid_set_x, valid_set_y = shared_valid_set_xy
    train_data_reader.reset()
    valid_data_reader.reset()

    ##temporally we use the training set as pretrain_set_x.
    ##we need to support any data for pretraining
    pretrain_set_x = train_set_x

    # numpy random generator
    numpy_rng = numpy.random.RandomState(123)
    logger.info('building the model')


    dnn_model = None
    pretrain_fn = None  ## not all the model support pretraining right now
    train_fn = None
    valid_fn = None
    valid_model = None ## valid_fn and valid_model are the same. reserve to computer multi-stream distortion
    if model_type == 'DNN':
        dnn_model = DNN(numpy_rng=numpy_rng, n_ins=n_ins, n_outs = n_outs,
                        l1_reg = l1_reg, l2_reg = l2_reg,
                         hidden_layers_sizes = hidden_layers_sizes,
                          hidden_activation = hidden_activation,
                          output_activation = output_activation,
                          use_rprop = use_rprop, rprop_init_update=finetune_lr)
        train_fn, valid_fn = dnn_model.build_finetune_functions(
                    (train_set_x, train_set_y), (valid_set_x, valid_set_y), batch_size=batch_size)

    else:
        logger.critical('%s type NN model is not supported!' %(model_type))
        raise

    logger.info('fine-tuning the %s model' %(model_type))

    start_time = time.clock()

    best_dnn_model = dnn_model
    best_validation_loss = sys.float_info.max
    previous_loss = sys.float_info.max

    early_stop = 0
    epoch = 0
    previous_finetune_lr = finetune_lr

    while (epoch < training_epochs):
        epoch = epoch + 1

        current_momentum = momentum
        current_finetune_lr = finetune_lr
        if epoch <= warmup_epoch:
            current_finetune_lr = finetune_lr
            current_momentum = warmup_momentum
        else:
            current_finetune_lr = previous_finetune_lr * 0.5

        previous_finetune_lr = current_finetune_lr

        train_error = []
        sub_start_time = time.clock()

        while (not train_data_reader.is_finish()):
            shared_train_set_xy, temp_train_set_x, temp_train_set_y = train_data_reader.load_next_partition()
            train_set_x.set_value(numpy.asarray(temp_train_set_x, dtype=theano.config.floatX), borrow=True)
            train_set_y.set_value(numpy.asarray(temp_train_set_y, dtype=theano.config.floatX), borrow=True)

            n_train_batches = train_set_x.get_value().shape[0] / batch_size

            logger.debug('this partition: %d frames (divided into %d batches of size %d)' %(train_set_x.get_value(borrow=True).shape[0], n_train_batches, batch_size) )

            for minibatch_index in range(n_train_batches):
                this_train_error = train_fn(minibatch_index, current_finetune_lr, current_momentum)
                train_error.append(this_train_error)

                if numpy.isnan(this_train_error):
                    logger.warning('training error over minibatch %d of %d was %s' % (minibatch_index+1,n_train_batches,this_train_error) )

        train_data_reader.reset()

        logger.debug('calculating validation loss')
        validation_losses = valid_fn()
        this_validation_loss = numpy.mean(validation_losses)

        # this has a possible bias if the minibatches were not all of identical size
        # but it should not be siginficant if minibatches are small
        this_train_valid_loss = numpy.mean(train_error)

        sub_end_time = time.clock()

        loss_difference = this_validation_loss - previous_loss

        logger.info('epoch %i, validation error %f, train error %f  time spent %.2f' %(epoch, this_validation_loss, this_train_valid_loss, (sub_end_time - sub_start_time)))
        if plot:
            plotlogger.add_plot_point('training convergence','validation set',(epoch,this_validation_loss))
            plotlogger.add_plot_point('training convergence','training set',(epoch,this_train_valid_loss))
            plotlogger.save_plot('training convergence',title='Progress of training and validation error',xlabel='epochs',ylabel='error')

        if this_validation_loss < best_validation_loss:
            best_dnn_model = dnn_model
            best_validation_loss = this_validation_loss
            logger.debug('validation loss decreased, so saving model')
            early_stop = 0
        else:
            logger.debug('validation loss did not improve')
            dbn = best_dnn_model
            early_stop += 1

        if early_stop >= early_stop_epoch:
            # too many consecutive epochs without surpassing the best model
            logger.debug('stopping early')
            break

        if math.isnan(this_validation_loss):
            break

        previous_loss = this_validation_loss

    end_time = time.clock()
    pickle.dump(best_dnn_model, open(nnets_file_name, 'wb'))

    logger.info('overall  training time: %.2fm validation error %f' % ((end_time - start_time) / 60., best_validation_loss))

    if plot:
        plotlogger.save_plot('training convergence',title='Final training and validation error',xlabel='epochs',ylabel='error')

    return  best_validation_loss