def __init__(self, args, max_stack_depth, bracket_types):
   self.args = args
   self.max_stack_depth = max_stack_depth
   self.bracket_types = bracket_types
   self.vocab, self.ids = utils.get_vocab_of_bracket_types(bracket_types)
   self.vocab_list = list(sorted(self.vocab.keys(), key=lambda x: self.vocab[x]))
   self.distributions = {}
Пример #2
0
  def __init__(self, args):
    self.args = args
    self.observation_class = namedtuple('observation', ['sentence'])
    self.vocab, _ = utils.get_vocab_of_bracket_types(args['language']['bracket_types'])
    args['language']['vocab_size'] = len(self.vocab)
    self.batch_size = args['training']['batch_size']

    train_dataset_path = utils.get_corpus_paths_of_args(args)['train']
    dev_dataset_path = utils.get_corpus_paths_of_args(args)['dev']
    test_dataset_path = utils.get_corpus_paths_of_args(args)['test']

    self.train_dataset = ObservationIterator(self.load_tokenized_dataset(train_dataset_path, 'train'))
    self.dev_dataset = ObservationIterator(self.load_tokenized_dataset(dev_dataset_path, 'dev'))
    self.test_dataset = ObservationIterator(self.load_tokenized_dataset(test_dataset_path, 'test'))
Пример #3
0
def get_dyckkm_simplernn_mlogk_params(k, m):
    """
  Returns Simple RNN parameters for O(m\log k) construction
  generating Dyck-(k,m)
  """
    assert math.log2(k).is_integer()
    vocab = utils.get_vocab_of_bracket_types(k)[0]
    print(vocab)
    vocab_size = len(vocab)
    open_bracket_indices = list(range(0, k))
    close_bracket_indices = list(range(k, 2 * k))
    end_symbol_index = 2 * k
    slot_size = 3 * (int(math.log2(k))) - 1
    #slot_size = k
    print('Slot size', 2 * slot_size)
    hidden_size = slot_size * m
    print('hidden size', 2 * hidden_size)
    num_stack_states = m

    # Vocabulary
    embedding_weights = torch.eye(vocab_size - 1, vocab_size - 1)

    # W (recurrent matrix)
    matrixDown = torch.FloatTensor(
        [[1 if x == y - slot_size else 0 for x in range(slot_size * m)]
         for y in range(slot_size * m)])
    matrixUp = torch.FloatTensor(
        [[1 if x == y + slot_size else 0 for x in range(slot_size * m)]
         for y in range(slot_size * m)])
    Wtop = torch.cat((matrixDown, matrixDown), 1)
    Wbottom = torch.cat((matrixUp, matrixUp), 1)
    W = 2 * torch.cat((Wtop, Wbottom), 0)
    print('W', W.shape, W)

    # U (input matrix)
    efficient_embeddings, efficient_softmaxes = get_efficient_ctilde_u_mtx(k)
    one_slot_k = torch.ones(slot_size, k)
    zero_mslot_k = torch.zeros((m - 1) * slot_size, k)
    zero_slot_k = torch.zeros(slot_size, k)
    one_mslot_k = torch.ones((m - 1) * slot_size, k)
    U = torch.cat((
        torch.cat((2 * efficient_embeddings, -2 * one_slot_k), 1),
        torch.cat((zero_mslot_k, -2 * one_mslot_k), 1),
        torch.cat((-2 * one_slot_k, zero_slot_k), 1),
        torch.cat((-2 * one_mslot_k, zero_mslot_k), 1),
    ), 0)
    print('U', U)

    # b (bias)
    bias = -torch.ones(hidden_size * 2)
    print('b', bias)

    # softmax
    softmax_mtx = torch.zeros(len(vocab), hidden_size).float()
    softmax_bias = torch.zeros(len(vocab))

    # close-i
    for i, index_of_i in enumerate(close_bracket_indices):
        softmax_mtx[index_of_i, 0:slot_size] = efficient_softmaxes[i]
        #softmax_mtx[index_of_i,i] = 1
        softmax_bias[index_of_i] = -0.5

    # open-i
    for i, index_of_i in enumerate(open_bracket_indices):
        for j in range(slot_size):
            softmax_mtx[index_of_i, slot_size * (m - 1) + j] = -1
        softmax_bias[index_of_i] = 0.5

    # end
    for i in range(slot_size):
        for j in range(num_stack_states):
            softmax_mtx[end_symbol_index, slot_size * j + i] = -1
    softmax_bias[end_symbol_index] = 0.5
    softmax_mtx = torch.cat((softmax_mtx, softmax_mtx), 1)
    #softmax_bias = torch.cat((softmax_bias, softmax_bias),0)
    print('softmax')
    print('softmax mtx', softmax_mtx)
    print('softmax bias', softmax_bias)

    return (embedding_weights, 1e4 * W, 1e4 * U, 1e4 * bias, 1e4 * softmax_mtx,
            1e4 * softmax_bias)
