Exemplo n.º 1
0
class Attention(object):
    def __init__(self, channel=None):
        self.rng_numpy, self.rng_theano = get_two_rngs()
        self.layers = Layers()
        self.predict = Predict()
        self.channel = channel

    def load_params(self, path, params):
        # load params from disk
        pp = np.load(path)
        for kk, vv in params.iteritems():
            if kk not in pp:
                raise Warning('%s is not in the archive'%kk)
            params[kk] = pp[kk]

        return params

    def init_params(self, options):
        # all parameters
        params = OrderedDict()
        # embedding
        params['Wemb'] = norm_weight(options['n_words'], options['dim_word'])

        ctx_dim = options['ctx_dim']

        # init_state, init_cell
        params = self.layers.get_layer('ff')[0](params, nin=ctx_dim, nout=options['mu_dim'],
                                                prefix='ff_state')
        params = self.layers.get_layer('ff')[0](params, nin=ctx_dim, nout=options['mu_dim'],
                                                prefix='ff_memory')

        # decoder: LSTM
        params = self.layers.get_layer('lstm')[0](params, nin=options['dim_word'],
                                                  dim=options['tu_dim'], prefix='tu_lstm')
        params = self.layers.get_layer('attend')[0](params, nin=options['tu_dim'],
                                                    dimctx=ctx_dim, prefix='attend')
        params = self.layers.get_layer('lstm_concat')[0](options, params, nin=options['tu_dim'],
                                                         dim=options['mu_dim'], dimctx=ctx_dim,
                                                         prefix='mu_lstm')

        # readout
        params = self.layers.get_layer('ff')[0](params, nin=options['mu_dim'], nout=options['dim_word'],
                                                prefix='ff_logit_lstm')
        if options['ctx2out']:
            params = self.layers.get_layer('ff')[0](params, nin=ctx_dim, nout=options['dim_word'],
                                                    prefix='ff_logit_ctx')

        params = self.layers.get_layer('ff')[0](params, nin=options['dim_word'], nout=options['n_words'],
                                                prefix='ff_logit')
        return params

    def build_model(self, tparams, options):
        trng = RandomStreams(1234)
        use_noise = theano.shared(np.float32(0.))
        # description string: #words x #samples
        x = tensor.matrix('x', dtype='int64')
        mask = tensor.matrix('mask', dtype='float32')
        # context: #samples x #annotations x dim
        ctx = tensor.tensor3('ctx', dtype='float32')
        mask_ctx = tensor.matrix('mask_ctx', dtype='float32')

        n_timesteps = x.shape[0]
        n_samples = x.shape[1]

        # index into the word embedding matrix, shift it forward in time
        emb = tparams['Wemb'][x.flatten()].reshape(
                [n_timesteps, n_samples, options['dim_word']])
        emb_shifted = tensor.zeros_like(emb)
        emb_shifted = tensor.set_subtensor(emb_shifted[1:], emb[:-1])
        emb = emb_shifted

        ctx_ = ctx
        counts = mask_ctx.sum(-1).dimshuffle(0,'x')
        ctx_mean = ctx_.sum(1)/counts

        # initial state/cell
        init_state = self.layers.get_layer('ff')[1](tparams, ctx_mean,
                                                    activ='tanh', prefix='ff_state')
        init_memory = self.layers.get_layer('ff')[1](tparams, ctx_mean,
                                                     activ='tanh', prefix='ff_memory')

        # decoder
        tu_lstm = self.layers.get_layer('lstm')[1](tparams, emb, mask=mask, prefix='tu_lstm')
        attend = self.layers.get_layer('attend')[1](tparams, tu_lstm[0], ctx_)
        mu_lstm = self.layers.get_layer('lstm_concat')[1](options, tparams, tu_lstm[0],
                                                          mask=mask, ctxs=attend[1],
                                                          one_step=False,
                                                          init_state=init_state,
                                                          init_memory=init_memory,
                                                          trng=trng,
                                                          use_noise=use_noise,
                                                          prefix='mu_lstm')

        proj_h = mu_lstm[0]
        betas = mu_lstm[2]
        ctxs = mu_lstm[3]
        alphas = attend[0]
        if options['use_dropout']:
            proj_h = self.layers.dropout_layer(proj_h, use_noise, trng)

        # compute word probabilities
        logit = self.layers.get_layer('ff')[1](tparams, proj_h, activ='linear',
                                               prefix='ff_logit_lstm')
        if options['prev2out']:
            logit += emb
        if options['ctx2out']:
            logit += self.layers.get_layer('ff')[1](tparams, ctxs, activ='linear',
                                                    prefix='ff_logit_ctx')
        logit = tanh(logit)
        if options['use_dropout']:
            logit = self.layers.dropout_layer(logit, use_noise, trng)

        # (t,m,n_words)
        logit = self.layers.get_layer('ff')[1](tparams, logit,
                                               activ='linear', prefix='ff_logit')
        logit_shp = logit.shape
        # (t*m, n_words)
        probs = tensor.nnet.softmax(logit.reshape([logit_shp[0]*logit_shp[1],
                                                   logit_shp[2]]))

        # cost
        x_flat = x.flatten() # (t*m,)
        cost = -tensor.log(probs[T.arange(x_flat.shape[0]), x_flat] + 1e-8)
        cost = cost.reshape([x.shape[0], x.shape[1]])
        cost = (cost * mask).sum(0)

        extra = [probs, alphas, betas]
        test = [attend[1]]
        return trng, use_noise, x, mask, ctx, mask_ctx, alphas, cost, extra, test

    def pred_probs(self, whichset, f_log_probs, verbose=True):
        probs = []
        n_done = 0
        NLL = []
        L = []
        if whichset == 'train':
            tags = self.engine.train
            iterator = self.engine.kf_train
        elif whichset == 'valid':
            tags = self.engine.valid
            iterator = self.engine.kf_valid
        elif whichset == 'test':
            tags = self.engine.test
            iterator = self.engine.kf_test
        else:
            raise NotImplementedError()
        n_samples = np.sum([len(index) for index in iterator])
        for index in iterator:
            tag = [tags[i] for i in index]
            x, mask, ctx, ctx_mask,vid_names = data_engine.prepare_data(
                self.engine, tag)
            pred_probs = f_log_probs(x, mask, ctx, ctx_mask)
            L.append(mask.sum(0).tolist())
            NLL.append((-1 * pred_probs).tolist())
            probs.append(pred_probs.tolist())
            n_done += len(tag)
            if verbose:
                sys.stdout.write('\rComputing LL on %d/%d examples'%(
                             n_done, n_samples))
                sys.stdout.flush()
        print
        probs = flatten_list_of_list(probs)
        NLL = flatten_list_of_list(NLL)
        L = flatten_list_of_list(L)
        perp = 2**(np.sum(NLL) / np.sum(L) / np.log(2))
        return -1 * np.mean(probs), perp

    def train(self,
              random_seed=1234,
              reload_=False,
              verbose=True,
              debug=True,
              save_model_dir='',
              from_dir=None,
              # dataset
              dataset='youtube2text',
              video_feature='googlenet',
              K=10,
              OutOf=240,
              # network
              dim_word=256, # word vector dimensionality
              ctx_dim=-1, # context vector dimensionality, auto set
              tu_dim=512,
              mu_dim=1024,
              vu_dim=1024,
              n_layers_out=1,
              n_layers_init=1,
              prev2out=False,
              ctx2out=False,
              selector=False,
              n_words=100000,
              maxlen=100, # maximum length of the description
              use_dropout=False,
              isGlobal=False,
              # training
              patience=10,
              max_epochs=5000,
              decay_c=0.,
              alpha_c=0.,
              alpha_entropy_r=0.,
              lrate=0.01,
              optimizer='adadelta',
              clip_c=2.,
              # minibatch
              batch_size = 64,
              valid_batch_size = 64,
              dispFreq=100,
              validFreq=10,
              saveFreq=10, # save the parameters after every saveFreq updates
              sampleFreq=10, # generate some samples after every sampleFreq updates
              # metric
              metric='blue'
              ):

        self.rng_numpy, self.rng_theano = get_two_rngs()

        model_options = locals().copy()
        if 'self' in model_options:
            del model_options['self']
        model_options = validate_options(model_options)
        with open('%smodel_options.pkl'%save_model_dir, 'wb') as f:
            pkl.dump(model_options, f)

        print 'Loading data'
        self.engine = data_engine.Movie2Caption('attention', dataset,
                                                video_feature,
                                                batch_size, valid_batch_size,
                                                maxlen, n_words,
                                                K, OutOf)
        model_options['ctx_dim'] = self.engine.ctx_dim

        print 'init params'
        t0 = time.time()
        params = self.init_params(model_options)

        # reloading
        if reload_:
            model_saved = from_dir+'/model_best_so_far.npz'
            assert os.path.isfile(model_saved)
            print "Reloading model params..."
            params = load_params(model_saved, params)

        tparams = init_tparams(params)
        if verbose:
            print tparams.keys

        trng, use_noise, x, mask, ctx, mask_ctx, alphas, cost, extra, test = \
            self.build_model(tparams, model_options)

        if debug:
            print 'buliding test'
            test_fun = theano.function([x, mask, ctx, mask_ctx],
                                       test,
                                       name='f_test',
                                       on_unused_input='ignore')

        print 'buliding sampler'
        f_init, f_next = self.predict.build_sampler(self.layers, tparams, model_options, use_noise, trng)

        # before any regularizer
        print 'building f_log_probs'
        f_log_probs = theano.function([x, mask, ctx, mask_ctx], -cost,
                                      profile=False, on_unused_input='ignore')

        cost = cost.mean()
        if decay_c > 0.:
            decay_c = theano.shared(np.float32(decay_c), name='decay_c')
            weight_decay = 0.
            for kk, vv in tparams.iteritems():
                weight_decay += (vv ** 2).sum()
            weight_decay *= decay_c
            cost += weight_decay

        if alpha_c > 0.:
            alpha_c = theano.shared(np.float32(alpha_c), name='alpha_c')
            alpha_reg = alpha_c * ((1.-alphas.sum(0))**2).sum(0).mean()
            cost += alpha_reg

        if alpha_entropy_r > 0:
            alpha_entropy_r = theano.shared(np.float32(alpha_entropy_r),
                                            name='alpha_entropy_r')
            alpha_reg_2 = alpha_entropy_r * (-tensor.sum(alphas *
                        tensor.log(alphas+1e-8),axis=-1)).sum(0).mean()
            cost += alpha_reg_2
        else:
            alpha_reg_2 = tensor.zeros_like(cost)
        print 'building f_alpha'
        f_alpha = theano.function([x, mask, ctx, mask_ctx],
                                  [alphas, alpha_reg_2],
                                  name='f_alpha',
                                  on_unused_input='ignore')

        print 'compute grad'
        grads = tensor.grad(cost, wrt=itemlist(tparams))
        if clip_c > 0.:
            g2 = 0.
            for g in grads:
                g2 += (g**2).sum()
            new_grads = []
            for g in grads:
                new_grads.append(tensor.switch(g2 > (clip_c**2),
                                               g / tensor.sqrt(g2) * clip_c,
                                               g))
            grads = new_grads

        lr = tensor.scalar(name='lr')
        print 'build train fns'
        f_grad_shared, f_update = eval(optimizer)(lr, tparams, grads,
                                                  [x, mask, ctx, mask_ctx], cost,
                                                  extra + grads)

        print 'compilation took %.4f sec'%(time.time()-t0)
        print 'Optimization'

        history_errs = []
        # reload history
        if reload_:
            print 'loading history error...'
            history_errs = np.load(
                from_dir+'model_best_so_far.npz')['history_errs'].tolist()

        bad_counter = 0

        processes = None
        queue = None
        rqueue = None
        shared_params = None

        uidx = 0
        uidx_best_blue = 0
        uidx_best_valid_err = 0
        estop = False
        best_p = unzip(tparams)
        best_blue_valid = 0
        best_valid_err = 999
        alphas_ratio = []
        for eidx in xrange(max_epochs):
            n_samples = 0
            train_costs = []
            grads_record = []
            print 'Epoch ', eidx
            for idx in self.engine.kf_train:
                tags = [self.engine.train[index] for index in idx]
                n_samples += len(tags)
                uidx += 1
                use_noise.set_value(1.)

                pd_start = time.time()
                x, mask, ctx, ctx_mask,vid_names = data_engine.prepare_data(
                    self.engine, tags)

                if debug:
                    datas = test_fun(x, mask, ctx, ctx_mask)
                    for item in datas:
                        print item[0].shape

                pd_duration = time.time() - pd_start
                if x is None:
                    print 'Minibatch with zero sample under length ', maxlen
                    continue

                ud_start = time.time()
                rvals = f_grad_shared(x, mask, ctx, ctx_mask)
                cost = rvals[0]
                probs = rvals[1]
                alphas = rvals[2]
                betas = rvals[3]
                grads = rvals[4:]
                grads, NaN_keys = grad_nan_report(grads, tparams)
                if len(grads_record) >= 5:
                    del grads_record[0]
                grads_record.append(grads)
                if NaN_keys != []:
                    print 'grads contain NaN'
                    import pdb; pdb.set_trace()
                if np.isnan(cost) or np.isinf(cost):
                    print 'NaN detected in cost'
                    import pdb; pdb.set_trace()
                # update params
                f_update(lrate)
                ud_duration = time.time() - ud_start

                if eidx == 0:
                    train_error = cost
                else:
                    train_error = train_error * 0.95 + cost * 0.05
                train_costs.append(cost)

                if np.mod(uidx, dispFreq) == 0:
                    print 'Epoch ', eidx, ', Update ', uidx, \
                        ', Train cost mean so far', train_error, \
                        ', betas mean', np.round(betas.mean(), 3), \
                        ', fetching data time spent (sec)', np.round(pd_duration, 3), \
                        ', update time spent (sec)', np.round(ud_duration, 3)
                    alphas,reg = f_alpha(x,mask,ctx,ctx_mask)
                    print 'alpha ratio %.3f, reg %.3f' % (
                        alphas.min(-1).mean() / (alphas.max(-1)).mean(), reg)

                if np.mod(uidx, saveFreq) == 0:
                    pass

                if np.mod(uidx, sampleFreq) == 0:
                    use_noise.set_value(0.)
                    print '------------- sampling from train ----------'
                    self.predict.sample_execute(self.engine, model_options, tparams,
                                                f_init, f_next, x, ctx, ctx_mask, trng,vid_names)

                    print '------------- sampling from valid ----------'
                    idx = self.engine.kf_valid[np.random.randint(1, len(self.engine.kf_valid) - 1)]
                    tags = [self.engine.valid[index] for index in idx]
                    x_s, mask_s, ctx_s, mask_ctx_s,vid_names = data_engine.prepare_data(self.engine, tags)
                    self.predict.sample_execute(self.engine, model_options, tparams,
                                                f_init, f_next, x_s, ctx_s, mask_ctx_s, trng, vid_names)
                    # end of sample

                if validFreq != -1 and np.mod(uidx, validFreq) == 0:
                    t0_valid = time.time()
                    alphas,_ = f_alpha(x, mask, ctx, ctx_mask)
                    ratio = alphas.min(-1).mean()/(alphas.max(-1)).mean()
                    alphas_ratio.append(ratio)
                    np.savetxt(save_model_dir+'alpha_ratio.txt',alphas_ratio)

                    current_params = unzip(tparams)
                    np.savez(save_model_dir+'model_current.npz',
                             history_errs=history_errs, **current_params)

                    use_noise.set_value(0.)
                    train_err = -1
                    train_perp = -1
                    valid_err = -1
                    valid_perp = -1
                    test_err = -1
                    test_perp = -1

                    if not debug:
                        # first compute train cost
                        if 0:
                            print 'computing cost on trainset'
                            train_err, train_perp = self.pred_probs(
                                    'train', f_log_probs,
                                    verbose=model_options['verbose'])
                        else:
                            train_err = 0.
                            train_perp = 0.
                        if 1:
                            print 'validating...'
                            valid_err, valid_perp = self.pred_probs(
                                'valid', f_log_probs,
                                verbose=model_options['verbose'],
                                )
                        else:
                            valid_err = 0.
                            valid_perp = 0.
                        if 0:
                            print 'testing...'
                            test_err, test_perp = self.pred_probs(
                                'test', f_log_probs,
                                verbose=model_options['verbose']
                                )
                        else:
                            test_err = 0.
                            test_perp = 0.

                    mean_ranking = 0
                    blue_t0 = time.time()
                    scores, processes, queue, rqueue, shared_params = \
                        metrics.compute_score(model_type='attention',
                                              model_archive=current_params,
                                              options=model_options,
                                              engine=self.engine,
                                              save_dir=save_model_dir,
                                              beam=5, n_process=5,
                                              whichset='both',
                                              on_cpu=False,
                                              processes=processes, queue=queue, rqueue=rqueue,
                                              shared_params=shared_params, metric=metric,
                                              one_time=False,
                                              f_init=f_init, f_next=f_next, model=self.predict
                                              )

                    valid_B1 = scores['valid']['Bleu_1']
                    valid_B2 = scores['valid']['Bleu_2']
                    valid_B3 = scores['valid']['Bleu_3']
                    valid_B4 = scores['valid']['Bleu_4']
                    valid_Rouge = scores['valid']['ROUGE_L']
                    valid_Cider = scores['valid']['CIDEr']
                    valid_meteor = scores['valid']['METEOR']
                    test_B1 = scores['test']['Bleu_1']
                    test_B2 = scores['test']['Bleu_2']
                    test_B3 = scores['test']['Bleu_3']
                    test_B4 = scores['test']['Bleu_4']
                    test_Rouge = scores['test']['ROUGE_L']
                    test_Cider = scores['test']['CIDEr']
                    test_meteor = scores['test']['METEOR']
                    print 'computing meteor/blue score used %.4f sec, '\
                          'blue score: %.1f, meteor score: %.1f'%(
                    time.time()-blue_t0, valid_B4, valid_meteor)
                    history_errs.append([eidx, uidx, train_perp, train_err,
                                         valid_perp, valid_err,
                                         test_perp, test_err,
                                         valid_B1, valid_B2, valid_B3,
                                         valid_B4, valid_meteor, valid_Rouge, valid_Cider,
                                         test_B1, test_B2, test_B3,
                                         test_B4, test_meteor, test_Rouge, test_Cider])
                    np.savetxt(save_model_dir+'train_valid_test.txt',
                               history_errs, fmt='%.3f')
                    print 'save validation results to %s'%save_model_dir
                    # save best model according to the best blue or meteor
                    if len(history_errs) > 1 and valid_B4 > np.array(history_errs)[:-1, 11].max():
                        print 'Saving to %s...'%save_model_dir,
                        np.savez(
                            save_model_dir+'model_best_blue_or_meteor.npz',
                            history_errs=history_errs, **best_p)
                    if len(history_errs) > 1 and valid_err < np.array(history_errs)[:-1, 5].min():
                        best_p = unzip(tparams)
                        bad_counter = 0
                        best_valid_err = valid_err
                        uidx_best_valid_err = uidx

                        print 'Saving to %s...'%save_model_dir,
                        np.savez(
                            save_model_dir+'model_best_so_far.npz',
                            history_errs=history_errs, **best_p)
                        with open('%smodel_options.pkl'%save_model_dir, 'wb') as f:
                            pkl.dump(model_options, f)
                        print 'Done'
                    elif len(history_errs) > 1 and valid_err >= np.array(history_errs)[:-1, 5].min():
                        bad_counter += 1
                        print 'history best ', np.array(history_errs)[:,6].min()
                        print 'bad_counter ', bad_counter
                        print 'patience ', patience
                        if bad_counter > patience:
                            print 'Early Stop!'
                            estop = True
                            break

                    if test_B4 > 0.48 and test_meteor > 0.32:
                        print 'Saving to %s...' % save_model_dir,
                        numpy.savez(
                            save_model_dir + 'model_' + str(uidx) + '.npz',
                            history_errs=history_errs, **current_params)

                    if self.channel:
                        self.channel.save()

                    print 'Train ', train_err, 'Valid ', valid_err, 'Test ', test_err, \
                          'best valid err so far',best_valid_err
                    print 'valid took %.2f sec'%(time.time() - t0_valid)
                    # end of validatioin
                if debug:
                    break
            if estop:
                break
            if debug:
                break

            # end for loop over minibatches
            print 'This epoch has seen %d samples, train cost %.2f'%(
                n_samples, np.mean(train_costs))
        # end for loop over epochs
        print 'Optimization ended.'
        if best_p is not None:
            zipp(best_p, tparams)

        print 'stopped at epoch %d, minibatch %d, '\
              'curent Train %.2f, current Valid %.2f, current Test %.2f '%(
               eidx, uidx, np.mean(train_err), np.mean(valid_err), np.mean(test_err))
        params = copy.copy(best_p)
        np.savez(save_model_dir+'model_best.npz',
                 train_err=train_err,
                 valid_err=valid_err, test_err=test_err, history_errs=history_errs,
                 **params)

        if history_errs != []:
            history = np.asarray(history_errs)
            best_valid_idx = history[:,6].argmin()
            np.savetxt(save_model_dir+'train_valid_test.txt', history, fmt='%.4f')
            print 'final best exp ', history[best_valid_idx]

        return train_err, valid_err, test_err
