def fit(self, train_loader, dim=None, test_loader=None): '''Train the model on the provided data loader. Arguments: train_loader (DataLoader): the training data dim (int): the target dimensionality test_loader (DataLoader): the data for validation ''' self.x_mean, self.y_mean = _get_mean(train_loader) self.cxx, self.cxy, self.cyy = _get_covariance( train_loader, self.x_mean, self.y_mean) if self.symmetrize: self.cxx = 0.5 * (self.cxx + self.cyy) self.cyy.copy_(self.cxx) self.cxy = 0.5 * (self.cxy + self.cxy.t()) self.transformer = _Transform( x_mean=self.x_mean, x_covariance=self.cxx, y_mean=self.y_mean, y_covariance=self.cyy) self.ixx = self.transformer.x.mul self.iyy = self.transformer.y.mul u, s, v = _svd(self.ixx.mm(self.cxy.mm(self.iyy))) if dim is None: dim = s.size()[0] self.decoder_matrix = v[:, :dim] self.encoder_matrix = u.t()[:dim, :] if self.kinetic_map: self.encoder_matrix = _diag(s[:dim]).mm(self.encoder_matrix) else: self.decoder_matrix = self.decoder_matrix.mm(_diag(s[:dim])) self.koopman_matrix = _Variable( self.decoder_matrix.mm(self.encoder_matrix)) return self.get_loss(train_loader), self.get_loss(test_loader)
def _reparameterize(self, mu, lv): if self.training: std = lv.mul(0.5).exp_() eps = _Variable(_randn(*std.size())) if self.use_cuda: eps = eps.cuda() return eps.mul(std).add_(mu) else: return mu
def forward(self, m, h, use_message): if use_message: debuglogger.debug(f'Using message') return self.rnn(m, h) else: debuglogger.debug(f'Ignoring message, using blank instead...') blank_msg = _Variable(torch.zeros_like(m.data)) if self.use_cuda: blank_msg = blank_msg.cuda() return self.rnn(blank_msg, h)
def __call__(self, x, variable=False, **kwargs): try: x.sub_(self.sub[None, :]) except AttributeError: pass try: x = x.mm(self.mul) except AttributeError: pass if variable: return _Variable(x, **kwargs) return x
def forward(self, y_scores, h_c, desc, training): ''' desc = \sum_i y_scores desc_i w_hat = tanh(W_h h_c + W_d desc) w = bernoulli(sig(w_hat)) or round(sig(w_hat)) ''' # y_scores: batch_size x num_classes # desc: batch_size x num_classes x hid_dim # h_c: batch_size x hid_dim batch_size, num_classes = y_scores.size() y_broadcast = y_scores.unsqueeze(2).expand(batch_size, num_classes, self.hid_dim) debuglogger.debug(f'y_broadcast: {y_broadcast.size()}') # debuglogger.debug(f'y_broadcast: {y_broadcast}') debuglogger.debug(f'desc: {desc.size()}') # Weight descriptions based on current predictions desc = torch.mul(y_broadcast, desc).sum(1).squeeze(1) debuglogger.debug(f'desc: {desc.size()}') # desc: batch_size x hid_dim h_w = F.tanh(self.w_h(h_c) + self.w_d(desc)) w_scores = self.w(h_w) if self.use_binary: w_probs = F.sigmoid(w_scores) if training: # debuglogger.info(f"Training...") probs_ = w_probs.data.cpu().numpy() rand_num = np.random.rand(*probs_.shape) # debuglogger.debug(f'rand_num: {rand_num}') # debuglogger.info(f'probs: {probs_}') w_binary = _Variable( torch.from_numpy((rand_num < probs_).astype('float32'))) else: # debuglogger.info(f"Eval mode, rounding...") w_binary = torch.round(w_probs).detach() if w_probs.is_cuda: w_binary = w_binary.cuda() w_feats = w_binary # debuglogger.debug(f'w_binary: {w_binary}') else: w_feats = w_scores w_probs = None # debuglogger.info(f'Message : {w_feats}') return w_feats, w_probs
def fit(self, train_loader, dim=None, test_loader=None): '''Train the model on the provided data loader. Arguments: train_loader (DataLoader): the training data dim (int): the target dimensionality test_loader (DataLoader): the data for validation ''' self.x_mean, y_mean = _get_mean(train_loader) self.cxx, cxy, cyy = _get_covariance(train_loader, self.x_mean, y_mean) self.transformer = _Transform(x_mean=self.x_mean, y_mean=self.x_mean) u, s, v = _svd(self.cxx) if dim is None: dim = s.size()[0] self.decoder_matrix = u[:, :dim] self.encoder_matrix = v.t()[:dim, :] self.score_matrix = _Variable( self.decoder_matrix.mm(self.encoder_matrix)) return self.get_loss(train_loader), self.get_loss(test_loader)
def forward(self, x, m, t, desc, use_message, batch_size, training): """ Update State: h_z = message_processor(m, h_z) Image processing h_i = image_processor(x, h_z) Image Attention (https://arxiv.org/pdf/1502.03044.pdf): \beta_i = U tanh(W_r h_z + W_x x_i) \alpha = 1 / |x| if t == 0 \alpha = softmax(\beta) otherwise x = \sum_i \alpha x_i Combine Image and Message information h_c = text_im_combine(h_z, h_i) Text processing desc_proc = text_processor(desc) STOP Bit: s_hat = W_s h_c s = bernoulli(sig(s_hat)) or round(sig(s_hat)) Predictions: y_i = f_y(h_c, desc_proc_i) Generate message: m_out = message_generator(y, h_c, desc_proc) Communication: desc = \sum_i y_i t_i w_hat = tanh(W_h h_c + W_d t) w = bernoulli(sig(w_hat)) or round(sig(w_hat)) Args: x: Image features. m: communication from other agent t: (attention) Timestep. Used to change attention equation in first iteration. desc: List of description vectors used in communication and predictions. batch_size: size of batch training: whether agent is training or not Output: s, s_probs: A STOP bit and its associated probability, indicating whether the agent has decided to make a selection. The conversation will continue until both agents have selected STOP. w, w_probs: A binary message and the probability of each bit in the message being ``1``. y: A prediction for each class described in the descriptions. r: An estimate of the reward the agent will receive """ debuglogger.debug(f'Input sizes...') debuglogger.debug(f'x: {x.size()}') debuglogger.debug(f'm: {m.size()}') debuglogger.debug(f'm: {m}') debuglogger.debug(f'desc: {desc.size()}') # Initialize hidden state if necessary if self.h_z is None: self.h_z = self.initial_state(batch_size) # Process message sent from the other agent self.h_z = self.message_processor(m, self.h_z, use_message) debuglogger.debug(f'h_z: {self.h_z.size()}') # Process the image h_i = self.image_processor(x, self.h_z, t) debuglogger.debug(f'h_i: {h_i.size()}') # Combine the image and message info to a single vector h_c = self.text_im_combine(torch.cat([self.h_z, h_i], dim=1)) debuglogger.debug(f'h_c: {h_c.size()}') # Process the texts # desc: bs x num_classes x desc_dim # desc_proc: bs x num_classes x hid_dim desc_proc = self.text_processor(desc) debuglogger.debug(f'desc_proc: {desc_proc.size()}') # Estimate the reward r = self.reward_estimator(h_c) debuglogger.debug(f'r: {r.size()}') # Calculate stop bits s_score = self.s(h_c) s_prob = F.sigmoid(s_score) debuglogger.debug(f's_score: {s_score.size()}') debuglogger.debug(f's_prob: {s_prob.size()}') if training: # Sample decisions prob_ = s_prob.data.cpu().numpy() rand_num = np.random.rand(*prob_.shape) # debuglogger.debug(f'rand_num: {rand_num}') # debuglogger.debug(f'prob: {prob_}') s_binary = _Variable( torch.from_numpy((rand_num < prob_).astype('float32'))) if self.use_cuda: s_binary = s_binary.cuda() else: # Infer decisions s_binary = torch.round(s_prob).detach() debuglogger.debug(f'stop decisions: {s_binary.size()}') # debuglogger.debug(f'stop decisions: {s_binary}') # Predict classes # y: batch_size * num_classes y = self.predict_classes(h_c, desc_proc, batch_size) y_scores = F.softmax(y, dim=1).detach() debuglogger.debug(f'y_scores: {y_scores.size()}') # debuglogger.debug(f'y_scores: {y_scores}') # Generate message w, w_probs = self.message_generator(y_scores, h_c, desc_proc, training) debuglogger.debug(f'w: {w.size()}') debuglogger.debug(f'w_probs: {w_probs.size()}') return (s_binary, s_prob), (w, w_probs), y, r
def initial_state(self, batch_size): h = _Variable(torch.zeros(batch_size, self.h_dim)) if self.use_cuda: h = h.cuda() return h
im_from_scratch = True agent = Agent(im_feature_type, im_feat_dim, h_dim, m_dim, desc_dim, num_classes, s_dim, use_binary, use_attn, attn_dim, use_MLP, cuda, im_from_scratch, dropout) print(agent) total_params = sum([ functools.reduce(lambda x, y: x * y, p.size(), 1.0) for p in agent.parameters() ]) image_proc_params = sum([ functools.reduce(lambda x, y: x * y, p.size(), 1.0) for p in agent.image_processor.parameters() ]) print( f'Total params: {total_params}, image proc params: {image_proc_params}' ) x = _Variable(torch.ones(batch_size, 3, im_feat_dim, im_feat_dim)) m = _Variable(torch.ones(batch_size, m_dim)) desc = _Variable(torch.ones(batch_size, num_classes, desc_dim)) for i in range(2): s, w, y, r = agent(x, m, i, desc, use_message, batch_size, training) # print(f's_binary: {s[0]}') # print(f's_probs: {s[1]}') # print(f'w_binary: {w[0]}') # print(f'w_probs: {w[1]}') # print(f's_binary: {s[0]}') # print(f's_probs: {s[1]}') # print(f'y: {y}') # print(f'r: {r}')
def initial_state(self, batch_size): h = _Variable(torch.zeros(batch_size, self.h_dim)) if self.use_cuda: h = h.cuda(next(self.parameters()).device.index) return h