Пример #4
0
def get_dyckkm_lstm_mlogk_params(k=4,m=3):
  """
  Returns LSTM parameters for O(m\log k) construction
  generating Dyck-(k,m)
  """
  assert math.log2(k).is_integer()
  vocab = utils.get_vocab_of_bracket_types(k)[0]
  vocab_size = len(vocab)
  open_bracket_indices = list(range(0,k))
  close_bracket_indices = list(range(k,2*k))
  end_symbol_index = 2*k
  slot_size = 3*(int(math.log2(k)))-1
  print('Slot size', slot_size)
  hidden_size = slot_size*m
  print('hidden size', hidden_size)
  num_stack_states = m

  # Vocabulary
  embedding_weights = torch.eye(vocab_size, vocab_size)
  #print(embedding_weights, embedding_weights.shape)

  # Input gate
  input_gate_hi_mtx = torch.stack(tuple((sum([embedding_weights[i] for i in open_bracket_indices]) for _ in range(hidden_size))))
  input_gate_hh_rows = []
  input_gate_bias_vals = []
  for i in range(m):
    if i == 0:
      neg_ones = -1*torch.ones(slot_size*num_stack_states)
      row = neg_ones
    else:
      ones = torch.ones(slot_size)
      neg_ones = -1*ones
      zeros_initial = torch.zeros(max(slot_size*(i-1),0))
      zeros_final = torch.zeros(max(slot_size*(num_stack_states-2-max(i-1,0)),0))
      row = torch.cat((zeros_initial, ones, neg_ones, zeros_final))
    input_gate_hh_rows.extend([row for _ in range(slot_size)])
  #for thing in input_gate_hh_rows:
  #  print(len(thing), thing)
  #print([len(x) for x in input_gate_hh_rows])
  input_gate_hh_mtx = torch.stack(input_gate_hh_rows)
  input_gate_bias = torch.tensor([-0.5 if i < slot_size else -1.5 for i in range(hidden_size)])
  print('input')
  print('b_i', input_gate_bias)
  print('W_i', input_gate_hh_mtx)
  print('U_i', input_gate_hi_mtx)

  # Forget gate
  forget_gate_hi_mtx = -torch.stack(tuple((sum([embedding_weights[i] for i in close_bracket_indices]) for _ in range(hidden_size))))
  forget_gate_hh_rows = []
  for i in range(num_stack_states):
    neg_ones = -1*torch.ones(slot_size)
    zeros_initial = torch.zeros(slot_size*i)
    zeros_final = torch.zeros(max(slot_size*(num_stack_states-1-i), 0))
    row = torch.cat((zeros_initial, neg_ones, zeros_final))
    forget_gate_hh_rows.extend([row for _ in range(slot_size)])
  forget_gate_hh_mtx = torch.stack(forget_gate_hh_rows)
  forget_gate_bias = torch.ones(hidden_size)*0.5*TANH_OF_1 + 1
  print('forget')
  print('b_f', forget_gate_bias)
  print('W_f', forget_gate_hh_mtx)
  print('U_f', forget_gate_hi_mtx)

  # Output gate
  output_gate_hi_mtx = torch.stack(tuple((sum((embedding_weights[i] for i in close_bracket_indices)) for _ in range(hidden_size))))
  output_gate_hh_rows = []
  for i in range(num_stack_states):
    neg_max = -num_stack_states*torch.ones(slot_size*max(i-1, 0))
    neg_ones = -torch.ones(slot_size*min(i+1,2))
    zeros = torch.zeros(slot_size*(num_stack_states-i-1))
    row = torch.cat((neg_max, neg_ones, zeros))
    output_gate_hh_rows.extend([row for _ in range(slot_size)])
  #for row in output_gate_hh_rows:
  #  print(len(row), row)
  output_gate_hh_mtx = torch.stack(output_gate_hh_rows).transpose(0,1)
  output_gate_bias = 0.5*torch.ones(hidden_size)
  print('output')
  print('b_o', output_gate_bias)
  print('W_o', output_gate_hh_mtx)
  print('U_o', output_gate_hi_mtx)

  # New cell candidate
  #new_cell_hi_mtx = torch.stack(tuple(itertools.chain.from_iterable((embedding_weights[i] for i in open_bracket_indices) for x in range(num_stack_states))))
  efficient_embeddings, efficient_softmaxes = get_efficient_ctilde_u_mtx(k)
  new_cell_hi_mtx = torch.stack(tuple(itertools.chain.from_iterable(efficient_embeddings for x in range(num_stack_states))))
  new_cell_hh_mtx = torch.zeros(hidden_size, hidden_size)
  new_cell_bias = torch.zeros(hidden_size)
  print('new cell')
  print('b_\\tilde{c}', new_cell_bias, new_cell_bias.shape)
  print('W_\\tilde{c}', new_cell_hh_mtx, new_cell_hh_mtx.shape)
  print('U_\\tilde{c}', new_cell_hi_mtx, new_cell_hi_mtx.shape)

  # Softmax
  softmax_mtx = torch.zeros(len(vocab), hidden_size).float()
  softmax_bias = torch.zeros(len(vocab))

  # close-i
  for i,index_of_i in enumerate(close_bracket_indices):
    for j in range(num_stack_states):
      #softmax_mtx[index_of_i,k*j+ i] = 1
      softmax_mtx[index_of_i,slot_size*j:slot_size*(j+1)] = efficient_softmaxes[i]
    softmax_bias[index_of_i] = -TANH_OF_1*0.5
  # open-i
  for i, index_of_i in enumerate(open_bracket_indices):
    for j in range(slot_size):
      softmax_mtx[index_of_i,slot_size*(m-1)+j] = -1
    softmax_bias[index_of_i] = 0.5*TANH_OF_1
  # end
  for i in range(slot_size):
    for j in range(num_stack_states):
      softmax_mtx[end_symbol_index,slot_size*j+i] = -1
  softmax_bias[end_symbol_index]= 0.5*TANH_OF_1
  print('softmax')
  print('softmax_mtx', softmax_mtx)
  print('softmax_bias', softmax_bias)

  return (embedding_weights, 1e4*input_gate_hh_mtx, 1e4*input_gate_hi_mtx, 1e4*input_gate_bias,
      1e4*output_gate_hh_mtx, 1e4*output_gate_hi_mtx, 1e4*output_gate_bias,
      1e4*forget_gate_hh_mtx, 1e4*forget_gate_hi_mtx, 1e4*forget_gate_bias,
      1e4*new_cell_hh_mtx, 1e4*new_cell_hi_mtx, 1e4*new_cell_bias,
      1e4*softmax_mtx, 1e4*softmax_bias)
