Beispiel #1
0
    def fit(self, train_loader, dim=None, test_loader=None):
        '''Train the model on the provided data loader.

        Arguments:
            train_loader (DataLoader): the training data
            dim (int): the target dimensionality
            test_loader (DataLoader): the data for validation
        '''
        self.x_mean, self.y_mean = _get_mean(train_loader)
        self.cxx, self.cxy, self.cyy = _get_covariance(
            train_loader, self.x_mean, self.y_mean)
        if self.symmetrize:
            self.cxx = 0.5 * (self.cxx + self.cyy)
            self.cyy.copy_(self.cxx)
            self.cxy = 0.5 * (self.cxy + self.cxy.t())
        self.transformer = _Transform(
            x_mean=self.x_mean, x_covariance=self.cxx,
            y_mean=self.y_mean, y_covariance=self.cyy)
        self.ixx = self.transformer.x.mul
        self.iyy = self.transformer.y.mul
        u, s, v = _svd(self.ixx.mm(self.cxy.mm(self.iyy)))
        if dim is None:
            dim = s.size()[0]
        self.decoder_matrix = v[:, :dim]
        self.encoder_matrix = u.t()[:dim, :]
        if self.kinetic_map:
            self.encoder_matrix = _diag(s[:dim]).mm(self.encoder_matrix)
        else:
            self.decoder_matrix = self.decoder_matrix.mm(_diag(s[:dim]))
        self.koopman_matrix = _Variable(
            self.decoder_matrix.mm(self.encoder_matrix))
        return self.get_loss(train_loader), self.get_loss(test_loader)
Beispiel #2
0
 def _reparameterize(self, mu, lv):
     if self.training:
         std = lv.mul(0.5).exp_()
         eps = _Variable(_randn(*std.size()))
         if self.use_cuda:
             eps = eps.cuda()
         return eps.mul(std).add_(mu)
     else:
         return mu
Beispiel #3
0
 def forward(self, m, h, use_message):
     if use_message:
         debuglogger.debug(f'Using message')
         return self.rnn(m, h)
     else:
         debuglogger.debug(f'Ignoring message, using blank instead...')
         blank_msg = _Variable(torch.zeros_like(m.data))
         if self.use_cuda:
             blank_msg = blank_msg.cuda()
         return self.rnn(blank_msg, h)
Beispiel #4
0
 def __call__(self, x, variable=False, **kwargs):
     try:
         x.sub_(self.sub[None, :])
     except AttributeError:
         pass
     try:
         x = x.mm(self.mul)
     except AttributeError:
         pass
     if variable:
         return _Variable(x, **kwargs)
     return x
Beispiel #5
0
 def forward(self, y_scores, h_c, desc, training):
     '''
         desc = \sum_i y_scores desc_i
         w_hat = tanh(W_h h_c + W_d desc)
         w = bernoulli(sig(w_hat)) or round(sig(w_hat))
     '''
     # y_scores: batch_size x num_classes
     # desc: batch_size x num_classes x hid_dim
     # h_c: batch_size x hid_dim
     batch_size, num_classes = y_scores.size()
     y_broadcast = y_scores.unsqueeze(2).expand(batch_size, num_classes,
                                                self.hid_dim)
     debuglogger.debug(f'y_broadcast: {y_broadcast.size()}')
     # debuglogger.debug(f'y_broadcast: {y_broadcast}')
     debuglogger.debug(f'desc: {desc.size()}')
     # Weight descriptions based on current predictions
     desc = torch.mul(y_broadcast, desc).sum(1).squeeze(1)
     debuglogger.debug(f'desc: {desc.size()}')
     # desc: batch_size x hid_dim
     h_w = F.tanh(self.w_h(h_c) + self.w_d(desc))
     w_scores = self.w(h_w)
     if self.use_binary:
         w_probs = F.sigmoid(w_scores)
         if training:
             # debuglogger.info(f"Training...")
             probs_ = w_probs.data.cpu().numpy()
             rand_num = np.random.rand(*probs_.shape)
             # debuglogger.debug(f'rand_num: {rand_num}')
             # debuglogger.info(f'probs: {probs_}')
             w_binary = _Variable(
                 torch.from_numpy((rand_num < probs_).astype('float32')))
         else:
             # debuglogger.info(f"Eval mode, rounding...")
             w_binary = torch.round(w_probs).detach()
         if w_probs.is_cuda:
             w_binary = w_binary.cuda()
         w_feats = w_binary
         # debuglogger.debug(f'w_binary: {w_binary}')
     else:
         w_feats = w_scores
         w_probs = None
     # debuglogger.info(f'Message : {w_feats}')
     return w_feats, w_probs