Exemplo n.º 2
0
class Model(object):
    def __init__(self):
        self.layers = Layers()

    def init_params(self, options):
        # all parameters
        params = OrderedDict()
        # embedding
        ctx_dim_c = 4096
        params['Wemb'] = utils.norm_weight(options['n_words'],
                                           options['dim_word'])

        ctx_dim = options['ctx_dim']

        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_state',
                                                nin=ctx_dim,
                                                nout=options['dim'])
        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_memory',
                                                nin=ctx_dim,
                                                nout=options['dim'])
        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_state_c',
                                                nin=ctx_dim_c,
                                                nout=options['dim'])
        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_memory_c',
                                                nin=ctx_dim_c,
                                                nout=options['dim'])

        # decoder: LSTM
        params = self.layers.get_layer('lstm_cond')[0](options,
                                                       params,
                                                       prefix='bo_lstm',
                                                       nin=options['dim_word'],
                                                       dim=options['dim'],
                                                       dimctx=ctx_dim)
        params = self.layers.get_layer('lstm')[0](params,
                                                  nin=options['dim'],
                                                  dim=options['dim'],
                                                  prefix='to_lstm')

        # readout
        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_logit_bo',
                                                nin=options['dim'],
                                                nout=options['dim_word'])
        if options['ctx2out']:
            params = self.layers.get_layer('ff')[0](options,
                                                    params,
                                                    prefix='ff_logit_ctx',
                                                    nin=ctx_dim,
                                                    nout=options['dim_word'])
            params = self.layers.get_layer('ff')[0](options,
                                                    params,
                                                    prefix='ff_logit_ctx_c',
                                                    nin=ctx_dim_c,
                                                    nout=options['dim_word'])
            params = self.layers.get_layer('ff')[0](options,
                                                    params,
                                                    prefix='ff_logit_to',
                                                    nin=options['dim'],
                                                    nout=options['dim_word'])

        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_logit',
                                                nin=options['dim_word'],
                                                nout=options['n_words'])
        return params

    def build_model(self, tparams, options):
        trng = RandomStreams(1234)
        use_noise = theano.shared(numpy.float32(0.))
        # description string: #words x #samples
        x = tensor.matrix('x', dtype='int64')
        mask = tensor.matrix('mask', dtype='float32')
        # context: #samples x #annotations x dim
        ctx = tensor.tensor3('ctx', dtype='float32')
        mask_ctx = tensor.matrix('mask_ctx', dtype='float32')
        ctx_c = tensor.tensor3('ctx_c', dtype='float32')
        mask_ctx_c = tensor.matrix('mask_ctx_c', dtype='float32')
        n_timesteps = x.shape[0]
        n_samples = x.shape[1]

        # index into the word embedding matrix, shift it forward in time
        emb = tparams['Wemb'][x.flatten()].reshape(
            [n_timesteps, n_samples, options['dim_word']])
        emb_shifted = tensor.zeros_like(emb)
        emb_shifted = tensor.set_subtensor(emb_shifted[1:], emb[:-1])
        emb = emb_shifted
        counts = mask_ctx.sum(-1).dimshuffle(0, 'x')

        ctx_ = ctx
        ctx_c_ = ctx_c

        ctx0 = ctx_
        ctx_mean = ctx0.sum(1) / counts

        ctx0_c = ctx_c_
        ctx_mean_c = ctx0_c.sum(1) / counts

        # initial state/cell
        init_state = self.layers.get_layer('ff')[1](tparams,
                                                    ctx_mean,
                                                    options,
                                                    prefix='ff_state',
                                                    activ='tanh')
        init_memory = self.layers.get_layer('ff')[1](tparams,
                                                     ctx_mean,
                                                     options,
                                                     prefix='ff_memory',
                                                     activ='tanh')
        init_state_c = self.layers.get_layer('ff')[1](tparams,
                                                      ctx_mean_c,
                                                      options,
                                                      prefix='ff_state_c',
                                                      activ='tanh')
        init_memory_c = self.layers.get_layer('ff')[1](tparams,
                                                       ctx_mean_c,
                                                       options,
                                                       prefix='ff_memory_c',
                                                       activ='tanh')

        init_state += init_state_c
        init_memory += init_memory_c

        # decoder
        bo_lstm = self.layers.get_layer('lstm_cond')[1](
            tparams,
            emb,
            options,
            prefix='bo_lstm',
            mask=mask,
            context=ctx0,
            context_c=ctx0_c,
            one_step=False,
            init_state=init_state,
            init_memory=init_memory,
            trng=trng,
            use_noise=use_noise)
        to_lstm = self.layers.get_layer('lstm')[1](tparams,
                                                   bo_lstm[0],
                                                   mask=mask,
                                                   one_step=False,
                                                   prefix='to_lstm')

        bo_lstm_h = bo_lstm[0]
        to_lstm_h = to_lstm[0]
        alphas = bo_lstm[2]
        alphas_c = bo_lstm[3]
        ctxs = bo_lstm[4]
        ctxs_c = bo_lstm[5]
        weight = bo_lstm[6]
        if options['use_dropout']:
            bo_lstm_h = self.layers.dropout_layer(bo_lstm_h, use_noise, trng)
            to_lstm_h = self.layers.dropout_layer(to_lstm_h, use_noise, trng)

        # compute word probabilities
        logit = self.layers.get_layer('ff')[1](tparams,
                                               bo_lstm_h,
                                               options,
                                               prefix='ff_logit_bo',
                                               activ='linear')
        if options['prev2out']:
            logit += emb
        if options['ctx2out']:
            betas = weight[:, :, 2]
            #betas = betas.reshape([betas.shape[1],betas.shape[2]])
            to_lstm_h *= betas[:, :, None]
            ctxs_beta = self.layers.get_layer('ff')[1](tparams,
                                                       ctxs,
                                                       options,
                                                       prefix='ff_logit_ctx',
                                                       activ='linear')
            ctxs_beta_c = self.layers.get_layer('ff')[1](
                tparams,
                ctxs_c,
                options,
                prefix='ff_logit_ctx_c',
                activ='linear')
            to_lstm_h = self.layers.get_layer('ff')[1](tparams,
                                                       to_lstm_h,
                                                       options,
                                                       prefix='ff_logit_to',
                                                       activ='linear')
            logit = logit + ctxs_beta + ctxs_beta_c + to_lstm_h
        logit = utils.tanh(logit)

        if options['use_dropout']:
            logit = self.layers.dropout_layer(logit, use_noise, trng)

        # (t,m,n_words)
        logit = self.layers.get_layer('ff')[1](tparams,
                                               logit,
                                               options,
                                               prefix='ff_logit',
                                               activ='linear')
        logit_shp = logit.shape
        # (t*m, n_words)
        probs = tensor.nnet.softmax(
            logit.reshape([logit_shp[0] * logit_shp[1], logit_shp[2]]))
        # cost
        x_flat = x.flatten()  # (t*m,)
        cost = -tensor.log(probs[tensor.arange(x_flat.shape[0]), x_flat] +
                           1e-8)

        cost = cost.reshape([x.shape[0], x.shape[1]])
        cost = (cost * mask).sum(0)
        extra = [probs, alphas, alphas_c, weight[:, :, 0], weight[:, :, 1]]

        return trng, use_noise, x, mask, ctx, mask_ctx, ctx_c, mask_ctx_c, cost, extra

    def build_sampler(self, tparams, options, use_noise, trng, mode=None):
        # context: #annotations x dim
        ctx0 = tensor.matrix('ctx_sampler', dtype='float32')
        # ctx0.tag.test_value = numpy.random.uniform(size=(50,1024)).astype('float32')
        ctx_mask = tensor.vector('ctx_mask', dtype='float32')
        # ctx_mask.tag.test_value = numpy.random.binomial(n=1,p=0.5,size=(50,)).astype('float32')
        ctx0_c = tensor.matrix('ctx_sampler_c', dtype='float32')
        # ctx0.tag.test_value = numpy.random.uniform(size=(50,1024)).astype('float32')
        ctx_mask_c = tensor.vector('ctx_mask_c', dtype='float32')

        ctx_ = ctx0
        counts = ctx_mask.sum(-1)

        ctx = ctx_
        ctx_mean = ctx.sum(0) / counts

        ctx_c_ = ctx0_c
        counts_c = ctx_mask_c.sum(-1)

        ctx_c = ctx_c_
        ctx_mean_c = ctx_c.sum(0) / counts_c

        # ctx_mean = ctx.mean(0)
        ctx = ctx.dimshuffle('x', 0, 1)
        # initial state/cell
        bo_init_state = self.layers.get_layer('ff')[1](tparams,
                                                       ctx_mean,
                                                       options,
                                                       prefix='ff_state',
                                                       activ='tanh')
        bo_init_memory = self.layers.get_layer('ff')[1](tparams,
                                                        ctx_mean,
                                                        options,
                                                        prefix='ff_memory',
                                                        activ='tanh')

        bo_init_state_c = self.layers.get_layer('ff')[1](tparams,
                                                         ctx_mean_c,
                                                         options,
                                                         prefix='ff_state_c',
                                                         activ='tanh')
        bo_init_memory_c = self.layers.get_layer('ff')[1](tparams,
                                                          ctx_mean_c,
                                                          options,
                                                          prefix='ff_memory_c',
                                                          activ='tanh')

        bo_init_state += bo_init_state_c
        bo_init_memory += bo_init_memory_c

        to_init_state = tensor.alloc(0., options['dim'])
        to_init_memory = tensor.alloc(0., options['dim'])
        init_state = [bo_init_state, to_init_state]
        init_memory = [bo_init_memory, to_init_memory]

        print 'Building f_init...',
        f_init = theano.function([ctx0, ctx_mask, ctx0_c, ctx_mask_c],
                                 [ctx0] + init_state + init_memory,
                                 name='f_init',
                                 on_unused_input='ignore',
                                 profile=False,
                                 mode=mode)
        print 'Done'

        x = tensor.vector('x_sampler', dtype='int64')
        init_state = [
            tensor.matrix('bo_init_state', dtype='float32'),
            tensor.matrix('to_init_state', dtype='float32')
        ]
        init_memory = [
            tensor.matrix('bo_init_memory', dtype='float32'),
            tensor.matrix('to_init_memory', dtype='float32')
        ]

        # if it's the first word, emb should be all zero
        emb = tensor.switch(x[:, None] < 0,
                            tensor.alloc(0., 1, tparams['Wemb'].shape[1]),
                            tparams['Wemb'][x])

        bo_lstm = self.layers.get_layer('lstm_cond')[1](
            tparams,
            emb,
            options,
            prefix='bo_lstm',
            mask=None,
            context=ctx,
            context_c=ctx_c,
            one_step=True,
            init_state=init_state[0],
            init_memory=init_memory[0],
            trng=trng,
            use_noise=use_noise,
            mode=mode)
        to_lstm = self.layers.get_layer('lstm')[1](tparams,
                                                   bo_lstm[0],
                                                   mask=None,
                                                   one_step=True,
                                                   init_state=init_state[1],
                                                   init_memory=init_memory[1],
                                                   prefix='to_lstm')
        next_state = [bo_lstm[0], to_lstm[0]]
        next_memory = [bo_lstm[1], to_lstm[0]]

        bo_lstm_h = bo_lstm[0]
        to_lstm_h = to_lstm[0]
        alphas = bo_lstm[2]
        alphas_c = bo_lstm[3]
        ctxs = bo_lstm[4]
        ctxs_c = bo_lstm[5]
        weight = bo_lstm[6]
        if options['use_dropout']:
            bo_lstm_h = self.layers.dropout_layer(bo_lstm_h, use_noise, trng)
            to_lstm_h = self.layers.dropout_layer(to_lstm_h, use_noise, trng)

        logit = self.layers.get_layer('ff')[1](tparams,
                                               bo_lstm_h,
                                               options,
                                               prefix='ff_logit_bo',
                                               activ='linear')
        if options['prev2out']:
            logit += emb
        if options['ctx2out']:
            betas = weight[:, 2]
            # betas = betas.reshape([betas.shape[1],betas.shape[2]])
            to_lstm_h *= betas[:, None]
            ctxs_beta = self.layers.get_layer('ff')[1](tparams,
                                                       ctxs,
                                                       options,
                                                       prefix='ff_logit_ctx',
                                                       activ='linear')
            ctxs_beta_c = self.layers.get_layer('ff')[1](
                tparams,
                ctxs_c,
                options,
                prefix='ff_logit_ctx_c',
                activ='linear')
            to_lstm_h = self.layers.get_layer('ff')[1](tparams,
                                                       to_lstm_h,
                                                       options,
                                                       prefix='ff_logit_to',
                                                       activ='linear')
            logit = logit + ctxs_beta + ctxs_beta_c + to_lstm_h
        logit = utils.tanh(logit)
        if options['use_dropout']:
            logit = self.layers.dropout_layer(logit, use_noise, trng)

        logit = self.layers.get_layer('ff')[1](tparams,
                                               logit,
                                               options,
                                               prefix='ff_logit',
                                               activ='linear')
        logit_shp = logit.shape
        next_probs = tensor.nnet.softmax(logit)
        next_sample = trng.multinomial(pvals=next_probs).argmax(1)

        # next word probability
        print 'building f_next...'
        f_next = theano.function(
            [x, ctx0, ctx_mask, ctx0_c, ctx_mask_c] + init_state + init_memory,
            [next_probs, next_sample] + next_state + next_memory,
            name='f_next',
            profile=False,
            mode=mode,
            on_unused_input='ignore')
        print 'Done'
        return f_init, f_next

    def gen_sample(self,
                   tparams,
                   f_init,
                   f_next,
                   ctx0,
                   ctx0_c,
                   ctx_mask,
                   ctx_mask_c,
                   options,
                   trng=None,
                   k=1,
                   maxlen=30,
                   stochastic=False,
                   restrict_voc=False):
        '''
        ctx0: (26,1024)
        ctx_mask: (26,)

        restrict_voc: set the probability of outofvoc words with 0, renormalize
        '''

        if k > 1:
            assert not stochastic, 'Beam search does not support stochastic sampling'

        sample = []
        sample_score = []
        if stochastic:
            sample_score = 0

        live_k = 1
        dead_k = 0

        hyp_samples = [[]] * live_k
        hyp_scores = numpy.zeros(live_k).astype('float32')
        hyp_states = []
        hyp_memories = []

        # [(26,1024),(512,),(512,)]
        rval = f_init(ctx0, ctx_mask, ctx0_c, ctx_mask_c)
        ctx0 = rval[0]

        next_state = []
        next_memory = []
        n_layers_lstm = 2

        for lidx in xrange(n_layers_lstm):
            next_state.append(rval[1 + lidx])
            next_state[-1] = next_state[-1].reshape(
                [live_k, next_state[-1].shape[0]])
        for lidx in xrange(n_layers_lstm):
            next_memory.append(rval[1 + n_layers_lstm + lidx])
            next_memory[-1] = next_memory[-1].reshape(
                [live_k, next_memory[-1].shape[0]])
        next_w = -1 * numpy.ones((1, )).astype('int64')
        # next_state: [(1,512)]
        # next_memory: [(1,512)]
        for ii in xrange(maxlen):
            # return [(1, 50000), (1,), (1, 512), (1, 512)]
            # next_w: vector
            # ctx: matrix
            # ctx_mask: vector
            # next_state: [matrix]
            # next_memory: [matrix]
            rval = f_next(*([next_w, ctx0, ctx_mask, ctx0_c, ctx_mask_c] +
                            next_state + next_memory))
            next_p = rval[0]
            if restrict_voc:
                raise NotImplementedError()
            next_w = rval[1]  # already argmax sorted
            next_state = []
            for lidx in xrange(n_layers_lstm):
                next_state.append(rval[2 + lidx])
            next_memory = []
            for lidx in xrange(n_layers_lstm):
                next_memory.append(rval[2 + n_layers_lstm + lidx])
            if stochastic:
                sample.append(next_w[0])  # take the most likely one
                sample_score += next_p[0, next_w[0]]
                if next_w[0] == 0:
                    break
            else:
                # the first run is (1,50000)
                cand_scores = hyp_scores[:, None] - numpy.log(next_p)
                cand_flat = cand_scores.flatten()
                ranks_flat = cand_flat.argsort()[:(k - dead_k)]

                voc_size = next_p.shape[1]
                trans_indices = ranks_flat / voc_size  # index of row
                word_indices = ranks_flat % voc_size  # index of col
                costs = cand_flat[ranks_flat]

                new_hyp_samples = []
                new_hyp_scores = numpy.zeros(k - dead_k).astype('float32')
                new_hyp_states = []
                for lidx in xrange(n_layers_lstm):
                    new_hyp_states.append([])
                new_hyp_memories = []
                for lidx in xrange(n_layers_lstm):
                    new_hyp_memories.append([])

                for idx, [ti, wi] in enumerate(zip(trans_indices,
                                                   word_indices)):
                    new_hyp_samples.append(hyp_samples[ti] + [wi])
                    new_hyp_scores[idx] = copy.copy(costs[idx])
                    for lidx in xrange(n_layers_lstm):
                        new_hyp_states[lidx].append(
                            copy.copy(next_state[lidx][ti]))
                    for lidx in xrange(n_layers_lstm):
                        new_hyp_memories[lidx].append(
                            copy.copy(next_memory[lidx][ti]))

                # check the finished samples
                new_live_k = 0
                hyp_samples = []
                hyp_scores = []
                hyp_states = []
                for lidx in xrange(n_layers_lstm):
                    hyp_states.append([])
                hyp_memories = []
                for lidx in xrange(n_layers_lstm):
                    hyp_memories.append([])

                for idx in xrange(len(new_hyp_samples)):
                    if new_hyp_samples[idx][-1] == 0:
                        sample.append(new_hyp_samples[idx])
                        sample_score.append(new_hyp_scores[idx])
                        dead_k += 1
                    else:
                        new_live_k += 1
                        hyp_samples.append(new_hyp_samples[idx])
                        hyp_scores.append(new_hyp_scores[idx])
                        for lidx in xrange(n_layers_lstm):
                            hyp_states[lidx].append(new_hyp_states[lidx][idx])
                        for lidx in xrange(n_layers_lstm):
                            hyp_memories[lidx].append(
                                new_hyp_memories[lidx][idx])
                hyp_scores = numpy.array(hyp_scores)
                live_k = new_live_k

                if new_live_k < 1:
                    break
                if dead_k >= k:
                    break

                next_w = numpy.array([w[-1] for w in hyp_samples])
                next_state = []
                for lidx in xrange(n_layers_lstm):
                    next_state.append(numpy.array(hyp_states[lidx]))
                next_memory = []
                for lidx in xrange(n_layers_lstm):
                    next_memory.append(numpy.array(hyp_memories[lidx]))

        if not stochastic:
            # dump every remaining one
            if live_k > 0:
                for idx in xrange(live_k):
                    sample.append(hyp_samples[idx])
                    sample_score.append(hyp_scores[idx])

        return sample, sample_score, next_state, next_memory

    def pred_probs(self, engine, whichset, f_log_probs, verbose=True):

        probs = []
        n_done = 0
        NLL = []
        L = []
        if whichset == 'train':
            tags = engine.train
            iterator = engine.kf_train
        elif whichset == 'valid':
            tags = engine.valid
            iterator = engine.kf_valid
        elif whichset == 'test':
            tags = engine.test
            iterator = engine.kf_test
        else:
            raise NotImplementedError()
        n_samples = numpy.sum([len(index) for index in iterator])
        for index in iterator:
            tag = [tags[i] for i in index]
            x, mask, ctx, ctx_mask, ctx_c, ctx_mask_c = prepare_data(
                engine, tag)
            pred_probs = f_log_probs(x, mask, ctx, ctx_mask, ctx_c, ctx_mask_c)
            L.append(mask.sum(0).tolist())
            NLL.append((-1 * pred_probs).tolist())
            probs.append(pred_probs.tolist())
            n_done += len(tag)
            if verbose:
                sys.stdout.write('\rComputing LL on %d/%d examples' %
                                 (n_done, n_samples))
                sys.stdout.flush()
        print
        probs = utils.flatten_list_of_list(probs)
        NLL = utils.flatten_list_of_list(NLL)
        L = utils.flatten_list_of_list(L)
        perp = 2**(numpy.sum(NLL) / numpy.sum(L) / numpy.log(2))
        return -1 * numpy.mean(probs), perp

    def sample_execute(self, engine, options, tparams, f_init, f_next, x, ctx,
                       ctx_c, ctx_mask, ctx_mask_c, trng):
        stochastic = False
        for jj in xrange(numpy.minimum(10, x.shape[1])):
            sample, score, _, _ = self.gen_sample(tparams,
                                                  f_init,
                                                  f_next,
                                                  ctx[jj],
                                                  ctx_c[jj],
                                                  ctx_mask[jj],
                                                  ctx_mask_c[jj],
                                                  options,
                                                  trng=trng,
                                                  k=5,
                                                  maxlen=30,
                                                  stochastic=stochastic)
            if not stochastic:
                best_one = numpy.argmin(score)
                sample = sample[best_one]
            else:
                sample = sample
            print 'Truth ', jj, ': ',
            for vv in x[:, jj]:
                if vv == 0:
                    break
                if vv in engine.word_idict:
                    print engine.word_idict[vv],
                else:
                    print 'UNK',
            print
            for kk, ss in enumerate([sample]):
                print 'Sample (', jj, ') ', ': ',
                for vv in ss:
                    if vv == 0:
                        break
                    if vv in engine.word_idict:
                        print engine.word_idict[vv],
                    else:
                        print 'UNK',
            print