Пример #5
0
def get_dyckkm_lstm_mk_params(k=5,m=3):
  vocab = utils.get_vocab_of_bracket_types(k)[0]
  vocab_size = len(vocab)
  open_bracket_indices = list(range(0,k))
  close_bracket_indices = list(range(k,2*k))
  end_symbol_index = 2*k
  hidden_size = k*m
  num_stack_states = m

  # Vocabulary
  embedding_weights = torch.eye(vocab_size, vocab_size)

  # Input gate
  input_gate_hi_mtx = torch.stack(tuple((sum([embedding_weights[i] for i in open_bracket_indices]) for _ in range(hidden_size))))
  input_gate_hh_rows = []
  input_gate_bias_vals = []
  for i in range(m):
    if i == 0:
      neg_ones = -1*torch.ones(k*num_stack_states)
      row = neg_ones
    else:
      ones = torch.ones(k)
      neg_ones = -1*ones
      zeros_initial = torch.zeros(max(k*(i-1),0))
      zeros_final = torch.zeros(max(k*(num_stack_states-2-max(i-1,0)),0))
      row = torch.cat((zeros_initial, ones, neg_ones, zeros_final))
    input_gate_hh_rows.extend([row for _ in range(k)])
  input_gate_hh_mtx = torch.stack(input_gate_hh_rows)
  input_gate_bias = torch.tensor([-0.5 if i < k else -1.5 for i in range(hidden_size)])
  print('input')
  print(input_gate_hi_mtx)
  print(input_gate_bias)
  print(input_gate_hh_mtx)

  # Forget gate
  forget_gate_hi_mtx = -torch.stack(tuple((sum([embedding_weights[i] for i in close_bracket_indices]) for _ in range(hidden_size))))
  forget_gate_hh_rows = []
  for i in range(num_stack_states):
    neg_ones = -1*torch.ones(k)
    zeros_initial = torch.zeros(k*i)
    zeros_final = torch.zeros(max(k*(num_stack_states-1-i), 0))
    row = torch.cat((zeros_initial, neg_ones, zeros_final))
    forget_gate_hh_rows.extend([row for _ in range(k)])
  forget_gate_hh_mtx = torch.stack(forget_gate_hh_rows)
  forget_gate_bias = torch.ones(hidden_size)*0.5*TANH_OF_1 + 1
  print('forget')
  print(forget_gate_bias)
  print(forget_gate_hh_mtx)
  print(forget_gate_hi_mtx)

  # Output gate
  output_gate_hi_mtx = torch.stack(tuple((sum((embedding_weights[i] for i in close_bracket_indices)) for _ in range(hidden_size))))
  output_gate_hh_rows = []
  for i in range(num_stack_states):
    neg_max = -num_stack_states*torch.ones(k*max(i-1, 0))
    neg_ones = -torch.ones(k*min(i+1,2))
    zeros = torch.zeros(k*(num_stack_states-i-1))
    row = torch.cat((neg_max, neg_ones, zeros))
    output_gate_hh_rows.extend([row for _ in range(k)])
  output_gate_hh_mtx = torch.stack(output_gate_hh_rows).transpose(0,1)
  output_gate_bias = 0.5*torch.ones(hidden_size)
  print('output')
  print(output_gate_bias)
  print(output_gate_hh_mtx)




  # New cell candidate
  new_cell_hi_mtx = torch.stack(tuple(itertools.chain.from_iterable((embedding_weights[i] for i in open_bracket_indices) for x in range(num_stack_states))))
  new_cell_hh_mtx = torch.zeros(hidden_size, hidden_size)
  new_cell_bias = torch.zeros(hidden_size)
  print('new cell')
  print(new_cell_bias)
  print(new_cell_hh_mtx)
  print(new_cell_hi_mtx)

  # Softmax
  softmax_mtx = torch.zeros(len(vocab), hidden_size).float()
  softmax_bias = torch.zeros(len(vocab))

  # close-i
  for i,index_of_i in enumerate(close_bracket_indices):
    for j in range(num_stack_states):
      softmax_mtx[index_of_i,k*j+ i] = 1
    softmax_bias[index_of_i] = -TANH_OF_1*0.5
  # open-i
  for i, index_of_i in enumerate(open_bracket_indices):
    for j in range(k):
      softmax_mtx[index_of_i,k*(m-1)+j] = -1
    softmax_bias[index_of_i] = 0.5*TANH_OF_1
  # end
  for i in range(k):
    for j in range(num_stack_states):
      softmax_mtx[end_symbol_index,k*j+i] = -1
  softmax_bias[end_symbol_index]= 0.5*TANH_OF_1
  print('softmax')
  print(softmax_mtx)

  return (embedding_weights, 1e4*input_gate_hh_mtx, 1e4*input_gate_hi_mtx, 1e4*input_gate_bias,
      1e4*output_gate_hh_mtx, 1e4*output_gate_hi_mtx, 1e4*output_gate_bias,
      1e4*forget_gate_hh_mtx, 1e4*forget_gate_hi_mtx, 1e4*forget_gate_bias,
      1e4*new_cell_hh_mtx, 1e4*new_cell_hi_mtx, 1e4*new_cell_bias,
      1e4*softmax_mtx, 1e4*softmax_bias)