def __init__(self, n_choices, discrete_dim, embed_dim): super(CondRnnSampler, self).__init__(n_choices, discrete_dim, embed_dim) self.token_embed = nn.Parameter(torch.Tensor(n_choices, embed_dim)) glorot_uniform(self) self.lstm = nn.LSTMCell(embed_dim, embed_dim)
def __init__(self, args): super(RfillAutoreg, self).__init__() glorot_uniform(self) self.DECISION_MASK = torch.tensor(DECISION_MASK).to(args.device) self.STATE_TRANS = torch.LongTensor(STATE_TRANS).to(args.device) self.cell_type = args.cell_type self.vocab = deepcopy(RFILL_VOCAB) self.tok_start = self.vocab['|'] self.tok_stop = self.vocab['eos'] self.tok_pad = self.vocab['pad'] assert self.tok_pad == 0 self.inv_map = {} for key in self.vocab: self.inv_map[self.vocab[key]] = key self.rnn_state_proj = args.rnn_state_proj self.rnn_layers = args.rnn_layers if self.rnn_state_proj: self.ctx2h = MLP(args.embed_dim, [args.embed_dim * self.rnn_layers], nonlinearity=args.act_func, act_last=args.act_func) if self.cell_type == 'lstm': self.ctx2c = MLP(args.embed_dim, [args.embed_dim * self.rnn_layers], nonlinearity=args.act_func, act_last=args.act_func) if args.tok_type == 'embed': self.tok_embed = nn.Embedding(len(self.vocab), args.embed_dim) input_size = args.embed_dim elif args.tok_type == 'onehot': input_size = len(self.vocab) self.tok_embed = partial(self._get_onehot, vsize=input_size) if self.cell_type == 'lstm': self.rnn = nn.LSTM(input_size, args.embed_dim, self.rnn_layers, bidirectional=False) elif self.cell_type == 'gru': self.rnn = nn.GRU(input_size, args.embed_dim, self.rnn_layers, bidirectional=False) else: raise NotImplementedError self.out_pred = nn.Linear(args.embed_dim, len(self.vocab))
def __init__(self, n_choices, discrete_dim, embed_dim): super(MLPSampler, self).__init__(n_choices, discrete_dim, embed_dim) self.init_h = nn.Parameter(torch.Tensor(1, embed_dim)) glorot_uniform(self) list_mods = [] for i in range(1, self.discrete_dim): mlp = MLP(i, [embed_dim * 2, embed_dim * 2, embed_dim]) list_mods.append(mlp) self.list_mods = nn.ModuleList(list_mods)
def __init__(self, n_choices, discrete_dim): super(IidSampler, self).__init__() self.logits = nn.Parameter(torch.Tensor(discrete_dim, n_choices)) self.baselines = nn.Parameter(torch.Tensor(1, discrete_dim)) glorot_uniform(self)