def __init__(self, nx, t_total, specs): super(VAErecV2, self).__init__() self.nx = nx self.nz = nz = specs['nz'] self.t_total = t_total self.rnn_type = rnn_type = specs.get('rnn_type', 'gru') self.e_birnn = e_birnn = specs.get('e_birnn', False) self.use_drnn_mlp = specs.get('use_drnn_mlp', True) self.nx_rnn = nx_rnn = specs.get('nx_rnn', 128) self.nh_mlp = nh_mlp = specs.get('nh_mlp', [300, 200]) self.additive = specs.get('additive', False) # encode self.e_rnn = RNN(nx, nx_rnn, bi_dir=e_birnn, cell_type=rnn_type) self.e_mlp = MLP(nx_rnn, nh_mlp) self.e_mu = nn.Linear(self.e_mlp.out_dim, nz) self.e_logvar = nn.Linear(self.e_mlp.out_dim, nz) # decode if self.use_drnn_mlp: self.drnn_mlp = MLP(nx_rnn, nh_mlp + [nx_rnn], activation='relu') self.d_rnn = RNN(nx + nx_rnn, nx_rnn, cell_type=rnn_type) self.d_mlp = MLP(nx_rnn, nh_mlp) self.d_out = nn.Linear(self.d_mlp.out_dim, nx) self.d_rnn.set_mode('step') self.init_pose_mlp = MLP(nx_rnn, nh_mlp, activation='relu') self.init_pose_out = nn.Linear(self.init_pose_mlp.out_dim, nx)
def __init__(self, state_dim): super().__init__() self.state_dim = state_dim self.encoder_mlp = MLP(state_dim, (500, ), 'relu') self.encoder_linear = nn.Linear(500, 500) self.lstm1 = RNN(500, 1000, 'lstm') self.lstm2 = RNN(1000, 1000, 'lstm') self.decoder_mlp = MLP(1000, (500, 100), 'relu') self.decoder_linear = nn.Linear(100, state_dim) self.mode = 'batch'
def __init__(self, cnn_feat_dim, state_dim, v_hdim=128, v_margin=10, v_net_type='lstm', v_net_param=None, s_hdim=None, s_net_type='id', dynamic_v=False): super().__init__() s_hdim = state_dim if s_hdim is None else s_hdim self.mode = 'test' self.cnn_feat_dim = cnn_feat_dim self.state_dim = state_dim self.v_net_type = v_net_type self.v_hdim = v_hdim self.v_margin = v_margin self.s_net_type = s_net_type self.s_hdim = s_hdim self.dynamic_v = dynamic_v self.out_dim = v_hdim + s_hdim if v_net_type == 'lstm': self.v_net = RNN(cnn_feat_dim, v_hdim, v_net_type, bi_dir=False) elif v_net_type == 'tcn': if v_net_param is None: v_net_param = {} tcn_size = v_net_param.get('size', [64, 128]) dropout = v_net_param.get('dropout', 0.2) kernel_size = v_net_param.get('kernel_size', 3) assert tcn_size[-1] == v_hdim self.v_net = TemporalConvNet(cnn_feat_dim, tcn_size, kernel_size=kernel_size, dropout=dropout, causal=True) if s_net_type == 'lstm': self.s_net = RNN(state_dim, s_hdim, s_net_type, bi_dir=False) self.v_out = None self.t = 0 # training only self.indices = None self.s_scatter_indices = None self.s_gather_indices = None self.v_gather_indices = None self.cnn_feat_ctx = None self.num_episode = None self.max_episode_len = None self.set_mode('test')
class ERDNet(nn.Module): def __init__(self, state_dim): super().__init__() self.state_dim = state_dim self.encoder_mlp = MLP(state_dim, (500, ), 'relu') self.encoder_linear = nn.Linear(500, 500) self.lstm1 = RNN(500, 1000, 'lstm') self.lstm2 = RNN(1000, 1000, 'lstm') self.decoder_mlp = MLP(1000, (500, 100), 'relu') self.decoder_linear = nn.Linear(100, state_dim) self.mode = 'batch' def initialize(self, mode): self.mode = mode self.lstm1.set_mode(mode) self.lstm2.set_mode(mode) self.lstm1.initialize() self.lstm2.initialize() def forward(self, x): if self.mode == 'batch': batch_size = x.shape[1] x = x.view(-1, x.shape[-1]) x = self.encoder_mlp(x) x = self.encoder_linear(x) if self.mode == 'batch': x = x.view(-1, batch_size, x.shape[-1]) x = self.lstm1(x) x = self.lstm2(x) if self.mode == 'batch': x = x.view(-1, x.shape[-1]) x = self.decoder_mlp(x) x = self.decoder_linear(x) return x
def __init__(self, out_dim, v_hdim, cnn_fdim, no_cnn=False, frame_shape=(3, 224, 224), mlp_dim=(300, 200), cnn_type='resnet', v_net_type='lstm', v_net_param=None, cnn_rs=True, causal=False): super().__init__() self.out_dim = out_dim self.cnn_fdim = cnn_fdim self.v_hdim = v_hdim self.no_cnn = no_cnn self.frame_shape = frame_shape if no_cnn: self.cnn = None elif cnn_type == 'resnet': self.cnn = ResNet(cnn_fdim, running_stats=cnn_rs) elif cnn_type == 'mobile': self.cnn = MobileNet(cnn_fdim) self.v_net_type = v_net_type if v_net_type == 'lstm': self.v_net = RNN(cnn_fdim, v_hdim, v_net_type, bi_dir=not causal) elif v_net_type == 'tcn': if v_net_param is None: v_net_param = {} tcn_size = v_net_param.get('size', [64, 128]) dropout = v_net_param.get('dropout', 0.2) kernel_size = v_net_param.get('kernel_size', 3) assert tcn_size[-1] == v_hdim self.v_net = TemporalConvNet(cnn_fdim, tcn_size, kernel_size=kernel_size, dropout=dropout, causal=causal) self.mlp = MLP(v_hdim, mlp_dim, 'relu') self.linear = nn.Linear(self.mlp.out_dim, out_dim)
def __init__(self, cnn_feat_dim, v_hdim=128, v_margin=10, v_net_type='lstm', v_net_param=None, causal=False): super().__init__() self.mode = 'test' self.cnn_feat_dim = cnn_feat_dim self.v_net_type = v_net_type self.v_hdim = v_hdim self.v_margin = v_margin if v_net_type == 'lstm': self.v_net = RNN(cnn_feat_dim, v_hdim, v_net_type, bi_dir=not causal) elif v_net_type == 'tcn': if v_net_param is None: v_net_param = {} tcn_size = v_net_param.get('size', [64, 128]) dropout = v_net_param.get('dropout', 0.2) kernel_size = v_net_param.get('kernel_size', 3) assert tcn_size[-1] == v_hdim self.v_net = TemporalConvNet(cnn_feat_dim, tcn_size, kernel_size=kernel_size, dropout=dropout, causal=causal) self.v_out = None self.t = 0 # training only self.indices = None self.scatter_indices = None self.gather_indices = None self.cnn_feat_ctx = None
class VAErecV2(nn.Module): def __init__(self, nx, t_total, specs): super(VAErecV2, self).__init__() self.nx = nx self.nz = nz = specs['nz'] self.t_total = t_total self.rnn_type = rnn_type = specs.get('rnn_type', 'gru') self.e_birnn = e_birnn = specs.get('e_birnn', False) self.use_drnn_mlp = specs.get('use_drnn_mlp', True) self.nx_rnn = nx_rnn = specs.get('nx_rnn', 128) self.nh_mlp = nh_mlp = specs.get('nh_mlp', [300, 200]) self.additive = specs.get('additive', False) # encode self.e_rnn = RNN(nx, nx_rnn, bi_dir=e_birnn, cell_type=rnn_type) self.e_mlp = MLP(nx_rnn, nh_mlp) self.e_mu = nn.Linear(self.e_mlp.out_dim, nz) self.e_logvar = nn.Linear(self.e_mlp.out_dim, nz) # decode if self.use_drnn_mlp: self.drnn_mlp = MLP(nx_rnn, nh_mlp + [nx_rnn], activation='relu') self.d_rnn = RNN(nx + nx_rnn, nx_rnn, cell_type=rnn_type) self.d_mlp = MLP(nx_rnn, nh_mlp) self.d_out = nn.Linear(self.d_mlp.out_dim, nx) self.d_rnn.set_mode('step') self.init_pose_mlp = MLP(nx_rnn, nh_mlp, activation='relu') self.init_pose_out = nn.Linear(self.init_pose_mlp.out_dim, nx) def encode_x(self, x): if self.e_birnn: h_x = self.e_rnn(x).mean(dim=0) else: h_x = self.e_rnn(x)[-1] return h_x # def encode_x_all(self, x): # h_x = self.encode_x(x) # h = self.e_mlp(h_x) # return h_x, self.e_mu(h), self.e_logvar(h) def encode(self, x): # self.e_rnn.initialize(batch_size=x.shape[0]) h_x = self.encode_x(x) h = self.e_mlp(h_x) return self.e_mu(h), self.e_logvar(h) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std # def encode_hx(self, h_x): # h_init_pose = self.init_pose_mlp(h_x) # h_init_pose = self.init_pose_out(h_init_pose) # h = self.e_mlp(h_x) # return self.e_mu(h), self.e_logvar(h), h_init_pose # def decode_hx(self, h_x): # mu, logvar, h_init_pose = self.encode_hx(h_x) # z = mu # return self.decode(h_init_pose[None, ], z), mu, logvar def decode(self, z, x_p=None): if x_p == None: h_init_pose = self.init_pose_mlp(z) x = self.init_pose_out(h_init_pose) x_p = x # Feeding in the first frame of the predicted input self.d_rnn.initialize(batch_size=z.shape[0]) x_rec = [] for i in range(self.t_total): rnn_in = torch.cat([x_p, z], dim=1) h = self.d_rnn(rnn_in) h = self.d_mlp(h) x_i = self.d_out(h) # if self.additive: # x_i[..., :-6] += y_p[..., :-6] x_rec.append(x_i) x_p = x_i x_rec = torch.stack(x_rec) return x_rec def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) if self.training else mu return self.decode(z), mu, logvar def sample_prior(self, x): z = torch.randn((x.shape[1], self.nz), device=x.device) return self.decode(z) def step(self, model): pass
class VideoForecastNet(nn.Module): def __init__(self, cnn_feat_dim, state_dim, v_hdim=128, v_margin=10, v_net_type='lstm', v_net_param=None, s_hdim=None, s_net_type='id', dynamic_v=False): super().__init__() s_hdim = state_dim if s_hdim is None else s_hdim self.mode = 'test' self.cnn_feat_dim = cnn_feat_dim self.state_dim = state_dim self.v_net_type = v_net_type self.v_hdim = v_hdim self.v_margin = v_margin self.s_net_type = s_net_type self.s_hdim = s_hdim self.dynamic_v = dynamic_v self.out_dim = v_hdim + s_hdim if v_net_type == 'lstm': self.v_net = RNN(cnn_feat_dim, v_hdim, v_net_type, bi_dir=False) elif v_net_type == 'tcn': if v_net_param is None: v_net_param = {} tcn_size = v_net_param.get('size', [64, 128]) dropout = v_net_param.get('dropout', 0.2) kernel_size = v_net_param.get('kernel_size', 3) assert tcn_size[-1] == v_hdim self.v_net = TemporalConvNet(cnn_feat_dim, tcn_size, kernel_size=kernel_size, dropout=dropout, causal=True) if s_net_type == 'lstm': self.s_net = RNN(state_dim, s_hdim, s_net_type, bi_dir=False) self.v_out = None self.t = 0 # training only self.indices = None self.s_scatter_indices = None self.s_gather_indices = None self.v_gather_indices = None self.cnn_feat_ctx = None self.num_episode = None self.max_episode_len = None self.set_mode('test') def set_mode(self, mode): self.mode = mode if self.s_net_type == 'lstm': if mode == 'train': self.s_net.set_mode('batch') else: self.s_net.set_mode('step') def initialize(self, x): if self.mode == 'test': self.v_out = self.forward_v_net(x.unsqueeze(1)[:self.v_margin])[-1] if self.s_net_type == 'lstm': self.s_net.initialize() self.t = 0 elif self.mode == 'train': masks, cnn_feat, v_metas = x device, dtype = masks.device, masks.dtype end_indice = np.where(masks.cpu().numpy() == 0)[0] v_metas = v_metas[end_indice, :] num_episode = len(end_indice) end_indice = np.insert(end_indice, 0, -1) max_episode_len = int(np.diff(end_indice).max()) self.num_episode = num_episode self.max_episode_len = max_episode_len self.indices = np.arange(masks.shape[0]) for i in range(1, num_episode): start_index = end_indice[i] + 1 end_index = end_indice[i + 1] + 1 self.indices[ start_index:end_index] += i * max_episode_len - start_index self.cnn_feat_ctx = np.zeros( (self.v_margin + max_episode_len if self.dynamic_v else self.v_margin, num_episode, self.cnn_feat_dim)) for i in range(num_episode): exp_ind, start_ind = v_metas[i, :] self.cnn_feat_ctx[:self.v_margin, i, :] = cnn_feat[exp_ind][start_ind - self. v_margin:start_ind] self.cnn_feat_ctx = tensor(self.cnn_feat_ctx, dtype=dtype, device=device) self.s_scatter_indices = LongTensor( np.tile(self.indices[:, None], (1, self.state_dim))).to(device) self.s_gather_indices = LongTensor( np.tile(self.indices[:, None], (1, self.s_hdim))).to(device) self.v_gather_indices = LongTensor( np.tile(self.indices[:, None], (1, self.v_hdim))).to(device) def forward(self, x): if self.mode == 'test': if self.s_net_type == 'lstm': x = self.s_net(x) x = torch.cat((self.v_out, x), dim=1) self.t += 1 elif self.mode == 'train': if self.dynamic_v: v_ctx = self.forward_v_net(self.cnn_feat_ctx)[self.v_margin:] else: v_ctx = self.forward_v_net(self.cnn_feat_ctx)[[-1]] v_ctx = v_ctx.repeat(self.max_episode_len, 1, 1) v_ctx = v_ctx.transpose(0, 1).contiguous().view(-1, self.v_hdim) v_out = torch.gather(v_ctx, 0, self.v_gather_indices) if self.s_net_type == 'lstm': s_ctx = zeros( (self.num_episode * self.max_episode_len, self.state_dim), device=x.device) s_ctx.scatter_(0, self.s_scatter_indices, x) s_ctx = s_ctx.view(-1, self.max_episode_len, self.state_dim).transpose(0, 1).contiguous() s_ctx = self.s_net(s_ctx).transpose(0, 1).contiguous().view( -1, self.s_hdim) s_out = torch.gather(s_ctx, 0, self.s_gather_indices) else: s_out = x x = torch.cat((v_out, s_out), dim=1) return x def forward_v_net(self, x): if self.v_net_type == 'tcn': x = x.permute(1, 2, 0).contiguous() x = self.v_net(x) if self.v_net_type == 'tcn': x = x.permute(2, 0, 1).contiguous() return x