class VAE(nn.Module): def __init__(self, nx, ny, nz, horizon, specs): super(VAE, self).__init__() self.nx = nx self.ny = ny self.nz = nz self.horizon = horizon self.rnn_type = rnn_type = specs.get('rnn_type', 'lstm') self.x_birnn = x_birnn = specs.get('x_birnn', True) self.e_birnn = e_birnn = specs.get('e_birnn', True) self.use_drnn_mlp = specs.get('use_drnn_mlp', False) self.nh_rnn = nh_rnn = specs.get('nh_rnn', 128) self.nh_mlp = nh_mlp = specs.get('nh_mlp', [300, 200]) # encode self.x_rnn = RNN(nx, nh_rnn, bi_dir=x_birnn, cell_type=rnn_type) self.e_rnn = RNN(ny, nh_rnn, bi_dir=e_birnn, cell_type=rnn_type) self.e_mlp = MLP(2 * nh_rnn, nh_mlp) self.e_mu = nn.Linear(self.e_mlp.out_dim, nz) self.e_logvar = nn.Linear(self.e_mlp.out_dim, nz) # decode if self.use_drnn_mlp: self.drnn_mlp = MLP(nh_rnn, nh_mlp + [nh_rnn], activation='tanh') self.d_rnn = RNN(ny + nz + nh_rnn, nh_rnn, cell_type=rnn_type) self.d_mlp = MLP(nh_rnn, nh_mlp) self.d_out = nn.Linear(self.d_mlp.out_dim, ny) self.d_rnn.set_mode('step') def encode_x(self, x): if self.x_birnn: h_x = self.x_rnn(x).mean(dim=0) else: h_x = self.x_rnn(x)[-1] return h_x def encode_y(self, y): if self.e_birnn: h_y = self.e_rnn(y).mean(dim=0) else: h_y = self.e_rnn(y)[-1] return h_y def encode(self, x, y): h_x = self.encode_x(x) h_y = self.encode_y(y) h = torch.cat((h_x, h_y), dim=1) h = self.e_mlp(h) return self.e_mu(h), self.e_logvar(h) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, x, z): h_x = self.encode_x(x) if self.use_drnn_mlp: h_d = self.drnn_mlp(h_x) self.d_rnn.initialize(batch_size=z.shape[0], hx=h_d) else: self.d_rnn.initialize(batch_size=z.shape[0]) y = [] for i in range(self.horizon): y_p = x[-1] if i == 0 else y_i rnn_in = torch.cat([h_x, z, y_p], dim=1) h = self.d_rnn(rnn_in) h = self.d_mlp(h) y_i = self.d_out(h) y.append(y_i) y = torch.stack(y) return y def forward(self, x, y): mu, logvar = self.encode(x, y) z = self.reparameterize(mu, logvar) if self.training else mu return self.decode(x, z), mu, logvar def sample_prior(self, x): z = torch.randn((x.shape[1], self.nz), device=x.device) return self.decode(x, z)
class VideoForecastNet(nn.Module): def __init__(self, cnn_feat_dim, state_dim, v_hdim=128, v_margin=10, v_net_type='lstm', v_net_param=None, s_hdim=None, s_net_type='id', dynamic_v=False): super().__init__() s_hdim = state_dim if s_hdim is None else s_hdim self.mode = 'test' self.cnn_feat_dim = cnn_feat_dim self.state_dim = state_dim self.v_net_type = v_net_type self.v_hdim = v_hdim self.v_margin = v_margin self.s_net_type = s_net_type self.s_hdim = s_hdim self.dynamic_v = dynamic_v self.out_dim = v_hdim + s_hdim if v_net_type == 'lstm': self.v_net = RNN(cnn_feat_dim, v_hdim, v_net_type, bi_dir=False) elif v_net_type == 'tcn': if v_net_param is None: v_net_param = {} tcn_size = v_net_param.get('size', [64, 128]) dropout = v_net_param.get('dropout', 0.2) kernel_size = v_net_param.get('kernel_size', 3) assert tcn_size[-1] == v_hdim self.v_net = TemporalConvNet(cnn_feat_dim, tcn_size, kernel_size=kernel_size, dropout=dropout, causal=True) if s_net_type == 'lstm': self.s_net = RNN(state_dim, s_hdim, s_net_type, bi_dir=False) self.v_out = None self.t = 0 # training only self.indices = None self.s_scatter_indices = None self.s_gather_indices = None self.v_gather_indices = None self.cnn_feat_ctx = None self.num_episode = None self.max_episode_len = None self.set_mode('test') def set_mode(self, mode): self.mode = mode if self.s_net_type == 'lstm': if mode == 'train': self.s_net.set_mode('batch') else: self.s_net.set_mode('step') def initialize(self, x): if self.mode == 'test': self.v_out = self.forward_v_net(x.unsqueeze(1)[:self.v_margin])[-1] if self.s_net_type == 'lstm': self.s_net.initialize() self.t = 0 elif self.mode == 'train': masks, cnn_feat, v_metas = x device, dtype = masks.device, masks.dtype end_indice = np.where(masks.cpu().numpy() == 0)[0] v_metas = v_metas[end_indice, :] num_episode = len(end_indice) end_indice = np.insert(end_indice, 0, -1) max_episode_len = int(np.diff(end_indice).max()) self.num_episode = num_episode self.max_episode_len = max_episode_len self.indices = np.arange(masks.shape[0]) for i in range(1, num_episode): start_index = end_indice[i] + 1 end_index = end_indice[i + 1] + 1 self.indices[start_index:end_index] += i * max_episode_len - start_index self.cnn_feat_ctx = np.zeros((self.v_margin + max_episode_len if self.dynamic_v else self.v_margin, num_episode, self.cnn_feat_dim)) for i in range(num_episode): exp_ind, start_ind = v_metas[i, :] self.cnn_feat_ctx[:self.v_margin, i, :] = cnn_feat[exp_ind][start_ind - self.v_margin: start_ind] self.cnn_feat_ctx = tensor(self.cnn_feat_ctx, dtype=dtype, device=device) self.s_scatter_indices = LongTensor(np.tile(self.indices[:, None], (1, self.state_dim))).to(device) self.s_gather_indices = LongTensor(np.tile(self.indices[:, None], (1, self.s_hdim))).to(device) self.v_gather_indices = LongTensor(np.tile(self.indices[:, None], (1, self.v_hdim))).to(device) def forward(self, x): if self.mode == 'test': if self.s_net_type == 'lstm': x = self.s_net(x) x = torch.cat((self.v_out, x), dim=1) self.t += 1 elif self.mode == 'train': if self.dynamic_v: v_ctx = self.forward_v_net(self.cnn_feat_ctx)[self.v_margin:] else: v_ctx = self.forward_v_net(self.cnn_feat_ctx)[[-1]] v_ctx = v_ctx.repeat(self.max_episode_len, 1, 1) v_ctx = v_ctx.transpose(0, 1).contiguous().view(-1, self.v_hdim) v_out = torch.gather(v_ctx, 0, self.v_gather_indices) if self.s_net_type == 'lstm': s_ctx = zeros((self.num_episode * self.max_episode_len, self.state_dim), device=x.device) s_ctx.scatter_(0, self.s_scatter_indices, x) s_ctx = s_ctx.view(-1, self.max_episode_len, self.state_dim).transpose(0, 1).contiguous() s_ctx = self.s_net(s_ctx).transpose(0, 1).contiguous().view(-1, self.s_hdim) s_out = torch.gather(s_ctx, 0, self.s_gather_indices) else: s_out = x x = torch.cat((v_out, s_out), dim=1) return x def forward_v_net(self, x): if self.v_net_type == 'tcn': x = x.permute(1, 2, 0).contiguous() x = self.v_net(x) if self.v_net_type == 'tcn': x = x.permute(2, 0, 1).contiguous() return x