Example #1
0
    def __init__(self, n_choices, discrete_dim, embed_dim):
        super(CondRnnSampler, self).__init__(n_choices, discrete_dim,
                                             embed_dim)

        self.token_embed = nn.Parameter(torch.Tensor(n_choices, embed_dim))
        glorot_uniform(self)
        self.lstm = nn.LSTMCell(embed_dim, embed_dim)
Example #2
0
 def __init__(self, args):
     super(RfillAutoreg, self).__init__()
     glorot_uniform(self)
     self.DECISION_MASK = torch.tensor(DECISION_MASK).to(args.device)
     self.STATE_TRANS = torch.LongTensor(STATE_TRANS).to(args.device)
     self.cell_type = args.cell_type
     self.vocab = deepcopy(RFILL_VOCAB)
     self.tok_start = self.vocab['|']
     self.tok_stop = self.vocab['eos']
     self.tok_pad = self.vocab['pad']
     assert self.tok_pad == 0
     self.inv_map = {}
     for key in self.vocab:
         self.inv_map[self.vocab[key]] = key
     self.rnn_state_proj = args.rnn_state_proj
     self.rnn_layers = args.rnn_layers
     if self.rnn_state_proj:
         self.ctx2h = MLP(args.embed_dim, [args.embed_dim * self.rnn_layers], nonlinearity=args.act_func, act_last=args.act_func)
         if self.cell_type == 'lstm':
             self.ctx2c = MLP(args.embed_dim, [args.embed_dim * self.rnn_layers], nonlinearity=args.act_func, act_last=args.act_func)
     if args.tok_type == 'embed':
         self.tok_embed = nn.Embedding(len(self.vocab), args.embed_dim)
         input_size = args.embed_dim
     elif args.tok_type == 'onehot':
         input_size = len(self.vocab)
         self.tok_embed = partial(self._get_onehot, vsize=input_size)
     if self.cell_type == 'lstm':
         self.rnn = nn.LSTM(input_size, args.embed_dim, self.rnn_layers, bidirectional=False)
     elif self.cell_type == 'gru':
         self.rnn = nn.GRU(input_size, args.embed_dim, self.rnn_layers, bidirectional=False)
     else:
         raise NotImplementedError
     self.out_pred = nn.Linear(args.embed_dim, len(self.vocab))
    def __init__(self, n_choices, discrete_dim, embed_dim):
        super(MLPSampler, self).__init__(n_choices, discrete_dim, embed_dim)
        self.init_h = nn.Parameter(torch.Tensor(1, embed_dim))
        glorot_uniform(self)

        list_mods = []
        for i in range(1, self.discrete_dim):
            mlp = MLP(i, [embed_dim * 2, embed_dim * 2, embed_dim])
            list_mods.append(mlp)
        self.list_mods = nn.ModuleList(list_mods)
 def __init__(self, n_choices, discrete_dim):
     super(IidSampler, self).__init__()
     self.logits = nn.Parameter(torch.Tensor(discrete_dim, n_choices))
     self.baselines = nn.Parameter(torch.Tensor(1, discrete_dim))
     glorot_uniform(self)