def reset_parameters(self): for cell in self._cells: # xavier initilization gate_size = self.hidden_size / 4 for weight in [cell.weight_ih, cell.weight_hh]: for w in torch.chunk(weight, 4, dim=0): init.xavier_normal_(w) #forget bias = 1 for bias in [cell.bias_ih, cell.bias_hh]: torch.chunk(bias, 4, dim=0)[1].data.fill_(1)
def forward(self, embbedings, label): if self.device_id == None: kernel_norm = l2_norm(self.kernel, axis = 0) cos_theta = torch.mm(embbedings, kernel_norm) else: x = embbedings sub_kernels = torch.chunk(self.kernel, len(self.device_id), dim=1) temp_x = x.cuda(self.device_id[0]) kernel_norm = l2_norm(sub_kernels[0], axis = 0).cuda(self.device_id[0]) cos_theta = torch.mm(temp_x, kernel_norm) for i in range(1, len(self.device_id)): temp_x = x.cuda(self.device_id[i]) kernel_norm = l2_norm(sub_kernels[i], axis = 0).cuda(self.device_id[i]) cos_theta = torch.cat((cos_theta, torch.mm(temp_x, kernel_norm).cuda(self.device_id[0])), dim=1) cos_theta = cos_theta.clamp(-1, 1) # for numerical stability phi = cos_theta - self.m label = label.view(-1, 1) # size=(B,1) index = cos_theta.data * 0.0 # size=(B,Classnum) index.scatter_(1, label.data.view(-1, 1), 1) index = index.byte() output = cos_theta * 1.0 output[index] = phi[index] # only change the correct predicted output output *= self.s # scale up in order to make softmax work, first introduced in normface return output
def forward(self, inputs, mask=None, layer_cache=None, step=None): """ Args: inputs (FloatTensor): ``(batch_size, input_len, model_dim)`` Returns: (FloatTensor, FloatTensor): * gating_outputs ``(batch_size, input_len, model_dim)`` * average_outputs average attention ``(batch_size, input_len, model_dim)`` """ batch_size = inputs.size(0) inputs_len = inputs.size(1) device = inputs.device average_outputs = self.cumulative_average( inputs, self.cumulative_average_mask(batch_size, inputs_len).to(device).float() if layer_cache is None else step, layer_cache=layer_cache) average_outputs = self.average_layer(average_outputs) gating_outputs = self.gating_layer(torch.cat((inputs, average_outputs), -1)) input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) gating_outputs = torch.sigmoid(input_gate) * inputs + \ torch.sigmoid(forget_gate) * average_outputs return gating_outputs, average_outputs
def forward(self, input, label): # --------------------------- cos(theta) & phi(theta) --------------------------- if self.device_id == None: cosine = F.linear(F.normalize(input), F.normalize(self.weight)) else: x = input sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) temp_x = x.cuda(self.device_id[0]) weight = sub_weights[0].cuda(self.device_id[0]) cosine = F.linear(F.normalize(temp_x), F.normalize(weight)) for i in range(1, len(self.device_id)): temp_x = x.cuda(self.device_id[i]) weight = sub_weights[i].cuda(self.device_id[i]) cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1) sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) phi = cosine * self.cos_m - sine * self.sin_m if self.easy_margin: phi = torch.where(cosine > 0, phi, cosine) else: phi = torch.where(cosine > self.th, phi, cosine - self.mm) # --------------------------- convert label to one-hot --------------------------- one_hot = torch.zeros(cosine.size()) if self.device_id != None: one_hot = one_hot.cuda(self.device_id[0]) one_hot.scatter_(1, label.view(-1, 1).long(), 1) # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 output *= self.s return output
def forward(self, x): if self.device_id == None: out = F.linear(x, self.weight, self.bias) else: sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) sub_biases = torch.chunk(self.bias, len(self.device_id), dim=0) temp_x = x.cuda(self.device_id[0]) weight = sub_weights[0].cuda(self.device_id[0]) bias = sub_biases[0].cuda(self.device_id[0]) out = F.linear(temp_x, weight, bias) for i in range(1, len(self.device_id)): temp_x = x.cuda(self.device_id[i]) weight = sub_weights[i].cuda(self.device_id[i]) bias = sub_biases[i].cuda(self.device_id[i]) out = torch.cat((out, F.linear(temp_x, weight, bias).cuda(self.device_id[0])), dim=1) return out
def mio_module(self, each_mmbox, len_conf): chunk = torch.chunk(each_mmbox, each_mmbox.shape[1], 1) bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) cls = ( torch.cat([bmax,chunk[3]], dim=1) if len_conf==0 else torch.cat([chunk[3],bmax],dim=1) ) if len(chunk)==6: cls = torch.cat([cls, chunk[4], chunk[5]], dim=1) elif len(chunk)==8: cls = torch.cat([cls, chunk[4], chunk[5], chunk[6], chunk[7]], dim=1) return cls
def vgg_preprocess(batch): tensortype = type(batch.data) (r, g, b) = torch.chunk(batch, 3, dim = 1) batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] mean = tensortype(batch.data.size()) mean[:, 0, :, :] = 103.939 mean[:, 1, :, :] = 116.779 mean[:, 2, :, :] = 123.680 batch = batch.sub(Variable(mean)) # subtract mean return batch
def forward(self, inputs): '''inputs: batch_size, num_tokens, 50 return-activations: batch_size, num_tokens + 2, output_dim return-mask: batch_size, num_tokens + 2''' token_embedding = self._token_embedder(inputs) mask = token_embedding["mask"] type_representation = token_embedding["token_embedding"] lstm_outputs = self._elmo_lstm(type_representation, mask) output_tensors = [ torch.cat([type_representation, type_representation], dim=-1) * mask.float().unsqueeze(-1) ] for layer_activations in torch.chunk(lstm_outputs, lstm_outputs.size(0), dim=0): output_tensors.append(layer_activations.squeeze(0)) return {"activations": output_tensors, "mask": mask}
def forward(self, og_x, a=None, cond_blocks=None): x = self.conv_input(self.nonlinearity(og_x)) if a is not None : x += self.nin_skip(self.nonlinearity(a)) x = self.nonlinearity(x) x = self.dropout(x) x = self.conv_out(x) if cond_blocks is not None: conditioning_block = cond_blocks[(x.size(2), x.size(3))] x += conditioning_block a, b = torch.chunk(x, 2, dim=1) c3 = a * F.sigmoid(b) return og_x + c3
def forward(self, # pylint: disable=arguments-differ inputs: torch.Tensor) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Parameters ---------- inputs: ``torch.autograd.Variable`` Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch. Returns ------- Dict with keys: ``'activations'``: ``List[torch.autograd.Variable]`` A list of activations at each layer of the network, each of shape ``(batch_size, timesteps + 2, embedding_dim)`` ``'mask'``: ``torch.autograd.Variable`` Shape ``(batch_size, timesteps + 2)`` long tensor with sequence mask. Note that the output tensors all include additional special begin and end of sequence markers. """ token_embedding = self._token_embedder(inputs) type_representation = token_embedding['token_embedding'] mask = token_embedding['mask'] lstm_outputs = self._elmo_lstm(type_representation, mask) # Prepare the output. The first layer is duplicated. output_tensors = [ torch.cat([type_representation, type_representation], dim=-1) ] for layer_activations in torch.chunk(lstm_outputs, lstm_outputs.size(0), dim=0): output_tensors.append(layer_activations.squeeze(0)) return { 'activations': output_tensors, 'mask': mask, }
def forward(self, input, label): # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power)) self.iter += 1 self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power)) # --------------------------- cos(theta) & phi(theta) --------------------------- if self.device_id == None: cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) else: x = input sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) temp_x = x.cuda(self.device_id[0]) weight = sub_weights[0].cuda(self.device_id[0]) cos_theta = F.linear(F.normalize(temp_x), F.normalize(weight)) for i in range(1, len(self.device_id)): temp_x = x.cuda(self.device_id[i]) weight = sub_weights[i].cuda(self.device_id[i]) cos_theta = torch.cat((cos_theta, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1) cos_theta = cos_theta.clamp(-1, 1) cos_m_theta = self.mlambda[self.m](cos_theta) theta = cos_theta.data.acos() k = (self.m * theta / 3.14159265).floor() phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k NormOfFeature = torch.norm(input, 2, 1) # --------------------------- convert label to one-hot --------------------------- one_hot = torch.zeros(cos_theta.size()) if self.device_id != None: one_hot = one_hot.cuda(self.device_id[0]) one_hot.scatter_(1, label.view(-1, 1), 1) # --------------------------- Calculate output --------------------------- output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta output *= NormOfFeature.view(-1, 1) return output
def chunk(self, n_chunks, dim=0): """Splits this tensor into a tuple of tensors. See :func:`torch.chunk`. """ return torch.chunk(self, n_chunks, dim)
def chunk(self, n_chunks, dim=0): r"""Splits this tensor into a certain number of tensor chunks. See :func:`torch.chunk`. """ return torch.chunk(self, n_chunks, dim)
def forward(self, x, actions, geco, global_step): """ x: [batch_size, T, C, H, W] """ x_means, masks, nll_t, kl_t, prior_zs, posterior_zs, lambdas, inference_steps = [], [], [], [], [], [], [], [] rollout_nll = [] undiscounted_nll = [] dynamics_dist = [] total_losses = 0. current_recurrent_states = { 'n_step_dynamics': { 'h': self.init_recurrent_states['n_step_dynamics']['h'].to(x.device).repeat(1, self.batch_size, 1) } } for step_index, inference_step in enumerate(self.time_inference_schedule): current_recurrent_states['inference_lambda'] = { 'h': torch.zeros(1, self.batch_size * self.K, self.lstm_dim).to(x.device), 'c': torch.zeros(1, self.batch_size * self.K, self.lstm_dim).to(x.device)} if inference_step == "I": x_loc, mask, nll, kl, posterior_zs, lambdas, total_loss, current_recurrent_states, i, dyn_dist = self.inference_step( x[:,step_index], geco, step_index, global_step, posterior_zs, lambdas, current_recurrent_states, actions) x_means += [x_loc] masks += [mask] nll_t += [nll] kl_t += [kl] inference_steps += [i] dynamics_dist += dyn_dist total_losses += total_loss # Dynamics elif inference_step == "D": x_t = x[:, step_index] x_loc, mask, nll, posterior_zs, current_recurrent_states, loss, dyn_dist = \ self.dynamics_step(x_t, geco, step_index, global_step, posterior_zs, current_recurrent_states, actions) x_means += [x_loc] masks += [mask] nll_t += [nll] total_losses += loss dynamics_dist += dyn_dist # Dynamics BMS elif inference_step == "BMS": x_t = x[:, step_index] x_loc, mask, nll, posterior_zs, current_recurrent_states, loss, dyn_dist = \ self.dynamics_bms_step(x_t, geco, step_index, global_step, posterior_zs, current_recurrent_states, actions) x_means += [x_loc] masks += [mask] nll_t += [nll] total_losses += loss dynamics_dist += dyn_dist # Random rollout elif inference_step == "R": with torch.no_grad(): if step_index == self.context_len: z_prev = torch.stack(posterior_zs) # copy z_prev to [context_len, n_samples*batch_size*K, z_size] z_prev = z_prev.repeat(1, self.stochastic_samples, 1) z_prev = list(torch.chunk(z_prev, self.context_len, dim=0)) # list of [n_samples*N*K, z_size] posterior_zs = [_.squeeze(0) for _ in z_prev] current_recurrent_states['n_step_dynamics']['h'] = current_recurrent_states['n_step_dynamics']['h'].to(x.device).repeat(1, self.stochastic_samples, 1) x_loc, mask, posterior_zs, current_recurrent_states, dyn_dist = self.rollout(posterior_zs, step_index, current_recurrent_states, actions) if x_means[-1].shape[1] != self.stochastic_samples: x_means = [_.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1, 1, 1) for _ in x_means] masks = [m.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1, 1, 1) for m in masks] x_means = x_means + x_loc masks = masks + mask dynamics_dist += dyn_dist elif inference_step == "U": # O(TSNK) complexity if step_index == self.context_len: z_prev = torch.stack(posterior_zs) # copy z_prev to [context_len, n_samples*batch_size*K, z_size] z_prev = z_prev.repeat(1, self.stochastic_samples, 1) z_prev = list(torch.chunk(z_prev, self.context_len, dim=0)) # list of [n_samples*N*K, z_size] posterior_zs = [_.squeeze(0) for _ in z_prev] current_recurrent_states['n_step_dynamics']['h'] = current_recurrent_states['n_step_dynamics']['h'].to(x.device).repeat(1, self.stochastic_samples, 1) x_means = [_.unsqueeze(0).repeat(self.stochastic_samples, 1, 1, 1, 1, 1).permute(1,0,2,3,4,5).contiguous() for _ in x_means] masks = [_.unsqueeze(0).repeat(self.stochastic_samples, 1, 1, 1, 1, 1).permute(1,0,2,3,4,5).contiguous() for _ in masks] x_t = x[:, step_index].repeat(self.stochastic_samples, 1, 1, 1) x_loc, mask, posterior_zs, current_recurrent_states, nll, nll_disc, _ = self.rollout(posterior_zs, step_index, current_recurrent_states, actions, x_t, True) x_means = x_means + x_loc masks = masks + mask rollout_nll += [nll] #undiscounted_nll += [nll] if self.dynamics_uncertainty: _, _, C, H, W = x.shape loss, best_indices, best_nll = self.uncertainty_loss(rollout_nll, geco, global_step) total_losses += torch.mean(loss) # average over batch size (summed over rollout steps) best_x_means, best_masks = [], [] batch_ids = torch.arange(self.batch_size).to(x.device) best_batch_indices = (self.stochastic_samples * batch_ids) + best_indices for i in range(len(x_means)): x_m = x_means[i].view(self.stochastic_samples, self.batch_size, self.K, C, H, W) x_m = x_m.view(-1, self.K, C, H, W) masks_m = masks[i].view(self.stochastic_samples, self.batch_size, self.K, 1, H, W) masks_m = masks_m.view(-1, self.K, 1, H, W) best_x_means += [x_m[best_batch_indices]] best_masks += [masks_m[best_batch_indices]] x_means = best_x_means masks = best_masks nll_to_return = torch.sum(torch.mean(torch.stack(nll_t), dim=1)) + torch.mean(best_nll) else: nll_to_return = torch.sum(torch.mean(torch.stack(nll_t), dim=1)) outs = { 'total_loss': total_losses, 'nll': nll_to_return, 'kl': torch.sum(torch.stack(kl_t)), 'x_means': x_means, 'masks': masks, 'posterior_zs': posterior_zs, 'inference_steps': torch.mean(torch.Tensor(inference_steps).to(x.device)), 'dynamics': dynamics_dist, 'lambdas': lambdas } return outs
def forward(self, input, indices): asize = list(indices.size()) isize = list(input.size()) ksize = self.ksize kstride = self.kstride x = np.array(torch.chunk(input, chunks = isize[0], dim = 0), dtype = torch.Tensor)
def vgg_deprocess(tensor): bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0 (b, g, r) = torch.chunk(bgr, 3, dim=0) rgb = torch.cat((r, g, b), 0) return rgb
def forward(self, lr_feature, hr_feature): #self.fuse.weight.data=torch.abs(self.fuse.weight.data) with torch.no_grad(): scale=hr_feature.shape[-1]//lr_feature.shape[-1] if scale%2==0: x=torch.arange(-scale//2,scale//2+1).float() x=torch.cat([x[:x.shape[0]//2],x[x.shape[0]//2+1:]]).unsqueeze(0) distance_matrix=x.expand(scale,scale).unsqueeze(0) distance_matrix=torch.cat([distance_matrix,distance_matrix.transpose(2,1)],0) distance_matrix=torch.cat([distance_matrix,torch.sqrt(torch.pow(distance_matrix[0],2)+torch.pow(distance_matrix[1],2)).unsqueeze(0)],0).unsqueeze(0) padding1=torch.zeros(hr_feature.shape[0],1,hr_feature.shape[2],scale).float().cuda() padding2=torch.zeros(hr_feature.shape[0],1,scale,hr_feature.shape[3]).float().cuda() else: exit() #center distance_matrix=distance_matrix.repeat(hr_feature.shape[0],1,hr_feature.shape[-2]//scale,hr_feature.shape[-1]//scale).float().cuda() distance_matrix2=distance_matrix+0 distance_matrix1=distance_matrix+0 distance_matrix3=distance_matrix+0 distance_matrix4=distance_matrix+0 lr_feature=lr_feature.unsqueeze(-1).expand(lr_feature.shape[0],lr_feature.shape[1],lr_feature.shape[2],lr_feature.shape[3],scale) \ .contiguous().view(lr_feature.shape[0],lr_feature.shape[1],lr_feature.shape[2],lr_feature.shape[3]*scale) \ .unsqueeze(-2).expand(lr_feature.shape[0],lr_feature.shape[1],lr_feature.shape[2],scale,lr_feature.shape[3]*scale) \ .contiguous().view(lr_feature.shape[0],lr_feature.shape[1],lr_feature.shape[2]*scale,lr_feature.shape[3]*scale) #128 representation=torch.cat([lr_feature,hr_feature,lr_feature*hr_feature,torch.pow(lr_feature-hr_feature,2)],1) weights1=self.similarity1(representation) weights2=self.similarity2(distance_matrix) mapping=self.fuse(torch.cat([weights1,weights2],1)) #right x=torch.arange(1,scale+1).float() x=x.expand(scale,scale).unsqueeze(0) x=x.repeat(hr_feature.shape[0],hr_feature.shape[-2]//scale,hr_feature.shape[-1]//scale).float().cuda() distance_matrix1[:,0,:,:]=scale-x+1 distance_matrix1[:,2,:,:]=torch.sqrt(torch.pow(distance_matrix1[:,0,:,:],2)+torch.pow(distance_matrix1[:,1,:,:],2)).unsqueeze(0) representation_r=torch.cat([lr_feature[:,:,:,scale:],hr_feature[:,:,:,:-scale],lr_feature[:,:,:,scale:]*hr_feature[:,:,:,:-scale], \ torch.pow(lr_feature[:,:,:,scale:]-hr_feature[:,:,:,:-scale],2)],1) weights1_r=self.similarity1(representation_r) weights2_r=self.similarity2(distance_matrix1[:,:,:,scale:]) #print(padding.shape) mapping_r=torch.cat([self.fuse(torch.cat([weights1_r,weights2_r],1)),padding1],-1) #left distance_matrix2[:,0,:,:]=x distance_matrix2[:,2,:,:]=torch.sqrt(torch.pow(distance_matrix2[:,0,:,:],2)+torch.pow(distance_matrix2[:,1,:,:],2)).unsqueeze(0) representation_l=torch.cat([lr_feature[:,:,:,:-scale],hr_feature[:,:,:,scale:],lr_feature[:,:,:,:-scale]*hr_feature[:,:,:,scale:], \ torch.pow(lr_feature[:,:,:,:-scale]-hr_feature[:,:,:,scale:],2)],1) weights1_l=self.similarity1(representation_l) weights2_l=self.similarity2(distance_matrix2[:,:,:,:-scale]) mapping_l=torch.cat([padding1,self.fuse(torch.cat([weights1_l,weights2_l],1))],-1) #top x=torch.arange(1,scale+1).float() x=x.expand(scale,scale).unsqueeze(0).transpose(2,1) x=x.repeat(hr_feature.shape[0],hr_feature.shape[-2]//scale,hr_feature.shape[-1]//scale).float().cuda() distance_matrix3[:,1,:,:]=(scale-x+1) distance_matrix3[:,2,:,:]=torch.sqrt(torch.pow(distance_matrix3[:,0,:,:],2)+torch.pow(distance_matrix3[:,1,:,:],2)).unsqueeze(0) representation_t=torch.cat([lr_feature[:,:,:-scale,:],hr_feature[:,:,scale:,:],lr_feature[:,:,:-scale,:]*hr_feature[:,:,scale:,:], \ torch.pow(lr_feature[:,:,:-scale,:]-hr_feature[:,:,scale:,:],2)],1) weights1_t=self.similarity1(representation_t) weights2_t=self.similarity2(distance_matrix3[:,:,:-scale,:]) mapping_t=torch.cat([padding2,self.fuse(torch.cat([weights1_t,weights2_t],1))],-2) #bottom distance_matrix4[:,1,:,:]=x distance_matrix4[:,2,:,:]=torch.sqrt(torch.pow(distance_matrix4[:,0,:,:],2)+torch.pow(distance_matrix4[:,1,:,:],2)).unsqueeze(0) representation_b=torch.cat([lr_feature[:,:,scale:,:],hr_feature[:,:,:-scale,:],lr_feature[:,:,scale:,:]*hr_feature[:,:,:-scale,:], \ torch.pow(lr_feature[:,:,scale:,:]-hr_feature[:,:,:-scale,:],2)],1) weights1_b=self.similarity1(representation_b) weights2_b=self.similarity2(distance_matrix4[:,:,scale:,:]) mapping_b=torch.cat([self.fuse(torch.cat([weights1_b,weights2_b],1)),padding2],-2) mapping_all=torch.cat([mapping,mapping_r,mapping_l,mapping_t,mapping_b],dim=1) mapping_norm=F.softmax(mapping_all, dim=1) #return mapping,mapping_r,mapping_l,mapping_t,mapping_b return torch.chunk(mapping_norm*mapping_all,5,dim=1)
def train(): # Turn on training mode which enables dropout. if args.restart: global train_loss, best_val_loss, eval_start_time, log_start_time else: global train_step, train_loss, best_val_loss, eval_start_time, log_start_time model.train() if args.batch_chunk > 1: mems = [tuple() for _ in range(args.batch_chunk)] else: mems = tuple() train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter for batch, (data, target, seq_len) in enumerate(train_iter): model.zero_grad() if args.batch_chunk > 1: data_chunks = torch.chunk(data, args.batch_chunk, 1) target_chunks = torch.chunk(target, args.batch_chunk, 1) for i in range(args.batch_chunk): data_i = data_chunks[i].contiguous() target_i = target_chunks[i].contiguous() ret = para_model(data_i, target_i, *mems[i]) loss, mems[i] = ret[0], ret[1:] loss = loss.float().mean().type_as(loss) / args.batch_chunk if args.fp16: optimizer.backward(loss) else: loss.backward() train_loss += loss.float().item() else: ret = para_model(data, target, *mems) loss, mems = ret[0], ret[1:] loss = loss.float().mean().type_as(loss) if args.fp16: optimizer.backward(loss) else: loss.backward() train_loss += loss.float().item() if args.fp16: optimizer.clip_master_grads(args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() if args.sample_softmax > 0: optimizer_sparse.step() # step-wise learning rate annealing train_step += 1 if args.scheduler in ['cosine', 'constant', 'dev_perf']: # linear warmup stage if train_step < args.warmup_step: curr_lr = args.lr * train_step / args.warmup_step optimizer.param_groups[0]['lr'] = curr_lr if args.sample_softmax > 0: optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2 else: if args.scheduler == 'cosine': scheduler.step(train_step) if args.sample_softmax > 0: scheduler_sparse.step(train_step) elif args.scheduler == 'inv_sqrt': scheduler.step(train_step) if train_step % args.log_interval == 0: cur_loss = train_loss / args.log_interval elapsed = time.time() - log_start_time log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \ '| ms/batch {:5.2f} | loss {:5.2f}'.format( epoch, train_step, batch+1, optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss) if args.dataset in ['enwik8', 'text8']: log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2)) else: log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss)) logging(log_str) train_loss = 0 log_start_time = time.time() if train_step % args.eval_interval == 0: val_loss = evaluate(va_iter) logging('-' * 100) log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ '| valid loss {:5.2f}'.format( train_step // args.eval_interval, train_step, (time.time() - eval_start_time), val_loss) if args.dataset in ['enwik8', 'text8']: log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2)) else: log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss)) logging(log_str) logging('-' * 100) # Save the model if the validation loss is the best we've seen so far. if not best_val_loss or val_loss < best_val_loss: if not args.debug: with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f: torch.save(model, f) with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f: torch.save(optimizer.state_dict(), f) with open(os.path.join(args.work_dir, 'scheduler.pt'), 'wb') as f: torch.save(scheduler.state_dict(), f) with open(os.path.join(args.work_dir, 'trainstep.pt'), 'wb') as f: torch.save(train_step, f) best_val_loss = val_loss # dev-performance based learning rate annealing if args.scheduler == 'dev_perf': scheduler.step(val_loss) if args.sample_softmax > 0: scheduler_sparse.step(val_loss) eval_start_time = time.time() if train_step == args.max_step: break
def forward(self, x): h = self.encoder(x) mu, log_var = torch.chunk(h, 2, dim=1) # mean and log variance. z = self.reparametrize(mu, log_var) out = self.decoder(z) return out, mu, log_var
def vgg_preprocess(tensor): (r, g, b) = torch.chunk(tensor, 3, dim=0) bgr = torch.cat((b, g, r), 0) out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr) return out
def vgg_preprocess_var(var): (r, g, b) = torch.chunk(var, 3, dim=1) bgr = torch.cat((b, g, r), 1) out = bgr * 255 - torch.autograd.Variable(vgg_mean[None, ...]).type(var.type()).expand_as(bgr) return out
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): # r_emb: [klen, n_head, d_head], used for term B # r_w_bias: [n_head, d_head], used for term C # r_bias: [klen, n_head], used for term D qlen, bsz = w.size(0), w.size(1) if mems is not None: cat = torch.cat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) else: w_heads = self.qkv_net(cat) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q = w_head_q[-qlen:] else: if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(w)) else: w_heads = self.qkv_net(w) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) klen = w_head_k.size(0) w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) if klen > r_emb.size(0): r_emb_pad = r_emb[0:1].expand(klen - r_emb.size(0), -1, -1) r_emb = torch.cat([r_emb_pad, r_emb], 0) r_bias_pad = r_bias[0:1].expand(klen - r_bias.size(0), -1) r_bias = torch.cat([r_bias_pad, r_bias], 0) else: r_emb = r_emb[-klen:] r_bias = r_bias[-klen:] #### compute attention score rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head BD = self._rel_shift(B_ + D_) # [qlen x klen x bsz x n_head] attn_score = AC + BD attn_score.mul_(self.scale) #### compute attention probability if attn_mask is not None and attn_mask.any().item(): if attn_mask.dim() == 2: attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf')) elif attn_mask.dim() == 3: attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf')) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) #### compute attention vector attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) # [qlen x bsz x n_head x d_head] attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection output = w + attn_out else: ##### residual connection + layer normalization output = self.layer_norm(w + attn_out) return output
def forward(self, input): # split left image and right image imgs = torch.chunk(input, 2, dim = 1) img_left = imgs[0] img_right = imgs[1] conv1_l = self.conv1(img_left) conv2_l = self.conv2(conv1_l) conv3a_l = self.conv3(conv2_l) conv1_r = self.conv1(img_right) conv2_r = self.conv2(conv1_r) conv3a_r = self.conv3(conv2_r) # Correlate corr3a_l and corr3a_r out_corr = self.corr(conv3a_l, conv3a_r) out_corr = self.corr_activation(out_corr) out_conv3a_redir = self.conv_redir(conv3a_l) in_conv3b = torch.cat((out_conv3a_redir, out_corr), 1) conv3b = self.conv3_1(in_conv3b) conv4a = self.conv4(conv3b) conv4b = self.conv4_1(conv4a) conv5a = self.conv5(conv4b) conv5b = self.conv5_1(conv5a) conv6a = self.conv6(conv5b) conv6b = self.conv6_1(conv6a) pr6 = self.pred_flow6(conv6b) upconv5 = self.upconv5(conv6b) upflow6 = self.upflow6to5(pr6) concat5 = torch.cat((upconv5, upflow6, conv5b), 1) iconv5 = self.iconv5(concat5) pr5 = self.pred_flow5(iconv5) upconv4 = self.upconv4(iconv5) upflow5 = self.upflow5to4(pr5) concat4 = torch.cat((upconv4, upflow5, conv4b), 1) iconv4 = self.iconv4(concat4) pr4 = self.pred_flow4(iconv4) upconv3 = self.upconv3(iconv4) upflow4 = self.upflow4to3(pr4) concat3 = torch.cat((upconv3, upflow4, conv3b), 1) iconv3 = self.iconv3(concat3) pr3 = self.pred_flow3(iconv3) upconv2 = self.upconv2(iconv3) upflow3 = self.upflow3to2(pr3) concat2 = torch.cat((upconv2, upflow3, conv2_l), 1) iconv2 = self.iconv2(concat2) pr2 = self.pred_flow2(iconv2) upconv1 = self.upconv1(iconv2) upflow2 = self.upflow2to1(pr2) concat1 = torch.cat((upconv1, upflow2, conv1_l), 1) iconv1 = self.iconv1(concat1) pr1 = self.pred_flow1(iconv1) upconv0 = self.upconv0(iconv1) upflow1 = self.upflow1to0(pr1) concat0 = torch.cat((upconv0, upflow1, img_left), 1) iconv0 = self.iconv0(concat0) # predict flow pr0 = self.pred_flow0(iconv0) # predict flow from dropout output # pr6 = self.pred_flow6(F.dropout2d(conv6b)) # pr5 = self.pred_flow5(F.dropout2d(iconv5)) # pr4 = self.pred_flow4(F.dropout2d(iconv4)) # pr3 = self.pred_flow3(F.dropout2d(iconv3)) # pr2 = self.pred_flow2(F.dropout2d(iconv2)) # pr1 = self.pred_flow1(F.dropout2d(iconv1)) # pr0 = self.pred_flow0(F.dropout2d(iconv0)) # if self.training: # # print("finish forwarding.") # return pr0, pr1, pr2, pr3, pr4, pr5, pr6 # else: # return pr0 # can be chosen outside return pr0, pr1, pr2, pr3, pr4, pr5, pr6
def forward(self, x): N, C, T, V = x.size() if self.use_spatial_att: attention = self.atts if self.use_pes: y = self.pes(x) else: y = x if self.att_s: q, k = torch.chunk(self.in_nets(y).view( N, 2 * self.num_subset, self.inter_channels, T, V), 2, dim=1) # nctv -> n num_subset c'tv attention = attention + self.tan( torch.einsum('nsctu,nsctv->nsuv', [q, k]) / (self.inter_channels * T)) * self.alphas if self.glo_reg_s: attention = attention + self.attention0s.repeat(N, 1, 1, 1) attention = self.drop(attention) y = torch.einsum('nctu,nsuv->nsctv', [x, attention]).contiguous() \ .view(N, self.num_subset * self.in_channels, T, V) y = self.out_nets(y) # nctv y = self.relu(self.downs1(x) + y) y = self.ff_nets(y) y = self.relu(self.downs2(x) + y) else: y = self.out_nets(x) y = self.relu(self.downs2(x) + y) # set_trace() # y_1 = self.out_nett_extend(y) # y_1 = self.relu(self.downt3(y) + y_1) # y = y_1 forward_mask = self.backward_mask.transpose(-1, -2) backward_mask = self.backward_mask if self.use_temporal_att: attention = self.attt if self.use_pet: z = self.pet(y) else: z = y q_k_in = self.in_nett(z).view(N, 6 * self.num_subset, self.inter_channels, T, V) q_f, q_b, q_c, k_f, k_b, k_c = torch.chunk(q_k_in, 6, dim=1) attention_b = torch.einsum('nsctv,nscqv->nstq', [q_b, k_b]) / ( self.inter_channels * V) * self.alphat_0 attention_f = torch.einsum('nsctv,nscqv->nstq', [q_f, k_f]) / ( self.inter_channels * V) * self.alphat_1 attention_c = torch.einsum('nsctv,nscqv->nstq', [q_c, k_c]) / ( self.inter_channels * V) * self.alphat_2 attention_b = torch.einsum('nstq,tq->nstq', [attention_b, backward_mask]) attention_f = torch.einsum('nstq,tq->nstq', [attention_f, forward_mask]) attention_b = self.drop(attention_b) attention_f = self.drop(attention_f) attention_c = self.drop(attention_c) z_f = torch.einsum('nctv,nstq->nscqv', [y, attention_f]).contiguous() \ .view(N, self.num_subset * self.out_channels, T, V) z_b = torch.einsum('nctv,nstq->nscqv', [y, attention_b]).contiguous() \ .view(N, self.num_subset * self.out_channels, T, V) z_c = torch.einsum('nctv,nstq->nscqv', [y, attention_c]).contiguous() \ .view(N, self.num_subset * self.out_channels, T, V) z = torch.cat([z_f, z_b, z_c], dim=-3) z = self.out_nett(z) # nctv z = self.relu(self.downt1(y) + z) z = self.ff_nett(z) z = self.relu(self.downt2(y) + z) else: z = self.out_nett(y) z = self.relu(self.downt2(y) + z) # set_trace() z_1 = self.out_nett_extend(z) z_1 = self.relu(self.downt3(z) + z_1) z = z_1 return z
def angle_axis_to_rotation_matrix_taylor(angle_axis): rx, ry, rz = torch.chunk(angle_axis, 3, dim=-1) ones = torch.ones_like(rx) R = torch.cat([ones, -rz, ry, rz, ones, -rx, -ry, rx, ones], dim=1).view(-1, 3, 3) return R
def forward(self, # pylint: disable=arguments-differ inputs: torch.Tensor, word_inputs: torch.Tensor = None) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Parameters ---------- inputs: ``torch.Tensor``, required. Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch. word_inputs : ``torch.Tensor``, required. If you passed a cached vocab, you can in addition pass a tensor of shape ``(batch_size, timesteps)``, which represent word ids which have been pre-cached. Returns ------- Dict with keys: ``'activations'``: ``List[torch.Tensor]`` A list of activations at each layer of the network, each of shape ``(batch_size, timesteps + 2, embedding_dim)`` ``'mask'``: ``torch.Tensor`` Shape ``(batch_size, timesteps + 2)`` long tensor with sequence mask. Note that the output tensors all include additional special begin and end of sequence markers. """ if self._word_embedding is not None and word_inputs is not None: try: mask_without_bos_eos = (word_inputs > 0).long() # The character cnn part is cached - just look it up. embedded_inputs = self._word_embedding(word_inputs) # type: ignore # shape (batch_size, timesteps + 2, embedding_dim) type_representation, mask = add_sentence_boundary_token_ids( embedded_inputs, mask_without_bos_eos, self._bos_embedding, self._eos_embedding ) except RuntimeError: # Back off to running the character convolutions, # as we might not have the words in the cache. token_embedding = self._token_embedder(inputs) mask = token_embedding['mask'] type_representation = token_embedding['token_embedding'] else: token_embedding = self._token_embedder(inputs) mask = token_embedding['mask'] type_representation = token_embedding['token_embedding'] lstm_outputs = self._elmo_lstm(type_representation, mask) # Prepare the output. The first layer is duplicated. # Because of minor differences in how masking is applied depending # on whether the char cnn layers are cached, we'll be defensive and # multiply by the mask here. It's not strictly necessary, as the # mask passed on is correct, but the values in the padded areas # of the char cnn representations can change. output_tensors = [ torch.cat([type_representation, type_representation], dim=-1) * mask.float().unsqueeze(-1) ] for layer_activations in torch.chunk(lstm_outputs, lstm_outputs.size(0), dim=0): output_tensors.append(layer_activations.squeeze(0)) return { 'activations': output_tensors, 'mask': mask, }
def forward(self, x): h = F.relu(self.conv1_1(x)) h = F.relu(self.conv1_2(h)) h = F.max_pool2d(h, 2, 2) h = F.relu(self.conv2_1(h)) h = F.relu(self.conv2_2(h)) h = F.max_pool2d(h, 2, 2) h = F.relu(self.conv3_1(h)) h = F.relu(self.conv3_2(h)) h = F.relu(self.conv3_3(h)) f3_3 = h h = F.max_pool2d(h, 2, 2) h = F.relu(self.conv4_1(h)) h = F.relu(self.conv4_2(h)) h = F.relu(self.conv4_3(h)) f4_3 = h h = F.max_pool2d(h, 2, 2) h = F.relu(self.conv5_1(h)) h = F.relu(self.conv5_2(h)) h = F.relu(self.conv5_3(h)) f5_3 = h h = F.max_pool2d(h, 2, 2) h = F.relu(self.fc6(h)) h = F.relu(self.fc7(h)) ffc7 = h h = F.relu(self.conv6_1(h)) h = F.relu(self.conv6_2(h)) f6_2 = h h = F.relu(self.conv7_1(h)) h = F.relu(self.conv7_2(h)) f7_2 = h f3_3 = self.conv3_3_norm(f3_3) f4_3 = self.conv4_3_norm(f4_3) f5_3 = self.conv5_3_norm(f5_3) cls1 = self.conv3_3_norm_mbox_conf(f3_3) reg1 = self.conv3_3_norm_mbox_loc(f3_3) cls2 = self.conv4_3_norm_mbox_conf(f4_3) reg2 = self.conv4_3_norm_mbox_loc(f4_3) cls3 = self.conv5_3_norm_mbox_conf(f5_3) reg3 = self.conv5_3_norm_mbox_loc(f5_3) cls4 = self.fc7_mbox_conf(ffc7) reg4 = self.fc7_mbox_loc(ffc7) cls5 = self.conv6_2_mbox_conf(f6_2) reg5 = self.conv6_2_mbox_loc(f6_2) cls6 = self.conv7_2_mbox_conf(f7_2) reg6 = self.conv7_2_mbox_loc(f7_2) # max-out background label chunk = torch.chunk(cls1, 4, 1) bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) cls1 = torch.cat([bmax, chunk[3]], dim=1) return [ cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6 ]
def forward(self, input): splits = torch.chunk(input, 2, dim=self.dim) return splits[0] if self.first_half else splits[1]
def validate(nets, loss_terms, opts, dataloader, epoch, network_type, devices=(cuda0,cuda1), batch_n="whole_test_show"): """ validate phase """ netD, netG = nets["netD"], nets["netG"] ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms['ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms["GANLoss"], loss_terms["StyleLoss"] optG, optD = opts['optG'], opts['optD'] device0, device1 = devices netG.to(device0) netD.to(device0) netG.eval() netD.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = {"g_loss":AverageMeter(),"p_loss":AverageMeter(), "s_loss":AverageMeter(), "r_loss":AverageMeter(), "whole_loss":AverageMeter(), "d_loss":AverageMeter()} netG.train() netD.train() end = time.time() val_save_dir = os.path.join(result_dir, "val_{}_{}".format(epoch, batch_n if isinstance(batch_n, str) else batch_n+1)) val_save_real_dir = os.path.join(val_save_dir, "real") val_save_gen_dir = os.path.join(val_save_dir, "gen") val_save_comp_dir = os.path.join(val_save_dir, "comp") for size in SIZES_TAGS: if not os.path.exists(os.path.join(val_save_real_dir, size)): os.makedirs(os.path.join(val_save_real_dir, size)) if not os.path.exists(os.path.join(val_save_gen_dir, size)): os.makedirs(os.path.join(val_save_gen_dir, size)) if not os.path.exists(os.path.join(val_save_comp_dir, size)): os.makedirs(os.path.join(val_save_comp_dir, size)) info = {} t = 0 for i, (ori_imgs, ori_masks) in enumerate(dataloader): data_time.update(time.time() - end) pre_imgs = ori_imgs pre_complete_imgs = (pre_imgs / 127.5 - 1) for s_i, size in enumerate(TRAIN_SIZES): masks = ori_masks['val'] masks = F.interpolate(masks, size) masks = (masks > 0).type(torch.FloatTensor) imgs = F.interpolate(ori_imgs, size) if imgs.size(1) != 3: print(t, imgs.size() ) pre_inter_imgs = F.interpolate(pre_complete_imgs, size) imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(device0), masks.to(device0), pre_complete_imgs.to(device0), pre_inter_imgs.to(device0) #masks = (masks > 0).type(torch.FloatTensor) #imgs, masks = imgs.to(device), masks.to(device) imgs = (imgs / 127.5 - 1) # mask is 1 on masked region # forward if network_type == 'l2h_unet': recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs, size) elif network_type == 'l2h_gated': recon_imgs = netG(imgs, masks, pre_inter_imgs) elif network_type == 'sa_gated': recon_imgs, _ = netG(imgs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat([recon_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks) imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(device1), complete_imgs.to(device1) p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs) s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(imgs, complete_imgs) p_loss, s_loss = p_loss.to(device0), s_loss.to(device0) imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(device0), complete_imgs.to(device0) whole_loss = r_loss + p_loss #g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['p_loss'].update(p_loss.item(), imgs.size(0)) losses['s_loss'].update(s_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) pre_complete_imgs = complete_imgs # Update time recorder batch_time.update(time.time() - end) # Logger logging #if t < config.STATIC_VIEW_SIZE: print(i, size) real_img = img2photo(imgs) gen_img = img2photo(recon_imgs) comp_img = img2photo(complete_imgs) real_img = Image.fromarray(real_img[0].astype(np.uint8)) gen_img = Image.fromarray(gen_img[0].astype(np.uint8)) comp_img = Image.fromarray(comp_img[0].astype(np.uint8)) real_img.save(os.path.join(val_save_real_dir, SIZES_TAGS[s_i], "{}.png".format(i))) gen_img.save(os.path.join(val_save_gen_dir, SIZES_TAGS[s_i], "{}.png".format(i))) comp_img.save(os.path.join(val_save_comp_dir, SIZES_TAGS[s_i], "{}.png".format(i))) end = time.time()
def preprocess_batch(batch): batch = batch.transpose(0, 1) (r, g, b) = torch.chunk(batch, 3) batch = torch.cat((b, g, r)) batch = batch.transpose(0, 1) return batch
def episodic_training(self, train_results, tail): episode = self.replay_buffer.get_tail(tail) sl = episode['s'] sl = list(torch.chunk(sl, int((len(sl) / self.batch) + 1))) s, r, t, e = [episode[k] for k in ['s', 'r', 't', 'e']] v = [] for s in sl: v.append(self.v_net(s)) v.append(torch.zeros_like(v[0][:1])) v = torch.cat(v).detach() v1, v2 = v[:-1], v[1:] adv, v_target = generalized_advantage_estimation( r, t, e, v1, v2, self.gamma, self.lambda_gae, norm=self.norm_rewards) episode['adv'] = adv episode['v_target'] = v_target if self.batch_ppo: n = self.steps_per_episode * self.batch indices = torch.randperm(tail * max(1, n // tail + 1)) % tail indices = indices[:n].unsqueeze(1).view(self.steps_per_episode, self.batch) samples = {k: v[indices] for k, v in episode.items()} iterator_pi = iter_dict(samples) iterator_v = iter_dict(samples) else: iterator_pi = itertools.repeat(episode, self.steps_per_episode) iterator_v = itertools.repeat(episode, self.steps_per_episode) for i, sample in enumerate(iterator_pi): s, a, r, t, stag, adv, v_target, log_pi_old = [ sample[k] for k in ['s', 'a', 'r', 't', 'stag', 'adv', 'v_target', 'logp'] ] self.pi_net(s) log_pi = self.pi_net.log_prob(a) ratio = torch.exp((log_pi - log_pi_old).sum(dim=1)) clip_adv = torch.clamp(ratio, 1 - self.eps_ppo, 1 + self.eps_ppo) * adv loss_p = -(torch.min(ratio * adv, clip_adv)).mean() approx_kl = -float((log_pi - log_pi_old).sum(dim=1).mean()) ent = float(self.pi_net.entropy().sum(dim=1).mean()) if approx_kl > self.target_kl: train_results['scalar']['pi_opt_rounds'].append(i) break clipped = ratio.gt(1 + self.eps_ppo) | ratio.lt(1 - self.eps_ppo) clipfrac = float( torch.as_tensor(clipped, dtype=torch.float32).mean()) self.optimizer_p.zero_grad() loss_p.backward() if self.clip_p: nn.utils.clip_grad_norm(self.pi_net.parameters(), self.clip_p) self.optimizer_p.step() train_results['scalar']['loss_p'].append(float(loss_p)) train_results['scalar']['approx_kl'].append(approx_kl) train_results['scalar']['ent'].append(ent) train_results['scalar']['clipfrac'].append(clipfrac) for sample in iterator_v: s, a, r, t, stag, adv, v_target, log_pi_old = [ sample[k] for k in ['s', 'a', 'r', 't', 'stag', 'adv', 'v_target', 'logp'] ] v = self.v_net(s) loss_v = F.mse_loss(v, v_target, reduction='mean') self.optimizer_v.zero_grad() loss_v.backward() if self.clip_q: nn.utils.clip_grad_norm(self.v_net.parameters(), self.clip_q) self.optimizer_v.step() train_results['scalar']['loss_v'].append(float(loss_v)) return train_results
def tensor_save_bgrimage(tensor): (b, g, r) = torch.chunk(tensor, 3) tensor = torch.cat((r, g, b)) return tensor_save_rgbimage(tensor)
def __init__(self, image_size, noise_size=100, num_label=10, output_params=False): super(Encoder, self).__init__() self.noise_size = noise_size self.image_size = image_size self.num_label = num_label self.core_net = nn.Sequential( nn.Conv2d( 4, 128, 5, 2, 2, bias=False), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, 5, 2, 2, bias=False), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 512, 5, 2, 2, bias=False), nn.BatchNorm2d(512), nn.ReLU(), Expression(lambda tensor: tensor.view(tensor.size(0), 512 * 4 * 4)), ) if output_params: self.core_net.add_module(str(len(self.core_net._modules)), WN_Linear(4 * 4 * 512, self.noise_size*2, train_scale=True, init_stdv=0.1)) self.core_net.add_module(str(len(self.core_net._modules)), Expression(lambda x: torch.chunk(x, 2, 1))) else: self.core_net.add_module(str(len(self.core_net._modules)), WN_Linear(4 * 4 * 512, self.noise_size, train_scale=True, init_stdv=0.1))
def conv2d( input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_mode="zeros", ): """ Overloads torch.conv2d to be able to use MPC on convolutional networks. The idea is to build new tensors from input and weight to compute a matrix multiplication equivalent to the convolution. Args: input: input image weight: convolution kernels bias: optional additive bias stride: stride of the convolution kernels padding: dilation: spacing between kernel elements groups: padding_mode: type of padding, should be either 'zeros' or 'circular' but 'reflect' and 'replicate' accepted Returns: the result of the convolution as an AdditiveSharingTensor """ assert len(input.shape) == 4 assert len(weight.shape) == 4 # Change to tuple if not one stride = torch.nn.modules.utils._pair(stride) padding = torch.nn.modules.utils._pair(padding) dilation = torch.nn.modules.utils._pair(dilation) # Extract a few useful values batch_size, nb_channels_in, nb_rows_in, nb_cols_in = input.shape nb_channels_out, nb_channels_in_, nb_rows_kernel, nb_cols_kernel = weight.shape if bias is not None: assert len(bias) == nb_channels_out # Check if inputs are coherent assert nb_channels_in == nb_channels_in_ * groups assert nb_channels_in % groups == 0 assert nb_channels_out % groups == 0 # Compute output shape nb_rows_out = int(((nb_rows_in + 2 * padding[0] - dilation[0] * (nb_rows_kernel - 1) - 1) / stride[0]) + 1) nb_cols_out = int(((nb_cols_in + 2 * padding[1] - dilation[1] * (nb_cols_kernel - 1) - 1) / stride[1]) + 1) # Apply padding to the input if padding != (0, 0): padding_mode = "constant" if padding_mode == "zeros" else padding_mode input = torch.nn.functional.pad( input, (padding[1], padding[1], padding[0], padding[0]), padding_mode) # Update shape after padding nb_rows_in += 2 * padding[0] nb_cols_in += 2 * padding[1] # We want to get relative positions of values in the input tensor that are used by one filter convolution. # It basically is the position of the values used for the top left convolution. pattern_ind = [] for ch in range(nb_channels_in): for r in range(nb_rows_kernel): for c in range(nb_cols_kernel): pixel = r * nb_cols_in * dilation[0] + c * dilation[1] pattern_ind.append(pixel + ch * nb_rows_in * nb_cols_in) # The image tensor is reshaped for the matrix multiplication: # on each row of the new tensor will be the input values used for each filter convolution # We will get a matrix [[in values to compute out value 0], # [in values to compute out value 1], # ... # [in values to compute out value nb_rows_out*nb_cols_out]] im_flat = input.view(batch_size, -1) im_reshaped = [] for cur_row_out in range(nb_rows_out): for cur_col_out in range(nb_cols_out): # For each new output value, we just need to shift the receptive field offset = cur_row_out * stride[ 0] * nb_cols_in + cur_col_out * stride[1] tmp = [ind + offset for ind in pattern_ind] im_reshaped.append(im_flat[:, tmp].wrap()) im_reshaped = torch.stack(im_reshaped).permute(1, 0, 2) # The convolution kernels are also reshaped for the matrix multiplication # We will get a matrix [[weights for out channel 0], # [weights for out channel 1], # ... # [weights for out channel nb_channels_out]].TRANSPOSE() weight_reshaped = weight.view(nb_channels_out // groups, -1).t().wrap() # Now that everything is set up, we can compute the result if groups > 1: res = [] chunks_im = torch.chunk(im_reshaped, groups, dim=2) chunks_weights = torch.chunk(weight_reshaped, groups, dim=0) for g in range(groups): tmp = chunks_im[g].matmul(chunks_weights[g]) res.append(tmp) res = torch.cat(res, dim=2) else: res = im_reshaped.matmul(weight_reshaped) # Add a bias if needed if bias is not None: res += bias # ... And reshape it back to an image res = (res.permute(0, 2, 1).view(batch_size, nb_channels_out, nb_rows_out, nb_cols_out).contiguous()) return res.child
help='url used to set up distributed training') parser.add_argument('--dist_backend', default='gloo', type=str, help='distributed backend') parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--rank', default=0, type=int, help='The rank of this process') args = parser.parse_args() args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) if args.distributed: input_data = torch.randn(int(args.num_samples / args.world_size), 1, 161, args.seconds * 100).cuda() else: input_data = torch.randn(args.num_samples, 1, 161, args.seconds * 100).cuda() input_data = torch.chunk(input_data, int(len(input_data) / args.batch_size)) rnn_type = args.rnn_type.lower() assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru" with open(args.labels_path) as label_file: labels = str(''.join(json.load(label_file))) audio_conf = dict(sample_rate=args.sample_rate, window_size=args.window_size) model = DeepSpeech(rnn_hidden_size=args.hidden_size, nb_layers=args.hidden_layers, audio_conf=audio_conf, labels=labels, rnn_type=supported_rnns[rnn_type])
def forward(self, x): h = self.encoder(x) mu, logvar = torch.chunk(h, 2, dim=1) z = self.reparameterize(mu, logvar) return self.decoder(z), mu, logvar
def forward(self, input): # split left image and right image # print(input.size()) imgs = torch.chunk(input, 2, dim = 1) img_left = imgs[0] img_right = imgs[1] conv1 = self.conv1(input) conv2 = self.conv2(conv1) conv3a = self.conv3(conv2) conv3b = self.conv3_1(conv3a) conv4a = self.conv4(conv3b) conv4b = self.conv4_1(conv4a) conv5a = self.conv5(conv4b) conv5b = self.conv5_1(conv5a) conv6a = self.conv6(conv5b) conv6b = self.conv6_1(conv6a) pr6 = self.pred_flow6(conv6b) upconv5 = self.upconv5(conv6b) upflow6 = self.upflow6to5(pr6) concat5 = torch.cat((upconv5, upflow6, conv5b), 1) iconv5 = self.iconv5(concat5) pr5 = self.pred_flow5(iconv5) upconv4 = self.upconv4(iconv5) upflow5 = self.upflow5to4(pr5) concat4 = torch.cat((upconv4, upflow5, conv4b), 1) iconv4 = self.iconv4(concat4) pr4 = self.pred_flow4(iconv4) upconv3 = self.upconv3(iconv4) upflow4 = self.upflow4to3(pr4) concat3 = torch.cat((upconv3, upflow4, conv3b), 1) iconv3 = self.iconv3(concat3) pr3 = self.pred_flow3(iconv3) upconv2 = self.upconv2(iconv3) upflow3 = self.upflow3to2(pr3) concat2 = torch.cat((upconv2, upflow3, conv2), 1) iconv2 = self.iconv2(concat2) pr2 = self.pred_flow2(iconv2) upconv1 = self.upconv1(iconv2) upflow2 = self.upflow2to1(pr2) concat1 = torch.cat((upconv1, upflow2, conv1), 1) iconv1 = self.iconv1(concat1) pr1 = self.pred_flow1(iconv1) upconv0 = self.upconv0(iconv1) upflow1 = self.upflow1to0(pr1) concat0 = torch.cat((upconv0, upflow1, img_left), 1) iconv0 = self.iconv0(concat0) pr0 = self.pred_flow0(iconv0) # img_right_rec = warp(img_left, pr0) # if self.training: # # print("finish forwarding.") # return pr0, pr1, pr2, pr3, pr4, pr5, pr6 # else: # return pr0 # can be chosen outside return pr0, pr1, pr2, pr3, pr4, pr5, pr6
def forward(self, x): """Applies network layers and ops on input image(s) x. Args: x: input image or batch of images. Shape: [batch,3,300,300]. Return: Depending on phase: test: Variable(tensor) of output class label predictions, confidence score, and corresponding location predictions for each object detected. Shape: [batch,topk,7] train: list of concat outputs from: 1: confidence layers, Shape: [batch*num_priors,num_classes] 2: localization layers, Shape: [batch,num_priors*4] 3: priorbox layers, Shape: [2,num_priors*4] """ image_size = [x.shape[2] , x.shape[3]] loc = list() conf = list() if backbone == 'vgg': for k in range(16): x = self.vgg[k](x) conv3_3_x = x for k in range(16 , 23): x = self.vgg[k](x) conv4_3_x = x for k in range(23 , 30): x = self.vgg[k](x) conv5_3_x = x for k in range(30, len(self.vgg)): x = self.vgg[k](x) fc7_x = x for k, v in enumerate(self.extras): x = F.relu(v(x), inplace=True) if k == 1: conv6_2_x = x if k == 3 : conv7_2_x = x elif backbone in ['senet','resnet50', 'detnet','resnet101','resnet152' , 'resnext']: conv3_3_x = self.layer1(x) conv4_3_x = self.layer2(conv3_3_x) conv5_3_x = self.layer3(conv4_3_x) fc7_x = self.layer4(conv5_3_x) conv6_2_x = self.layer5(fc7_x) conv7_2_x = self.layer6(conv6_2_x) if refine: arm_loc = list() arm_conf = list() arm_sources = [conv3_3_x, conv4_3_x, conv5_3_x, fc7_x, conv6_2_x, conv7_2_x] for (x, l, c) in zip(arm_sources, self.arm_loc, self.arm_conf): arm_loc.append( l(x).permute(0, 2, 3, 1).contiguous() ) arm_conf.append( c(x).permute(0, 2, 3, 1).contiguous() ) arm_loc = torch.cat([o.view(o.size(0), -1) for o in arm_loc], 1) arm_conf = torch.cat([o.view(o.size(0), -1) for o in arm_conf], 1) if fpn: #lfpn6 = self._upsample_product( self.latlayer6(conv7_2_x) , self.smooth6(conv7_2_x)) #lfpn5 = self._upsample_product( self.latlayer5(lfpn6) , self.smooth5(conv6_2_x)) #lfpn4 = self._upsample_product( self.latlayer4(lfpn5) , self.smooth4(fc7_x) ) #lfpn3 = self._upsample_product( self.latlayer3(lfpn4) , self.smooth3(conv5_3_x) ) lfpn3 = self._upsample_product( self.latlayer3(fc7_x) , self.smooth3(conv5_3_x) ) lfpn2 = self._upsample_product( self.latlayer2(lfpn3) , self.smooth2(conv4_3_x) ) lfpn1 = self._upsample_product( self.latlayer1(lfpn2) , self.smooth1(conv3_3_x) ) #conv7_2_x = lfpn6 #conv6_2_x = lfpn5 #fc7_x = lfpn4 conv5_3_x = lfpn3 conv4_3_x = lfpn2 conv3_3_x = lfpn1 if backbone == 'vgg': conv3_3_x = self.L2Norm_3_3(conv3_3_x) conv4_3_x = self.L2Norm_4_3(conv4_3_x) conv5_3_x = self.L2Norm_5_3(conv5_3_x) if bup: #conv4_3_x = F.relu(self.bup1(conv3_3_x)) * conv4_3_x #conv5_3_x = F.relu(self.bup2(conv4_3_x)) * conv5_3_x fc7_x = F.relu(self.bup3(conv5_3_x)) * fc7_x conv6_2_x = F.relu(self.bup4(fc7_x)) * conv6_2_x conv7_2_x = F.relu( self.bup5(conv6_2_x)) * conv7_2_x sources = [conv3_3_x, conv4_3_x, conv5_3_x, fc7_x, conv6_2_x, conv7_2_x] if fem: sources[0] = self.cpm3_3(sources[0]) sources[1] = self.cpm4_3(sources[1]) sources[2] = self.cpm5_3(sources[2]) sources[3] = self.cpm7(sources[3]) sources[4] = self.cpm6_2(sources[4]) sources[5] = self.cpm7_2(sources[5]) # apply multibox head to source layers featuremap_size = [] for (x, l, c) in zip(sources, self.loc, self.conf): featuremap_size.append([ x.shape[2], x.shape[3]]) loc.append(l(x).permute(0, 2, 3, 1).contiguous()) if mo: if len(conf)==0: chunk = torch.chunk(c(x) , 4 , 1) bmax = torch.max(torch.max(chunk[0], chunk[1]) , chunk[2]) cls1 = torch.cat([bmax,chunk[3]], dim=1) conf.append( cls1.permute(0, 2, 3, 1).contiguous() ) else: conf.append(c(x).permute(0, 2, 3, 1).contiguous()) elif mio: len_conf = len(conf) if cfg['mbox'][0] ==1 : cls = self.mio_module(c(x),len_conf) else: mmbox = torch.chunk(c(x) , cfg['mbox'][0] , 1) cls_0 = self.mio_module(mmbox[0], len_conf) cls_1 = self.mio_module(mmbox[1], len_conf) cls_2 = self.mio_module(mmbox[2], len_conf) cls_3 = self.mio_module(mmbox[3], len_conf) cls = torch.cat([cls_0, cls_1, cls_2, cls_3] , dim=1) conf.append(cls.permute(0, 2, 3, 1).contiguous()) else: conf.append(c(x).permute(0, 2, 3, 1).contiguous()) if pa: mbox_num = cfg['mbox'][0] face_loc = torch.cat( [o[:,:,:,:4*mbox_num].contiguous().view(o.size(0),-1) for o in loc],1) face_conf = torch.cat( [o[:,:,:,:2*mbox_num].contiguous().view(o.size(0),-1) for o in conf],1) head_loc = torch.cat( [o[:,:,:,4*mbox_num:8*mbox_num].contiguous().view(o.size(0),-1) for o in loc[1:]],1) head_conf = torch.cat( [o[:,:,:,2*mbox_num:4*mbox_num].contiguous().view(o.size(0),-1) for o in conf[1:]],1) body_loc = torch.cat( [o[:,:,:,8*mbox_num:].contiguous().view(o.size(0),-1) for o in loc[2:]],1) body_conf = torch.cat( [o[:,:,:,4*mbox_num:].contiguous().view(o.size(0),-1) for o in conf[2:]],1) else: face_loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) face_conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) if self.phase == "test": self.cfg['feature_maps'] = featuremap_size self.cfg['min_dim'] = image_size self.priors = self.init_priors(self.cfg) if refine: output = self.detect( face_loc.view(face_loc.size(0), -1, 4), # loc preds self.softmax(face_conf.view(face_conf.size(0), -1, self.num_classes)), # conf preds self.priors.type(type(x.data)), # default boxes arm_loc.view(arm_loc.size(0), -1, 4), self.softmax(arm_conf.view(arm_conf.size(0), -1, self.num_classes)), ) else: output = self.detect( face_loc.view(face_loc.size(0), -1, 4), # loc preds self.softmax(face_conf.view(face_conf.size(0), -1, self.num_classes)), # conf preds self.priors.type(type(x.data)) # default boxes ) else: self.cfg['feature_maps'] = featuremap_size self.cfg['min_dim'] = image_size if pa: self.face_priors = self.init_priors(self.cfg) self.head_priors = self.init_priors(self.cfg , min_size=cfg['min_sizes'][:-1], max_size=cfg['max_sizes'][:-1]) self.body_priors = self.init_priors(self.cfg , min_size=cfg['min_sizes'][:-2], max_size=cfg['max_sizes'][:-2]) output = ( face_loc.view(face_loc.size(0), -1, 4), face_conf.view(face_conf.size(0), -1, self.num_classes), self.face_priors, head_loc.view(head_loc.size(0), -1, 4), head_conf.view(head_conf.size(0), -1, self.num_classes), self.head_priors, body_loc.view(body_loc.size(0), -1, 4), body_conf.view(body_conf.size(0), -1, self.num_classes), self.body_priors ) else: self.priors = self.init_priors(self.cfg) output = ( face_loc.view(face_loc.size(0), -1, 4), face_conf.view(face_conf.size(0), -1, self.num_classes), self.priors ) if refine: output = output + tuple((arm_loc.view(arm_loc.size(0), -1, 4), arm_conf.view(arm_conf.size(0), -1, self.num_classes) )) return output
def gnn_track_finding( hid, x, cell_data, embed_ckpt_dir='/global/cfs/cdirs/m3443/data/lightning_models/embedding/checkpoints/epoch=10.ckpt', filter_ckpt_dir='/global/cfs/cdirs/m3443/data/lightning_models/filtering/checkpoints/epoch=92.ckpt', gnn_ckpt_dir='/global/cfs/cdirs/m3443/data/lightning_models/gnn', ckpt_idx=-1, dbscan_epsilon=0.25, dbscan_minsamples=2): device = 'cuda' if torch.cuda.is_available() else 'cpu' # ### Setup some hyperparameters and event # embed_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/embedding/checkpoints/epoch=10.ckpt' # filter_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/filtering/checkpoints/epoch=92.ckpt' # gnn_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/gnn' # ckpt_idx = -1 # which GNN checkpoint to load # dbscan_epsilon, dbscan_minsamples = 0.25, 2 # hyperparameters for DBScan # min_hits = 5 # minimum number of hits associated with a particle to define "reconstructable particles" # frac_reco_matched, frac_truth_matched = 0.5, 0.5 # parameters for track matching data = Data(hid=torch.from_numpy(hid), x=torch.from_numpy(x).float(), cell_data=torch.from_numpy(cell_data).float()).to(device) # ### Evaluating Embedding # In[9]: e_ckpt = torch.load(embed_ckpt_dir, map_location=device) e_config = e_ckpt['hyper_parameters'] e_config['clustering'] = 'build_edges' e_config['knn_val'] = 500 e_config['r_val'] = 1.7 e_model = LayerlessEmbedding(e_config).to(device) e_model.load_state_dict(e_ckpt["state_dict"]) e_model.eval() # Map each hit to the embedding space, return the embeded parameters for each hit with torch.no_grad(): spatial = e_model(torch.cat([data.cell_data, data.x], axis=-1)) #.to(device) # ### From embeddeding space form doublets # `r_val = 1.7` and `knn_val = 500` are the hyperparameters to be studied. # # * `r_val` defines the radius of the clustering method # * `knn_val` defines the number of maximum neighbors in the embedding space e_spatial = utils_torch.build_edges(spatial.to(device), e_model.hparams['r_val'], e_model.hparams['knn_val']) # Removing edges that point from outer region to inner region, which almost removes half of edges. # In[16]: R_dist = torch.sqrt(data.x[:, 0]**2 + data.x[:, 2]**2) # distance away from origin... e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])] f_ckpt = torch.load(filter_ckpt_dir, map_location='cpu') f_config = f_ckpt['hyper_parameters'] f_config['train_split'] = [0, 0, 1] f_config['filter_cut'] = 0.18 f_model = VanillaFilter(f_config).to(device) f_model.load_state_dict(f_ckpt['state_dict']) f_model.eval() emb = None # embedding information was not used in the filtering stage. chunks = 10 output_list = [] for j in range(chunks): subset_ind = torch.chunk(torch.arange(e_spatial.shape[1]), chunks)[j] with torch.no_grad(): output = f_model(torch.cat([data.cell_data, data.x], axis=-1), e_spatial[:, subset_ind], emb).squeeze() #.to(device) output_list.append(output) del subset_ind del output gc.collect() output = torch.cat(output_list) output = torch.sigmoid(output) # The filtering network assigns a score to each edge. # In the end, edges with socres > `filter_cut` are selected to construct graphs. # edge_list = e_spatial[:, output.to('cpu') > f_model.hparams['filter_cut']] print(f_model.hparams['filter_cut']) edge_list = e_spatial[:, output > f_model.hparams['filter_cut']] print(edge_list.shape) # ### Form a graph # Now moving TensorFlow for GNN inference. n_nodes = data.x.shape[0] n_edges = edge_list.shape[1] nodes = data.x.cpu().numpy().astype(np.float32) edges = np.zeros((n_edges, 1), dtype=np.float32) senders = edge_list[0].cpu() receivers = edge_list[1].cpu() input_datadict = { "n_node": n_nodes, "n_edge": n_edges, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": np.array([n_nodes], dtype=np.float32) } input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict]) num_processing_steps_tr = 8 optimizer = snt.optimizers.Adam(0.001) model = SegmentClassifier() output_dir = gnn_ckpt_dir checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) ckpt_manager = tf.train.CheckpointManager(checkpoint, directory=output_dir, max_to_keep=10) status = checkpoint.restore( ckpt_manager.checkpoints[ckpt_idx]).expect_partial() # clean up GPU memory del e_spatial del e_model del f_model gc.collect() if device == 'cuda': torch.cuda.empty_cache() outputs_gnn = model(input_graph, num_processing_steps_tr) output_graph = outputs_gnn[-1] # ### Track labeling input_matrix = prepare_labeling( tf.squeeze(output_graph.edges).cpu().numpy(), senders, receivers, n_nodes) predict_track_df = dbscan_clustering(data.hid.cpu(), input_matrix, dbscan_epsilon, dbscan_minsamples) trkx_groups = predict_track_df.groupby(['track_id']) all_trk_ids = np.unique(predict_track_df.track_id) n_trkxs = all_trk_ids.shape[0] predict_tracks = [ trkx_groups.get_group(all_trk_ids[idx])['hit_id'].to_numpy().tolist() for idx in range(n_trkxs) ] return predict_tracks
def forward(self, input): internal_state = [] fcn2_output = input # 12 *(64*64) * 64 channels input = torch.cat(torch.chunk(input, 12, dim=0), dim=1) for step in range(3): x = input if step == 0: basize, _, height, width = input.size() (h_step, c) = ConvLSTMCell.init_hidden( basize, self.hidden_channels[self.num_layers - 1], (height, width)) fcn2 = self.conv_fcn2(x) h_c = self.conv_h(h_step) fcn2_h_cat = fcn2 + h_c fcn2_h_cat = self.pool_avg(fcn2_h_cat) fcn2_h_cat = self.conv_c(fcn2_h_cat) # Attention Module fcn2_h_cat = torch.mul(F.softmax(fcn2_h_cat, dim=1), 12) Att = fcn2_h_cat basize, dime, h, w = fcn2_h_cat.size() fcn2_h_cat = fcn2_h_cat.view(1, basize, dime, h, w).transpose(0, 1).transpose(1, 2) fcn2_h_cat = torch.cat(torch.chunk(fcn2_h_cat, basize, dim=0), dim=1).view(basize * dime, 1, 1, 1) fcn2_h_cat = torch.mul(fcn2_output, fcn2_h_cat).view(1, basize * dime, 64, 64, 64) fcn2_h_cat = torch.cat(torch.chunk(fcn2_h_cat, basize, dim=1), dim=0) fcn2_h_cat = torch.sum(fcn2_h_cat, 1, keepdim=False) #.squeeze() x = fcn2_h_cat if step < self.step - 1: for i in range(self.num_layers): # all cells are initialized in the first step if step == 0: bsize, _, height, width = x.size() (h, c) = ConvLSTMCell.init_hidden(bsize, self.hidden_channels[i], (height, width)) internal_state.append((h, c)) # do forward name = 'cell{}'.format(i) (h, c) = internal_state[i] x, new_c, new_o = getattr(self, name)(x, h, c) # ConvLSTMCell forward internal_state[i] = (x, new_c) h_step = x # only record effective steps #if step in self.effective_step: if step == 0: outputs_o = new_o else: outputs_o = torch.cat((outputs_o, new_o), dim=1) # outputs_o = torch.cat([outputs_o, new_o], dim=1) # outputs_o = torch.cat([outputs_o, outputs_o, outputs_o], dim=1) outputs = self.conv_pre(outputs_o) output = F.upsample(outputs, scale_factor=4, mode='bilinear') return output
def forward(self, x): x = self.conv1(x) # add_scalar x = x + 3 # mul_scalar x = x * 3 # add_scalar_out x += 3 # mul_scalar_out x *= 3 # add_scalar_relu x = x + 3 x = F.relu(x) # add_scalar_relu_out x += 3 x = F.relu(x) # mul_scalar_relu x = x * 3 x = F.relu(x) # mul_scalar_relu_out x *= 3 x = F.relu(x) x = self.maxpool1d(x) x = self.maxpool2d(x) x = self.maxpool3d(x) x = torch.flatten(x) x = torch.max(x) x = torch.min(x) x = x.reshape([-1]) x = x.resize_(1, 1, x.numel()) x = x.view(-1) # prim::ListConstruct xs = [x, x] # prim::ListUnpack x, y = xs # prim::TupleConstruct xs = (x, x) # prim::TupleUnpack x, y = xs x = x.transpose(1, 2) x = x.contiguous() x, y = torch.chunk(x, 2) x = F.dropout(x) x = self.dropout(x) x, _ = torch.sort(x) x = x.permute(0, 2, 3, 1) x = x.repeat_interleave(3, 1) x = torch.repeat_interleave(x, 3, 1) x = self.relu(x) x = F.relu(x) x = F.relu(x, inplace=True) x = x.relu() x.relu_() x = x.squeeze(0) x.squeeze_(0) x = torch.squeeze(x, 0) x = x.unsqueeze(0) x.unsqueeze_(0) x = torch.unsqueeze(x, 0) x = x.detach() x.detach_() x = x.repeat(4, 2) y = [] y.append(x) z = torch.stack(y, 0) z = [z, z] x, _ = z x = self.conv2(x) return x
def RGB_to_BGR(batch): batch = batch.transpose(0, 1) (r, g, b) = torch.chunk(batch, 3) batch = torch.cat((b, g, r)) batch = batch.transpose(0, 1) return batch
def calculate_cpes( self, training_batch, states, next_states, all_next_action_scores, logged_action_idxs, discount_tensor, not_done_mask, ): if not self.calc_cpe_in_training: return None, None, None if training_batch.extras.metrics is None: metrics_reward_concat_real_vals = training_batch.training_input.reward else: metrics_reward_concat_real_vals = torch.cat( (training_batch.training_input.reward, training_batch.extras.metrics), dim=1, ) model_propensities_next_states = masked_softmax( all_next_action_scores, training_batch.training_input.possible_next_actions_mask if self.maxq_learning else training_batch.training_input.next_action, self.rl_temperature, ) with torch.enable_grad(): ######### Train separate reward network for CPE evaluation ############# # FIXME: the reward network should be outputing a tensor, not a q-value object reward_estimates = self.reward_network(states).q_values reward_estimates_for_logged_actions = reward_estimates.gather( 1, self.reward_idx_offsets + logged_action_idxs ) reward_loss = F.mse_loss( reward_estimates_for_logged_actions, metrics_reward_concat_real_vals ) reward_loss.backward() self._maybe_run_optimizer( self.reward_network_optimizer, self.minibatches_per_step ) ######### Train separate q-network for CPE evaluation ############# metric_q_values = self.q_network_cpe(states).q_values.gather( 1, self.reward_idx_offsets + logged_action_idxs ) all_metrics_target_q_values = torch.chunk( self.q_network_cpe_target(next_states).q_values.detach(), len(self.metrics_to_score), dim=1, ) target_metric_q_values = [] for i, per_metric_target_q_values in enumerate(all_metrics_target_q_values): per_metric_next_q_values = torch.sum( per_metric_target_q_values * model_propensities_next_states, 1, keepdim=True, ) per_metric_next_q_values = per_metric_next_q_values * not_done_mask per_metric_target_q_values = metrics_reward_concat_real_vals[ :, i : i + 1 ] + (discount_tensor * per_metric_next_q_values) target_metric_q_values.append(per_metric_target_q_values) target_metric_q_values = torch.cat(target_metric_q_values, dim=1) metric_q_value_loss = self.q_network_loss( metric_q_values, target_metric_q_values ) metric_q_value_loss.backward() self._maybe_run_optimizer( self.q_network_cpe_optimizer, self.minibatches_per_step ) # Use the soft update rule to update target network self._maybe_soft_update( self.q_network_cpe, self.q_network_cpe_target, self.tau, self.minibatches_per_step, ) model_propensities = masked_softmax( self.all_action_scores, training_batch.training_input.possible_actions_mask if self.maxq_learning else training_batch.training_input.action, self.rl_temperature, ) model_rewards = reward_estimates[ :, torch.arange( self.reward_idx_offsets[0], self.reward_idx_offsets[0] + self.num_actions, ), ] return reward_loss, model_rewards, model_propensities
def vgg_preprocess_caffe(var): (r, g, b) = torch.chunk(var, 3, dim=1) bgr = torch.cat((b, g, r), 1) out = bgr * 255 - torch.autograd.Variable(vgg_mean).type(var.type()) return out
def _compute_rotation_matrix_taylor(angle_axis): rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) k_one = torch.ones_like(rx) rotation_matrix = torch.cat( [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1) return rotation_matrix.view(-1, 3, 3)
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) if mems is not None: cat = torch.cat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) else: w_heads = self.qkv_net(cat) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q = w_head_q[-qlen:] else: if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(w)) else: w_heads = self.qkv_net(w) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) klen = w_head_k.size(0) w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head #### compute attention score rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head rr_head_q = w_head_q + r_r_bias BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head BD = self._rel_shift(BD) # [qlen x klen x bsz x n_head] attn_score = AC + BD attn_score.mul_(self.scale) #### compute attention probability if attn_mask is not None and attn_mask.any().item(): if attn_mask.dim() == 2: attn_score = attn_score.float().masked_fill( attn_mask[None, :, :, None], -float('inf')).type_as(attn_score) elif attn_mask.dim() == 3: attn_score = attn_score.float().masked_fill( attn_mask[:, :, :, None], -float('inf')).type_as(attn_score) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) #### compute attention vector attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) # [qlen x bsz x n_head x d_head] attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection output = w + attn_out else: ##### residual connection + layer normalization output = self.layer_norm(w + attn_out) return output
def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0)
def forward( self, x: torch.Tensor, states=Tuple[torch.Tensor, torch.Tensor], variational_dropout_mask: Optional[torch.BoolTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ !!! Warning DO NOT USE THIS LAYER DIRECTLY, instead use the AugmentedLSTM class # Parameters x : `torch.Tensor` Input tensor of shape (bsize x input_dim). states : `Tuple[torch.Tensor, torch.Tensor]` Tuple of tensors containing the hidden state and the cell state of each element in the batch. Each of these tensors have a dimension of (bsize x nhid). Defaults to `None`. # Returns `Tuple[torch.Tensor, torch.Tensor]` Returned states. Shape of each state is (bsize x nhid). """ hidden_state, memory_state = states # In Pytext this was done as the last step of the cell. # But the original AugmentedLSTM from AllenNLP this was done before the processing if variational_dropout_mask is not None and self.training: hidden_state = hidden_state * variational_dropout_mask projected_input = self.input_linearity(x) projected_state = self.state_linearity(hidden_state) input_gate = forget_gate = memory_init = output_gate = highway_gate = None if self.use_highway: fused_op = projected_input[:, : 5 * self.lstm_dim] + projected_state fused_chunked = torch.chunk(fused_op, 5, 1) (input_gate, forget_gate, memory_init, output_gate, highway_gate) = fused_chunked highway_gate = torch.sigmoid(highway_gate) else: fused_op = projected_input + projected_state input_gate, forget_gate, memory_init, output_gate = torch.chunk(fused_op, 4, 1) input_gate = torch.sigmoid(input_gate) forget_gate = torch.sigmoid(forget_gate) memory_init = torch.tanh(memory_init) output_gate = torch.sigmoid(output_gate) memory = input_gate * memory_init + forget_gate * memory_state timestep_output: torch.Tensor = output_gate * torch.tanh(memory) if self.use_highway: highway_input_projection = projected_input[ :, self._highway_inp_proj_start: self._highway_inp_proj_end ] timestep_output = ( highway_gate * timestep_output + (1 - highway_gate) * highway_input_projection # noqa ) return timestep_output, memory