Ejemplo n.º 1
0
class VAE(nn.Module):
    def __init__(self, nx, ny, nz, horizon, specs):
        super(VAE, self).__init__()
        self.nx = nx
        self.ny = ny
        self.nz = nz
        self.horizon = horizon
        self.rnn_type = rnn_type = specs.get('rnn_type', 'lstm')
        self.x_birnn = x_birnn = specs.get('x_birnn', True)
        self.e_birnn = e_birnn = specs.get('e_birnn', True)
        self.use_drnn_mlp = specs.get('use_drnn_mlp', False)
        self.nh_rnn = nh_rnn = specs.get('nh_rnn', 128)
        self.nh_mlp = nh_mlp = specs.get('nh_mlp', [300, 200])
        # encode
        self.x_rnn = RNN(nx, nh_rnn, bi_dir=x_birnn, cell_type=rnn_type)
        self.e_rnn = RNN(ny, nh_rnn, bi_dir=e_birnn, cell_type=rnn_type)
        self.e_mlp = MLP(2 * nh_rnn, nh_mlp)
        self.e_mu = nn.Linear(self.e_mlp.out_dim, nz)
        self.e_logvar = nn.Linear(self.e_mlp.out_dim, nz)
        # decode
        if self.use_drnn_mlp:
            self.drnn_mlp = MLP(nh_rnn, nh_mlp + [nh_rnn], activation='tanh')
        self.d_rnn = RNN(ny + nz + nh_rnn, nh_rnn, cell_type=rnn_type)
        self.d_mlp = MLP(nh_rnn, nh_mlp)
        self.d_out = nn.Linear(self.d_mlp.out_dim, ny)
        self.d_rnn.set_mode('step')

    def encode_x(self, x):
        if self.x_birnn:
            h_x = self.x_rnn(x).mean(dim=0)
        else:
            h_x = self.x_rnn(x)[-1]
        return h_x

    def encode_y(self, y):
        if self.e_birnn:
            h_y = self.e_rnn(y).mean(dim=0)
        else:
            h_y = self.e_rnn(y)[-1]
        return h_y

    def encode(self, x, y):
        h_x = self.encode_x(x)
        h_y = self.encode_y(y)
        h = torch.cat((h_x, h_y), dim=1)
        h = self.e_mlp(h)
        return self.e_mu(h), self.e_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, x, z):
        h_x = self.encode_x(x)
        if self.use_drnn_mlp:
            h_d = self.drnn_mlp(h_x)
            self.d_rnn.initialize(batch_size=z.shape[0], hx=h_d)
        else:
            self.d_rnn.initialize(batch_size=z.shape[0])
        y = []
        for i in range(self.horizon):
            y_p = x[-1] if i == 0 else y_i
            rnn_in = torch.cat([h_x, z, y_p], dim=1)
            h = self.d_rnn(rnn_in)
            h = self.d_mlp(h)
            y_i = self.d_out(h)
            y.append(y_i)
        y = torch.stack(y)
        return y

    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = self.reparameterize(mu, logvar) if self.training else mu
        return self.decode(x, z), mu, logvar

    def sample_prior(self, x):
        z = torch.randn((x.shape[1], self.nz), device=x.device)
        return self.decode(x, z)
