Beispiel #1
0
def validate(funcs, options, iterator, verbose=False):
    probs = []

    n_done = 0
    for k, (sx1, sy1, sx2, sy2) in enumerate(iterator):
        x1, x1_mask = prepare_data(sx1, 500, options['voc_sizes'][0])
        y1, y1_mask = prepare_data(sy1, 500, options['voc_sizes'][1])
        x2, x2_mask = prepare_data(sx2, 500, options['voc_sizes'][2])
        y2, y2_mask = prepare_data(sy2, 500, options['voc_sizes'][3])
        ty12, ty12_mask = prepare_cross(sy1, sy2, y1.shape[0])

        inps = [x1, x1_mask, y1, y1_mask,
                x2, x2_mask, y2, y2_mask,
                ty12, ty12_mask]
        if options['use_coverage']:
            if not options.get('nn_coverage', False):
                inps += [numpy.zeros((y2.shape[1], y2.shape[0]), dtype='float32')]
            else:
                inps += [numpy.zeros((y2.shape[1], y2.shape[0], options['cov_dim']), dtype='float32')]

        pprobs = funcs['valid'](*inps)
        for pp in pprobs:
            probs.append(pp)

        if verbose:
            print >>sys.stderr, '%d samples computed' % (n_done)

    return numpy.array(probs)
Beispiel #2
0
def validate(funcs, options, iterator, verbose=False):
    probs = []

    n_done = 0
    for k, inputs in enumerate(iterator):
        xs, xs_mask = zip(*[prepare_data(inputs[k], 500, options['voc_sizes'][k])
                            for k in range(0, options['n_inputs'], 2)])
        ys, ys_mask = zip(*[prepare_data(inputs[k], 500, model_options['voc_sizes'][k])
                            for k in range(1, model_options['n_inputs'], 2)])

        tys, tys_mask = zip(*[prepare_cross(inputs[1], inputs[k], ys[0].shape[0])
                              for k in range(3, model_options['n_inputs'], 2)])
        tys = list(tys)
        lens = 0
        for k in range(len(tys)):
            tys[k] += lens
            lens += ys[k + 1].shape[0]

        inps = []
        for k in range(len(xs)):
            inps += [xs[k], xs_mask[k], ys[k], ys_mask[k]]
        for k in range(len(tys)):
            inps += [tys[k], tys_mask[k]]

        if options['use_coverage']:
            if not options.get('nn_coverage', False):
                lens = 0
                for k in range(1, len(ys)):
                    lens += ys[k].shape[0]
                inps += [numpy.zeros((ys[1].shape[1], lens), dtype='float32')]  # initial coverage
            else:
                raise NotImplementedError

        pprobs = funcs['valid'](*inps)
        for pp in pprobs:
            probs.append(pp)

        if verbose:
            print >>sys.stderr, '%d samples computed' % (n_done)

    return numpy.array(probs)
Beispiel #3
0
            valid_errs = validate(funcs, model_options, valid, False)
            valid_err  = float(valid_errs.mean())
            history_errs.append(valid_err)

            if numpy.isnan(valid_err):
                print 'NaN detected'
                sys.exit(-1)

            print 'Valid ', valid_err
            if monitor:
                try:
                    monitor.push({'valid': float(str(valid_err))}, step=int(uidx))
                except Exception, e:
                    print e

        xs, xs_mask = zip(*[prepare_data(inputs[k], 500, model_options['voc_sizes'][k])
                            for k in range(0, model_options['n_inputs'], 2)])
        ys, ys_mask = zip(*[prepare_data(inputs[k], 500, model_options['voc_sizes'][k])
                            for k in range(1, model_options['n_inputs'], 2)])

        tys, tys_mask = zip(*[prepare_cross(inputs[1], inputs[k], ys[0].shape[0])
                            for k in range(3, model_options['n_inputs'], 2)])
        tys = list(tys)

        # additional process --> add an off-set...
        lens = 0
        for k in range(len(tys)):
            tys[k] += lens
            lens += ys[k+1].shape[0]

        inps = []
