def construct_batch_of_segments_from_one_sample_image(self, sample): """ See construct_batch_of_segments_from_one_sample_stroke for more details Args: sample (dict): one data point from DrawingAsImage...Dataset contains fp's and n_segments """ fn = os.path.basename( sample['post_seg_fp'] ) # data/quickdraw/precurrentpost/data/pig/5598031527280640/7-10.jpg start, end = fn.strip('.jpg').split('-') end = int(end) n_penups = end seg_idx = 0 seg_idx_map = { } # maps tuple of (left_idx, right_idx) in terms of penups to seg_idx in batch batch = [] for i in range(n_penups): # i is left index for j in range(i + 1, n_penups + 1): # j is right index img = self.ds._construct_rank_image(i, j, n_penups, sample) batch.append(img) seg_idx_map[(i, j)] = seg_idx seg_idx += 1 seg_lens = [1 for _ in range(len(batch))] # dummy lengths (not used) batch = np.stack(batch) # [n_segs, C, H, W] batch = torch.Tensor(batch) batch = batch.transpose(0, 1) # [C, n_segs, H, W] batch = nn_utils.move_to_cuda(batch) return batch, n_penups, seg_lens, seg_idx_map
def one_forward_pass(self, batch): """ Return loss and other items of interest for one forward pass Args: batch: tuple from DataLoaders Returns: dict where 'loss': float Tensor must exist """ strokes, stroke_lens, cats, cats_idx = batch max_len, bsz, _ = strokes.size() # Create inputs to decoder sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(0) # start of sequence sos = nn_utils.move_to_cuda(sos) dec_inputs = torch.cat([sos, strokes], dim=0) # add sos at the begining of the strokes; [max_len + 1, bsz, 5] # Decode xy, q, hidden, cell = self.dec(dec_inputs, stroke_lens=stroke_lens) # Calculate losses mask, dxdy, p = self.dec.make_target(strokes, stroke_lens) loss = self.dec.reconstruction_loss(mask, dxdy, p, xy, q) result = {'loss': loss, 'loss_R': loss} return result
def inference_pass(self, vae_zs, cats_idx): """ Args: vae_zs: [bsz, z_dim] cats_idx: [bsz] LongTensor Returns: decoded_probs: [bsz, max_len, vocab] decoded_ids: [bsz, max_len] decoded_texts: list of strs """ bsz = vae_zs.size(0) # Get hidden and cell by encoding z z_emb = self.enc(vae_zs) # [bsz, enc_dim] z_emb = z_emb.unsqueeze(0) # [1, bsz, enc_dim] hidden = z_emb.repeat(self.dec.num_layers, 1, 1) # [dec_num_layers, bsz, enc_dim] cell = z_emb.repeat(self.dec.num_layers, 1, 1) # [dec_num_layers, bsz, enc_dim] # Create init input init_ids = nn_utils.move_to_cuda(torch.LongTensor([SOS_ID] * bsz).unsqueeze(1)) # [bsz, 1] init_ids.transpose_(0, 1) # [1, bsz] decoded_probs, decoded_ids, decoded_texts = self.dec.generate( self.text_embedding, category_embedding=self.category_embedding, categories=cats_idx, init_ids=init_ids, hidden=hidden, cell=cell, pad_id=PAD_ID, eos_id=EOS_ID, max_len=200, # TODO: set max_len to max_len on data decode_method=self.hp.decode_method, tau=self.hp.tau, k=self.hp.k, idx2token=self.tr_loader.dataset.idx2token ) return decoded_probs, decoded_ids, decoded_texts
def one_forward_pass(self, batch): """ Return loss and other items of interest for one forward pass Returns: dict: 'loss': float Tensor must exist """ input_ids, target, input_mask, num_sentences_per_school, url, perfrl, perwht, share_singleparent, totenrl, share_collegeplus, mail_returnrate = batch num_sentences_per_school, perm = torch.sort(num_sentences_per_school, descending=True) num_sentences_per_school = nn_utils.move_to_cuda( num_sentences_per_school) input_ids = nn_utils.move_to_cuda(input_ids[perm, :, :]) input_mask = nn_utils.move_to_cuda(input_mask[perm, :, :]) target = nn_utils.move_to_cuda(target[perm].unsqueeze_(1)) perfrl = nn_utils.move_to_cuda(perfrl[perm].unsqueeze_(1)) perwht = nn_utils.move_to_cuda(perwht[perm].unsqueeze_(1)) share_singleparent = nn_utils.move_to_cuda( share_singleparent[perm].unsqueeze_(1)) totenrl = nn_utils.move_to_cuda(totenrl[perm].unsqueeze_(1)) share_collegeplus = nn_utils.move_to_cuda( share_collegeplus[perm].unsqueeze_(1)) mail_returnrate = nn_utils.move_to_cuda( mail_returnrate[perm].unsqueeze_(1)) if self.hp.model_type == 'meanbert': predicted_target, predicted_confounds = self.model( input_ids, num_sentences_per_school, attention_mask=input_mask) # [bsz] (n_outcomes) elif self.hp.model_type == 'robert': predicted_target, predicted_confounds = self.model( input_ids, num_sentences_per_school, attention_mask=input_mask) # [bsz] (n_outcomes) if len(self.hp.adv_terms) > 0: actual_adv = [eval(t) for t in self.hp.adv_terms] predicted_adv = [] for i in range(0, predicted_confounds.size(1)): predicted_adv.append(predicted_confounds[:, i].unsqueeze_(1)) losses = self.compute_loss_adv_for_grad_reversal( predicted_target, target, predicted_adv, actual_adv) else: losses = {'loss': self.compute_loss(predicted_target, target)} return losses
def construct_batch_of_segments_from_one_sample_stroke(self, strokes): """ Args: strokes: [len, 5] np array Returns: batch: [n_pts (seq_len), n_segs, 5] FloatTensor n_penups: int seg_lens: list of ints, length n_segs seg_idx_map: dict Maps penup_idx tuples to seg_idx Example with 5 penups {(0, 1): 0, (0, 2): 1, (0, 3): 2, (0, 4): 3, (0, 5): 4, (1, 2): 5, (1, 3): 6, (1, 4): 7, (1, 5): 8, (2, 3): 9, (2, 4): 10, (2, 5): 11, (3, 4): 12, (3, 5): 13, (4, 5): 14} """ # get locations of segments using penup (4th point in stroke5 format) n_pts = strokes.size(0) strokes = strokes.cpu().numpy() pen_up = (np.where(strokes[:, 3] == 1)[0]).tolist() n_penups = len(pen_up) n_segs = int(n_penups * (n_penups + 1) / 2) # construct tensor of segments batch = np.zeros((n_segs, n_pts, 5)) seg_lens = [] seg_idx = 0 seg_idx_map = { } # maps tuple of (left_idx, right_idx) in terms of penups to seg_idx in batch pen_up = [0] + pen_up # insert dummy for i in range(len(pen_up) - 1): # i is left index for j in range(i + 1, len(pen_up)): # j is right index start_stroke_idx = pen_up[i] end_stroke_idx = pen_up[j] seg = strokes[start_stroke_idx:end_stroke_idx + 1] seg_len = len(seg) batch[seg_idx, :seg_len, :] = seg seg_lens.append(seg_len) seg_idx_map[(i, j)] = seg_idx seg_idx += 1 batch = torch.Tensor(batch) batch = batch.transpose(0, 1) # [n_pts, n_segs, 5] batch = nn_utils.move_to_cuda(batch) return batch, n_penups, seg_lens, seg_idx_map
def forward(self, strokes, stroke_lens=None, hidden_cell=None, output_all=True): """ Args: strokes: [max_len + 1, bsz, input_dim] (+ 1 for sos) stroke_lens: list of ints, length len hidden_cell: tuple of [n_layers, bsz, dec_dim] output_all: bool (unused... for compatability with SketchRNNDecoderGMM) Returns: xy: [max_len + 1, bsz, 5] (+1 for sos) q: [max_len + 1, bsz, 3] # xy: [len, bsz, 5] (len may be less than max_len + 1) # q: [len, bsz, 3] models p (3 pen strokes in stroke-5) as categorical distribution (page 3) hidden: [n_layers, bsz, dim] cell: [n_layers, bsz, dim] """ bsz = strokes.size(1) if hidden_cell is None: # init hidden = torch.zeros(self.lstm.num_layers, bsz, self.dec_dim) cell = torch.zeros(self.lstm.num_layers, bsz, self.dec_dim) hidden, cell = nn_utils.move_to_cuda( hidden), nn_utils.move_to_cuda(cell) hidden_cell = (hidden, cell) # decode outputs, (hidden, cell) = self.lstm(strokes, hidden_cell) # packed_inputs = nn.utils.rnn.pack_padded_sequence(strokes, stroke_lens, enforce_sorted=False) # outputs, (hidden, cell) = self.lstm(packed_inputs, (hidden, cell)) # outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs) # # [len, bsz, dim]; h/c = [n_layers * n_directions, bsz, dim] # # NOTE: pad_packed will "trim" extra timesteps, so outputs may be shorter than strokes outputs = self.fc_out_xypen(outputs) # [len, bsz, 5] xy = outputs[:, :, :2] # [len, bsz, 2] pen = outputs[:, :, 2:] # [len, bsz, 3] q = F.softmax(pen, dim=-1) return xy, q, hidden, cell
def one_forward_pass(self, batch): """ Return loss and other items of interest for one forward pass Args: batch: tuple from loader Returns: dict where 'loss': float Tensor must exist """ strokes, stroke_lens, cats, cats_idx = batch max_len, bsz, _ = strokes.size() # Encode # These two lines are different ret_z, ret_mu, ret_sigmahat = self.retrieve(strokes, cats, cats_idx) # [bsz, enc_dim] z, mu, sigma_hat = ret_z, ret_mu, ret_sigmahat # z, mu, sigma_hat = self.enc(retrieved_enc) # each [bsz, z_dim] # Create inputs to decoder sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze( 0) # start of sequence sos = nn_utils.move_to_cuda(sos) inputs_init = torch.cat( [sos, strokes], 0) # add sos at the begining of the strokes; [max_len + 1, bsz, 5] z_stack = torch.stack( [z] * (max_len + 1), dim=0) # expand z to concat with inputs; [max_len + 1, bsz, z_dim] dec_inputs = torch.cat( [inputs_init, z_stack], 2) # each input is stroke + z; [max_len + 1, bsz, z_dim + 5] # init hidden and cell states is tanh(fc(z)) (Page 3) hidden, cell = torch.split(torch.tanh(self.fc_z_to_hc(z)), self.hp.dec_dim, 1) hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) # TODO: if we want multiple layers, we need to replicate hidden and cell n_layers times # Decode _, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec( dec_inputs, output_all=True, hidden_cell=hidden_cell) # Calculate losses mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M) loss_KL = self.kullback_leibler_loss(sigma_hat, mu, self.hp.KL_min, self.hp.wKL, self.eta_step) loss_R = self.dec.reconstruction_loss(mask, dx, dy, p, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q) loss = loss_KL + loss_R result = {'loss': loss, 'loss_KL': loss_KL, 'loss_R': loss_R} return result
def create_transformer_padding_masks(src_lens=None, tgt_lens=None): """ Return ByteTensors where a true value means value should be ignored. Used to handle variable length sequences within a batch. Args: src_lens: list of length bsz tgt_lens: list of length bsz Returns: src_key_padding_mask: [bsz, max_src_len] ByteTensor tgt_key_padding_mask: [bsz, max_tgt_len] ByteTensor memory_key_padding_mask: [bsz, max_src_len] ByteTensor """ src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask = None, None, None # Src and memory masks if src_lens is not None: bsz = len(src_lens) max_src_len = max(src_lens) src_key_padding_mask = torch.zeros(bsz, max_src_len).bool() for i, seq_len in enumerate(src_lens): src_key_padding_mask[i, seq_len:] = 1 memory_key_padding_mask = src_key_padding_mask src_key_padding_mask = nn_utils.move_to_cuda(src_key_padding_mask) memory_key_padding_mask = nn_utils.move_to_cuda( memory_key_padding_mask) # Tgt mask if tgt_lens is not None: bsz = len(tgt_lens) max_tgt_len = max(tgt_lens) tgt_key_padding_mask = torch.zeros(bsz, max_tgt_len).bool() for i, seq_len in enumerate(tgt_lens): tgt_key_padding_mask[i, seq_len:] = 1 tgt_key_padding_mask = nn_utils.move_to_cuda(tgt_key_padding_mask) return src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask
def preprocess_batch_from_data_loader(self, batch): """ Convert tensors to cuda and convert to [len, bsz, ...] instead of [bsz, len, ...] """ preprocessed = [] for item in batch: if type(item) == torch.Tensor: item = nn_utils.move_to_cuda(item) if item.dim() > 1: item.transpose_(0, 1) preprocessed.append(item) return preprocessed
def one_forward_pass(self, batch): """ Return loss and other items of interest for one forward pass Args: batch: tuple from loader Returns: dict where 'loss': float Tensor must exist """ strokes, stroke_lens, cats, cats_idx = batch max_len, bsz, _ = strokes.size() # Encode z, mu, sigma_hat = self.enc(strokes) # each [bsz, z_dim] # Create inputs to decoder sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(0) # start of sequence sos = nn_utils.move_to_cuda(sos) inputs_init = torch.cat([sos, strokes], 0) # add sos at the begining of the strokes; [max_len + 1, bsz, 5] z_stack = torch.stack([z] * (max_len + 1), dim=0) # expand z to concat with inputs; [max_len + 1, bsz, z_dim] dec_inputs = torch.cat([inputs_init, z_stack], 2) # each input is stroke + z; [max_len + 1, bsz, 5 + z_dim] if self.hp.use_categories_dec: cat_embs = self.category_embedding(cats_idx) # [bsz, cat_dim] cat_embs = cat_embs.repeat(dec_inputs.size(0), 1, 1) # [max_len + 1, bsz, cat_dim] dec_inputs = torch.cat([dec_inputs, cat_embs], dim=2) # [max_len+1, bsz, 5 + z_dim + cat_dim] # init hidden and cell states is tanh(fc(z)) (Page 3) hidden, cell = torch.split(torch.tanh(self.fc_z_to_hc(z)), self.hp.dec_dim, 1) hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) # TODO: if we want multiple layers, we need to replicate hidden and cell n_layers times # Decode _, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(dec_inputs, output_all=True, hidden_cell=hidden_cell) # Calculate losses mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M) loss_KL_final, loss_KL_thresh, loss_KL = self.kullback_leibler_loss(sigma_hat, mu, self.hp.KL_min, self.hp.wKL, self.eta_step) loss_R = self.dec.reconstruction_loss(mask, dx, dy, p, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q) loss = loss_KL_final + loss_R result = {'loss': loss, 'loss_R': loss_R, 'loss_KL_final': loss_KL_final, 'loss_KL_thresh': loss_KL_thresh, 'loss_KL': loss_KL, 'loss_KL_eta': torch.Tensor([self.eta_step]) # "loss" for logging purposes; convert to Tensor because .item() is called on each value } return result
def make_target(self, strokes, stroke_lens, M): """ Create target vector out of stroke-5 data and stroke_lens. Namely, use stroke_lens to create mask for each sequence Args: strokes: [max_len, bsz, 5] stroke_lens: list of ints M: int, number of mixtures Returns: mask: [max_len + 1, bsz] dx: [max_len + 1, bsz, num_mixtures] dy: [max_len + 1, bsz, num_mixtures] p: [max_len + 1, bsz, 3] """ max_len, bsz, _ = strokes.size() # add eos eos = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * bsz).unsqueeze( 0) # ([1, bsz, 5]) eos = nn_utils.move_to_cuda(eos) strokes = torch.cat([strokes, eos], 0) # [max_len + 1, bsz, 5] # calculate mask for each sequence using stroke_lens mask = torch.zeros(max_len + 1, bsz) for idx, length in enumerate(stroke_lens): mask[:length, idx] = 1 mask = nn_utils.move_to_cuda(mask) mask = mask.detach() dx = torch.stack([strokes.data[:, :, 0]] * M, 2).detach() dy = torch.stack([strokes.data[:, :, 1]] * M, 2).detach() p1 = strokes.data[:, :, 2].detach() p2 = strokes.data[:, :, 3].detach() p3 = strokes.data[:, :, 4].detach() p = torch.stack([p1, p2, p3], 2) return mask, dx, dy, p
def generate_square_subsequent_mask(size): """ Generate a square mask for the sequence that prevents attending to items in the future. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ mask = (torch.triu(torch.ones(size, size)) == 1).transpose( 0, 1) # True's in lower left half mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( mask == 1, float(0.0)) mask = nn_utils.move_to_cuda(mask) return mask
def preprocess_batch_from_data_loader(self, batch): """ Transposes strokes, moves to cuda Args: batch: tuple of strokes: [bsz, max_len, 5] Tensor stroke_lens: list of ints cats: list of strs (categories) cats_idx: [bsz] LongTensor Returns: batch: [max_len, bsz, 5] stroke_lens: list of ints cats: list of strs cats_idx: [bsz] """ strokes, stroke_lens, cats, cats_idx = batch strokes = strokes.transpose(0, 1).float() strokes = nn_utils.move_to_cuda(strokes) stroke_lens = stroke_lens.numpy().tolist() cats_idx = nn_utils.move_to_cuda(cats_idx) return strokes, stroke_lens, cats, cats_idx
def __init__(self, hp, save_dir, skip_data=False): super().__init__(hp, save_dir, skip_data=skip_data) # What else to do here? # 1. Load retrieval set self.idx2cat, self.cat2idx = build_category_index_nodata() self.cat_to_retrieval_text = defaultdict(list) # self.cat_to_retrieval_vals = defaultdict(list) max_len = 0 # n = 0 self.n_ret_per_cat = 251 self.retrieval_vals = np.zeros((201, 35 * self.n_ret_per_cat, 5)) for split in ['train', 'valid', 'test']: ds = ProgressionPairDataset(split) for i in range(len(ds)): item = ds.__getitem__(i) stroke5, text, category = item[0], item[2], item[4] # import pdb; pdb.set_trace() cat_idx = self.cat2idx[category] max_len = max(stroke5.shape[0], max_len) self.cat_to_retrieval_text[category].append(text) # self.cat_to_retrieval_vals[category].append(stroke5) n_cats = len(self.cat_to_retrieval_text[category]) # if n_cats > 250: # print(category, cat_idx, n_cats) self.retrieval_vals[:len(stroke5), cat_idx * self.n_ret_per_cat + n_cats, :] = stroke5 # n += 1 # print(max_len) self.retrieval_vals = self.retrieval_vals[:max_len, :, :] self.retrieval_vals = nn_utils.move_to_cuda( torch.Tensor(self.retrieval_vals)) if self.hp.fixed_mem: self.retrieval_vals.requires_grad = False else: self.retrieval_vals.requires_grad = True self.query_ff = nn.Sequential( nn.Linear(self.hp.z_dim + 1, self.hp.z_dim), nn.ReLU(), nn.Linear(self.hp.z_dim, self.n_ret_per_cat)) self.query_ff.cuda() self.optimizers.append(optim.Adam(self.query_ff.parameters(), hp.lr))
def _calc_instruction_to_strokes_score(self, batch_of_segs, seg_lens, texts, cats_idx): """ P(S|I). Note that it's the prob, not the loss (NLL) returned by the model. Args: batch_of_segs: [n_pts (seq_len), n_segs, 5] CudaFloatTensor seg_lens: list of ints, length n_segs texts (list): n_segs list of strings cats_idx: list of the same int, length n_segs Returns: scores: (n_segs) np array """ text_indices_list = [ map_sentence_to_index(text, self.token2idx) for text in texts ] # Construct inputs to instruction_to_strokes model bsz = batch_of_segs.size(1) text_lens = [len(t) for t in text_indices_list] max_len = max(text_lens) text_indices = np.zeros((max_len, bsz)) for i, indices in enumerate(text_indices_list): text_indices[:len(indices), i] = indices text_indices = nn_utils.move_to_cuda(torch.LongTensor(text_indices)) cats = ['' for _ in range(bsz)] # dummy urls = ['' for _ in range(bsz)] # dummy batch = (batch_of_segs, seg_lens, texts, text_lens, text_indices, cats, cats_idx, urls) with torch.no_grad(): result = self.instruction_to_strokes.one_forward_pass( batch, average_loss=False) # [n_segs]? scores = result['loss'].cpu().numpy().astype( np.float64 ) # float32 doesn't serialize to json for some reason scores = np.exp(-scores) # map losses (NLL) to probs return scores
def compute_unlikelihood_loss(self, logits, text_lens): loss = torch.tensor(0.) # Keep track of last n batches of probs. Shift by 1, add latest # updated = nn_utils.move_to_cuda(torch.zeros_like(self.model_vocab_prob)) updated = self.model_vocab_prob.clone().detach( ) # TODO: where to detach... updated[:updated.size(0) - 1, :] = self.model_vocab_prob[1:, :].detach() # detach? self.model_vocab_prob = updated # Detaching above so that it doesn't backprop through all entire model_vocab_probs, just the current one? # compute models' vocab prob in current batch probs = F.softmax(logits, dim=-1) # [len, bsz, vocab] logits_len, bsz, _ = logits.size() mask = nn_utils.move_to_cuda(torch.zeros(logits_len, bsz)) # [len, bsz] for i in range(bsz): mask[:text_lens[i]] = 1 mask = mask.unsqueeze(-1) # [len, bsz, vocab] batch_model_prob = (probs * mask).mean(dim=0).mean(dim=0) # [vocab] self.model_vocab_prob[-1] = batch_model_prob # Only compute after having seen n batches if self.model_vocab_prob[0].sum() == 0: return loss cur_model_vocab = self.model_vocab_prob.mean(dim=0) # [vocab] mismatch = cur_model_vocab * torch.log( cur_model_vocab / self.vocab_prob.detach()) unlikelihood = torch.log(1 - probs) * mask # [len, bsz, vocab] loss = -(mismatch * unlikelihood).mean() loss *= 500 # mixing parameter # print('ull, ', loss.item()) return loss
def one_forward_pass(self, batch): """ Return loss and other items of interest for one forward pass Args: batch: tuple from DataLoaders Returns: dict where 'loss': float Tensor must exist """ strokes, stroke_lens, cats, cats_idx = batch max_len, bsz, _ = strokes.size() # Create inputs to decoder sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(0) # start of sequence sos = nn_utils.move_to_cuda(sos) dec_inputs = torch.cat([sos, strokes], 0) # add sos at the begining of the strokes; [max_len + 1, bsz, 5] if self.hp.use_categories_dec: cat_embs = self.category_embedding(cats_idx) # [bsz, cat_dim] cat_embs = cat_embs.repeat(dec_inputs.size(0), 1, 1) # [max_len + 1, bsz, cat_dim] dec_inputs = torch.cat([dec_inputs, cat_embs], dim=2) # [max_len+1, bsz, 5 + cat_dim] # Decode outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(dec_inputs, output_all=True) # Calculate losses mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M) loss = self.dec.reconstruction_loss(mask, dx, dy, p, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q) result = {'loss': loss, 'loss_R': loss} if ((loss != loss).any() or (loss == float('inf')).any() or (loss == float('-inf')).any()): raise Exception('Nan in SketchRNnDecoderGMMOnly forward pass') return result
def kullback_leibler_loss(self, sigma_hat, mu, KL_min, wKL, eta_step): """ Calculate KL loss -- (eq. 10, 11) Args: sigma_hat: [bsz, z_dim] mu: [bsz, z_dim] KL_min: float wKL: float eta_step: float Returns: float Tensor """ bsz, z_dim = sigma_hat.size() LKL = -0.5 * torch.sum(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat)) \ / float(bsz * z_dim) KL_min = torch.Tensor([KL_min]) KL_min = nn_utils.move_to_cuda(KL_min) KL_min = KL_min.detach() LKL_thresh = torch.max(LKL, KL_min) LKL_final = wKL * eta_step * LKL_thresh return LKL_final, LKL_thresh, LKL
def __init__(self, hp, save_dir=None): super().__init__(hp, save_dir) self.tr_loader = self.get_data_loader('train', shuffle=True) self.val_loader = self.get_data_loader('valid', shuffle=False) self.end_epoch_loader = self.val_loader # # Model # hp.enc_dim = hp.dim if (hp.enc_dim == -1) else hp.enc_dim if hp.use_layer_norm: # layer norm enc lstm is 1 layer, so h and c get split up # to initialize decoder assert hp.enc_dim == hp.dim * hp.n_dec_layers self.token_embedding = nn.Embedding(self.tr_loader.dataset.vocab_size, hp.dim) self.models.append(self.token_embedding) self.category_embedding = None if (self.hp.use_categories_enc) or (hp.use_categories_dec): self.category_embedding = nn.Embedding(35, hp.dim) self.models.append(self.category_embedding) if self.hp.rank_imgs_text: self.rank_bilin_mod = torch.nn.Bilinear(hp.dim, hp.dim, 1) self.models.append(self.rank_bilin_mod) # Encoder decoder if hp.model_type.endswith('lstm'): if hp.drawing_type == 'image': self.n_channels = len(hp.images.split(',')) self.enc = StrokeAsImageEncoderCNN(hp.cnn_type, self.n_channels, hp.dim) else: # drawing_type is stroke # encoders may be different if hp.model_type == 'cnn_lstm': self.enc = StrokeEncoderCNN( n_feat_maps=hp.dim, input_dim=5, emb_dim=hp.enc_dim, dropout=hp.dropout, use_categories=hp.use_categories_enc) # raise NotImplementedError('use_categories_enc=true not implemented for CNN encoder') elif hp.model_type == 'transformer_lstm': self.enc = StrokeEncoderTransformer( 5, hp.enc_dim, num_layers=hp.n_enc_layers, dropout=hp.dropout, use_categories=hp.use_categories_enc, ) elif hp.model_type == 'lstm': self.enc = StrokeEncoderLSTM( 5, hp.enc_dim, num_layers=hp.n_enc_layers, dropout=hp.dropout, batch_first=False, use_categories=hp.use_categories_enc, use_layer_norm=hp.use_layer_norm, rec_dropout=hp.rec_dropout) if hp.use_mem: self.mem = SketchMem(base_mem_size=hp.base_mem_size, category_mem_size=hp.category_mem_size, mem_dim=hp.mem_dim, input_dim=hp.dim, output_dim=hp.dim) self.models.append(self.mem) # decoder is lstm dec_input_dim = hp.dim if hp.condition_on_hc: dec_input_dim += hp.dim if hp.use_categories_dec: dec_input_dim += hp.dim self.dec = InstructionDecoderLSTM( dec_input_dim, hp.dim, num_layers=hp.n_dec_layers, dropout=hp.dropout, batch_first=False, condition_on_hc=hp.condition_on_hc, use_categories=hp.use_categories_dec, use_layer_norm=hp.use_layer_norm, rec_dropout=hp.rec_dropout) self.models.extend([self.enc, self.dec]) elif hp.model_type == 'transformer': if hp.use_categories_enc or hp.use_categories_dec: raise NotImplementedError( 'Use categories not implemented for Transformer') self.strokes_input_fc = nn.Linear(5, hp.dim) self.pos_enc = PositionalEncoder(hp.dim, max_seq_len=250) self.transformer = nn.Transformer( d_model=hp.dim, dim_feedforward=hp.dim * 4, nhead=2, activation='relu', num_encoder_layers=hp.n_enc_layers, num_decoder_layers=hp.n_dec_layers, dropout=hp.dropout, ) for p in self.transformer.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) self.models.extend( [self.strokes_input_fc, self.pos_enc, self.transformer]) for model in self.models: model.cuda() # Additional loss if hp.unlikelihood_loss: # assert hp.model_type == 'cnn_lstm' # load true vocab distribution token2idx = self.tr_loader.dataset.token2idx vocab_prob = utils.load_file(INSTRUCTIONS_VOCAB_DISTRIBUTION_PATH) self.vocab_prob = torch.zeros(len(token2idx)).fill_( 1e-6) # fill with eps for token, prob in vocab_prob.items(): try: idx = token2idx[token] self.vocab_prob[idx] = prob except KeyError as e: # not sure why 'lion' isn't in vocab print(e) continue self.vocab_prob = nn_utils.move_to_cuda(self.vocab_prob) # [vocab] # create running of vocab distribution n_past = 256 # number of minibatches to store self.model_vocab_prob = torch.zeros(n_past, len(token2idx)) # [n, vocab] self.model_vocab_prob = nn_utils.move_to_cuda( self.model_vocab_prob) # Optimizers self.optimizers.append(optim.Adam(self.parameters(), hp.lr)) self.scorers = [InstructionScorer('rouge')]
def inference_pass(self, strokes, stroke_lens, cats_idx): """ Args: strokes: [len, bsz, 5] stroke_lens: list of ints cats_idx: [bsz] LongTensor Returns: decoded_probs: [bsz, max_len, vocab] decoded_ids: [bsz, max_len] decoded_texts: list of strs """ bsz = strokes.size(1) if self.hp.model_type in ['cnn_lstm', 'transformer_lstm', 'lstm']: mem_emb = None if self.hp.drawing_type == 'image': # TODO: this is horribly confusing... # strokes is actually images [C, B, H, W] # stroke_lens (2nd item in batch) is actually a tuple of rank_imgs and rank_imgs_pref # We don't need to use rank imgs during inference, it's just used during training as an auxiliary loss embedded = self.enc(strokes) # [bsz, dim] if self.hp.use_mem: # mem_emb = self.mem(embedded, cats_idx) # [bsz, mem_dim] embedded = embedded + self.mem(embedded, cats_idx) # [bsz, mem_dim] embedded = embedded.unsqueeze(0) # [1, bsz, dim] hidden = embedded.repeat(self.dec.num_layers, 1, 1) # [n_ glayers, bsz, dim] cell = embedded.repeat(self.dec.num_layers, 1, 1) # [n_layers, bsz, dim] else: if self.hp.model_type == 'cnn_lstm': # Encode strokes embedded = self.enc( strokes, stroke_lens, category_embedding=self.category_embedding, categories=cats_idx) # [bsz, dim] if self.hp.use_mem: # mem_emb = self.mem(embedded, cats_idx) # [bsz, mem_dim] embedded = embedded + self.mem( embedded, cats_idx) # [bsz, mem_dim] embedded = embedded.unsqueeze(0) # [1, bsz, dim] hidden = embedded.repeat(self.dec.num_layers, 1, 1) # [n_layers, bsz, dim] cell = embedded.repeat(self.dec.num_layers, 1, 1) # [n_layers, bsz, dim] elif self.hp.model_type == 'transformer_lstm': # Encode strokes hidden = self.enc( strokes, stroke_lens, category_embedding=self.category_embedding, categories=cats_idx) # [bsz, dim] # [bsz, dim] hidden = hidden.unsqueeze(0) # [1, bsz, dim] hidden = hidden.repeat(self.dec.num_layers, 1, 1) # [n_layers, bsz, dim] cell = hidden.clone() # [n_layers, bsz, dim] elif self.hp.model_type == 'lstm': _, (hidden, cell) = self.enc( strokes, stroke_lens, category_embedding=self.category_embedding, categories=cats_idx, split_hc=self.hp.n_dec_layers) # [max_stroke_len, bsz, dim]; h/c = [layers * direc, bsz, dim] # Create init input init_ids = nn_utils.move_to_cuda( torch.LongTensor([SOS_ID] * bsz).unsqueeze(1)) # [bsz, 1] init_ids.transpose_(0, 1) # [1, bsz] decoded_probs, decoded_ids, decoded_texts = self.dec.generate( self.token_embedding, category_embedding=self.category_embedding, categories=cats_idx, mem_emb=mem_emb, init_ids=init_ids, hidden=hidden, cell=cell, pad_id=PAD_ID, eos_id=EOS_ID, max_len=25, decode_method=self.hp.decode_method, tau=self.hp.tau, k=self.hp.k, idx2token=self.tr_loader.dataset.idx2token, ) elif self.hp.model_type == 'transformer': strokes_emb = self.strokes_input_fc( strokes) # [max_stroke_len, bsz, dim] src_input_embs = scale_add_pos_emb( strokes_emb, self.pos_enc) # [max_stroke_len, bsz, dim] init_ids = nn_utils.move_to_cuda( torch.LongTensor([SOS_ID] * bsz).unsqueeze(1)) # [bsz, 1] init_ids.transpose_(0, 1) # [1, bsz] init_embs = self.token_embedding(init_ids) # [1, bsz, dim] decoded_probs, decoded_ids, decoded_texts = transformer_generate( self.transformer, self.token_embedding, self.pos_enc, src_input_embs=src_input_embs, input_lens=stroke_lens, init_ids=init_ids, pad_id=PAD_ID, eos_id=EOS_ID, max_len=100, decode_method=self.hp.decode_method, tau=self.hp.tau, k=self.hp.k, idx2token=self.tr_loader.dataset.idx2token) return decoded_probs, decoded_ids, decoded_texts
def one_forward_pass_imagecnn_lstm(self, batch): imgs, ( rank_imgs, rank_imgs_pref ), texts, text_lens, text_indices_w_sos_eos, cats, cats_idx, urls = batch # Encode strokes embedded = self.enc(imgs) # [bsz, dim] if self.hp.use_mem: # mem_emb = self.mem(embedded, cats_idx) # [bsz, mem_dim] embedded = embedded + self.mem(embedded, cats_idx) # [bsz, mem_dim] mem_emb = None embedded = embedded.unsqueeze(0) # [1, bsz, dim] hidden = embedded.repeat(self.dec.num_layers, 1, 1) # [n_layers, bsz, dim] cell = embedded.repeat(self.dec.num_layers, 1, 1) # [n_layers, bsz, dim] # Decode texts_emb = self.token_embedding( text_indices_w_sos_eos) # [max_text_len + 2, bsz, dim] logits, texts_hidden = self.dec( texts_emb, text_lens, hidden=hidden, cell=cell, token_embedding=self.token_embedding, category_embedding=self.category_embedding, categories=cats_idx, # [max_text_len + 2, bsz, vocab]; h/c mem_emb=mem_emb) loss = self.compute_loss(logits, text_indices_w_sos_eos, PAD_ID) result = {'loss': loss, 'loss_decode': loss.clone().detach()} if self.hp.unlikelihood_loss: loss_UL = self.compute_unlikelihood_loss(logits, text_lens) result['loss'] += loss_UL # for backward result['loss_unlikelihood'] = loss_UL.clone().detach( ) # for logging if self.hp.rank_imgs_text: # not on cuda because it's a tuple within the batch... not done by preprocess() rank_imgs = nn_utils.move_to_cuda(rank_imgs) rank_imgs_pref = nn_utils.move_to_cuda(rank_imgs_pref) # embed rank images C, bsz, H, W = imgs.size() rank_imgs = rank_imgs.view( bsz * self.hp.n_rank_imgs, C, H, W ) # [bsz, rank_n_imgs, C, H, W] -> [bsz * rank_n_imgs, C, H, W] rank_imgs = rank_imgs.transpose( 0, 1) # [C, bsz * rank_n_imgs, H, W] (CNN expects batch second) rank_imgs_emb = self.enc(rank_imgs) # [bsz * rank_n_imgs, dim] rank_imgs_emb = rank_imgs_emb.view(bsz, self.hp.n_rank_imgs, -1) # [bsz, rank_n_imgs, dim] # Compute loss texts_hidden = texts_hidden[0][ -1, :, :] # last layer hidden? -> [bsz, dim] # TODO: loss_rank = self.rank_imgs_text_loss(rank_imgs_emb, texts_hidden, rank_imgs_pref) result['loss'] += loss_rank # for backward result['loss_rank_imgs_text'] = loss_rank.clone().detach( ) # for logging return result
def forward(self, strokes, hidden_cell=None): """ Args: strokes: [max_len, bsz, input_dim] (input_size == isz == 5) hidden_cell: tuple of [n_layers * n_directions, bsz, dim] Returns: z: [bsz, z_dim] mu: [bsz, z_dim] sigma_hat [bsz, z_dim] (used to calculate KL loss, eq. 10) """ bsz = strokes.size(1) # Initialize hidden state and cell state with zeros on first forward pass num_directions = 2 if self.bidirectional else 1 if hidden_cell is None: hidden = torch.zeros(self.num_layers * num_directions, bsz, self.enc_dim) cell = torch.zeros(self.num_layers * num_directions, bsz, self.enc_dim) hidden, cell = nn_utils.move_to_cuda( hidden), nn_utils.move_to_cuda(cell) hidden_cell = (hidden, cell) # Pass inputs, hidden, and cell into encoder's lstm # http://pytorch.org/docs/master/nn.html#torch.nn.LSTM if self.use_layer_norm: _, (hidden_f, cell_f) = self.lstm_f( strokes, hidden_cell ) # h and c: [n_layers * n_directions, bsz, enc_dim] last_hidden_f = hidden_f.view( self.num_layers, num_directions, bsz, self.enc_dim)[-1, :, :, :] # [num_directions, bsz, hsz] last_hidden_f = last_hidden_f.transpose(0, 1).reshape( bsz, -1) # [bsz, num_directions * hsz] strokes_b = torch.flip(strokes, [0]) _, (hidden_b, cell_b) = self.lstm_b( strokes_b, hidden_cell ) # h and c: [n_layers * n_directions, bsz, enc_dim] last_hidden_b = hidden_b.view( self.num_layers, num_directions, bsz, self.enc_dim)[-1, :, :, :] # [num_directions, bsz, hsz] last_hidden_b = last_hidden_b.transpose(0, 1).reshape( bsz, -1) # [bsz, num_directions * hsz] last_hidden = torch.stack([last_hidden_f, last_hidden_b]).mean(dim=0) else: _, (hidden, cell) = self.lstm( strokes, hidden_cell ) # h and c: [n_layers * n_directions, bsz, enc_dim] # TODO: seems throw a CUDNN error without the float... but shouldn't it be float already? last_hidden = hidden.view(self.num_layers, num_directions, bsz, self.enc_dim)[-1, :, :, :] # [num_directions, bsz, hsz] last_hidden = last_hidden.transpose(0, 1).reshape( bsz, -1) # [bsz, num_directions * hsz] # Get mu and sigma from hidden mu = self.fc_mu(last_hidden) # [bsz, z_dim] sigma_hat = self.fc_sigma(last_hidden) # [bsz, z_dim] # Get z for VAE using mu and sigma, N ~ N(0,1) # Turn sigma_hat vector into non-negative std parameter sigma = torch.exp(sigma_hat / 2.) N = torch.randn_like(sigma) N = nn_utils.move_to_cuda(N) z = mu + sigma * N # [bsz, z_dim] # Note we return sigma_hat, not sigma to be used in KL-loss (eq. 10) return z, mu, sigma_hat
def transformer_generate( transformer, token_embedding, pos_enc, src_input_embs=None, input_lens=None, init_ids=None, pad_id=None, eos_id=None, max_len=100, decode_method=None, tau=None, k=None, idx2token=None, ): """ Decode up to max_len symbols by feeding previous output as next input. Args: transformer: nn.Transformer token_embedding: nn.Embedding(vocab, dim) pos_enc: PositionalEncoder module input_embs: [input_len, bsz, dim] input_lens: list of ints init_ids: [init_len, bsz] (e.g. SOS ids) init_embs: [init_len, bsz, emb] (e.g. embedded SOS ids) pad_id: int eos_id: int (id for EOS_ID token) decode_method: str (how to sample words given probabilities; 'greedy', 'sample') tau: float (temperature for softmax) k: int (for sampling or beam search) idx2token: dict Returns: decoded_probs: [bsz, max_len, vocab] decoded_ids: [bsz, max_len] decoded_texts: list of strs """ init_len, bsz = init_ids.size() vocab_size = len(idx2token) # Encode inputs src_key_padding_mask, _, memory_key_padding_mask = create_transformer_padding_masks( src_lens=input_lens) memory = transformer.encoder( src_input_embs, src_key_padding_mask=src_key_padding_mask) # [input_len, bsz, dim] # Track which sequences have generated eos_id rows_with_eos = nn_utils.move_to_cuda(torch.zeros(bsz).long()) pad_ids = nn_utils.move_to_cuda(torch.Tensor(bsz).fill_(pad_id)).long() pad_prob = nn_utils.move_to_cuda(torch.zeros( bsz, vocab_size)) # one hot for pad id pad_prob[:, pad_id] = 1 # Generate decoded_probs = nn_utils.move_to_cuda( torch.zeros(init_len + max_len, bsz, vocab_size)) decoded_ids = nn_utils.move_to_cuda( torch.zeros(init_len + max_len, bsz).long()) decoded_ids[:init_len, :] = init_ids for t in range(init_len, max_len): # pass through TransformerDecoder tgt_mask = generate_square_subsequent_mask(t).type_as(decoded_ids) cur_dec_input = decoded_ids[:t, :] # [t, bsz] cur_dec_input = token_embedding(cur_dec_input) cur_dec_input = scale_add_pos_emb(cur_dec_input, pos_enc) # [t, bsz, dim] dec_outputs = transformer.decoder( cur_dec_input, memory, # dec_outputs = [t, bsz, dim] tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask) # Compute logits over vocab, use last output to get next token logits = torch.matmul(dec_outputs, token_embedding.weight.t()) # [t, bsz, vocab] logits = logits[-1, :, :] # [bsz, vocab] prob = nn_utils.logits_to_prob(logits, tau=tau) # [bsz, vocab] prob, ids = nn_utils.prob_to_vocab_id( prob, decode_method, k=k) # prob: [bsz, vocab]; ids: [bsz, k] ids = ids[:, 0] # get top k # Update generated sequence so far # If sequence (row) has already produced an eos_id *earlier*, replace id/prob with pad # TODO: I don't think decoded_probs is being filled with pad_prob for some reason prob = torch.where((rows_with_eos == 1).unsqueeze(1), pad_prob, prob) # unsqueeze to broadcast ids = torch.where(rows_with_eos == 1, pad_ids, ids) decoded_probs[t, :, :] = prob decoded_ids[t, :] = ids # Update for next iteration in loop rows_with_eos = rows_with_eos | (ids == eos_id).long() # Terminate early if all sequences have generated eos if rows_with_eos.sum().item() == bsz: break # # Remove initial input to decoder decoded_probs = decoded_probs[init_len:, :, :] decoded_ids = decoded_ids[init_len:, :] # TODO: remove this once InstructionDecoderLSTM is refactored to return [len, bsz] instead of [bsz, len] decoded_probs.transpose_(0, 1) decoded_ids.transpose_(0, 1) # Convert to strings decoded_texts = [] if idx2token is not None: for i in range(bsz): tokens = [] for j in range(decoded_ids.size(1)): id = decoded_ids[i][j].item() # import pdb; pdb.set_trace() # TODO: Saw an example that was EOS EOS EOS... why isn't this being caught by the following equality statement? if id == eos_id: break tokens.append(idx2token[id]) text = ' '.join(tokens) decoded_texts.append(text) import pdb pdb.set_trace() return decoded_probs, decoded_ids, decoded_texts
def one_forward_pass(self, batch): """ Return loss and other items of interest for one forward pass Args: batch: tuple from DataLoaders Returns: dict where 'loss': float Tensor must exist """ strokes, stroke_lens, texts, text_lens, text_indices, cats, cats_idx, urls = batch # batch is 1st dimension (not 0th) due to preprocess_batch() # Create base inputs to decoder _, bsz, _ = strokes.size() sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze( 0) # start of sequence sos = nn_utils.move_to_cuda(sos) dec_inputs = torch.cat( [sos, strokes], dim=0 ) # add sos at the begining of the strokes; [max_len + 1, bsz, 5] if self.hp.use_categories_dec: cat_embs = self.category_embedding(cats_idx) # [bsz, cat_dim] cat_embs = cat_embs.repeat(dec_inputs.size(0), 1, 1) # [max_len + 1, bsz, cat_dim] dec_inputs = torch.cat([dec_inputs, cat_embs], dim=2) # [max_len+1, bsz, 5 + cat_dim] # # Encode instructions, decode # if self.hp.instruction_set in ['stack', 'stack_leaves']: # text_indices: [max_seq_len, bsz, max_instruction_len], # text_lens: [max_seq_len, bsz] # decoder's hidden states are "matched" with language representations outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec( dec_inputs, output_all=True) # # PLAN get hidden states and instruction stack for each batch input, for each segment # import pdb; pdb.set_trace() # outputs: [max_seq_len, bsz, dim] # loss_match = 0.0 # For each sequence in batch, divide drawing into segments (based on penup strokes) # For each segment, compute a matching loss between its hidden state and the # the encoded instruction stack for that segment all_instruction_embs = [] all_seg_hiddens = [] for i in range(bsz): penups = np.where( strokes[:, i, :].cpu().numpy()[:, 3] == 1)[0].tolist() penups = ([0] + penups) if ( penups[0] != 0 ) else penups # first element could already be 0 # TODO: find other place that I do [0] + penups, see if I need to account for the case # where the first element is 0 # Encode instruction stacks # text_indices: [max_seq_len, bsz, max_instruction_len] instructions = [ text_indices[start_idx, i, :] for start_idx in penups[:-1] ] # Note on above: # [:-1] because that's the end of the last segment # instructions for each timestep within segment are the same, take the start_idx instructions = torch.stack( instructions, dim=1 ) # [max_instruction_len, n_segs] (max across all segs in batch) # (n_segs is the "batch" for the encoder) instructions_lens = [ text_lens[start_idx, i].item() for start_idx in penups[:-1] ] instructions = instructions[:max( instructions_lens), :] # encoder requires this cur_cats_idx = [ cats_idx[i] for _ in range(len(instructions_lens)) ] # all segs are from same drawing (i.e. same category) instruction_embs = self.enc( instructions, instructions_lens, self.text_embedding, category_embedding=None, categories=cur_cats_idx) # [n_segs, dim] all_instruction_embs.append(instruction_embs) # Compute hidden states mean for each seg seg_hiddens = [] for j in range(len(penups) - 1): # n_segs start_idx = penups[j] end_idx = penups[j + 1] seg_outputs = outputs[start_idx:end_idx + 1, i, :] # [seg_len, dim] seg_hidden = seg_outputs.mean(dim=0) # [dim] seg_hiddens.append(seg_hidden) seg_hiddens = torch.stack(seg_hiddens, dim=0) # [n_segs, dim] all_seg_hiddens.append(seg_hiddens) # Concate all segs across all batch items if self.hp.loss_match == 'triplet': all_instruction_embs = torch.cat( all_instruction_embs, dim=0) # [n_total_segs, enc_dim] all_seg_hiddens = torch.cat(all_seg_hiddens, dim=0) # [n_total_segs, dec_dim] all_seg_hiddens = self.fc_dec( all_seg_hiddens) # [n_total_segs, enc_dim] # Compute triplet loss pos = (all_seg_hiddens - all_instruction_embs)**2 # [n_total_segs, enc_dim] all_instruction_embs_shuffled = all_instruction_embs[ torch.randperm(pos.size(0)), :] # [n_total_segs, enc_dim] neg = (all_seg_hiddens - all_instruction_embs_shuffled )**2 # [n_total_segs, enc_dim] loss_match = (pos - neg).mean() + torch.tensor(0.1).to( pos.device) # positive - negative + alpha loss_match = max(torch.tensor(0.0), loss_match) elif self.hp.loss_match == 'decode': raise NotImplementedError # TODO: check if text_indices is correct elif self.hp.instruction_set in [ 'toplevel', 'toplevel_leaves', 'leaves' ]: # triplet loss if self.hp.cond_instructions == 'decinputs': # concatenate instruction embedding to every time step # Encode instructions # text_indices: [len, bsz], text_lens: [bsz] instructions_emb = self.enc( text_indices, text_lens, self.text_embedding, category_embedding=self.category_embedding, categories=cats_idx) # [bsz, enc_dim] # decode instructions_emb = instructions_emb.unsqueeze( 0) # [1, bsz, dim] instructions_emb = instructions_emb.repeat( dec_inputs.size(0), 1, 1) # [max_len + 1, bsz, dim] dec_inputs = torch.cat([dec_inputs, instructions_emb], dim=2) # [max_len + 1, bsz, inp_dim] outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec( dec_inputs, output_all=True) elif self.hp.cond_instructions == 'match': # match decoder's hidden representations to encoded language # decode outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec( dec_inputs, output_all=True) # outputs: [max_seq_len, bsz, dim] outputs = outputs.mean(dim=0) # [bsz, dec_dim] if self.hp.loss_match == 'triplet': # Encode instructions # text_indices: [len, bsz], text_lens: [bsz] instructions_emb = self.enc( text_indices, text_lens, self.text_embedding, category_embedding=self.category_embedding, categories=cats_idx) # [bsz, enc_dim] outputs = self.fc_dec(outputs) # [bsz, enc_dim] pos = (outputs - instructions_emb)**2 # [bsz, enc_dim] instructions_emb_shuffled = instructions_emb[ torch.randperm(bsz), :] # [bsz, enc_dim] neg = (outputs - instructions_emb_shuffled)**2 # [bsz, enc_dim] loss_match = (pos - neg).mean() + torch.tensor(0.1).to( pos.device) # positive - negative + alpha loss_match = max(torch.tensor(0.0), loss_match) elif self.hp.loss_match == 'decode': hidden = nn_utils.move_to_cuda( torch.zeros(self.ins_dec.num_layers, bsz, self.ins_dec.hidden_dim)) cell = nn_utils.move_to_cuda( torch.zeros(self.ins_dec.num_layers, bsz, self.ins_dec.hidden_dim)) # Decode texts_emb = self.text_embedding( text_indices) # [len, bsz, dim] logits, texts_hidden = self.ins_dec( texts_emb, text_lens, hidden=hidden, cell=cell, token_embedding=self.text_embedding, category_embedding=self.category_embedding, categories=cats_idx) loss_match = self.compute_dec_loss(logits, text_indices) # # Calculate reconstruction and final loss # mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M) loss_R = self.dec.reconstruction_loss(mask, dx, dy, p, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q) if self.hp.cond_instructions == 'decinputs': loss = loss_R result = {'loss': loss, 'loss_R': loss_R.clone().detach()} elif self.hp.cond_instructions == 'match': loss = loss_R + loss_match result = { 'loss': loss, 'loss_R': loss_R.clone().detach(), 'loss_match': loss_match.clone().detach() } if ((loss != loss).any() or (loss == float('inf')).any() or (loss == float('-inf')).any()): raise Exception('Nan in SketchRNnDecoderGMMOnly forward pass') return result
def segment_sample(self, sample, dataset): """ Args: sample: batch of samples from DataLoader of Strokedataset (batch_size=1) dataset: str Returns: strokes: TODO: why am I returning this? segmented: list of dicts Note: There are several different indices / mappings. 1) penups (i.e. number of penups). One penup = one segment. 2) seg_idx. Indexes into batch_of_segs. seg_idx_map: (left_penup, right_penup) -> seg_idx e.g. (0,1) -> 0, (0,2) -> 1, (0,3) -> 2, (1,2) -> 3, (1,3) -> 4, (2,3) -> 5 3) parchild_idx. Indexes into parchild_scores, i.e. P(S | [I1, I2]). leftrightsegidx_to_parchildidx: (left_seg_idx, right_seg_idx) -> par_child_idx """ if self.s2i_hp.drawing_type == 'stroke': if dataset == 'ndjson': strokes, stroke_lens, cats, cats_idx = sample elif dataset == 'progressionpair': strokes, stroke_lens, texts, text_lens, text_indices_w_sos_eos, cats, cats_idx, urls = sample strokes = strokes.transpose(0, 1).float() # strokes: [len, 1, 5] strokes = nn_utils.move_to_cuda(strokes) strokes = strokes.squeeze(1) # [len, 5] segs, n_penups, seg_lens, seg_idx_map = self.construct_batch_of_segments_from_one_sample_stroke( strokes) cats_idx = cats_idx.repeat(len(seg_lens)) cats_idx = nn_utils.move_to_cuda(cats_idx) elif self.s2i_hp.drawing_type == 'image': if dataset == 'ndjson': raise NotImplementedError elif dataset == 'progressionpair': segs, n_penups, seg_lens, seg_idx_map = self.construct_batch_of_segments_from_one_sample_image( sample) cats_idx = self.ds.cat2idx[sample['category']] cats_idx = torch.LongTensor( [cats_idx for _ in range(len(seg_lens))]) cats_idx = nn_utils.move_to_cuda(cats_idx) seg_scores, seg_texts, parchild_scores, leftrightsegidx_to_parchildidx = self.calculate_seg_scores( segs, seg_lens, cats_idx, seg_idx_map) # top level segmentation # initial instruction for entire sequence seg_idx = seg_idx_map[(0, n_penups)] segmented = [{ 'left': 0, 'right': n_penups, 'score': seg_scores[seg_idx], 'text': seg_texts[seg_idx], 'id': uuid4().hex, 'parent': '' }] # recursively segment segmented = self.split(0, n_penups, seg_idx_map, seg_scores, seg_texts, segmented, parchild_scores, leftrightsegidx_to_parchildidx) return segmented
def forward(self, strokes, stroke_lens=None, output_all=True, hidden_cell=None): """ Args: strokes: [len, bsz, input_dim (e.g. dim + 5)] output_all: boolean, return output at every timestep or just the last hidden_cell: tuple of [n_layers, bsz, dec_dim] :returns: pi: weights for each mixture [max_len + 1, bsz, M] mu_x: mean x for each mixture [max_len + 1, bsz, M] mu_y: mean y for each mixture [max_len + 1, bsz, M] sigma_x: var x for each mixture [max_len + 1, bsz, M] sigma_y: var y for each mixture [max_len + 1, bsz, M] rho_xy: covariance for each mixture [max_len + 1, bsz, M] q: [max_len + 1, bsz, 3] models p (3 pen strokes in stroke-5) as categorical distribution (page 3); hidden: [1, bsz, dec_dim] last hidden state cell: [1, bsz, dec_dim] last cell state """ bsz = strokes.size(1) if hidden_cell is None: # init hidden = torch.zeros(self.num_layers, bsz, self.dec_dim) cell = torch.zeros(self.num_layers, bsz, self.dec_dim) hidden, cell = nn_utils.move_to_cuda( hidden), nn_utils.move_to_cuda(cell) hidden_cell = (hidden, cell) outputs, (hidden, cell) = self.lstm(strokes, hidden_cell) # self.outputs = outputs # Pass hidden state at each step to fully connected layer (Fig 2, Eq. 4) # Dimensions # outputs: [max_len + 1, bsz, dec_dim] # view: [(max_len + 1) * bsz, dec_dim] # y: [(max_len + 1) * bsz, 6 * M + 3] (6 comes from 5 for params, 6th for weights; see page 3) if output_all: y = self.fc_params(outputs.view(-1, self.dec_dim)) else: y = self.fc_params(hidden.view(-1, self.dec_dim)) # Separate pen and mixture params params = torch.split( y, 6, dim=1 ) # tuple of M [(max_len + 1) * bsz, 6] tensors, 1 [(max_len + 1) * bsz, 3] tensor params_mixture = torch.stack( params[:-1]) # trajectories; [M, (max_len + 1) * bsz, 6] params_pen = params[-1] # pen up/down; [(max_len + 1) * bsz, 3] # Split trajectories into each mixture param pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy = torch.split(params_mixture, 1, dim=2) # These all have [M, (max_len+1) * bsz, 1]; squeeze pi = pi.squeeze(2) # [M, (max_len+1) * bsz] mu_x = mu_x.squeeze(2) # [M, (max_len+1) * bsz] mu_y = mu_y.squeeze(2) # [M, (max_len+1) * bsz] sigma_x = sigma_x.squeeze(2) # [M, (max_len+1) * bsz] sigma_y = sigma_y.squeeze(2) # [M, (max_len+1) * bsz] rho_xy = rho_xy.squeeze(2) # [M, (max_len+1) * bsz] # When training, lstm receives whole input, use all outputs from lstm # When generating, input is just last generated sample # len_out used to reshape mixture param tensors if output_all: len_out = outputs.size(0) else: len_out = 1 # Compute softmax over mixtures pi = F.softmax(pi.t(), dim=-1).view(len_out, bsz, self.M) mu_x = mu_x.t().contiguous().view(len_out, bsz, self.M) mu_y = mu_y.t().contiguous().view(len_out, bsz, self.M) # Eq. 6 sigma_x = torch.exp(sigma_x.t()).view(len_out, bsz, self.M) sigma_y = torch.exp(sigma_y.t()).view(len_out, bsz, self.M) rho_xy = torch.tanh(rho_xy.t()).view(len_out, bsz, self.M) # Eq. 7 q = F.softmax(params_pen, dim=-1).view(len_out, bsz, 3) # TODO: refactor all instances to unpack outputs # TODO: rename outputs as all_hidden? return outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, hidden, cell
def generate_and_save(self, data_loader, epoch, n_gens, outputs_path=None): """ Generate and save drawings """ n = 0 gen_strokes = [] gt_strokes = [] gt_texts = [] for i, batch in enumerate(data_loader): batch = self.preprocess_batch_from_data_loader(batch) strokes, stroke_lens, texts, text_lens, text_indices, cats, cats_idx, urls = batch max_len, bsz, _ = strokes.size() if self.hp.cond_instructions == 'decinputs': # Encode instructions # text_indices: [len, bsz], text_lens: [bsz] instructions_emb = self.enc( text_indices, text_lens, self.text_embedding, category_embedding=self.category_embedding, categories=cats_idx) # [bsz, enc_dim] z = instructions_emb hidden_cell = (nn_utils.move_to_cuda( torch.zeros(1, bsz, self.hp.dec_dim)), nn_utils.move_to_cuda( torch.zeros(1, bsz, self.hp.dec_dim))) # initialize state with start of sequence stroke-5 stroke sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze( 0) # [1 (len), bsz, 5 (stroke-5)] sos = nn_utils.move_to_cuda(sos) # generate until end of sequence or maximum sequence length s = sos seq_x = [] # delta-x seq_y = [] # delta-y seq_pen = [] # pen-down for _ in range(max_len): if self.hp.cond_instructions == 'decinputs': # input is last state, z, and hidden_cell input = torch.cat( [s, z.unsqueeze(0)], dim=2 ) # [1 (len), 1 (bsz), input_dim (5) + z_dim (128)] elif self.hp.cond_instructions == 'match': # input is last state and hidden_cell input = s # [1, bsz (1), 5] if self.hp.use_categories_dec \ and hasattr(self, 'category_embedding'): # hack because VAE was trained with use_categories_dec=True but didn't actually have a category embedding cat_embs = self.category_embedding( cats_idx) # [bsz (1), cat_dim] input = torch.cat([input, cat_embs.unsqueeze(0)], dim=2) # [1, 1, dim] # dim = 5 + cat_dim if decodergmm, 5 + z_dim + cat_dim if vae outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, hidden, cell = \ self.dec(input, stroke_lens=stroke_lens, output_all=False, hidden_cell=hidden_cell) hidden_cell = (hidden, cell) # for next timee step # sample next state s, dx, dy, pen_up, eos = self.sample_next_state( pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q) seq_x.append(dx) seq_y.append(dy) seq_pen.append(pen_up) if eos: # done drawing break # get in format to draw image # Cumulative sum because seq_x and seq_y are deltas, so get x (or y) at each stroke sample_x = np.cumsum(seq_x, 0) sample_y = np.cumsum(seq_y, 0) sample_pen = np.array(seq_pen) sequence = np.stack([sample_x, sample_y, sample_pen]).T # output_fp = os.path.join(outputs_path, f'e{epoch}-gen{n}.jpg') # save_strokes_as_img(sequence, output_fp) # Save original as well output_fp = os.path.join(outputs_path, f'e{epoch}-gt{n}.jpg') strokes_x = strokes[:, 0, 0] # first 0 for x because sample_next_state etc. only using 0-th batch item; 2nd 0 for dx strokes_y = strokes[:, 0, 1] # 1 for dy strokes_x = np.cumsum(strokes_x.cpu().numpy()) strokes_y = np.cumsum(strokes_y.cpu().numpy()) strokes_pen = strokes[:, 0, 3].cpu().numpy() strokes_out = np.stack([strokes_x, strokes_y, strokes_pen]).T # save_strokes_as_img(strokes_out, output_fp) gen_strokes.append(sequence) gt_strokes.append(strokes_out) gt_texts.append(texts[0]) # 0 because batch size is 1 n += 1 if n == n_gens: break # save grid drawings rowcol_size = 5 chunk_size = rowcol_size**2 for i in range(0, chunk_size, len(gen_strokes)): output_fp = os.path.join(outputs_path, f'e{epoch}_gen{i}-{i+chunk_size}.jpg') save_multiple_strokes_as_img(gen_strokes[i:i + chunk_size], output_fp) output_fp = os.path.join(outputs_path, f'e{epoch}_gt{i}-{i+chunk_size}.jpg') save_multiple_strokes_as_img(gt_strokes[i:i + chunk_size], output_fp) # save texts output_fp = os.path.join(outputs_path, f'e{epoch}_texts{i}-{i+chunk_size}.json') utils.save_file(gt_texts[i:i + chunk_size], output_fp)
def sample_next_state(self, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q): """ Return state using current mixture parameters etc. set from decoder call. Note that this state is different from the stroke-5 format. NOTE: currently only operates on first item in batch (hence the BATCH_ITEM) Args: pi: [len, bsz, M] mu_x: [len, bsz, M] mu_y: [len, bsz, M] sigma_x: [len, bsz, M] sigma_y: [len, bsz, M] rho_xy: [len, bsz, M] q: [len, bsz, 3] When used during generation, len should be 1 (decoding step by step) Returns: s: [1, (bsz), 5] dx: [1] dy: [1] pen_up: bool eos: bool """ def adjust_temp(pi_pdf): """Not super sure why this instead of just dividing by temperauture as in eq. 8, but magenta sketch_run/model.py does it this way(adjust_temp())""" pi_pdf = np.log(pi_pdf) / self.hp.temperature pi_pdf -= pi_pdf.max() pi_pdf = np.exp(pi_pdf) pi_pdf /= pi_pdf.sum() return pi_pdf _, bsz, M = pi.size() # TODO: currently, this method (and sample_bivariate_normal) only doesn't work produce samples # for every item in batch. It only does it for BATCH_ITEM-th point. BATCH_ITEM = 0 # index in batch # Get mixture index pi = pi.data[-1, BATCH_ITEM, :].cpu().numpy() # [M] pi = adjust_temp(pi) pi_idx = np.random.choice(M, p=pi) # choose Gaussian weighted by pi # Get mixture params mu_x = mu_x.data[-1, BATCH_ITEM, pi_idx] # [M] mu_y = mu_y.data[-1, BATCH_ITEM, pi_idx] # [M] sigma_x = sigma_x.data[-1, BATCH_ITEM, pi_idx] # [M] sigma_y = sigma_y.data[-1, BATCH_ITEM, pi_idx] # [M] rho_xy = rho_xy.data[-1, BATCH_ITEM, pi_idx] # [M] # Get next x andy by using mixture params and sampling from bivariate normal dx, dy = self.sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False) # Get pen state q = q.data[-1, BATCH_ITEM, :].cpu().numpy() # [3] # q = adjust_temp(q) # TODO: they don't adjust the temp for q in the magenta repo... q_idx = np.random.choice(3, p=q) # Create next_state vector next_state = torch.zeros(5) next_state[0] = dx next_state[1] = dy next_state[q_idx + 2] = 1 next_state = nn_utils.move_to_cuda(next_state) s = next_state.view(1, 1, -1) pen_up = q_idx == 1 eos = q_idx == 2 return s, dx, dy, pen_up, eos
def one_forward_pass(self, batch, average_loss=True): """ Return loss and other items of interest for one forward pass Args: batch: tuple from DataLoaders average_loss (bool): whether to average loss per batch item - Current use case: Segmentation model computes loss per segment. Batches are a batch of segments for one example. Returns: dict where 'loss': float Tensor must exist """ strokes, stroke_lens, texts, text_lens, text_indices, cats, cats_idx, urls = batch # Create base inputs to decoder _, bsz, _ = strokes.size() sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze( 0) # start of sequence sos = nn_utils.move_to_cuda(sos) dec_inputs = torch.cat( [sos, strokes], dim=0 ) # add sos at the begining of the strokes; [max_len + 1, bsz, 5] # # Encode instructions, decode # text_indices: [len, bsz], text_lens: [bsz] hidden = self.enc(text_indices, text_lens, self.text_embedding, category_embedding=None, categories=cats_idx) # [bsz, dim] # Method 1: concatenate instruction embedding to every time step if self.hp.cond_instructions == 'decinputs': hidden = hidden.unsqueeze(0) # [1, bsz, dim] hidden = hidden.repeat(dec_inputs.size(0), 1, 1) # [max_len + 1, bsz, dim] dec_inputs = torch.cat([dec_inputs, hidden], dim=2) # [max_len + 1, bsz, 5 + dim] outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec( dec_inputs, output_all=True) # Method 2: initialize decoder's hidden state with instruction embedding elif self.hp.cond_instructions == 'initdec': hidden = hidden.unsqueeze(0) # [1, bsz, dim] hidden_cell = (hidden, hidden.clone()) outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec( dec_inputs, output_all=True, hidden_cell=hidden_cell) # # Calculate losses # mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M) loss = self.dec.reconstruction_loss(mask, dx, dy, p, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, average_loss=average_loss) result = {'loss': loss, 'loss_R': loss} if ((loss != loss).any() or (loss == float('inf')).any() or (loss == float('-inf')).any()): raise Exception('Nan in SketchRNnDecoderGMMOnly forward pass') return result
def generate_and_save(self, data_loader, epoch, n_gens, outputs_path=None): """ Generate sequence """ n = 0 gen_strokes = [] gt_strokes = [] for i, batch in enumerate(data_loader): batch = self.preprocess_batch_from_data_loader(batch) strokes, stroke_lens, cats, cats_idx = batch max_len, bsz, _ = strokes.size() # Encode if self.hp.model_type == 'vae': z, _, _ = self.enc(strokes) # z: [bsz, 128] # init hidden and cell states is tanh(fc(z)) (Page 3) hidden, cell = torch.split(torch.tanh(self.fc_z_to_hc(z)), self.hp.dec_dim, dim=1) hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) elif 'decoder' in self.hp.model_type: hidden_cell = (nn_utils.move_to_cuda(torch.zeros(1, bsz, self.hp.dec_dim)), nn_utils.move_to_cuda(torch.zeros(1, bsz, self.hp.dec_dim))) # initialize state with start of sequence stroke-5 stroke sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(0) # [1 (len), bsz, 5 (stroke-5)] sos = nn_utils.move_to_cuda(sos) # generate until end of sequence or maximum sequence length s = sos seq_x = [] # delta-x seq_y = [] # delta-y seq_pen = [] # pen-down for _ in range(max_len): if self.hp.model_type in ['vae', 'decodergmm']: if self.hp.model_type == 'vae': # input is last state, z, and hidden_cell input = torch.cat([s, z.unsqueeze(0)], dim=2) # [1 (len), 1 (bsz), input_dim (5) + z_dim (128)] elif self.hp.model_type == 'decodergmm': # input is last state and hidden_cell input = s # [1, bsz (1), 5] if self.hp.use_categories_dec \ and hasattr(self, 'category_embedding'): # hack because VAE was trained with use_categories_dec=True but didn't actually have a category embedding cat_embs = self.category_embedding(cats_idx) # [bsz (1), cat_dim] input = torch.cat([input, cat_embs.unsqueeze(0)], dim=2) # [1, 1, dim] # dim = 5 + cat_dim if decodergmm, 5 + z_dim + cat_dim if vae outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, hidden, cell = \ self.dec(input, stroke_lens=stroke_lens, output_all=False, hidden_cell=hidden_cell) hidden_cell = (hidden, cell) # for next timie step # sample next state s, dx, dy, pen_up, eos = self.sample_next_state(pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q) elif self.hp.model_type == 'decoderlstm': # input is last state and hidden_cell input = s xy, q, hidden, cell = self.dec(input, stroke_lens=stroke_lens, output_all=False, hidden_cell=hidden_cell) hidden_cell = (hidden, cell) dx, dy = xy[-1,0,0].item(), xy[-1,0,1].item() # last timestep, first batch item, x / y pen_up = q[-1,0,:].max(dim=0)[1].item() == 1 # max index is the 2nd one (penup) eos = q[-1,0,:].max(dim=0)[1].item() == 2 # max index is the 3rd one (eos) seq_x.append(dx) seq_y.append(dy) seq_pen.append(pen_up) if eos: # done drawing break # get in format to draw image # Cumulative sum because seq_x and seq_y are deltas, so get x (or y) at each stroke sample_x = np.cumsum(seq_x, 0) sample_y = np.cumsum(seq_y, 0) sample_pen = np.array(seq_pen) sequence = np.stack([sample_x, sample_y, sample_pen]).T # output_fp = os.path.join(outputs_path, f'e{epoch}-gen{n}.jpg') # save_strokes_as_img(sequence, output_fp) # Save original as well output_fp = os.path.join(outputs_path, f'e{epoch}-gt{n}.jpg') strokes_x = strokes[:, 0, 0] # first 0 for x because sample_next_state etc. only using 0-th batch item; 2nd 0 for dx strokes_y = strokes[:, 0, 1] # 1 for dy strokes_x = np.cumsum(strokes_x.cpu().numpy()) strokes_y = np.cumsum(strokes_y.cpu().numpy()) strokes_pen = strokes[:, 0, 3].cpu().numpy() strokes_out = np.stack([strokes_x, strokes_y, strokes_pen]).T # save_strokes_as_img(strokes_out, output_fp) gen_strokes.append(sequence) gt_strokes.append(strokes_out) n += 1 if n == n_gens: break rowcol_size = 5 chunk_size = rowcol_size ** 2 for i in range(0, chunk_size, len(gen_strokes)): output_fp = os.path.join(outputs_path, f'e{epoch}_gen{i}-{i+chunk_size}.jpg') save_multiple_strokes_as_img(gen_strokes[i:i+chunk_size], output_fp) output_fp = os.path.join(outputs_path, f'e{epoch}_gt{i}-{i+chunk_size}.jpg') save_multiple_strokes_as_img(gt_strokes[i:i+chunk_size], output_fp)