def __init__(self, nx, t_total, specs):
        super(VAErecV2, self).__init__()
        self.nx = nx
        self.nz = nz = specs['nz']
        self.t_total = t_total
        self.rnn_type = rnn_type = specs.get('rnn_type', 'gru')
        self.e_birnn = e_birnn = specs.get('e_birnn', False)
        self.use_drnn_mlp = specs.get('use_drnn_mlp', True)
        self.nx_rnn = nx_rnn = specs.get('nx_rnn', 128)
        self.nh_mlp = nh_mlp = specs.get('nh_mlp', [300, 200])
        self.additive = specs.get('additive', False)
        # encode
        self.e_rnn = RNN(nx, nx_rnn, bi_dir=e_birnn, cell_type=rnn_type)
        self.e_mlp = MLP(nx_rnn, nh_mlp)
        self.e_mu = nn.Linear(self.e_mlp.out_dim, nz)
        self.e_logvar = nn.Linear(self.e_mlp.out_dim, nz)
        # decode
        if self.use_drnn_mlp:
            self.drnn_mlp = MLP(nx_rnn, nh_mlp + [nx_rnn], activation='relu')
        self.d_rnn = RNN(nx + nx_rnn, nx_rnn, cell_type=rnn_type)
        self.d_mlp = MLP(nx_rnn, nh_mlp)
        self.d_out = nn.Linear(self.d_mlp.out_dim, nx)
        self.d_rnn.set_mode('step')

        self.init_pose_mlp = MLP(nx_rnn, nh_mlp, activation='relu')
        self.init_pose_out = nn.Linear(self.init_pose_mlp.out_dim, nx)
Example #2
0
 def __init__(self, state_dim):
     super().__init__()
     self.state_dim = state_dim
     self.encoder_mlp = MLP(state_dim, (500, ), 'relu')
     self.encoder_linear = nn.Linear(500, 500)
     self.lstm1 = RNN(500, 1000, 'lstm')
     self.lstm2 = RNN(1000, 1000, 'lstm')
     self.decoder_mlp = MLP(1000, (500, 100), 'relu')
     self.decoder_linear = nn.Linear(100, state_dim)
     self.mode = 'batch'
    def __init__(self,
                 cnn_feat_dim,
                 state_dim,
                 v_hdim=128,
                 v_margin=10,
                 v_net_type='lstm',
                 v_net_param=None,
                 s_hdim=None,
                 s_net_type='id',
                 dynamic_v=False):
        super().__init__()
        s_hdim = state_dim if s_hdim is None else s_hdim
        self.mode = 'test'
        self.cnn_feat_dim = cnn_feat_dim
        self.state_dim = state_dim
        self.v_net_type = v_net_type
        self.v_hdim = v_hdim
        self.v_margin = v_margin
        self.s_net_type = s_net_type
        self.s_hdim = s_hdim
        self.dynamic_v = dynamic_v
        self.out_dim = v_hdim + s_hdim

        if v_net_type == 'lstm':
            self.v_net = RNN(cnn_feat_dim, v_hdim, v_net_type, bi_dir=False)
        elif v_net_type == 'tcn':
            if v_net_param is None:
                v_net_param = {}
            tcn_size = v_net_param.get('size', [64, 128])
            dropout = v_net_param.get('dropout', 0.2)
            kernel_size = v_net_param.get('kernel_size', 3)
            assert tcn_size[-1] == v_hdim
            self.v_net = TemporalConvNet(cnn_feat_dim,
                                         tcn_size,
                                         kernel_size=kernel_size,
                                         dropout=dropout,
                                         causal=True)

        if s_net_type == 'lstm':
            self.s_net = RNN(state_dim, s_hdim, s_net_type, bi_dir=False)

        self.v_out = None
        self.t = 0
        # training only
        self.indices = None
        self.s_scatter_indices = None
        self.s_gather_indices = None
        self.v_gather_indices = None
        self.cnn_feat_ctx = None
        self.num_episode = None
        self.max_episode_len = None
        self.set_mode('test')