Beispiel #6
0
    def fit(self, train_loader, dim=None, test_loader=None):
        '''Train the model on the provided data loader.

        Arguments:
            train_loader (DataLoader): the training data
            dim (int): the target dimensionality
            test_loader (DataLoader): the data for validation
        '''
        self.x_mean, y_mean = _get_mean(train_loader)
        self.cxx, cxy, cyy = _get_covariance(train_loader, self.x_mean, y_mean)
        self.transformer = _Transform(x_mean=self.x_mean, y_mean=self.x_mean)
        u, s, v = _svd(self.cxx)
        if dim is None:
            dim = s.size()[0]
        self.decoder_matrix = u[:, :dim]
        self.encoder_matrix = v.t()[:dim, :]
        self.score_matrix = _Variable(
            self.decoder_matrix.mm(self.encoder_matrix))
        return self.get_loss(train_loader), self.get_loss(test_loader)
Beispiel #7
0
    def forward(self, x, m, t, desc, use_message, batch_size, training):
        """
        Update State:
            h_z = message_processor(m, h_z)

        Image processing
            h_i = image_processor(x, h_z)
            Image Attention (https://arxiv.org/pdf/1502.03044.pdf):
                \beta_i = U tanh(W_r h_z + W_x x_i)
                \alpha = 1 / |x|        if t == 0
                \alpha = softmax(\beta) otherwise
                x = \sum_i \alpha x_i

        Combine Image and Message information
            h_c = text_im_combine(h_z, h_i)

        Text processing
            desc_proc = text_processor(desc)

        STOP Bit:
            s_hat = W_s h_c
            s = bernoulli(sig(s_hat)) or round(sig(s_hat))

        Predictions:
            y_i = f_y(h_c, desc_proc_i)

        Generate message:
            m_out = message_generator(y, h_c, desc_proc)
            Communication:
                desc = \sum_i y_i t_i
                w_hat = tanh(W_h h_c + W_d t)
                w = bernoulli(sig(w_hat)) or round(sig(w_hat))

        Args:
            x: Image features.
            m: communication from other agent
            t: (attention) Timestep. Used to change attention equation in first iteration.
            desc: List of description vectors used in communication and predictions.
            batch_size: size of batch
            training: whether agent is training or not
        Output:
            s, s_probs: A STOP bit and its associated probability, indicating whether the agent has decided to make a selection. The conversation will continue until both agents have selected STOP.
            w, w_probs: A binary message and the probability of each bit in the message being ``1``.
            y: A prediction for each class described in the descriptions.
            r: An estimate of the reward the agent will receive
        """
        debuglogger.debug(f'Input sizes...')
        debuglogger.debug(f'x: {x.size()}')
        debuglogger.debug(f'm: {m.size()}')
        debuglogger.debug(f'm: {m}')
        debuglogger.debug(f'desc: {desc.size()}')

        # Initialize hidden state if necessary
        if self.h_z is None:
            self.h_z = self.initial_state(batch_size)

        # Process message sent from the other agent
        self.h_z = self.message_processor(m, self.h_z, use_message)
        debuglogger.debug(f'h_z: {self.h_z.size()}')

        # Process the image
        h_i = self.image_processor(x, self.h_z, t)
        debuglogger.debug(f'h_i: {h_i.size()}')

        # Combine the image and message info to a single vector
        h_c = self.text_im_combine(torch.cat([self.h_z, h_i], dim=1))
        debuglogger.debug(f'h_c: {h_c.size()}')

        # Process the texts
        # desc: bs x num_classes x desc_dim
        # desc_proc:    bs x num_classes x hid_dim
        desc_proc = self.text_processor(desc)
        debuglogger.debug(f'desc_proc: {desc_proc.size()}')

        # Estimate the reward
        r = self.reward_estimator(h_c)
        debuglogger.debug(f'r: {r.size()}')

        # Calculate stop bits
        s_score = self.s(h_c)
        s_prob = F.sigmoid(s_score)
        debuglogger.debug(f's_score: {s_score.size()}')
        debuglogger.debug(f's_prob: {s_prob.size()}')
        if training:
            # Sample decisions
            prob_ = s_prob.data.cpu().numpy()
            rand_num = np.random.rand(*prob_.shape)
            # debuglogger.debug(f'rand_num: {rand_num}')
            # debuglogger.debug(f'prob: {prob_}')
            s_binary = _Variable(
                torch.from_numpy((rand_num < prob_).astype('float32')))
            if self.use_cuda:
                s_binary = s_binary.cuda()
        else:
            # Infer decisions
            s_binary = torch.round(s_prob).detach()
        debuglogger.debug(f'stop decisions: {s_binary.size()}')
        # debuglogger.debug(f'stop decisions: {s_binary}')

        # Predict classes
        # y: batch_size * num_classes
        y = self.predict_classes(h_c, desc_proc, batch_size)
        y_scores = F.softmax(y, dim=1).detach()
        debuglogger.debug(f'y_scores: {y_scores.size()}')
        # debuglogger.debug(f'y_scores: {y_scores}')

        # Generate message
        w, w_probs = self.message_generator(y_scores, h_c, desc_proc, training)
        debuglogger.debug(f'w: {w.size()}')
        debuglogger.debug(f'w_probs: {w_probs.size()}')

        return (s_binary, s_prob), (w, w_probs), y, r