Beispiel #4
0
            valid_err  = float(valid_errs.mean())
            history_errs.append(valid_err)

            if numpy.isnan(valid_err):
                print 'NaN detected'
                sys.exit(-1)

            print 'Valid ', valid_err
            if monitor:
                try:
                    monitor.push({'valid': float(str(valid_err))}, step=int(uidx))
                except Exception, e:
                    print e

        # training
        x1, x1_mask = prepare_data(sx1, model_options['maxlen'], model_options['voc_sizes'][0])
        y1, y1_mask = prepare_data(sy1, model_options['maxlen'], model_options['voc_sizes'][1])
        x2, x2_mask = prepare_data(sx2, model_options['maxlen'], model_options['voc_sizes'][2])
        y2, y2_mask = prepare_data(sy2, model_options['maxlen'], model_options['voc_sizes'][3])
        ty12, ty12_mask = prepare_cross(sy1, sy2, y1.shape[0])

        v = model_options.get('drop', 1)
        if v < 1:
            drops = (model_options['rng'].rand(ty12_mask.shape[1]) < v)[None, :].astype('float32')
            ty12_mask *= drops
            print 'drop {} retrieved pairs'.format(drops.sum()),

        inps = [x1, x1_mask, y1, y1_mask,
                x2, x2_mask, y2, y2_mask,
                ty12, ty12_mask]
