Exemplo n.º 1
0
def main(configuration, is_chief=False):

    l1_reg_weight = configuration['l1_reg_weight']
    l2_reg_weight = configuration['l2_reg_weight']
    #  time_steps*nb_samples
    src = K.placeholder(shape=(None, None), dtype='int32')
    src_mask = K.placeholder(shape=(None, None))
    trg = K.placeholder(shape=(None, None), dtype='int32')
    trg_mask = K.placeholder(shape=(None, None))

    # for fast training of new parameters
    ite = K.placeholder(ndim=0)

    enc_dec = EncoderDecoder(**configuration)

    softmax_output_num_sampled = configuration['softmax_output_num_sampled']

    enc_dec.build_trainer(
        src,
        src_mask,
        trg,
        trg_mask,
        ite,
        l1_reg_weight=l1_reg_weight,
        l2_reg_weight=l2_reg_weight,
        softmax_output_num_sampled=softmax_output_num_sampled)

    enc_dec.build_sampler()

    # Chief is responsible for initializing and loading model states

    if is_chief:
        init_op = tf.initialize_all_variables()
        init_fn = K.function(inputs=[], outputs=[init_op])
        init_fn([])

        if configuration['reload']:
            enc_dec.load()

    sample_search = BeamSearch(enc_dec=enc_dec,
                               configuration=configuration,
                               beam_size=1,
                               maxlen=configuration['seq_len_src'],
                               stochastic=True)

    valid_search = BeamSearch(enc_dec=enc_dec,
                              configuration=configuration,
                              beam_size=configuration['beam_size'],
                              maxlen=3 * configuration['seq_len_src'],
                              stochastic=False)

    sampler = Sampler(sample_search, **configuration)
    bleuvalidator = BleuValidator(valid_search, **configuration)

    # train function
    train_fn = enc_dec.train_fn

    if configuration['with_reconstruction'] and configuration[
            'with_fast_training']:
        fast_train_fn = enc_dec.fast_train_fn

    # train data
    ds = DStream(**configuration)

    # valid data
    vs = get_devtest_stream(data_type='valid',
                            input_file=None,
                            **configuration)

    iters = args.start
    valid_bleu_best = -1
    epoch_best = -1
    iters_best = -1
    max_epochs = configuration['finish_after']

    # TODO: use global iter and only the chief can save the model
    for epoch in range(max_epochs):
        for x, x_mask, y, y_mask in ds.get_iterator():
            last_time = time.time()
            if configuration['with_reconstruction'] and configuration[
                    'with_fast_training'] and iters < configuration[
                        'fast_training_iterations']:
                if configuration['fix_base_parameters'] and not configuration[
                        'with_tied_weights']:
                    tc = fast_train_fn([x.T, x_mask.T, y.T, y_mask.T])
                else:
                    tc = fast_train_fn([x.T, x_mask.T, y.T, y_mask.T, iters])
            else:
                tc = train_fn([x.T, x_mask.T, y.T, y_mask.T])
            cur_time = time.time()
            iters += 1
            logger.info(
                'epoch %d \t updates %d train cost %.4f use time %.4f' %
                (epoch, iters, tc[0], cur_time - last_time))

            if iters % configuration['save_freq'] == 0:
                enc_dec.save()

            if iters % configuration['sample_freq'] == 0:
                sampler.apply(x, y)

            if iters < configuration['val_burn_in']:
                continue

            if (iters <= configuration['val_burn_in_fine'] and iters % configuration['valid_freq'] == 0) \
               or (iters > configuration['val_burn_in_fine'] and iters % configuration['valid_freq_fine'] == 0):
                valid_bleu = bleuvalidator.apply(vs,
                                                 configuration['valid_out'])
                os.system('mkdir -p results/%d' % iters)
                os.system('mv %s* %s results/%d' %
                          (configuration['valid_out'], configuration['saveto'],
                           iters))
                logger.info(
                    'valid_test \t epoch %d \t updates %d valid_bleu %.4f' %
                    (epoch, iters, valid_bleu))
                if valid_bleu > valid_bleu_best:
                    valid_bleu_best = valid_bleu
                    epoch_best = epoch
                    iters_best = iters
                    enc_dec.save(path=configuration['saveto_best'])

    logger.info('final result: epoch %d \t updates %d valid_bleu_best %.4f' %
                (epoch_best, iters_best, valid_bleu_best))
Exemplo n.º 2
0
            # added by Zhaopeng Tu, 2017-11-29
            # for layer_norm with adam, we explicitly update parameters
            # will be merged in the future
            if enc_dec.with_layernorm:
                update_fn(0.001)

            cur_time = time.time()
            iters += 1
            logger.info('epoch %d \t updates %d train cost %.4f use time %.4f'
                        %(epoch, iters, tc[0], cur_time-last_time))

            if iters % configuration['save_freq'] == 0:
                enc_dec.save()

            if iters % configuration['sample_freq'] == 0:
                sampler.apply(x, y)

            if iters < configuration['val_burn_in']:
                continue

            if (iters <= configuration['val_burn_in_fine'] and iters % configuration['valid_freq'] == 0) \
               or (iters > configuration['val_burn_in_fine'] and iters % configuration['valid_freq_fine'] == 0):
                valid_bleu = bleuvalidator.apply(vs, configuration['valid_src'], configuration['valid_out'])
                os.system('mkdir -p out/%d' % iters)
                os.system('mv %s* %s out/%d' % (configuration['valid_out'], configuration['saveto'], iters))
                logger.info('valid_test \t epoch %d \t updates %d valid_bleu %.4f'
                        %(epoch, iters, valid_bleu))
                if valid_bleu > valid_bleu_best:
                    valid_bleu_best = valid_bleu
                    epoch_best = epoch
                    iters_best = iters
Exemplo n.º 3
0
    max_epochs = configuration['finish_after']

    for epoch in range(max_epochs):
        for x, x_mask, x_hist, x_mask_hist, y, y_mask in ds.get_iterator():
            last_time = time.time()
            tc = train_fn(x.T, x_mask.T, x_hist.T, x_mask_hist.T, y.T, y_mask.T)
            cur_time = time.time()
            iters += 1
            logger.info('epoch %d \t updates %d train cost %.4f use time %.4f'
                        %(epoch, iters, tc[0], cur_time-last_time))

            if iters % configuration['save_freq'] == 0:
                enc_dec.save()

            if iters % configuration['sample_freq'] == 0:
                sampler.apply(x, x_hist, y)

            if iters < configuration['val_burn_in']:
                continue

            if (iters <= configuration['val_burn_in_fine'] and iters % configuration['valid_freq'] == 0) \
               or (iters > configuration['val_burn_in_fine'] and iters % configuration['valid_freq_fine'] == 0):
                valid_bleu = bleuvalidator.apply(vs, vs_hist, configuration['valid_out'])
                os.system('mkdir -p results/%d' % iters)
                os.system('mv %s* %s results/%d' % (configuration['valid_out'], configuration['saveto'], iters))
                logger.info('valid_test \t epoch %d \t updates %d valid_bleu %.4f'
                        %(epoch, iters, valid_bleu))
                if valid_bleu > valid_bleu_best:
                    valid_bleu_best = valid_bleu
                    epoch_best = epoch
                    iters_best = iters