Beispiel #8
0
 def initial_state(self, batch_size):
     h = _Variable(torch.zeros(batch_size, self.h_dim))
     if self.use_cuda:
         h = h.cuda()
     return h
Beispiel #9
0
    im_from_scratch = True
    agent = Agent(im_feature_type, im_feat_dim, h_dim, m_dim, desc_dim,
                  num_classes, s_dim, use_binary, use_attn, attn_dim, use_MLP,
                  cuda, im_from_scratch, dropout)
    print(agent)
    total_params = sum([
        functools.reduce(lambda x, y: x * y, p.size(), 1.0)
        for p in agent.parameters()
    ])
    image_proc_params = sum([
        functools.reduce(lambda x, y: x * y, p.size(), 1.0)
        for p in agent.image_processor.parameters()
    ])
    print(
        f'Total params: {total_params}, image proc params: {image_proc_params}'
    )
    x = _Variable(torch.ones(batch_size, 3, im_feat_dim, im_feat_dim))
    m = _Variable(torch.ones(batch_size, m_dim))
    desc = _Variable(torch.ones(batch_size, num_classes, desc_dim))

    for i in range(2):
        s, w, y, r = agent(x, m, i, desc, use_message, batch_size, training)
        # print(f's_binary: {s[0]}')
        # print(f's_probs: {s[1]}')
        # print(f'w_binary: {w[0]}')
        # print(f'w_probs: {w[1]}')
        # print(f's_binary: {s[0]}')
        # print(f's_probs: {s[1]}')
        # print(f'y: {y}')
        # print(f'r: {r}')
 def initial_state(self, batch_size):
     h = _Variable(torch.zeros(batch_size, self.h_dim))
     if self.use_cuda:
         h = h.cuda(next(self.parameters()).device.index)
     return h