コード例 #1
0
  def __init__(self, config, loss, params):
    self._lr = get_shared_floatX(config.learning_rate, 'lr')
    self._t = get_shared_floatX(1, 't')
    self._all_m_tm1 = []
    self._all_v_tm1 = []
    self._updates = [(self._t, self._t + 1)]

    if config.lr_decay:
      lr_coef = tt.pow(config.lr_decay, (self._t - 1) // config.lr_decay_freq)
      self._updates.append((self._lr, lr_coef * config.learning_rate))

    grads = theano.grad(loss, params)
    #grads = theano.grad(loss, params, disconnected_inputs='ignore')

    self._global_grad_norm = tt.sqrt(tt.sum(tt.stack([tt.sum(g**2.) for g in grads])))
    if config.max_grad_norm:
      global_clip_factor = ifelse(tt.lt(self._global_grad_norm, config.max_grad_norm),
        cast_floatX_np(1.),
        cast_floatX(config.max_grad_norm/self._global_grad_norm))
      # global_clip_factor = tt.minimum(cast_floatX(config.max_grad_norm/self._global_grad_norm), cast_floatX_np(1))
      grads = [global_clip_factor * g for g in grads]

    lr_t = self._lr * \
      clip_sqrt(1 - tt.pow(config.adam_beta2, self._t)) / (1 - tt.pow(config.adam_beta1, self._t))

    for p, g in zip(params, grads):
        m_tm1 = get_shared_floatX(np.zeros_like(p.get_value()), 'adam_m_' + p.name)
        v_tm1 = get_shared_floatX(np.zeros_like(p.get_value()), 'adam_v_' + p.name)
        self._all_m_tm1.append(m_tm1)
        self._all_v_tm1.append(v_tm1)
        m_t = config.adam_beta1 * m_tm1 + (1-config.adam_beta1) * g
        v_t = config.adam_beta2 * v_tm1 + (1-config.adam_beta2) * tt.sqr(g)
        delta_t = -lr_t * m_t / (clip_sqrt(v_t) + config.adam_eps)
        p_t = p + delta_t
        self._updates += [(m_tm1, m_t), (v_tm1, v_t), (p, p_t)]
コード例 #2
0
  def __init__(self, config, loss, params):
    self._lr = get_shared_floatX(config.learning_rate, 'lr')
    self._t = get_shared_floatX(1, 't')
    self._all_m_tm1 = []
    self._all_v_tm1 = []
    self._updates = [(self._t, self._t + 1)]

    if config.lr_decay:
      lr_coef = tt.pow(config.lr_decay, (self._t - 1) // config.lr_decay_freq)
      self._updates.append((self._lr, lr_coef * config.learning_rate))

    grads = theano.grad(loss, params)

    self._global_grad_norm = tt.sqrt(tt.sum(tt.stack([tt.sum(g**2.) for g in grads])))
    if config.max_grad_norm:
      global_clip_factor = ifelse(tt.lt(self._global_grad_norm, config.max_grad_norm),
        cast_floatX_np(1.),
        cast_floatX(config.max_grad_norm/self._global_grad_norm))
      grads = [global_clip_factor * g for g in grads]

    lr_t = self._lr * \
      clip_sqrt(1 - tt.pow(config.adam_beta2, self._t)) / (1 - tt.pow(config.adam_beta1, self._t))

    for p, g in zip(params, grads):
        m_tm1 = get_shared_floatX(np.zeros_like(p.get_value()), 'adam_m_' + p.name)
        v_tm1 = get_shared_floatX(np.zeros_like(p.get_value()), 'adam_v_' + p.name)
        self._all_m_tm1.append(m_tm1)
        self._all_v_tm1.append(v_tm1)
        m_t = config.adam_beta1 * m_tm1 + (1-config.adam_beta1) * g
        v_t = config.adam_beta2 * v_tm1 + (1-config.adam_beta2) * tt.sqr(g)
        delta_t = -lr_t * m_t / (clip_sqrt(v_t) + config.adam_eps)
        p_t = p + delta_t
        self._updates += [(m_tm1, m_t), (v_tm1, v_t), (p, p_t)]
コード例 #3
0
 def make_param_from_value(self, name, value):
     if name in self._params:
         param = self._params[name]
         if value.shape != param.get_value().shape:
             raise AssertionError(
                 'parameter {} re-use attempt with mis-matching shapes: '
                 'existing shape {}, requested shape {}'.format(
                     name,
                     param.get_value().shape, value.shape))
         return param
     param = get_shared_floatX(value, name)
     self._params[name] = param
     if self._ema:
         self._param_shadows[name] = get_shared_floatX(
             value, '_shadow_' + name)
     return param
コード例 #4
0
 def make_param_from_value(self, name, value):
   if name in self._params:
     param = self._params[name];
     if value.shape != param.get_value().shape:
       raise AssertionError('parameter {} re-use attempt with mis-matching shapes: '
         'existing shape {}, requested shape {}'.format(
           name, param.get_value().shape, value.shape))
     return param
   param = get_shared_floatX(value, name)
   self._params[name] = param
   return param
コード例 #5
0
  def __init__(self, config, loss, params):
    self._lr = get_shared_floatX(config.learning_rate, 'lr')
    self._t = get_shared_floatX(1, 't')
    self._updates = [(self._t, self._t + 1)]

    if config.lr_decay:
      lr_coef = tt.pow(config.lr_decay, (self._t - 1) // config.lr_decay_freq)
      self._updates.append((self._lr, lr_coef * config.learning_rate))

    grads = theano.grad(loss, params)
    #grads = theano.grad(loss, params, disconnected_inputs='ignore')

    self._global_grad_norm = tt.sqrt(tt.sum(tt.stack([tt.sum(g**2.) for g in grads])))
    if config.max_grad_norm:
      global_clip_factor = ifelse(tt.lt(self._global_grad_norm, config.max_grad_norm),
        cast_floatX_np(1.),
        cast_floatX(config.max_grad_norm/self._global_grad_norm))
      # global_clip_factor = tt.minimum(cast_floatX(config.max_grad_norm/self._global_grad_norm), cast_floatX_np(1))
      grads = [global_clip_factor * g for g in grads]

    for p, g in zip(params, grads):
      delta_t = -self._lr * g
      p_t = p + delta_t
      self._updates += [(p, p_t)]
コード例 #6
0
  def __init__(self, config, data):
    self.init_start(config)
    # cuda optimized batched dot product
    batched_dot = tt.batched_dot if config.device == 'cpu' else theano.sandbox.cuda.blas.batched_dot

    ###################################################
    # Load all data onto GPU
    ###################################################

    emb_val = data.word_emb_data.word_emb                                               # (voc size, emb_dim)
    first_known_word = data.word_emb_data.first_known_word
    assert config.emb_dim == emb_val.shape[1]
    assert first_known_word > 0
    emb_val[:first_known_word] = 0 
    if config.learn_single_unk:
      first_unknown_word = data.word_emb_data.first_unknown_word
      known_emb = get_shared_floatX(emb_val[:first_unknown_word], 'known_emb')          # (num known words, emb_dim)
      single_unk_emb = self.make_param('single_unk_emb', (config.emb_dim,), 'uniform')  # (emb_dim,)
      emb = tt.concatenate([known_emb, tt.shape_padleft(single_unk_emb)], axis=0)       # (num known words + 1, emb_dim)
    else:
      emb = get_shared_floatX(emb_val, 'emb')                                           # (voc size, emb_dim)

    trn_ctxs, trn_ctx_masks, trn_ctx_lens, trn_qtns, trn_qtn_masks, trn_qtn_lens, trn_qtn_ctx_idxs, \
      trn_anss, trn_ans_stts, trn_ans_ends = _gpu_dataset('trn', data.trn, config)

    dev_ctxs, dev_ctx_masks, dev_ctx_lens, dev_qtns, dev_qtn_masks, dev_qtn_lens, dev_qtn_ctx_idxs, \
      dev_anss, dev_ans_stts, dev_ans_ends = _gpu_dataset('dev', data.dev, config)

    tst_ctxs, tst_ctx_masks, tst_ctx_lens, tst_qtns, tst_qtn_masks, tst_qtn_lens, tst_qtn_ctx_idxs, \
      tst_anss, tst_ans_stts, tst_ans_ends = _gpu_dataset('tst', data.tst, config)

    ###################################################
    # Map input given to interface functions to an actual mini batch
    ###################################################

    qtn_idxs = tt.ivector('qtn_idxs')                           # (batch_bize,)
    batch_size = qtn_idxs.size

    dataset_ctxs = tt.imatrix('dataset_ctxs')                   # (num contexts in dataset, max_p_len of dataset)
    dataset_ctx_masks = tt.imatrix('dataset_ctx_masks')         # (num contexts in dataset, max_p_len of dataset)
    dataset_ctx_lens = tt.ivector('dataset_ctx_lens')           # (num contexts in dataset,)
    dataset_qtns = tt.imatrix('dataset_qtns')                   # (num questions in dataset, max_q_len of dataset)
    dataset_qtn_masks = tt.imatrix('dataset_qtn_masks')         # (num questions in dataset, max_q_len of dataset)
    dataset_qtn_lens = tt.ivector('dataset_qtn_lens')           # (num questions in dataset,)
    dataset_qtn_ctx_idxs = tt.ivector('dataset_qtn_ctx_idxs')   # (num questions in dataset,)
    dataset_anss = tt.ivector('dataset_anss')                   # (num questions in dataset,)
    dataset_ans_stts = tt.ivector('dataset_ans_stts')           # (num questions in dataset,)
    dataset_ans_ends = tt.ivector('dataset_ans_ends')           # (num questions in dataset,)

    ctx_idxs = dataset_qtn_ctx_idxs[qtn_idxs]                   # (batch_size,)
    p_lens = dataset_ctx_lens[ctx_idxs]                         # (batch_size,)
    max_p_len = p_lens.max()
    p = dataset_ctxs[ctx_idxs][:,:max_p_len].T                  # (max_p_len, batch_size)
    p_mask = dataset_ctx_masks[ctx_idxs][:,:max_p_len].T        # (max_p_len, batch_size)
    float_p_mask = cast_floatX(p_mask)

    q_lens = dataset_qtn_lens[qtn_idxs]                         # (batch_size,)
    max_q_len = q_lens.max()
    q = dataset_qtns[qtn_idxs][:,:max_q_len].T                  # (max_q_len, batch_size)
    q_mask = dataset_qtn_masks[qtn_idxs][:,:max_q_len].T        # (max_q_len, batch_size)
    float_q_mask = cast_floatX(q_mask)

    a = dataset_anss[qtn_idxs]                                  # (batch_size,)
    a_stt = dataset_ans_stts[qtn_idxs]                          # (batch_size,)
    a_end = dataset_ans_ends[qtn_idxs]                          # (batch_size,)

    ###################################################
    # RaSoR
    ###################################################

    ff_dim = config.ff_dims[-1]

    p_emb = emb[p]        # (max_p_len, batch_size, emb_dim)
    q_emb = emb[q]        # (max_q_len, batch_size, emb_dim)

    p_star_parts = [p_emb]
    p_star_dim = config.emb_dim

    ############ q indep

    if config.ablation in [None, 'only_q_indep']:

      # (max_q_len, batch_size, 2*hidden_dim)
      q_indep_h = self.stacked_bi_lstm('q_indep_lstm', q_emb, float_q_mask,
        config.num_bilstm_layers, config.emb_dim, config.hidden_dim,
        config.lstm_drop_x, config.lstm_drop_h,
        couple_i_and_f = config.lstm_couple_i_and_f,
        learn_initial_state = config.lstm_learn_initial_state,
        tie_x_dropout = config.lstm_tie_x_dropout,
        sep_x_dropout = config.lstm_sep_x_dropout,
        sep_h_dropout = config.lstm_sep_h_dropout,
        w_init = config.lstm_w_init,
        u_init = config.lstm_u_init,
        forget_bias_init = config.lstm_forget_bias_init,
        other_bias_init = config.default_bias_init)

      # (max_q_len, batch_size, ff_dim)     # contains junk where masked
      q_indep_ff = self.ff('q_indep_ff', q_indep_h, [2*config.hidden_dim] + config.ff_dims,
        'relu', config.ff_drop_x, bias_init=config.default_bias_init)
      if config.extra_drop_x:
        q_indep_ff = self.dropout(q_indep_ff, config.extra_drop_x)
      w_q = self.make_param('w_q', (ff_dim,), 'uniform')
      q_indep_scores = tt.dot(q_indep_ff, w_q)                                    # (max_q_len, batch_size)
      q_indep_weights = softmax_columns_with_mask(q_indep_scores, float_q_mask)   # (max_q_len, batch_size)
      q_indep = tt.sum(tt.shape_padright(q_indep_weights) * q_indep_h, axis=0)    # (batch_size, 2*hidden_dim)
      q_indep_repeated = tt.extra_ops.repeat(                                     # (max_p_len, batch_size, 2*hidden_dim)
        tt.shape_padleft(q_indep), max_p_len, axis=0)

      p_star_parts.append(q_indep_repeated)
      p_star_dim += 2 * config.hidden_dim
    
    ############ q aligned

    if config.ablation in [None, 'only_q_align']:

      if config.q_aln_ff_tie:
        q_align_ff_p_name = q_align_ff_q_name = 'q_align_ff'
      else:
        q_align_ff_p_name = 'q_align_ff_p'
        q_align_ff_q_name = 'q_align_ff_q'
      # (max_p_len, batch_size, ff_dim)     # contains junk where masked
      q_align_ff_p = self.ff(q_align_ff_p_name, p_emb, [config.emb_dim] + config.ff_dims,
        'relu', config.ff_drop_x, bias_init=config.default_bias_init)
      # (max_q_len, batch_size, ff_dim)     # contains junk where masked
      q_align_ff_q = self.ff(q_align_ff_q_name, q_emb, [config.emb_dim] + config.ff_dims,
        'relu', config.ff_drop_x, bias_init=config.default_bias_init)

      # http://deeplearning.net/software/theano/library/tensor/basic.html#theano.tensor.batched_dot
      # https://groups.google.com/d/msg/theano-users/yBh27AJGq2E/vweiLoXADQAJ
      q_align_ff_p_shuffled = q_align_ff_p.dimshuffle((1,0,2))                    # (batch_size, max_p_len, ff_dim)
      q_align_ff_q_shuffled = q_align_ff_q.dimshuffle((1,2,0))                    # (batch_size, ff_dim, max_q_len)
      q_align_scores = batched_dot(q_align_ff_p_shuffled, q_align_ff_q_shuffled)  # (batch_size, max_p_len, max_q_len)

      p_mask_shuffled = float_p_mask.dimshuffle((1,0,'x'))                        # (batch_size, max_p_len, 1)
      q_mask_shuffled = float_q_mask.dimshuffle((1,'x',0))                        # (batch_size, 1, max_q_len)
      pq_mask = p_mask_shuffled * q_mask_shuffled                                 # (batch_size, max_p_len, max_q_len)

      q_align_weights = softmax_depths_with_mask(q_align_scores, pq_mask)         # (batch_size, max_p_len, max_q_len)
      q_emb_shuffled = q_emb.dimshuffle((1,0,2))                                  # (batch_size, max_q_len, emb_dim)
      q_align = batched_dot(q_align_weights, q_emb_shuffled)                      # (batch_size, max_p_len, emb_dim)
      q_align_shuffled = q_align.dimshuffle((1,0,2))                              # (max_p_len, batch_size, emb_dim)
    
      p_star_parts.append(q_align_shuffled)
      p_star_dim += config.emb_dim

    ############ passage-level bi-lstm

    p_star = tt.concatenate(p_star_parts, axis=2)     # (max_p_len, batch_size, p_star_dim)

    # (max_p_len, batch_size, 2*hidden_dim)
    p_level_h = self.stacked_bi_lstm('p_level_lstm', p_star, float_p_mask,
      config.num_bilstm_layers, p_star_dim, config.hidden_dim,
      config.lstm_drop_x, config.lstm_drop_h,
      couple_i_and_f = config.lstm_couple_i_and_f,
      learn_initial_state = config.lstm_learn_initial_state,
      tie_x_dropout = config.lstm_tie_x_dropout,
      sep_x_dropout = config.lstm_sep_x_dropout,
      sep_h_dropout = config.lstm_sep_h_dropout,
      w_init = config.lstm_w_init,
      u_init = config.lstm_u_init,
      forget_bias_init = config.lstm_forget_bias_init,
      other_bias_init = config.default_bias_init)

    if config.sep_stt_end_drop:
      p_level_h_for_stt = self.dropout(p_level_h, config.ff_drop_x)
      p_level_h_for_end = self.dropout(p_level_h, config.ff_drop_x)
    else:
      p_level_h_for_stt = p_level_h_for_end = self.dropout(p_level_h, config.ff_drop_x)

    # Having a single FF hidden layer allows to compute the FF over the concatenation
    # of span-start-hidden-state and span-end-hidden-state by operating the linear transformation
    # separately over each rather than over their concatenations.
    assert len(config.ff_dims) == 1

    if config.objective in ['span_multinomial', 'span_binary']:

      ############ scores

      p_stt_lin = self.linear(                              # (max_p_len, batch_size, ff_dim)
        'p_stt_lin', p_level_h_for_stt, 2*config.hidden_dim, ff_dim, bias_init=config.default_bias_init)
      p_end_lin = self.linear(                              # (max_p_len, batch_size, ff_dim)
        'p_end_lin', p_level_h_for_end, 2*config.hidden_dim, ff_dim, with_bias=False)

      # (batch_size, max_p_len*max_ans_len, ff_dim), (batch_size, max_p_len*max_ans_len)
      span_lin_reshaped, span_masks_reshaped = _span_sums(
        p_stt_lin, p_end_lin, p_lens, max_p_len, batch_size, ff_dim, config.max_ans_len)

      span_ff_reshaped = tt.nnet.relu(span_lin_reshaped)    # (batch_size, max_p_len*max_ans_len, ff_dim)
      w_a = self.make_param('w_a', (ff_dim,), 'uniform')
      span_scores_reshaped = tt.dot(span_ff_reshaped, w_a)  # (batch_size, max_p_len*max_ans_len)

      ############ classification

      classification_func = _span_multinomial_classification if config.objective == 'span_multinomial' else \
        _span_binary_classification
      # (batch_size,), (batch_size), (batch_size,)
      xents, accs, a_hats = classification_func(span_scores_reshaped, span_masks_reshaped, a)
      loss = xents.mean()
      acc = accs.mean()
      # (batch_size,), (batch_size)
      ans_hat_start_word_idxs, ans_hat_end_word_idxs = _tt_ans_idx_to_ans_word_idxs(a_hats, config.max_ans_len)

    elif config.objective == 'span_endpoints':

      ############ scores

      # note that dropout was already applied when assigning to p_level_h_for_stt/end
      p_stt_ff = self.ff(                                                 # (max_p_len, batch_size, ff_dim)
        'p_stt_ff', p_level_h_for_stt, [2*config.hidden_dim] + [ff_dim],
        'relu', dropout_ps=None, bias_init=config.default_bias_init)
      p_end_ff = self.ff(                                                 # (max_p_len, batch_size, ff_dim)
        'p_end_ff', p_level_h_for_end, [2*config.hidden_dim] + [ff_dim],
        'relu', dropout_ps=None, bias_init=config.default_bias_init)

      w_a_stt = self.make_param('w_a_stt', (ff_dim,), 'uniform')
      w_a_end = self.make_param('w_a_end', (ff_dim,), 'uniform')
      word_stt_scores = tt.dot(p_stt_ff, w_a_stt)                         # (max_p_len, batch_size)
      word_end_scores = tt.dot(p_end_ff, w_a_end)                         # (max_p_len, batch_size)

      ############ classification

      stt_log_probs, stt_xents = _word_multinomial_classification(        # (batch_size, max_p_len), (batch_size,)
        word_stt_scores.T, float_p_mask.T, a_stt)
      end_log_probs, end_xents = _word_multinomial_classification(        # (batch_size, max_p_len), (batch_size,)
        word_end_scores.T, float_p_mask.T, a_end)

      xents = stt_xents + end_xents                                       # (batch_size,)
      loss = xents.mean()

      ############ finding highest P(span) = P(span start) * P(span end)

      end_log_probs = end_log_probs.dimshuffle((1,0,'x'))                 # (max_p_len, batch_size, 1)
      stt_log_probs = stt_log_probs.dimshuffle((1,0,'x'))                 # (max_p_len, batch_size, 1)
      # (batch_size, max_p_len*max_ans_len, 1), (batch_size, max_p_len*max_ans_len)
      span_log_probs_reshaped, span_masks_reshaped = _span_sums(
        stt_log_probs, end_log_probs, p_lens, max_p_len, batch_size, 1, config.max_ans_len)

      span_log_probs_reshaped = span_log_probs_reshaped.reshape(          # (batch_size, max_p_len*max_ans_len)
        (batch_size, max_p_len*config.max_ans_len))
      a_hats = argmax_with_mask(                                          # (batch_size,)
        span_log_probs_reshaped, span_masks_reshaped)
      accs = cast_floatX(tt.eq(a_hats, a))                                # (batch_size,)

      acc = accs.mean()
      # (batch_size,), (batch_size)
      ans_hat_start_word_idxs, ans_hat_end_word_idxs = _tt_ans_idx_to_ans_word_idxs(a_hats, config.max_ans_len)

    else:
      raise AssertionError('unsupported objective')

    ############ optimization

    opt = AdamOptimizer(config, loss, self._params.values())
    updates = opt.get_updates()
    global_grad_norm = opt.get_global_grad_norm()
    self.get_lr_value = lambda : opt.get_lr_value()

    ############ interface

    trn_givens = {
      self._is_training : np.int32(1), 
      dataset_ctxs: trn_ctxs,
      dataset_ctx_masks: trn_ctx_masks,
      dataset_ctx_lens: trn_ctx_lens,
      dataset_qtns: trn_qtns,
      dataset_qtn_masks: trn_qtn_masks,
      dataset_qtn_lens: trn_qtn_lens,
      dataset_qtn_ctx_idxs: trn_qtn_ctx_idxs,
      dataset_anss: trn_anss,
      dataset_ans_stts: trn_ans_stts,
      dataset_ans_ends: trn_ans_ends}

    dev_givens = {
      self._is_training : np.int32(0), 
      dataset_ctxs: dev_ctxs,
      dataset_ctx_masks: dev_ctx_masks,
      dataset_ctx_lens: dev_ctx_lens,
      dataset_qtns: dev_qtns,
      dataset_qtn_masks: dev_qtn_masks,
      dataset_qtn_lens: dev_qtn_lens,
      dataset_qtn_ctx_idxs: dev_qtn_ctx_idxs,
      dataset_anss: dev_anss,
      dataset_ans_stts: dev_ans_stts,
      dataset_ans_ends: dev_ans_ends}

    tst_givens = {
      self._is_training : np.int32(0), 
      dataset_ctxs: tst_ctxs,
      dataset_ctx_masks: tst_ctx_masks,
      dataset_ctx_lens: tst_ctx_lens,
      dataset_qtns: tst_qtns,
      dataset_qtn_masks: tst_qtn_masks,
      dataset_qtn_lens: tst_qtn_lens,
      dataset_qtn_ctx_idxs: tst_qtn_ctx_idxs}
      #dataset_anss: tst_anss,
      #dataset_ans_stts: tst_ans_stts,
      #dataset_ans_ends: tst_ans_ends}

    self.train = theano.function(
      [qtn_idxs],
      [loss, acc, global_grad_norm],
      givens = trn_givens,
      updates = updates,
      on_unused_input = 'ignore')
      #mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=True))

    self.eval_dev = theano.function(
      [qtn_idxs],
      [loss, acc, ans_hat_start_word_idxs, ans_hat_end_word_idxs],
      givens = dev_givens,
      updates = None,
      on_unused_input = 'ignore')

    self.eval_tst = theano.function(
      [qtn_idxs],
      [ans_hat_start_word_idxs, ans_hat_end_word_idxs],
      givens = tst_givens,
      updates = None,
      on_unused_input = 'ignore')
コード例 #7
0
ファイル: model.py プロジェクト: Levstyle/RaSoR
    def __init__(self, config, data):
        self.init_start(config)
        # cuda optimized batched dot product
        batched_dot = tt.batched_dot if config.device == 'cpu' else theano.sandbox.cuda.blas.batched_dot

        ###################################################
        # Load all data onto GPU
        ###################################################

        emb_val = data.word_emb_data.word_emb  # (voc size, emb_dim)
        first_known_word = data.word_emb_data.first_known_word
        assert config.emb_dim == emb_val.shape[1]
        assert first_known_word > 0
        emb_val[:first_known_word] = 0
        if config.learn_single_unk:
            first_unknown_word = data.word_emb_data.first_unknown_word
            known_emb = get_shared_floatX(
                emb_val[:first_unknown_word],
                'known_emb')  # (num known words, emb_dim)
            single_unk_emb = self.make_param('single_unk_emb',
                                             (config.emb_dim, ),
                                             'uniform')  # (emb_dim,)
            emb = tt.concatenate(
                [known_emb, tt.shape_padleft(single_unk_emb)],
                axis=0)  # (num known words + 1, emb_dim)
        else:
            emb = get_shared_floatX(emb_val, 'emb')  # (voc size, emb_dim)

        if config.is_train:
            trn_ds_vec = data.trn.vectorized
            trn_ctxs, trn_ctx_masks, trn_ctx_lens = _gpu_sequences(
                'trn_ctxs', trn_ds_vec.ctxs, trn_ds_vec.ctx_lens)
            trn_qtns, trn_qtn_masks, trn_qtn_lens = _gpu_sequences(
                'trn_qtns', trn_ds_vec.qtns, trn_ds_vec.qtn_lens)
            trn_anss = _gpu_trn_answers('trn_anss', trn_ds_vec.anss,
                                        trn_ds_vec.num_anss,
                                        config.max_ans_len)

            dev_ds_vec = data.dev.vectorized
            dev_ctxs, dev_ctx_masks, dev_ctx_lens = _gpu_sequences(
                'dev_ctxs', dev_ds_vec.ctxs, dev_ds_vec.ctx_lens)
            dev_qtns, dev_qtn_masks, dev_qtn_lens = _gpu_sequences(
                'dev_qtns', dev_ds_vec.qtns, dev_ds_vec.qtn_lens)
            dev_anss, dev_ans_masks = _gpu_dev_answers('dev_anss',
                                                       dev_ds_vec.anss,
                                                       dev_ds_vec.num_anss,
                                                       config.max_ans_len)
        else:
            tst_ds_vec = data.tst.vectorized
            tst_ctxs, tst_ctx_masks, tst_ctx_lens = _gpu_sequences(
                'tst_ctxs', tst_ds_vec.ctxs, tst_ds_vec.ctx_lens)
            tst_qtns, tst_qtn_masks, tst_qtn_lens = _gpu_sequences(
                'tst_qtns', tst_ds_vec.qtns, tst_ds_vec.qtn_lens)

        ###################################################
        # Map input given to interface functions to an actual mini batch
        ###################################################

        in_sample_idxs = tt.ivector('in_sample_idxs')  # (batch_Size,)
        in_ctxs = tt.imatrix(
            'in_ctxs')  # (num samples in dataset, max_p_len of dataset)
        in_ctx_masks = tt.imatrix(
            'in_ctx_masks')  # (num samples in dataset, max_p_len of dataset)
        in_ctx_lens = tt.ivector('in_ctx_lens')  # (num samples in dataset,)
        in_qtns = tt.imatrix(
            'in_qtns')  # (num samples in dataset, max_q_len of dataset)
        in_qtn_masks = tt.imatrix(
            'in_qtn_masks')  # (num samples in dataset, max_q_len of dataset)
        in_qtn_lens = tt.ivector('in_qtn_lens')  # (num samples in dataset,)
        batch_size = in_sample_idxs.size

        p_lens = in_ctx_lens[in_sample_idxs]  # (batch_size,)
        max_p_len = p_lens.max()
        p = in_ctxs[in_sample_idxs][:, :max_p_len].T  # (max_p_len, batch_size)
        p_mask = in_ctx_masks[
            in_sample_idxs][:, :max_p_len].T  # (max_p_len, batch_size)
        float_p_mask = cast_floatX(p_mask)

        q_lens = in_qtn_lens[in_sample_idxs]  # (batch_size,)
        max_q_len = q_lens.max()
        q = in_qtns[in_sample_idxs][:, :max_q_len].T  # (max_q_len, batch_size)
        q_mask = in_qtn_masks[
            in_sample_idxs][:, :max_q_len].T  # (max_q_len, batch_size)
        float_q_mask = cast_floatX(q_mask)

        if config.is_train:
            in_trn_anss = tt.ivector(
                'in_trn_anss')  # (num samples in train dataset,)
            in_dev_anss = tt.imatrix(
                'in_dev_anss'
            )  # (num samples in test dataset, max_num_ans of test dataset)
            in_dev_ans_masks = tt.imatrix(
                'in_dev_ans_masks'
            )  # (num samples in test dataset, max_num_ans of test dataset)

            trn_a = in_trn_anss[in_sample_idxs]  # (batch_size,)
            dev_a = in_dev_anss[in_sample_idxs]  # (batch_size, max_num_ans)
            dev_a_mask = in_dev_ans_masks[
                in_sample_idxs]  # (batch_size, max_num_ans)

        ###################################################
        # RaSoR
        ###################################################

        ############ embed words

        p_emb = emb[p]  # (max_p_len, batch_size, emb_dim)
        q_emb = emb[q]  # (max_q_len, batch_size, emb_dim)

        ############ q indep

        # (max_q_len, batch_size, 2*hidden_dim)
        q_indep_h = self.stacked_bi_lstm(
            'q_indep_lstm',
            q_emb,
            float_q_mask,
            config.num_bilstm_layers,
            config.emb_dim,
            config.hidden_dim,
            config.lstm_drop_x,
            config.lstm_drop_h,
            couple_i_and_f=config.lstm_couple_i_and_f,
            learn_initial_state=config.lstm_learn_initial_state,
            tie_x_dropout=config.lstm_tie_x_dropout,
            sep_x_dropout=config.lstm_sep_x_dropout,
            sep_h_dropout=config.lstm_sep_h_dropout,
            w_init=config.lstm_w_init,
            u_init=config.lstm_u_init,
            forget_bias_init=config.lstm_forget_bias_init,
            other_bias_init=config.default_bias_init)

        ff_dim = config.ff_dims[-1]
        # (max_q_len, batch_size, ff_dim)     # contains junk where masked
        q_indep_ff = self.ff('q_indep_ff',
                             q_indep_h,
                             [2 * config.hidden_dim] + config.ff_dims,
                             'relu',
                             config.ff_drop_x,
                             bias_init=config.default_bias_init)
        if config.extra_drop_x:
            q_indep_ff = self.dropout(q_indep_ff, config.extra_drop_x)
        w_q = self.make_param('w_q', (ff_dim, ), 'uniform')
        q_indep_scores = tt.dot(q_indep_ff, w_q)  # (max_q_len, batch_size)
        q_indep_weights = softmax_columns_with_mask(
            q_indep_scores, float_q_mask)  # (max_q_len, batch_size)
        q_indep = tt.sum(tt.shape_padright(q_indep_weights) * q_indep_h,
                         axis=0)  # (batch_size, 2*hidden_dim)

        ############ q aligned

        if config.q_aln_ff_tie:
            q_align_ff_p_name = q_align_ff_q_name = 'q_align_ff'
        else:
            q_align_ff_p_name = 'q_align_ff_p'
            q_align_ff_q_name = 'q_align_ff_q'
        # (max_p_len, batch_size, ff_dim)     # contains junk where masked
        q_align_ff_p = self.ff(q_align_ff_p_name,
                               p_emb, [config.emb_dim] + config.ff_dims,
                               'relu',
                               config.ff_drop_x,
                               bias_init=config.default_bias_init)
        # (max_q_len, batch_size, ff_dim)     # contains junk where masked
        q_align_ff_q = self.ff(q_align_ff_q_name,
                               q_emb, [config.emb_dim] + config.ff_dims,
                               'relu',
                               config.ff_drop_x,
                               bias_init=config.default_bias_init)

        # http://deeplearning.net/software/theano/library/tensor/basic.html#theano.tensor.batched_dot
        # https://groups.google.com/d/msg/theano-users/yBh27AJGq2E/vweiLoXADQAJ
        q_align_ff_p_shuffled = q_align_ff_p.dimshuffle(
            (1, 0, 2))  # (batch_size, max_p_len, ff_dim)
        q_align_ff_q_shuffled = q_align_ff_q.dimshuffle(
            (1, 2, 0))  # (batch_size, ff_dim, max_q_len)
        q_align_scores = batched_dot(
            q_align_ff_p_shuffled,
            q_align_ff_q_shuffled)  # (batch_size, max_p_len, max_q_len)

        p_mask_shuffled = float_p_mask.dimshuffle(
            (1, 0, 'x'))  # (batch_size, max_p_len, 1)
        q_mask_shuffled = float_q_mask.dimshuffle(
            (1, 'x', 0))  # (batch_size, 1, max_q_len)
        pq_mask = p_mask_shuffled * q_mask_shuffled  # (batch_size, max_p_len, max_q_len)

        q_align_weights = softmax_depths_with_mask(
            q_align_scores, pq_mask)  # (batch_size, max_p_len, max_q_len)
        q_emb_shuffled = q_emb.dimshuffle(
            (1, 0, 2))  # (batch_size, max_q_len, emb_dim)
        q_align = batched_dot(
            q_align_weights,
            q_emb_shuffled)  # (batch_size, max_p_len, emb_dim)

        ############ p star

        q_align_shuffled = q_align.dimshuffle(
            (1, 0, 2))  # (max_p_len, batch_size, emb_dim)
        q_indep_repeated = tt.extra_ops.repeat(  # (max_p_len, batch_size, 2*hidden_dim)
            tt.shape_padleft(q_indep),
            max_p_len,
            axis=0)
        p_star = tt.concatenate(  # (max_p_len, batch_size, 2*emb_dim + 2*hidden_dim)
            [p_emb, q_align_shuffled, q_indep_repeated],
            axis=2)

        ############ passage-level bi-lstm

        # (max_p_len, batch_size, 2*hidden_dim)
        p_level_h = self.stacked_bi_lstm(
            'p_level_lstm',
            p_star,
            float_p_mask,
            config.num_bilstm_layers,
            2 * config.emb_dim + 2 * config.hidden_dim,
            config.hidden_dim,
            config.lstm_drop_x,
            config.lstm_drop_h,
            couple_i_and_f=config.lstm_couple_i_and_f,
            learn_initial_state=config.lstm_learn_initial_state,
            tie_x_dropout=config.lstm_tie_x_dropout,
            sep_x_dropout=config.lstm_sep_x_dropout,
            sep_h_dropout=config.lstm_sep_h_dropout,
            w_init=config.lstm_w_init,
            u_init=config.lstm_u_init,
            forget_bias_init=config.lstm_forget_bias_init,
            other_bias_init=config.default_bias_init)

        ############ span scores

        if config.sep_stt_end_drop:
            p_level_h_for_stt = self.dropout(p_level_h, config.ff_drop_x)
            p_level_h_for_end = self.dropout(p_level_h, config.ff_drop_x)
        else:
            p_level_h_for_stt = p_level_h_for_end = self.dropout(
                p_level_h, config.ff_drop_x)

        # Having a single FF hidden layer allows to compute the FF over the concatenation
        # of span-start-hidden-state and span-end-hidden-state by operating the linear transformation
        # separately over each (more efficient).
        assert len(config.ff_dims) == 1
        # (max_p_len, batch_size, ff_dim)
        p_stt = self.linear('p_stt',
                            p_level_h_for_stt,
                            2 * config.hidden_dim,
                            ff_dim,
                            bias_init=config.default_bias_init)
        # (max_p_len, batch_size, ff_dim)
        p_end = self.linear('p_end',
                            p_level_h_for_end,
                            2 * config.hidden_dim,
                            ff_dim,
                            with_bias=False)

        p_end_zero_padded = tt.concatenate(  # (max_p_len+max_ans_len-1, batch_size, ff_dim)
            [p_end,
             tt.zeros((config.max_ans_len - 1, batch_size, ff_dim))],
            axis=0)
        p_max_ans_len_range = tt.shape_padleft(  # (1, max_ans_len)
            tt.arange(config.max_ans_len))
        p_offsets = tt.shape_padright(tt.arange(max_p_len))  # (max_p_len, 1)
        p_end_idxs = p_max_ans_len_range + p_offsets  # (max_p_len, max_ans_len)
        p_end_idxs_flat = p_end_idxs.flatten()  # (max_p_len*max_ans_len,)

        p_ends = p_end_zero_padded[
            p_end_idxs_flat]  # (max_p_len*max_ans_len, batch_size, ff_dim)
        p_ends = p_ends.reshape(  # (max_p_len, max_ans_len, batch_size, ff_dim)
            (max_p_len, config.max_ans_len, batch_size, ff_dim))

        p_stt_shuffled = p_stt.dimshuffle(
            (0, 'x', 1, 2))  # (max_p_len, 1, batch_size, ff_dim)

        p_stt_end_lin = p_stt_shuffled + p_ends  # (max_p_len, max_ans_len, batch_size, ff_dim)
        p_stt_end = tt.nnet.relu(
            p_stt_end_lin)  # (max_p_len, max_ans_len, batch_size, ff_dim)

        w_a = self.make_param('w_a', (ff_dim, ), 'uniform')
        span_scores = tt.dot(p_stt_end,
                             w_a)  # (max_p_len, max_ans_len, batch_size)

        ############ span masks

        p_lens_shuffled = p_lens.dimshuffle('x', 'x', 0)  # (1, 1, batch_size)
        p_end_idxs_shuffled = p_end_idxs.dimshuffle(
            0, 1, 'x')  # (max_p_len, max_ans_len, 1)
        span_masks = tt.lt(
            p_end_idxs_shuffled,
            p_lens_shuffled)  # (max_p_len, max_ans_len, batch_size)

        span_scores_reshaped = span_scores.dimshuffle(
            (2, 0, 1)).reshape(  # (batch_size, max_p_len*max_ans_len)
                (batch_size, -1))
        span_masks_reshaped = span_masks.dimshuffle(
            (2, 0, 1)).reshape(  # (batch_size, max_p_len*max_ans_len)
                (batch_size, -1))
        span_masks_reshaped = cast_floatX(span_masks_reshaped)

        if config.is_train:

            ############ loss

            # (batch_size,), (batch_size), (batch_size,)
            trn_xents, trn_accs, trn_a_hat = _single_answer_classification(
                span_scores_reshaped, span_masks_reshaped, trn_a)
            trn_loss, trn_acc = trn_xents.mean(), trn_accs.mean()

            # (batch_size,), (batch_size), (batch_size,), (batch_size), (batch_size,)
            dev_min_xents, dev_prx_xents, dev_max_accs, dev_prx_accs, dev_a_hat = _multi_answer_classification(
                span_scores_reshaped, span_masks_reshaped, dev_a, dev_a_mask)
            dev_min_loss, dev_prx_loss, dev_max_acc, dev_prx_acc = \
              dev_min_xents.mean(), dev_prx_xents.mean(), dev_max_accs.mean(), dev_prx_accs.mean()
            dev_ans_hat_start_word_idxs, dev_ans_hat_end_word_idxs = \
              _tt_ans_idx_to_ans_word_idxs(dev_a_hat, config.max_ans_len)

            ############ optimization

            opt = AdamOptimizer(config, trn_loss, self._params.values())
            updates = opt.get_updates()
            trn_global_grad_norm = opt.get_global_grad_norm()
            self.get_lr_value = lambda: opt.get_lr_value()

            ############ interface

            train_givens = {
                self._is_training: np.int32(1),
                in_ctxs: trn_ctxs,
                in_ctx_masks: trn_ctx_masks,
                in_ctx_lens: trn_ctx_lens,
                in_qtns: trn_qtns,
                in_qtn_masks: trn_qtn_masks,
                in_qtn_lens: trn_qtn_lens,
                in_trn_anss: trn_anss
            }
            dev_givens = {
                self._is_training: np.int32(0),
                in_ctxs: dev_ctxs,
                in_ctx_masks: dev_ctx_masks,
                in_ctx_lens: dev_ctx_lens,
                in_qtns: dev_qtns,
                in_qtn_masks: dev_qtn_masks,
                in_qtn_lens: dev_qtn_lens,
                in_dev_anss: dev_anss,
                in_dev_ans_masks: dev_ans_masks
            }

            self.train = theano.function(
                [in_sample_idxs], [trn_loss, trn_acc, trn_global_grad_norm],
                givens=train_givens,
                updates=updates)
            #mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=True))

            self.eval_dev = theano.function([in_sample_idxs], [
                dev_min_loss, dev_prx_loss, dev_max_acc, dev_prx_acc,
                dev_ans_hat_start_word_idxs, dev_ans_hat_end_word_idxs
            ],
                                            givens=dev_givens,
                                            updates=None)

        else:  # config.is_train = False

            tst_a_hat = _no_answer_classification(
                span_scores_reshaped, span_masks_reshaped)  # (batch_size,)
            tst_ans_hat_start_word_idxs, tst_ans_hat_end_word_idxs = _tt_ans_idx_to_ans_word_idxs(
                tst_a_hat, config.max_ans_len)

            tst_givens = {
                self._is_training: np.int32(0),
                in_ctxs: tst_ctxs,
                in_ctx_masks: tst_ctx_masks,
                in_ctx_lens: tst_ctx_lens,
                in_qtns: tst_qtns,
                in_qtn_masks: tst_qtn_masks,
                in_qtn_lens: tst_qtn_lens
            }

            self.eval_tst = theano.function(
                [in_sample_idxs],
                [tst_ans_hat_start_word_idxs, tst_ans_hat_end_word_idxs],
                givens=tst_givens,
                updates=None)