Beispiel #5
0
def train(dim_word=100,  # word vector dimensionality
          dim=1000,  # the number of GRU units
          encoder='gru',
          patience=10,  # early stopping patience
          max_epochs=5000,
          finish_after=10000000,  # finish after this many updates
          dispFreq=100,
          decay_c=0.,  # L2 weight decay penalty
          lrate=0.01,
          n_words=100000,  # vocabulary size
          vocab_dim=100000,  # Size of M, C
          memory_dim=1000,  # Dimension of memory
          memory_size=15,  # n_back to attend
          maxlen=100,  # maximum length of the description
          optimizer='rmsprop',
          batch_size=16,
          valid_batch_size=16,
          saveto='model.npz',
          validFreq=1000,
          saveFreq=1000,  # save the parameters after every saveFreq updates
          sampleFreq=100,  # generate some samples after every sampleFreq
          dataset='/data/lisatmp3/chokyun/wikipedia/extracted/wiki.tok.txt.gz',
          valid_dataset='../data/dev/newstest2011.en.tok',
          dictionary='/data/lisatmp3/chokyun/wikipedia/extracted/'
          'wiki.tok.txt.gz.pkl',
          use_dropout=False,
          reload_=False):

    # Model options
    model_options = locals().copy()

    # Theano random stream
    trng = RandomStreams(1234)

    # load dictionary
    with open(dictionary, 'rb') as f:
        worddicts = pkl.load(f)

    # invert dictionary
    worddicts_r = dict()
    for kk, vv in worddicts.iteritems():
        worddicts_r[vv] = kk

    # reload options
    if reload_ and os.path.exists(saveto):
        with open('%s.pkl' % saveto, 'rb') as f:
            model_options = pkl.load(f)

    print 'Loading data'
    train = TextIterator(dataset,
                         dictionary,
                         n_words_source=n_words,
                         batch_size=batch_size,
                         maxlen=maxlen)
    valid = TextIterator(valid_dataset,
                         dictionary,
                         n_words_source=n_words,
                         batch_size=valid_batch_size,
                         maxlen=maxlen)

    # initialize RMN
    rmn_ = RMN(model_options)

    print 'Building model'
    rmn_.init_params()

    # reload parameters
    if reload_ and os.path.exists(saveto):
        rmn_.load_params(saveto)

    # create shared variables for parameters
    tparams = rmn_.tparams

    # build the symbolic computational graph
    use_noise, x, x_mask, opt_ret, cost = rmn_.build_model()
    inps = [x, x_mask]

    print 'Buliding sampler'
    f_next = rmn_.build_sampler(trng)

    # before any regularizer
    print 'Building f_log_probs...',
    f_log_probs = theano.function(inps, cost, profile=profile)
    print 'Done'

    cost = cost.mean()

    # apply L2 regularization on weights
    if decay_c > 0.:
        decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
        weight_decay = 0.
        for kk, vv in tparams.iteritems():
            weight_decay += (vv ** 2).sum()
        weight_decay *= decay_c
        cost += weight_decay

    # after any regularizer - compile the computational graph for cost
    print 'Building f_cost...',
    f_cost = theano.function(inps, cost, profile=profile)
    print 'Done'

    print 'Computing gradient...',
    grads = tensor.grad(cost, wrt=itemlist(tparams))
    print 'Done'

    # compile the optimizer, the actual computational graph is compiled here
    lr = tensor.scalar(name='lr')
    print 'Building optimizers...',
    optimizer = getattr(importlib.import_module('optimizer'), optimizer)
    f_grad_shared, f_update = optimizer(lr, tparams, grads, inps, cost)
    print 'Done'

    print 'Optimization'

    history_errs = []
    uidx = 0
    estop = False
    bad_counter = 0

    # reload history
    if reload_ and os.path.exists(saveto):
        history_errs = list(numpy.load(saveto)['history_errs'])
        uidx = numpy.load(saveto)['uidx']
    best_p = None

    if validFreq == -1:
        validFreq = len(train[0])/batch_size
    if saveFreq == -1:
        saveFreq = len(train[0])/batch_size
    if sampleFreq == -1:
        sampleFreq = len(train[0])/batch_size

    # Training loop
    for eidx in xrange(max_epochs):
        n_samples = 0

        for x in train:
            n_samples += len(x)
            uidx += 1
            use_noise.set_value(1.)

            # pad batch and create mask
            x, x_mask = prepare_data(x, maxlen=maxlen, n_words=n_words)

            if x is None:
                print 'Minibatch with zero sample under length ', maxlen
                uidx -= 1
                continue

            ud_start = time.time()

            # compute cost, grads and copy grads to shared variables
            cost = f_grad_shared(x, x_mask)

            # do the update on parameters
            f_update(lrate)

            ud = time.time() - ud_start

            # check for bad numbers
            if numpy.isnan(cost) or numpy.isinf(cost):
                print 'NaN detected'
                return 1.

            # verbose
            if numpy.mod(uidx, dispFreq) == 0:
                print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost, 'UD ', ud

            # save the best model so far
            if numpy.mod(uidx, saveFreq) == 0:
                print 'Saving...',

                if best_p is not None:
                    params = best_p
                else:
                    params = unzip(tparams)
                numpy.savez(saveto, history_errs=history_errs, uidx=uidx,
                            **params)
                pkl.dump(model_options, open('%s.pkl' % saveto, 'wb'))
                print 'Done'

            # generate some samples with the model and display them
            if numpy.mod(uidx, sampleFreq) == 0:
                # FIXME: random selection?
                for jj in xrange(5):
                    sample, score = rmn_.gen_sample(tparams, f_next,
                                                    trng=trng, maxlen=30,
                                                    argmax=False)
                    print 'Sample ', jj, ': ',
                    ss = sample
                    for vv in ss:
                        if vv == 0:
                            break
                        if vv in worddicts_r:
                            print worddicts_r[vv],
                        else:
                            print 'UNK',
                    print

            # validate model on validation set and early stop if necessary
            if numpy.mod(uidx, validFreq) == 0:
                use_noise.set_value(0.)
                valid_errs = rmn_.pred_probs(valid, f_log_probs, prepare_data)
                valid_err = valid_errs.mean()
                history_errs.append(valid_err)

                if uidx == 0 or valid_err <= numpy.array(history_errs).min():
                    best_p = unzip(tparams)
                    bad_counter = 0
                if len(history_errs) > patience and valid_err >= \
                        numpy.array(history_errs)[:-patience].min():
                    bad_counter += 1
                    if bad_counter > patience:
                        print 'Early Stop!'
                        estop = True
                        break

                if numpy.isnan(valid_err):
                    ipdb.set_trace()

                print 'Valid ', valid_err

            # finish after this many updates
            if uidx >= finish_after:
                print 'Finishing after %d iterations!' % uidx
                estop = True
                break

        print 'Seen %d samples' % n_samples

        if estop:
            break

    if best_p is not None:
        zipp(best_p, tparams)

    use_noise.set_value(0.)
    valid_err = rmn_.pred_probs(f_log_probs, prepare_data,
                                model_options, valid).mean()

    print 'Valid ', valid_err

    params = copy.copy(best_p)
    numpy.savez(saveto, zipped_params=best_p,
                history_errs=history_errs,
                uidx=uidx,
                **params)

    return valid_err