Ejemplo n.º 2
0
class VideoForecastNet(nn.Module):
    def __init__(self, cnn_feat_dim, state_dim, v_hdim=128, v_margin=10, v_net_type='lstm', v_net_param=None,
                 s_hdim=None, s_net_type='id', dynamic_v=False):
        super().__init__()
        s_hdim = state_dim if s_hdim is None else s_hdim
        self.mode = 'test'
        self.cnn_feat_dim = cnn_feat_dim
        self.state_dim = state_dim
        self.v_net_type = v_net_type
        self.v_hdim = v_hdim
        self.v_margin = v_margin
        self.s_net_type = s_net_type
        self.s_hdim = s_hdim
        self.dynamic_v = dynamic_v
        self.out_dim = v_hdim + s_hdim

        if v_net_type == 'lstm':
            self.v_net = RNN(cnn_feat_dim, v_hdim, v_net_type, bi_dir=False)
        elif v_net_type == 'tcn':
            if v_net_param is None:
                v_net_param = {}
            tcn_size = v_net_param.get('size', [64, 128])
            dropout = v_net_param.get('dropout', 0.2)
            kernel_size = v_net_param.get('kernel_size', 3)
            assert tcn_size[-1] == v_hdim
            self.v_net = TemporalConvNet(cnn_feat_dim, tcn_size, kernel_size=kernel_size, dropout=dropout, causal=True)

        if s_net_type == 'lstm':
            self.s_net = RNN(state_dim, s_hdim, s_net_type, bi_dir=False)

        self.v_out = None
        self.t = 0
        # training only
        self.indices = None
        self.s_scatter_indices = None
        self.s_gather_indices = None
        self.v_gather_indices = None
        self.cnn_feat_ctx = None
        self.num_episode = None
        self.max_episode_len = None
        self.set_mode('test')

    def set_mode(self, mode):
        self.mode = mode
        if self.s_net_type == 'lstm':
            if mode == 'train':
                self.s_net.set_mode('batch')
            else:
                self.s_net.set_mode('step')

    def initialize(self, x):
        if self.mode == 'test':
            self.v_out = self.forward_v_net(x.unsqueeze(1)[:self.v_margin])[-1]
            if self.s_net_type == 'lstm':
                self.s_net.initialize()
            self.t = 0
        elif self.mode == 'train':
            masks, cnn_feat, v_metas = x
            device, dtype = masks.device, masks.dtype
            end_indice = np.where(masks.cpu().numpy() == 0)[0]
            v_metas = v_metas[end_indice, :]
            num_episode = len(end_indice)
            end_indice = np.insert(end_indice, 0, -1)
            max_episode_len = int(np.diff(end_indice).max())
            self.num_episode = num_episode
            self.max_episode_len = max_episode_len
            self.indices = np.arange(masks.shape[0])
            for i in range(1, num_episode):
                start_index = end_indice[i] + 1
                end_index = end_indice[i + 1] + 1
                self.indices[start_index:end_index] += i * max_episode_len - start_index
            self.cnn_feat_ctx = np.zeros((self.v_margin + max_episode_len if
                                          self.dynamic_v else self.v_margin, num_episode, self.cnn_feat_dim))
            for i in range(num_episode):
                exp_ind, start_ind = v_metas[i, :]
                self.cnn_feat_ctx[:self.v_margin, i, :] = cnn_feat[exp_ind][start_ind - self.v_margin: start_ind]
            self.cnn_feat_ctx = tensor(self.cnn_feat_ctx, dtype=dtype, device=device)
            self.s_scatter_indices = LongTensor(np.tile(self.indices[:, None], (1, self.state_dim))).to(device)
            self.s_gather_indices = LongTensor(np.tile(self.indices[:, None], (1, self.s_hdim))).to(device)
            self.v_gather_indices = LongTensor(np.tile(self.indices[:, None], (1, self.v_hdim))).to(device)

    def forward(self, x):
        if self.mode == 'test':
            if self.s_net_type == 'lstm':
                x = self.s_net(x)
            x = torch.cat((self.v_out, x), dim=1)
            self.t += 1
        elif self.mode == 'train':
            if self.dynamic_v:
                v_ctx = self.forward_v_net(self.cnn_feat_ctx)[self.v_margin:]
            else:
                v_ctx = self.forward_v_net(self.cnn_feat_ctx)[[-1]]
                v_ctx = v_ctx.repeat(self.max_episode_len, 1, 1)
            v_ctx = v_ctx.transpose(0, 1).contiguous().view(-1, self.v_hdim)
            v_out = torch.gather(v_ctx, 0, self.v_gather_indices)
            if self.s_net_type == 'lstm':
                s_ctx = zeros((self.num_episode * self.max_episode_len, self.state_dim), device=x.device)
                s_ctx.scatter_(0, self.s_scatter_indices, x)
                s_ctx = s_ctx.view(-1, self.max_episode_len, self.state_dim).transpose(0, 1).contiguous()
                s_ctx = self.s_net(s_ctx).transpose(0, 1).contiguous().view(-1, self.s_hdim)
                s_out = torch.gather(s_ctx, 0, self.s_gather_indices)
            else:
                s_out = x
            x = torch.cat((v_out, s_out), dim=1)
        return x

    def forward_v_net(self, x):
        if self.v_net_type == 'tcn':
            x = x.permute(1, 2, 0).contiguous()
        x = self.v_net(x)
        if self.v_net_type == 'tcn':
            x = x.permute(2, 0, 1).contiguous()
        return x