示例#1
0
def train(
        dim_word=100,
        dim_word_src=200,
        enc_dim=1000,
        dec_dim=1000,  # the number of LSTM units
        patience=-1,  # early stopping patience
        max_epochs=5000,
        finish_after=-1,  # finish after this many updates
        decay_c=0.,  # L2 regularization penalty
        alpha_c=0.,  # alignment regularization
        clip_c=-1.,  # gradient clipping threshold
        lrate=0.01,  # learning rate
        n_words_src=100000,  # source vocabulary size
        n_words=100000,  # target vocabulary size
        maxlen=100,  # maximum length of the description
        maxlen_trg=None,  # maximum length of the description
        maxlen_sample=1000,
        optimizer='rmsprop',
        batch_size=16,
        valid_batch_size=16,
        sort_size=20,
        save_path=None,
        save_file_name='model',
        save_best_models=0,
        dispFreq=100,
        validFreq=100,
        saveFreq=1000,  # save the parameters after every saveFreq updates
        sampleFreq=-1,
        verboseFreq=10000,
        datasets=[
            'data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok',
            '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok'
        ],
        valid_datasets=[
            '../data/dev/newstest2011.en.tok',
            '../data/dev/newstest2011.fr.tok'
        ],
        dictionaries=[
            '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok.pkl',
            '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok.pkl'
        ],
        source_word_level=0,
        target_word_level=0,
        use_dropout=False,
        re_load=False,
        re_load_old_setting=False,
        uidx=None,
        eidx=None,
        cidx=None,
        layers=None,
        save_every_saveFreq=0,
        save_burn_in=20000,
        use_bpe=0,
        init_params=None,
        build_model=None,
        build_sampler=None,
        gen_sample=None,
        **kwargs):

    if maxlen_trg is None:
        maxlen_trg = maxlen * 10
    # Model options
    model_options = locals().copy()
    del model_options['init_params']
    del model_options['build_model']
    del model_options['build_sampler']
    del model_options['gen_sample']

    # load dictionaries and invert them
    worddicts = [None] * len(dictionaries)
    worddicts_r = [None] * len(dictionaries)
    for ii, dd in enumerate(dictionaries):
        with open(dd, 'rb') as f:
            worddicts[ii] = cPickle.load(f)
        worddicts_r[ii] = dict()
        for kk, vv in worddicts[ii].iteritems():
            worddicts_r[ii][vv] = kk

    print 'Building model'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    file_name = '%s%s.npz' % (save_path, save_file_name)
    best_file_name = '%s%s.best.npz' % (save_path, save_file_name)
    opt_file_name = '%s%s%s.npz' % (save_path, save_file_name, '.grads')
    best_opt_file_name = '%s%s%s.best.npz' % (save_path, save_file_name,
                                              '.grads')
    model_name = '%s%s.pkl' % (save_path, save_file_name)
    params = init_params(model_options)
    cPickle.dump(model_options, open(model_name, 'wb'))
    history_errs = []

    # reload options
    if re_load and os.path.exists(file_name):
        print 'You are reloading your experiment.. do not panic dude..'
        if re_load_old_setting:
            with open(model_name, 'rb') as f:
                models_options = cPickle.load(f)
        params = load_params(file_name, params)
        # reload history
        model = numpy.load(file_name)
        history_errs = list(model['history_errs'])
        if uidx is None:
            uidx = model['uidx']
        if eidx is None:
            eidx = model['eidx']
        if cidx is None:
            cidx = model['cidx']
    else:
        if uidx is None:
            uidx = 0
        if eidx is None:
            eidx = 0
        if cidx is None:
            cidx = 0

    print 'Loading data'
    train = TextIterator(source=datasets[0],
                         target=datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=batch_size,
                         sort_size=sort_size)
    valid = TextIterator(source=valid_datasets[0],
                         target=valid_datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=valid_batch_size,
                         sort_size=sort_size)

    # create shared variables for parameters
    tparams = init_tparams(params)

    trng, use_noise, \
        x, x_mask, y, y_mask, \
        opt_ret, \
        cost = \
        build_model(tparams, model_options)
    inps = [x, x_mask, y, y_mask]

    print 'Building sampler...\n',
    f_init, f_next = build_sampler(tparams, model_options, trng, use_noise)
    #print 'Done'

    # before any regularizer
    print 'Building f_log_probs...',
    f_log_probs = theano.function(inps, cost, profile=profile)
    print 'Done'
    if re_load:
        use_noise.set_value(0.)
        valid_errs = pred_probs(f_log_probs,
                                prepare_data,
                                model_options,
                                valid,
                                verboseFreq=verboseFreq)
        valid_err = valid_errs.mean()

        if numpy.isnan(valid_err):
            import ipdb
            ipdb.set_trace()

        print 'Reload sanity check: Valid ', valid_err

    cost = cost.mean()

    # apply L2 regularization on weights
    if decay_c > 0.:
        decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
        weight_decay = 0.
        for kk, vv in tparams.iteritems():
            weight_decay += (vv**2).sum()
        weight_decay *= decay_c
        cost += weight_decay

    # regularize the alpha weights
    if alpha_c > 0. and not model_options['decoder'].endswith('simple'):
        alpha_c = theano.shared(numpy.float32(alpha_c), name='alpha_c')
        alpha_reg = alpha_c * (
            (tensor.cast(y_mask.sum(0) // x_mask.sum(0), 'float32')[:, None] -
             opt_ret['dec_alphas'].sum(0))**2).sum(1).mean()
        cost += alpha_reg

    # after all regularizers - compile the computational graph for cost
    print 'Building f_cost...',
    f_cost = theano.function(inps, cost, profile=profile)
    print 'Done'

    print 'Computing gradient...',
    grads = tensor.grad(cost, wrt=itemlist(tparams))
    print 'Done'

    if clip_c > 0:
        grads, not_finite, clipped = gradient_clipping(grads, tparams, clip_c)
    else:
        not_finite = 0
        clipped = 0

    # compile the optimizer, the actual computational graph is compiled here
    lr = tensor.scalar(name='lr')
    print 'Building optimizers...',
    if re_load and os.path.exists(file_name):
        if clip_c > 0:
            f_grad_shared, f_update, toptparams = eval(optimizer)(
                lr,
                tparams,
                grads,
                inps,
                cost=cost,
                not_finite=not_finite,
                clipped=clipped,
                file_name=opt_file_name)
        else:
            f_grad_shared, f_update, toptparams = eval(optimizer)(
                lr, tparams, grads, inps, cost=cost, file_name=opt_file_name)
    else:
        if clip_c > 0:
            f_grad_shared, f_update, toptparams = eval(optimizer)(
                lr,
                tparams,
                grads,
                inps,
                cost=cost,
                not_finite=not_finite,
                clipped=clipped)
        else:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr,
                                                                  tparams,
                                                                  grads,
                                                                  inps,
                                                                  cost=cost)
    print 'Done'

    print 'Optimization'
    best_p = None
    bad_counter = 0

    if validFreq == -1:
        validFreq = len(train[0]) / batch_size
    if saveFreq == -1:
        saveFreq = len(train[0]) / batch_size

    # Training loop
    ud_start = time.time()
    estop = False

    if re_load:
        print "Checkpointed minibatch number: %d" % cidx
        for cc in xrange(cidx):
            if numpy.mod(cc, 1000) == 0:
                print "Jumping [%d / %d] examples" % (cc, cidx)
            train.next()

    for epoch in xrange(max_epochs):
        n_samples = 0
        NaN_grad_cnt = 0
        NaN_cost_cnt = 0
        clipped_cnt = 0
        if re_load:
            re_load = 0
        else:
            cidx = 0

        for x, y in train:
            cidx += 1
            uidx += 1
            use_noise.set_value(1.)

            x, x_mask, y, y_mask, n_x = prepare_data(x,
                                                     y,
                                                     maxlen=maxlen,
                                                     maxlen_trg=maxlen_trg,
                                                     n_words_src=n_words_src,
                                                     n_words=n_words)
            n_samples += n_x

            if x is None:
                print 'Minibatch with zero sample under length ', maxlen
                uidx -= 1
                uidx = max(uidx, 0)
                continue

            # compute cost, grads and copy grads to shared variables
            if clip_c > 0:
                cost, not_finite, clipped = f_grad_shared(x, x_mask, y, y_mask)
            else:
                cost = f_grad_shared(x, x_mask, y, y_mask)

            if clipped:
                clipped_cnt += 1

            # check for bad numbers, usually we remove non-finite elements
            # and continue training - but not done here
            if numpy.isnan(cost) or numpy.isinf(cost):
                NaN_cost_cnt += 1

            if not_finite:
                NaN_grad_cnt += 1
                continue

            # do the update on parameters
            f_update(lrate)

            if numpy.isnan(cost) or numpy.isinf(cost):
                continue

            if float(NaN_grad_cnt) > max_epochs * 0.5 or float(
                    NaN_cost_cnt) > max_epochs * 0.5:
                print 'Too many NaNs, abort training'
                return 1., 1., 1.

            # verbose
            if numpy.mod(uidx, dispFreq) == 0:
                ud = time.time() - ud_start
                print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost, 'NaN_in_grad', NaN_grad_cnt,\
                      'NaN_in_cost', NaN_cost_cnt, 'Gradient_clipped', clipped_cnt, 'UD ', ud
                ud_start = time.time()

            # generate some samples with the model and display them
            if numpy.mod(uidx, sampleFreq) == 0 and sampleFreq != -1:
                # FIXME: random selection?
                for jj in xrange(numpy.minimum(5, x.shape[1])):
                    stochastic = True
                    use_noise.set_value(0.)
                    sample, score = gen_sample(tparams,
                                               f_init,
                                               f_next,
                                               x[:, jj][:, None],
                                               model_options,
                                               trng=trng,
                                               k=1,
                                               maxlen=maxlen_sample,
                                               stochastic=stochastic,
                                               argmax=False)
                    print
                    print 'Source ', jj, ': ',
                    if source_word_level:
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                if use_bpe:
                                    print(worddicts_r[0][vv]).replace(
                                        '@@', ''),
                                else:
                                    print worddicts_r[0][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        source_ = []
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                source_.append(worddicts_r[0][vv])
                            else:
                                source_.append('UNK')
                        print "".join(source_)
                    print 'Truth ', jj, ' : ',
                    if target_word_level:
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print(worddicts_r[1][vv]).replace(
                                        '@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        truth_ = []
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                truth_.append(worddicts_r[1][vv])
                            else:
                                truth_.append('UNK')
                        print "".join(truth_)
                    print 'Sample ', jj, ': ',
                    if stochastic:
                        ss = sample
                    else:
                        score = score / numpy.array([len(s) for s in sample])
                        ss = sample[score.argmin()]
                    if target_word_level:
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print(worddicts_r[1][vv]).replace(
                                        '@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        sample_ = []
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                sample_.append(worddicts_r[1][vv])
                            else:
                                sample_.append('UNK')
                        print "".join(sample_)
                    print

            # validate model on validation set and early stop if necessary
            if numpy.mod(uidx, validFreq) == 0:
                use_noise.set_value(0.)
                valid_errs = pred_probs(f_log_probs,
                                        prepare_data,
                                        model_options,
                                        valid,
                                        verboseFreq=verboseFreq)
                valid_err = valid_errs.mean()
                history_errs.append(valid_err)

                if uidx == 0 or valid_err <= numpy.array(history_errs).min():
                    best_p = unzip(tparams)
                    best_optp = unzip(toptparams)
                    bad_counter = 0

                if saveFreq != validFreq and save_best_models:
                    numpy.savez(best_file_name,
                                history_errs=history_errs,
                                uidx=uidx,
                                eidx=eidx,
                                cidx=cidx,
                                **best_p)
                    numpy.savez(best_opt_file_name, **best_optp)

                if len(history_errs) > patience and valid_err >= \
                        numpy.array(history_errs)[:-patience].min() and patience != -1:
                    bad_counter += 1
                    if bad_counter > patience:
                        print 'Early Stop!'
                        estop = True
                        break

                if numpy.isnan(valid_err):
                    import ipdb
                    ipdb.set_trace()

                print 'Valid ', valid_err

            # save the best model so far
            if numpy.mod(uidx, saveFreq) == 0:
                print 'Saving...',

                if not os.path.exists(save_path):
                    os.mkdir(save_path)

                params = unzip(tparams)
                optparams = unzip(toptparams)
                numpy.savez(file_name,
                            history_errs=history_errs,
                            uidx=uidx,
                            eidx=eidx,
                            cidx=cidx,
                            **params)
                numpy.savez(opt_file_name, **optparams)

                if save_every_saveFreq and (uidx >= save_burn_in):
                    this_file_name = '%s%s.%d.npz' % (save_path,
                                                      save_file_name, uidx)
                    this_opt_file_name = '%s%s%s.%d.npz' % (
                        save_path, save_file_name, '.grads', uidx)
                    numpy.savez(this_file_name,
                                history_errs=history_errs,
                                uidx=uidx,
                                eidx=eidx,
                                cidx=cidx,
                                **params)
                    numpy.savez(this_opt_file_name,
                                history_errs=history_errs,
                                uidx=uidx,
                                eidx=eidx,
                                cidx=cidx,
                                **params)
                    if best_p is not None and saveFreq != validFreq:
                        this_best_file_name = '%s%s.%d.best.npz' % (
                            save_path, save_file_name, uidx)
                        numpy.savez(this_best_file_name,
                                    history_errs=history_errs,
                                    uidx=uidx,
                                    eidx=eidx,
                                    cidx=cidx,
                                    **best_p)
                print 'Done...',
                print 'Saved to %s' % file_name

            # finish after this many updates
            if uidx >= finish_after and finish_after != -1:
                print 'Finishing after %d iterations!' % uidx
                estop = True
                break

        print 'Seen %d samples' % n_samples
        eidx += 1

        if estop:
            break

    use_noise.set_value(0.)
    valid_err = pred_probs(f_log_probs, prepare_data, model_options,
                           valid).mean()

    print 'Valid ', valid_err

    params = unzip(tparams)
    optparams = unzip(toptparams)
    file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
    opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads',
                                       uidx)
    numpy.savez(file_name,
                history_errs=history_errs,
                uidx=uidx,
                eidx=eidx,
                cidx=cidx,
                **params)
    numpy.savez(opt_file_name, **optparams)
    if best_p is not None and saveFreq != validFreq:
        best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
        best_opt_file_name = '%s%s%s.%d.best.npz' % (save_path, save_file_name,
                                                     '.grads', uidx)
        numpy.savez(best_file_name,
                    history_errs=history_errs,
                    uidx=uidx,
                    eidx=eidx,
                    cidx=cidx,
                    **best_p)
        numpy.savez(best_opt_file_name, **best_optp)

    return valid_err
示例#2
0
def train(
      dim_word=100,
      dim_word_src=200,
      enc_dim=1000,
      dec_dim=1000,  # the number of LSTM units
      patience=-1,  # early stopping patience
      max_epochs=5000,
      finish_after=-1,  # finish after this many updates
      decay_c=0.,  # L2 regularization penalty
      alpha_c=0.,  # alignment regularization
      clip_c=-1.,  # gradient clipping threshold
      lrate=0.01,  # learning rate
      n_words_src=100000,  # source vocabulary size
      n_words=100000,  # target vocabulary size
      maxlen=100,  # maximum length of the description
      maxlen_trg=None,  # maximum length of the description
      maxlen_sample=1000,
      optimizer='rmsprop',
      batch_size=16,
      valid_batch_size=16,
      sort_size=20,
      save_path=None,
      save_file_name='model',
      save_best_models=0,
      dispFreq=100,
      validFreq=100,
      saveFreq=1000,   # save the parameters after every saveFreq updates
      sampleFreq=-1,
      verboseFreq=10000,
      datasets=[
          'data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok',
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok'],
      valid_datasets=['../data/dev/newstest2011.en.tok',
                      '../data/dev/newstest2011.fr.tok'],
      dictionaries=[
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok.pkl',
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok.pkl'],
      source_word_level=0,
      target_word_level=0,
      use_dropout=False,
      re_load=False,
      re_load_old_setting=False,
      uidx=None,
      eidx=None,
      cidx=None,
      layers=None,
      save_every_saveFreq=0,
      save_burn_in=20000,
      use_bpe=0,
      init_params=None,
      build_model=None,
      build_sampler=None,
      gen_sample=None,
      c_lb=2.,
      st_estimator=None,
      learn_t=False,
      shuffle_dataset=False,
      only_use_w=False,
      nb_cumulate = 1,
      repeat_actions=False,
      decoder_type="You have to set this.",
      layer_norm=False,
      planning_do_layerNorm=False,
      **kwargs
    ):



    if maxlen_trg is None:
        maxlen_trg = maxlen * 10
    # Model options
    model_options = locals().copy()
    del model_options['init_params']
    del model_options['build_model']
    del model_options['build_sampler']
    del model_options['gen_sample']

    # load dictionaries and invert them
    worddicts = [None] * len(dictionaries)
    worddicts_r = [None] * len(dictionaries)
    for ii, dd in enumerate(dictionaries):
        with open(dd, 'rb') as f:
            worddicts[ii] = cPickle.load(f)
        worddicts_r[ii] = dict()
        for kk, vv in worddicts[ii].iteritems():
            worddicts_r[ii][vv] = kk

    print 'Building model'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    file_name = '%s%s.npz' % (save_path, save_file_name)
    best_file_name = '%s%s.best.npz' % (save_path, save_file_name)
    opt_file_name = '%s%s%s.npz' % (save_path, save_file_name, '.grads')
    best_opt_file_name = '%s%s%s.best.npz' % (save_path, save_file_name, '.grads')
    model_name = '%s%s.pkl' % (save_path, save_file_name)
    params = init_params(model_options)
    cPickle.dump(model_options, open(model_name, 'wb'))
    history_errs = []
    debug_variables = []

    # reload options
    if re_load and os.path.exists(file_name):
        print 'You are reloading your experiment.. do not panic dude..'
        if re_load_old_setting:
            with open(model_name, 'rb') as f:
                models_options = cPickle.load(f)
        params = load_params(file_name, params)
        # reload history
        model = numpy.load(file_name)
        history_errs = list(model['history_errs'])
        if 'debug_variables' in model.keys():
            debug_variables = list(model['debug_variables'])

        if uidx is None:
            uidx = model['uidx']
        if eidx is None:
            eidx = model['eidx']
        if cidx is None:
            cidx = model['cidx']
    else:
        if uidx is None:
            uidx = 0
        if eidx is None:
            eidx = 0
        if cidx is None:
            cidx = 0

    print 'Loading data'

    if shuffle_dataset:
        print "We will shuffle the data after each epoch."
    else:
        print "We won't shuffle the data after each epoch."

    train = TextIterator(source=datasets[0],
                         target=datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=batch_size,
                         sort_size=sort_size,
                         shuffle_per_epoch=shuffle_dataset)

    valid = TextIterator(source=valid_datasets[0],
                         target=valid_datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=valid_batch_size,
                         sort_size=sort_size)


    #print "testing the shuffling dataset..."
    #sh_start = time.time()
    #train.reset()
    #print "took {} sec.".format(time.time() - sh_start)
    #sys.exit(0)


    # create shared variables for parameters

    tparams = init_tparams(params)

    trng, use_noise, \
        x, x_mask, y, y_mask, \
        opt_ret, \
        cost = \
        build_model(tparams, model_options)
    inps = [x, x_mask, y, y_mask]



    print 'Building sampler...\n',
    f_init, f_next = build_sampler(tparams, model_options, trng, use_noise)
    #print 'Done'


    # before any regularizer
    print 'Building f_log_probs...',

    # For REINFORCE and stuff
    up = OrderedDict()
    if 'dec_updates' in opt_ret:
        up = opt_ret['dec_updates']

    f_log_probs = theano.function(inps, cost, profile=profile, updates=up)
    print 'Done'
    if re_load:
        use_noise.set_value(0.)
        valid_errs = pred_probs(f_log_probs, prepare_data,
                                model_options, valid, verboseFreq=verboseFreq)
        valid_err = valid_errs.mean()

        if numpy.isnan(valid_err):
            import ipdb
            ipdb.set_trace()

        print 'Reload sanity check: Valid ', valid_err

    cost = cost.mean()

    # apply L2 regularization on weights
    if decay_c > 0.:
        decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
        weight_decay = 0.
        for kk, vv in tparams.iteritems():
            weight_decay += (vv ** 2).sum()
        weight_decay *= decay_c
        cost += weight_decay

    # regularize the alpha weights
    #TODO: Will have to check if this still applies
    if alpha_c > 0. and not model_options['decoder'].endswith('simple'):
        alpha_c = theano.shared(numpy.float32(alpha_c), name='alpha_c')
        alpha_reg = alpha_c * (
            (tensor.cast(y_mask.sum(0) // x_mask.sum(0), 'float32')[:, None] -
             opt_ret['dec_alphas'].sum(0))**2).sum(1).mean()
        cost += alpha_reg

    # The commit penalty
    commit_penalty = None
    f_planning = None
    pre_cost = cost
    commits = None

    doing_planning = 'dec_samples' in opt_ret
    if doing_planning:

        probs = opt_ret['dec_probs']
        nb_ex = probs.shape[0]
        #probs = probs.flatten()
        #entropy = -1 * (probs * tensor.log(probs)).sum()/nb_ex
        #commit_penalty = -c_lb*entropy
        #sum over the plan dimension, than apply the mask
        commit_penalty = ((c_lb*((1./model_options['kwargs']['plan_step'] - probs)**2).sum(axis=-1)*y_mask).sum(axis=0))
        commit_penalty = (commit_penalty/(y_mask.sum(axis=0) + 1e-4)).mean()

        cost += commit_penalty

        commits = opt_ret['dec_commits']
        commits = ((commits[:, :, 0]*y_mask).sum(axis=0)/(y_mask.sum(axis=0) + 1e-4)).mean() # Sum over timestep, average over minibatch

    #planning function
        cost_output = [cost, opt_ret['dec_probs'], opt_ret['dec_commits'], opt_ret['dec_action_plans'], x, y, commit_penalty]
        f_planning = theano.function(inps, cost_output, profile=profile, updates=up)

    # after all regularizers - compile the computational graph for cost
    print 'Building f_cost...',
    f_cost = theano.function(inps, cost, profile=profile, updates=up)

    print 'Done'

    print 'Computing gradient...',

    #The gradient for the commitment plan (REINFORCE)
    known_grads = None
    new_updates = OrderedDict()
    if st_estimator == "REINFORCE":
        known_grads, new_updates = stochastic_unit.REINFORCEMENT().bprop(opt_ret['dec_probs'],
                                                                     opt_ret['dec_samples'], cost, OrderedDict())
        up.update(new_updates)

    grads = tensor.grad(cost, wrt=itemlist(tparams), known_grads=known_grads)
    grads = [g.astype("float32") for g in grads]

    #Debug output
    debug_output = []

    """
    if st_estimator is not None:
        # Debug variables
        try:
            sub_grads_c = tensor.grad(commit_penalty, wrt=[tparams['decoder_planning_commit_ww']], known_grads=known_grads)
            sub_grads_p = tensor.grad(pre_cost, wrt=[tparams['decoder_planning_commit_ww']], known_grads=known_grads)
            sub_grads_all = tensor.grad(cost, wrt=[tparams['decoder_planning_commit_ww']], known_grads=known_grads)#

            sub_grads_c = tensor.mean(tensor.abs_(sub_grads_c[0]))
            sub_grads_p = tensor.mean(tensor.abs_(sub_grads_p[0]))
            sub_grads_all = tensor.mean(tensor.abs_(sub_grads_all[0]))
            #debug_output = [sub_grads_all, sub_grads_c, sub_grads_p, commits, opt_ret['dec_temperature']]
        except KeyError as e:
            print e
            print "Continuing anyway."
    """

    print 'Done'

    if clip_c > 0:
        grads, not_finite, clipped = gradient_clipping(grads, tparams, clip_c)
    else:
        not_finite = 0
        clipped = 0

    # compile the optimizer, the actual computational graph is compiled here
    lr = tensor.scalar(name='lr')
    print 'Building optimizers...',

    if re_load and os.path.exists(file_name):
        if clip_c > 0:
            f_grad_shared, f_update_algo, f_update_param, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                                not_finite=not_finite,
                                                                                nb_cumulate=nb_cumulate,
                                                                                clipped=clipped,
                                                                                file_name=opt_file_name,
                                                                                other_updates=up, other_outputs=debug_output)
        else:
            f_grad_shared, f_update_algo, f_update_param, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  file_name=opt_file_name, nb_cumulate=nb_cumulate,
                                                                  other_updates=up, other_outputs=debug_output)
    else:
        if clip_c > 0:
            f_grad_shared, f_update_algo, f_update_param, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                                       not_finite=not_finite, clipped=clipped,
                                                                                       nb_cumulate=nb_cumulate,
                                                                                       other_updates=up, other_outputs=debug_output)
        else:
            f_grad_shared, f_update_algo, f_update_param, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                                       nb_cumulate=nb_cumulate,
                                                                                       other_updates=up, other_outputs=debug_output)

    print 'Done'

    print 'Optimization'
    best_p = None
    bad_counter = 0

    if validFreq == -1:
        validFreq = len(train[0]) / batch_size
    if saveFreq == -1:
        saveFreq = len(train[0]) / batch_size

    # Training loop
    ud_start = time.time()
    estop = False

    if re_load:
        print "Checkpointed minibatch number: %d" % cidx
        for cc in xrange(cidx):
            if numpy.mod(cc, 1000)==0:
                print "Jumping [%d / %d] examples" % (cc, cidx)
            train.next()

    for epoch in xrange(max_epochs):
        n_samples = 0
        NaN_grad_cnt = 0
        NaN_cost_cnt = 0
        clipped_cnt = 0

        if re_load:
            re_load = 0
        else:
            cidx = 0

        for x, y in train:

            cidx += 1
            uidx += 1
            use_noise.set_value(1.)

            x, x_mask, y, y_mask, n_x = prepare_data(x, y, maxlen=maxlen,
                                                     maxlen_trg=maxlen_trg,
                                                     n_words_src=n_words_src,
                                                     n_words=n_words)

            if x is None:
                print 'Minibatch with zero sample under length ', maxlen
                uidx -= 1
                uidx = max(uidx, 0)
                continue

            n_samples += n_x

            #with open("debug_shape.txt", 'a') as ff:
            #    ff.write("And one:")
            #    ff.write(str(x.shape))
            #    ff.write(str(y.shape))
            #    ff.write(str(uidx))

            output = f_grad_shared(x, x_mask, y, y_mask)

            debug_output = []
            if clip_c > 0:
                cost = output[0]
                not_finite = output[1]
                clipped = output[2]
                debug_output = output[3:]
            else:
                cost = output[0]
                debug_output = output[1:]

            if clipped:
                clipped_cnt += 1

            # check for bad numbers, usually we remove non-finite elements
            # and continue training - but not done here
            if numpy.isnan(cost) or numpy.isinf(cost):
                NaN_cost_cnt += 1

            if not_finite:
                NaN_grad_cnt += 1
                continue

            # update the algorithm
            gnorm = f_update_algo(lrate)

            if (uidx % nb_cumulate) == nb_cumulate - 1:
                # do the update on parameters
                f_update_param()


            if numpy.isnan(cost) or numpy.isinf(cost):
                continue

            if float(NaN_grad_cnt) > max_epochs * 0.5 or float(NaN_cost_cnt) > max_epochs * 0.5:
                print 'Too many NaNs, abort training'
                return 1., 1., 1.

            # verbose
            if numpy.mod(uidx, dispFreq) == 0:
                ud = time.time() - ud_start
                print 'Epoch ', eidx, 'Update ', uidx, "Seen ", n_samples,'Cost ', cost, 'NaN_in_grad', NaN_grad_cnt,\
                      'NaN_in_cost', NaN_cost_cnt, 'Gradient_clipped', clipped_cnt, 'UD ', ud
                ud_start = time.time()

                print "Debug values:", debug_output
                print ""
                # For now we don't save all of them
                debug_variables.append(debug_output)

            # generate some samples with the model and display them
            if numpy.mod(uidx, sampleFreq) == 0 and sampleFreq != -1:
                # FIXME: random selection?
                for jj in xrange(numpy.minimum(5, x.shape[1])):
                    stochastic = True
                    use_noise.set_value(0.)
                    res = gen_sample(tparams, f_init, f_next,
                                               x[:, jj][:, None],
                                               model_options, trng=trng, k=1,
                                               maxlen=maxlen_sample,
                                               stochastic=stochastic,
                                               argmax=False)
                    sample = res[0]
                    score = res[1]
                    print
                    print 'Source ', jj, ': ',
                    if source_word_level:
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                if use_bpe:
                                    print (worddicts_r[0][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[0][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        source_ = []
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                source_.append(worddicts_r[0][vv])
                            else:
                                source_.append('UNK')
                        print "".join(source_)
                    print 'Truth ', jj, ' : ',
                    if target_word_level:
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print (worddicts_r[1][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        truth_ = []
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                truth_.append(worddicts_r[1][vv])
                            else:
                                truth_.append('UNK')
                        print "".join(truth_)
                    print 'Sample ', jj, ': ',
                    if stochastic:
                        ss = sample
                    else:
                        score = score / numpy.array([len(s) for s in sample])
                        ss = sample[score.argmin()]
                    if target_word_level:
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print (worddicts_r[1][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        sample_ = []
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                sample_.append(worddicts_r[1][vv])
                            else:
                                sample_.append('UNK')
                        print "".join(sample_)
                    print

            # validate model on validation set and early stop if necessary
            if numpy.mod(uidx, validFreq) == 0:
                use_noise.set_value(0.)

                try:
                    valid_errs = pred_probs(f_log_probs, prepare_data,
                                        model_options, valid, verboseFreq=verboseFreq)
                    valid_err = valid_errs.mean()
                except MemoryError as e:
                    print "Merrory error! ", e
                    valid_err = history_errs[-1]

                history_errs.append(valid_err)

                if uidx == 0 or valid_err <= numpy.array(history_errs).min():
                    best_p = unzip(tparams)
                    best_optp = unzip(toptparams)
                    bad_counter = 0

                if saveFreq != validFreq and save_best_models:
                    numpy.savez(best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **best_p)
                    numpy.savez(best_opt_file_name, **best_optp)

                if len(history_errs) > patience and valid_err >= \
                        numpy.array(history_errs)[:-patience].min() and patience != -1:
                    bad_counter += 1
                    if bad_counter > patience:
                        print 'Early Stop!'
                        estop = True
                        break

                if numpy.isnan(valid_err):
                    import ipdb
                    ipdb.set_trace()

                print 'Valid ', valid_err

            # save the best model so far
            if numpy.mod(uidx, saveFreq) == 0:
                print 'Saving...',

                if not os.path.exists(save_path):
                    os.mkdir(save_path)

                params = unzip(tparams)
                optparams = unzip(toptparams)
                numpy.savez(file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                            cidx=cidx,debug_variables=debug_variables, **params)
                numpy.savez(opt_file_name, **optparams)
                numpy.savez(opt_file_name, **optparams)

                if save_every_saveFreq and (uidx >= save_burn_in):
                    this_file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
                    this_opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads', uidx)
                    numpy.savez(this_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, debug_variables=debug_variables, **params)
                    numpy.savez(this_opt_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, debug_variables=debug_variables,  **toptparams)

                    if best_p is not None: #and saveFreq != validFreq:
                        this_best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
                        this_best_grad_file_name = '%s%s.%d.best.grads.npz' % (save_path, save_file_name, uidx)
                        numpy.savez(this_best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                    cidx=cidx, debug_variables=debug_variables, **best_p)
                        numpy.savez(this_best_grad_file_name, **best_optp)

                print 'Done...',
                print 'Saved to %s' % file_name

                if doing_planning:
                    print "Saving a batch planning example"
                    this_file_name = '%s%s.planning_%d.pkl' % (save_path, save_file_name, uidx)
                    planning_examples = pred_planning(f_planning, prepare_data,
                                model_options, valid, verboseFreq=verboseFreq)

                    print "Cost of commitment: {} ({} commits)".format(planning_examples[0][-1], planning_examples[0][-1]/c_lb)

                    #cPickle.dump(planning_examples, open(this_file_name, "wb"))
                    #import ipdb
                    #ipdb.set_trace()
                    #grads_d = debug_grads(f_grad_debug, prepare_data,
                    #                              model_options, valid, verboseFreq=verboseFreq)

            # finish after this many updates
            if uidx >= finish_after and finish_after != -1:
                print 'Finishing after %d iterations!' % uidx
                estop = True
                break

        print 'Seen %d samples' % n_samples
        eidx += 1

        if estop:
            break

    use_noise.set_value(0.)
    valid_err = pred_probs(f_log_probs, prepare_data,
                           model_options, valid).mean()

    print 'Valid ', valid_err

    params = unzip(tparams)
    optparams = unzip(toptparams)
    file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
    opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads', uidx)
    numpy.savez(file_name, history_errs=history_errs, uidx=uidx, eidx=eidx, cidx=cidx, **params)
    numpy.savez(opt_file_name, **optparams)
    if best_p is not None and saveFreq != validFreq:
        best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
        best_opt_file_name = '%s%s%s.%d.best.npz' % (save_path, save_file_name, '.grads',uidx)
        numpy.savez(best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx, cidx=cidx, **best_p)
        numpy.savez(best_opt_file_name, **best_optp)

    return valid_err
示例#3
0
文件: main.py 项目: gschen/ESIM
def train(
        dim_word=100,  # word vector dimensionality
        dim=100,  # the number of GRU units
        encoder='lstm',  # encoder model
        decoder='lstm',  # decoder model
        patience=10,  # early stopping patience
        max_epochs=5000,
        finish_after=10000000,  # finish after this many updates
        decay_c=0.,  # L2 regularization penalty
        clip_c=-1.,  # gradient clipping threshold
        lrate=0.01,  # learning rate
        n_words=100000,  # vocabulary size
        maxlen=100,  # maximum length of the description
        optimizer='adadelta',
        batch_size=16,
        valid_batch_size=16,
        saveto='model.npz',
        LoadFrom='',
        dispFreq=100,
        validFreq=1000,
        saveFreq=1000,  # save the parameters after every saveFreq updates
        use_dropout=False,
        reload_=False,
        test=1,  # print verbose information for debug but slow speed
        datasets=[],
        valid_datasets=[],
        test_datasets=[],
        test_matched_datasets=[],
        test_mismatched_datasets=[],
        dictionary='',
        embedding='',  # pretrain embedding file, such as word2vec, GLOVE
):
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    #model_options
    model_options = locals().copy()
    model_options[
        'alphabet'] = " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"
    model_options['l_alphabet'] = len(model_options['alphabet'])
    model_options['dim_char_emb'] = 15
    model_options['char_nout'] = 100
    model_options['char_k_rows'] = 5
    model_options['char_k_cols'] = model_options['dim_char_emb']

    #load dictionary and invert them
    with open(dictionary, 'rb') as f:
        worddicts = pkl.load(f, encoding='iso-8859-1')
    worddicts_r = dict()
    for word in worddicts:
        worddicts_r[worddicts[word]] = word

    logger.debug(pprint.pformat(model_options))

    time.sleep(0.1)
    print('Loading data')

    #return (3,batch_size,-1)
    train = TextIterator(datasets[0],
                         datasets[1],
                         datasets[2],
                         dictionary,
                         n_words=n_words,
                         batch_size=batch_size)
    train_valid = TextIterator(datasets[0],
                               datasets[1],
                               datasets[2],
                               dictionary,
                               n_words=n_words,
                               batch_size=valid_batch_size,
                               shuffle=False)
    valid = TextIterator(valid_datasets[0],
                         valid_datasets[1],
                         valid_datasets[2],
                         dictionary,
                         n_words=n_words,
                         batch_size=valid_batch_size,
                         shuffle=False)
    test = TextIterator(test_datasets[0],
                        test_datasets[1],
                        test_datasets[2],
                        dictionary,
                        n_words=n_words,
                        batch_size=valid_batch_size,
                        shuffle=False)
    test_matched = TextIterator(test_matched_datasets[0],
                                test_matched_datasets[1],
                                test_matched_datasets[2],
                                dictionary,
                                n_words=n_words,
                                batch_size=valid_batch_size,
                                shuffle=False)
    test_mismatched = TextIterator(test_mismatched_datasets[0],
                                   test_mismatched_datasets[1],
                                   test_mismatched_datasets[2],
                                   dictionary,
                                   n_words=n_words,
                                   batch_size=valid_batch_size,
                                   shuffle=False)
    print('Building model')
    opt_ret, cost, pred, probs = build_model(model_options, worddicts)
    op = tf.train.AdamOptimizer(model_options['lrate'],
                                beta1=0.9,
                                beta2=0.999,
                                epsilon=1e-8).minimize(cost)

    uidx = 0
    eidx = 0
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if model_options['reload_']:
            saver = tf.train.Saver()
            saver.restore(sess, model_options['LoadFrom'])
            print('Reload dond!')
        train_loss = 0
        while True:
            try:
                x1, x2, label = train.next()
            except:
                eidx += 1
                print(eidx)
                continue
            _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                x1,
                x2,
                label,
                worddicts_r,
                model_options['alphabet'],
                maxlen=maxlen)
            ud_start = time.time()
            _cost, _pred, _prob, _ = sess.run(
                [cost, pred, probs, op],
                feed_dict={
                    use_noise: True,
                    word_x1: _x1,
                    word_x1_mask: _x1_mask,
                    char_x1: _char_x1,
                    word_x2: _x2,
                    word_x2_mask: _x2_mask,
                    char_x2: _char_x2,
                    char_x1_mask: _char_x1_mask,
                    char_x2_mask: _char_x2_mask,
                    y: _y
                })
            ud = time.time() - ud_start
            uidx += 1
            train_loss += _cost
            if uidx % model_options['dispFreq'] == 0:
                logger.debug('Epoch {0} Update {1} Cost {2} UD {3}'.format(
                    eidx,
                    uidx,
                    train_loss / model_options['dispFreq'],
                    ud,
                ))
                train_loss = 0
            if uidx % model_options['validFreq'] == 0:
                valid_cost = 0
                valid_pred = []
                valid_label = []
                n_vaild_samples = 0
                test_cost = 0
                test_pred = []
                test_label = []
                n_test_samples = 0
                while True:
                    try:
                        x1, x2, label = valid.next()
                        _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                            x1,
                            x2,
                            label,
                            worddicts_r,
                            model_options['alphabet'],
                            maxlen=maxlen)
                        _cost, _pred, _prob = sess.run(
                            [cost, pred, probs],
                            feed_dict={
                                use_noise: False,
                                word_x1: _x1,
                                word_x1_mask: _x1_mask,
                                char_x1: _char_x1,
                                word_x2: _x2,
                                word_x2_mask: _x2_mask,
                                char_x2: _char_x2,
                                char_x1_mask: _char_x1_mask,
                                char_x2_mask: _char_x2_mask,
                                y: _y
                            })
                        valid_cost += _cost * len(label)
                        valid_pred.extend(_pred)
                        valid_label.extend(_y)
                        n_vaild_samples += len(label)
                        print('Seen %d samples' % n_vaild_samples)
                    except:
                        break

                while True:
                    try:
                        x1, x2, label = test.next()
                        _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                            x1,
                            x2,
                            label,
                            worddicts_r,
                            model_options['alphabet'],
                            maxlen=maxlen)
                        _cost, _pred, _prob = sess.run(
                            [cost, pred, probs],
                            feed_dict={
                                use_noise: False,
                                word_x1: _x1,
                                word_x1_mask: _x1_mask,
                                char_x1: _char_x1,
                                word_x2: _x2,
                                word_x2_mask: _x2_mask,
                                char_x2: _char_x2,
                                char_x1_mask: _char_x1_mask,
                                char_x2_mask: _char_x2_mask,
                                y: _y
                            })
                        test_cost += _cost * len(label)
                        test_pred.extend(_pred)
                        test_label.extend(_y)
                        n_test_samples += len(label)
                        print('Seen %d samples' % n_test_samples)
                    except:
                        print('Valid cost', valid_cost / len(valid_label))
                        print(
                            'Valid accuracy',
                            numpy.mean(
                                numpy.array(valid_pred) == numpy.array(
                                    valid_label)))
                        print('Test cost', test_cost / len(test_label))
                        print(
                            'Test accuracy',
                            numpy.mean(
                                numpy.array(test_pred) == numpy.array(
                                    test_label)))
                        break
            if uidx % model_options['test'] == 0:
                mismatched_result = []
                matched_result = []
                while True:
                    try:
                        x1, x2, label = test_mismatched.next()
                        _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                            x1,
                            x2,
                            label,
                            worddicts_r,
                            model_options['alphabet'],
                            maxlen=maxlen)
                        _cost, _pred, _prob = sess.run(
                            [cost, pred, probs],
                            feed_dict={
                                use_noise: False,
                                word_x1: _x1,
                                word_x1_mask: _x1_mask,
                                char_x1: _char_x1,
                                word_x2: _x2,
                                word_x2_mask: _x2_mask,
                                char_x2: _char_x2,
                                char_x1_mask: _char_x1_mask,
                                char_x2_mask: _char_x2_mask,
                                y: _y
                            })
                        mismatched_result.extend(_pred)
                        print(len(mismatched_result))
                    except:
                        break
                while True:
                    try:
                        x1, x2, label = test_matched.next()
                        _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                            x1,
                            x2,
                            label,
                            worddicts_r,
                            model_options['alphabet'],
                            maxlen=maxlen)
                        _cost, _pred, _prob = sess.run(
                            [cost, pred, probs],
                            feed_dict={
                                use_noise: False,
                                word_x1: _x1,
                                word_x1_mask: _x1_mask,
                                char_x1: _char_x1,
                                word_x2: _x2,
                                word_x2_mask: _x2_mask,
                                char_x2: _char_x2,
                                char_x1_mask: _char_x1_mask,
                                char_x2_mask: _char_x2_mask,
                                y: _y
                            })
                        matched_result.extend(_pred)
                        print(len(matched_result))
                    except:
                        break
                index = 0
                a = []
                b = []
                dic = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
                for i in mismatched_result:
                    a.append((index, dic[i]))
                    index += 1
                for i in matched_result:
                    b.append((index, dic[i]))
                    index += 1
                a = pd.DataFrame(a)
                a.columns = ['pairID', 'gold_label']
                a.to_csv('sub_mismatched_' + str(uidx) + '.csv', index=False)
                b = pd.DataFrame(b)
                b.columns = ['pairID', 'gold_label']
                b.to_csv('sub_matched_' + str(uidx) + '.csv', index=False)
                print('submission ' + str(uidx) + ' done!')
            if uidx % model_options['saveFreq'] == 0:
                saver = tf.train.Saver()
                save_path = saver.save(
                    sess, model_options['saveto'] + '_' + str(uidx))
                print("Model saved in file: %s" % save_path)
示例#4
0
文件: nmt.py 项目: BloodD/dl4mt-cdec
def train(
      dim_word=100,
      dim_word_src=200,
      enc_dim=1000,
      dec_dim=1000,  # the number of LSTM units
      patience=-1,  # early stopping patience
      max_epochs=5000,
      finish_after=-1,  # finish after this many updates
      decay_c=0.,  # L2 regularization penalty
      alpha_c=0.,  # alignment regularization
      clip_c=-1.,  # gradient clipping threshold
      lrate=0.01,  # learning rate
      n_words_src=100000,  # source vocabulary size
      n_words=100000,  # target vocabulary size
      maxlen=100,  # maximum length of the description
      maxlen_trg=None,  # maximum length of the description
      maxlen_sample=1000,
      optimizer='rmsprop',
      batch_size=16,
      valid_batch_size=16,
      sort_size=20,
      save_path=None,
      save_file_name='model',
      save_best_models=0,
      dispFreq=100,
      validFreq=100,
      saveFreq=1000,   # save the parameters after every saveFreq updates
      sampleFreq=-1,
      verboseFreq=10000,
      datasets=[
          'data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok',
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok'],
      valid_datasets=['../data/dev/newstest2011.en.tok',
                      '../data/dev/newstest2011.fr.tok'],
      dictionaries=[
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok.pkl',
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok.pkl'],
      source_word_level=0,
      target_word_level=0,
      use_dropout=False,
      re_load=False,
      re_load_old_setting=False,
      uidx=None,
      eidx=None,
      cidx=None,
      layers=None,
      save_every_saveFreq=0,
      save_burn_in=20000,
      use_bpe=0,
      init_params=None,
      build_model=None,
      build_sampler=None,
      gen_sample=None,
      **kwargs
    ):

    if maxlen_trg is None:
        maxlen_trg = maxlen * 10
    # Model options
    model_options = locals().copy()
    del model_options['init_params']
    del model_options['build_model']
    del model_options['build_sampler']
    del model_options['gen_sample']

    # load dictionaries and invert them
    worddicts = [None] * len(dictionaries)
    worddicts_r = [None] * len(dictionaries)
    for ii, dd in enumerate(dictionaries):
        with open(dd, 'rb') as f:
            worddicts[ii] = cPickle.load(f)
        worddicts_r[ii] = dict()
        for kk, vv in worddicts[ii].iteritems():
            worddicts_r[ii][vv] = kk

    print 'Building model'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    file_name = '%s%s.npz' % (save_path, save_file_name)
    best_file_name = '%s%s.best.npz' % (save_path, save_file_name)
    opt_file_name = '%s%s%s.npz' % (save_path, save_file_name, '.grads')
    best_opt_file_name = '%s%s%s.best.npz' % (save_path, save_file_name, '.grads')
    model_name = '%s%s.pkl' % (save_path, save_file_name)
    params = init_params(model_options)
    cPickle.dump(model_options, open(model_name, 'wb'))
    history_errs = []

    # reload options
    if re_load and os.path.exists(file_name):
        print 'You are reloading your experiment.. do not panic dude..'
        if re_load_old_setting:
            with open(model_name, 'rb') as f:
                models_options = cPickle.load(f)
        params = load_params(file_name, params)
        # reload history
        model = numpy.load(file_name)
        history_errs = list(model['history_errs'])
        if uidx is None:
            uidx = model['uidx']
        if eidx is None:
            eidx = model['eidx']
        if cidx is None:
            cidx = model['cidx']
    else:
        if uidx is None:
            uidx = 0
        if eidx is None:
            eidx = 0
        if cidx is None:
            cidx = 0

    print 'Loading data'
    train = TextIterator(source=datasets[0],
                         target=datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=batch_size,
                         sort_size=sort_size)
    valid = TextIterator(source=valid_datasets[0],
                         target=valid_datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=valid_batch_size,
                         sort_size=sort_size)

    # create shared variables for parameters
    tparams = init_tparams(params)

    trng, use_noise, \
        x, x_mask, y, y_mask, \
        opt_ret, \
        cost = \
        build_model(tparams, model_options)
    inps = [x, x_mask, y, y_mask]

    print 'Building sampler...\n',
    f_init, f_next = build_sampler(tparams, model_options, trng, use_noise)
    #print 'Done'

    # before any regularizer
    print 'Building f_log_probs...',
    f_log_probs = theano.function(inps, cost, profile=profile)
    print 'Done'
    if re_load:
        use_noise.set_value(0.)
        valid_errs = pred_probs(f_log_probs, prepare_data,
                                model_options, valid, verboseFreq=verboseFreq)
        valid_err = valid_errs.mean()

        if numpy.isnan(valid_err):
            import ipdb
            ipdb.set_trace()

        print 'Reload sanity check: Valid ', valid_err

    cost = cost.mean()

    # apply L2 regularization on weights
    if decay_c > 0.:
        decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
        weight_decay = 0.
        for kk, vv in tparams.iteritems():
            weight_decay += (vv ** 2).sum()
        weight_decay *= decay_c
        cost += weight_decay

    # regularize the alpha weights
    if alpha_c > 0. and not model_options['decoder'].endswith('simple'):
        alpha_c = theano.shared(numpy.float32(alpha_c), name='alpha_c')
        alpha_reg = alpha_c * (
            (tensor.cast(y_mask.sum(0) // x_mask.sum(0), 'float32')[:, None] -
             opt_ret['dec_alphas'].sum(0))**2).sum(1).mean()
        cost += alpha_reg

    # after all regularizers - compile the computational graph for cost
    print 'Building f_cost...',
    f_cost = theano.function(inps, cost, profile=profile)
    print 'Done'

    print 'Computing gradient...',
    grads = tensor.grad(cost, wrt=itemlist(tparams))
    print 'Done'

    if clip_c > 0:
        grads, not_finite, clipped = gradient_clipping(grads, tparams, clip_c)
    else:
        not_finite = 0
        clipped = 0

    # compile the optimizer, the actual computational graph is compiled here
    lr = tensor.scalar(name='lr')
    print 'Building optimizers...',
    if re_load and os.path.exists(file_name):
        if clip_c > 0:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  not_finite=not_finite, clipped=clipped,
                                                                  file_name=opt_file_name)
        else:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  file_name=opt_file_name)
    else:
        if clip_c > 0:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  not_finite=not_finite, clipped=clipped)
        else:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost)
    print 'Done'

    print 'Optimization'
    best_p = None
    bad_counter = 0

    if validFreq == -1:
        validFreq = len(train[0]) / batch_size
    if saveFreq == -1:
        saveFreq = len(train[0]) / batch_size

    # Training loop
    ud_start = time.time()
    estop = False

    if re_load:
        print "Checkpointed minibatch number: %d" % cidx
        for cc in xrange(cidx):
            if numpy.mod(cc, 1000)==0:
                print "Jumping [%d / %d] examples" % (cc, cidx)
            train.next()

    for epoch in xrange(max_epochs):
        n_samples = 0
        NaN_grad_cnt = 0
        NaN_cost_cnt = 0
        clipped_cnt = 0
        if re_load:
            re_load = 0
        else:
            cidx = 0

        for x, y in train:
            cidx += 1
            uidx += 1
            use_noise.set_value(1.)

            x, x_mask, y, y_mask, n_x = prepare_data(x, y, maxlen=maxlen,
                                                     maxlen_trg=maxlen_trg,
                                                     n_words_src=n_words_src,
                                                     n_words=n_words)
            n_samples += n_x

            if x is None:
                print 'Minibatch with zero sample under length ', maxlen
                uidx -= 1
                uidx = max(uidx, 0)
                continue

            # compute cost, grads and copy grads to shared variables
            if clip_c > 0:
                cost, not_finite, clipped = f_grad_shared(x, x_mask, y, y_mask)
            else:
                cost = f_grad_shared(x, x_mask, y, y_mask)

            if clipped:
                clipped_cnt += 1

            # check for bad numbers, usually we remove non-finite elements
            # and continue training - but not done here
            if numpy.isnan(cost) or numpy.isinf(cost):
                NaN_cost_cnt += 1

            if not_finite:
                NaN_grad_cnt += 1
                continue

            # do the update on parameters
            f_update(lrate)

            if numpy.isnan(cost) or numpy.isinf(cost):
                continue

            if float(NaN_grad_cnt) > max_epochs * 0.5 or float(NaN_cost_cnt) > max_epochs * 0.5:
                print 'Too many NaNs, abort training'
                return 1., 1., 1.

            # verbose
            if numpy.mod(uidx, dispFreq) == 0:
                ud = time.time() - ud_start
                print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost, 'NaN_in_grad', NaN_grad_cnt,\
                      'NaN_in_cost', NaN_cost_cnt, 'Gradient_clipped', clipped_cnt, 'UD ', ud
                ud_start = time.time()

            # generate some samples with the model and display them
            if numpy.mod(uidx, sampleFreq) == 0 and sampleFreq != -1:
                # FIXME: random selection?
                for jj in xrange(numpy.minimum(5, x.shape[1])):
                    stochastic = True
                    use_noise.set_value(0.)
                    sample, score = gen_sample(tparams, f_init, f_next,
                                               x[:, jj][:, None],
                                               model_options, trng=trng, k=1,
                                               maxlen=maxlen_sample,
                                               stochastic=stochastic,
                                               argmax=False)
                    print
                    print 'Source ', jj, ': ',
                    if source_word_level:
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                if use_bpe:
                                    print (worddicts_r[0][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[0][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        source_ = []
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                source_.append(worddicts_r[0][vv])
                            else:
                                source_.append('UNK')
                        print "".join(source_)
                    print 'Truth ', jj, ' : ',
                    if target_word_level:
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print (worddicts_r[1][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        truth_ = []
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                truth_.append(worddicts_r[1][vv])
                            else:
                                truth_.append('UNK')
                        print "".join(truth_)
                    print 'Sample ', jj, ': ',
                    if stochastic:
                        ss = sample
                    else:
                        score = score / numpy.array([len(s) for s in sample])
                        ss = sample[score.argmin()]
                    if target_word_level:
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print (worddicts_r[1][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        sample_ = []
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                sample_.append(worddicts_r[1][vv])
                            else:
                                sample_.append('UNK')
                        print "".join(sample_)
                    print

            # validate model on validation set and early stop if necessary
            if numpy.mod(uidx, validFreq) == 0:
                use_noise.set_value(0.)
                valid_errs = pred_probs(f_log_probs, prepare_data,
                                        model_options, valid, verboseFreq=verboseFreq)
                valid_err = valid_errs.mean()
                history_errs.append(valid_err)

                if uidx == 0 or valid_err <= numpy.array(history_errs).min():
                    best_p = unzip(tparams)
                    best_optp = unzip(toptparams)
                    bad_counter = 0

                if saveFreq != validFreq and save_best_models:
                    numpy.savez(best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **best_p)
                    numpy.savez(best_opt_file_name, **best_optp)

                if len(history_errs) > patience and valid_err >= \
                        numpy.array(history_errs)[:-patience].min() and patience != -1:
                    bad_counter += 1
                    if bad_counter > patience:
                        print 'Early Stop!'
                        estop = True
                        break

                if numpy.isnan(valid_err):
                    import ipdb
                    ipdb.set_trace()

                print 'Valid ', valid_err

            # save the best model so far
            if numpy.mod(uidx, saveFreq) == 0:
                print 'Saving...',

                if not os.path.exists(save_path):
                    os.mkdir(save_path)

                params = unzip(tparams)
                optparams = unzip(toptparams)
                numpy.savez(file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                            cidx=cidx, **params)
                numpy.savez(opt_file_name, **optparams)

                if save_every_saveFreq and (uidx >= save_burn_in):
                    this_file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
                    this_opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads', uidx)
                    numpy.savez(this_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **params)
                    numpy.savez(this_opt_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **params)
                    if best_p is not None and saveFreq != validFreq:
                        this_best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
                        numpy.savez(this_best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                    cidx=cidx, **best_p)
                print 'Done...',
                print 'Saved to %s' % file_name

            # finish after this many updates
            if uidx >= finish_after and finish_after != -1:
                print 'Finishing after %d iterations!' % uidx
                estop = True
                break

        print 'Seen %d samples' % n_samples
        eidx += 1

        if estop:
            break

    use_noise.set_value(0.)
    valid_err = pred_probs(f_log_probs, prepare_data,
                           model_options, valid).mean()

    print 'Valid ', valid_err

    params = unzip(tparams)
    optparams = unzip(toptparams)
    file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
    opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads', uidx)
    numpy.savez(file_name, history_errs=history_errs, uidx=uidx, eidx=eidx, cidx=cidx, **params)
    numpy.savez(opt_file_name, **optparams)
    if best_p is not None and saveFreq != validFreq:
        best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
        best_opt_file_name = '%s%s%s.%d.best.npz' % (save_path, save_file_name, '.grads',uidx)
        numpy.savez(best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx, cidx=cidx, **best_p)
        numpy.savez(best_opt_file_name, **best_optp)

    return valid_err