Beispiel #6
0
def train(
        dim_word=100,  # word vector dimensionality
        dim=100,  # the number of GRU units
        encoder='lstm',  # encoder model
        decoder='lstm',  # decoder model
        patience=10,  # early stopping patience
        max_epochs=5000,
        finish_after=10000000,  # finish after this many updates
        decay_c=0.,  # L2 regularization penalty
        clip_c=-1.,  # gradient clipping threshold
        lrate=0.01,  # learning rate
        n_words=100000,  # vocabulary size
        maxlen=100,  # maximum length of the description
        optimizer='adadelta',
        batch_size=16,
        valid_batch_size=16,
        saveto='model.npz',
        LoadFrom='',
        dispFreq=100,
        validFreq=1000,
        saveFreq=1000,  # save the parameters after every saveFreq updates
        use_dropout=False,
        reload_=False,
        test=1,  # print verbose information for debug but slow speed
        datasets=[],
        valid_datasets=[],
        test_datasets=[],
        test_matched_datasets=[],
        test_mismatched_datasets=[],
        dictionary='',
        embedding='',  # pretrain embedding file, such as word2vec, GLOVE
):
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    #model_options
    model_options = locals().copy()
    model_options[
        'alphabet'] = " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"
    model_options['l_alphabet'] = len(model_options['alphabet'])
    model_options['dim_char_emb'] = 15
    model_options['char_nout'] = 100
    model_options['char_k_rows'] = 5
    model_options['char_k_cols'] = model_options['dim_char_emb']

    #load dictionary and invert them
    with open(dictionary, 'rb') as f:
        worddicts = pkl.load(f, encoding='iso-8859-1')
    worddicts_r = dict()
    for word in worddicts:
        worddicts_r[worddicts[word]] = word

    logger.debug(pprint.pformat(model_options))

    time.sleep(0.1)
    print('Loading data')

    #return (3,batch_size,-1)
    train = TextIterator(datasets[0],
                         datasets[1],
                         datasets[2],
                         dictionary,
                         n_words=n_words,
                         batch_size=batch_size)
    train_valid = TextIterator(datasets[0],
                               datasets[1],
                               datasets[2],
                               dictionary,
                               n_words=n_words,
                               batch_size=valid_batch_size,
                               shuffle=False)
    valid = TextIterator(valid_datasets[0],
                         valid_datasets[1],
                         valid_datasets[2],
                         dictionary,
                         n_words=n_words,
                         batch_size=valid_batch_size,
                         shuffle=False)
    test = TextIterator(test_datasets[0],
                        test_datasets[1],
                        test_datasets[2],
                        dictionary,
                        n_words=n_words,
                        batch_size=valid_batch_size,
                        shuffle=False)
    test_matched = TextIterator(test_matched_datasets[0],
                                test_matched_datasets[1],
                                test_matched_datasets[2],
                                dictionary,
                                n_words=n_words,
                                batch_size=valid_batch_size,
                                shuffle=False)
    test_mismatched = TextIterator(test_mismatched_datasets[0],
                                   test_mismatched_datasets[1],
                                   test_mismatched_datasets[2],
                                   dictionary,
                                   n_words=n_words,
                                   batch_size=valid_batch_size,
                                   shuffle=False)
    print('Building model')
    opt_ret, cost, pred, probs = build_model(model_options, worddicts)
    op = tf.train.AdamOptimizer(model_options['lrate'],
                                beta1=0.9,
                                beta2=0.999,
                                epsilon=1e-8).minimize(cost)

    uidx = 0
    eidx = 0
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if model_options['reload_']:
            saver = tf.train.Saver()
            saver.restore(sess, model_options['LoadFrom'])
            print('Reload dond!')
        train_loss = 0
        while True:
            try:
                x1, x2, label = train.next()
            except:
                eidx += 1
                print(eidx)
                continue
            _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                x1,
                x2,
                label,
                worddicts_r,
                model_options['alphabet'],
                maxlen=maxlen)
            ud_start = time.time()
            _cost, _pred, _prob, _ = sess.run(
                [cost, pred, probs, op],
                feed_dict={
                    use_noise: True,
                    word_x1: _x1,
                    word_x1_mask: _x1_mask,
                    char_x1: _char_x1,
                    word_x2: _x2,
                    word_x2_mask: _x2_mask,
                    char_x2: _char_x2,
                    char_x1_mask: _char_x1_mask,
                    char_x2_mask: _char_x2_mask,
                    y: _y
                })
            ud = time.time() - ud_start
            uidx += 1
            train_loss += _cost
            if uidx % model_options['dispFreq'] == 0:
                logger.debug('Epoch {0} Update {1} Cost {2} UD {3}'.format(
                    eidx,
                    uidx,
                    train_loss / model_options['dispFreq'],
                    ud,
                ))
                train_loss = 0
            if uidx % model_options['validFreq'] == 0:
                valid_cost = 0
                valid_pred = []
                valid_label = []
                n_vaild_samples = 0
                test_cost = 0
                test_pred = []
                test_label = []
                n_test_samples = 0
                while True:
                    try:
                        x1, x2, label = valid.next()
                        _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                            x1,
                            x2,
                            label,
                            worddicts_r,
                            model_options['alphabet'],
                            maxlen=maxlen)
                        _cost, _pred, _prob = sess.run(
                            [cost, pred, probs],
                            feed_dict={
                                use_noise: False,
                                word_x1: _x1,
                                word_x1_mask: _x1_mask,
                                char_x1: _char_x1,
                                word_x2: _x2,
                                word_x2_mask: _x2_mask,
                                char_x2: _char_x2,
                                char_x1_mask: _char_x1_mask,
                                char_x2_mask: _char_x2_mask,
                                y: _y
                            })
                        valid_cost += _cost * len(label)
                        valid_pred.extend(_pred)
                        valid_label.extend(_y)
                        n_vaild_samples += len(label)
                        print('Seen %d samples' % n_vaild_samples)
                    except:
                        break

                while True:
                    try:
                        x1, x2, label = test.next()
                        _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                            x1,
                            x2,
                            label,
                            worddicts_r,
                            model_options['alphabet'],
                            maxlen=maxlen)
                        _cost, _pred, _prob = sess.run(
                            [cost, pred, probs],
                            feed_dict={
                                use_noise: False,
                                word_x1: _x1,
                                word_x1_mask: _x1_mask,
                                char_x1: _char_x1,
                                word_x2: _x2,
                                word_x2_mask: _x2_mask,
                                char_x2: _char_x2,
                                char_x1_mask: _char_x1_mask,
                                char_x2_mask: _char_x2_mask,
                                y: _y
                            })
                        test_cost += _cost * len(label)
                        test_pred.extend(_pred)
                        test_label.extend(_y)
                        n_test_samples += len(label)
                        print('Seen %d samples' % n_test_samples)
                    except:
                        print('Valid cost', valid_cost / len(valid_label))
                        print(
                            'Valid accuracy',
                            numpy.mean(
                                numpy.array(valid_pred) == numpy.array(
                                    valid_label)))
                        print('Test cost', test_cost / len(test_label))
                        print(
                            'Test accuracy',
                            numpy.mean(
                                numpy.array(test_pred) == numpy.array(
                                    test_label)))
                        break
            if uidx % model_options['test'] == 0:
                mismatched_result = []
                matched_result = []
                while True:
                    try:
                        x1, x2, label = test_mismatched.next()
                        _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                            x1,
                            x2,
                            label,
                            worddicts_r,
                            model_options['alphabet'],
                            maxlen=maxlen)
                        _cost, _pred, _prob = sess.run(
                            [cost, pred, probs],
                            feed_dict={
                                use_noise: False,
                                word_x1: _x1,
                                word_x1_mask: _x1_mask,
                                char_x1: _char_x1,
                                word_x2: _x2,
                                word_x2_mask: _x2_mask,
                                char_x2: _char_x2,
                                char_x1_mask: _char_x1_mask,
                                char_x2_mask: _char_x2_mask,
                                y: _y
                            })
                        mismatched_result.extend(_pred)
                        print(len(mismatched_result))
                    except:
                        break
                while True:
                    try:
                        x1, x2, label = test_matched.next()
                        _x1, _x1_mask, _char_x1, _char_x1_mask, _x2, _x2_mask, _char_x2, _char_x2_mask, lengths_x, lengths_y, _y = prepare_data(
                            x1,
                            x2,
                            label,
                            worddicts_r,
                            model_options['alphabet'],
                            maxlen=maxlen)
                        _cost, _pred, _prob = sess.run(
                            [cost, pred, probs],
                            feed_dict={
                                use_noise: False,
                                word_x1: _x1,
                                word_x1_mask: _x1_mask,
                                char_x1: _char_x1,
                                word_x2: _x2,
                                word_x2_mask: _x2_mask,
                                char_x2: _char_x2,
                                char_x1_mask: _char_x1_mask,
                                char_x2_mask: _char_x2_mask,
                                y: _y
                            })
                        matched_result.extend(_pred)
                        print(len(matched_result))
                    except:
                        break
                index = 0
                a = []
                b = []
                dic = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
                for i in mismatched_result:
                    a.append((index, dic[i]))
                    index += 1
                for i in matched_result:
                    b.append((index, dic[i]))
                    index += 1
                a = pd.DataFrame(a)
                a.columns = ['pairID', 'gold_label']
                a.to_csv('sub_mismatched_' + str(uidx) + '.csv', index=False)
                b = pd.DataFrame(b)
                b.columns = ['pairID', 'gold_label']
                b.to_csv('sub_matched_' + str(uidx) + '.csv', index=False)
                print('submission ' + str(uidx) + ' done!')
            if uidx % model_options['saveFreq'] == 0:
                saver = tf.train.Saver()
                save_path = saver.save(
                    sess, model_options['saveto'] + '_' + str(uidx))
                print("Model saved in file: %s" % save_path)
