def forward(self, x, geneexpr): out1 = F.relu(self.conv1(x)) # 3 of these out1a = F.relu(self.conv1a(x)) out1b = F.relu(self.conv1b(x)) out = self.maxpool_3(torch.cat([out1,out1a,out1b],dim=1)) # (?, 300, 600) out = F.pad(out,(5,5)) out = F.relu(self.conv2(out)) # (?, 300, 140) out = self.maxpool_4(out) # (?, 300, 35) out = F.pad(out,(3,3)) out = F.relu(self.conv3(out)) # (?, 500, 32) out = F.pad(out,(1,1)) out = self.maxpool_4(out) # (?, 500, 8) out = out.view(-1, 200*13) # (?, 500*8) if self.gdl == 0: geneexpr = self.dropout(geneexpr) geneexpr = F.relu(self.genelinear(geneexpr)) elif self.gdl == 1: geneexpr = F.relu(self.genelinear(geneexpr)) # (?, 500) geneexpr = self.dropout(geneexpr) out = torch.cat([out, geneexpr], dim=1) # (?, 200*13+500) out = F.relu(self.linear1(out)) # (?, 800) out = self.dropout(out) out = F.relu(self.linear2(out)) # (?, 800) out = self.dropout(out) return self.output(out) # (?, 1)
def forward(self, x, geneexpr): #if sparse_in: # (?, 600, 4) # in_seq = to_one_hot(x, n_dims=4).permute(0,3,1,2).squeeze() #else: # in_seq = x.squeeze() x = F.pad(x,(9,9)) out = F.relu(self.conv1(x)) # (?, 4, 580) out = self.maxpool_3(out) # (?, 30, 145) out = F.pad(out,(5,5)) out = F.relu(self.conv2(out)) # (?, 300, 140) out = self.maxpool_4(out) # (?, 300, 35) out = F.pad(out,(3,3)) out = F.relu(self.conv3(out)) # (?, 500, 32) out = F.pad(out,(1,1)) out = self.maxpool_4(out) # (?, 500, 8) out = out.view(-1, 200*13) # (?, 500*8) if self.gdl == 0: geneexpr = self.dropout(geneexpr) geneexpr = F.relu(self.genelinear(geneexpr)) elif self.gdl == 1: geneexpr = F.relu(self.genelinear(geneexpr)) # (?, 500) geneexpr = self.dropout(geneexpr) elif self.gdl == 2: geneexpr = F.normalize(self.genelinear(geneexpr), p=2, dim=1) out = torch.cat([out, geneexpr], dim=1) # (?, 200*13+500) out = F.relu(self.linear1(out)) # (?, 800) out = self.dropout(out) out = F.relu(self.linear2(out)) # (?, 800) out = self.dropout(out) return self.output(out) # (?, 1)
def fixed_padding(inputs, kernel_size, rate): kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) pad_total = kernel_size_effective - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) return padded_inputs
def tile(self, x): n_height_padding = self.patch_size_height - \ x.size(2) % self.patch_size_height n_width_padding = self.patch_size_width - \ x.size(3) % self.patch_size_width n_top_padding = n_height_padding / 2 n_bottom_padding = n_height_padding - n_top_padding n_left_padding = n_width_padding / 2 n_right_padding = n_width_padding - n_left_padding x = F.pad(x, (n_left_padding, n_right_padding, n_top_padding, n_bottom_padding)) b, n_filters, n_height, n_width = x.size() assert n_height % self.patch_size_height == 0 assert n_width % self.patch_size_width == 0 new_height = n_height / self.patch_size_height new_width = n_width / self.patch_size_width x = x.view(b, n_filters, new_height, self.patch_size_height, new_width, self.patch_size_width) x = x.permute(0, 2, 4, 1, 3, 5) x = x.contiguous() x = x.view(b, new_height, new_width, self.patch_size_height * self.patch_size_width * n_filters) x = x.permute(0, 3, 1, 2) x = x.contiguous() return x
def decode_step(self, enc_hs, enc_mask, input_, hidden): src_seq_len, bat_siz = enc_mask.shape h_t, hidden = self.dec_rnn(input_, hidden) # Concatenate the ht and hs # ctx_trans: batch x seq_len x (trg_hid_siz*2) ctx_trans = torch.cat( (h_t.unsqueeze(1).expand(-1, src_seq_len, -1), enc_hs[1].transpose( 0, 1)), dim=2) trans = F.softmax(self.trans(ctx_trans), dim=-1) trans_list = trans.split(1, dim=1) ws = (self.wid_siz - 1) // 2 trans_shift = [ F.pad(t, (-ws + i, src_seq_len - (ws + 1) - i)) for i, t in enumerate(trans_list) ] trans = torch.cat(trans_shift, dim=1) trans = trans * enc_mask.transpose(0, 1).unsqueeze(1) + EPSILON trans = trans / trans.sum(-1, keepdim=True) trans = trans.log() # Concatenate the ht and hs # ctx_emiss: batch x seq_len x (trg_hid_siz+src_hid_size*2) ctx_emiss = torch.cat( (h_t.unsqueeze(1).expand(-1, src_seq_len, -1), enc_hs[0].transpose( 0, 1)), dim=2) ctx = torch.tanh(self.linear_out(ctx_emiss)) # emiss: batch x seq_len x nb_vocab emiss = F.log_softmax(self.final_out(ctx), dim=-1) return trans, emiss, hidden
def forward(self, x): # compute 'same' padding (batch, channel, t, h, w) = x.size() #print t,h,w out_t = np.ceil(float(t) / float(self._stride[0])) out_h = np.ceil(float(h) / float(self._stride[1])) out_w = np.ceil(float(w) / float(self._stride[2])) #print out_t, out_h, out_w pad_t = self.compute_pad(0, t) pad_h = self.compute_pad(1, h) pad_w = self.compute_pad(2, w) #print pad_t, pad_h, pad_w pad_t_f = pad_t // 2 pad_t_b = pad_t - pad_t_f pad_h_f = pad_h // 2 pad_h_b = pad_h - pad_h_f pad_w_f = pad_w // 2 pad_w_b = pad_w - pad_w_f pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) #print x.size() #print pad x = F.pad(x, pad) #print x.size() x = self.conv3d(x) if self._use_batch_norm: x = self.bn(x) if self._activation_fn is not None: x = self._activation_fn(x) return x
def forward(self, x): # compute 'same' padding (batch, channel, t, h, w) = x.size() #print t,h,w out_t = np.ceil(float(t) / float(self.stride[0])) out_h = np.ceil(float(h) / float(self.stride[1])) out_w = np.ceil(float(w) / float(self.stride[2])) #print out_t, out_h, out_w pad_t = self.compute_pad(0, t) pad_h = self.compute_pad(1, h) pad_w = self.compute_pad(2, w) #print pad_t, pad_h, pad_w pad_t_f = pad_t // 2 pad_t_b = pad_t - pad_t_f pad_h_f = pad_h // 2 pad_h_b = pad_h - pad_h_f pad_w_f = pad_w // 2 pad_w_b = pad_w - pad_w_f pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) #print x.size() #print pad x = F.pad(x, pad) return super(MaxPool3dSamePadding, self).forward(x)
def forward(self, x1, x2): x1 = self.up(x1) diffX = x1.size()[2] - x2.size()[2] diffY = x1.size()[3] - x2.size()[3] x2 = F.pad(x2, (diffX // 2, diffX // 2,diffY // 2, diffY // 2)) x = torch.cat([x2, x1], dim=1)#residual connection x = self.conv(x) return x
def occlusion_sensitivity( model, images, ids, mean=None, patch=35, stride=1, n_batches=128 ): """ "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" https://arxiv.org/pdf/1610.02391.pdf Look at Figure A5 on page 17 Originally proposed in: "Visualizing and Understanding Convolutional Networks" https://arxiv.org/abs/1311.2901 """ torch.set_grad_enabled(False) model.eval() mean = mean if mean else 0 patch_H, patch_W = patch if isinstance(patch, Sequence) else (patch, patch) pad_H, pad_W = patch_H // 2, patch_W // 2 # Padded image images = F.pad(images, (pad_W, pad_W, pad_H, pad_H), value=mean) B, _, H, W = images.shape new_H = (H - patch_H) // stride + 1 new_W = (W - patch_W) // stride + 1 # Prepare sampling grids anchors = [] grid_h = 0 while grid_h <= H - patch_H: grid_w = 0 while grid_w <= W - patch_W: grid_w += stride anchors.append((grid_h, grid_w)) grid_h += stride # Baseline score without occlusion baseline = model(images).detach().gather(1, ids) # Compute per-pixel logits scoremaps = [] for i in tqdm(range(0, len(anchors), n_batches), leave=False): batch_images = [] batch_ids = [] for grid_h, grid_w in anchors[i : i + n_batches]: images_ = images.clone() images_[..., grid_h : grid_h + patch_H, grid_w : grid_w + patch_W] = mean batch_images.append(images_) batch_ids.append(ids) batch_images = torch.cat(batch_images, dim=0) batch_ids = torch.cat(batch_ids, dim=0) scores = model(batch_images).detach().gather(1, batch_ids) scoremaps += list(torch.split(scores, B)) diffmaps = torch.cat(scoremaps, dim=1) - baseline diffmaps = diffmaps.view(B, new_H, new_W) return diffmaps
def mk_diff_img(image, channels_first=True): assert channels_first assert len(image.shape) == 2 image = image.unsqueeze(0).unsqueeze(0) x_kernel = torch.Tensor([ [-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]) x_kernel = x_kernel.view((1, 1, 3, 3)).cuda() y_kernel = torch.Tensor([ [-1, -1, -1], [0, 0, 0], [1, 1, 1]]) y_kernel = y_kernel.view((1, 1, 3, 3)).cuda() # padded_image = F.pad(image, [1, 1, 1, 1], value=image.abs().max()) padded_image = F.pad(image, [1, 1, 1, 1], value=0) diff_img_x = F.conv2d(padded_image, y_kernel) diff_img_y = F.conv2d(padded_image, x_kernel) image = image.squeeze(0).squeeze(0) diff_img_x.squeeze_(0) diff_img_x.squeeze_(0) diff_img_y.squeeze_(0) diff_img_y.squeeze_(0) class LocalFunction(autograd.Function): def __init__(self): super().__init__() @staticmethod def forward(ctx, points_positions): assert len(points_positions.shape) == 2 points_positions_detached = points_positions.detach().round().long() points_positions_detached[:, 0].clamp_(0, image.shape[0] - 1) points_positions_detached[:, 1].clamp_(0, image.shape[1] - 1) ctx.save_for_backward(points_positions_detached) return image[points_positions_detached[:, 0], points_positions_detached[:, 1]] @staticmethod def backward(ctx, grad_outputs): points_positions_detached, = ctx.saved_tensors d_x = diff_img_x[points_positions_detached[:, 0], points_positions_detached[:, 1]] d_y = diff_img_y[points_positions_detached[:, 0], points_positions_detached[:, 1]] res = torch.zeros(points_positions_detached.shape).cuda() res[:, 0] = grad_outputs * d_x res[:, 1] = grad_outputs * d_y return res return LocalFunction()
def _extract_patches(x, kernel_size, stride, padding): if padding[0] + padding[1] > 0: x = F.pad(x, (padding[1], padding[1], padding[0], padding[0])).data # Actually check dims x = x.unfold(2, kernel_size[0], stride[0]) x = x.unfold(3, kernel_size[1], stride[1]) x = x.transpose_(1, 2).transpose_(2, 3).contiguous() x = x.view( x.size(0), x.size(1), x.size(2), x.size(3) * x.size(4) * x.size(5)) return x
def forward(self, x): out = F.relu(self.conv1(x)) # (?, 1024, 571) out = F.pad(out,(14,0)) # (?, 1024, 585) out = self.maxpool(out) # (?, 1024, 39) out = self.dropout_2(out) # (?, 1024, 39) out = out.permute(2,0,1) # (39, ?, 1024) out,_ = self.lstm(out, self.initHidden(out.size(1))) # (39, ?, 1024) out = self.dropout_3(out) # (39, ?, 1024) out = out.transpose(1,0).reshape(-1,39*1024) # (/, 39*1024) out = F.relu(self.linear(out)) # (?, 925) return self.output(out) # (?, 164)
def forward(self, x, geneexpr): #if sparse_in: # (?, 600, 4) # in_seq = to_one_hot(x, n_dims=4).permute(0,3,1,2).squeeze() #else: # in_seq = x.squeeze() x = F.pad(x,(9,9)) out = F.relu(self.conv1(x)) # (?, 4, 580) out = self.maxpool_3(out) # (?, 30, 145) out = F.pad(out,(5,5)) out = F.relu(self.conv2(out)) # (?, 300, 140) out = self.maxpool_4(out) # (?, 300, 35) out = F.pad(out,(3,3)) out = F.relu(self.conv3(out)) # (?, 500, 32) out = F.pad(out,(1,1)) out = self.maxpool_4(out) # (?, 500, 8) out = out.view(-1, self.flat_sz) # (?, 500*8) out = F.relu(self.tucker(geneexpr,out)) out = self.dropout(out) out = F.relu(self.linear(out)) # (?, 800) out = self.dropout(out) return self.output(out) # (?, 1)
def forward(self, x, geneexpr): out = F.relu(self.conv1(x)) # (?, 320, 571) out = F.pad(out,(14,0)) # (?, 320, 585) out = self.maxpool(out) # (?, 320, 45) out = self.dropout_2(out) # (?, 320, 45) out = out.permute(2,0,1) # (45, ?, 320) out,_ = self.lstm(out, self.initHidden(out.size(1))) # (45, ?, 320) out = self.dropout_3(out) # (45, ?, 320) out = out.transpose(1,0).reshape(-1,45*320) # (/, 45*320) # NEW geneexpr = F.relu(self.genelinear(geneexpr)) # (?, 500) geneexpr = self.dropout_4(geneexpr) out = torch.cat([out,geneexpr], dim = 1) # (?, 45*320+500) out = F.relu(self.linear(out)) # (?, 925) return self.output(out) # (?, 1)
def deepcompare_2ch2stream(input, params): def stream(input, name): o = conv2d(input, params, name + '.conv0') o = F.max_pool2d(F.relu(o), 2, 2) o = conv2d(o, params, name + '.conv1') o = F.max_pool2d(F.relu(o), 2, 2) o = conv2d(o, params, name + '.conv2') o = F.relu(o) o = conv2d(o, params, name + '.conv3') o = F.relu(o) return o.view(o.size(0), -1) o_fovea = stream(F.avg_pool2d(input, 2, 2), 'fovea') o_retina = stream(F.pad(input, (-16,) * 4), 'retina') o = linear(torch.cat([o_fovea, o_retina], dim=1), params, 'fc0') return linear(F.relu(o), params, 'fc1')
def forward(self, x, geneexpr): # x is of size (?,4,600) out = x.permute(2,0,1) # (600,?,4) out,_ = self.lstm(out,self.initHidden(out.size(1))) # (600,?,10) out = self.dropout_2(out) # (600,?,10) out = out.permute(1,2,0) # (?,10,600) out = F.relu(self.conv1(out)) # (?,1024,571) out = F.pad(out,(14,0)) # (?,1024,585) out = self.maxpool(out) # (?,1024,39) out = self.dropout_2(out) # (?,1024,39) out = out.transpose(1,0).reshape(-1,39*1024) # (?,39*1024) # NEW geneexpr = F.relu(self.genelinear(geneexpr)) # (?, 500) geneexpr = self.dropout_4(geneexpr) out = torch.cat([out,geneexpr], dim = 1) # (?, 39*1024+500) out = F.relu(self.linear(out)) # (?, 925) return self.output(out) # (?, 1)
def forward(self, x, geneexpr): out = F.relu(self.conv1(x)) # (?, 1024, 571) out = F.pad(out,(14,0)) # (?, 1024, 585) out = self.maxpool(out) # (?, 1024, 39) out = self.dropout_2(out) # (?, 1024, 39) out = out.permute(2,0,1) # (39, ?, 1024) out,_ = self.lstm(out, self.initHidden(out.size(1))) # (39,?,1024) out = self.dropout_3(out) # (39,?,1024) out = out.permute(1,0,2) # (?,39,1024) geneexpr = F.relu(self.genelinear(geneexpr)) # (?,1024) geneexpr = geneexpr.unsqueeze(2) # (?,1024,1) scores = torch.bmm(out,geneexpr) # (?,39,1) attn_dist = F.softmax(scores,dim=1) # (?,39,1) out = torch.bmm(attn_dist.transpose(2,1),out) # (?,1,1024) out = out.squeeze(1) # (?,1024) out = F.relu(self.linear(out)) # (?,925) out = self.dropout_4(out) # (?,925) return self.output(out)
def build_sequences(sequences, nenvs, nsteps, depth, return_mask=False, offset=0): # sequences are bs x size, containing e.g. rewards, actions, state reps # returns bs x depth x size processed sequences with a sliding window set by 'depth', padded with 0's # if return_mask=True also returns a mask showing where the sequences were padded # This can be used to produce targets for tree outputs, from the true observed sequences sequences = [s.view(nenvs, nsteps, -1) for s in sequences] if return_mask: mask = torch.ones_like(sequences[0]).float() sequences.append(mask) sequences = [F.pad(s, (0, 0, 0, depth+offset, 0, 0), mode="constant", value=0).data for s in sequences] proc_sequences = [] for seq in sequences: proc_seq = [] for env in range(seq.shape[0]): for t in range(nsteps): proc_seq.append(seq[env, t+offset:t+offset+depth, :]) proc_sequences.append(torch.stack(proc_seq)) return proc_sequences
def __init__(self, in_planes, planes, stride=1, option='A'): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: if option == 'A': """ For CIFAR10 ResNet paper uses option A. """ self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) elif option == 'B': self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) )
def forward(self, src_tokens, src_lengths): # embed tokens and positions x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) input_embedding = x.transpose(0, 1) # project to size of convolution x = self.fc1(x) # B x T x C -> T x B x C x = x.transpose(0, 1) # temporal convolutions for proj, conv, attention in zip(self.projections, self.convolutions, self.attention): residual = x if proj is None else proj(x) x = F.dropout(x, p=self.dropout, training=self.training) padding_l = (conv.kernel_size[0] - 1) // 2 padding_r = conv.kernel_size[0] // 2 x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) x = conv(x) x = F.glu(x, dim=2) if attention is not None: x = attention(x) x = (x + residual) * math.sqrt(0.5) # T x B x C -> B x T x C x = x.transpose(1, 0) # project back to size of embedding x = self.fc2(x) # scale gradients (this only affects backward, not forward) x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) # add output to input embedding for attention y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5) return { 'encoder_out': (x, y), }
def conv2d_same_padding(input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1): input_rows = input.size(2) filter_rows = weight.size(2) effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1 out_rows = (input_rows + stride[0] - 1) // stride[0] padding_needed = max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows - input_rows) padding_rows = max(0, (out_rows - 1) * stride[0] + (filter_rows - 1) * dilation[0] + 1 - input_rows) rows_odd = (padding_rows % 2 != 0) padding_cols = max(0, (out_rows - 1) * stride[0] + (filter_rows - 1) * dilation[0] + 1 - input_rows) cols_odd = (padding_rows % 2 != 0) if rows_odd or cols_odd: input = pad(input, [0, int(cols_odd), 0, int(rows_odd)]) return F.conv2d(input, weight, bias, stride, padding=(padding_rows // 2, padding_cols // 2), dilation=dilation, groups=groups)
def get_P_net_res(self, x): h, w = x.shape[2:] roi = cuda(torch.tensor([[0, 0, 0, w, h]]).float()) i = 0 all_bboxes = [] all_score = [] while True: n_h, n_w = int(h * self.scale ** i), int(w * self.scale ** i) if n_h < 12 or n_w < 12: break roialign = ROIAlign((n_h, n_w), 1 / 1., 2) xx = roialign(x, roi) # xx = F.interpolate(x, size=(n_h, n_w)) a = np.ceil(n_h / 12.) * 12 b = np.ceil(n_w / 12.) * 12 a = int(a) b = int(b) xx = F.pad(xx, (0, b - n_w, 0, a - n_h), mode='constant', value=127.5) xx = (xx - 127.5) / 128.0 P_net_logits, P_net_loc, P_net_landmarks = self.P_net(xx) map_H, map_W = P_net_logits.shape[2:] P_net_logits = P_net_logits.permute(0, 2, 3, 1).contiguous().view(-1, 2) P_net_loc = P_net_loc.permute(0, 2, 3, 1).contiguous().view(-1, 4) P_net_landmarks = P_net_landmarks.permute(0, 2, 3, 1).contiguous().view(-1, 10) anchors = self.anchors[:map_H, :map_W].contiguous().view(-1, 4) / self.scale ** i i += 1 score = F.softmax(P_net_logits, dim=-1)[..., 1] inds = score >= self.config.P_net_conf_thresh if inds.sum() == 0: continue score = score[inds] P_net_loc = P_net_loc[inds] anchors = anchors[inds] bboxes = loc2bbox(P_net_loc, anchors) bboxes[..., slice(0, 4, 2)] = torch.clamp(bboxes[..., slice(0, 4, 2)], 0, w - 1) bboxes[..., slice(1, 4, 2)] = torch.clamp(bboxes[..., slice(1, 4, 2)], 0, h - 1) hw = bboxes[..., 2:4] - bboxes[..., :2] inds = hw >= self.config.roi_min_size[0] inds = inds.all(dim=-1) if inds.sum() == 0: continue bboxes = bboxes[inds] score = score[inds] score, inds = score.sort(descending=True) bboxes = bboxes[inds] keep = _box_nms(bboxes, score, 0.5) score = score[keep] bboxes = bboxes[keep] all_bboxes.append(bboxes) all_score.append(score) if len(all_bboxes) == 0: return cuda(torch.zeros((0, 5))) bboxes = torch.cat(all_bboxes, dim=0) score = torch.cat(all_score, dim=0) score, inds = score.sort(descending=True) bboxes = bboxes[inds] keep = _box_nms(bboxes, score, self.config.P_net_iou_thresh) bboxes = bboxes[keep] score = score[keep] return torch.cat([bboxes, score.view(-1, 1)], dim=1)
def main(): global args, best_prec1 args = parser.parse_args() setup_logger(args) if args.fp16: try: from apex.fp16_utils import FP16_Optimizer except: print_and_log( 'WARNING: apex not installed, ignoring --fp16 option') args.fp16 = False kwargs = {'num_workers': 1, 'pin_memory': True} dataset = args.model.split('_')[0] if dataset == 'mnist': full_dataset = datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) if not (args.validate_set): train_loader = torch.utils.data.DataLoader( full_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = None else: train_dataset = split_dataset(full_dataset, split_end=50000) val_dataset = split_dataset(full_dataset, split_start=50000) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.MNIST( './data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.batch_size, shuffle=False, **kwargs) elif dataset == 'cifar10': normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) if args.augment: transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) full_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train) if not (args.validate_set): train_loader = torch.utils.data.DataLoader( full_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = None else: train_dataset = split_dataset(full_dataset, split_end=45000) val_dataset = split_dataset(full_dataset, split_start=45000) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './data', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) elif dataset == 'imagenet': if not (args.data): raise Exception( 'need to specify imagenet dataset location using the --data argument' ) traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) full_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_sampler = None if not (args.validate_set): train_loader = torch.utils.data.DataLoader( full_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = None else: train_dataset = split_dataset(full_dataset, split_end=len(full_dataset) - 10000) val_dataset = split_dataset(full_dataset, split_start=len(full_dataset) - 10000) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) test_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) else: raise RuntimeError( 'Unknown dataset {}. Dataset is first segment of network name'. format(dataset)) print_and_log(args) with open(args.schedule_file, 'r') as stream: try: loaded_schedule = yaml.load(stream) except yaml.YAMLError as exc: print_and_log(exc) if args.model == 'mnist_mlp': model = mnist_mlp(initial_sparsity=args.initial_sparsity_fc, sparse=not (args.tied), no_batch_norm=args.no_batch_norm) elif args.model == 'cifar10_WideResNet': model = cifar10_WideResNet( args.layers, widen_factor=args.widen_factor, initial_sparsity_conv=args.initial_sparsity_conv, initial_sparsity_fc=args.initial_sparsity_fc, sub_kernel_granularity=args.sub_kernel_granularity, sparse=not (args.tied)) elif args.model == 'imagenet_resnet50': model = imagenet_resnet50( initial_sparsity_conv=args.initial_sparsity_conv, initial_sparsity_fc=args.initial_sparsity_fc, sub_kernel_granularity=args.sub_kernel_granularity, widen_factor=args.widen_factor, vanilla_conv1=True, vanilla_conv3=True, vanilla_downsample=True, sparse=not args.sparse_momentum) else: raise RuntimeError('unrecognized model name ' + repr(args.model)) model = model.cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay) if args.fp16: print_and_log('FP16') optimizer = FP16_Optimizer(optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={'init_scale': 2**16}) model = model.half() mask = None if not args.dense: decay = CosineDecay(args.prune_rate, len(train_loader) * (args.epochs)) mask = Masking(optimizer, decay, prune_rate=args.prune_rate, prune_mode='magnitude', growth_mode=args.growth, redistribution_mode=args.redistribution, verbose=True, fp16=args.fp16) mask.add_module(model, density=args.density) #mask.remove_weight_partial_name('downsample', verbose=True) #mask.remove_weight('conv1.weight') if dataset == 'imagenet': print_and_log('setting up data parallel') model = torch.nn.DataParallel(model).cuda() base_model = model.module else: base_model = model # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print_and_log("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) #args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) if 'optimizer' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) print_and_log('OPTIM') mask.optimizer = optimizer print_and_log("=> loaded checkpoint '{}' ".format(args.resume)) else: print_and_log("=> no checkpoint found at '{}'".format(args.resume)) if args.copy_mask_from: if os.path.isfile(args.copy_mask_from): print_and_log("=> loading mask data '{}'".format( args.copy_mask_from)) mask_data = torch.load(args.copy_mask_from) filtered_mask_data = collections.OrderedDict([ (x, y) for (x, y) in mask_data['state_dict'].items() if 'mask' in x ]) model.load_state_dict(filtered_mask_data, strict=False) else: print_and_log("=> no mask checkpoint found at '{}'".format( args.copy_mask_from)) # get the number of model parameters model_size = base_model.get_model_size() cudnn.benchmark = True # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() train_loss_l = [] test_loss_l = [] train_prec1_l = [] test_prec1_l = [] train_prec5_l = [] test_prec5_l = [] val_loss_l = [] val_prec1_l = [] val_prec5_l = [] prune_mode = args.prune_mode print_and_log('PRUNE MODE ' + str(prune_mode)) start_pruning_after_epoch_n = args.start_pruning_after_epoch prune_every_epoch_n = args.prune_epoch_frequency prune_iterations = args.prune_iterations post_prune_epochs = args.post_prune_epochs filename = args.model + '_' + repr(args.job_idx) n_prunes_done = 0 if prune_mode: ## Special consideration so that pruning mnist_mlp does not use less than 100 parameters in the top layer after pruning if args.prune_target_sparsity_fc > 0.9 and args.model == 'mnist_mlp': total_available_weights = (1. - args.prune_target_sparsity_fc) * ( 784 * 300 + 300 * 100 + 100 * 10) - 100 prune_target_sparsity_special = 0.9 prune_target_sparsity_fc = 1. - total_available_weights / ( 784 * 300 + 300 * 100) else: prune_target_sparsity_fc = prune_target_sparsity_special = args.prune_target_sparsity_fc prune_fraction_fc = 1.0 - (1 - prune_target_sparsity_fc)**( 1.0 / prune_iterations) prune_fraction_conv = 1.0 - (1 - args.prune_target_sparsity_conv)**( 1.0 / prune_iterations) prune_fraction_fc_special = 1.0 - ( 1 - prune_target_sparsity_special)**(1.0 / prune_iterations) cubic_pruning_multipliers = ( 1 - np.arange(prune_iterations + 1) / prune_iterations)**3.0 def get_prune_fraction_cubic(current_prune_iter, final_sparsity): return 1 - (1 - final_sparsity + final_sparsity * cubic_pruning_multipliers[current_prune_iter + 1]) / ( 1 - final_sparsity + final_sparsity * cubic_pruning_multipliers[current_prune_iter]) nEpochs_to_prune = int(start_pruning_after_epoch_n + prune_every_epoch_n * (prune_iterations - 1)) + post_prune_epochs print_and_log( 'prune fraction fc : {} , prune_fraction conv : {} '.format( prune_fraction_fc, prune_fraction_conv)) print_and_log('nepochs ' + repr(nEpochs_to_prune)) filename += '_target_' + repr( args.prune_target_sparsity_fc) + ',' + repr( args.prune_target_sparsity_conv) validate(test_loader, model, criterion, 1, 'validate') save_checkpoint( { 'model_size': base_model.get_model_size(), 'model_name': args.model, 'state_dict': model.state_dict(), 'args': args }, filename=filename + '_initial') current_iteration = 0 lr_schedule = loaded_schedule['lr_schedule'] rewire_schedule = loaded_schedule['rewire_period_schedule'] DeepR_temperature_schedule = loaded_schedule['DeepR_temperature_schedule'] threshold = 1.0e-3 if args.resume: print_and_log("Validating...") validate(test_loader, model, criterion, 1, 'validate') for epoch in range(args.start_epoch, nEpochs_to_prune if prune_mode else args.epochs): adjust_learning_rate(optimizer, epoch, lr_schedule) rewire_period = get_schedule_val(rewire_schedule, epoch) DeepR_temperature = get_schedule_val(DeepR_temperature_schedule, epoch) print_and_log('rewiring every {} iterations'.format(rewire_period)) t1 = time.time() current_iteration, threshold = train(mask, train_loader, model, criterion, optimizer, epoch, current_iteration, rewire_period, DeepR_temperature, threshold) print_and_log('epoch time ' + repr(time.time() - t1)) if prune_mode and epoch >= start_pruning_after_epoch_n and ( epoch - start_pruning_after_epoch_n ) % prune_every_epoch_n == 0 and n_prunes_done < prune_iterations: if args.cubic_prune_schedule: base_model.prune( get_prune_fraction_cubic(n_prunes_done, prune_target_sparsity_fc), get_prune_fraction_cubic(n_prunes_done, args.prune_target_sparsity_conv), get_prune_fraction_cubic(n_prunes_done, prune_target_sparsity_special)) else: base_model.prune(prune_fraction_fc, prune_fraction_conv, prune_fraction_fc_special) n_prunes_done += 1 print_and_log(base_model.get_model_size()) if not (args.no_validate_train): prec1_train, prec5_train, loss_train = validate( train_loader, model, criterion, epoch, 'train') else: prec1_train, prec5_train, loss_train = 0.0, 0.0, 0.0 if args.validate_set: prec1_val, prec5_val, loss_val = validate(val_loader, model, criterion, epoch, 'validate') else: prec1_val, prec5_val, loss_val = 0.0, 0.0, 0.0 prec1_test, prec5_test, loss_test = validate(test_loader, model, criterion, epoch, 'test') test_loss_l.append(loss_test) train_loss_l.append(loss_train) val_loss_l.append(loss_val) test_prec1_l.append(prec1_test) train_prec1_l.append(prec1_train) val_prec1_l.append(prec1_val) test_prec5_l.append(prec5_test) train_prec5_l.append(prec5_train) val_prec5_l.append(prec5_val) # remember best prec@1 and save checkpoint filenames = [filename] if epoch == args.stop_rewire_epoch: filenames += [filename + '_StopRewiringPoint_' + repr(epoch)] for f in filenames: save_checkpoint( { 'model_size': base_model.get_model_size(), 'test_loss': test_loss_l, 'train_loss': train_loss_l, 'val_loss': val_loss_l, 'test_prec1': test_prec1_l, 'train_prec1': train_prec1_l, 'val_prec1': val_prec1_l, 'test_prec5': test_prec5_l, 'train_prec5': train_prec5_l, 'val_prec5': train_prec5_l, 'model_name': args.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch + 1, 'args': args }, filename=f) if not args.dense and epoch < args.epochs: mask.at_end_of_epoch() print_and_log('Best accuracy: ', best_prec1)
def concatPaddedSequences(seq1, seqLens1, seq2, seqLens2, padding='right'): ''' Concates two input sequences of shape (batchSize, seqLength). The corresponding lengths tensor is of shape (batchSize). Padding sense of input sequences needs to be specified as 'right' or 'left' Args: seq1, seqLens1 : First sequence tokens and length seq2, seqLens2 : Second sequence tokens and length padding : Padding sense of input sequences - either 'right' or 'left' ''' concat_list = [] cat_seq = torch.cat([seq1, seq2], dim=1) maxLen1 = seq1.size(1) maxLen2 = seq2.size(1) maxCatLen = cat_seq.size(1) batchSize = seq1.size(0) for b_idx in range(batchSize): len_1 = seqLens1[b_idx].data[0] len_2 = seqLens2[b_idx].data[0] cat_len_ = len_1 + len_2 if cat_len_ == 0: raise RuntimeError("Both input sequences are empty") elif padding == 'left': pad_len_1 = maxLen1 - len_1 pad_len_2 = maxLen2 - len_2 if len_1 == 0: print("[Warning] Empty input sequence 1 given to " "concatPaddedSequences") cat_ = seq2[b_idx][pad_len_2:] elif len_2 == 0: print("[Warning] Empty input sequence 2 given to " "concatPaddedSequences") cat_ = seq1[b_idx][pad_len_1:] else: cat_ = torch.cat([seq1[b_idx][pad_len_1:], seq2[b_idx][pad_len_2:]], 0) cat_padded = F.pad( input=cat_, # Left pad pad=((maxCatLen - cat_len_), 0), mode="constant", value=0) elif padding == 'right': if len_1 == 0: print("[Warning] Empty input sequence 1 given to " "concatPaddedSequences") cat_ = seq2[b_idx][:len_1] elif len_2 == 0: print("[Warning] Empty input sequence 2 given to " "concatPaddedSequences") cat_ = seq1[b_idx][:len_1] else: cat_ = torch.cat([seq1[b_idx][:len_1], seq2[b_idx][:len_2]], 0) # cat_ = cat_seq[b_idx].masked_select(cat_seq[b_idx].ne(0)) cat_padded = F.pad( input=cat_, pad=(0, (maxCatLen - cat_len_)), mode="constant", value=0) else: raise (ValueError, "Expected padding to be either 'left' or \ 'right', got '%s' instead." % padding) concat_list.append(cat_padded.unsqueeze(0)) concat_output = torch.cat(concat_list, 0) return concat_output
def main(args): model0 = build_model().to(device) model0.eval() model1 = build_model().to(device) model1.eval() receptive_field = model0.receptive_field x0_original = np.load("supra_piano/dump/dev/zf882fv0052-wave.npy") x0_original = x0_original[200000:200000 + SAMPLE_SIZE] # x0_original = np.load("vctk/dump/dev/p374_422-wave.npy") # x0_original = x0_original[20000:20000 + SAMPLE_SIZE] x1_original = np.load("vctk/dump/dev/p341_048-wave.npy") x1_original = x1_original[32000:32000 + SAMPLE_SIZE] mixed = torch.FloatTensor(x0_original + x1_original).reshape(1, -1).to(device) # Write inputs mixed_out = inv_linear_quantize(mixed[0].detach().cpu().numpy(), hparams.quantize_channels - 1) - 1.0 mixed_out = np.clip(mixed_out, -1, 1) sf.write("mixed.wav", mixed_out, hparams.sample_rate) x0_original_out = inv_linear_quantize(x0_original, hparams.quantize_channels - 1) sf.write("x0_original.wav", x0_original_out, hparams.sample_rate) x1_original_out = inv_linear_quantize(x1_original, hparams.quantize_channels - 1) sf.write("x1_original.wav", x1_original_out, hparams.sample_rate) # Initialize with noise x0 = torch.FloatTensor(np.random.uniform(-256, 512, size=(1, SAMPLE_SIZE))).to(device) x0 = F.pad(x0, (receptive_field, 0), "constant", 127) x0.requires_grad = True x1 = torch.FloatTensor(np.random.uniform(-256, 512, size=(1, SAMPLE_SIZE))).to(device) x1 = F.pad(x1, (receptive_field, 0), "constant", 127) x1.requires_grad = True # Initialize with noised GT x0[0, receptive_field:] = torch.FloatTensor( x0_original + np.random.normal(0, 256., x0_original.shape)).to(device) x1[0, receptive_field:] = torch.FloatTensor( x1_original + np.random.normal(0, 256., x1_original.shape)).to(device) sigmas = [ 175.9, 110., 68.7, 42.9, 26.8, 16.8, 10.5, 4.1, 2.56, 1.6, 1.0, 0.0 ] n_steps = 10000 start_sigma = 256. end_sigma = 0.1 # Exponential annealing ratio = (end_sigma / start_sigma)**(1.0 / n_steps) sigma = start_sigma # Dummy start values curr_model_idx = -1 curr_model_sigma = 1000000. for i in range(n_steps): # Bump down a model if sigma < curr_model_sigma: curr_model_idx += 1 curr_model_sigma = sigmas[curr_model_idx] checkpoint_path0 = join(args["<checkpoint0>"], checkpoints[curr_model_sigma], "checkpoint_latest.pth") checkpoint_path1 = join(args["<checkpoint1>"], checkpoints[curr_model_sigma], "checkpoint_latest.pth") print("Load checkpoint0 from {}".format(checkpoint_path0)) checkpoint0 = torch.load(checkpoint_path0) checkpoint1 = torch.load(checkpoint_path1) model0.load_state_dict(checkpoint0["state_dict"]) model1.load_state_dict(checkpoint1["state_dict"]) eta = .05 * (sigma**2) gamma = 15 * (1.0 / sigma)**2 # Uncomment to see GT log likelihoods per sigma # x0[0, receptive_field:] = torch.FloatTensor(x0_original + np.random.normal(0, sigma, x0_original.shape)).to(device) # x1[0, receptive_field:] = torch.FloatTensor(x1_original + np.random.normal(0, sigma, x1_original.shape)).to(device) # Forward pass model0.zero_grad() log_prob, prediction0 = model0.smoothed_loss(x0, sigma=sigma) log_prob0 = torch.sum(log_prob[:, (receptive_field - 1):]) # log_prob0 = torch.sum(log_prob) grad0 = torch.autograd.grad(log_prob0, x0)[0] x0_update = eta * grad0[:, receptive_field:] # x0_update = eta * grad0 model1.zero_grad() log_prob, prediction1 = model1.smoothed_loss(x1, sigma=sigma) log_prob1 = torch.sum(log_prob[:, (receptive_field - 1):]) # log_prob1 = torch.sum(log_prob) grad1 = torch.autograd.grad(log_prob1, x1)[0] x1_update = eta * grad1[:, receptive_field:] # x1_update = eta * grad1 # Langevin step epsilon0 = np.sqrt(2 * eta) * torch.normal( 0, 1, size=(1, SAMPLE_SIZE), device=device) x0_update += epsilon0 epsilon1 = np.sqrt(2 * eta) * torch.normal( 0, 1, size=(1, SAMPLE_SIZE), device=device) x1_update += epsilon1 # Reconstruction step # x0_update -= eta * gamma * (x0[:, receptive_field:] + x1[:, receptive_field:] - mixed) # x1_update -= eta * gamma * (x0[:, receptive_field:] + x1[:, receptive_field:] - mixed) # x0_update -= eta * gamma * (x0 + x1 - mixed) # x1_update -= eta * gamma * (x0 + x1 - mixed) with torch.no_grad(): x0[:, receptive_field:] += x0_update x1[:, receptive_field:] += x1_update # x0 += x0_update # x1 += x1_update if not i % 50: # debugging print("--------------") print('sigma = {}'.format(sigma)) print('eta = {}'.format(eta)) print("i {}".format(i)) print("Max sample {}".format(abs(x0).max())) print('Mean sample logpx: {}'.format(log_prob0 / SAMPLE_SIZE)) print('Mean sample logpy: {}'.format(log_prob1 / SAMPLE_SIZE)) print("Max gradient update: {}".format(eta * abs(grad0).max())) # print("Reconstruction: {}".format(abs(x0 + x1 - mixed).mean())) # Reduce sigma sigma *= ratio # out0 = P.inv_mulaw_quantize(x0[0].detach().cpu().numpy(), hparams.quantize_channels - 1) out0 = inv_linear_quantize(x0[0].detach().cpu().numpy(), hparams.quantize_channels - 1) out0 = np.clip(out0, -1, 1) sf.write("out0.wav", out0, hparams.sample_rate) out1 = inv_linear_quantize(x1[0].detach().cpu().numpy(), hparams.quantize_channels - 1) out1 = np.clip(out1, -1, 1) sf.write("out1.wav", out1, hparams.sample_rate) import pdb pdb.set_trace()
def _forward_impl(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: # This version will not deal with the static key/value pairs. # Keeping it here for future changes. # # TODO: This method has some duplicate lines with the # `torch.nn.functional.multi_head_attention`. Will need to refactor. static_k = None static_v = None tgt_len, bsz, embed_dim_to_check = query.size() assert self.embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = self.embed_dim // self.num_heads assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 q = self.linear_Q(query) k = self.linear_K(key) v = self.linear_V(value) q = self.q_scaling_product.mul_scalar(q, scaling) if attn_mask is not None: assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) if attn_mask.dtype == torch.uint8: warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: raise RuntimeError('The size of the 2D attn_mask is not correct.') elif attn_mask.dim() == 3: if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]: raise RuntimeError('The size of the 3D attn_mask is not correct.') else: raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") key_padding_mask = key_padding_mask.to(torch.bool) if self.bias_k is not None and self.bias_v is not None: if static_k is None and static_v is None: k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = nnF.pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = nnF.pad(key_padding_mask, (0, 1)) else: assert static_k is None, "bias cannot be added to static key." assert static_v is None, "bias cannot be added to static value." else: assert self.bias_k is None assert self.bias_v is None q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * self.num_heads assert static_k.size(2) == head_dim k = static_k if static_v is not None: assert static_v.size(0) == bsz * self.num_heads assert static_v.size(2) == head_dim v = static_v src_len = k.size(1) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: src_len += 1 k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) if k.is_quantized: k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype) k = torch.cat([k, k_zeros], dim=1) v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) if v.is_quantized: v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype) v = torch.cat([v, v_zeros], dim=1) if attn_mask is not None: attn_mask = nnF.pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = nnF.pad(key_padding_mask, (0, 1)) # Leaving the quantized zone here q = self.dequant_q(q) k = self.dequant_k(k) v = self.dequant_v(v) attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float('-inf')) else: attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_output_weights = nnF.softmax( attn_output_weights, dim=-1) attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) # Reentering the quantized zone attn_output = self.quant_attn_output(attn_output) attn_output = self.out_proj(attn_output) # type: ignore attn_output_weights = self.quant_attn_output_weights(attn_output_weights) if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) return attn_output, attn_output_weights.mean(dim=1) else: return attn_output, None
def forward(self, inputs1, inputs2): outputs2 = self.up(inputs2) offset = outputs2.size()[2] - inputs1.size()[2] padding = 2 * [offset // 2, offset // 2] outputs1 = F.pad(inputs1, padding) return self.conv(torch.cat([outputs1, outputs2], 1))
def calculate_local_path_costs(self, non_obstacle_cost_map=None): coords = self.coords h = coords.size(2) w = coords.size(3) obstacles_pd = F.pad(self.obstacles, (1, 1, 1, 1), "replicate") if non_obstacle_cost_map is None: learned_bias = torch.ones_like(self.obstacles).to( obstacles_pd.device ) else: learned_bias = non_obstacle_cost_map.to(obstacles_pd.device) left_diff_sq = ( self.gx_to_left( F.pad(coords[:, 1:2, :, :], (1, 1, 0, 0), "replicate") ) ** 2 ) right_diff_sq = ( self.gx_to_right( F.pad(coords[:, 1:2, :, :], (1, 1, 0, 0), "replicate") ) ** 2 ) up_diff_sq = ( self.gy_to_up( F.pad(coords[:, 0:1, :, :], (0, 0, 1, 1), "replicate") ) ** 2 ) down_diff_sq = ( self.gy_to_down( F.pad(coords[:, 0:1, :, :], (0, 0, 1, 1), "replicate") ) ** 2 ) out = torch.cat( [ # Order in from up to down, from left to right # hopefully same as in PyTorch torch.sqrt(left_diff_sq + up_diff_sq + self.eps) + self.ob_cost * torch.max( obstacles_pd[:, :, 0:h, 0:w], obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], ), torch.sqrt(left_diff_sq + self.eps) + self.ob_cost * torch.max( obstacles_pd[:, :, 0:h, 1 : w + 1], obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], ), torch.sqrt(left_diff_sq + down_diff_sq + self.eps) + self.ob_cost * torch.max( obstacles_pd[:, :, 2 : h + 2, 0:w], obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], ), torch.sqrt(up_diff_sq + self.eps) + self.ob_cost * torch.max( obstacles_pd[:, :, 0:h, 1 : w + 1], obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], ), 0 * right_diff_sq + self.ob_cost * obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], # current center torch.sqrt(down_diff_sq + self.eps) + self.ob_cost * torch.max( obstacles_pd[:, :, 2 : h + 2, 1 : w + 1], obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], ), torch.sqrt(right_diff_sq + up_diff_sq + self.eps) + self.ob_cost * torch.max( obstacles_pd[:, :, 0:h, 2 : w + 2], obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], ), torch.sqrt(right_diff_sq + self.eps) + self.ob_cost * torch.max( obstacles_pd[:, :, 1 : h + 1, 2 : w + 2], obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], ), torch.sqrt(right_diff_sq + down_diff_sq + self.eps) + self.ob_cost * torch.max( obstacles_pd[:, :, 2 : h + 2, 2 : w + 2], obstacles_pd[:, :, 1 : h + 1, 1 : w + 1], ), ], dim=1, ) return out + torch.clamp( learned_bias.expand_as(out), min=0, max=self.ob_cost )
def streams(patch, params): o_retina = siam_stream(F.pad(patch, (-16,) * 4), params, 'retina') o_fovea = siam_stream(F.avg_pool2d(patch, 2, 2), params, 'fovea') return torch.cat([o_retina, o_fovea], dim=1)
def _call(self, x): offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) z = sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1) return y
def _pad(x, crop_size): h, w = x.size()[2:] pad_h = max(crop_size - h, 0) pad_w = max(crop_size - w, 0) x = F.pad(x, (0, pad_w, 0, pad_h)) return x, pad_h, pad_w
def crop_and_concat(upsampled, bypass, crop=False): if crop: c = (bypass.size()[2] - upsampled.size()[2]) // 2 bypass = F.pad(bypass, (-c, -c, -c, -c)) return torch.cat((upsampled, bypass), 1)
def debug(select): """Compare numpy + librosa and pytorch implementation result. For debug. Args: select: 'dft' | 'logmel' """ if select == 'dft': n = 10 norm = None # None | 'ortho' np.random.seed(0) # Data np_data = np.random.uniform(-1, 1, n) pt_data = torch.Tensor(np_data) # Numpy FFT np_fft = np.fft.fft(np_data, norm=norm) np_ifft = np.fft.ifft(np_fft, norm=norm) np_rfft = np.fft.rfft(np_data, norm=norm) np_irfft = np.fft.ifft(np_rfft, norm=norm) # Pytorch FFT obj = DFT(n, norm) pt_dft = obj.dft(pt_data, torch.zeros_like(pt_data)) pt_idft = obj.idft(pt_dft[0], pt_dft[1]) pt_rdft = obj.rdft(pt_data) pt_irdft = obj.irdft(pt_rdft[0], pt_rdft[1]) print( 'Comparing librosa and pytorch implementation of DFT. All numbers ' 'below should be close to 0.') print(np.mean((np.abs(np.real(np_fft) - pt_dft[0].cpu().numpy())))) print(np.mean((np.abs(np.imag(np_fft) - pt_dft[1].cpu().numpy())))) print(np.mean((np.abs(np.real(np_ifft) - pt_idft[0].cpu().numpy())))) print(np.mean((np.abs(np.imag(np_ifft) - pt_idft[1].cpu().numpy())))) print(np.mean((np.abs(np.real(np_rfft) - pt_rdft[0].cpu().numpy())))) print(np.mean((np.abs(np.imag(np_rfft) - pt_rdft[1].cpu().numpy())))) print(np.mean(np.abs(np_data - pt_irdft.cpu().numpy()))) elif select == 'stft': data_length = 32000 device = torch.device('cuda') # 'cuda' | 'cpu' np.random.seed(0) sample_rate = 16000 n_fft = 1024 hop_length = 250 win_length = 1024 window = 'hann' center = True dtype = np.complex64 pad_mode = 'reflect' # Data np_data = np.random.uniform(-1, 1, data_length) pt_data = torch.Tensor(np_data).to(device) # Numpy stft matrix np_stft_matrix = librosa.core.stft(y=np_data, n_fft=n_fft, hop_length=hop_length, window=window, center=center).T # Pytorch stft matrix pt_stft_extractor = STFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True) pt_stft_extractor.to(device) (pt_stft_real, pt_stft_imag) = pt_stft_extractor.forward(pt_data[None, :]) print( 'Comparing librosa and pytorch implementation of DFT. All numbers ' 'below should be close to 0.') print( np.mean( np.abs( np.real(np_stft_matrix) - pt_stft_real.data.cpu().numpy()[0, 0]))) print( np.mean( np.abs( np.imag(np_stft_matrix) - pt_stft_imag.data.cpu().numpy()[0, 0]))) # Numpy istft np_istft_s = librosa.core.istft(stft_matrix=np_stft_matrix.T, hop_length=hop_length, window=window, center=center, length=data_length) # Pytorch istft pt_istft_extractor = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True) pt_istft_extractor.to(device) # Recover from real and imag part pt_istft_s = pt_istft_extractor.forward(pt_stft_real, pt_stft_imag, data_length)[0, :] # Recover from magnitude and phase (pt_stft_mag, cos, sin) = magphase(pt_stft_real, pt_stft_imag) pt_istft_s2 = pt_istft_extractor.forward(pt_stft_mag * cos, pt_stft_mag * sin, data_length)[0, :] print(np.mean(np.abs(np_istft_s - pt_istft_s.data.cpu().numpy()))) print(np.mean(np.abs(np_data - pt_istft_s.data.cpu().numpy()))) print(np.mean(np.abs(np_data - pt_istft_s2.data.cpu().numpy()))) elif select == 'logmel': data_length = 32000 norm = None # None | 'ortho' device = torch.device('cuda') # 'cuda' | 'cpu' np.random.seed(0) # Spectrogram parameters sample_rate = 16000 n_fft = 1024 hop_length = 250 win_length = 1024 window = 'hann' center = True dtype = np.complex64 pad_mode = 'reflect' # Mel parameters n_mels = 64 fmin = 50 fmax = 7000 ref = 1.0 amin = 1e-10 top_db = None # Data np_data = np.random.uniform(-1, 1, data_length) pt_data = torch.Tensor(np_data).to(device) print('Comparing librosa and pytorch implementation of logmel ' 'spectrogram. All numbers below should be close to 0.') # Numpy librosa np_stft_matrix = librosa.core.stft(y=np_data, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, dtype=dtype, pad_mode=pad_mode) np_pad = np.pad(np_data, int(n_fft // 2), mode=pad_mode) np_melW = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax).T np_mel_spectrogram = np.dot(np.abs(np_stft_matrix.T)**2, np_melW) np_logmel_spectrogram = librosa.core.power_to_db(np_mel_spectrogram, ref=ref, amin=amin, top_db=top_db) # Pytorch stft_extractor = STFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True) logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, freeze_parameters=True) stft_extractor.to(device) logmel_extractor.to(device) pt_pad = F.pad(pt_data[None, None, :], pad=(n_fft // 2, n_fft // 2), mode=pad_mode)[0, 0] print(np.mean(np.abs(np_pad - pt_pad.cpu().numpy()))) pt_stft_matrix_real = stft_extractor.conv_real(pt_pad[None, None, :])[0] pt_stft_matrix_imag = stft_extractor.conv_imag(pt_pad[None, None, :])[0] print( np.mean( np.abs( np.real(np_stft_matrix) - pt_stft_matrix_real.data.cpu().numpy()))) print( np.mean( np.abs( np.imag(np_stft_matrix) - pt_stft_matrix_imag.data.cpu().numpy()))) # Spectrogram spectrogram_extractor = Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True) spectrogram_extractor.to(device) pt_spectrogram = spectrogram_extractor.forward(pt_data[None, :]) pt_mel_spectrogram = torch.matmul(pt_spectrogram, logmel_extractor.melW) print( np.mean( np.abs(np_mel_spectrogram - pt_mel_spectrogram.data.cpu().numpy()[0, 0]))) # Log mel spectrogram pt_logmel_spectrogram = logmel_extractor.forward(pt_spectrogram) print( np.mean( np.abs(np_logmel_spectrogram - pt_logmel_spectrogram[0, 0].data.cpu().numpy())))
def sample_target_adaptive(im, target_bb, search_area_factor, output_sz, mode: str = 'replicate', max_scale_change=None, mask=None): """ Extracts a crop centered at target_bb box, of area search_area_factor^2. If the crop area contains regions outside the image, it is shifted so that the it is inside the image. Further, if the crop area exceeds the image size, a smaller crop which fits the image is returned instead. args: im - Input numpy image to crop. target_bb - target box [x, y, w, h] search_area_factor - Ratio of crop size to target size output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done. mode - If 'replicate', the boundary pixels are replicated in case the search region crop goes out of image. If 'inside', the search region crop is shifted/shrunk to fit completely inside the image. If 'inside_major', the search region crop is shifted/shrunk to fit completely inside one axis of the image. max_scale_change - Maximum allowed scale change when performing the crop (only applicable for 'inside' and 'inside_major') mask - Optional mask to apply the same crop. returns: numpy image - Extracted crop. torch.Tensor - A bounding box denoting the cropped region in the image. numpy mask - Cropped mask returned only if mask is not None. """ if max_scale_change is None: max_scale_change = float('inf') if isinstance(output_sz, (float, int)): output_sz = (output_sz, output_sz) output_sz = torch.Tensor(output_sz) im_h = im.shape[0] im_w = im.shape[1] bbx, bby, bbw, bbh = target_bb.tolist() # Crop image crop_sz_x, crop_sz_y = (output_sz * ( target_bb[2:].prod() / output_sz.prod()).sqrt() * search_area_factor).ceil().long().tolist() # Get new sample size if forced inside the image if mode == 'inside' or mode == 'inside_major': # Calculate rescaling factor if outside the image rescale_factor = [crop_sz_x / im_w, crop_sz_y / im_h] if mode == 'inside': rescale_factor = max(rescale_factor) elif mode == 'inside_major': rescale_factor = min(rescale_factor) rescale_factor = min(max(1, rescale_factor), max_scale_change) crop_sz_x = math.floor(crop_sz_x / rescale_factor) crop_sz_y = math.floor(crop_sz_y / rescale_factor) if crop_sz_x < 1 or crop_sz_y < 1: raise Exception('Too small bounding box.') x1 = round(bbx + 0.5 * bbw - crop_sz_x * 0.5) x2 = x1 + crop_sz_x y1 = round(bby + 0.5 * bbh - crop_sz_y * 0.5) y2 = y1 + crop_sz_y # Move box inside image shift_x = max(0, -x1) + min(0, im_w - x2) x1 += shift_x x2 += shift_x shift_y = max(0, -y1) + min(0, im_h - y2) y1 += shift_y y2 += shift_y out_x = (max(0, -x1) + max(0, x2 - im_w)) // 2 out_y = (max(0, -y1) + max(0, y2 - im_h)) // 2 shift_x = (-x1 - out_x) * (out_x > 0) shift_y = (-y1 - out_y) * (out_y > 0) x1 += shift_x x2 += shift_x y1 += shift_y y2 += shift_y x1_pad = max(0, -x1) x2_pad = max(x2 - im.shape[1] + 1, 0) y1_pad = max(0, -y1) y2_pad = max(y2 - im.shape[0] + 1, 0) # Crop target im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :] if mask is not None: mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad] # Pad im_crop_padded = cv.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv.BORDER_REPLICATE) if mask is not None: mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0) # Resize image im_out = cv.resize(im_crop_padded, tuple(output_sz.long().tolist())) if mask is not None: mask_out = \ F.interpolate(mask_crop_padded[None, None], tuple(output_sz.flip(0).long().tolist()), mode='nearest')[0, 0] crop_box = torch.Tensor([x1, y1, x2 - x1, y2 - y1]) if mask is None: return im_out, crop_box else: return im_out, crop_box, mask_out
def add_overlay(self, tag, embed, img, alpha=0.8, cmap='inferno', add_ref=None): """ Adds to the summary the images of the input image (ref or search) overlayed with the corresponding embedding or correlation map. It expect tensors of with dimensions [C x H x W] or [B x C x H x W] if the tensor has a batch dimension it takes the FIRST ELEMENT of the batch. The image is displayed as fusion of the input image in grayscale and the overlay in the chosen color_map, this fusion is controlled by the alpha factor. In the case of the embeddings, since there are multiple feature channels, we show each of them individually in a grid. OBS: The colors represent relative values, where the peak color corresponds to the maximum value in any given channel, so no direct value comparisons can be made between epochs, only the relative distribution of neighboring pixel values, (which should be enough, since we are mosly interested in finding the maximum of a given correlation map) Args: tag: (str) The string identifying the image in tensorboard, images with the same tag are grouped together with a slider, and are indexed by epoch. embed: (torch.Tensor) The tensor containing the embedding of an input (ref or search image) or a correlation map (the final output). The shape should be [B, C, H, W] or [B, H, W] for the case of the correlation map. img: (torch.Tensor) The image on top of which the embed is going to be overlaid. Reference image embeddings should be overlaid on top of reference images and search image embeddings as well as the correlation maps should be overlaid on top of the search images. alpha: (float) A mixing variable, it controls how much of the final embedding corresponds to the grayscale input image and how much corresponds to the overlay. Alpha = 0, means there is no overlay in the final image, only the input image. Conversely, Alpha = 1 means there is only overlay. Adjust this value so you can distinctly see the overlay details while still seeing where it is in relation to the orignal image. cmap: (str) The name of the colormap to be used with the overlay. The colormaps are defined in the colormaps.py module, but values include 'viridis' (greenish blue) and 'inferno' (yellowish red). add_ref: (torch.Tensor) Optional. An additional reference image that will be plotted to the side of the other images. Useful when plotting correlation maps, because it lets the user see both the search image and the reference that is used as the target. ``Example`` >>> summ_maker = SummaryMaker(os.path.join(exp_dir, 'tensorboard'), params, model.upscale_factor) ... >>> embed_ref = model.get_embedding(ref_img_batch) >>> embed_srch = model.get_embedding(search_batch) >>> output_batch = model.match_corr(embed_ref, embed_srch) >>> batch_index = 0 >>> summ_maker.add_overlay("Ref_image_{}".format(tbx_index), embed_ref[batch_index], ref_img_batch[batch_index], cmap='inferno') >>> summ_maker.add_overlay("Search_image_{}".format(tbx_index), embed_srch[batch_index], search_batch[batch_index], cmap='inferno') >>> summ_maker.add_overlay("Correlation_map_{}".format(tbx_index), output_batch[batch_index], search_batch[batch_index], cmap='inferno') """ # TODO Add numbers in the final image to the feature channels. # TODO Add the color bar showing the progression of values. # If minibatch is given, take only the first image # TODO let the user select the image? Loop on all images? if len(embed.shape) == 4: embed = embed[0] if len(img.shape) == 4: img = img[0] # Normalize the image. img = img - img.min() img = img/img.max() embed = cm.apply_cmap(embed, cmap=cmap) # Get grayscale version of image by taking the weighted average of the channels # as described in https://www.cs.virginia.edu/~vicente/recognition/notebooks/image_processing_lab.html#2.-Converting-to-Grayscale R,G,B = img img_gray = 0.21 * R + 0.72 * G + 0.07 * B # Get the upscaled size of the embedding, so as to take into account # the network's downscale caused by the stride. upsc_size = (embed.shape[-1] - 1) * self.up_factor + 1 embed = F.interpolate(embed, upsc_size, mode='bilinear', align_corners=False) # Pad the embedding with zeros to match the image dimensions. We pad # all 4 corners equally to keep the embedding centered. tot_pad = img.shape[-1] - upsc_size # Sanity check 1. The amount of padding must be equal on all sides, so # the total padding on any dimension must be an even integer. assert tot_pad % 2 == 0, "The embed or image dimensions are incorrect." pad = int(tot_pad/2) embed = F.pad(embed, (pad, pad, pad, pad), 'constant', 0) # Sanity check 2, the size of the embedding in the (H, w) dimensions # matches the size of the image. assert embed.shape[-2:] == img.shape[-2:], ("The embedding overlay " "and image dimensions " "do not agree.") final_imgs = alpha * embed + (1-alpha) * img_gray # The embedding_channel (or feature channel) dimension is treated like # a batch dimension, so the grid shows each individual embeding # overlayed with the input image. Plus the original image is also shown. # If add_ref is used the ref image is the first to be shown. img = img.unsqueeze(0) final_imgs = torch.cat((img, final_imgs)) if add_ref is not None: # Pads the image if necessary pad = int((img.shape[-1] - add_ref.shape[-1])//2) add_ref = F.pad(add_ref, (pad, pad, pad, pad), 'constant', 0) add_ref = add_ref.unsqueeze(0) final_imgs = torch.cat((add_ref, final_imgs)) final_imgs = make_grid(final_imgs, nrow=6) self.writer_val.add_image(tag, final_imgs, self.epoch)
def forward(self, x): x = F.max_pool2d(F.pad(x, (0,1,0,1), mode='replicate'), 2, stride=1) return x
def forward(self, x): h = self.up1(x) h = F.pad(h, (1, 1, 1, 1), mode='reflect') h = self.b3(self.c2(h)) return F.relu(h)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[str] = None) -> torch.Tensor: """Propagate forward the input through the MHB. We compute for each head the queries, keys and values matrices, followed by the Scaled Dot-Product. The result is concatenated and returned with shape (batch_size, K, d_model). Parameters ---------- query: Input tensor with shape (batch_size, K, d_model) used to compute queries. key: Input tensor with shape (batch_size, K, d_model) used to compute keys. value: Input tensor with shape (batch_size, K, d_model) used to compute values. mask: Mask to apply on scores before computing attention. One of ``'subsequent'``, None. Default is None. Returns ------- Self attention tensor with shape (batch_size, K, d_model). """ batch_size = query.shape[0] # Apply padding to input sequence query = F.pad(query.transpose(1, 2), (self._padding, self._padding), 'replicate').transpose(1, 2) key = F.pad(key.transpose(1, 2), (self._padding, self._padding), 'replicate').transpose(1, 2) value = F.pad(value.transpose(1, 2), (self._padding, self._padding), 'replicate').transpose(1, 2) # Compute Q, K and V, concatenate heads on batch dimension queries = torch.cat(self._W_q(query).chunk(self._h, dim=-1), dim=0) keys = torch.cat(self._W_k(key).chunk(self._h, dim=-1), dim=0) values = torch.cat(self._W_v(value).chunk(self._h, dim=-1), dim=0) # Divide Q, K and V using a moving window queries = queries.unfold(dimension=1, size=self._window_size, step=self._step).reshape( (-1, self._q, self._window_size)).transpose(1, 2) keys = keys.unfold(dimension=1, size=self._window_size, step=self._step).reshape( (-1, self._q, self._window_size)).transpose(1, 2) values = values.unfold(dimension=1, size=self._window_size, step=self._step).reshape( (-1, self._v, self._window_size)).transpose(1, 2) # Scaled Dot Product self._scores = torch.bmm(queries, keys.transpose(1, 2)) / np.sqrt( self._window_size) # Compute local map mask if self._attention_size is not None: self._scores = self._scores.masked_fill(self._attention_mask, float('-inf')) # Compute future mask if mask == "subsequent": self._scores = self._scores.masked_fill(self._future_mask, float('-inf')) # Apply softmax self._scores = F.softmax(self._scores, dim=-1) attention = torch.bmm(self._scores, values) # Fold chunks back attention = attention.reshape( (batch_size * self._h, -1, self._window_size, self._v)) attention = attention[:, :, self._padding:-self._padding, :] attention = attention.reshape((batch_size * self._h, -1, self._v)) # Concatenat the heads attention_heads = torch.cat(attention.chunk(self._h, dim=0), dim=-1) # Apply linear transformation W^O self_attention = self._W_o(attention_heads) return self_attention
def forward( self, obstacles, coords, start_map, goal_map, non_obstacle_cost_map=None, additional_steps=50, return_path=True, ): self.trav_init_time = 0 self.trav_mask_time = 0 self.trav_soft_time = 0 self.conv_time = 0 self.close_time = 0 self.obstacles = self.preprocess_obstacle_map( obstacles.to(self.device) ) self.start_map = start_map.to(self.device) self.been_there = torch.zeros_like(self.start_map).to( torch.device("cpu") ) self.coords = coords.to(self.device) self.goal_map = goal_map.to(self.device) self.been_there = torch.zeros_like(self.goal_map).to(self.device) self.height = obstacles.size(2) self.width = obstacles.size(3) m, goal_idx = torch.max(self.goal_map.view(-1), 0) c_map = self.calculate_local_path_costs(non_obstacle_cost_map) # c_map might be non persistent in map update self.g_map = self.init_g_map() self.close_list_map = self.init_closelistmap() self.open_list_map = self.init_openlistmap() not_done = False step = 0 stopped_by_max_iter = False if self.visualize: self.fig, self.ax = plt.subplots(1, 1) self.image = self.ax.imshow( self.g_map.squeeze().cpu().detach().numpy().astype(np.float32), animated=True, ) self.fig.canvas.draw() not_done = (self.close_list_map.view(-1)[goal_idx].item() < 1.0) or ( self.g_map.view(-1)[goal_idx].item() >= 0.9 * self.ob_cost ) rad = 1 self.start_coords = ( (self.coords * self.start_map.expand_as(self.coords)) .sum(dim=2) .sum(dim=2) .squeeze() ) node_coords = self.start_coords self.goal_coords = ( (self.coords * self.goal_map.expand_as(self.coords)) .sum(dim=2) .sum(dim=2) .squeeze() ) self.max_steps = 4 * int( torch.sqrt( ((self.start_coords - self.goal_coords) ** 2).sum() + 1e-6 ).item() ) while not_done: ymin, ymax, xmin, xmax = self.safe_roi_2d( node_coords[0] - rad, node_coords[0] + rad + 1, node_coords[1] - rad, node_coords[1] + rad + 1, ) if ( (ymin - 1 > 0) and (xmin - 1 > 0) and (ymax + 1 < self.height) and (xmax + 1 < self.width) ): n2c = self.neights2channels( self.g_map[:, :, ymin - 1 : ymax + 1, xmin - 1 : xmax + 1] ) self.g_map[:, :, ymin:ymax, xmin:xmax] = torch.min( self.g_map[:, :, ymin:ymax, xmin:xmax].clone(), (n2c + c_map[:, :, ymin:ymax, xmin:xmax]).min( dim=1, keepdim=True )[0], ) self.close_list_map[:, :, ymin:ymax, xmin:xmax] = torch.max( self.close_list_map[:, :, ymin:ymax, xmin:xmax], self.open_list_map[:, :, ymin:ymax, xmin:xmax], ) self.open_list_map[:, :, ymin:ymax, xmin:xmax] = F.relu( F.max_pool2d( self.open_list_map[ :, :, ymin - 1 : ymax + 1, xmin - 1 : xmax + 1 ], 3, stride=1, padding=0, ) - self.close_list_map[:, :, ymin:ymax, xmin:xmax] - self.obstacles[:, :, ymin:ymax, xmin:xmax] ) else: self.g_map = torch.min( self.g_map, ( self.neights2channels( F.pad(self.g_map, (1, 1, 1, 1), "replicate") ) + c_map ).min(dim=1, keepdim=True)[0], ) self.close_list_map = torch.max( self.close_list_map, self.open_list_map ) self.open_list_map = F.relu( F.max_pool2d(self.open_list_map, 3, stride=1, padding=1) - self.close_list_map - self.obstacles ) step += 1 if step >= self.max_steps: stopped_by_max_iter = True break not_done = ( self.close_list_map.view(-1)[goal_idx].item() < 1.0 ) or (self.g_map.view(-1)[goal_idx].item() >= 0.1 * self.inf) rad += 1 if not stopped_by_max_iter: for i in range(additional_steps): # now propagating beyong start point self.g_map = torch.min( self.g_map, ( self.neights2channels( F.pad(self.g_map, (1, 1, 1, 1), "replicate") ) + c_map ).min(dim=1, keepdim=True)[0], ) self.close_list_map = torch.max( self.close_list_map, self.open_list_map ) self.open_list_map = F.relu( F.max_pool2d(self.open_list_map, 3, stride=1, padding=1) - self.close_list_map - self.obstacles ) if return_path: out_path, cost = self.reconstruct_path() return out_path, cost return
def forward(self, x, pad=True): if pad: x = F.pad(x, (self.padding, 0)) return self.pool(x)
def forward(self, x): out = F.pad(x, (0, 0, 0, 0, 0, self.add_channels)) out = self.pooling(out) return out
def forward(self, x): x = x.transpose(0, 1).transpose(1, 2) x = F.pad(x, pad=self.pad, value=0) x = self.conv(x) x = x.transpose(1, 2).transpose(0, 1).contiguous() return x
def __init__(self): # self.name = "class_sp_inter_ba" self.name = "class_sp_inter_ba_mult_long" self.model = PseudoMultiTaskNetMult() if CUDA: self.model.cuda(0) self.writer = SummaryWriter() self.epoch_function = epoch_mixed self.epoch_args = {} self.epochs = 50 # self.lr = 0.1 self.lr = 1e-3 self.momentum = 0.9 self.weight_decay = 1e-4 self.step_size = self.epochs // 3 self.gamma = 0.1 def loss(outputs, labels): likelihoods = torch.log((outputs[0] + outputs[1] + outputs[2]) / 3) disagreement = torch.sum(torch.abs(outputs[0] - outputs[1])) sparsity = torch.norm(outputs[3], 1) _, y_pred = torch.max(likelihoods.data, 1) return (F.nll_loss(likelihoods, labels) + 0.01 * disagreement + 0.0001 * sparsity, y_pred) self.epoch_args["loss_fn"] = loss self.epoch_args["aug_loss_fn"] = lambda x, y: F.l1_loss( F.pad(x, (2, 2, 2, 2)), y) self.epoch_args["optimizer"] = optim.SGD( self.model.parameters(), lr=10 * self.lr, momentum=self.momentum, weight_decay=self.weight_decay) self.epoch_args["scheduler"] = StepLR(self.epoch_args["optimizer"], step_size=self.step_size, gamma=self.gamma) self.aug_batch_size = 16 self.batch_size = 16 * self.aug_batch_size # self.aug_lr = 0.01 self.aug_lr = 1e-3 self.epoch_args["aug_optimizer"] = optim.SGD( self.model.parameters(), lr=10 * self.aug_lr, momentum=self.momentum, weight_decay=self.weight_decay) self.epoch_args["aug_scheduler"] = StepLR( self.epoch_args["aug_optimizer"], step_size=self.epochs // 4, gamma=self.gamma) # self.train_loader = load_augmented(self.batch_size, # self.aug_batch_size) self.train_loader = load_mnist(self.batch_size, train=True) self.val_loader = load_mnist(self.batch_size, train=False) for key, value in self.__dict__.items(): self.writer.add_text(f"config/{key}", str(value))
def from_tensors( tensors: Sequence[torch.Tensor], size_divisibility: int = 0, pad_ref_long: bool = False, pad_value: float = 0.0, ) -> "ImageList": """ Args: tensors: a tuple or list of `torch.Tensors`, each of shape (Hi, Wi) or (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded with `pad_value` so that they will have the same shape. size_divisibility (int): If `size_divisibility > 0`, also adds padding to ensure the common height and width is divisible by `size_divisibility` pad_value (float): value to pad Returns: an `ImageList`. """ assert len(tensors) > 0 assert isinstance(tensors, (tuple, list)) for t in tensors: assert isinstance(t, torch.Tensor), type(t) assert t.shape[1:-2] == tensors[0].shape[1:-2], t.shape # per dimension maximum (H, W) or (C_1, ..., C_K, H, W) where K >= 1 among all tensors max_size = list(max(s) for s in zip(*[img.shape for img in tensors])) if pad_ref_long: max_size_max = max(max_size[-2:]) max_size[-2:] = [max_size_max] * 2 max_size = tuple(max_size) if size_divisibility > 0: import math stride = size_divisibility max_size = list(max_size) # type: ignore max_size[-2] = int(math.ceil(max_size[-2] / stride) * stride) # type: ignore max_size[-1] = int(math.ceil(max_size[-1] / stride) * stride) # type: ignore max_size = tuple(max_size) image_sizes = [im.shape[-2:] for im in tensors] if len(tensors) == 1: # This seems slightly (2%) faster. # TODO: check whether it's faster for multiple images as well image_size = image_sizes[0] padded = F.pad( tensors[0], [ 0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0] ], value=pad_value, ) batched_imgs = padded.unsqueeze_(0) else: batch_shape = (len(tensors), ) + max_size batched_imgs = tensors[0].new_full(batch_shape, pad_value) for img, pad_img in zip(tensors, batched_imgs): pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) return ImageList(batched_imgs.contiguous(), image_sizes)
def my_pad(_t: Tensor, pad: List[int]) -> Tensor: return torch.transpose(F.pad(torch.transpose(_t, 1, 2), pad=pad), 1, 2)
def forward(self, x): padded_x = F.pad(x, (0,self.pad,0,self.pad), mode="replicate") pooled_x = nn.MaxPool2d(self.kernel_size, self.pad)(padded_x) return pooled_x
def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config): """ # 匹配流程 1. 移除 proposal的 0填充 2. 区分 拥挤和非拥挤 GT-boxes 3. 计算 Overlaps, proposals, gt_boxes, 2D表格内进行竖横筛选,先筛选垂直方向proposal, 再筛选水平方向gtboxes 4. 正负 RoIs 判定,IoU, counts, ratio 5、为正 ROIs 配置 gt-boxes, gt-class-ids targets 6、计算 ROIs 的校正量 deltas targets 7. 计算 mask targets 8. 组合 正负RoIs # 匹配标准 1。从proposal看,存在一个GT Iou >= 0.7 为正, 与所有GT的 IoU < 0.5为负, 介于0.5~0.7之间为中性. 2。从GTbox角度看,与每个GT最近的那个proposal, 必须为正, 即每个GT都必须要有匹配对象. 输入: proposals: [N, (y1, x1, y2, x2)] 归一化,零填充. Variable gt_class_ids: [MAX_GT_INSTANCES] 类标. Var gt_boxes: [MAX_GT_INSTANCES, (y1, x1, y2, x2)] 未归一化,无零填充. Var gt_masks: [height, width, MAX_GT_INSTANCES] of boolean type. Var 返回: Target ROIs and corresponding class IDs, bounding box shifts, and masks. rois: [TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized coordinates class_ids: [TRAIN_ROIS_PER_IMAGE]. Integer class IDs. Zero padded. deltas: [TRAIN_ROIS_PER_IMAGE, NUM_CLASSES, (dy, dx, log(dh), log(dw))] Class-specific bbox refinments. bbox的偏移量是按类指定的。 masks: [TRAIN_ROIS_PER_IMAGE, height, width). Masks cropped to bbox boundaries and resized to neural network output size. Masks被按照bbox裁切,再缩放到config中指定的输出大小 Note: 如果没有足够的target ROIs,会进行零填充. MAX_GT_INSTANCES < TRAIN_ROIS_PER_IMAGE """ assert proposals.size(0) > 0, '当前的proposal是空的!' # gt_class_ids, gt_boxes, gt_masks = gt_class_ids.data, gt_boxes.data, gt_masks.data # 0. 归一化gt_boxes坐标 h, w = config.IMAGE_SHAPE[0:2] scale = Variable(torch.from_numpy(np.array([h, w, h, w])).float(), requires_grad=False).cuda() gt_boxes = gt_boxes / scale # 1. 移除proposals,GT中的0填充 proposals = utils.trim_zeros(proposals) # 2. 区分拥挤GT & 不拥挤GT # Handle COCO crowds # A crowd box in COCO is a bounding box around several instances. Exclude # them from training. A crowd box is given a negative class ID. # 当一个BOX围住好几个物体实例时,称为Crowdbox,将其从训练阶段排除,给予一个负标签 crowd_ix = np.where(gt_class_ids.data < 0)[0] if crowd_ix: crowd_ix = Variable(torch.from_numpy(crowd_ix).long().cuda(), requires_grad=False) crowd_gt_boxes = torch.gather(gt_boxes, dim=0, index=crowd_ix.unsqueeze(-1).expand( crowd_ix.size(0), 4)) else: crowd_gt_boxes = Variable(torch.FloatTensor([])).cuda() # crowd_gt_class_ids = torch.gather(gt_class_ids, dim=0, index=crowd_ix) # crowd_gt_masks = torch.gather(gt_masks, dim=2, index=crowd_ix.unsqueeze(0).unsqueeze(0) # .expand(gt_masks.shape[0:2] + (crowd_ix.size(0),))) non_crowd_ix = Variable(torch.from_numpy( np.where(gt_class_ids.data > 0)[0]).long().cuda(), requires_grad=False) gt_class_ids = torch.gather(gt_class_ids, dim=0, index=non_crowd_ix) gt_boxes = torch.gather(gt_boxes, dim=0, index=non_crowd_ix.unsqueeze(-1).expand( non_crowd_ix.size(0), 4)) gt_masks = torch.gather(gt_masks, dim=2, index=non_crowd_ix.unsqueeze(0).unsqueeze( 0).expand(gt_masks.shape[0:2] + (non_crowd_ix.size(0), ))) crowd_ix, non_crowd_ix = None, None # 3、计算proposals和gt_boxes的Overlaps # Compute overlaps matrix [proposals, gt_boxes] overlaps = utils.bbox_overlaps(proposals, gt_boxes) # shape: N×K if crowd_gt_boxes.numel() > 0: crowd_overlaps = utils.bbox_overlaps(proposals, crowd_gt_boxes) crowd_iou_max = torch.max(crowd_overlaps, dim=1)[0].data no_crowd_bool = (crowd_iou_max < 0.001) # shape: N×K' else: no_crowd_bool = torch.ones(proposals.shape[0]).byte().cuda() crowd_overlaps, crowd_iou_max = None, None # 4、判定正负ROIs # Determine postive and negative ROIs # method1: 给定 counts & max & min threshold, firstly filter by max & min, then both randomly select P/N-rois # method2: 给定 counts & min, 首先选择 top-k(p_counts) as P-rois, others randomly selecte n_counts as N-rois # method1 可以始终确保正负ROI在一定阈值max&min的控制之内,但不能保证产生的数量,且需要手动设置这个max&min参数。 # method2 可以始终确保前N个最大重叠度的ROI为正,而让负ROI随机产生。但是不能保证正负ROI一定符合max&min阈值要求,但也正好免去此手动参数。 # method2 也可以只指定一个min threshold, 因为潜在负样本较多,所以不会产生实际负样本数量不足的现象。 # method2 中直接在正样本中排除掉crowd gt box,不再考虑 no_crowd_bool 参数。 # method3 是method1的另一种实现方式 # 负样本应该尽可能为背景,不包含有效物体。 # 负样本与非拥挤框的IOU<0.5, 且与拥挤框的IOU<0.001. 因为拥挤框内有多个完整物体,若某个proposal与之显著交叠,可能导致该物体成为负样本。 method1, method2, method3 = (False, True, False) if method1: # dim1 : 每个proposal/roi的最佳gt_box的iou值 roi_iou_max = torch.max(overlaps, dim=1)[0].data # shape: N # 4.1. Positive ROIs are those with >= 0.5 IoU with a GT box positive_indices = np.where( (roi_iou_max >= config.ROIS_GTBOX_IOU[0]))[0] # shape: N positive_indices = torch.from_numpy(positive_indices).long().cuda() # 4.2. Negative ROIs are those with < 0.5 with every GT box. Skip crowds. negative_indices = np.where( np.logical_and(roi_iou_max < config.ROIS_GTBOX_IOU[1], no_crowd_bool))[0] negative_indices = torch.from_numpy(negative_indices).long().cuda() # Subsample ROIs. Aim for 33% positive # Positive ROIs 在所有正ROIs中,随机选择config中指定的数量个 # 实际数量有可能不足config中的设定值,因此最终的数量由shape[0]推算而出! positive_count = int(config.TRAIN_ROIS_PER_IMAGE * config.ROIS_POSITIVE_RATIO) # all*pr positive_indices = positive_indices[list( torch.randperm(positive_indices.numel()))][:positive_count] positive_count = positive_indices.shape[ 0] # 切片允许索引超限,因此实际数量仍需切片后统计 # Negative ROIs. Add enough to maintain positive:negative ratio. # 最终的ROIs数量,必须满足预设的正负比例,但不一定能同时满足预设的总数量 r = 1.0 / config.ROIS_POSITIVE_RATIO # 1/0.33 negative_count = math.floor( (r - 1) * positive_count) # total*pr*(1/pr-1)=all*(1-pr) negative_indices = negative_indices[list( torch.randperm(negative_indices.numel()))][:negative_count] roi_iou_max = None elif method2: roi_iou_max = torch.max(overlaps, dim=1)[0].data # shape: N roi_iou_max, roi_iou_ind = torch.sort(roi_iou_max, dim=0, descending=True) positive_count = int(config.TRAIN_ROIS_PER_IMAGE * config.ROIS_POSITIVE_RATIO) positive_indices = roi_iou_ind[0:positive_count] print( 'roi_iou_max max & mid & max : %s, %s, %s' % (roi_iou_max[0], roi_iou_max[positive_count], roi_iou_max[-1])) roi_iou_ind = roi_iou_ind[positive_count:] roi_iou_max = roi_iou_max[positive_count:] negative_count = config.TRAIN_ROIS_PER_IMAGE - positive_count negative_indices = roi_iou_ind[ (roi_iou_max < config.ROIS_GTBOX_IOU[1]) & no_crowd_bool[roi_iou_ind]] print('negative_indices.numel()---@rois_target()--:', negative_indices.numel()) print('proposals numbers : %s' % overlaps.shape[0]) # todo ??? # hotpoint + hotproposal ,非常容易出现所有overlaps>0.77的情况,导致没有负样本 # hotpoint + randomproposal, bu wending, max<0.7 min>0.7 dou chu xian. # generalpoint + hotproposal, shuzhi bijiao zheng chang. keyi wu bug yunx wan. # generalpoint + randomproposal, shuzhi wanquan zheng chang. haowu renhe bugs. vc = torch.randperm( negative_indices.numel())[0:negative_count].cuda() negative_indices = negative_indices[torch.randperm( negative_indices.numel())[0:negative_count].cuda()] roi_iou_max, roi_iou_ind, index = None, None, None elif method3: roi_iou_max = torch.max(overlaps, dim=1)[0].data # shape: N roi_iou_ind = torch.arange( 0, roi_iou_max.shape[0]).long().cuda() # shape: N positive_count = int(config.TRAIN_ROIS_PER_IMAGE * config.ROIS_POSITIVE_RATIO) positive_indices = roi_iou_ind[ roi_iou_max > config.ROIS_GTBOX_IOU[0]] positive_indices = positive_indices[torch.randperm( positive_indices.numel())[0:positive_count].cuda()] negative_count = config.TRAIN_ROIS_PER_IMAGE - positive_indices.shape[ 0] negative_indices = roi_iou_ind[ (roi_iou_max < config.ROIS_GTBOX_IOU[1]) & no_crowd_bool[roi_iou_ind]] negative_indices = negative_indices[torch.randperm( negative_indices.numel())[0:negative_count].cuda()] roi_iou_max = None else: raise ValueError('wrongt method!') # Gather selected ROIs 收集正负ROIs # index shape : N → N×1 → N×4 index = Variable(positive_indices.unsqueeze(-1).expand( positive_indices.size(0), proposals.size(-1)), requires_grad=False) positive_rois = torch.gather(proposals, dim=0, index=index) index = Variable(negative_indices.unsqueeze(-1).expand( negative_indices.size(0), proposals.size(-1)), requires_grad=False) negative_rois = torch.gather(proposals, dim=0, index=index) # 5. 为正ROIs分配 GT-boxes + GT-class-ids # Assign positive ROIs to GT boxes. 沿着gtbox方向(dim=1)统计最大值,得到每个roi的最佳gtbox/gt_class_id # index shape : N → N×1 → N×gt_boxes_count index = Variable(positive_indices.unsqueeze(-1).expand( positive_indices.size(0), overlaps.size(-1)), requires_grad=False) positive_overlaps = torch.gather(overlaps, dim=0, index=index) # N×K roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1].long() # N # N → N×1 → N×4 index = roi_gt_box_assignment.unsqueeze(-1).expand( roi_gt_box_assignment.size(0), gt_boxes.size(-1)) roi_gt_boxes = torch.gather(gt_boxes, dim=0, index=index) roi_gt_class_ids = torch.gather(gt_class_ids, dim=0, index=roi_gt_box_assignment) gt_boxes, index = None, None # 6、计算ROIs的校正量deltas # Compute bbox refinement for positive ROIs # 对正ROIs计算bbox的修正量 deltas = utils.box_refinement_graph(positive_rois, roi_gt_boxes) deltas /= Variable(torch.from_numpy( config.BBOX_STD_DEV).float().cuda(), requires_grad=False) # 6、抓取正ROI的-masks, 并计算mask targets # 在原始的GT-mask上,裁切位于roi_box中的那部分mask出来,再缩放到指定shape大小。 # GT-mask是一个H×W的二值图,因此裁切出来的仍然是一个小二值图 # 此小二值图,即是此roi_box的gt_mask_targets,可用于计算二值交叉熵损失 # Assign positive ROIs to GT masks # Permute masks from [h, w, n] to [N, channel, height, width] 跟box的坐标相对应 # transposed_masks = torch.transpose(torch.transpose(gt_masks, 2, 1), 1, 0).unsqueeze(1) transposed_masks = gt_masks.permute(2, 0, 1).unsqueeze(1) gt_masks = None # Pick the right mask for each ROI index = roi_gt_box_assignment.unsqueeze(-1).unsqueeze(-1).unsqueeze( -1) # N -> Nx1x1x1 index = index.expand((roi_gt_box_assignment.size(0), ) + transposed_masks.shape[1:]) # Nx1x1x1 -> Nx1xHxW roi_masks = torch.gather(transposed_masks, dim=0, index=index) transposed_masks, roi_gt_box_assignment, index = None, None, None # Compute mask targets, 挖出与roi相对应的mask boxes = positive_rois if config.USE_MINI_MASK: # Transform ROI corrdinates from normalized image space to normalized mini-mask space. y1, x1, y2, x2 = torch.split(positive_rois, 1, dim=1) gt_y1, gt_x1, gt_y2, gt_x2 = torch.split(roi_gt_boxes, 1, dim=1) gt_h = gt_y2 - gt_y1 gt_w = gt_x2 - gt_x1 y1 = (y1 - gt_y1) / gt_h x1 = (x1 - gt_x1) / gt_w y2 = (y2 - gt_y1) / gt_h x2 = (x2 - gt_x1) / gt_w boxes = torch.cat([y1, x1, y2, x2], 1) box_ids = Variable(torch.arange(0, roi_masks.size(0)), requires_grad=False).int().cuda() # 从roi_masks中切出boxes,再resize到config.MASK_SHAPE大小 # crfuc(Nx1xHxW, Nx4, N) -> (N*1*28*28) # masks = torch.image.crop_and_resize(roi_masks.float(), boxes, box_ids, config.MASK_SHAPE) crfuc = CropAndResizeFunction(config.MASK_SHAPE[0], config.MASK_SHAPE[1], 0) masks = Variable(crfuc(roi_masks, boxes, box_ids).data, requires_grad=False) roi_masks, box_ids = None, None # Remove the extra dimension from masks. masks = torch.squeeze(masks, dim=1) # Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with # binary cross entropy loss. 应用二值交叉熵损失前,进行round处理 masks = torch.round(masks) # 7、组合正负ROIs,并进行零填充 # Append negative ROIs and pad bbox deltas and masks that # are not used for negative ROIs with zeros. rois = torch.cat([positive_rois, negative_rois], dim=0) # shape Nx4 P = config.TRAIN_ROIS_PER_IMAGE - rois.size(0) P = P if P > 0 else 0 N = negative_rois.size(0) rois = F.pad(rois, (0, 0, 0, P)) # roi_gt_boxes = F.pad(roi_gt_boxes, (0, 0, 0, N + P)) roi_gt_class_ids = F.pad(roi_gt_class_ids, (0, N + P)) deltas = F.pad(deltas, (0, 0, 0, N + P)) masks = F.pad(masks, (0, 0, 0, 0, 0, N + P)) return rois, roi_gt_class_ids, deltas, masks
def loss_fn(func_param, labels: Variable): beta_0, beta_2k, sigma, gamma = func_param beta_1 = gamma[:, 0] - 2 * beta_2k[:, 0] * sigma[:, 0] beta_N = pad(torch.unsqueeze(beta_1, dim=1), (1, 0)) beta_N[:, 0] = beta_0 beta = (gamma - pad(gamma, (1, 0))[:, :-1]) / (2 * sigma) beta[:, 0] = beta_2k[:, 0] beta = beta - pad(beta, (1, 0))[:, :-1] beta[:, -1] = beta_2k[:, 1] - beta[:, :-1].sum(dim=1) #calculate the maximum for each segment of the spline ksi = torch.cumsum(sigma, dim=1) df1 = ksi.expand(sigma.shape[1], sigma.shape[0], sigma.shape[1]).T.clone() df2 = pad(ksi.T.unsqueeze(2), (1, 0), 'constant', value=1) ksi = pad(ksi, (1, 0))[:, :-1] knots = df1 - ksi knots[knots < 0] = 0 knots = (df2 * beta_N).sum(dim=2) + (knots.pow(2) * beta).sum(dim=2) knots = pad(knots.T, (1, 0))[:, :-1] #F(ksi_1~K)=0~max diff = labels.view(-1, 1) - knots alpha_l = diff > 0 alpha_A = torch.sum(alpha_l * beta, dim=1) alpha_B = beta_N[:, 1] - 2 * torch.sum(alpha_l * beta * ksi, dim=1) alpha_C = beta_N[:, 0] - labels + torch.sum(alpha_l * beta * ksi * ksi, dim=1) #since A may be zero, roots can be from different methods. not_zero = (alpha_A != 0) alpha = torch.zeros_like(alpha_A) #since there may be numerical calculation error,#0 idx = (alpha_B**2 - 4 * alpha_A * alpha_C) < 0 #0 diff = diff.abs() index = diff == (diff.min(dim=1)[0].view(-1, 1)) index[~idx, :] = False #index=diff.abs()<1e-4#0,1e-4 is a threshold #idx=index.sum(dim=1)>0#0 alpha[idx] = ksi[index] #0 alpha[~not_zero] = -alpha_C[~not_zero] / alpha_B[~not_zero] not_zero = ~(~not_zero | idx) #0 delta = alpha_B[not_zero].pow( 2) - 4 * alpha_A[not_zero] * alpha_C[not_zero] alpha[not_zero] = (-alpha_B[not_zero] + torch.sqrt(delta)) / (2 * alpha_A[not_zero]) crps_1 = labels * (2 * alpha - 1) #lam2=lambda n:2*beta_N[:,n-1]*(1/n/(n+1)-alpha.pow(n)/n) #crps_2=reduce(lambda a,b:a+b,[lam2(n) for n in range(1,2+1)]) crps_2 = beta_N[:, 0] * (1 - 2 * alpha) + beta_N[:, 1] * (1 / 3 - alpha.pow(2)) crps_3 = torch.sum(2 * beta / ((2 + 1) * (2 + 2)) * (1 - ksi).pow(2 + 2), dim=1) crps_4 = torch.sum(alpha_l * 2 * beta / (2 + 1) * (torch.unsqueeze(alpha, 1) - ksi).pow(2 + 1), dim=1) crps = crps_1 + crps_2 + crps_3 - crps_4 crps = torch.mean(crps) return crps
def __init__(self, shape, nb_class, channel, kernel_size, nb_pixel_block, nb_res_block, res_channel, dropout=0.1, nb_cond_res_block=0, cond_res_channel=0, cond_res_kernel=3, nb_out_res_block=0, cond_interpolate=1, attention=True): super().__init__() height, width = shape self.nb_class = nb_class self.cond_interpolate = cond_interpolate assert kernel_size % 2, "Kernel size must be odd" self.horz_conv = CausalConv2d(nb_class, channel, [kernel_size // 2, kernel_size], padding='down') self.horz_bn = nn.BatchNorm2d(channel) self.vert_conv = CausalConv2d( nb_class, channel, [(kernel_size + 1) // 2, kernel_size // 2], padding='downright') self.vert_bn = nn.BatchNorm2d(channel) coord_x = (torch.arange(height).float() - height / 2) / height coord_x = coord_x.view(1, 1, height, 1).expand(1, 1, height, width) coord_y = (torch.arange(width).float() - width / 2) / width coord_y = coord_y.view(1, 1, 1, width).expand(1, 1, height, width) self.register_buffer('bg', torch.cat([coord_x, coord_y], 1)) self.blks = nn.ModuleList([ PixelBlock(channel, res_channel, kernel_size, nb_res_block, dropout=dropout, condition_dim=cond_res_channel, attention=attention) for _ in range(nb_pixel_block) ]) if nb_cond_res_block > 0: cond_net = [ WNConv2d(nb_class, cond_res_channel, cond_res_kernel, padding=cond_res_kernel // 2) ] cond_net.extend([ GatedResBlock(cond_res_channel, cond_res_channel, cond_res_kernel, dropout=dropout) for _ in range(nb_cond_res_block) ]) self.cond_net = nn.Sequential(*cond_net) out = [] for _ in range(nb_out_res_block): out.append(GatedResBlock(channel, res_channel, 1, dropout=dropout)) out.append(nn.ELU(inplace=True)) out.append(WNConv2d(channel, nb_class, 1)) self.out = nn.Sequential(*out) self.shift_down = lambda x, size=1: F.pad(x, [0, 0, size, 0] )[:, :, :x.shape[2], :] self.shift_right = lambda x, size=1: F.pad(x, [size, 0, 0, 0] )[:, :, :, :x.shape[3]]
def forward(self, x): p = int(np.floor((self.kernel_size - 1) / 2)) p2d = (p, p, p, p) return F.max_pool2d(F.pad(x, p2d), self.kernel_size, stride=2)
def cifar10_loader(args, num_workers=4): normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) if args.data_aug: with torch.no_grad(): transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Lambda( lambda x: F.pad(Variable( x.unsqueeze(0), requires_grad=False), (4, 4, 4, 4), mode='reflect').data.squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose( [transforms.ToTensor(), normalize]) else: transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train) testset = datasets.CIFAR10(root=args.data_path, train=False, transform=transform_test) if args.use_few_data: indices = classwise_get_indices(args.dataset, trainset, args.K_shot, 10, args.seed) trainset = Subset(trainset, indices=indices) few_shot_num = args.K_shot * 10 train_loader = DataLoader(trainset, batch_size=few_shot_num, shuffle=False, num_workers=num_workers, pin_memory=True) train_loader = pseudo_loader(train_loader, args.batch_size, few_shot_num, shuffle=False) # train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) # train_loader = iter(cycle(train_loader)) else: train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return train_loader, test_loader
def neg_log_likelihood_loss(self,x,y): unary_pot = self.bilstm(x) unary_pot = F.pad(unary_pot,(0,2),'constant',LOW_POT) partition = self.crf.partition(unary_pot) gold_score = self.crf.score(unary_pot,y) return partition - gold_score
def forward(self, x): conv2d_0_pad = F.pad(x, (1, 1, 1, 1)) conv2d_0 = self.conv2d_0(conv2d_0_pad) conv2d_0_bn = self.conv2d_0_bn(conv2d_0) conv2d_0_activation = F.relu(conv2d_0_bn) maxpool2d_0 = F.max_pool2d(conv2d_0_activation, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False) conv2d_1_pad = F.pad(maxpool2d_0, (1, 1, 1, 1)) conv2d_1 = self.conv2d_1(conv2d_1_pad) conv2d_1_bn = self.conv2d_1_bn(conv2d_1) conv2d_1_activation = F.relu(conv2d_1_bn) maxpool2d_1 = F.max_pool2d(conv2d_1_activation, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False) conv2d_2_pad = F.pad(maxpool2d_1, (1, 1, 1, 1)) conv2d_2 = self.conv2d_2(conv2d_2_pad) conv2d_2_bn = self.conv2d_2_bn(conv2d_2) conv2d_2_activation = F.relu(conv2d_2_bn) maxpool2d_2_pad = F.pad(conv2d_2_activation, (0, 1, 0, 1), value=float('-inf')) maxpool2d_2 = F.max_pool2d(maxpool2d_2_pad, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False) conv2d_3_pad = F.pad(maxpool2d_2, (1, 1, 1, 1)) conv2d_3 = self.conv2d_3(conv2d_3_pad) conv2d_3_bn = self.conv2d_3_bn(conv2d_3) conv2d_3_activation = F.relu(conv2d_3_bn) maxpool2d_3_pad = F.pad(conv2d_3_activation, (0, 1, 0, 1), value=float('-inf')) maxpool2d_3 = F.max_pool2d(maxpool2d_3_pad, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False) cls_0_insert_conv2d_pad = F.pad(conv2d_3_activation, (1, 1, 1, 1)) cls_0_insert_conv2d = self.cls_0_insert_conv2d(cls_0_insert_conv2d_pad) loc_0_insert_conv2d_pad = F.pad(conv2d_3_activation, (1, 1, 1, 1)) loc_0_insert_conv2d = self.loc_0_insert_conv2d(loc_0_insert_conv2d_pad) conv2d_4_pad = F.pad(maxpool2d_3, (1, 1, 1, 1)) conv2d_4 = self.conv2d_4(conv2d_4_pad) cls_0_insert_conv2d_bn = self.cls_0_insert_conv2d_bn(cls_0_insert_conv2d) loc_0_insert_conv2d_bn = self.loc_0_insert_conv2d_bn(loc_0_insert_conv2d) conv2d_4_bn = self.conv2d_4_bn(conv2d_4) cls_0_insert_conv2d_activation = F.relu(cls_0_insert_conv2d_bn) loc_0_insert_conv2d_activation = F.relu(loc_0_insert_conv2d_bn) conv2d_4_activation = F.relu(conv2d_4_bn) cls_0_conv_pad = F.pad(cls_0_insert_conv2d_activation, (1, 1, 1, 1)) cls_0_conv = self.cls_0_conv(cls_0_conv_pad) loc_0_conv_pad = F.pad(loc_0_insert_conv2d_activation, (1, 1, 1, 1)) loc_0_conv = self.loc_0_conv(loc_0_conv_pad) maxpool2d_4_pad = F.pad(conv2d_4_activation, (0, 1, 0, 1), value=float('-inf')) maxpool2d_4 = F.max_pool2d(maxpool2d_4_pad, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False) cls_1_insert_conv2d_pad = F.pad(conv2d_4_activation, (1, 1, 1, 1)) cls_1_insert_conv2d = self.cls_1_insert_conv2d(cls_1_insert_conv2d_pad) loc_1_insert_conv2d_pad = F.pad(conv2d_4_activation, (1, 1, 1, 1)) loc_1_insert_conv2d = self.loc_1_insert_conv2d(loc_1_insert_conv2d_pad) cls_0_reshape = torch.reshape(input = cls_0_conv.permute(0,2,3,1) , shape = (cls_0_conv.size(0),-1,2)) loc_0_reshape = torch.reshape(input = loc_0_conv.permute(0,2,3,1) , shape = (loc_0_conv.size(0),-1,4)) conv2d_5_pad = F.pad(maxpool2d_4, (1, 1, 1, 1)) conv2d_5 = self.conv2d_5(conv2d_5_pad) cls_1_insert_conv2d_bn = self.cls_1_insert_conv2d_bn(cls_1_insert_conv2d) loc_1_insert_conv2d_bn = self.loc_1_insert_conv2d_bn(loc_1_insert_conv2d) cls_0_activation = F.sigmoid(cls_0_reshape) conv2d_5_bn = self.conv2d_5_bn(conv2d_5) cls_1_insert_conv2d_activation = F.relu(cls_1_insert_conv2d_bn) loc_1_insert_conv2d_activation = F.relu(loc_1_insert_conv2d_bn) conv2d_5_activation = F.relu(conv2d_5_bn) cls_1_conv_pad = F.pad(cls_1_insert_conv2d_activation, (1, 1, 1, 1)) cls_1_conv = self.cls_1_conv(cls_1_conv_pad) loc_1_conv_pad = F.pad(loc_1_insert_conv2d_activation, (1, 1, 1, 1)) loc_1_conv = self.loc_1_conv(loc_1_conv_pad) maxpool2d_5_pad = F.pad(conv2d_5_activation, (0, 1, 0, 1), value=float('-inf')) maxpool2d_5 = F.max_pool2d(maxpool2d_5_pad, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False) cls_2_insert_conv2d_pad = F.pad(conv2d_5_activation, (1, 1, 1, 1)) cls_2_insert_conv2d = self.cls_2_insert_conv2d(cls_2_insert_conv2d_pad) loc_2_insert_conv2d_pad = F.pad(conv2d_5_activation, (1, 1, 1, 1)) loc_2_insert_conv2d = self.loc_2_insert_conv2d(loc_2_insert_conv2d_pad) cls_1_reshape = torch.reshape(input = cls_1_conv.permute(0,2,3,1) , shape = (cls_1_conv.size(0),-1,2)) loc_1_reshape = torch.reshape(input = loc_1_conv.permute(0,2,3,1) , shape = (loc_1_conv.size(0),-1,4)) conv2d_6_pad = F.pad(maxpool2d_5, (1, 1, 1, 1)) conv2d_6 = self.conv2d_6(conv2d_6_pad) cls_2_insert_conv2d_bn = self.cls_2_insert_conv2d_bn(cls_2_insert_conv2d) loc_2_insert_conv2d_bn = self.loc_2_insert_conv2d_bn(loc_2_insert_conv2d) cls_1_activation = F.sigmoid(cls_1_reshape) conv2d_6_bn = self.conv2d_6_bn(conv2d_6) cls_2_insert_conv2d_activation = F.relu(cls_2_insert_conv2d_bn) loc_2_insert_conv2d_activation = F.relu(loc_2_insert_conv2d_bn) conv2d_6_activation = F.relu(conv2d_6_bn) cls_2_conv_pad = F.pad(cls_2_insert_conv2d_activation, (1, 1, 1, 1)) cls_2_conv = self.cls_2_conv(cls_2_conv_pad) loc_2_conv_pad = F.pad(loc_2_insert_conv2d_activation, (1, 1, 1, 1)) loc_2_conv = self.loc_2_conv(loc_2_conv_pad) conv2d_7 = self.conv2d_7(conv2d_6_activation) cls_3_insert_conv2d_pad = F.pad(conv2d_6_activation, (1, 1, 1, 1)) cls_3_insert_conv2d = self.cls_3_insert_conv2d(cls_3_insert_conv2d_pad) loc_3_insert_conv2d_pad = F.pad(conv2d_6_activation, (1, 1, 1, 1)) loc_3_insert_conv2d = self.loc_3_insert_conv2d(loc_3_insert_conv2d_pad) cls_2_reshape = torch.reshape(input = cls_2_conv.permute(0,2,3,1) , shape = (cls_2_conv.size(0),-1,2)) loc_2_reshape = torch.reshape(input = loc_2_conv.permute(0,2,3,1) , shape = (loc_2_conv.size(0),-1,4)) conv2d_7_bn = self.conv2d_7_bn(conv2d_7) cls_3_insert_conv2d_bn = self.cls_3_insert_conv2d_bn(cls_3_insert_conv2d) loc_3_insert_conv2d_bn = self.loc_3_insert_conv2d_bn(loc_3_insert_conv2d) cls_2_activation = F.sigmoid(cls_2_reshape) conv2d_7_activation = F.relu(conv2d_7_bn) cls_3_insert_conv2d_activation = F.relu(cls_3_insert_conv2d_bn) loc_3_insert_conv2d_activation = F.relu(loc_3_insert_conv2d_bn) cls_4_insert_conv2d_pad = F.pad(conv2d_7_activation, (1, 1, 1, 1)) cls_4_insert_conv2d = self.cls_4_insert_conv2d(cls_4_insert_conv2d_pad) loc_4_insert_conv2d_pad = F.pad(conv2d_7_activation, (1, 1, 1, 1)) loc_4_insert_conv2d = self.loc_4_insert_conv2d(loc_4_insert_conv2d_pad) cls_3_conv_pad = F.pad(cls_3_insert_conv2d_activation, (1, 1, 1, 1)) cls_3_conv = self.cls_3_conv(cls_3_conv_pad) loc_3_conv_pad = F.pad(loc_3_insert_conv2d_activation, (1, 1, 1, 1)) loc_3_conv = self.loc_3_conv(loc_3_conv_pad) cls_4_insert_conv2d_bn = self.cls_4_insert_conv2d_bn(cls_4_insert_conv2d) loc_4_insert_conv2d_bn = self.loc_4_insert_conv2d_bn(loc_4_insert_conv2d) cls_3_reshape = torch.reshape(input = cls_3_conv.permute(0,2,3,1) , shape = (cls_3_conv.size(0),-1,2)) loc_3_reshape = torch.reshape(input = loc_3_conv.permute(0,2,3,1) , shape = (loc_3_conv.size(0),-1,4)) cls_4_insert_conv2d_activation = F.relu(cls_4_insert_conv2d_bn) loc_4_insert_conv2d_activation = F.relu(loc_4_insert_conv2d_bn) cls_3_activation = F.sigmoid(cls_3_reshape) cls_4_conv_pad = F.pad(cls_4_insert_conv2d_activation, (1, 1, 1, 1)) cls_4_conv = self.cls_4_conv(cls_4_conv_pad) loc_4_conv_pad = F.pad(loc_4_insert_conv2d_activation, (1, 1, 1, 1)) loc_4_conv = self.loc_4_conv(loc_4_conv_pad) cls_4_reshape = torch.reshape(input = cls_4_conv.permute(0,2,3,1) , shape = (cls_4_conv.size(0),-1,2)) loc_4_reshape = torch.reshape(input = loc_4_conv.permute(0,2,3,1) , shape = (loc_4_conv.size(0),-1,4)) cls_4_activation = F.sigmoid(cls_4_reshape) loc_branch_concat = torch.cat((loc_0_reshape, loc_1_reshape, loc_2_reshape, loc_3_reshape, loc_4_reshape), 1) cls_branch_concat = torch.cat((cls_0_activation, cls_1_activation, cls_2_activation, cls_3_activation, cls_4_activation), 1) return loc_branch_concat, cls_branch_concat
Normal(posterior_means, posterior_std_devs), global_prior).sum(dim=2).mean(dim=(0, 1)) # Calculate latent overshooting objective for t > 0 if args.overshooting_kl_beta != 0: overshooting_vars = [ ] # Collect variables for overshooting to process in batch for t in range(1, args.chunk_size - 1): d = min(t + args.overshooting_distance, args.chunk_size - 1) # Overshooting distance t_, d_ = t - 1, d - 1 # Use t_ and d_ to deal with different time indexing for latent states seq_pad = ( 0, 0, 0, 0, 0, t - d + args.overshooting_distance ) # Calculate sequence padding so overshooting terms can be calculated in one batch # Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) prior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks overshooting_vars.append( (F.pad(actions[t:d], seq_pad), F.pad(nonterminals[t:d], seq_pad), F.pad(rewards[t:d], seq_pad[2:]), beliefs[t_], prior_states[t_], F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad), F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(), seq_pad, value=1), F.pad( torch.ones(d - t, args.batch_size, args.state_size, device=args.device), seq_pad)) ) # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences overshooting_vars = tuple(zip(*overshooting_vars)) # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once) beliefs, prior_states, prior_means, prior_std_devs = transition_model(
def forward(self, x): x = F.max_pool2d(F.pad(x, self.padding, mode='replicate'), self.kernel_size, self.stride, 0, self.dilation) return x
def forward(self, x): out = F.pad(x, (0, 0, 0, 0, 0, self.num_zeros)) out = self.identity(out) return out