def test_forward(self): batch = 16 len1, len2 = 21, 24 seq_len1 = torch.randint(low=len1 - 10, high=len1 + 1, size=(batch,)).long() seq_len2 = torch.randint(low=len2 - 10, high=len2 + 1, size=(batch,)).long() mask1 = [] for w in seq_len1: mask1.append([1] * w.item() + [0] * (len1 - w.item())) mask1 = torch.FloatTensor(mask1) mask2 = [] for w in seq_len2: mask2.append([1] * w.item() + [0] * (len2 - w.item())) mask2 = torch.FloatTensor(mask2) d = 200 # hidden dimension l = 20 # number of perspective test1 = torch.randn(batch, len1, d) test2 = torch.randn(batch, len2, d) test1 = test1 * mask1.view(-1, len1, 1).expand(-1, len1, d) test2 = test2 * mask2.view(-1, len2, 1).expand(-1, len2, d) test1_fw, test1_bw = torch.split(test1, d // 2, dim=-1) test2_fw, test2_bw = torch.split(test2, d // 2, dim=-1) ml_fw = BiMpmMatching.from_params(Params({"is_forward": True, "num_perspectives": l})) ml_bw = BiMpmMatching.from_params(Params({"is_forward": False, "num_perspectives": l})) vecs_p_fw, vecs_h_fw = ml_fw(test1_fw, mask1, test2_fw, mask2) vecs_p_bw, vecs_h_bw = ml_bw(test1_bw, mask1, test2_bw, mask2) vecs_p, vecs_h = torch.cat(vecs_p_fw + vecs_p_bw, dim=2), torch.cat(vecs_h_fw + vecs_h_bw, dim=2) assert vecs_p.size() == torch.Size([batch, len1, 10 + 10 * l]) assert vecs_h.size() == torch.Size([batch, len2, 10 + 10 * l]) assert ml_fw.get_output_dim() == ml_bw.get_output_dim() == vecs_p.size(2) // 2 == vecs_h.size(2) // 2
def __init__(self, X_in, y_in): # Check if we have Torch.LongTensor inputs (assume Numpy array otherwise) if not isinstance(X_in, torch.LongTensor): X_in = torch.from_numpy(X_in.astype('int64')).long() if not isinstance(y_in, torch.LongTensor): y_in = torch.from_numpy(y_in.astype('int64')).long() self.X_in = torch.split(X_in, 1, dim=0) self.y_in = torch.split(y_in, 1, dim=0)
def forward(self, input): projected = self.layer(input) non_lin, gate = torch.split(projected, self.input_dim, -1) non_lin = self.activation(non_lin) gate = self.gate(gate) combined = gate * input + (1 - gate) * non_lin return combined
def forward(self, featuremap, boxes, box_ind): """ RoIAlign based on crop_and_resize. See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301 :param featuremap: NxCxHxW :param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization** :param box_ind: M :return: MxCxoHxoW """ x1, y1, x2, y2 = torch.split(boxes, 1, dim=1) image_height, image_width = featuremap.size()[2:4] if self.transform_fpcoor: spacing_w = (x2 - x1) / float(self.crop_width) spacing_h = (y2 - y1) / float(self.crop_height) nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1) ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1) nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1) nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1) boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1) else: x1 = x1 / float(image_width - 1) x2 = x2 / float(image_width - 1) y1 = y1 / float(image_height - 1) y2 = y2 / float(image_height - 1) boxes = torch.cat((y1, x1, y2, x2), 1) boxes = boxes.detach().contiguous() box_ind = box_ind.detach() return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind)
def calc_pdparam(self, x, evaluate=True): ''' Calculate pdparams for multi-action by chunking the network logits output ''' x = torch.cat(torch.split(x, self.state_dims, dim=1)).unsqueeze_(dim=1) pdparam = SARSA.calc_pdparam(self, x, evaluate=evaluate) return pdparam
def calc_pdparam(self, x, evaluate=True): ''' Calculate pdparams for multi-action by chunking the network logits output ''' pdparam = super(MultitaskDQN, self).calc_pdparam(x, evaluate=evaluate) pdparam = torch.cat(torch.split(pdparam, self.action_dims, dim=1)) return pdparam
def forward(self, q, k, v, attn_mask=None): d_k, d_v = self.d_k, self.d_v n_head = self.n_head residual = q #print('q,k,v:',q.size(),k.size(),v.size()) mb_size, len_q, q_hidden_size = q.size() mb_size, len_k, k_hidden_size = k.size() mb_size, len_v, v_hidden_size = v.size() # treat as a (n_head) size batch q_s = q.repeat(n_head, 1, 1).view(n_head, -1, q_hidden_size) # n_head x (mb_size*len_q) x d_model k_s = k.repeat(n_head, 1, 1).view(n_head, -1, k_hidden_size) # n_head x (mb_size*len_k) x d_model v_s = v.repeat(n_head, 1, 1).view(n_head, -1, v_hidden_size) # n_head x (mb_size*len_v) x d_model #print('q_s,k_s,v_s:',q_s.size(),k_s.size(),v_s.size()) #print('w_qs',self.w_qs.size()) # treat the result as a (n_head * mb_size) size batch q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k) # (n_head*mb_size) x len_q x d_k k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k) # (n_head*mb_size) x len_k x d_k v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v) # (n_head*mb_size) x len_v x d_v # perform attention, result size = (n_head * mb_size) x len_q x d_v #print('attn_mask:',attn_mask.size()) #print(attn_mask) outputs, attns = self.attention.forward(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head,1,1)) # back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v) outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1) # project back to residual size outputs = self.proj.forward(outputs) outputs = self.dropout(outputs) return self.layer_norm(outputs + residual), attns
def forward(self, input_, hx): """ Args: input_: A (batch, input_size) tensor containing input features. hx: A tuple (h_0, c_0), which contains the initial hidden and cell state, where the size of both states is (batch, hidden_size). time: The current timestep value, which is used to get appropriate running statistics. Returns: h_1, c_1: Tensors containing the next hidden and cell state. """ h_0, c_0 = hx batch_size = h_0.size(0) bias_batch = (self.bias.unsqueeze(0) .expand(batch_size, *self.bias.size())) wh = torch.mm(h_0, self.weight_hh) wh = torch.mm(h_0, self.weight_hh) wi = torch.mm(input_, self.weight_ih) bn_wh = self.bn_hh(wh) bn_wi = self.bn_ih(wi) f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch, split_size=self.hidden_size, dim=1) c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1)) return h_1, c_1
def occlusion_sensitivity( model, images, ids, mean=None, patch=35, stride=1, n_batches=128 ): """ "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" https://arxiv.org/pdf/1610.02391.pdf Look at Figure A5 on page 17 Originally proposed in: "Visualizing and Understanding Convolutional Networks" https://arxiv.org/abs/1311.2901 """ torch.set_grad_enabled(False) model.eval() mean = mean if mean else 0 patch_H, patch_W = patch if isinstance(patch, Sequence) else (patch, patch) pad_H, pad_W = patch_H // 2, patch_W // 2 # Padded image images = F.pad(images, (pad_W, pad_W, pad_H, pad_H), value=mean) B, _, H, W = images.shape new_H = (H - patch_H) // stride + 1 new_W = (W - patch_W) // stride + 1 # Prepare sampling grids anchors = [] grid_h = 0 while grid_h <= H - patch_H: grid_w = 0 while grid_w <= W - patch_W: grid_w += stride anchors.append((grid_h, grid_w)) grid_h += stride # Baseline score without occlusion baseline = model(images).detach().gather(1, ids) # Compute per-pixel logits scoremaps = [] for i in tqdm(range(0, len(anchors), n_batches), leave=False): batch_images = [] batch_ids = [] for grid_h, grid_w in anchors[i : i + n_batches]: images_ = images.clone() images_[..., grid_h : grid_h + patch_H, grid_w : grid_w + patch_W] = mean batch_images.append(images_) batch_ids.append(ids) batch_images = torch.cat(batch_images, dim=0) batch_ids = torch.cat(batch_ids, dim=0) scores = model(batch_images).detach().gather(1, batch_ids) scoremaps += list(torch.split(scores, B)) diffmaps = torch.cat(scoremaps, dim=1) - baseline diffmaps = diffmaps.view(B, new_H, new_W) return diffmaps
def forward(self, tensors): # tensors must all be the same shape, let's say (batch_size, timesteps, dim) assert self.num_tensors == len(tensors) normed_weights = torch.nn.functional.softmax(torch.cat([p for p in self.scalar_parameters]), dim=0) normed_weights = torch.split(normed_weights, split_size_or_sections=1) pieces = [] for weight, tensor in zip(normed_weights, tensors): pieces.append(weight * tensor) return self.gamma * sum(pieces)
def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2): """ Calculates the intersection area of two lists of bounding boxes. :author 申瑞珉 (Ruimin Shen) :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. :return: The matrix (size [N1, N2]) of the intersection area. """ ymin1, xmin1 = torch.split(yx_min1, 1, -1) ymax1, xmax1 = torch.split(yx_max1, 1, -1) ymin2, xmin2 = torch.split(yx_min2, 1, -1) ymax2, xmax2 = torch.split(yx_max2, 1, -1) max_ymin = torch.max(ymin1.repeat(1, ymin2.size(0)), torch.transpose(ymin2, 0, 1).repeat(ymin1.size(0), 1)) # PyTorch's bug min_ymax = torch.min(ymax1.repeat(1, ymax2.size(0)), torch.transpose(ymax2, 0, 1).repeat(ymax1.size(0), 1)) # PyTorch's bug height = torch.clamp(min_ymax - max_ymin, min=0) max_xmin = torch.max(xmin1.repeat(1, xmin2.size(0)), torch.transpose(xmin2, 0, 1).repeat(xmin1.size(0), 1)) # PyTorch's bug min_xmax = torch.min(xmax1.repeat(1, xmax2.size(0)), torch.transpose(xmax2, 0, 1).repeat(xmax1.size(0), 1)) # PyTorch's bug width = torch.clamp(min_xmax - max_xmin, min=0) return height * width
def filter_shard_state(state, shard_size=None): for k, v in state.items(): if shard_size is None: yield k, v if v is not None: v_split = [] if isinstance(v, torch.Tensor): for v_chunk in torch.split(v, shard_size): v_chunk = v_chunk.data.clone() v_chunk.requires_grad = v.requires_grad v_split.append(v_chunk) yield k, (v, v_split)
def dis_update(self, images_a, images_b, hyperparameters): self.dis.zero_grad() x_aa, x_ba, x_ab, x_bb, shared = self.gen(images_a, images_b) data_a = torch.cat((images_a, x_ba), 0) data_b = torch.cat((images_b, x_ab), 0) res_a, res_b = self.dis(data_a,data_b) # res_true_a, res_true_b = self.dis(images_a,images_b) # res_fake_a, res_fake_b = self.dis(x_ba, x_ab) for it, (this_a, this_b) in enumerate(itertools.izip(res_a, res_b)): out_a = nn.functional.sigmoid(this_a) out_b = nn.functional.sigmoid(this_b) out_true_a, out_fake_a = torch.split(out_a, out_a.size(0) // 2, 0) out_true_b, out_fake_b = torch.split(out_b, out_b.size(0) // 2, 0) out_true_n = out_true_a.size(0) out_fake_n = out_fake_a.size(0) all1 = Variable(torch.ones((out_true_n)).cuda(self.gpu)) all0 = Variable(torch.zeros((out_fake_n)).cuda(self.gpu)) ad_true_loss_a = nn.functional.binary_cross_entropy(out_true_a, all1) ad_true_loss_b = nn.functional.binary_cross_entropy(out_true_b, all1) ad_fake_loss_a = nn.functional.binary_cross_entropy(out_fake_a, all0) ad_fake_loss_b = nn.functional.binary_cross_entropy(out_fake_b, all0) if it==0: ad_loss_a = ad_true_loss_a + ad_fake_loss_a ad_loss_b = ad_true_loss_b + ad_fake_loss_b else: ad_loss_a += ad_true_loss_a + ad_fake_loss_a ad_loss_b += ad_true_loss_b + ad_fake_loss_b true_a_acc = _compute_true_acc(out_true_a) true_b_acc = _compute_true_acc(out_true_b) fake_a_acc = _compute_fake_acc(out_fake_a) fake_b_acc = _compute_fake_acc(out_fake_b) exec( 'self.dis_true_acc_%d = 0.5 * (true_a_acc + true_b_acc)' %it) exec( 'self.dis_fake_acc_%d = 0.5 * (fake_a_acc + fake_b_acc)' %it) loss = hyperparameters['gan_w'] * ( ad_loss_a + ad_loss_b ) loss.backward() self.dis_opt.step() self.dis_loss = loss.data.cpu().numpy()[0] return
def shards(state, shard_size, eval_only=False): """ Args: state: A dictionary which corresponds to the output of *LossCompute._make_shard_state(). The values for those keys are Tensor-like or None. shard_size: The maximum size of the shards yielded by the model. eval_only: If True, only yield the state, nothing else. Otherwise, yield shards. Yields: Each yielded shard is a dict. Side effect: After the last shard, this function does back-propagation. """ if eval_only: yield filter_shard_state(state) else: # non_none: the subdict of the state dictionary where the values # are not None. non_none = dict(filter_shard_state(state, shard_size)) # Now, the iteration: # state is a dictionary of sequences of tensor-like but we # want a sequence of dictionaries of tensors. # First, unzip the dictionary into a sequence of keys and a # sequence of tensor-like sequences. keys, values = zip(*((k, [v_chunk for v_chunk in v_split]) for k, (_, v_split) in non_none.items())) # Now, yield a dictionary for each shard. The keys are always # the same. values is a sequence of length #keys where each # element is a sequence of length #shards. We want to iterate # over the shards, not over the keys: therefore, the values need # to be re-zipped by shard and then each shard can be paired # with the keys. for shard_tensors in zip(*values): yield dict(zip(keys, shard_tensors)) # Assumed backprop'd variables = [] for k, (v, v_split) in non_none.items(): if isinstance(v, torch.Tensor) and state[k].requires_grad: variables.extend(zip(torch.split(state[k], shard_size), [v_chunk.grad for v_chunk in v_split])) inputs, grads = zip(*variables) torch.autograd.backward(inputs, grads)
def forward(self, input, hidden_state): hidden,c=hidden_state#hidden and c are images with several channels #print 'hidden ',hidden.size() #print 'input ',input.size() combined = torch.cat((input, hidden), 1)#oncatenate in the channels #print 'combined',combined.size() A=self.conv(combined) (ai,af,ao,ag)=torch.split(A,self.num_features,dim=1)#it should return 4 tensors i=torch.sigmoid(ai) f=torch.sigmoid(af) o=torch.sigmoid(ao) g=torch.tanh(ag) next_c=f*c+i*g next_h=o*torch.tanh(next_c) return next_h, next_c
def forward(self, input_, c_input, hx): """ Args: batch = 1 input_: A (batch, input_size) tensor containing input features. c_input: A list with size c_num,each element is the input ct from skip word (batch, hidden_size). hx: A tuple (h_0, c_0), which contains the initial hidden and cell state, where the size of both states is (batch, hidden_size). Returns: h_1, c_1: Tensors containing the next hidden and cell state. """ h_0, c_0 = hx batch_size = h_0.size(0) #assert(batch_size == 1) bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) wi = torch.mm(input_, self.weight_ih) i, o, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1) i = torch.sigmoid(i) g = torch.tanh(g) o = torch.sigmoid(o) c_num = len(c_input) if c_num == 0: f = 1 - i c_1 = f*c_0 + i*g h_1 = o * torch.tanh(c_1) else: c_input_var = torch.cat(c_input, 0) alpha_bias_batch = (self.alpha_bias.unsqueeze(0).expand(batch_size, *self.alpha_bias.size())) c_input_var = c_input_var.squeeze(1) ## (c_num, hidden_dim) alpha_wi = torch.addmm(self.alpha_bias, input_, self.alpha_weight_ih).expand(c_num, self.hidden_size) alpha_wh = torch.mm(c_input_var, self.alpha_weight_hh) alpha = torch.sigmoid(alpha_wi + alpha_wh) ## alpha = i concat alpha alpha = torch.exp(torch.cat([i, alpha],0)) alpha_sum = alpha.sum(0) ## alpha = softmax for each hidden element alpha = torch.div(alpha, alpha_sum) merge_i_c = torch.cat([g, c_input_var],0) c_1 = merge_i_c * alpha c_1 = c_1.sum(0).unsqueeze(0) h_1 = o * torch.tanh(c_1) return h_1, c_1
def forward(self, input_tensor, cur_state): h_cur, c_cur = cur_state combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis combined_conv = self.conv(combined) cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) i = torch.sigmoid(cc_i) f = torch.sigmoid(cc_f) o = torch.sigmoid(cc_o) g = torch.tanh(cc_g) c_next = f * c_cur + i * g h_next = o * torch.tanh(c_next) return h_next, c_next
def node_forward(self, inputs, child_c, child_h): child_h_sum = torch.mean(child_h, dim=0, keepdim=True) iou = self.ioux(inputs) + self.iouh(child_h_sum) i, o, u = torch.split(iou, iou.size(1) // 3, dim=1) i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u) f = F.sigmoid( self.fh(child_h) + self.fx(inputs).repeat(len(child_h), 1) ) fc = torch.mul(f, child_c) c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True) h = torch.mul(o, F.tanh(c)) return c, h
def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ mask: torch.Tensor = None) -> torch.Tensor: """ Compute a weighted average of the ``tensors``. The input tensors an be any shape with at least two dimensions, but must all be the same shape. When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned ``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape ``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``. When ``do_layer_norm=False`` the ``mask`` is ignored. """ if len(tensors) != self.mixture_size: raise ConfigurationError("{} tensors were passed, but the module was initialized to " "mix {} tensors.".format(len(tensors), self.mixture_size)) def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): tensor_masked = tensor * broadcast_mask mean = torch.sum(tensor_masked) / num_elements_not_masked variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked return (tensor - mean) / torch.sqrt(variance + 1E-12) normed_weights = torch.nn.functional.softmax(torch.cat([parameter for parameter in self.scalar_parameters]), dim=0) normed_weights = torch.split(normed_weights, split_size=1) if not self.do_layer_norm: pieces = [] for weight, tensor in zip(normed_weights, tensors): pieces.append(weight * tensor) return self.gamma * sum(pieces) else: mask_float = mask.float() broadcast_mask = mask_float.unsqueeze(-1) input_dim = tensors[0].size(-1) num_elements_not_masked = torch.sum(mask_float) * input_dim pieces = [] for weight, tensor in zip(normed_weights, tensors): pieces.append(weight * _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked)) return self.gamma * sum(pieces)
def forward(self, x, y, y_mask): """Input shapes: x = batch * len1 * h y = batch * len2 * h y_mask = batch * len2 Output shapes: matched_seq = batch * len1 * h """ # Project vectors x_s = x.repeat(self.n_head,1,1).view(self.n_head,-1,x.size(2)) # n_head * (batch x len1) * input_size #print('y', y.size()) y_s = y.repeat(self.n_head, 1, 1).view(self.n_head, -1, y.size(2)) # n_head * (batch x len2) * input_size #print('y_s', y_s.size()) x_s = torch.bmm(x_s,self.w) # n_head * (batch x len1) * hidden_size y_s = torch.bmm(y_s, self.w) # n_head * (batch x len2) * hidden_size #print('y_s',y_s.size()) x_s = x_s.view(-1,x.size(1),self.hidden_size) # (n_head x batch) * len1 * hhidden_size y_s = y_s.view(-1, y.size(1), self.hidden_size) # (n_head x batch) * len2 * hidden_size #print('y_s', y_s.size()) #x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) x_s_proj = F.relu(x_s) #y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) y_s_proj = F.relu(y_s) # Compute scores scores = x_s_proj.bmm(y_s_proj.transpose(2, 1)) # (n_head x batch) * len1 * len2 # Mask padding y_mask = y_mask.unsqueeze(1).repeat(self.n_head,x.size(1),1) #print('y_mask:',y_mask.size()) #print('scores:', scores.size()) scores.data.masked_fill_(y_mask.data, -float('inf')) # Normalize with softmax alpha_flat = F.softmax(scores.view(-1, y.size(1)),dim=1) alpha = alpha_flat.view(-1, x.size(1), y.size(1)) # Take weighted average matched_seq = alpha.bmm(y_s) # (n_head x batch) * len1 * hidden_size matched_seq = torch.cat(torch.split(matched_seq,x.size(0),dim=0),dim=-1) # batch x len1 x (n_head * hidden_size) return matched_seq
def test_elmo_bilm_can_handle_higher_dimensional_input_with_cache(self): sentences = [["This", "is", "a", "sentence"], ["Here", "'s", "one"], ["Another", "one"]] vocab, tensor = self.get_vocab_and_both_elmo_indexed_ids(sentences) words_to_cache = list(vocab.get_token_to_index_vocabulary("tokens").keys()) elmo_bilm = Elmo(self.options_file, self.weight_file, 1, vocab_to_cache=words_to_cache) elmo_bilm.eval() individual_dim = elmo_bilm(tensor["character_ids"], tensor["tokens"]) elmo_bilm = Elmo(self.options_file, self.weight_file, 1, vocab_to_cache=words_to_cache) elmo_bilm.eval() expanded_word_ids = torch.stack([tensor["tokens"] for _ in range(4)], dim=1) expanded_char_ids = torch.stack([tensor["character_ids"] for _ in range(4)], dim=1) expanded_result = elmo_bilm(expanded_char_ids, expanded_word_ids) split_result = [x.squeeze(1) for x in torch.split(expanded_result["elmo_representations"][0], 1, dim=1)] for expanded in split_result: numpy.testing.assert_array_almost_equal(expanded.data.cpu().numpy(), individual_dim["elmo_representations"][0].data.cpu().numpy())
def forward(self, input_, hx): """ Args: input_: A (batch, input_size) tensor containing input features. hx: A tuple (h_0, c_0), which contains the initial hidden and cell state, where the size of both states is (batch, hidden_size). Returns: h_1, c_1: Tensors containing the next hidden and cell state. """ h_0, c_0 = hx batch_size = h_0.size(0) bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) wi = torch.mm(input_, self.weight_ih) f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1) c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) return c_1
def forward(self, x, char_x, lens): B, T = x.shape # 获取掩码 mask = x.gt(0) # 获取词嵌入向量 x = self.embed(x) # 获取字嵌入向量 char_x = self.char_lstm(char_x[mask]) char_x = pad_sequence(torch.split(char_x, lens.tolist()), True) # 获取词表示与字表示的拼接 x = torch.cat((x, char_x), dim=-1) x = self.drop(x) x = pack_padded_sequence(x, lens, True) x, _ = self.word_lstm(x) x, _ = pad_packed_sequence(x, True) x = self.drop(x) return self.out(x)
def forward(self, k, q): if len(q.shape) == 2: # q_len missing q = torch.unsqueeze(q, dim=1) if len(k.shape) == 2: # k_len missing k = torch.unsqueeze(k, dim=1) mb_size = k.shape[0] # ? k_len = k.shape[1] q_len = q.shape[1] # k: (?, k_len, embed_dim,) # q: (?, q_len, embed_dim,) # kx: (n_head, ?*k_len, embed_dim) -> (n_head*?, k_len, hidden_dim) # qx: (n_head, ?*q_len, embed_dim) -> (n_head*?, q_len, hidden_dim) # score: (n_head*?, q_len, k_len,) # output: (?, q_len, out_dim,) kx = k.repeat(self.n_head, 1, 1).view(self.n_head, -1, self.embed_dim) # (n_head, ?*k_len, embed_dim) qx = q.repeat(self.n_head, 1, 1).view(self.n_head, -1, self.embed_dim) # (n_head, ?*q_len, embed_dim) kx = torch.bmm(kx, self.w_kx).view(-1, k_len, self.hidden_dim) # (n_head*?, k_len, hidden_dim) qx = torch.bmm(qx, self.w_qx).view(-1, q_len, self.hidden_dim) # (n_head*?, q_len, hidden_dim) if self.score_function == 'scaled_dot_product': kt = kx.permute(0, 2, 1) qkt = torch.bmm(qx, kt) score = torch.div(qkt, math.sqrt(self.hidden_dim)) elif self.score_function == 'mlp': kxx = torch.unsqueeze(kx, dim=1).expand(-1, q_len, -1, -1) qxx = torch.unsqueeze(qx, dim=2).expand(-1, -1, k_len, -1) kq = torch.cat((kxx, qxx), dim=-1) # (n_head*?, q_len, k_len, hidden_dim*2) score = F.tanh(torch.matmul(kq, self.weight)) elif self.score_function == 'bi_linear': qw = torch.matmul(qx, self.weight) kt = kx.permute(0, 2, 1) score = torch.bmm(qw, kt) else: raise RuntimeError('invalid score_function') score = F.softmax(score, dim=-1) output = torch.bmm(score, kx) # (n_head*?, q_len, hidden_dim) output = torch.cat(torch.split(output, mb_size, dim=0), dim=-1) # (?, q_len, n_head*hidden_dim) output = self.proj(output) # (?, q_len, out_dim) output = self.dropout(output) return output
def forward(self, x): x_shape = x.size() # (b, c, h, w) offset = self.offset_filter(x) # (b, 2*c, h, w) offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1) # (b, c, h, w) offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) if not self.input_shape or self.input_shape != x_shape: self.input_shape = x_shape grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2])) # (h, w) grid_w = torch.Tensor(grid_w) grid_h = torch.Tensor(grid_h) if self.cuda: grid_w = grid_w.cuda() grid_h = grid_h.cuda() self.grid_w = nn.Parameter(grid_w) self.grid_h = nn.Parameter(grid_h) offset_w = offset_w + self.grid_w # (b*c, h, w) offset_h = offset_h + self.grid_h # (b*c, h, w) x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1) # (b*c, 1, h, w) x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3)) # (b*c, h, w) x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) # (b, c, h, w) x = self.regular_filter(x) return x
def forward( self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matrix_V, abs_pos_K, abs_pos_V, ): """Forward function. Args: queries ([type]): [description] keys ([type]): [description] time_mask ([type]): [description] attn_mask ([type]): [description] time_matrix_K ([type]): [description] time_matrix_V ([type]): [description] abs_pos_K ([type]): [description] abs_pos_V ([type]): [description] Returns: [type]: [description] """ Q, K, V = self.Q_w(queries), self.K_w(keys), self.V_w(keys) # head dim * batch dim for parallelization (h*N, T, C/h) Q_ = torch.cat(torch.split(Q, self.head_size, dim=2), dim=0) K_ = torch.cat(torch.split(K, self.head_size, dim=2), dim=0) V_ = torch.cat(torch.split(V, self.head_size, dim=2), dim=0) time_matrix_K_ = torch.cat( torch.split(time_matrix_K, self.head_size, dim=3), dim=0 ) time_matrix_V_ = torch.cat( torch.split(time_matrix_V, self.head_size, dim=3), dim=0 ) abs_pos_K_ = torch.cat(torch.split(abs_pos_K, self.head_size, dim=2), dim=0) abs_pos_V_ = torch.cat(torch.split(abs_pos_V, self.head_size, dim=2), dim=0) # batched channel wise matmul to gen attention weights attn_weights = Q_.matmul(torch.transpose(K_, 1, 2)) attn_weights += Q_.matmul(torch.transpose(abs_pos_K_, 1, 2)) attn_weights += time_matrix_K_.matmul(Q_.unsqueeze(-1)).squeeze(-1) # seq length adaptive scaling attn_weights = attn_weights / (K_.shape[-1] ** 0.5) # key masking, -2^32 lead to leaking, inf lead to nan # 0 * inf = nan, then reduce_sum([nan,...]) = nan # time_mask = time_mask.unsqueeze(-1).expand(attn_weights.shape[0], -1, attn_weights.shape[-1]) time_mask = time_mask.unsqueeze(-1).repeat(self.head_num, 1, 1) time_mask = time_mask.expand(-1, -1, attn_weights.shape[-1]) attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1) paddings = torch.ones(attn_weights.shape) * ( -(2 ** 32) + 1 ) # -1e23 # float('-inf') paddings = paddings.to("cuda") attn_weights = torch.where( time_mask, paddings, attn_weights ) # True:pick padding attn_weights = torch.where( attn_mask, paddings, attn_weights ) # enforcing causality attn_weights = self.softmax( attn_weights ) # code as below invalids pytorch backward rules # attn_weights = torch.where(time_mask, paddings, attn_weights) # weird query mask in tf impl # https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/4 # attn_weights[attn_weights != attn_weights] = 0 # rm nan for -inf into softmax case attn_weights = self.dropout(attn_weights) outputs = attn_weights.matmul(V_) outputs += attn_weights.matmul(abs_pos_V_) outputs += ( attn_weights.unsqueeze(2) .matmul(time_matrix_V_) .reshape(outputs.shape) .squeeze(2) ) # (num_head * N, T, C / num_head) -> (N, T, C) outputs = torch.cat( torch.split(outputs, Q.shape[0], dim=0), dim=2 ) # div batch_size return outputs
def forward(self, x): split = torch.split(x, self.half, 1) out1 = self.IN(split[0].contiguous()) out2 = self.BN(split[1].contiguous()) out = torch.cat((out1, out2), 1) return out
def _upward_downward(self, layer, direction, inputs, tree, idx): # check to see whether this node has been computed on this # layer in this direction, if so short circuit the rest of # this function and return that result if idx in self.hidden_state[layer][direction]: h_t = self.hidden_state[layer][direction][idx] c_t = self.cell_state[layer][direction][idx] return h_t, c_t x_t = self._construct_x_t(layer, inputs, idx, tree) oidx, (h_prev, c_prev) = self._construct_previous(layer, direction, inputs, tree, idx) if self.bias: Wih, Whh, bih, bhh = self._get_parameters(layer, direction) # print(Wih.size()) # print(Whh.size()) # print(bih.size()) # print(bhh.size()) # print(x_t.size()) # print(h_prev.size()) fcio_t_raw = torch.matmul(Whh, h_prev) +\ torch.matmul(Wih, x_t[:, None]) +\ bhh[:, None] + bih[:, None] else: Wih, Whh = self._get_parameters(layer, direction) fcio_t_raw = torch.matmul(Whh, h_prev) +\ torch.matmul(Wih, x_t[:, None]) f_t_raw, c_hat_t_raw, i_t_raw, o_t_raw = torch.split(fcio_t_raw, self.hidden_size, dim=0) f_t = F.sigmoid(f_t_raw) gated_children = torch.mul(f_t, c_prev) gated_children = torch.sum(gated_children, 1, keepdim=False) c_hat_t_raw = torch.sum(c_hat_t_raw, 1, keepdim=False) i_t_raw = torch.sum(i_t_raw, 1, keepdim=False) o_t_raw = torch.sum(o_t_raw, 1, keepdim=False) c_hat_t = self.__class__.nonlinearity(c_hat_t_raw) i_t = F.sigmoid(i_t_raw) o_t = F.sigmoid(o_t_raw) c_t = gated_children + torch.mul(i_t, c_hat_t) h_t = torch.mul(o_t, self.__class__.nonlinearity(c_t)) if self.dropout: dropout = Dropout(p=self.dropout) h_t = dropout(h_t) c_t = dropout(c_t) self.hidden_state[layer][direction][idx] = h_t self.cell_state[layer][direction][idx] = c_t if direction == 'up' and self.bidirectional: self._upward_downward(layer, 'down', inputs, tree, idx) return h_t, c_t
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D']) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() # If accumulating gradients, loop multiple times for accumulation_index in range(config['num_G_accumulations']): z_.sample_() y_.sample_() D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations']) G_loss.backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = {'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item())} # Return G's loss and the components of D's loss. return out
def train_ctr_phase(self): self.max_epoch = -1 self.max_val_score = 0.0 for epoch in range(self.args.epochs): if epoch == 0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag): data_match_tensor, label_match_tensor = self.get_match_function( epoch) penalty_same_ctr = 0 penalty_diff_ctr = 0 penalty_same_hinge = 0 penalty_diff_hinge = 0 train_acc = 0.0 train_size = 0 perm = torch.randperm(data_match_tensor.size(0)) data_match_tensor_split = torch.split(data_match_tensor[perm], self.args.batch_size, dim=0) label_match_tensor_split = torch.split(label_match_tensor[perm], self.args.batch_size, dim=0) print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split)) #Batch iteration over single epoch for batch_idx, (x_e, y_e, d_e, idx_e) in enumerate(self.train_dataset): # print('Batch Idx: ', batch_idx) self.opt.zero_grad() loss_e = torch.tensor(0.0).to(self.cuda) x_e = x_e.to(self.cuda) y_e = torch.argmax(y_e, dim=1).to(self.cuda) d_e = torch.argmax(d_e, dim=1).numpy() same_ctr_loss = torch.tensor(0.0).to(self.cuda) diff_ctr_loss = torch.tensor(0.0).to(self.cuda) same_hinge_loss = torch.tensor(0.0).to(self.cuda) diff_hinge_loss = torch.tensor(0.0).to(self.cuda) if epoch > self.args.penalty_s: # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split total_batch_size = len(data_match_tensor_split) if batch_idx >= total_batch_size: break curr_batch_size = data_match_tensor_split[batch_idx].shape[ 0] # data_match= data_match_tensor[idx].to(cuda) data_match = data_match_tensor_split[batch_idx].to( self.cuda) data_match = data_match.view( data_match.shape[0] * data_match.shape[1], data_match.shape[2], data_match.shape[3], data_match.shape[4]) feat_match = self.phi(data_match) # label_match= label_match_tensor[idx].to(self.cuda) label_match = label_match_tensor_split[batch_idx].to( self.cuda) label_match = label_match.view(label_match.shape[0] * label_match.shape[1]) # Creating tensor of shape ( domain size, total domains, feat size ) if len(feat_match.shape) == 4: feat_match = feat_match.view( curr_batch_size, len(self.train_domains), feat_match.shape[1] * feat_match.shape[2] * feat_match.shape[3]) else: feat_match = feat_match.view(curr_batch_size, len(self.train_domains), feat_match.shape[1]) label_match = label_match.view(curr_batch_size, len(self.train_domains)) # print(feat_match.shape) data_match = data_match.view(curr_batch_size, len(self.train_domains), data_match.shape[1], data_match.shape[2], data_match.shape[3]) # Contrastive Loss same_neg_counter = 1 diff_neg_counter = 1 for y_c in range(self.args.out_classes): pos_indices = label_match[:, 0] == y_c neg_indices = label_match[:, 0] != y_c pos_feat_match = feat_match[pos_indices] neg_feat_match = feat_match[neg_indices] # if pos_feat_match.shape[0] > neg_feat_match.shape[0]: # print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0]) # If no instances of label y_c in the current batch then continue if pos_feat_match.shape[ 0] == 0 or neg_feat_match.shape[0] == 0: continue # Iterating over anchors from different domains for d_i in range(pos_feat_match.shape[1]): if torch.sum(torch.isnan(neg_feat_match)): print('Non Reshaped X2 is Nan') sys.exit() diff_neg_feat_match = neg_feat_match.view( neg_feat_match.shape[0] * neg_feat_match.shape[1], neg_feat_match.shape[2]) if torch.sum(torch.isnan(diff_neg_feat_match)): print('Reshaped X2 is Nan') sys.exit() neg_dist = embedding_dist( pos_feat_match[:, d_i, :], diff_neg_feat_match[:, :], self.args.pos_metric, self.args.tau, xent=True) if torch.sum(torch.isnan(neg_dist)): print('Neg Dist Nan') sys.exit() # Iterating pos dist for current anchor for d_j in range(pos_feat_match.shape[1]): if d_i != d_j: pos_dist = 1.0 - embedding_dist( pos_feat_match[:, d_i, :], pos_feat_match[:, d_j, :], self.args.pos_metric) pos_dist = pos_dist / self.args.tau if torch.sum(torch.isnan(neg_dist)): print('Pos Dist Nan') sys.exit() if torch.sum( torch.isnan( torch.log( torch.exp(pos_dist) + neg_dist))): print('Xent Nan') sys.exit() # print( 'Pos Dist', pos_dist ) # print( 'Log Dist ', torch.log( torch.exp(pos_dist) + neg_dist )) diff_hinge_loss += -1 * torch.sum( pos_dist - torch.log( torch.exp(pos_dist) + neg_dist)) diff_ctr_loss += torch.sum(neg_dist) diff_neg_counter += pos_dist.shape[0] same_ctr_loss = same_ctr_loss / same_neg_counter diff_ctr_loss = diff_ctr_loss / diff_neg_counter same_hinge_loss = same_hinge_loss / same_neg_counter diff_hinge_loss = diff_hinge_loss / diff_neg_counter penalty_same_ctr += float(same_ctr_loss) penalty_diff_ctr += float(diff_ctr_loss) penalty_same_hinge += float(same_hinge_loss) penalty_diff_hinge += float(diff_hinge_loss) loss_e += ((epoch - self.args.penalty_s) / (self.args.epochs - self.args.penalty_s)) * diff_hinge_loss loss_e.backward(retain_graph=False) self.opt.step() del same_ctr_loss del diff_ctr_loss del same_hinge_loss del diff_hinge_loss torch.cuda.empty_cache() print('Train Loss Ctr : ', penalty_same_ctr, penalty_diff_ctr, penalty_same_hinge, penalty_diff_hinge) print('Done Training for epoch: ', epoch) if (epoch + 1) % 10 == 0: from evaluation.match_eval import MatchEval test_method = MatchEval(self.args, self.train_dataset, self.val_dataset, self.test_dataset, self.base_res_dir, self.run, self.cuda) #Compute test metrics: Mean Rank test_method.phi = self.phi test_method.get_metric_eval() # Save the model's weights post training if test_method.metric_score[ 'TopK Perfect Match Score'] > self.max_val_score: self.max_val_score = test_method.metric_score[ 'TopK Perfect Match Score'] self.max_epoch = epoch self.save_model_ctr_phase(epoch) print('Current Best Epoch: ', self.max_epoch, ' with TopK Overlap: ', self.max_val_score)
def forward(self, x): size1 = x.size(1) / 2 x = torch.split(x, size1, 1) x = torch.max(x[0], x[1]) return x
def expected_calls(self, data, model, mode, callback): if not hasattr(model, callback.attr_name): return [] if callback.split_channels == (2, 1) or callback.split_channels == ([ 2, 1 ], None): pytest.skip("incompatible test") data1, data2 = data data1 = prepare_image(data1) data2 = prepare_image(data2) B, _, _, _ = data1.shape img1 = [data1] img2 = [data2] name1 = [ f"{mode}/{callback.name}", ] name2 = [ f"{mode}/{callback.name}", ] img = [img1, img2] name = [name1, name2] # channel splitting for pos in range(2): if callback.split_channels[pos]: img[pos] = [] name[pos] = [] img_new, name_new = [], [] splits = torch.split(data[pos], callback.split_channels[pos], dim=-3) for i, s in enumerate(splits): if isinstance(callback.name, str): n = f"{mode}/{callback.name}_{i}" else: n = f"{mode}/{callback.name[i]}" name_new.append(n) img_new.append(s) img[pos] = img_new name[pos] = name_new if len(img[0]) != len(img[1]): if len(img[0]) == 1: img[0] = img[0] * len(img[1]) elif len(img[1]) == 1: img[1] = img[1] * len(img[0]) else: raise RuntimeError() for pos in range(2): if callback.max_resolution: resize_mode = callback.resize_mode target = callback.max_resolution H_max, W_max = target needs_resize = [ i.shape[-2] > H_max or i.shape[-1] > W_max for i in img[pos] ] img[pos] = [ F.interpolate(i, target, mode=resize_mode) if resize else i for i, resize in zip(img[pos], needs_resize) ] for pos in range(2): if (colormap := callback.colormap[pos]): if isinstance(colormap, str): colormap = [colormap] * len(img[pos]) img[pos] = [ apply_colormap(i, cmap)[..., :3, :, :] if cmap is not None else i for cmap, i in zip(colormap, img[pos]) ]
class TestVisualizeCallback: callback_cls = VisualizeCallback # training, validation, or testing mode @pytest.fixture(params=["train", "val", "test"]) def mode(self, request): return request.param @pytest.fixture def data_shape(self): return 2, 3, 32, 32 @pytest.fixture def data(self, data_shape): data = create_image(*data_shape) return data @pytest.fixture(autouse=True) def model(self, request, mocker, callback, data, mode, trainer): if hasattr(request, "param"): step = request.param.pop("step", 10) epoch = request.param.pop("epoch", 1) else: step = 10 epoch = 1 model = mocker.MagicMock(name="module") model.current_epoch = epoch model.global_step = step model.global_step = step if callback.attr_name is not None: setattr(model, callback.attr_name, data) if mode == "train": attr = "on_train_batch_end" elif mode == "val": attr = "on_validation_batch_end" elif mode == "test": attr = "on_test_batch_end" else: raise ValueError(f"{mode}") callback.trigger = lambda: getattr(callback, attr)(trainer, model) return model @pytest.fixture def callback(self, request): cls = self.callback_cls init_signature = inspect.signature(cls) defaults = { k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty } if hasattr(request, "param"): name = request.param.get("name", "image") defaults.update(request.param) else: name = "image" defaults["name"] = name callback = cls(**defaults) return callback @pytest.fixture def logger_func(self, model): return model.logger.experiment.add_images @pytest.fixture def expected_calls(self, data, model, mode, callback): if not hasattr(model, callback.attr_name): return [] data = prepare_image(data) B, _, H, W = data.shape img = [data] name = [ f"{mode}/{callback.name}", ] # channel splitting if callback.split_channels: img, name = [], [] splits = torch.split(data, callback.split_channels, dim=-3) for i, s in enumerate(splits): if isinstance(callback.name, str): n = f"{mode}/{callback.name}_{i}" else: n = f"{mode}/{callback.name[i]}" name.append(n) img.append(s) if callback.max_resolution: resize_mode = callback.resize_mode target = callback.max_resolution H_max, W_max = target scale_factor = [] for i in img: H, W = i.shape[-2:] height_ratio, width_ratio = H / H_max, W / W_max s = 1 / max(height_ratio, width_ratio) scale_factor.append(s) img = [ F.interpolate(i, scale_factor=s, mode=resize_mode) if s < 1 else i for i, s in zip(img, scale_factor) ] if (colormap := callback.colormap): if isinstance(colormap, str): colormap = [colormap] * len(img) img = [ apply_colormap(i, cmap)[..., :3, :, :] if cmap is not None else i for cmap, i in zip(colormap, img) ] if callback.split_batches: new_img, new_name = [], [] for i, n in zip(img, name): split_i = torch.split(i, 1, dim=0) split_n = [f"{n}/{b}" for b in range(B)] new_img += split_i new_name += split_n name, img = new_name, new_img if callback.as_uint8: img = [ to_8bit(i, same_on_batch=not callback.per_img_norm) for i in img ] step = [ model.current_epoch if callback.epoch_counter else model.global_step ] * len(name) expected = [(n, i, s) for n, i, s in zip(name, img, step)] return expected
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): # encoder_padding_mask = encoder_out['encoder_padding_mask'] encoder_out = encoder_out['encoder_out'] bsz, seqlen, _ = prev_output_tokens.size() # get outputs from encoder encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] # srclen = encoder_outs.size(0) x = prev_output_tokens #x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C 10,32,16 x = x.transpose(0, 1) # for_logging = ((x*self.all_stds)+self.all_means).cpu().detach().numpy() # #from fairseq import pdb; pdb.set_trace(); # wandb.log( # {'mean_input_velocities': for_logging[:,:,1::self.num_var_per_segment].mean(), # 'mean_input_velocities_1': for_logging[:,:,1].mean(), # 'mean_input_velocities_4': for_logging[:,:,-3].mean(), # 'mean_input_densities': for_logging[:,:,0::self.num_var_per_segment].mean(), # 'mean_input_density_1': for_logging[:,:,0].mean()#, # # 'mean_onramp_flows': for_logging[:,:,::self.num_var_per_segment].mean(), # # 'mean_offramp_flows': for_logging[:,:,::self.num_var_per_segment].mean() # } # ) if self.encoder_hidden_proj != None: prev_hiddens = self.encoder_hidden_proj(encoder_hiddens[0, :, :]) prev_cells = self.encoder_cell_proj(encoder_cells[0, :, :]) else: prev_hiddens = encoder_hiddens[0, :, :] prev_cells = encoder_cells[0, :, :] # input_feed = torch.sigmoid(self.encoder_hidden_to_input_feed_proj(prev_hiddens)) if self.encoder_hidden_to_input_feed_proj != None: if self.input_feed_activation != None: input_feed = self.input_feed_activation( self.encoder_hidden_to_input_feed_proj( encoder_hiddens[0, :, :])) else: if self.extra_hidden_layer: extra_hidden = torch.relu( self.encoder_hidden_to_decoder_input_feed_hidden_layer( encoder_hiddens[0, :, :])) else: extra_hidden = encoder_hiddens[0, :, :] input_feed = self.encoder_hidden_to_input_feed_proj( extra_hidden) # input_feed = self.encoder_hidden_to_input_feed_proj(encoder_hiddens[0,:,:]) else: if self.input_feed_activation != None: input_feed = self.input_feed_activation( encoder_hiddens[0, :, :]) else: input_feed = encoder_hiddens[0, :, :] self.first_input_feed = input_feed outs = [] common_params_list = [] segment_params_list = [] flow_res_list = [] for j in range(seqlen): #input_to_rnn = torch.cat((x[j, :,:], input_feed), dim=1) # hidden, cell = self.rnn(input_to_rnn, (prev_hiddens, prev_cells)) input_x = ((x[j, :, :] * self.input_stds) + self.input_means ) #+ torch.Tensor([0.5]).float() # input_x = F.dropout(input_x, p=self.dropout_in, training=self.training) # ['Seg00_q', 'Seg00_speed','Seg04_q', 'Seg04_speed','Seg04_r', 'Seg02_s'] # T x (B x C) q0_i, v0_i, q4_i, v4_i, r4_i, s2_i = torch.unbind(input_x, dim=1) unnormed_input_feed = (input_feed * self.all_stds) + self.all_means rho1, v1, r1, s1, rho2, v2, r2, s2, rho3, v3, r3, s3, rho4, v4, r4, s4 = torch.unbind( unnormed_input_feed, dim=1) # q4 = rho4*v4*3.0 # q4 = q4_i if q4_i>0 else q4 rho4 = (q4_i / (v4_i * 3.0)) * ( (q4_i / (v4_i * 3.0)) > 0).float() + rho4 * ( (q4_i / (v4_i * 3.0)) <= 0).float() #if (q4_i/(v4_i*3.0))>0 else rho4 v4 = v4_i * ((v4_i > 0).float()) + v4 * ( (v4_i <= 0).float()) #if v4_i>0 else v4 r4 = r4_i * ((r4_i > 0).float()) + r4 * ( (r4_i <= 0).float()) #if r4_i>0 else r4 s2 = s2_i * ((s2_i > 0).float()) + s2 * ( (s2_i <= 0).float()) #if s2_i>0 else s2 real_size_input = torch.stack([ rho1, v1, r1, s1, rho2, v2, r2, s2, rho3, v3, r3, s3, rho4, v4, r4, s4 ], dim=1) # real_size_input = (blended_input * self.all_stds) + self.all_means #input_x = F.dropout(input_x, p=self.dropout_in, training=self.training) #input_feed = input_feed #+ torch.Tensor([0.5]).float() # input_mask = (input_x*self.max_vals) > 0.0 # input_mask = ((input_x*self.all_stds)+self.all_means) > 0.0 # input_mask = input_mask*0.0 # blended_input = (input_x*input_mask.float()) + ( (1-input_mask.float())*(input_feed)) # input_feed_mask = ((input_feed*self.all_stds)+self.all_means) > 0.0 # blended_input = input_feed * input_feed_mask.float() # hidden, cell = self.rnn(blended_input, (prev_hiddens, prev_cells)) hidden, cell = self.rnn(x[j, :, :], (prev_hiddens, prev_cells)) prev_hiddens = hidden #for next loop prev_cells = cell ntf_params = self.ntf_projection(hidden) #NTF common_params, segment_params = torch.split( ntf_params, [self.num_common_params, self.total_segment_specific_params], dim=1) if self.common_param_activation != None: common_params = self.common_param_activation(common_params) common_params = (self.common_param_multipliers * common_params) + self.common_param_additions v0, q0, rhoNp1, vf, a_var, rhocr = torch.unbind(common_params, dim=1) #, g_var g_var = torch.Tensor([[1.0]]) v0 = v0_i * ((v0_i > 0).float()) + v0 * ( (v0_i <= 0).float()) #if v0_i>0 else v0 q0 = q0_i * ((q0_i > 0).float()) + q0 * ( (q0_i <= 0).float()) # if q0_i>0 else q0 # vf = vf.detach() #* 0.0 +120.0 # a_var = a_var.detach() #* 0.0 + 1.4 # rhocr = rhocr.detach() #* 0.0 + 30. # g_var = g_var.detach() #*0.0 + 1.0 if self.segment_param_activation != None: segment_params = self.segment_param_activation(segment_params) else: segment_params = segment_params segment_params = segment_params.view( (-1, self.num_segment_specific_params, self.num_segments)) segment_params = segment_params * self.segment_param_multipliers + self.segment_param_additions future_r, future_s = torch.unbind(segment_params, dim=1) # real_size_input = blended_input*self.max_vals # real_size_input = real_size_input * input_feed_mask.float() model_steps = [] for _ in range(self.num_ntf_steps): one_ntf_output, flow_res = self.ntf_module( x=real_size_input, v0=v0, q0=q0, rhoNp1=rhoNp1, vf=vf, a_var=a_var, rhocr=rhocr,\ g_var=g_var, future_r=future_r, future_s=future_s) real_size_input = one_ntf_output flow_res_list.append(flow_res) model_steps.append(one_ntf_output) # mean_ntf_output = torch.stack(model_steps, dim=0).mean(dim=0) mean_ntf_output = real_size_input # scaled_output = mean_ntf_output/(self.max_vals+1e-6) normed_output = (mean_ntf_output - self.all_means) / (self.all_stds) common_params_list.append(common_params) segment_params_list.append(segment_params) # outs.append(scaled_output) outs.append(normed_output) input_feed = normed_output #- torch.Tensor([0.5]).float() # collect outputs across time steps # dim=1 to go from T x B x C -> B x T x C self.returned_out = torch.stack(outs, dim=1) self.all_common_params = torch.stack(common_params_list, dim=1) self.all_segment_params = torch.stack(segment_params_list, dim=1) v0_a, q0_a, rhoNp1_a, vf_a, a_var_a, rhocr_a = torch.unbind( self.all_common_params, dim=2) q0_a = (q0_a - 3000.) / 2000. v0_a = (v0_a - 90.) / 20. rho1_a, v1_a, r1_a, s1_a, rho2_a, v2_a, r2_a, s2_a, rho3_a, v3_a, r3_a, s3_a, rho4_a, v4_a, r4_a, s4_a = torch.unbind( self.returned_out, dim=2) # q4 = rho4 * v4 * 3.0 q4_a = ((rho4_a * 15) + 15) * ( (v4_a * 20) + 90) * 3.0 #3 lanes lambda 4 q4_a = (q4_a - 3000.) / 2000. # v4 = v4 # r4 = # s2 = new_out = torch.stack([q0_a, v0_a, q4_a, v4_a, r4_a, s2_a], dim=2) self.mean_flow_res = torch.stack(flow_res_list, dim=2).sum(axis=1).abs().mean(axis=1) # return returned_out, self.all_common_params, self.all_segment_params return new_out, self.all_common_params, self.all_segment_params
def g(self, t, y): # Diagonal diffusion. y = torch.split(y, split_size_or_sections=1, dim=1) out = [g_net_i(y_i) for (g_net_i, y_i) in zip(self.g_nets, y)] return torch.cat(out, dim=1)
def forward(self, input): embedding = self.main(input).view(-1, 64 * 8 * 2 * 4) embedding = self.flatten(embedding) ha, hg = torch.split(embedding, [opt.ha_dim, opt.hg_dim], dim=1) return ha, hg, embedding
def forward(self, x): out = self.conv1(x) outs = torch.split(out, self.core_list, dim=1) #print(outs[0].size(),outs[1].size(),outs[2].size()) #split the input features #print(outs[0].size(),outs[1].size()) #outs = [self.module0_1(outs[0]),self.module0_2(outs[1])] outs = [ self.trans0_1(self.module0_1(outs[0].contiguous())), self.trans0_2(self.module0_2(outs[1].contiguous())), self.trans0_3(self.module0_3(outs[2].contiguous())) ] #print(outs[0].size(),outs[1].size(),outs[2].size()) out_temp = [] for ind, i in enumerate(self.graph[0]): for index, k in enumerate(i): if index == 0: out_temp.append(outs[k]) #print(out_temp[ind].size()) #print(k) else: out_temp[ind] = torch.cat((out_temp[ind], outs[k]), dim=1) outs = out_temp #outs = [self.module1_1(outs[0]),self.module1_2(outs[1]),self.module1_3(outs[2])] outs = [ self.trans1_1(self.module1_1(outs[0].contiguous())), self.trans1_2(self.module1_2(outs[1].contiguous())), self.trans1_3(self.module1_3(outs[2].contiguous())) ] #print(outs[0].size(),outs[1].size(),outs[2].size()) out_temp = [] for ind, i in enumerate(self.graph[1]): for index, k in enumerate(i): if index == 0: out_temp.append(outs[k]) #print(out_temp[ind].size()) #print(k) else: out_temp[ind] = torch.cat((out_temp[ind], outs[k]), dim=1) outs = out_temp outs = [ self.trans2_1(self.module2_1(outs[0].contiguous())), self.trans2_2(self.module2_2(outs[1].contiguous())), self.trans2_3(self.module2_3(outs[2].contiguous())) ] #print(outs[0].size(),outs[1].size(),outs[2].size()) out_temp = [] for ind, i in enumerate(self.graph[2]): for index, k in enumerate(i): if index == 0: out_temp.append(outs[k]) #print(out_temp[ind].size()) #print(k) else: out_temp[ind] = torch.cat((out_temp[ind], outs[k]), dim=1) outs = out_temp outs = [ self.trans3_1(self.module3_1(outs[0].contiguous())), self.trans3_2(self.module3_2(outs[1].contiguous())), self.trans3_3(self.module3_3(outs[2].contiguous())) ] #print(outs[0].size(),outs[1].size(),outs[2].size()) out = torch.cat((outs[0], outs[1], outs[2]), dim=1) #print(out.size()) out = F.avg_pool2d(F.relu(self.bn(out.contiguous())), 2) #print(out.size()) out = out.view(out.size(0), -1) #print(out.size()) #out = self.linear(out.contiguous()) out = self.Linear_01(out.contiguous()) return out
def mean_axis(self, xs, axis): y = list(map(lambda x: torch.mean(x, 0), torch.split(xs, axis))) return torch.stack(y)
def __call__( self, batch: Dict[str, Union[torch.Tensor, np.ndarray]] ) -> List[Tuple[Optional[str], List[str], List[int], float]]: """Inference Args: batch: Input speech data and corresponding lengths Returns: text, token, token_int, hyp """ assert check_argument_types() if isinstance(batch["speech"], np.ndarray): batch["speech"] = torch.tensor(batch["speech"]) if isinstance(batch["speech_lengths"], np.ndarray): batch["speech_lengths"] = torch.tensor(batch["speech_lengths"]) # a. To device batch = to_device(batch, device=self.device) # b. Forward Encoder # enc: [N, T, C] enc, encoder_out_lens = self.asr_model.encode(**batch) # logp_encoder_output: [N, T, C] logp_encoder_output = torch.nn.functional.log_softmax( self.asr_model.ctc.ctc_lo(enc), dim=2 ) # It maybe useful to tune blank_bias. # The valid range of blank_bias is [-inf, 0] logp_encoder_output[:, :, 0] += self.blank_bias batch_size = encoder_out_lens.size(0) sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(torch.int32) start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t() num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32) supervision_segments = torch.cat([sequence_idx, start_frame, num_frames], dim=1) supervision_segments = supervision_segments.to(torch.int32) # An introduction to DenseFsaVec: # https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector # It could be viewed as a fsa-type lopg_encoder_output, # whose weight on the arcs are initialized with logp_encoder_output. # The goal of converting tensor-type to fsa-type is using # fsa related functions in k2. e.g. k2.intersect_dense_pruned below dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments) # The term "intersect" is similar to "compose" in k2. # The differences is are: # for "compose" functions, the composition involves # mathcing output label of a.fsa and input label of b.fsa # while for "intersect" functions, the composition involves # matching input label of a.fsa and input label of b.fsa # Actually, in compose functions, b.fsa is inverted and then # a.fsa and inv_b.fsa are intersected together. # For difference between compose and interset: # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/fsa_algo.py#L308 # For definition of k2.intersect_dense_pruned: # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/autograd.py#L648 lattices = k2.intersect_dense_pruned( self.decode_graph, dense_fsa_vec, self.search_beam_size, self.output_beam_size, self.min_active_states, self.max_active_states, ) # lattices.scores is the sum of decode_graph.scores(a.k.a. lm weight) and # dense_fsa_vec.scores(a.k.a. am weight) on related arcs. # For ctc decoding graph, lattices.scores only store am weight # since the decoder_graph only define the ctc topology and # has no lm weight on its arcs. # While for 3-gram decoding, whose graph is converted from language models, # lattice.scores contains both am weights and lm weights # # It maybe useful to tune lattice.scores # The valid range of lattice_weight is [0, inf) # The lattice_weight will affect the search of k2.random_paths lattices.scores *= self.lattice_weight results = [] if self.use_nbest_rescoring: ( am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits, ) = nbest_am_lm_scores( lattices, self.num_paths, self.device, self.nbest_batch_size ) ys_pad_lens = torch.tensor([len(hyp) for hyp in token_ids]).to(self.device) max_token_length = max(ys_pad_lens) ys_pad_list = [] for hyp in token_ids: ys_pad_list.append( torch.cat( [ torch.tensor(hyp, dtype=torch.long), torch.tensor( [self.asr_model.ignore_id] * (max_token_length.item() - len(hyp)), dtype=torch.long, ), ] ) ) ys_pad = ( torch.stack(ys_pad_list).to(torch.long).to(self.device) ) # [batch, max_token_length] encoder_out = enc.index_select(0, path_to_seq_map.to(torch.long)).to( self.device ) # [batch, T, dim] encoder_out_lens = encoder_out_lens.index_select( 0, path_to_seq_map.to(torch.long) ).to( self.device ) # [batch] decoder_scores = -self.asr_model.batchify_nll( encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, self.nll_batch_size ) # padded_value for nnlm is 0 ys_pad[ys_pad == self.asr_model.ignore_id] = 0 nnlm_nll, x_lengths = self.lm.batchify_nll( ys_pad, ys_pad_lens, self.nll_batch_size ) nnlm_scores = -nnlm_nll.sum(dim=1) batch_tot_scores = ( self.am_weight * am_scores + self.decoder_weight * decoder_scores + self.nnlm_weight * nnlm_scores ) split_size = indices_to_split_size( seq_to_path_splits.tolist(), total_elements=batch_tot_scores.size(0) ) batch_tot_scores = torch.split( batch_tot_scores, split_size, ) hyps = [] scores = [] processed_seqs = 0 for tot_scores in batch_tot_scores: if tot_scores.nelement() == 0: # the last element by torch.tensor_split may be empty # e.g. # torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4])) # (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64)) break best_seq_idx = processed_seqs + torch.argmax(tot_scores) assert best_seq_idx < len(token_ids) best_token_seqs = token_ids[best_seq_idx] processed_seqs += tot_scores.nelement() hyps.append(best_token_seqs) scores.append(tot_scores.max().item()) assert len(hyps) == len(split_size) else: best_paths = k2.shortest_path(lattices, use_double_scores=True) scores = best_paths.get_tot_scores( use_double_scores=True, log_semiring=False ).tolist() hyps = get_texts(best_paths) assert len(scores) == len(hyps) for token_int, score in zip(hyps, scores): # For decoding methods nbest_rescoring and ctc_decoding # hyps stores token_index, which is lattice.labels. # convert token_id to text with self.tokenizer token = self.converter.ids2tokens(token_int) assert self.tokenizer is not None text = self.tokenizer.tokens2text(token) results.append((text, token, token_int, score)) assert check_return_type(results) return results
def forward(self, x): x = self.filter(x) out = torch.split(x, self.out_channels, 1) # split channels in CNN or split D in FC return torch.max(out[0], out[1])
def train_erm_phase(self): for run_erm in range(self.args.n_runs_matchdg_erm): self.max_epoch = -1 self.max_val_acc = 0.0 for epoch in range(self.args.epochs): if epoch == 0: data_match_tensor, label_match_tensor = self.init_erm_phase( ) elif epoch % self.args.match_interrupt == 0 and self.args.match_flag: data_match_tensor, label_match_tensor = self.get_match_function( epoch) penalty_erm = 0 penalty_erm_extra = 0 penalty_ws = 0 train_acc = 0.0 train_size = 0 perm = torch.randperm(data_match_tensor.size(0)) data_match_tensor_split = torch.split(data_match_tensor[perm], self.args.batch_size, dim=0) label_match_tensor_split = torch.split( label_match_tensor[perm], self.args.batch_size, dim=0) print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split)) #Batch iteration over single epoch for batch_idx, (x_e, y_e, d_e, idx_e) in enumerate(self.train_dataset): # print('Batch Idx: ', batch_idx) self.opt.zero_grad() loss_e = torch.tensor(0.0).to(self.cuda) x_e = x_e.to(self.cuda) y_e = torch.argmax(y_e, dim=1).to(self.cuda) d_e = torch.argmax(d_e, dim=1).numpy() #Forward Pass out = self.phi(x_e) erm_loss_extra = F.cross_entropy(out, y_e.long()).to(self.cuda) penalty_erm_extra += float(erm_loss_extra) wasserstein_loss = torch.tensor(0.0).to(self.cuda) erm_loss = torch.tensor(0.0).to(self.cuda) if epoch > self.args.penalty_s: # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split total_batch_size = len(data_match_tensor_split) if batch_idx >= total_batch_size: break curr_batch_size = data_match_tensor_split[ batch_idx].shape[0] # data_match= data_match_tensor[idx].to(self.cuda) data_match = data_match_tensor_split[batch_idx].to( self.cuda) data_match = data_match.view( data_match.shape[0] * data_match.shape[1], data_match.shape[2], data_match.shape[3], data_match.shape[4]) feat_match = self.phi(data_match) # label_match= label_match_tensor[idx].to(self.cuda) label_match = label_match_tensor_split[batch_idx].to( self.cuda) label_match = label_match.view(label_match.shape[0] * label_match.shape[1]) erm_loss += F.cross_entropy(feat_match, label_match.long()).to( self.cuda) penalty_erm += float(erm_loss) train_acc += torch.sum( torch.argmax(feat_match, dim=1) == label_match).item() train_size += label_match.shape[0] # Creating tensor of shape ( domain size, total domains, feat size ) if len(feat_match.shape) == 4: feat_match = feat_match.view( curr_batch_size, len(self.train_domains), feat_match.shape[1] * feat_match.shape[2] * feat_match.shape[3]) else: feat_match = feat_match.view( curr_batch_size, len(self.train_domains), feat_match.shape[1]) label_match = label_match.view(curr_batch_size, len(self.train_domains)) # print(feat_match.shape) data_match = data_match.view(curr_batch_size, len(self.train_domains), data_match.shape[1], data_match.shape[2], data_match.shape[3]) #Positive Match Loss pos_match_counter = 0 for d_i in range(feat_match.shape[1]): # if d_i != base_domain_idx: # continue for d_j in range(feat_match.shape[1]): if d_j > d_i: if self.args.pos_metric == 'l2': wasserstein_loss += torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1)) elif self.args.pos_metric == 'l1': wasserstein_loss += torch.sum( torch.sum(torch.abs( feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1)) elif self.args.pos_metric == 'cos': wasserstein_loss += torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :])) pos_match_counter += feat_match.shape[0] wasserstein_loss = wasserstein_loss / pos_match_counter penalty_ws += float(wasserstein_loss) loss_e += (self.args.penalty_ws * (epoch - self.args.penalty_s) / (self.args.epochs - self.args.penalty_s)) * wasserstein_loss loss_e += erm_loss loss_e += erm_loss_extra loss_e.backward(retain_graph=False) self.opt.step() del erm_loss_extra del erm_loss del wasserstein_loss del loss_e torch.cuda.empty_cache() print('Train Loss Basic : ', penalty_erm_extra, penalty_erm, penalty_ws) print('Train Acc Env : ', 100 * train_acc / train_size) print('Done Training for epoch: ', epoch) #Train Dataset Accuracy self.train_acc.append(100 * train_acc / train_size) #Val Dataset Accuracy self.val_acc.append(self.get_test_accuracy('val')) #Test Dataset Accuracy self.final_acc.append(self.get_test_accuracy('test')) #Save the model if current best epoch as per validation loss if self.val_acc[-1] > self.max_val_acc: self.max_val_acc = self.val_acc[-1] self.max_epoch = epoch self.save_model_erm_phase(run_erm) print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])
class TestBlendVisualizeCallback(TestVisualizeCallback): callback_cls = BlendVisualizeCallback @pytest.fixture(params=[ pytest.param(True, id="float"), pytest.param(False, id="long") ]) def data(self, data_shape, request): B, C, H, W = data_shape img = create_image(B, C, H, W) if request.param: img = img.float() else: img = img.long() return img.clone(), img.clone() @pytest.fixture def expected_calls(self, data, model, mode, callback): if not hasattr(model, callback.attr_name): return [] if callback.split_channels == (2, 1) or callback.split_channels == ([ 2, 1 ], None): pytest.skip("incompatible test") data1, data2 = data data1 = prepare_image(data1) data2 = prepare_image(data2) B, _, _, _ = data1.shape img1 = [data1] img2 = [data2] name1 = [ f"{mode}/{callback.name}", ] name2 = [ f"{mode}/{callback.name}", ] img = [img1, img2] name = [name1, name2] # channel splitting for pos in range(2): if callback.split_channels[pos]: img[pos] = [] name[pos] = [] img_new, name_new = [], [] splits = torch.split(data[pos], callback.split_channels[pos], dim=-3) for i, s in enumerate(splits): if isinstance(callback.name, str): n = f"{mode}/{callback.name}_{i}" else: n = f"{mode}/{callback.name[i]}" name_new.append(n) img_new.append(s) img[pos] = img_new name[pos] = name_new if len(img[0]) != len(img[1]): if len(img[0]) == 1: img[0] = img[0] * len(img[1]) elif len(img[1]) == 1: img[1] = img[1] * len(img[0]) else: raise RuntimeError() for pos in range(2): if callback.max_resolution: resize_mode = callback.resize_mode target = callback.max_resolution H_max, W_max = target needs_resize = [ i.shape[-2] > H_max or i.shape[-1] > W_max for i in img[pos] ] img[pos] = [ F.interpolate(i, target, mode=resize_mode) if resize else i for i, resize in zip(img[pos], needs_resize) ] for pos in range(2): if (colormap := callback.colormap[pos]): if isinstance(colormap, str): colormap = [colormap] * len(img[pos]) img[pos] = [ apply_colormap(i, cmap)[..., :3, :, :] if cmap is not None else i for cmap, i in zip(colormap, img[pos]) ] name = name[0] final_img = [] for pos, (d, s) in enumerate(zip(img[0], img[1])): _, C1, _, _ = d.shape _, C2, _, _ = s.shape if C1 != C2: if C1 == 1: d = d.repeat(1, C2, 1, 1) elif C2 == 1: s = s.repeat(1, C1, 1, 1) else: raise ValueError( f"could not match shapes {d.shape}, {s.shape}") s = clamp_normalize(s, inplace=True) d = clamp_normalize(d, inplace=True) final_img.append( alpha_blend(d, s, callback.alpha[1], callback.alpha[0])[0]) img = final_img if callback.as_uint8: img = [ to_8bit(i, same_on_batch=not callback.per_img_norm) for i in img ] if callback.split_batches: new_img, new_name = [], [] for i, n in zip(img, name): split_i = torch.split(i, 1, dim=0) split_n = [f"{n}/{b}" for b in range(B)] new_img += split_i new_name += split_n name, img = new_name, new_img step = [ model.current_epoch if callback.epoch_counter else model.global_step ] * len(name) expected = [(n, i, s) for n, i, s in zip(name, img, step)] return expected
image_path = f'input/{i}.png' image = Image.open(image_path).convert("RGB") image = T.Resize(in_sz)(image) image = image_to_tensor(image).to(device=device) images.append(image) images = torch.stack(images) images = images.unsqueeze(0) print(f'i: {images.shape}') print(f'c: {cam_poses.shape}') net.encode( images, cam_poses, focal, ) print("Rendering", args.num_views * H * W, "rays") all_rgb_fine = [] for rays in tqdm.tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)): rgb, _depth = render_par(rays[None]) all_rgb_fine.append(rgb[0]) _depth = None rgb_fine = torch.cat(all_rgb_fine) frames = (rgb_fine.view(args.num_views, H, W, 3).cpu().numpy() * 255).astype( np.uint8 ) #im_name = os.path.basename(os.path.splitext(image_path)[0]) im_name = output_name frames_dir_name = os.path.join(args.output, im_name + "_frames") os.makedirs(frames_dir_name, exist_ok=True)
class TestKeypointVisualizeCallback(TestVisualizeCallback): callback_cls = KeypointVisualizeCallback @pytest.fixture(params=[ pytest.param(True, id="float"), pytest.param(False, id="long") ]) def data(self, data_shape, request): B, C, H, W = data_shape N = 3 img = create_image(B, C, H, W) bbox = self.create_bbox(B, N) bbox = bbox.float() if request.param else bbox.long() cls = self.create_classes(B, N) score = self.create_classes(B, N) target = {"coords": bbox, "class": cls, "score": score} return img, target def create_bbox(self, B, N): torch.random.manual_seed(42) bbox = torch.empty(B, N, 4).fill_(-1).float() return bbox def create_classes(self, B, N): torch.random.manual_seed(42) return torch.empty(B, N, 1).fill_(-1).float() def create_scores(self, B, N): torch.random.manual_seed(42) return torch.empty(B, N, 1).fill_(-1) @pytest.fixture def expected_calls(self, data, model, mode, callback): if not hasattr(model, callback.attr_name): return [] data, target = data data = prepare_image(data) B, _, _, _ = data.shape img = [data] name = [ f"{mode}/{callback.name}", ] # channel splitting if callback.split_channels: img, name = [], [] splits = torch.split(data, callback.split_channels, dim=-3) for i, s in enumerate(splits): n = f"{mode}/{callback.name[i]}" name.append(n) img.append(s) if callback.max_resolution: resize_mode = callback.resize_mode target = callback.max_resolution H_max, W_max = target needs_resize = [ i.shape[-2] > H_max or i.shape[-1] > W_max for i in img ] img = [ F.interpolate(i, target, mode=resize_mode) if resize else i for i, resize in zip(img, needs_resize) ] if (colormap := callback.colormap): if isinstance(colormap, str): colormap = [colormap] * len(img) img = [ apply_colormap(i, cmap)[..., :3, :, :] if cmap is not None else i for cmap, i in zip(colormap, img) ] img = [i.repeat(1, 3, 1, 1) if i.shape[-3] == 1 else i for i in img] if callback.as_uint8: img = [ to_8bit(i, same_on_batch=not callback.per_img_norm) for i in img ] if callback.split_batches: new_img, new_name = [], [] for i, n in zip(img, name): split_i = torch.split(i, 1, dim=0) split_n = [f"{n}/{b}" for b in range(B)] new_img += split_i new_name += split_n name, img = new_name, new_img step = [ model.current_epoch if callback.epoch_counter else model.global_step ] * len(name) expected = [(n, i, s) for n, i, s in zip(name, img, step)] return expected
def decode_structure(model, root_code): """ Decode a root code into a tree structure of boxes """ decode = model.sampleDecoder(root_code) syms = [torch.ones(8).mul(10).cuda()] stack = [decode] boxes = [] while len(stack) > 0: f = stack.pop() label_prob = model.nodeClassifier(f) _, label = torch.max(label_prob, 1) label = label.data if label[0] == 1: # ADJ left, right = model.adjDecoder(f) stack.append(left) stack.append(right) s = syms.pop() syms.append(s) syms.append(s) if label[0] == 2: # SYM left, s = model.symDecoder(f) s = s.squeeze(0) stack.append(left) syms.pop() syms.append(s.data) if label[0] == 0: # BOX reBox = model.boxDecoder(f) reBoxes = [reBox] s = syms.pop() l1 = abs(s[0] + 1) l2 = abs(s[0]) l3 = abs(s[0] - 1) if l1 < 0.15: sList = torch.split(s, 1, 0) bList = torch.split(reBox.data.squeeze(0), 1, 0) f1 = torch.cat([sList[1], sList[2], sList[3]]) f1 = f1 / torch.norm(f1) f2 = torch.cat([sList[4], sList[5], sList[6]]) folds = round(1 / s[7]) for i in range(folds - 1): rotvector = torch.cat( [f1, sList[7].mul(2 * 3.1415).mul(i + 1)]) rotm = vrrotvec2mat(rotvector) center = torch.cat([bList[0], bList[1], bList[2]]) dir0 = torch.cat([bList[3], bList[4], bList[5]]) dir1 = torch.cat([bList[6], bList[7], bList[8]]) dir2 = torch.cat([bList[9], bList[10], bList[11]]) newcenter = rotm.matmul(center.add(-f2)).add(f2) newdir1 = rotm.matmul(dir1) newdir2 = rotm.matmul(dir2) newbox = torch.cat([newcenter, dir0, newdir1, newdir2]) reBoxes.append(Variable(newbox.unsqueeze(0))) if l2 < 0.15: sList = torch.split(s, 1, 0) bList = torch.split(reBox.data.squeeze(0), 1, 0) trans = torch.cat([sList[1], sList[2], sList[3]]) trans_end = torch.cat([sList[4], sList[5], sList[6]]) center = torch.cat([bList[0], bList[1], bList[2]]) trans_length = math.sqrt(torch.sum(trans**2)) trans_total = math.sqrt(torch.sum(trans_end.add(-center)**2)) folds = round(trans_total / trans_length) for i in range(folds): center = torch.cat([bList[0], bList[1], bList[2]]) dir0 = torch.cat([bList[3], bList[4], bList[5]]) dir1 = torch.cat([bList[6], bList[7], bList[8]]) dir2 = torch.cat([bList[9], bList[10], bList[11]]) newcenter = center.add(trans.mul(i + 1)) newbox = torch.cat([newcenter, dir0, dir1, dir2]) reBoxes.append(Variable(newbox.unsqueeze(0))) if l3 < 0.15: sList = torch.split(s, 1, 0) bList = torch.split(reBox.data.squeeze(0), 1, 0) ref_normal = torch.cat([sList[1], sList[2], sList[3]]) ref_normal = ref_normal / torch.norm(ref_normal) ref_point = torch.cat([sList[4], sList[5], sList[6]]) center = torch.cat([bList[0], bList[1], bList[2]]) dir0 = torch.cat([bList[3], bList[4], bList[5]]) dir1 = torch.cat([bList[6], bList[7], bList[8]]) dir2 = torch.cat([bList[9], bList[10], bList[11]]) if ref_normal.matmul(ref_point.add(-center)) < 0: ref_normal = -ref_normal newcenter = ref_normal.mul( 2 * abs(torch.sum( ref_point.add(-center).mul(ref_normal)))).add(center) if ref_normal.matmul(dir1) < 0: ref_normal = -ref_normal dir1 = dir1.add(ref_normal.mul(-2 * ref_normal.matmul(dir1))) if ref_normal.matmul(dir2) < 0: ref_normal = -ref_normal dir2 = dir2.add(ref_normal.mul(-2 * ref_normal.matmul(dir2))) newbox = torch.cat([newcenter, dir0, dir1, dir2]) reBoxes.append(Variable(newbox.unsqueeze(0))) boxes.extend(reBoxes) return boxes
def forward( self, # pylint: disable=arguments-differ inputs: torch.Tensor, mask: torch.LongTensor = None, ) -> torch.FloatTensor: """ Parameters ---------- inputs : ``torch.FloatTensor``, required. A tensor of shape (batch_size, timesteps, input_dim) mask : ``torch.FloatTensor``, optional (default = None). A tensor of shape (batch_size, timesteps). Returns ------- A tensor of shape (batch_size, timesteps, output_projection_dim), where output_projection_dim = input_dim by default. """ num_heads = self._num_heads batch_size, timesteps, hidden_dim = inputs.size() if mask is None: mask = Variable(inputs.data.new(batch_size, timesteps).fill_(1.0)) # Treat the queries, keys and values each as a ``num_heads`` size batch. # shape (num_heads, batch_size * timesteps, hidden_dim) inputs_per_head = inputs.repeat(num_heads, 1, 1).view(num_heads, batch_size * timesteps, hidden_dim) # Do the projections for all the heads at once. # Then reshape the result as though it had a # (num_heads * batch_size) sized batch. queries_per_head = torch.bmm(inputs_per_head, self._query_projections) # shape (num_heads * batch_size, timesteps, attention_dim) queries_per_head = queries_per_head.view(num_heads * batch_size, timesteps, self._attention_dim) keys_per_head = torch.bmm(inputs_per_head, self._key_projections) # shape (num_heads * batch_size, timesteps, attention_dim) keys_per_head = keys_per_head.view(num_heads * batch_size, timesteps, self._attention_dim) values_per_head = torch.bmm(inputs_per_head, self._value_projections) # shape (num_heads * batch_size, timesteps, attention_dim) values_per_head = values_per_head.view(num_heads * batch_size, timesteps, self._values_dim) # shape (num_heads * batch_size, timesteps, timesteps) scaled_similarities = ( torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) / self._scale) # Masking should go here causality_mask = subsequent_mask(timesteps).cuda() masked_scaled_similarities = scaled_similarities.masked_fill( causality_mask == 0, -1e9) # shape (num_heads * batch_size, timesteps, timesteps) # Normalise the distributions, using the same mask for all heads. attention = masked_softmax(masked_scaled_similarities, mask.repeat(num_heads, 1)) attention = self._attention_dropout(attention) # This is doing the following batch-wise matrix multiplication: # (num_heads * batch_size, timesteps, timesteps) * # (num_heads * batch_size, timesteps, values_dim) # which is equivalent to a weighted sum of the values with respect to # the attention distributions for each element in the num_heads * batch_size # dimension. # shape (num_heads * batch_size, timesteps, values_dim) outputs = torch.bmm(attention, values_per_head) # Reshape back to original shape (batch_size, timesteps, num_heads * values_dim) # Note that we _cannot_ use a reshape here, because this tensor was created # with num_heads being the first dimension, so reshaping naively would not # throw an error, but give an incorrect result. outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1) # Project back to original input size. # shape (batch_size, timesteps, input_size) outputs = self._output_projection(outputs) return outputs
def losses(self): """ Return the losses from a set of RPN predictions and their associated ground-truth. Returns: dict[loss name -> loss value]: A dict mapping from loss name to loss value. Loss names are: `loss_rpn_cls` for objectness classification and `loss_rpn_loc` for proposal localization. """ def resample(label): """ Randomly sample a subset of positive and negative examples by overwritting the label vector to the ignore value (-1) for all elements that are not included in the sample. """ pos_idx, neg_idx = subsample_labels( label, self.batch_size_per_image, self.positive_fraction, 0 ) # Fill with the ignore label (-1), then set positive and negative labels label.fill_(-1) label.scatter_(0, pos_idx, 1) label.scatter_(0, neg_idx, 0) return label gt_objectness_logits, gt_anchor_deltas = self._get_ground_truth() """ gt_objectness_logits: list of N tensors. Tensor i is a vector whose length is the total number of anchors in image i (i.e., len(anchors[i])) gt_anchor_deltas: list of N tensors. Tensor i has shape (len(anchors[i]), B), where B is the box dimension """ # Collect all objectness labels and delta targets over feature maps and images # The final ordering is L, N, H, W, A from slowest to fastest axis. num_anchors_per_map = [np.prod(x.shape[1:]) for x in self.pred_objectness_logits] num_anchors_per_image = sum(num_anchors_per_map) # Stack to: (N, num_anchors_per_image) gt_objectness_logits = torch.stack( [resample(label) for label in gt_objectness_logits], dim=0 ) # Log the number of positive/negative anchors per-image that's used in training num_pos_anchors = (gt_objectness_logits == 1).sum().item() num_neg_anchors = (gt_objectness_logits == 0).sum().item() storage = get_event_storage() storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / self.num_images) storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / self.num_images) assert gt_objectness_logits.shape[1] == num_anchors_per_image # Split to tuple of L tensors, each with shape (N, num_anchors_per_map) gt_objectness_logits = torch.split(gt_objectness_logits, num_anchors_per_map, dim=1) # Concat from all feature maps gt_objectness_logits = cat([x.flatten() for x in gt_objectness_logits], dim=0) # Stack to: (N, num_anchors_per_image, B) gt_anchor_deltas = torch.stack(gt_anchor_deltas, dim=0) assert gt_anchor_deltas.shape[1] == num_anchors_per_image B = gt_anchor_deltas.shape[2] # box dimension (4 or 5) # Split to tuple of L tensors, each with shape (N, num_anchors_per_image) gt_anchor_deltas = torch.split(gt_anchor_deltas, num_anchors_per_map, dim=1) # Concat from all feature maps gt_anchor_deltas = cat([x.reshape(-1, B) for x in gt_anchor_deltas], dim=0) # Collect all objectness logits and delta predictions over feature maps # and images to arrive at the same shape as the labels and targets # The final ordering is L, N, H, W, A from slowest to fastest axis. pred_objectness_logits = cat( [ # Reshape: (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N*Hi*Wi*A, ) x.permute(0, 2, 3, 1).flatten() for x in self.pred_objectness_logits ], dim=0, ) pred_anchor_deltas = cat( [ # Reshape: (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) # -> (N*Hi*Wi*A, B) x.view(x.shape[0], -1, B, x.shape[-2], x.shape[-1]) .permute(0, 3, 4, 1, 2) .reshape(-1, B) for x in self.pred_anchor_deltas ], dim=0, ) objectness_loss, localization_loss = rpn_losses( gt_objectness_logits, gt_anchor_deltas, pred_objectness_logits, pred_anchor_deltas, self.smooth_l1_beta, ) normalizer = 1.0 / (self.batch_size_per_image * self.num_images) loss_cls = objectness_loss * normalizer # cls: classification loss loss_loc = localization_loss * normalizer # loc: localization loss losses = {"loss_rpn_cls": loss_cls, "loss_rpn_loc": loss_loc} return losses
def split(input, sizes_or_sections, dim): return th.split(input, sizes_or_sections, dim)
def train(x, y, iteration, epoch, batch_size, target_map = None, r_mixup = 0.0): G.optim.zero_grad() D.optim.zero_grad() if config["unet_mixup"]: real_target = torch.tensor([1.0]).cuda() fake_target = torch.tensor([0.0]).cuda() if config["unet_mixup"] and not config["full_batch_mixup"]: use_mixup_in_this_round = True elif config["unet_mixup"] and config["full_batch_mixup"]: use_mixup_in_this_round = torch.rand(1).detach().item()<r_mixup else: use_mixup_in_this_round = False out = {} skip_normal_real_fake_loss = (use_mixup_in_this_round and config["full_batch_mixup"] ) n_d_accu = config['num_D_accumulations'] split_size = int(x.size(0)/n_d_accu) x = torch.split(x, split_size) y = torch.split(y, split_size) d_real_target = torch.tensor([1.0]).cuda() d_fake_target = torch.tensor([0.0]).cuda() discriminator_loss = functools.partial(BCEloss, d_real_target=d_real_target, d_fake_target=d_fake_target) mix_fake_target = torch.tensor([1.0]).cuda() fake_loss = functools.partial(BCEfakeloss, target = mix_fake_target) # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): counter = 0 # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() for accumulation_index in range(n_d_accu): z_.sample_() y_.sample_() if use_mixup_in_this_round: if (not config["full_batch_mixup"]) or (config["full_batch_mixup"] and (config["consistency_loss_and_augmentation"] or config["consistency_loss"]) ): D_fake, D_real , D_mixed, G_z, mixed, D_middle_fake, D_middle_real, D_middle_mixed, target_map = GD(z_[:batch_size], y_[:batch_size], x[counter], y[counter], train_G=False, split_D=config['split_D'], mixup = True, target_map = target_map) # mixup can be true because weight is set to 0 when no mixup is used else: D_mixed, G_z, mixed, D_middle_mixed, target_map = GD(z_[:batch_size], y_[:batch_size], x[counter], y[counter], train_G=False, return_G_z = True, split_D=config['split_D'], mixup = True, mixup_only = True, target_map = target_map) if config["slow_mixup"] and not config["full_batch_mixup"]: mixup_coeff = min(1.0, epoch/config["warmup_epochs"] )#use without full batch mixup else: mixup_coeff = 1.0 if config["display_mixed_batch"]: # This can help for debugging plt.figure() m = torchvision.utils.make_grid(mixed,nrow=5,padding=2,normalize = True) m = m.permute(1,2,0) m = m.cpu().numpy() plt.imshow(m) plt.figure() plt.figure() m = torchvision.utils.make_grid(G_z,nrow=5,padding=2,normalize = True) m = m.permute(1,2,0) m = m.cpu().numpy() plt.imshow(m) plt.figure() plt.figure() m = torchvision.utils.make_grid(x[counter],nrow=5,padding=2,normalize = True) m = m.permute(1,2,0) m = m.cpu().numpy() plt.imshow(m) plt.figure() m = torchvision.utils.make_grid(target_map,nrow=5,padding=2) m = m.permute(1,2,0) m = m.cpu().numpy() plt.imshow(m) plt.title("mix") plt.show() plt.figure() else: D_fake, D_real , G_z, D_middle_fake, D_middle_real = GD(z_[:batch_size], y_[:batch_size], x[counter], y[counter], train_G=False, split_D=config['split_D']) if not skip_normal_real_fake_loss: D_loss_real_2d, D_loss_fake_2d = discriminator_loss(D_fake.view(-1), D_real.view(-1)) D_loss_real_2d_item = D_loss_real_2d.detach().item() D_loss_fake_2d_item = D_loss_fake_2d.detach().item() if use_mixup_in_this_round and (config["consistency_loss"] or config["consistency_loss_and_augmentation"]): mix = D_real*target_map + D_fake*(1-target_map) if use_mixup_in_this_round: D_mixed_flattened = D_mixed.view(-1) target_map_flattend = target_map.view(-1) mix_list = [] for i in range(D_mixed.size(0)): # MIXUP LOSS 2D mix2d_i= F.binary_cross_entropy_with_logits(D_mixed[i].view(-1),target_map[i].view(-1) ) mix_list.append(mix2d_i) D_loss_mixed_2d = torch.stack(mix_list) #-> D_loss_mixed_2d.mean() is taken later D_loss_mixed_2d_item = D_loss_mixed_2d.mean().detach().item() #D_loss_mixed_2d = D_loss_mixed_2d.view(D_mixed.size()).mean([2,3]) if not skip_normal_real_fake_loss: D_loss_real_middle, D_loss_fake_middle = discriminator_loss(D_middle_fake, D_middle_real) D_loss_real_middle_item = D_loss_real_middle.detach().item() D_loss_fake_middle_item = D_loss_fake_middle.detach().item() if use_mixup_in_this_round and not config["consistency_loss"]: # consistency loss is only concerned with segmenter output #target for mixed encoder loss is fake mix_bce = F.binary_cross_entropy_with_logits(D_middle_mixed, fake_target.expand_as(D_middle_mixed), reduction="none") mixed_middle_loss = mixup_coeff*mix_bce mixed_middle_loss_item = mixed_middle_loss.mean().detach().item() if skip_normal_real_fake_loss: D_loss_real = torch.tensor([0.0]).cuda() D_loss_fake = torch.tensor([0.0]).cuda() else: D_loss_real = D_loss_real_2d + D_loss_real_middle D_loss_fake = D_loss_fake_2d + D_loss_fake_middle D_loss_real_item = D_loss_real.detach().item() D_loss_fake_item = D_loss_fake.detach().item() D_loss = 0.5*D_loss_real + 0.5*D_loss_fake if use_mixup_in_this_round: if config["consistency_loss"] or config["consistency_loss_and_augmentation"]: consistency_loss = mixup_coeff*1.0*F.mse_loss(D_mixed, mix ) consistency_loss_item = consistency_loss.float().detach().item() if not config["consistency_loss"]: # GAN loss from cutmix augmentation (=/= consitency loss) mix_loss = D_loss_mixed_2d + mixed_middle_loss mix_loss = mix_loss.mean() else: mix_loss = 0.0 if config["consistency_loss"]: mix_loss = consistency_loss elif config["consistency_loss_and_augmentation"]: mix_loss = mix_loss + consistency_loss D_loss = D_loss + mix_loss D_loss = D_loss / float(config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) if iteration%2 == 0 and EG: print('D extrapolation') D.optim.extrapolation() else: print('D step') D.optim.step() del D_loss # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) ###################################### # G-step ###################################### # Zero G's gradients by default before training G, for safety G.optim.zero_grad() counter = 0 z_.sample_() y_.sample_() z__ = torch.split(z_, split_size) #batch_size) y__ = torch.split(y_, split_size) #batch_size) # If accumulating gradients, loop multiple times for accumulation_index in range(config['num_G_accumulations']): G_fake, G_fake_middle = GD(z__[counter], y__[counter], train_G=True, split_D=config['split_D'], reference_x = x[counter] ) G_loss_fake_2d = fake_loss(G_fake) G_loss_fake_middle = fake_loss(G_fake_middle) G_loss = 0.5*G_loss_fake_middle + 0.5*G_loss_fake_2d G_loss = G_loss / float(config['num_G_accumulations']) G_loss_fake_middle_item = G_loss_fake_middle.detach().item() G_loss_fake_2d_item = G_loss_fake_2d.detach().item() G_loss_item = G_loss.detach().item() G_loss.backward() counter += 1 # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) print(iteration) if iteration%2 == 0 and EG: print('G extrapolation') G.optim.extrapolation() else: print('G step') G.optim.step() del G_loss # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) # save intermediate losses if use_mixup_in_this_round and (config["consistency_loss"] or config["consistency_loss_and_augmentation"]) and config["num_D_steps"]>0: out["consistency"] = float(consistency_loss_item) out['G_loss'] = float(G_loss_item) if not (config["full_batch_mixup"] and use_mixup_in_this_round) and config["num_D_steps"]>0: out['D_loss_real'] = float(D_loss_real_item) out['D_loss_fake'] = float(D_loss_fake_item) if use_mixup_in_this_round and not config["consistency_loss"] and config["num_D_steps"]>0: out["mixed_middle_loss"] = float(mixed_middle_loss_item) out["D_loss_mixed_2d"] = float(D_loss_mixed_2d_item) if not (config["full_batch_mixup"] and use_mixup_in_this_round): if config["num_D_steps"]>0: out["D_loss_real_middle"] = float(D_loss_real_middle_item) out["D_loss_fake_middle"] = float(D_loss_fake_middle_item) out["D_loss_real_2d"] = float(D_loss_real_2d_item) out["D_loss_fake_2d"] = float(D_loss_fake_2d_item) out["G_loss_fake_middle"] = float(G_loss_fake_middle_item) out["G_loss_fake_2d"] = float(G_loss_fake_2d_item) return out
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, yesno: torch.IntTensor = None, question_tf: torch.FloatTensor = None, passage_tf: torch.FloatTensor = None, q_em_cased: torch.IntTensor = None, p_em_cased: torch.IntTensor = None, q_em_uncased: torch.IntTensor = None, p_em_uncased: torch.IntTensor = None, q_in_lemma: torch.IntTensor = None, p_in_lemma: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ x1_c_emb = self._dropout(self._char_field_embedder(passage)) x2_c_emb = self._dropout(self._char_field_embedder(question)) # embedded_question = torch.cat([self._dropout(self._text_field_embedder(question)), # self._features_embedder(q_em_cased), # self._features_embedder(q_em_uncased), # self._features_embedder(q_in_lemma), # question_tf.unsqueeze(2)], dim=2) # embedded_passage = torch.cat([self._dropout(self._text_field_embedder(passage)), # self._features_embedder(p_em_cased), # self._features_embedder(p_em_uncased), # self._features_embedder(p_in_lemma), # passage_tf.unsqueeze(2)], dim=2) token_emb_q = self._dropout(self._text_field_embedder(question)) token_emb_c = self._dropout(self._text_field_embedder(passage)) token_emb_question, q_ner_and_pos = torch.split(token_emb_q, [300, 40], dim=2) token_emb_passage, p_ner_and_pos = torch.split(token_emb_c, [300, 40], dim=2) question_word_features = torch.cat([ q_ner_and_pos, self._features_embedder(q_em_cased), self._features_embedder(q_em_uncased), self._features_embedder(q_in_lemma), question_tf.unsqueeze(2) ], dim=2) passage_word_features = torch.cat([ p_ner_and_pos, self._features_embedder(p_em_cased), self._features_embedder(p_em_uncased), self._features_embedder(p_in_lemma), passage_tf.unsqueeze(2) ], dim=2) # embedded_question = self._highway_layer(embedded_q) # embedded_passage = self._highway_layer(embedded_q) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None char_features_c = self._char_rnn( x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2), x1_c_emb.size(3))), passage_lstm_mask.unsqueeze(2).repeat( 1, 1, x1_c_emb.size(2)).reshape( (x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2)))).reshape( (x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2), -1))[:, :, -1, :] char_features_q = self._char_rnn( x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2), x2_c_emb.size(3))), question_lstm_mask.unsqueeze(2).repeat( 1, 1, x2_c_emb.size(2)).reshape( (x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2)))).reshape( (x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2), -1))[:, :, -1, :] # token_emb_q, char_emb_q, question_word_features = torch.split(embedded_question, [300, 300, 56], dim=2) # token_emb_c, char_emb_c, passage_word_features = torch.split(embedded_passage, [300, 300, 56], dim=2) # char_features_q = self._char_rnn(char_emb_q, question_lstm_mask) # char_features_c = self._char_rnn(char_emb_c, passage_lstm_mask) emb_question = torch.cat( [token_emb_question, char_features_q, question_word_features], dim=2) emb_passage = torch.cat( [token_emb_passage, char_features_c, passage_word_features], dim=2) encoded_question = self._dropout( self._phrase_layer(emb_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(emb_passage, passage_lstm_mask)) batch_size = encoded_question.size(0) passage_length = encoded_passage.size(1) encoding_dim = encoded_question.size(-1) # c_check = self._stacked_brnn(encoded_passage, passage_lstm_mask) # q = self._stacked_brnn(encoded_question, question_lstm_mask) c_check = encoded_passage q = encoded_question for i in range(self.hops): q_tilde = self.interactive_aligners[i].forward( c_check, q, question_mask) c_bar = self.interactive_SFUs[i].forward( c_check, torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2)) c_tilde = self.self_aligners[i].forward(c_bar, passage_mask) c_hat = self.self_SFUs[i].forward( c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde], 2)) c_check = self.aggregate_rnns[i].forward(c_hat, passage_mask) # Predict start_scores, end_scores, yesno_scores = self.mem_ans_ptr.forward( c_check, q, passage_mask, question_mask) best_span, yesno_predict, loc = self.get_best_span( start_scores, end_scores, yesno_scores) output_dict = { "span_start_logits": start_scores, "span_end_logits": end_scores, "best_span": best_span } # Compute the loss for training. if span_start is not None: loss = nll_loss(start_scores, span_start.squeeze(-1)) self._span_start_accuracy(start_scores, span_start.squeeze(-1)) loss += nll_loss(end_scores, span_end.squeeze(-1)) self._span_end_accuracy(end_scores, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) gold_span_end_loc = [] span_end = span_end.view(batch_size).squeeze().data.cpu().numpy() for i in range(batch_size): gold_span_end_loc.append( max(span_end[i] + i * passage_length, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) _yesno = yesno_scores.view(-1, 3).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(_yesno, yesno.view(-1), ignore_index=-1) pred_span_end_loc = [] for i in range(batch_size): pred_span_end_loc.append(max(loc[i], 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = yesno_scores.view(-1, 3).index_select(0, predicted_end).view( -1, 3) self._span_yesno_accuracy(_yesno, yesno.squeeze(-1)) output_dict['loss'] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['yesno'] = yesno_predict return output_dict
def forward(self, x): batch_size = x.size(0) # batch, window, n_val for i in range(self.target): embed = self.em(torch.tensor(i).cuda()) embed = torch.unsqueeze(torch.unsqueeze(embed, 0), 1) # 1, 1, size embed = embed.repeat(batch_size, window, 1) fea = x[:, :, 5 * i:5 * (i + 1)] # fea1=fea.unsqueeze(-1) # fea2=fea.unsqueeze(-2) # fea_cross=torch.matmul(fea1,fea2).reshape(batch_size,window,-1) # fea_map=[] # for j in range(self.em_size): # fea_map.append(self.fea_cross_weight[j](fea_cross)) # crossed_fea=torch.cat(fea_map,dim=2) # fc=crossed_fea fc = self.w(fea) fc += embed # fc = self.fc(torch.cat([embed, fea], dim=2)) # batch_size, window, size if i == 0: fc_output = fc else: fc_output = torch.cat([fc_output, fc], dim=2) index_fea = x[:, :, 30:] index_fea = self.index_fc(index_fea) features = torch.cat([fc_output, index_fea], dim=2) features_origin1 = features # lstnet_input = fc_output features = features.reshape(batch_size, window, -1, self.em_size) split_tensor1 = torch.stack( torch.split(features, self.em_size * [1], 3), 0) split_tensor2 = split_tensor1.permute(0, 1, 2, 4, 3) dot_result_m = torch.matmul(split_tensor1, split_tensor2) dot_result_m = dot_result_m.view(self.em_size, batch_size, window, 7 * 7) crossed_feas = self.W1(dot_result_m) crossed_feas = crossed_feas.permute(1, 0, 2, 3) crossed_feas = crossed_feas.permute(0, 2, 1, 3) crossed_feas = crossed_feas.permute(0, 1, 3, 2) # crossed_feas=F.relu(crossed_feas) features_origin2 = crossed_feas # features lstnet_input = crossed_feas lstnet_input = lstnet_input.reshape(batch_size, window, -1) lstnet_input += features_origin1 # features2 = features.reshape(batch_size,window,self.em_size,-1) # CNN # lstnet_input=features c = lstnet_input.view(-1, 1, self.P, self.m) # batch, 1, window, n_val c = F.relu(self.conv1(c)) c = self.dropout(c) c = torch.squeeze(c, 3) # batch, hidCNN, window-kernel_size+1 # RNN r = c.permute(2, 0, 1).contiguous() _, r = self.GRU1(r) r = self.dropout(torch.squeeze(r, 0)) # batch, hidRNN # skip-rnn if (self.skip > 0): s = c[:, :, int(-self.pt * self.skip):].contiguous() s = s.view(batch_size, self.hidC, self.pt, self.skip) s = s.permute(2, 0, 3, 1).contiguous() s = s.view(self.pt, batch_size * self.skip, self.hidC) _, s = self.GRUskip(s) s = s.view(batch_size, self.skip * self.hidS) s = self.dropout(s) r = torch.cat((r, s), 1) # batch, skip*hidSkip + hidRNN res = self.linear1(r) # batch, n_val # highway if (self.hw > 0): z = lstnet_input[:, -self.hw:, :] # batch, hw, n_val z = z.permute(0, 2, 1).contiguous().view(-1, self.hw) z = self.highway(z) z = z.view(-1, self.m) res = res + z # batch, n_val res = self.output(res) return res
def split(self, split_size, dim=0): r"""Splits this tensor into tensor chunks of :attr:`split_size` size. See :func:`torch.split`. """ return torch.split(self, split_size, dim)
def _train_epoch(self, epoch): """ Train the model for the given epoch """ # change to the way we are loading data the correct form .. @ask sylvia train_data_with_time = None train_data, train_times = data.get_time_columns(train_data_with_time) self.train_rnn_inp = data.get_rnn_input( self.train_tokens, self.train_counts, train_times, self.hyperparameters['num_times'], len(self.vocab), len(self.train_tokens)) self.model.train() acc_loss = 0 acc_nll = 0 acc_kl_theta_loss = 0 acc_kl_eta_loss = 0 acc_kl_alpha_loss = 0 cnt = 0 indices = torch.randperm(train_data.shape[0]) indices = torch.split(indices, self.hyperparameters['batch_size']) optimizer = self.set_optimizer() for idx, ind in enumerate(indices): optimizer.zero_grad() self.zero_grad() data_batch, times_batch = data.get_batch(train_data, ind, self.device, train_times) # we can use pytorch data loader here ### I comment the following row just because I need to make the code compile :/ # times_batch = get_indices(train_times, times_batch) sums = data_batch.sum(1).unsqueeze(1) times_batch = torch.from_numpy(times_batch) if self.hyperparameters['bow_norm']: normalized_data_batch = data_batch / sums else: normalized_data_batch = data_batch loss, nll, kl_alpha, kl_eta, kl_theta = self.model.forward( data_batch, normalized_data_batch, times_batch, self.train_rnn_inp, train_data.shape[0]) loss.backward() if self.hyperparameters['clip'] > 0: torch.nn.utils.clip_grad_norm_(self.parameters(), self.hyperparameters['clip']) optimizer.step() acc_loss += torch.sum(loss).item() acc_nll += torch.sum(nll).item() acc_kl_theta_loss += torch.sum(kl_theta).item() acc_kl_eta_loss += torch.sum(kl_eta).item() acc_kl_alpha_loss += torch.sum(kl_alpha).item() cnt += 1 if idx % self.hyperparameters['log_interval'] == 0 and idx > 0: cur_loss = round(acc_loss / cnt, 2) cur_nll = round(acc_nll / cnt, 2) cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) cur_kl_eta = round(acc_kl_eta_loss / cnt, 2) cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2) lr = optimizer.param_groups[0]['lr'] print( 'Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} .. KL_eta: {} .. KL_alpha: {} .. Rec_loss: {} .. NELBO: {}'.format( epoch, idx, len(indices), lr, cur_kl_theta, cur_kl_eta, cur_kl_alpha, cur_nll, cur_loss)) cur_loss = round(acc_loss / cnt, 2) cur_nll = round(acc_nll / cnt, 2) cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) cur_kl_eta = round(acc_kl_eta_loss / cnt, 2) cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2) lr = optimizer.param_groups[0]['lr'] print('*' * 100) print( 'Epoch----->{} .. LR: {} .. KL_theta: {} .. KL_eta: {} .. KL_alpha: {} .. Rec_loss: {} .. NELBO: {}'.format( epoch, lr, cur_kl_theta, cur_kl_eta, cur_kl_alpha, cur_nll, cur_loss)) print('*' * 100)
def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_x = AverageMeter() losses_u = AverageMeter() ws = AverageMeter() end = time.time() # bar = Bar('Training', max=args.train_iteration) labeled_train_iter = iter(labeled_trainloader) unlabeled_train_iter = iter(unlabeled_trainloader) model.train() for batch_idx in range(args.train_iteration): try: inputs_x, targets_x = labeled_train_iter.next() except: labeled_train_iter = iter(labeled_trainloader) inputs_x, targets_x = labeled_train_iter.next() try: (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() except: unlabeled_train_iter = iter(unlabeled_trainloader) (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() # measure data loading time data_time.update(time.time() - end) batch_size = inputs_x.size(0) # Transform label to one-hot targets_x = torch.zeros(batch_size, args.num_classes).scatter_(1, targets_x.view(-1,1).long(), 1) if use_cuda: inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True) inputs_u = inputs_u.cuda() inputs_u2 = inputs_u2.cuda() with torch.no_grad(): # compute guessed labels of unlabel samples outputs_u = model(inputs_u) outputs_u2 = model(inputs_u2) p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 pt = p**(1/args.T) targets_u = pt / pt.sum(dim=1, keepdim=True) targets_u = targets_u.detach() # mixup all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0) all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0) l = np.random.beta(args.alpha, args.alpha) l = max(l, 1-l) idx = torch.randperm(all_inputs.size(0)) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b # interleave labeled and unlabed samples between batches to get correct batchnorm calculation mixed_input = list(torch.split(mixed_input, batch_size)) mixed_input = interleave(mixed_input, batch_size) logits = [model(mixed_input[0])] for input in mixed_input[1:]: logits.append(model(input)) # put interleaved samples back logits = interleave(logits, batch_size) logits_x = logits[0] logits_u = torch.cat(logits[1:], dim=0) Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.train_iteration) loss = Lx + w * Lu # record loss losses.update(loss.item(), inputs_x.size(0)) losses_x.update(Lx.item(), inputs_x.size(0)) losses_u.update(Lu.item(), inputs_x.size(0)) ws.update(w, inputs_x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() ema_optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress # bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format( # batch=batch_idx + 1, # size=args.train_iteration, # data=data_time.avg, # bt=batch_time.avg, # total=bar.elapsed_td, # eta=bar.eta_td, # loss=losses.avg, # loss_x=losses_x.avg, # loss_u=losses_u.avg, # w=ws.avg, # ) # bar.next() # bar.finish() return (losses.avg, losses_x.avg, losses_u.avg,)
def split(self, split_size, dim=0): """Splits this tensor into a list of tensors. See :func:`torch.split`. """ return torch.split(self, split_size, dim)
def forward(self, x): x = self.filter(x) out = torch.split(x, self.out_channels, 1) return torch.max(out[0], out[1])
def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]: """Compute the matrix fwt for the given input signal. Matrix fwt are used to avoid padding. Args: input_signal (torch.Tensor): Batched input data [batch_size, time], should be of even length. 1d inputs are interpreted as [time]. Returns: List[torch.Tensor]: A list with the coefficients for each scale. Raises: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ if input_signal.dim() == 1: # assume time series input_signal = input_signal.unsqueeze(0) elif input_signal.dim() != 2: raise ValueError( f"Invalid input tensor shape {input_signal.size()}. " "The input signal is expected to be of the form " "[batch_size, length].") if input_signal.shape[-1] % 2 != 0: # odd length input # print('input length odd, padding a zero on the right') input_signal = torch.nn.functional.pad(input_signal, [0, 1]) _, length = input_signal.shape re_build = False if self.input_length != length: self.input_length = length re_build = True if self.level is None: self.level = int(np.log2(length)) re_build = True elif self.level <= 0: raise ValueError("level must be a positive integer.") if not self.fwt_matrix_list or re_build: self._construct_analysis_matrices(device=input_signal.device, dtype=input_signal.dtype) lo = input_signal.T split_list = [] for scale, fwt_matrix in enumerate(self.fwt_matrix_list): if self.pad_list[scale]: # fix odd coefficients lengths for the conv matrix to work. lo = torch.nn.functional.pad(lo.T.unsqueeze(1), [0, 1]).squeeze(1).T coefficients = torch.sparse.mm(fwt_matrix, lo) lo, hi = torch.split(coefficients, coefficients.shape[0] // 2, dim=0) split_list.append(hi) split_list.append(lo) return split_list[::-1]
times = [] epochs = 500 for epoch in range(epochs): print("Epoch", epoch) epoch_start = time.time() model.train() train_losses = [] train_accuracies = [] train_accuracies2 = [] _start = time.time() for batch_idx, sequence_batch in enumerate(train_loader): sequence_batch = Variable(sequence_batch, requires_grad=False) if use_cuda: sequence_batch = sequence_batch.cuda() sequence = sequence_batch.squeeze(dim=0) subsequences = torch.split(sequence, split_size=100) for seq in subsequences: batch_size = 1 seq_len = seq.size()[0] seq = seq.view(seq_len, -1).contiguous() seq = seq.unsqueeze(dim=1) targets = seq[1:] optimizer.zero_grad() predictions = model(seq) losses = [F.binary_cross_entropy(input=pred, target=targets[step]) for step, pred in enumerate(predictions[:-1])] loss = sum(losses) loss.backward() torch.nn.utils.clip_grad_norm(model.parameters(), .25) optimizer.step() train_losses.append(np.mean([l.data.cpu().numpy() for l in losses]))
def cell(self, x, h_prev, x_mask, h_mask): """ forwards inside the cell of our model :param x: input :param h_prev: hidden of previous step :param x_mask: mask for input dropout :param h_mask: mask for hidden dropout :return: the hidden output of the current step """ s0 = self._compute_init_state(x, h_prev, x_mask, h_mask) # states contains the nodes computed as described in the paper, # that means, the sum of output of the operations of the incoming # edges. If genotype defined, there is only one incoming edge # as a constraint described in the paper. states = [s0] # IMPORTANT: genotype is None when doing arch search # "i" is the index of the next intermediate node, # "name" is the label of the activation function, # "pred" is the index of the previous node, so the edge will be pred -->name--->i for i, (name, pred) in enumerate(self.genotype.recurrent): # taking the previous state using its index s_prev = states[pred] if self.use_edge_matrices: # sum of first (i-1) natural numbers plus the previous node index edge_weights_id = sum(num for num in range(1, i)) + pred else: edge_weights_id = i # applying dropout masks if training. # computing the matrix mul between the previous output # and the weights of the current node "i" (FC layer) if self.training: ch = (s_prev * h_mask).mm(self._Ws[edge_weights_id]) else: ch = s_prev.mm(self._Ws[edge_weights_id]) c, h = torch.split(ch, self.nhid, dim=-1) c = c.sigmoid() # getting the chosen activation function fn = self._get_activation(name) # activation function on hidden h = fn(h) s = s_prev + c * (h - s_prev) states += [s] # computing the output as the mean of the output of # the INTERMEDIATE nodes, where their index are # defined by the "concat" list in the genotype output = torch.mean( torch.stack([states[i] for i in self.genotype.concat], -1), -1) if self.handle_hidden_mode == 'ACTIVATION': # avoid the explosion of the hidden state by forcing the value in [-1, 1] output = F.tanh(output) return output