Beispiel #7
0
def train(
        dim_word=100,  # word vector dimensionality
        dim=1000,  # the number of GRU units
        encoder='gru',
        patience=10,  # early stopping patience
        max_epochs=5000,
        finish_after=10000000,  # finish after this many updates
        dispFreq=100,
        decay_c=0.,  # L2 weight decay penalty
        lrate=0.01,
        n_words=100000,  # vocabulary size
        vocab_dim=100000,  # Size of M, C
        memory_dim=1000,  # Dimension of memory
        memory_size=15,  # n_back to attend
        maxlen=100,  # maximum length of the description
        optimizer='rmsprop',
        batch_size=16,
        valid_batch_size=16,
        saveto='model.npz',
        validFreq=1000,
        saveFreq=1000,  # save the parameters after every saveFreq updates
        sampleFreq=100,  # generate some samples after every sampleFreq
        dataset='/data/lisatmp3/chokyun/wikipedia/extracted/wiki.tok.txt.gz',
        valid_dataset='../data/dev/newstest2011.en.tok',
        dictionary='/data/lisatmp3/chokyun/wikipedia/extracted/'
    'wiki.tok.txt.gz.pkl',
        use_dropout=False,
        reload_=False):

    # Model options
    model_options = locals().copy()

    # Theano random stream
    trng = RandomStreams(1234)

    # load dictionary
    with open(dictionary, 'rb') as f:
        worddicts = pkl.load(f)

    # invert dictionary
    worddicts_r = dict()
    for kk, vv in worddicts.iteritems():
        worddicts_r[vv] = kk

    # reload options
    if reload_ and os.path.exists(saveto):
        with open('%s.pkl' % saveto, 'rb') as f:
            model_options = pkl.load(f)

    print 'Loading data'
    train = TextIterator(dataset,
                         dictionary,
                         n_words_source=n_words,
                         batch_size=batch_size,
                         maxlen=maxlen)
    valid = TextIterator(valid_dataset,
                         dictionary,
                         n_words_source=n_words,
                         batch_size=valid_batch_size,
                         maxlen=maxlen)

    # initialize RMN
    rmn_ = RMN(model_options)

    print 'Building model'
    rmn_.init_params()

    # reload parameters
    if reload_ and os.path.exists(saveto):
        rmn_.load_params(saveto)

    # create shared variables for parameters
    tparams = rmn_.tparams

    # build the symbolic computational graph
    use_noise, x, x_mask, opt_ret, cost = rmn_.build_model()
    inps = [x, x_mask]

    print 'Buliding sampler'
    f_next = rmn_.build_sampler(trng)

    # before any regularizer
    print 'Building f_log_probs...',
    f_log_probs = theano.function(inps, cost, profile=profile)
    print 'Done'

    cost = cost.mean()

    # apply L2 regularization on weights
    if decay_c > 0.:
        decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
        weight_decay = 0.
        for kk, vv in tparams.iteritems():
            weight_decay += (vv**2).sum()
        weight_decay *= decay_c
        cost += weight_decay

    # after any regularizer - compile the computational graph for cost
    print 'Building f_cost...',
    f_cost = theano.function(inps, cost, profile=profile)
    print 'Done'

    print 'Computing gradient...',
    grads = tensor.grad(cost, wrt=itemlist(tparams))
    print 'Done'

    # compile the optimizer, the actual computational graph is compiled here
    lr = tensor.scalar(name='lr')
    print 'Building optimizers...',
    optimizer = getattr(importlib.import_module('optimizer'), optimizer)
    f_grad_shared, f_update = optimizer(lr, tparams, grads, inps, cost)
    print 'Done'

    print 'Optimization'

    history_errs = []
    uidx = 0
    estop = False
    bad_counter = 0

    # reload history
    if reload_ and os.path.exists(saveto):
        history_errs = list(numpy.load(saveto)['history_errs'])
        uidx = numpy.load(saveto)['uidx']
    best_p = None

    if validFreq == -1:
        validFreq = len(train[0]) / batch_size
    if saveFreq == -1:
        saveFreq = len(train[0]) / batch_size
    if sampleFreq == -1:
        sampleFreq = len(train[0]) / batch_size

    # Training loop
    for eidx in xrange(max_epochs):
        n_samples = 0

        for x in train:
            n_samples += len(x)
            uidx += 1
            use_noise.set_value(1.)

            # pad batch and create mask
            x, x_mask = prepare_data(x, maxlen=maxlen, n_words=n_words)

            if x is None:
                print 'Minibatch with zero sample under length ', maxlen
                uidx -= 1
                continue

            ud_start = time.time()

            # compute cost, grads and copy grads to shared variables
            cost = f_grad_shared(x, x_mask)

            # do the update on parameters
            f_update(lrate)

            ud = time.time() - ud_start

            # check for bad numbers
            if numpy.isnan(cost) or numpy.isinf(cost):
                print 'NaN detected'
                return 1.

            # verbose
            if numpy.mod(uidx, dispFreq) == 0:
                print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost, 'UD ', ud

            # save the best model so far
            if numpy.mod(uidx, saveFreq) == 0:
                print 'Saving...',

                if best_p is not None:
                    params = best_p
                else:
                    params = unzip(tparams)
                numpy.savez(saveto,
                            history_errs=history_errs,
                            uidx=uidx,
                            **params)
                pkl.dump(model_options, open('%s.pkl' % saveto, 'wb'))
                print 'Done'

            # generate some samples with the model and display them
            if numpy.mod(uidx, sampleFreq) == 0:
                # FIXME: random selection?
                for jj in xrange(5):
                    sample, score = rmn_.gen_sample(tparams,
                                                    f_next,
                                                    trng=trng,
                                                    maxlen=30,
                                                    argmax=False)
                    print 'Sample ', jj, ': ',
                    ss = sample
                    for vv in ss:
                        if vv == 0:
                            break
                        if vv in worddicts_r:
                            print worddicts_r[vv],
                        else:
                            print 'UNK',
                    print

            # validate model on validation set and early stop if necessary
            if numpy.mod(uidx, validFreq) == 0:
                use_noise.set_value(0.)
                valid_errs = rmn_.pred_probs(valid, f_log_probs, prepare_data)
                valid_err = valid_errs.mean()
                history_errs.append(valid_err)

                if uidx == 0 or valid_err <= numpy.array(history_errs).min():
                    best_p = unzip(tparams)
                    bad_counter = 0
                if len(history_errs) > patience and valid_err >= \
                        numpy.array(history_errs)[:-patience].min():
                    bad_counter += 1
                    if bad_counter > patience:
                        print 'Early Stop!'
                        estop = True
                        break

                if numpy.isnan(valid_err):
                    ipdb.set_trace()

                print 'Valid ', valid_err

            # finish after this many updates
            if uidx >= finish_after:
                print 'Finishing after %d iterations!' % uidx
                estop = True
                break

        print 'Seen %d samples' % n_samples

        if estop:
            break

    if best_p is not None:
        zipp(best_p, tparams)

    use_noise.set_value(0.)
    valid_err = rmn_.pred_probs(f_log_probs, prepare_data, model_options,
                                valid).mean()

    print 'Valid ', valid_err

    params = copy.copy(best_p)
    numpy.savez(saveto,
                zipped_params=best_p,
                history_errs=history_errs,
                uidx=uidx,
                **params)

    return valid_err