Example #4
0
class ERDNet(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.state_dim = state_dim
        self.encoder_mlp = MLP(state_dim, (500, ), 'relu')
        self.encoder_linear = nn.Linear(500, 500)
        self.lstm1 = RNN(500, 1000, 'lstm')
        self.lstm2 = RNN(1000, 1000, 'lstm')
        self.decoder_mlp = MLP(1000, (500, 100), 'relu')
        self.decoder_linear = nn.Linear(100, state_dim)
        self.mode = 'batch'

    def initialize(self, mode):
        self.mode = mode
        self.lstm1.set_mode(mode)
        self.lstm2.set_mode(mode)
        self.lstm1.initialize()
        self.lstm2.initialize()

    def forward(self, x):
        if self.mode == 'batch':
            batch_size = x.shape[1]
            x = x.view(-1, x.shape[-1])
        x = self.encoder_mlp(x)
        x = self.encoder_linear(x)
        if self.mode == 'batch':
            x = x.view(-1, batch_size, x.shape[-1])
        x = self.lstm1(x)
        x = self.lstm2(x)
        if self.mode == 'batch':
            x = x.view(-1, x.shape[-1])
        x = self.decoder_mlp(x)
        x = self.decoder_linear(x)
        return x
Example #5
0
    def __init__(self,
                 out_dim,
                 v_hdim,
                 cnn_fdim,
                 no_cnn=False,
                 frame_shape=(3, 224, 224),
                 mlp_dim=(300, 200),
                 cnn_type='resnet',
                 v_net_type='lstm',
                 v_net_param=None,
                 cnn_rs=True,
                 causal=False):
        super().__init__()
        self.out_dim = out_dim
        self.cnn_fdim = cnn_fdim
        self.v_hdim = v_hdim
        self.no_cnn = no_cnn
        self.frame_shape = frame_shape
        if no_cnn:
            self.cnn = None
        elif cnn_type == 'resnet':
            self.cnn = ResNet(cnn_fdim, running_stats=cnn_rs)
        elif cnn_type == 'mobile':
            self.cnn = MobileNet(cnn_fdim)

        self.v_net_type = v_net_type
        if v_net_type == 'lstm':
            self.v_net = RNN(cnn_fdim, v_hdim, v_net_type, bi_dir=not causal)
        elif v_net_type == 'tcn':
            if v_net_param is None:
                v_net_param = {}
            tcn_size = v_net_param.get('size', [64, 128])
            dropout = v_net_param.get('dropout', 0.2)
            kernel_size = v_net_param.get('kernel_size', 3)
            assert tcn_size[-1] == v_hdim
            self.v_net = TemporalConvNet(cnn_fdim,
                                         tcn_size,
                                         kernel_size=kernel_size,
                                         dropout=dropout,
                                         causal=causal)
        self.mlp = MLP(v_hdim, mlp_dim, 'relu')
        self.linear = nn.Linear(self.mlp.out_dim, out_dim)
 def __init__(self,
              cnn_feat_dim,
              v_hdim=128,
              v_margin=10,
              v_net_type='lstm',
              v_net_param=None,
              causal=False):
     super().__init__()
     self.mode = 'test'
     self.cnn_feat_dim = cnn_feat_dim
     self.v_net_type = v_net_type
     self.v_hdim = v_hdim
     self.v_margin = v_margin
     if v_net_type == 'lstm':
         self.v_net = RNN(cnn_feat_dim,
                          v_hdim,
                          v_net_type,
                          bi_dir=not causal)
     elif v_net_type == 'tcn':
         if v_net_param is None:
             v_net_param = {}
         tcn_size = v_net_param.get('size', [64, 128])
         dropout = v_net_param.get('dropout', 0.2)
         kernel_size = v_net_param.get('kernel_size', 3)
         assert tcn_size[-1] == v_hdim
         self.v_net = TemporalConvNet(cnn_feat_dim,
                                      tcn_size,
                                      kernel_size=kernel_size,
                                      dropout=dropout,
                                      causal=causal)
     self.v_out = None
     self.t = 0
     # training only
     self.indices = None
     self.scatter_indices = None
     self.gather_indices = None
     self.cnn_feat_ctx = None
class VAErecV2(nn.Module):
    def __init__(self, nx, t_total, specs):
        super(VAErecV2, self).__init__()
        self.nx = nx
        self.nz = nz = specs['nz']
        self.t_total = t_total
        self.rnn_type = rnn_type = specs.get('rnn_type', 'gru')
        self.e_birnn = e_birnn = specs.get('e_birnn', False)
        self.use_drnn_mlp = specs.get('use_drnn_mlp', True)
        self.nx_rnn = nx_rnn = specs.get('nx_rnn', 128)
        self.nh_mlp = nh_mlp = specs.get('nh_mlp', [300, 200])
        self.additive = specs.get('additive', False)
        # encode
        self.e_rnn = RNN(nx, nx_rnn, bi_dir=e_birnn, cell_type=rnn_type)
        self.e_mlp = MLP(nx_rnn, nh_mlp)
        self.e_mu = nn.Linear(self.e_mlp.out_dim, nz)
        self.e_logvar = nn.Linear(self.e_mlp.out_dim, nz)
        # decode
        if self.use_drnn_mlp:
            self.drnn_mlp = MLP(nx_rnn, nh_mlp + [nx_rnn], activation='relu')
        self.d_rnn = RNN(nx + nx_rnn, nx_rnn, cell_type=rnn_type)
        self.d_mlp = MLP(nx_rnn, nh_mlp)
        self.d_out = nn.Linear(self.d_mlp.out_dim, nx)
        self.d_rnn.set_mode('step')

        self.init_pose_mlp = MLP(nx_rnn, nh_mlp, activation='relu')
        self.init_pose_out = nn.Linear(self.init_pose_mlp.out_dim, nx)

    def encode_x(self, x):
        if self.e_birnn:
            h_x = self.e_rnn(x).mean(dim=0)
        else:
            h_x = self.e_rnn(x)[-1]
        return h_x

    # def encode_x_all(self, x):
    #     h_x = self.encode_x(x)
    #     h = self.e_mlp(h_x)
    #     return h_x, self.e_mu(h), self.e_logvar(h)

    def encode(self, x):
        # self.e_rnn.initialize(batch_size=x.shape[0])
        h_x = self.encode_x(x)
        h = self.e_mlp(h_x)
        return self.e_mu(h), self.e_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    # def encode_hx(self, h_x):
    #     h_init_pose = self.init_pose_mlp(h_x)
    #     h_init_pose = self.init_pose_out(h_init_pose)
    #     h = self.e_mlp(h_x)
    #     return self.e_mu(h), self.e_logvar(h), h_init_pose

    # def decode_hx(self, h_x):
    #     mu, logvar, h_init_pose = self.encode_hx(h_x)
    #     z = mu
    #     return self.decode(h_init_pose[None, ], z), mu, logvar

    def decode(self, z, x_p=None):
        if x_p == None:
            h_init_pose = self.init_pose_mlp(z)
            x = self.init_pose_out(h_init_pose)
            x_p = x  # Feeding in the first frame of the predicted input

        self.d_rnn.initialize(batch_size=z.shape[0])
        x_rec = []

        for i in range(self.t_total):
            rnn_in = torch.cat([x_p, z], dim=1)
            h = self.d_rnn(rnn_in)
            h = self.d_mlp(h)
            x_i = self.d_out(h)
            # if self.additive:
            # x_i[..., :-6] += y_p[..., :-6]
            x_rec.append(x_i)
            x_p = x_i
        x_rec = torch.stack(x_rec)

        return x_rec

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar) if self.training else mu
        return self.decode(z), mu, logvar

    def sample_prior(self, x):
        z = torch.randn((x.shape[1], self.nz), device=x.device)
        return self.decode(z)

    def step(self, model):
        pass
