def __init__(self, params, parser=None):
        super().__init__()

        self.model_args = ['y_dim', 'z_dim', 'h_dim', 'rnn_dim', 'n_layers']
        self.params = parse_model_params(self.model_args, params, parser)
        self.params_str = get_params_str(self.model_args, params)

        y_dim = params['y_dim']
        z_dim = params['z_dim']
        h_dim = params['h_dim']
        rnn_dim = params['rnn_dim']
        n_layers = params['n_layers']

        self.enc = nn.Sequential(nn.Linear(y_dim + y_dim + rnn_dim, h_dim),
                                 nn.ReLU(), nn.Linear(h_dim, h_dim), nn.ReLU())
        self.enc_mean = nn.Linear(h_dim, z_dim)
        self.enc_std = nn.Sequential(nn.Linear(h_dim, z_dim), nn.Softplus())

        self.prior = nn.Sequential(nn.Linear(y_dim + rnn_dim,
                                             h_dim), nn.ReLU(),
                                   nn.Linear(h_dim, h_dim), nn.ReLU())
        self.prior_mean = nn.Linear(h_dim, z_dim)
        self.prior_std = nn.Sequential(nn.Linear(h_dim, z_dim), nn.Softplus())

        self.dec = nn.Sequential(nn.Linear(y_dim + z_dim + rnn_dim, h_dim),
                                 nn.ReLU(), nn.Linear(h_dim, h_dim), nn.ReLU())
        self.dec_mean = nn.Linear(h_dim, y_dim)
        self.dec_std = nn.Sequential(nn.Linear(h_dim, y_dim), nn.Softplus())

        self.rnn = nn.GRU(y_dim + z_dim, rnn_dim, n_layers)
    def __init__(self, params, parser=None):
        super().__init__()

        self.model_args = ['x_dim', 'y_dim', 'z_dim', 'h_dim', 'm_dim', 'rnn_micro_dim', 'rnn_macro_dim', 'n_layers', 'n_agents']
        self.params = parse_model_params(self.model_args, params, parser)
        self.params_str = get_params_str(self.model_args, params)

        x_dim = params['x_dim']
        y_dim = params['y_dim']
        z_dim = params['z_dim']
        h_dim = params['h_dim']
        m_dim = params['m_dim']
        rnn_micro_dim = params['rnn_micro_dim']
        rnn_macro_dim = params['rnn_macro_dim']
        n_layers = params['n_layers']
        n_agents = params['n_agents']

        self.dec_macro = nn.ModuleList([nn.Sequential(
            nn.Linear(y_dim+rnn_macro_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, m_dim),
            nn.LogSoftmax(dim=-1)) for i in range(n_agents)])

        self.enc = nn.ModuleList([nn.Sequential(
            nn.Linear(x_dim+m_dim+rnn_micro_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()) for i in range(n_agents)])
        self.enc_mean = nn.ModuleList([nn.Linear(h_dim, z_dim) for i in range(n_agents)])
        self.enc_std = nn.ModuleList([nn.Sequential(
            nn.Linear(h_dim, z_dim),
            nn.Softplus()) for i in range(n_agents)])

        self.prior = nn.ModuleList([nn.Sequential(
            nn.Linear(m_dim+rnn_micro_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()) for i in range(n_agents)])
        self.prior_mean = nn.ModuleList([nn.Linear(h_dim, z_dim) for i in range(n_agents)])
        self.prior_std = nn.ModuleList([nn.Sequential(
            nn.Linear(h_dim, z_dim),
            nn.Softplus()) for i in range(n_agents)])

        self.dec = nn.ModuleList([nn.Sequential(
            nn.Linear(y_dim+m_dim+z_dim+rnn_micro_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()) for i in range(n_agents)])
        self.dec_mean = nn.ModuleList([nn.Linear(h_dim, x_dim) for i in range(n_agents)])
        self.dec_std = nn.ModuleList([nn.Sequential(
            nn.Linear(h_dim, x_dim),
            nn.Softplus()) for i in range(n_agents)])

        self.gru_micro = nn.ModuleList([nn.GRU(x_dim+z_dim, rnn_micro_dim, n_layers) for i in range(n_agents)])
        self.gru_macro = nn.GRU(m_dim*n_agents, rnn_macro_dim, n_layers)
예제 #3
0
    def __init__(self, params, parser=None):
        super().__init__()

        self.model_args = [
            'x_dim', 'y_dim', 'z_dim', 'h_dim', 'rnn_dim', 'n_layers',
            'n_agents'
        ]
        self.params = parse_model_params(self.model_args, params, parser)
        self.params_str = get_params_str(self.model_args, params)

        x_dim = params['x_dim']
        y_dim = params['y_dim']
        z_dim = params['z_dim']
        h_dim = params['h_dim']
        rnn_dim = params['rnn_dim']
        n_layers = params['n_layers']
        n_agents = params['n_agents']

        self.enc = nn.Sequential(nn.Linear(rnn_dim, h_dim), nn.ReLU(),
                                 nn.Linear(h_dim, h_dim), nn.ReLU())
        self.enc_mean = nn.Linear(h_dim, z_dim)
        self.enc_std = nn.Sequential(nn.Linear(h_dim, z_dim), nn.Softplus())

        self.dec = nn.ModuleList([
            nn.Sequential(nn.Linear(z_dim + rnn_dim, h_dim), nn.ReLU(),
                          nn.Linear(h_dim, h_dim), nn.ReLU())
            for i in range(n_agents)
        ])
        self.dec_mean = nn.ModuleList(
            [nn.Linear(h_dim, x_dim) for i in range(n_agents)])
        self.dec_std = nn.ModuleList([
            nn.Sequential(nn.Linear(h_dim, x_dim), nn.Softplus())
            for i in range(n_agents)
        ])

        self.discrim = nn.Sequential(nn.Linear(rnn_dim, h_dim), nn.ReLU(),
                                     nn.Linear(h_dim, h_dim), nn.ReLU())
        self.discrim_mean = nn.Linear(h_dim, z_dim)
        self.discrim_std = nn.Sequential(nn.Linear(h_dim, z_dim),
                                         nn.Softplus())

        self.rnn = nn.GRU(y_dim, rnn_dim, n_layers)