Exemplo n.º 3
0
class Model(object):
    def __init__(self):
        self.layers = Layers()

    def init_params(self, options):
        # all parameters
        params = OrderedDict()
        # embedding
        params['Wemb'] = utils.norm_weight(options['vocab_size'],
                                           options['word_dim'])
        # LSTM initial states
        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_state',
                                                nin=options['ctx_dim'],
                                                nout=options['lstm_dim'])
        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_memory',
                                                nin=options['ctx_dim'],
                                                nout=options['lstm_dim'])
        # decoder: LSTM
        params = self.layers.get_layer('lstm_cond')[0](
            options,
            params,
            prefix='bo_lstm',
            nin=options['word_dim'],
            dim=options['lstm_dim'],
            dimctx=options['ctx_dim'])
        params = self.layers.get_layer('lstm')[0](params,
                                                  nin=options['lstm_dim'],
                                                  dim=options['lstm_dim'],
                                                  prefix='to_lstm')
        # readout
        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_logit_bo',
                                                nin=options['lstm_dim'],
                                                nout=options['word_dim'])
        if options['ctx2out']:
            params = self.layers.get_layer('ff')[0](options,
                                                    params,
                                                    prefix='ff_logit_ctx',
                                                    nin=options['ctx_dim'],
                                                    nout=options['word_dim'])
            params = self.layers.get_layer('ff')[0](options,
                                                    params,
                                                    prefix='ff_logit_to',
                                                    nin=options['lstm_dim'],
                                                    nout=options['word_dim'])
        # MLP
        params = self.layers.get_layer('ff')[0](options,
                                                params,
                                                prefix='ff_logit',
                                                nin=options['word_dim'],
                                                nout=options['vocab_size'])
        return params

    def build_model(self, tfparams, options, x, mask, ctx, ctx_mask):
        use_noise = tf.Variable(False,
                                dtype=tf.bool,
                                trainable=False,
                                name="use_noise")
        x_shape = tf.shape(x)
        n_timesteps = x_shape[0]
        n_samples = x_shape[1]
        # get word embeddings
        emb = tf.nn.embedding_lookup(
            tfparams['Wemb'], x,
            name="inputs_emb_lookup")  # (num_steps,64,512)
        emb_shape = tf.shape(emb)
        indices = tf.expand_dims(tf.range(1, emb_shape[0]), axis=1)
        emb_shifted = tf.scatter_nd(indices, emb[:-1], emb_shape)
        emb = emb_shifted

        # count num_frames==28
        with tf.name_scope("ctx_mean"):
            with tf.name_scope("counts"):
                counts = tf.expand_dims(
                    tf.reduce_sum(ctx_mask,
                                  axis=-1,
                                  name="reduce_sum_ctx_mask"), 1)  # (64,1)
            ctx_ = ctx
            ctx0 = ctx_  # (64,28,2048)
            ctx_mean = tf.reduce_sum(
                ctx0, axis=1, name="reduce_sum_ctx"
            ) / counts  #mean pooling of {vi}   # (64,2048)

        # initial state/cell
        with tf.name_scope("init_state"):
            init_state = self.layers.get_layer('ff')[1](
                tfparams, ctx_mean, options, prefix='ff_state',
                activ='tanh')  # (64,512)

        with tf.name_scope("init_memory"):
            init_memory = self.layers.get_layer('ff')[1](
                tfparams, ctx_mean, options, prefix='ff_memory',
                activ='tanh')  # (64,512)

        # hstltm = self.layers.build_hlstm(['bo_lstm','to_lstm'], inputs, n_timesteps, init_state, init_memory)
        with tf.name_scope("bo_lstm"):
            bo_lstm = self.layers.get_layer('lstm_cond')[1](
                tfparams,
                emb,
                options,
                prefix='bo_lstm',
                mask=mask,
                context=ctx0,
                one_step=False,
                init_state=init_state,
                init_memory=init_memory,
                use_noise=use_noise)
        with tf.name_scope("to_lstm"):
            to_lstm = self.layers.get_layer('lstm')[1](tfparams,
                                                       bo_lstm[0],
                                                       mask=mask,
                                                       one_step=False,
                                                       prefix='to_lstm')
        bo_lstm_h = bo_lstm[0]  # (t,64,512)
        to_lstm_h = to_lstm[0]  # (t,64,512)
        alphas = bo_lstm[2]  # (t,64,28)
        ctxs = bo_lstm[3]  # (t,64,2048)
        betas = bo_lstm[4]  # (t,64,)
        if options['use_dropout']:
            bo_lstm_h = self.layers.dropout_layer(bo_lstm_h, use_noise)
            to_lstm_h = self.layers.dropout_layer(to_lstm_h, use_noise)
        # compute word probabilities
        logit = self.layers.get_layer('ff')[1](
            tfparams, bo_lstm_h, options, prefix='ff_logit_bo',
            activ='linear')  # (t,64,512)*(512,512) = (t,64,512)
        if options['prev2out']:
            logit += emb
        if options['ctx2out']:
            to_lstm_h *= (1 - betas[:, :, None])  # (t,64,512)*(t,64,1)
            ctxs_beta = self.layers.get_layer('ff')[1](
                tfparams, ctxs, options, prefix='ff_logit_ctx',
                activ='linear')  # (t,64,2048)*(2048,512) = (t,64,512)
            ctxs_beta += self.layers.get_layer('ff')[1](
                tfparams,
                to_lstm_h,
                options,
                prefix='ff_logit_to',
                activ='linear'
            )  # (t,64,512)+((t,64,512)*(512,512)) = (t,64,512)
            logit += ctxs_beta
        logit = utils.tanh(logit)  # (t,64,512)
        if options['use_dropout']:
            logit = self.layers.dropout_layer(logit, use_noise)
        # (t,m,n_words)
        logit = self.layers.get_layer('ff')[1](
            tfparams, logit, options, prefix='ff_logit',
            activ='linear')  # (t,64,512)*(512,vocab_size) = (t,64,vocab_size)
        logit_shape = tf.shape(logit)
        # (t*m, n_words)
        probs = tf.nn.softmax(
            tf.reshape(logit,
                       [logit_shape[0] * logit_shape[1], logit_shape[2]
                        ]))  # (t*64, vocab_size)
        # cost
        x_flat = tf.reshape(x, [x_shape[0] * x_shape[1]])  # (t*m,)
        x_flat_shape = tf.shape(x_flat)
        gather_indices = tf.stack([tf.range(x_flat_shape[0]), x_flat],
                                  axis=1)  # (t*m,2)
        cost = -tf.log(
            tf.gather_nd(probs, gather_indices) +
            1e-8)  # (t*m,) : pick probs of each word in each timestep
        cost = tf.reshape(cost, [x_shape[0], x_shape[1]])  # (t,m)
        cost = tf.reduce_sum(
            (cost * mask), axis=0
        )  # (m,) : sum across all timesteps for each element in batch
        extra = [probs, alphas, betas]
        return use_noise, cost, extra

    def build_sampler(self,
                      tfparams,
                      options,
                      use_noise,
                      ctx0,
                      ctx_mask,
                      x,
                      bo_init_state_sampler,
                      to_init_state_sampler,
                      bo_init_memory_sampler,
                      to_init_memory_sampler,
                      mode=None):
        # ctx: # frames x ctx_dim
        ctx_ = ctx0
        counts = tf.reduce_sum(ctx_mask, axis=-1)  # scalar

        ctx = ctx_
        ctx_mean = tf.reduce_sum(ctx, axis=0) / counts  # (2048,)
        ctx = tf.expand_dims(ctx, 0)  # (1,28,2048)

        # initial state/cell
        bo_init_state = self.layers.get_layer('ff')[1](tfparams,
                                                       ctx_mean,
                                                       options,
                                                       prefix='ff_state',
                                                       activ='tanh')  # (512,)
        bo_init_memory = self.layers.get_layer('ff')[1](tfparams,
                                                        ctx_mean,
                                                        options,
                                                        prefix='ff_memory',
                                                        activ='tanh')  # (512,)
        to_init_state = tf.zeros(
            shape=(options['lstm_dim'], ),
            dtype=tf.float32)  # DOUBT : constant or not? # (512,)
        to_init_memory = tf.zeros(shape=(options['lstm_dim'], ),
                                  dtype=tf.float32)  # (512,)
        init_state = [bo_init_state, to_init_state]
        init_memory = [bo_init_memory, to_init_memory]

        print 'building f_init...',
        f_init = [ctx0] + init_state + init_memory
        print 'done'

        init_state = [bo_init_state_sampler, to_init_state_sampler]
        init_memory = [bo_init_memory_sampler, to_init_memory_sampler]

        # # if it's the first word, embedding should be all zero
        emb = tf.cond(
            tf.reduce_any(x[:, None] < 0), lambda: tf.zeros(
                shape=(1, tfparams['Wemb'].shape[1]), dtype=tf.float32),
            lambda: tf.nn.embedding_lookup(tfparams['Wemb'], x))  # (m,512)

        bo_lstm = self.layers.get_layer('lstm_cond')[1](
            tfparams,
            emb,
            options,
            prefix='bo_lstm',
            mask=None,
            context=ctx,
            one_step=True,
            init_state=init_state[0],
            init_memory=init_memory[0],
            use_noise=use_noise,
            mode=mode)
        to_lstm = self.layers.get_layer('lstm')[1](tfparams,
                                                   bo_lstm[0],
                                                   mask=None,
                                                   one_step=True,
                                                   init_state=init_state[1],
                                                   init_memory=init_memory[1],
                                                   prefix='to_lstm')
        next_state = [bo_lstm[0], to_lstm[0]]
        next_memory = [bo_lstm[1], to_lstm[0]]

        bo_lstm_h = bo_lstm[0]  # (1,512)
        to_lstm_h = to_lstm[0]  # (1,512)
        alphas = bo_lstm[2]  # (1,28)
        ctxs = bo_lstm[3]  # (1,2048)
        betas = bo_lstm[4]  # (1,)
        if options['use_dropout']:
            bo_lstm_h = self.layers.dropout_layer(bo_lstm_h, use_noise)
            to_lstm_h = self.layers.dropout_layer(to_lstm_h, use_noise)
        # compute word probabilities
        logit = self.layers.get_layer('ff')[1](
            tfparams, bo_lstm_h, options, prefix='ff_logit_bo',
            activ='linear')  # (1,512)*(512,512) = (1,512)
        if options['prev2out']:
            logit += emb
        if options['ctx2out']:
            to_lstm_h *= (1 - betas[:, None])  # (1,512)*(1,1) = (1,512)
            ctxs_beta = self.layers.get_layer('ff')[1](
                tfparams, ctxs, options, prefix='ff_logit_ctx',
                activ='linear')  # (1,2048)*(2048,512) = (1,512)
            ctxs_beta += self.layers.get_layer('ff')[1](
                tfparams,
                to_lstm_h,
                options,
                prefix='ff_logit_to',
                activ='linear')  # (1,512)+((1,512)*(512,512)) = (1,512)
            logit += ctxs_beta
        logit = utils.tanh(logit)  # (1,512)
        if options['use_dropout']:
            logit = self.layers.dropout_layer(logit, use_noise)
        # (1,n_words)
        logit = self.layers.get_layer('ff')[1](
            tfparams, logit, options, prefix='ff_logit',
            activ='linear')  # (1,512)*(512,vocab_size) = (1,vocab_size)
        next_probs = tf.nn.softmax(logit)
        # next_sample = trng.multinomial(pvals=next_probs).argmax(1)    # INCOMPLETE , DOUBT : why is multinomial needed?
        next_sample = tf.multinomial(
            next_probs, 1)  # draw samples with given probabilities (1,1)
        next_sample_shape = tf.shape(next_sample)
        next_sample = tf.reshape(next_sample, [next_sample_shape[0]])
        # next word probability
        print 'building f_next...',
        f_next = [next_probs, next_sample] + next_state + next_memory
        print 'done'
        return f_init, f_next

    def gen_sample(self,
                   sess,
                   tfparams,
                   f_init,
                   f_next,
                   ctx0,
                   ctx_mask,
                   options,
                   k=1,
                   maxlen=30,
                   stochastic=False,
                   restrict_voc=False):
        '''
        ctx0: (28,2048) (f, dim_ctx)
        ctx_mask: (28,) (f, )

        restrict_voc: set the probability of outofvoc words with 0, renormalize
        '''
        if k > 1:
            assert not stochastic, 'Beam search does not support stochastic sampling'

        sample = []
        sample_score = []
        if stochastic:
            sample_score = 0

        live_k = 1
        dead_k = 0

        hyp_samples = [[]] * live_k
        hyp_scores = np.zeros(live_k).astype('float32')
        hyp_states = []
        hyp_memories = []

        # [(28,2048),(512,),(512,),(512,),(512,)]
        rval = sess.run(f_init,
                        feed_dict={
                            "ctx_sampler:0": ctx0,
                            "ctx_mask_sampler:0": ctx_mask
                        })
        ctx0 = rval[0]

        next_state = []
        next_memory = []
        n_layers_lstm = 2

        for lidx in xrange(n_layers_lstm):
            next_state.append(rval[1 + lidx])
            next_state[-1] = next_state[-1].reshape(
                [live_k, next_state[-1].shape[0]])
        for lidx in xrange(n_layers_lstm):
            next_memory.append(rval[1 + n_layers_lstm + lidx])
            next_memory[-1] = next_memory[-1].reshape(
                [live_k, next_memory[-1].shape[0]])
        next_w = -1 * np.ones((1, )).astype('int32')
        for ii in xrange(maxlen):
            # return [(1, vocab_size), (1,), (1, 512), (1, 512), (1, 512), (1, 512)]
            # next_w: vector (1,)
            # ctx: matrix   (28, 2048)
            # ctx_mask: vector  (28,)
            # next_state: [matrix] [(1, 512), (1, 512)]
            # next_memory: [matrix] [(1, 512), (1, 512)]
            rval = sess.run(f_next,
                            feed_dict={
                                "x_sampler:0": next_w,
                                "ctx_sampler:0": ctx0,
                                "ctx_mask_sampler:0": ctx_mask,
                                'bo_init_state_sampler:0': next_state[0],
                                'to_init_state_sampler:0': next_state[1],
                                'bo_init_memory_sampler:0': next_memory[0],
                                'to_init_memory_sampler:0': next_memory[1]
                            })
            next_p = rval[0]
            if restrict_voc:
                raise NotImplementedError()
            next_w = rval[1]  # already argmax sorted
            next_state = []
            for lidx in xrange(n_layers_lstm):
                next_state.append(rval[2 + lidx])
            next_memory = []
            for lidx in xrange(n_layers_lstm):
                next_memory.append(rval[2 + n_layers_lstm + lidx])
            if stochastic:
                sample.append(next_w[0])  # take the most likely one
                sample_score += next_p[0, next_w[0]]
                if next_w[0] == 0:
                    break
            else:
                # the first run is (1,vocab_size)
                cand_scores = hyp_scores[:, None] - np.log(next_p)
                cand_flat = cand_scores.flatten()
                ranks_flat = cand_flat.argsort()[:(k - dead_k)]

                voc_size = next_p.shape[1]
                trans_indices = ranks_flat / voc_size  # index of row
                word_indices = ranks_flat % voc_size  # index of col
                costs = cand_flat[ranks_flat]

                new_hyp_samples = []
                new_hyp_scores = np.zeros(k - dead_k).astype('float32')
                new_hyp_states = []
                for lidx in xrange(n_layers_lstm):
                    new_hyp_states.append([])
                new_hyp_memories = []
                for lidx in xrange(n_layers_lstm):
                    new_hyp_memories.append([])

                for idx, [ti, wi] in enumerate(zip(trans_indices,
                                                   word_indices)):
                    new_hyp_samples.append(hyp_samples[ti] + [wi])
                    new_hyp_scores[idx] = copy.copy(costs[idx])
                    for lidx in xrange(n_layers_lstm):
                        new_hyp_states[lidx].append(
                            copy.copy(next_state[lidx][ti]))
                    for lidx in xrange(n_layers_lstm):
                        new_hyp_memories[lidx].append(
                            copy.copy(next_memory[lidx][ti]))

                # check the finished samples
                new_live_k = 0
                hyp_samples = []
                hyp_scores = []
                hyp_states = []
                for lidx in xrange(n_layers_lstm):
                    hyp_states.append([])
                hyp_memories = []
                for lidx in xrange(n_layers_lstm):
                    hyp_memories.append([])

                for idx in xrange(len(new_hyp_samples)):
                    if new_hyp_samples[idx][-1] == 0:
                        sample.append(new_hyp_samples[idx])
                        sample_score.append(new_hyp_scores[idx])
                        dead_k += 1
                    else:
                        new_live_k += 1
                        hyp_samples.append(new_hyp_samples[idx])
                        hyp_scores.append(new_hyp_scores[idx])
                        for lidx in xrange(n_layers_lstm):
                            hyp_states[lidx].append(new_hyp_states[lidx][idx])
                        for lidx in xrange(n_layers_lstm):
                            hyp_memories[lidx].append(
                                new_hyp_memories[lidx][idx])
                hyp_scores = np.array(hyp_scores)
                live_k = new_live_k

                if new_live_k < 1:
                    break
                if dead_k >= k:
                    break

                next_w = np.array([w[-1] for w in hyp_samples])
                next_state = []
                for lidx in xrange(n_layers_lstm):
                    next_state.append(np.array(hyp_states[lidx]))
                next_memory = []
                for lidx in xrange(n_layers_lstm):
                    next_memory.append(np.array(hyp_memories[lidx]))

        if not stochastic:
            # dump every remaining one
            if live_k > 0:
                for idx in xrange(live_k):
                    sample.append(hyp_samples[idx])
                    sample_score.append(hyp_scores[idx])

        return sample, sample_score, next_state, next_memory

    def pred_probs(self, sess, engine, whichset, f_log_probs, verbose=True):
        probs = []
        n_done = 0
        NLL = []
        L = []
        if whichset == 'train':
            tags = engine.train_data_ids
            iterator = engine.kf_train
        elif whichset == 'val':
            tags = engine.val_data_ids
            iterator = engine.kf_val
        elif whichset == 'test':
            tags = engine.test_data_ids
            iterator = engine.kf_test
        else:
            raise NotImplementedError()
        n_samples = np.sum([len(index) for index in iterator])
        for index in iterator:
            tag = [tags[i] for i in index]
            x, mask, ctx, ctx_mask = prepare_data(engine, tag, mode=whichset)

            pred_probs = sess.run(f_log_probs,
                                  feed_dict={
                                      "word_seq_x:0": x,
                                      "word_seq_mask:0": mask,
                                      "ctx:0": ctx,
                                      "ctx_mask:0": ctx_mask
                                  })

            L.append(mask.sum(0).tolist())
            NLL.append((-1 * pred_probs).tolist())
            probs.append(pred_probs.tolist())
            n_done += len(tag)
            if verbose:
                sys.stdout.write('\rComputing LL on %d/%d examples' %
                                 (n_done, n_samples))
                sys.stdout.flush()
        print ""
        probs = utils.flatten_list_of_list(probs)
        NLL = utils.flatten_list_of_list(NLL)
        L = utils.flatten_list_of_list(L)
        perp = 2**(np.sum(NLL) / np.sum(L) / np.log(2))
        return -1 * np.mean(probs), perp

    def sample_execute(self, sess, engine, options, tfparams, f_init, f_next,
                       x, ctx, ctx_mask):
        stochastic = not options['beam_search']
        if stochastic:
            beam = 1
        else:
            beam = 5
        # x = (t,64)
        # ctx = (64,28,2048)
        # ctx_mask = (64,28)
        for jj in xrange(np.minimum(10, x.shape[1])):
            sample, score, _, _ = self.gen_sample(sess,
                                                  tfparams,
                                                  f_init,
                                                  f_next,
                                                  ctx[jj],
                                                  ctx_mask[jj],
                                                  options,
                                                  k=beam,
                                                  maxlen=30,
                                                  stochastic=stochastic)
            if not stochastic:
                best_one = np.argmin(score)
                sample = sample[best_one]
            else:
                sample = sample
            print 'Truth ', jj, ': ',
            for vv in x[:, jj]:
                if vv == 0:
                    break
                if vv in engine.reverse_vocab:
                    print engine.reverse_vocab[vv],
                else:
                    print 'UNK',
            print ""
            for kk, ss in enumerate([sample]):
                print 'Sample (', jj, ') ', ': ',
                for vv in ss:
                    if vv == 0:
                        break
                    if vv in engine.reverse_vocab:
                        print engine.reverse_vocab[vv],
                    else:
                        print 'UNK',
            print ""