class VideoForecastNet(nn.Module):
    def __init__(self,
                 cnn_feat_dim,
                 state_dim,
                 v_hdim=128,
                 v_margin=10,
                 v_net_type='lstm',
                 v_net_param=None,
                 s_hdim=None,
                 s_net_type='id',
                 dynamic_v=False):
        super().__init__()
        s_hdim = state_dim if s_hdim is None else s_hdim
        self.mode = 'test'
        self.cnn_feat_dim = cnn_feat_dim
        self.state_dim = state_dim
        self.v_net_type = v_net_type
        self.v_hdim = v_hdim
        self.v_margin = v_margin
        self.s_net_type = s_net_type
        self.s_hdim = s_hdim
        self.dynamic_v = dynamic_v
        self.out_dim = v_hdim + s_hdim

        if v_net_type == 'lstm':
            self.v_net = RNN(cnn_feat_dim, v_hdim, v_net_type, bi_dir=False)
        elif v_net_type == 'tcn':
            if v_net_param is None:
                v_net_param = {}
            tcn_size = v_net_param.get('size', [64, 128])
            dropout = v_net_param.get('dropout', 0.2)
            kernel_size = v_net_param.get('kernel_size', 3)
            assert tcn_size[-1] == v_hdim
            self.v_net = TemporalConvNet(cnn_feat_dim,
                                         tcn_size,
                                         kernel_size=kernel_size,
                                         dropout=dropout,
                                         causal=True)

        if s_net_type == 'lstm':
            self.s_net = RNN(state_dim, s_hdim, s_net_type, bi_dir=False)

        self.v_out = None
        self.t = 0
        # training only
        self.indices = None
        self.s_scatter_indices = None
        self.s_gather_indices = None
        self.v_gather_indices = None
        self.cnn_feat_ctx = None
        self.num_episode = None
        self.max_episode_len = None
        self.set_mode('test')

    def set_mode(self, mode):
        self.mode = mode
        if self.s_net_type == 'lstm':
            if mode == 'train':
                self.s_net.set_mode('batch')
            else:
                self.s_net.set_mode('step')

    def initialize(self, x):
        if self.mode == 'test':
            self.v_out = self.forward_v_net(x.unsqueeze(1)[:self.v_margin])[-1]
            if self.s_net_type == 'lstm':
                self.s_net.initialize()
            self.t = 0
        elif self.mode == 'train':
            masks, cnn_feat, v_metas = x
            device, dtype = masks.device, masks.dtype
            end_indice = np.where(masks.cpu().numpy() == 0)[0]
            v_metas = v_metas[end_indice, :]
            num_episode = len(end_indice)
            end_indice = np.insert(end_indice, 0, -1)
            max_episode_len = int(np.diff(end_indice).max())
            self.num_episode = num_episode
            self.max_episode_len = max_episode_len
            self.indices = np.arange(masks.shape[0])
            for i in range(1, num_episode):
                start_index = end_indice[i] + 1
                end_index = end_indice[i + 1] + 1
                self.indices[
                    start_index:end_index] += i * max_episode_len - start_index
            self.cnn_feat_ctx = np.zeros(
                (self.v_margin +
                 max_episode_len if self.dynamic_v else self.v_margin,
                 num_episode, self.cnn_feat_dim))
            for i in range(num_episode):
                exp_ind, start_ind = v_metas[i, :]
                self.cnn_feat_ctx[:self.v_margin,
                                  i, :] = cnn_feat[exp_ind][start_ind - self.
                                                            v_margin:start_ind]
            self.cnn_feat_ctx = tensor(self.cnn_feat_ctx,
                                       dtype=dtype,
                                       device=device)
            self.s_scatter_indices = LongTensor(
                np.tile(self.indices[:, None], (1, self.state_dim))).to(device)
            self.s_gather_indices = LongTensor(
                np.tile(self.indices[:, None], (1, self.s_hdim))).to(device)
            self.v_gather_indices = LongTensor(
                np.tile(self.indices[:, None], (1, self.v_hdim))).to(device)

    def forward(self, x):
        if self.mode == 'test':
            if self.s_net_type == 'lstm':
                x = self.s_net(x)
            x = torch.cat((self.v_out, x), dim=1)
            self.t += 1
        elif self.mode == 'train':
            if self.dynamic_v:
                v_ctx = self.forward_v_net(self.cnn_feat_ctx)[self.v_margin:]
            else:
                v_ctx = self.forward_v_net(self.cnn_feat_ctx)[[-1]]
                v_ctx = v_ctx.repeat(self.max_episode_len, 1, 1)
            v_ctx = v_ctx.transpose(0, 1).contiguous().view(-1, self.v_hdim)
            v_out = torch.gather(v_ctx, 0, self.v_gather_indices)
            if self.s_net_type == 'lstm':
                s_ctx = zeros(
                    (self.num_episode * self.max_episode_len, self.state_dim),
                    device=x.device)
                s_ctx.scatter_(0, self.s_scatter_indices, x)
                s_ctx = s_ctx.view(-1, self.max_episode_len,
                                   self.state_dim).transpose(0,
                                                             1).contiguous()
                s_ctx = self.s_net(s_ctx).transpose(0, 1).contiguous().view(
                    -1, self.s_hdim)
                s_out = torch.gather(s_ctx, 0, self.s_gather_indices)
            else:
                s_out = x
            x = torch.cat((v_out, s_out), dim=1)
        return x

    def forward_v_net(self, x):
        if self.v_net_type == 'tcn':
            x = x.permute(1, 2, 0).contiguous()
        x = self.v_net(x)
        if self.v_net_type == 'tcn':
            x = x.permute(2, 0, 1).contiguous()
        return x