Esempio n. 1
0
 def reset_parameters(self):
     for cell in self._cells:
         # xavier initilization
         gate_size = self.hidden_size / 4
         for weight in [cell.weight_ih, cell.weight_hh]:
             for w in torch.chunk(weight, 4, dim=0):
                 init.xavier_normal_(w)
         #forget bias = 1
         for bias in [cell.bias_ih, cell.bias_hh]:
             torch.chunk(bias, 4, dim=0)[1].data.fill_(1)
Esempio n. 2
0
    def forward(self, embbedings, label):
        if self.device_id == None:
            kernel_norm = l2_norm(self.kernel, axis = 0)
            cos_theta = torch.mm(embbedings, kernel_norm)
        else:
            x = embbedings
            sub_kernels = torch.chunk(self.kernel, len(self.device_id), dim=1)
            temp_x = x.cuda(self.device_id[0])
            kernel_norm = l2_norm(sub_kernels[0], axis = 0).cuda(self.device_id[0])
            cos_theta = torch.mm(temp_x, kernel_norm)
            for i in range(1, len(self.device_id)):
                temp_x = x.cuda(self.device_id[i])
                kernel_norm = l2_norm(sub_kernels[i], axis = 0).cuda(self.device_id[i])
                cos_theta = torch.cat((cos_theta, torch.mm(temp_x, kernel_norm).cuda(self.device_id[0])), dim=1)

        cos_theta = cos_theta.clamp(-1, 1)  # for numerical stability
        phi = cos_theta - self.m
        label = label.view(-1, 1)  # size=(B,1)
        index = cos_theta.data * 0.0  # size=(B,Classnum)
        index.scatter_(1, label.data.view(-1, 1), 1)
        index = index.byte()
        output = cos_theta * 1.0
        output[index] = phi[index]  # only change the correct predicted output
        output *= self.s  # scale up in order to make softmax work, first introduced in normface

        return output
Esempio n. 3
0
    def forward(self, inputs, mask=None, layer_cache=None, step=None):
        """
        Args:
            inputs (FloatTensor): ``(batch_size, input_len, model_dim)``

        Returns:
            (FloatTensor, FloatTensor):

            * gating_outputs ``(batch_size, input_len, model_dim)``
            * average_outputs average attention
                ``(batch_size, input_len, model_dim)``
        """

        batch_size = inputs.size(0)
        inputs_len = inputs.size(1)

        device = inputs.device
        average_outputs = self.cumulative_average(
          inputs, self.cumulative_average_mask(batch_size,
                                               inputs_len).to(device).float()
          if layer_cache is None else step, layer_cache=layer_cache)
        average_outputs = self.average_layer(average_outputs)
        gating_outputs = self.gating_layer(torch.cat((inputs,
                                                      average_outputs), -1))
        input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2)
        gating_outputs = torch.sigmoid(input_gate) * inputs + \
            torch.sigmoid(forget_gate) * average_outputs

        return gating_outputs, average_outputs
Esempio n. 4
0
    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        if self.device_id == None:
            cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        else:
            x = input
            sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
            temp_x = x.cuda(self.device_id[0])
            weight = sub_weights[0].cuda(self.device_id[0])
            cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
            for i in range(1, len(self.device_id)):
                temp_x = x.cuda(self.device_id[i])
                weight = sub_weights[i].cuda(self.device_id[i])
                cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1) 
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size())
        if self.device_id != None:
            one_hot = one_hot.cuda(self.device_id[0])
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s

        return output
Esempio n. 5
0
 def forward(self, x):
     if self.device_id == None:
         out = F.linear(x, self.weight, self.bias)
     else:
         sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
         sub_biases = torch.chunk(self.bias, len(self.device_id), dim=0)
         temp_x = x.cuda(self.device_id[0])
         weight = sub_weights[0].cuda(self.device_id[0])
         bias = sub_biases[0].cuda(self.device_id[0])
         out = F.linear(temp_x, weight, bias)
         for i in range(1, len(self.device_id)):
             temp_x = x.cuda(self.device_id[i])
             weight = sub_weights[i].cuda(self.device_id[i])
             bias = sub_biases[i].cuda(self.device_id[i])
             out = torch.cat((out, F.linear(temp_x, weight, bias).cuda(self.device_id[0])), dim=1)
     return out
Esempio n. 6
0
 def mio_module(self, each_mmbox, len_conf):
     chunk = torch.chunk(each_mmbox, each_mmbox.shape[1], 1)
     bmax  = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
     cls = ( torch.cat([bmax,chunk[3]], dim=1) if len_conf==0 else torch.cat([chunk[3],bmax],dim=1) )
     if len(chunk)==6:
         cls = torch.cat([cls, chunk[4], chunk[5]], dim=1) 
     elif len(chunk)==8:
         cls = torch.cat([cls, chunk[4], chunk[5], chunk[6], chunk[7]], dim=1) 
     return cls 
Esempio n. 7
0
File: utils.py Progetto: phonx/MUNIT
def vgg_preprocess(batch):
    tensortype = type(batch.data)
    (r, g, b) = torch.chunk(batch, 3, dim = 1)
    batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR
    batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
    mean = tensortype(batch.data.size())
    mean[:, 0, :, :] = 103.939
    mean[:, 1, :, :] = 116.779
    mean[:, 2, :, :] = 123.680
    batch = batch.sub(Variable(mean)) # subtract mean
    return batch
    def forward(self, inputs):
        '''inputs: batch_size, num_tokens, 50
        return-activations: batch_size, num_tokens + 2, output_dim
        return-mask: batch_size, num_tokens + 2'''
        token_embedding = self._token_embedder(inputs)
        mask = token_embedding["mask"]
        type_representation = token_embedding["token_embedding"]
        lstm_outputs = self._elmo_lstm(type_representation, mask)

        output_tensors = [
            torch.cat([type_representation, type_representation], dim=-1) * mask.float().unsqueeze(-1)
        ]
        for layer_activations in torch.chunk(lstm_outputs, lstm_outputs.size(0), dim=0):
            output_tensors.append(layer_activations.squeeze(0))
        return {"activations": output_tensors, "mask": mask}
Esempio n. 9
0
    def forward(self, og_x, a=None, cond_blocks=None):
        x = self.conv_input(self.nonlinearity(og_x))
        if a is not None : 
            x += self.nin_skip(self.nonlinearity(a))
        x = self.nonlinearity(x)
        x = self.dropout(x)
        x = self.conv_out(x)

        if cond_blocks is not None:
            conditioning_block = cond_blocks[(x.size(2), x.size(3))]

            x += conditioning_block

        a, b = torch.chunk(x, 2, dim=1)
        c3 = a * F.sigmoid(b)
        return og_x + c3
Esempio n. 10
0
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
        """
        Parameters
        ----------
        inputs: ``torch.autograd.Variable``
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.

        Returns
        -------
        Dict with keys:

        ``'activations'``: ``List[torch.autograd.Variable]``
            A list of activations at each layer of the network, each of shape
            ``(batch_size, timesteps + 2, embedding_dim)``
        ``'mask'``:  ``torch.autograd.Variable``
            Shape ``(batch_size, timesteps + 2)`` long tensor with sequence mask.

        Note that the output tensors all include additional special begin and end of sequence
        markers.
        """
        token_embedding = self._token_embedder(inputs)
        type_representation = token_embedding['token_embedding']
        mask = token_embedding['mask']
        lstm_outputs = self._elmo_lstm(type_representation, mask)

        # Prepare the output.  The first layer is duplicated.
        output_tensors = [
                torch.cat([type_representation, type_representation], dim=-1)
        ]
        for layer_activations in torch.chunk(lstm_outputs, lstm_outputs.size(0), dim=0):
            output_tensors.append(layer_activations.squeeze(0))

        return {
                'activations': output_tensors,
                'mask': mask,
        }
Esempio n. 11
0
    def forward(self, input, label):
        # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))
        self.iter += 1
        self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))

        # --------------------------- cos(theta) & phi(theta) ---------------------------
        if self.device_id == None:
            cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
        else:
            x = input
            sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
            temp_x = x.cuda(self.device_id[0])
            weight = sub_weights[0].cuda(self.device_id[0])
            cos_theta = F.linear(F.normalize(temp_x), F.normalize(weight))
            for i in range(1, len(self.device_id)):
                temp_x = x.cuda(self.device_id[i])
                weight = sub_weights[i].cuda(self.device_id[i])
                cos_theta = torch.cat((cos_theta, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)

        cos_theta = cos_theta.clamp(-1, 1)
        cos_m_theta = self.mlambda[self.m](cos_theta)
        theta = cos_theta.data.acos()
        k = (self.m * theta / 3.14159265).floor()
        phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
        NormOfFeature = torch.norm(input, 2, 1)

        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cos_theta.size())
        if self.device_id != None:
            one_hot = one_hot.cuda(self.device_id[0])
        one_hot.scatter_(1, label.view(-1, 1), 1)

        # --------------------------- Calculate output ---------------------------
        output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta
        output *= NormOfFeature.view(-1, 1)

        return output
Esempio n. 12
0
    def chunk(self, n_chunks, dim=0):
        """Splits this tensor into a tuple of tensors.

        See :func:`torch.chunk`.
        """
        return torch.chunk(self, n_chunks, dim)
Esempio n. 13
0
    def chunk(self, n_chunks, dim=0):
        r"""Splits this tensor into a certain number of tensor chunks.

        See :func:`torch.chunk`.
        """
        return torch.chunk(self, n_chunks, dim)
Esempio n. 14
0
    def forward(self, x, actions, geco, global_step):
        """
        x: [batch_size, T, C, H, W]
        """
        x_means, masks, nll_t, kl_t, prior_zs, posterior_zs, lambdas, inference_steps = [], [], [], [], [], [], [], []
        rollout_nll = []
        undiscounted_nll = []
        dynamics_dist = []

        total_losses = 0.

        current_recurrent_states = {
            'n_step_dynamics': {
                'h': self.init_recurrent_states['n_step_dynamics']['h'].to(x.device).repeat(1, self.batch_size, 1)
                }
        }

        for step_index, inference_step in enumerate(self.time_inference_schedule):
            
            current_recurrent_states['inference_lambda'] = {
                'h': torch.zeros(1, self.batch_size * self.K, self.lstm_dim).to(x.device),
                'c': torch.zeros(1, self.batch_size * self.K, self.lstm_dim).to(x.device)}
            

            if inference_step == "I":
                x_loc, mask, nll, kl, posterior_zs, lambdas, total_loss, current_recurrent_states, i, dyn_dist = self.inference_step(
                        x[:,step_index], geco, step_index, global_step, posterior_zs, lambdas, current_recurrent_states, actions)
                x_means += [x_loc]
                masks += [mask]
                nll_t += [nll]
                kl_t += [kl]
                inference_steps += [i]
                dynamics_dist += dyn_dist
                total_losses += total_loss

            # Dynamics
            elif inference_step == "D":
                x_t = x[:, step_index]

                x_loc, mask, nll, posterior_zs, current_recurrent_states, loss, dyn_dist = \
                        self.dynamics_step(x_t, geco, step_index, global_step, posterior_zs, current_recurrent_states, actions)
                x_means += [x_loc]
                masks += [mask]
                nll_t += [nll]
                total_losses += loss
                dynamics_dist += dyn_dist
            
            # Dynamics BMS
            elif inference_step == "BMS":
                x_t = x[:, step_index]

                x_loc, mask, nll, posterior_zs, current_recurrent_states, loss, dyn_dist = \
                        self.dynamics_bms_step(x_t, geco, step_index, global_step, posterior_zs, current_recurrent_states, actions)
                x_means += [x_loc]
                masks += [mask]
                nll_t += [nll]
                total_losses += loss
                dynamics_dist += dyn_dist

            # Random rollout
            elif inference_step == "R":
                with torch.no_grad():
                    if step_index == self.context_len:
                        z_prev = torch.stack(posterior_zs)
                        # copy z_prev to [context_len, n_samples*batch_size*K, z_size]
                        z_prev = z_prev.repeat(1, self.stochastic_samples, 1)
                        z_prev = list(torch.chunk(z_prev, self.context_len, dim=0))  # list of [n_samples*N*K, z_size]
                        posterior_zs = [_.squeeze(0) for _ in z_prev]
                        current_recurrent_states['n_step_dynamics']['h'] = current_recurrent_states['n_step_dynamics']['h'].to(x.device).repeat(1, self.stochastic_samples, 1)
                    x_loc, mask, posterior_zs, current_recurrent_states, dyn_dist = self.rollout(posterior_zs, step_index, current_recurrent_states, actions)
                    if x_means[-1].shape[1] != self.stochastic_samples:
                        x_means = [_.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1, 1, 1) for _ in x_means]
                        masks = [m.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1, 1, 1) for m in masks]
                    x_means = x_means + x_loc 
                    masks = masks + mask
                    dynamics_dist += dyn_dist
            elif inference_step == "U":  # O(TSNK) complexity
                if step_index == self.context_len:
                    z_prev = torch.stack(posterior_zs)
                    # copy z_prev to [context_len, n_samples*batch_size*K, z_size]
                    z_prev = z_prev.repeat(1, self.stochastic_samples, 1)
                    z_prev = list(torch.chunk(z_prev, self.context_len, dim=0))  # list of [n_samples*N*K, z_size]
                    posterior_zs = [_.squeeze(0) for _ in z_prev]
                    current_recurrent_states['n_step_dynamics']['h'] = current_recurrent_states['n_step_dynamics']['h'].to(x.device).repeat(1, self.stochastic_samples, 1)
                    x_means = [_.unsqueeze(0).repeat(self.stochastic_samples, 1, 1, 1, 1, 1).permute(1,0,2,3,4,5).contiguous() for _ in x_means]
                    masks = [_.unsqueeze(0).repeat(self.stochastic_samples, 1, 1, 1, 1, 1).permute(1,0,2,3,4,5).contiguous() for _ in masks]
                
                x_t = x[:, step_index].repeat(self.stochastic_samples, 1, 1, 1)
                x_loc, mask, posterior_zs, current_recurrent_states, nll, nll_disc, _ = self.rollout(posterior_zs, step_index, current_recurrent_states, actions, x_t, True)
                x_means = x_means + x_loc 
                masks = masks + mask
                
                rollout_nll += [nll]
                #undiscounted_nll += [nll]
                
        if self.dynamics_uncertainty:
            _, _, C, H, W = x.shape
            loss, best_indices, best_nll = self.uncertainty_loss(rollout_nll, geco, global_step)
            total_losses += torch.mean(loss)  # average over batch size (summed over rollout steps)
            best_x_means, best_masks = [], []
            batch_ids = torch.arange(self.batch_size).to(x.device)
            best_batch_indices = (self.stochastic_samples * batch_ids) + best_indices

            for i in range(len(x_means)):
                x_m = x_means[i].view(self.stochastic_samples, self.batch_size, self.K, C, H, W)
                x_m = x_m.view(-1, self.K, C, H, W)
                
                masks_m = masks[i].view(self.stochastic_samples, self.batch_size, self.K, 1, H, W)
                masks_m = masks_m.view(-1, self.K, 1, H, W)

                best_x_means += [x_m[best_batch_indices]]
                best_masks += [masks_m[best_batch_indices]]

            x_means = best_x_means
            masks = best_masks
            nll_to_return = torch.sum(torch.mean(torch.stack(nll_t), dim=1)) + torch.mean(best_nll)
        else:
            nll_to_return = torch.sum(torch.mean(torch.stack(nll_t), dim=1))

        outs = {
            'total_loss': total_losses,
            'nll': nll_to_return,
            'kl': torch.sum(torch.stack(kl_t)),
            'x_means': x_means,
            'masks': masks,
            'posterior_zs': posterior_zs,
            'inference_steps': torch.mean(torch.Tensor(inference_steps).to(x.device)),
            'dynamics': dynamics_dist,
            'lambdas': lambdas
        }
        return outs
Esempio n. 15
0
	def forward(self, input, indices):
		asize = list(indices.size()) 
		isize = list(input.size())
		ksize = self.ksize
		kstride = self.kstride
		x = np.array(torch.chunk(input, chunks = isize[0], dim = 0), dtype = torch.Tensor)
 def vgg_deprocess(tensor):
     bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0
     (b, g, r) = torch.chunk(bgr, 3, dim=0)
     rgb = torch.cat((r, g, b), 0)
     return rgb
    def forward(self, lr_feature, hr_feature):

        #self.fuse.weight.data=torch.abs(self.fuse.weight.data)
        with torch.no_grad():
            scale=hr_feature.shape[-1]//lr_feature.shape[-1]
            if scale%2==0:
                x=torch.arange(-scale//2,scale//2+1).float()
                x=torch.cat([x[:x.shape[0]//2],x[x.shape[0]//2+1:]]).unsqueeze(0)
                distance_matrix=x.expand(scale,scale).unsqueeze(0)
                distance_matrix=torch.cat([distance_matrix,distance_matrix.transpose(2,1)],0)
                distance_matrix=torch.cat([distance_matrix,torch.sqrt(torch.pow(distance_matrix[0],2)+torch.pow(distance_matrix[1],2)).unsqueeze(0)],0).unsqueeze(0)
                padding1=torch.zeros(hr_feature.shape[0],1,hr_feature.shape[2],scale).float().cuda()
                padding2=torch.zeros(hr_feature.shape[0],1,scale,hr_feature.shape[3]).float().cuda()
            else:
                exit()
            #center
            distance_matrix=distance_matrix.repeat(hr_feature.shape[0],1,hr_feature.shape[-2]//scale,hr_feature.shape[-1]//scale).float().cuda()
            distance_matrix2=distance_matrix+0
            distance_matrix1=distance_matrix+0
            distance_matrix3=distance_matrix+0
            distance_matrix4=distance_matrix+0
        lr_feature=lr_feature.unsqueeze(-1).expand(lr_feature.shape[0],lr_feature.shape[1],lr_feature.shape[2],lr_feature.shape[3],scale) \
                                     .contiguous().view(lr_feature.shape[0],lr_feature.shape[1],lr_feature.shape[2],lr_feature.shape[3]*scale) \
                                  .unsqueeze(-2).expand(lr_feature.shape[0],lr_feature.shape[1],lr_feature.shape[2],scale,lr_feature.shape[3]*scale) \
                                  .contiguous().view(lr_feature.shape[0],lr_feature.shape[1],lr_feature.shape[2]*scale,lr_feature.shape[3]*scale)
        #128
        representation=torch.cat([lr_feature,hr_feature,lr_feature*hr_feature,torch.pow(lr_feature-hr_feature,2)],1)
        weights1=self.similarity1(representation)
        weights2=self.similarity2(distance_matrix)
        mapping=self.fuse(torch.cat([weights1,weights2],1))
        #right
        x=torch.arange(1,scale+1).float()
        x=x.expand(scale,scale).unsqueeze(0)
        x=x.repeat(hr_feature.shape[0],hr_feature.shape[-2]//scale,hr_feature.shape[-1]//scale).float().cuda()
        
        distance_matrix1[:,0,:,:]=scale-x+1
        distance_matrix1[:,2,:,:]=torch.sqrt(torch.pow(distance_matrix1[:,0,:,:],2)+torch.pow(distance_matrix1[:,1,:,:],2)).unsqueeze(0)
        representation_r=torch.cat([lr_feature[:,:,:,scale:],hr_feature[:,:,:,:-scale],lr_feature[:,:,:,scale:]*hr_feature[:,:,:,:-scale], \
                       torch.pow(lr_feature[:,:,:,scale:]-hr_feature[:,:,:,:-scale],2)],1)
        weights1_r=self.similarity1(representation_r)
        weights2_r=self.similarity2(distance_matrix1[:,:,:,scale:])
        #print(padding.shape)
        mapping_r=torch.cat([self.fuse(torch.cat([weights1_r,weights2_r],1)),padding1],-1)
        #left
        distance_matrix2[:,0,:,:]=x
        distance_matrix2[:,2,:,:]=torch.sqrt(torch.pow(distance_matrix2[:,0,:,:],2)+torch.pow(distance_matrix2[:,1,:,:],2)).unsqueeze(0)
        representation_l=torch.cat([lr_feature[:,:,:,:-scale],hr_feature[:,:,:,scale:],lr_feature[:,:,:,:-scale]*hr_feature[:,:,:,scale:], \
                       torch.pow(lr_feature[:,:,:,:-scale]-hr_feature[:,:,:,scale:],2)],1)
        weights1_l=self.similarity1(representation_l)
        weights2_l=self.similarity2(distance_matrix2[:,:,:,:-scale])
        mapping_l=torch.cat([padding1,self.fuse(torch.cat([weights1_l,weights2_l],1))],-1)
        #top
        x=torch.arange(1,scale+1).float()
        x=x.expand(scale,scale).unsqueeze(0).transpose(2,1)
        x=x.repeat(hr_feature.shape[0],hr_feature.shape[-2]//scale,hr_feature.shape[-1]//scale).float().cuda()
        
        distance_matrix3[:,1,:,:]=(scale-x+1)
        distance_matrix3[:,2,:,:]=torch.sqrt(torch.pow(distance_matrix3[:,0,:,:],2)+torch.pow(distance_matrix3[:,1,:,:],2)).unsqueeze(0)
        representation_t=torch.cat([lr_feature[:,:,:-scale,:],hr_feature[:,:,scale:,:],lr_feature[:,:,:-scale,:]*hr_feature[:,:,scale:,:], \
                       torch.pow(lr_feature[:,:,:-scale,:]-hr_feature[:,:,scale:,:],2)],1)
        weights1_t=self.similarity1(representation_t)
        weights2_t=self.similarity2(distance_matrix3[:,:,:-scale,:])
        mapping_t=torch.cat([padding2,self.fuse(torch.cat([weights1_t,weights2_t],1))],-2)
        #bottom
        
        distance_matrix4[:,1,:,:]=x
        distance_matrix4[:,2,:,:]=torch.sqrt(torch.pow(distance_matrix4[:,0,:,:],2)+torch.pow(distance_matrix4[:,1,:,:],2)).unsqueeze(0)
        representation_b=torch.cat([lr_feature[:,:,scale:,:],hr_feature[:,:,:-scale,:],lr_feature[:,:,scale:,:]*hr_feature[:,:,:-scale,:], \
                       torch.pow(lr_feature[:,:,scale:,:]-hr_feature[:,:,:-scale,:],2)],1)
        weights1_b=self.similarity1(representation_b)
        weights2_b=self.similarity2(distance_matrix4[:,:,scale:,:])
        mapping_b=torch.cat([self.fuse(torch.cat([weights1_b,weights2_b],1)),padding2],-2)

        mapping_all=torch.cat([mapping,mapping_r,mapping_l,mapping_t,mapping_b],dim=1)
        mapping_norm=F.softmax(mapping_all, dim=1)
        #return mapping,mapping_r,mapping_l,mapping_t,mapping_b
        return torch.chunk(mapping_norm*mapping_all,5,dim=1)
Esempio n. 18
0
def train():
    # Turn on training mode which enables dropout.
    if args.restart:
        global train_loss, best_val_loss, eval_start_time, log_start_time
    else:
        global train_step, train_loss, best_val_loss, eval_start_time, log_start_time
    model.train()
    if args.batch_chunk > 1:
        mems = [tuple() for _ in range(args.batch_chunk)]
    else:
        mems = tuple()
    train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
    for batch, (data, target, seq_len) in enumerate(train_iter):
        model.zero_grad()
        if args.batch_chunk > 1:
            data_chunks = torch.chunk(data, args.batch_chunk, 1)
            target_chunks = torch.chunk(target, args.batch_chunk, 1)
            for i in range(args.batch_chunk):
                data_i = data_chunks[i].contiguous()
                target_i = target_chunks[i].contiguous()
                ret = para_model(data_i, target_i, *mems[i])
                loss, mems[i] = ret[0], ret[1:]
                loss = loss.float().mean().type_as(loss) / args.batch_chunk
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                train_loss += loss.float().item()
        else:
            ret = para_model(data, target, *mems)
            loss, mems = ret[0], ret[1:]
            loss = loss.float().mean().type_as(loss)
            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()
            train_loss += loss.float().item()

        if args.fp16:
            optimizer.clip_master_grads(args.clip)
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        optimizer.step()
        if args.sample_softmax > 0:
            optimizer_sparse.step()

        # step-wise learning rate annealing
        train_step += 1
        if args.scheduler in ['cosine', 'constant', 'dev_perf']:
            # linear warmup stage
            if train_step < args.warmup_step:
                curr_lr = args.lr * train_step / args.warmup_step
                optimizer.param_groups[0]['lr'] = curr_lr
                if args.sample_softmax > 0:
                    optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
            else:
                if args.scheduler == 'cosine':
                    scheduler.step(train_step)
                    if args.sample_softmax > 0:
                        scheduler_sparse.step(train_step)
        elif args.scheduler == 'inv_sqrt':
            scheduler.step(train_step)

        if train_step % args.log_interval == 0:
            cur_loss = train_loss / args.log_interval
            elapsed = time.time() - log_start_time
            log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \
                      '| ms/batch {:5.2f} | loss {:5.2f}'.format(
                epoch, train_step, batch+1, optimizer.param_groups[0]['lr'],
                elapsed * 1000 / args.log_interval, cur_loss)
            if args.dataset in ['enwik8', 'text8']:
                log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
            else:
                log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
            logging(log_str)
            train_loss = 0
            log_start_time = time.time()

        if train_step % args.eval_interval == 0:
            val_loss = evaluate(va_iter)
            logging('-' * 100)
            log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
                      '| valid loss {:5.2f}'.format(
                train_step // args.eval_interval, train_step,
                (time.time() - eval_start_time), val_loss)
            if args.dataset in ['enwik8', 'text8']:
                log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
            else:
                log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
            logging(log_str)
            logging('-' * 100)
            # Save the model if the validation loss is the best we've seen so far.
            if not best_val_loss or val_loss < best_val_loss:
                if not args.debug:
                    with open(os.path.join(args.work_dir, 'model.pt'),
                              'wb') as f:
                        torch.save(model, f)
                    with open(os.path.join(args.work_dir, 'optimizer.pt'),
                              'wb') as f:
                        torch.save(optimizer.state_dict(), f)
                    with open(os.path.join(args.work_dir, 'scheduler.pt'),
                              'wb') as f:
                        torch.save(scheduler.state_dict(), f)
                    with open(os.path.join(args.work_dir, 'trainstep.pt'),
                              'wb') as f:
                        torch.save(train_step, f)
                best_val_loss = val_loss

            # dev-performance based learning rate annealing
            if args.scheduler == 'dev_perf':
                scheduler.step(val_loss)
                if args.sample_softmax > 0:
                    scheduler_sparse.step(val_loss)

            eval_start_time = time.time()

        if train_step == args.max_step:
            break
Esempio n. 19
0
 def forward(self, x):
     h = self.encoder(x)
     mu, log_var = torch.chunk(h, 2, dim=1)  # mean and log variance.
     z = self.reparametrize(mu, log_var)
     out = self.decoder(z)
     return out, mu, log_var
 def vgg_preprocess(tensor):
     (r, g, b) = torch.chunk(tensor, 3, dim=0)
     bgr = torch.cat((b, g, r), 0)
     out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr)
     return out
 def vgg_deprocess(tensor):
     bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0
     (b, g, r) = torch.chunk(bgr, 3, dim=0)
     rgb = torch.cat((r, g, b), 0)
     return rgb
def vgg_preprocess_var(var):
        (r, g, b) = torch.chunk(var, 3, dim=1)
        bgr = torch.cat((b, g, r), 1)
        out = bgr * 255 - torch.autograd.Variable(vgg_mean[None, ...]).type(var.type()).expand_as(bgr)
        return out
Esempio n. 23
0
    def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
        # r_emb: [klen, n_head, d_head], used for term B
        # r_w_bias: [n_head, d_head], used for term C
        # r_bias: [klen, n_head], used for term D

        qlen, bsz = w.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)

        if klen > r_emb.size(0):
            r_emb_pad = r_emb[0:1].expand(klen - r_emb.size(0), -1, -1)
            r_emb = torch.cat([r_emb_pad, r_emb], 0)
            r_bias_pad = r_bias[0:1].expand(klen - r_bias.size(0), -1)
            r_bias = torch.cat([r_bias_pad, r_bias], 0)
        else:
            r_emb = r_emb[-klen:]
            r_bias = r_bias[-klen:]

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias[None]  # qlen x bsz x n_head x d_head

        AC = torch.einsum('ibnd,jbnd->ijbn',
                          (rw_head_q, w_head_k))  # qlen x klen x bsz x n_head
        B_ = torch.einsum('ibnd,jnd->ijbn',
                          (w_head_q, r_emb))  # qlen x klen x bsz x n_head
        D_ = r_bias[None, :, None]  # 1    x klen x 1   x n_head
        BD = self._rel_shift(B_ + D_)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score.masked_fill_(attn_mask[None, :, :, None],
                                        -float('inf'))
            elif attn_mask.dim() == 3:
                attn_score.masked_fill_(attn_mask[:, :, :, None],
                                        -float('inf'))

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(attn_vec.size(0),
                                              attn_vec.size(1),
                                              self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)

        return output
Esempio n. 24
0
    def forward(self, input):

        # split left image and right image
        imgs = torch.chunk(input, 2, dim = 1)
        img_left = imgs[0]
        img_right = imgs[1]

        conv1_l = self.conv1(img_left)
        conv2_l = self.conv2(conv1_l)
        conv3a_l = self.conv3(conv2_l)

        conv1_r = self.conv1(img_right)
        conv2_r = self.conv2(conv1_r)
        conv3a_r = self.conv3(conv2_r)

        # Correlate corr3a_l and corr3a_r
        out_corr = self.corr(conv3a_l, conv3a_r)
        out_corr = self.corr_activation(out_corr)
	out_conv3a_redir = self.conv_redir(conv3a_l)
	in_conv3b = torch.cat((out_conv3a_redir, out_corr), 1)


        conv3b = self.conv3_1(in_conv3b)
        conv4a = self.conv4(conv3b)
        conv4b = self.conv4_1(conv4a)
        conv5a = self.conv5(conv4b)
        conv5b = self.conv5_1(conv5a)
        conv6a = self.conv6(conv5b)
        conv6b = self.conv6_1(conv6a)


        pr6 = self.pred_flow6(conv6b)
        upconv5 = self.upconv5(conv6b)
        upflow6 = self.upflow6to5(pr6)
        concat5 = torch.cat((upconv5, upflow6, conv5b), 1)
        iconv5 = self.iconv5(concat5)

        pr5 = self.pred_flow5(iconv5)
        upconv4 = self.upconv4(iconv5)
        upflow5 = self.upflow5to4(pr5)
        concat4 = torch.cat((upconv4, upflow5, conv4b), 1)
        iconv4 = self.iconv4(concat4)
        
        pr4 = self.pred_flow4(iconv4)
        upconv3 = self.upconv3(iconv4)
        upflow4 = self.upflow4to3(pr4)
        concat3 = torch.cat((upconv3, upflow4, conv3b), 1)
        iconv3 = self.iconv3(concat3)

        pr3 = self.pred_flow3(iconv3)
        upconv2 = self.upconv2(iconv3)
        upflow3 = self.upflow3to2(pr3)
        concat2 = torch.cat((upconv2, upflow3, conv2_l), 1)
        iconv2 = self.iconv2(concat2)

        pr2 = self.pred_flow2(iconv2)
        upconv1 = self.upconv1(iconv2)
        upflow2 = self.upflow2to1(pr2)
        concat1 = torch.cat((upconv1, upflow2, conv1_l), 1)
        iconv1 = self.iconv1(concat1)

        pr1 = self.pred_flow1(iconv1)
        upconv0 = self.upconv0(iconv1)
        upflow1 = self.upflow1to0(pr1)
        concat0 = torch.cat((upconv0, upflow1, img_left), 1)
        iconv0 = self.iconv0(concat0)

        # predict flow
        pr0 = self.pred_flow0(iconv0)

        # predict flow from dropout output
        # pr6 = self.pred_flow6(F.dropout2d(conv6b))
        # pr5 = self.pred_flow5(F.dropout2d(iconv5))
        # pr4 = self.pred_flow4(F.dropout2d(iconv4))
        # pr3 = self.pred_flow3(F.dropout2d(iconv3))
        # pr2 = self.pred_flow2(F.dropout2d(iconv2))
        # pr1 = self.pred_flow1(F.dropout2d(iconv1))
        # pr0 = self.pred_flow0(F.dropout2d(iconv0))

        # if self.training:
        #     # print("finish forwarding.")
        #     return pr0, pr1, pr2, pr3, pr4, pr5, pr6
        # else:
        #     return pr0

        # can be chosen outside
        return pr0, pr1, pr2, pr3, pr4, pr5, pr6
Esempio n. 25
0
    def forward(self, x):

        N, C, T, V = x.size()
        if self.use_spatial_att:
            attention = self.atts
            if self.use_pes:
                y = self.pes(x)
            else:
                y = x
            if self.att_s:
                q, k = torch.chunk(self.in_nets(y).view(
                    N, 2 * self.num_subset, self.inter_channels, T, V),
                                   2,
                                   dim=1)  # nctv -> n num_subset c'tv
                attention = attention + self.tan(
                    torch.einsum('nsctu,nsctv->nsuv', [q, k]) /
                    (self.inter_channels * T)) * self.alphas
            if self.glo_reg_s:
                attention = attention + self.attention0s.repeat(N, 1, 1, 1)
            attention = self.drop(attention)
            y = torch.einsum('nctu,nsuv->nsctv', [x, attention]).contiguous() \
                .view(N, self.num_subset * self.in_channels, T, V)
            y = self.out_nets(y)  # nctv

            y = self.relu(self.downs1(x) + y)

            y = self.ff_nets(y)

            y = self.relu(self.downs2(x) + y)
        else:
            y = self.out_nets(x)
            y = self.relu(self.downs2(x) + y)

        # set_trace()
        # y_1 = self.out_nett_extend(y)
        # y_1 = self.relu(self.downt3(y) + y_1)
        # y = y_1

        forward_mask = self.backward_mask.transpose(-1, -2)
        backward_mask = self.backward_mask
        if self.use_temporal_att:
            attention = self.attt
            if self.use_pet:
                z = self.pet(y)
            else:
                z = y
            q_k_in = self.in_nett(z).view(N, 6 * self.num_subset,
                                          self.inter_channels, T, V)
            q_f, q_b, q_c, k_f, k_b, k_c = torch.chunk(q_k_in, 6, dim=1)
            attention_b = torch.einsum('nsctv,nscqv->nstq', [q_b, k_b]) / (
                self.inter_channels * V) * self.alphat_0
            attention_f = torch.einsum('nsctv,nscqv->nstq', [q_f, k_f]) / (
                self.inter_channels * V) * self.alphat_1
            attention_c = torch.einsum('nsctv,nscqv->nstq', [q_c, k_c]) / (
                self.inter_channels * V) * self.alphat_2
            attention_b = torch.einsum('nstq,tq->nstq',
                                       [attention_b, backward_mask])
            attention_f = torch.einsum('nstq,tq->nstq',
                                       [attention_f, forward_mask])
            attention_b = self.drop(attention_b)
            attention_f = self.drop(attention_f)
            attention_c = self.drop(attention_c)
            z_f = torch.einsum('nctv,nstq->nscqv', [y, attention_f]).contiguous() \
                .view(N, self.num_subset * self.out_channels, T, V)
            z_b = torch.einsum('nctv,nstq->nscqv', [y, attention_b]).contiguous() \
                .view(N, self.num_subset * self.out_channels, T, V)
            z_c = torch.einsum('nctv,nstq->nscqv', [y, attention_c]).contiguous() \
                .view(N, self.num_subset * self.out_channels, T, V)
            z = torch.cat([z_f, z_b, z_c], dim=-3)
            z = self.out_nett(z)  # nctv

            z = self.relu(self.downt1(y) + z)

            z = self.ff_nett(z)

            z = self.relu(self.downt2(y) + z)
        else:
            z = self.out_nett(y)
            z = self.relu(self.downt2(y) + z)

        # set_trace()
        z_1 = self.out_nett_extend(z)
        z_1 = self.relu(self.downt3(z) + z_1)
        z = z_1
        return z
Esempio n. 26
0
 def angle_axis_to_rotation_matrix_taylor(angle_axis):
     rx, ry, rz = torch.chunk(angle_axis, 3, dim=-1)
     ones = torch.ones_like(rx)
     R = torch.cat([ones, -rz, ry, rz, ones, -rx, -ry, rx, ones],
                   dim=1).view(-1, 3, 3)
     return R
Esempio n. 27
0
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                word_inputs: torch.Tensor = None) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``, required.
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.
        word_inputs : ``torch.Tensor``, required.
            If you passed a cached vocab, you can in addition pass a tensor of shape ``(batch_size, timesteps)``,
            which represent word ids which have been pre-cached.

        Returns
        -------
        Dict with keys:

        ``'activations'``: ``List[torch.Tensor]``
            A list of activations at each layer of the network, each of shape
            ``(batch_size, timesteps + 2, embedding_dim)``
        ``'mask'``:  ``torch.Tensor``
            Shape ``(batch_size, timesteps + 2)`` long tensor with sequence mask.

        Note that the output tensors all include additional special begin and end of sequence
        markers.
        """
        if self._word_embedding is not None and word_inputs is not None:
            try:
                mask_without_bos_eos = (word_inputs > 0).long()
                # The character cnn part is cached - just look it up.
                embedded_inputs = self._word_embedding(word_inputs) # type: ignore
                # shape (batch_size, timesteps + 2, embedding_dim)
                type_representation, mask = add_sentence_boundary_token_ids(
                        embedded_inputs,
                        mask_without_bos_eos,
                        self._bos_embedding,
                        self._eos_embedding
                )
            except RuntimeError:
                # Back off to running the character convolutions,
                # as we might not have the words in the cache.
                token_embedding = self._token_embedder(inputs)
                mask = token_embedding['mask']
                type_representation = token_embedding['token_embedding']
        else:
            token_embedding = self._token_embedder(inputs)
            mask = token_embedding['mask']
            type_representation = token_embedding['token_embedding']
        lstm_outputs = self._elmo_lstm(type_representation, mask)

        # Prepare the output.  The first layer is duplicated.
        # Because of minor differences in how masking is applied depending
        # on whether the char cnn layers are cached, we'll be defensive and
        # multiply by the mask here. It's not strictly necessary, as the
        # mask passed on is correct, but the values in the padded areas
        # of the char cnn representations can change.
        output_tensors = [
                torch.cat([type_representation, type_representation], dim=-1) * mask.float().unsqueeze(-1)
        ]
        for layer_activations in torch.chunk(lstm_outputs, lstm_outputs.size(0), dim=0):
            output_tensors.append(layer_activations.squeeze(0))

        return {
                'activations': output_tensors,
                'mask': mask,
        }
Esempio n. 28
0
    def forward(self, x):
        h = F.relu(self.conv1_1(x))
        h = F.relu(self.conv1_2(h))
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.conv2_1(h))
        h = F.relu(self.conv2_2(h))
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.conv3_1(h))
        h = F.relu(self.conv3_2(h))
        h = F.relu(self.conv3_3(h))
        f3_3 = h
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.conv4_1(h))
        h = F.relu(self.conv4_2(h))
        h = F.relu(self.conv4_3(h))
        f4_3 = h
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.conv5_1(h))
        h = F.relu(self.conv5_2(h))
        h = F.relu(self.conv5_3(h))
        f5_3 = h
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.fc6(h))
        h = F.relu(self.fc7(h))
        ffc7 = h
        h = F.relu(self.conv6_1(h))
        h = F.relu(self.conv6_2(h))
        f6_2 = h
        h = F.relu(self.conv7_1(h))
        h = F.relu(self.conv7_2(h))
        f7_2 = h

        f3_3 = self.conv3_3_norm(f3_3)
        f4_3 = self.conv4_3_norm(f4_3)
        f5_3 = self.conv5_3_norm(f5_3)

        cls1 = self.conv3_3_norm_mbox_conf(f3_3)
        reg1 = self.conv3_3_norm_mbox_loc(f3_3)
        cls2 = self.conv4_3_norm_mbox_conf(f4_3)
        reg2 = self.conv4_3_norm_mbox_loc(f4_3)
        cls3 = self.conv5_3_norm_mbox_conf(f5_3)
        reg3 = self.conv5_3_norm_mbox_loc(f5_3)
        cls4 = self.fc7_mbox_conf(ffc7)
        reg4 = self.fc7_mbox_loc(ffc7)
        cls5 = self.conv6_2_mbox_conf(f6_2)
        reg5 = self.conv6_2_mbox_loc(f6_2)
        cls6 = self.conv7_2_mbox_conf(f7_2)
        reg6 = self.conv7_2_mbox_loc(f7_2)

        # max-out background label
        chunk = torch.chunk(cls1, 4, 1)
        bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
        cls1 = torch.cat([bmax, chunk[3]], dim=1)

        return [
            cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6,
            reg6
        ]
def vgg_preprocess_var(var):
        (r, g, b) = torch.chunk(var, 3, dim=1)
        bgr = torch.cat((b, g, r), 1)
        out = bgr * 255 - torch.autograd.Variable(vgg_mean[None, ...]).type(var.type()).expand_as(bgr)
        return out
 def forward(self, input):
     splits = torch.chunk(input, 2, dim=self.dim)
     return splits[0] if self.first_half else splits[1]
def validate(nets, loss_terms, opts, dataloader, epoch, network_type, devices=(cuda0,cuda1), batch_n="whole_test_show"):
    """
    validate phase
    """
    netD, netG = nets["netD"], nets["netG"]
    ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms['ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms["GANLoss"], loss_terms["StyleLoss"]
    optG, optD = opts['optG'], opts['optD']
    device0, device1 = devices
    netG.to(device0)
    netD.to(device0)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {"g_loss":AverageMeter(),"p_loss":AverageMeter(), "s_loss":AverageMeter(), "r_loss":AverageMeter(), "whole_loss":AverageMeter(), "d_loss":AverageMeter()}

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(result_dir, "val_{}_{}".format(epoch, batch_n if isinstance(batch_n, str) else batch_n+1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_comp_dir = os.path.join(val_save_dir, "comp")
    for size in SIZES_TAGS:
        if not os.path.exists(os.path.join(val_save_real_dir, size)):
            os.makedirs(os.path.join(val_save_real_dir, size))
        if not os.path.exists(os.path.join(val_save_gen_dir, size)):
            os.makedirs(os.path.join(val_save_gen_dir, size))
        if not os.path.exists(os.path.join(val_save_comp_dir, size)):
            os.makedirs(os.path.join(val_save_comp_dir, size))
    info = {}
    t = 0
    for i, (ori_imgs, ori_masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        pre_imgs = ori_imgs
        pre_complete_imgs = (pre_imgs / 127.5 - 1)

        for s_i, size in enumerate(TRAIN_SIZES):

            masks = ori_masks['val']
            masks = F.interpolate(masks, size)
            masks = (masks > 0).type(torch.FloatTensor)
            imgs = F.interpolate(ori_imgs, size)
            if imgs.size(1) != 3:
                print(t, imgs.size() )
            pre_inter_imgs = F.interpolate(pre_complete_imgs, size)

            imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(device0), masks.to(device0), pre_complete_imgs.to(device0), pre_inter_imgs.to(device0)
            #masks = (masks > 0).type(torch.FloatTensor)

            #imgs, masks = imgs.to(device), masks.to(device)
            imgs = (imgs / 127.5 - 1)
            # mask is 1 on masked region
            # forward
            if network_type == 'l2h_unet':
                recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs, size)
            elif network_type == 'l2h_gated':
                recon_imgs = netG(imgs, masks, pre_inter_imgs)
            elif network_type == 'sa_gated':
                recon_imgs, _ = netG(imgs, masks)
            complete_imgs = recon_imgs * masks + imgs * (1 - masks)


            pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat([recon_imgs, masks, torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)


            g_loss = GANLoss(pred_neg)

            r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks)

            imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(device1), complete_imgs.to(device1)
            p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs)
            s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(imgs, complete_imgs)
            p_loss, s_loss = p_loss.to(device0), s_loss.to(device0)
            imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(device0), complete_imgs.to(device0)

            whole_loss = r_loss + p_loss #g_loss + r_loss

            # Update the recorder for losses
            losses['g_loss'].update(g_loss.item(), imgs.size(0))
            losses['r_loss'].update(r_loss.item(), imgs.size(0))
            losses['p_loss'].update(p_loss.item(), imgs.size(0))
            losses['s_loss'].update(s_loss.item(), imgs.size(0))
            losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

            d_loss = DLoss(pred_pos, pred_neg)
            losses['d_loss'].update(d_loss.item(), imgs.size(0))
            pre_complete_imgs = complete_imgs
            # Update time recorder
            batch_time.update(time.time() - end)


            # Logger logging


            #if t < config.STATIC_VIEW_SIZE:
            print(i, size)
            real_img = img2photo(imgs)
            gen_img = img2photo(recon_imgs)
            comp_img = img2photo(complete_imgs)

            real_img = Image.fromarray(real_img[0].astype(np.uint8))
            gen_img = Image.fromarray(gen_img[0].astype(np.uint8))
            comp_img = Image.fromarray(comp_img[0].astype(np.uint8))
            real_img.save(os.path.join(val_save_real_dir, SIZES_TAGS[s_i], "{}.png".format(i)))
            gen_img.save(os.path.join(val_save_gen_dir, SIZES_TAGS[s_i], "{}.png".format(i)))
            comp_img.save(os.path.join(val_save_comp_dir, SIZES_TAGS[s_i], "{}.png".format(i)))

            end = time.time()
Esempio n. 32
0
def preprocess_batch(batch):
    batch = batch.transpose(0, 1)
    (r, g, b) = torch.chunk(batch, 3)
    batch = torch.cat((b, g, r))
    batch = batch.transpose(0, 1)
    return batch
Esempio n. 33
0
    def episodic_training(self, train_results, tail):

        episode = self.replay_buffer.get_tail(tail)

        sl = episode['s']
        sl = list(torch.chunk(sl, int((len(sl) / self.batch) + 1)))

        s, r, t, e = [episode[k] for k in ['s', 'r', 't', 'e']]

        v = []
        for s in sl:
            v.append(self.v_net(s))

        v.append(torch.zeros_like(v[0][:1]))
        v = torch.cat(v).detach()
        v1, v2 = v[:-1], v[1:]

        adv, v_target = generalized_advantage_estimation(
            r,
            t,
            e,
            v1,
            v2,
            self.gamma,
            self.lambda_gae,
            norm=self.norm_rewards)

        episode['adv'] = adv
        episode['v_target'] = v_target

        if self.batch_ppo:
            n = self.steps_per_episode * self.batch
            indices = torch.randperm(tail * max(1, n // tail + 1)) % tail
            indices = indices[:n].unsqueeze(1).view(self.steps_per_episode,
                                                    self.batch)

            samples = {k: v[indices] for k, v in episode.items()}
            iterator_pi = iter_dict(samples)
            iterator_v = iter_dict(samples)
        else:
            iterator_pi = itertools.repeat(episode, self.steps_per_episode)
            iterator_v = itertools.repeat(episode, self.steps_per_episode)

        for i, sample in enumerate(iterator_pi):
            s, a, r, t, stag, adv, v_target, log_pi_old = [
                sample[k] for k in
                ['s', 'a', 'r', 't', 'stag', 'adv', 'v_target', 'logp']
            ]
            self.pi_net(s)
            log_pi = self.pi_net.log_prob(a)
            ratio = torch.exp((log_pi - log_pi_old).sum(dim=1))

            clip_adv = torch.clamp(ratio, 1 - self.eps_ppo,
                                   1 + self.eps_ppo) * adv
            loss_p = -(torch.min(ratio * adv, clip_adv)).mean()

            approx_kl = -float((log_pi - log_pi_old).sum(dim=1).mean())
            ent = float(self.pi_net.entropy().sum(dim=1).mean())

            if approx_kl > self.target_kl:
                train_results['scalar']['pi_opt_rounds'].append(i)
                break

            clipped = ratio.gt(1 + self.eps_ppo) | ratio.lt(1 - self.eps_ppo)
            clipfrac = float(
                torch.as_tensor(clipped, dtype=torch.float32).mean())

            self.optimizer_p.zero_grad()
            loss_p.backward()
            if self.clip_p:
                nn.utils.clip_grad_norm(self.pi_net.parameters(), self.clip_p)
            self.optimizer_p.step()

            train_results['scalar']['loss_p'].append(float(loss_p))
            train_results['scalar']['approx_kl'].append(approx_kl)
            train_results['scalar']['ent'].append(ent)
            train_results['scalar']['clipfrac'].append(clipfrac)

        for sample in iterator_v:
            s, a, r, t, stag, adv, v_target, log_pi_old = [
                sample[k] for k in
                ['s', 'a', 'r', 't', 'stag', 'adv', 'v_target', 'logp']
            ]

            v = self.v_net(s)
            loss_v = F.mse_loss(v, v_target, reduction='mean')

            self.optimizer_v.zero_grad()
            loss_v.backward()
            if self.clip_q:
                nn.utils.clip_grad_norm(self.v_net.parameters(), self.clip_q)
            self.optimizer_v.step()

            train_results['scalar']['loss_v'].append(float(loss_v))

        return train_results
Esempio n. 34
0
def tensor_save_bgrimage(tensor):
    (b, g, r) = torch.chunk(tensor, 3)
    tensor = torch.cat((r, g, b))
    return tensor_save_rgbimage(tensor)
    def __init__(self, image_size, noise_size=100, num_label=10, output_params=False):
        super(Encoder, self).__init__()

        self.noise_size = noise_size
        self.image_size = image_size
        self.num_label  = num_label

        self.core_net = nn.Sequential(
            nn.Conv2d(  4, 128, 5, 2, 2, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 256, 5, 2, 2, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 512, 5, 2, 2, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
            Expression(lambda tensor: tensor.view(tensor.size(0), 512 * 4 * 4)),
        )
        
        if output_params:
            self.core_net.add_module(str(len(self.core_net._modules)), WN_Linear(4 * 4 * 512, self.noise_size*2, train_scale=True, init_stdv=0.1))
            self.core_net.add_module(str(len(self.core_net._modules)), Expression(lambda x: torch.chunk(x, 2, 1)))
        else:
            self.core_net.add_module(str(len(self.core_net._modules)), WN_Linear(4 * 4 * 512, self.noise_size, train_scale=True, init_stdv=0.1))
Esempio n. 36
0
        def conv2d(
            input,
            weight,
            bias=None,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            padding_mode="zeros",
        ):
            """
            Overloads torch.conv2d to be able to use MPC on convolutional networks.
            The idea is to build new tensors from input and weight to compute a matrix multiplication
            equivalent to the convolution.

            Args:
                input: input image
                weight: convolution kernels
                bias: optional additive bias
                stride: stride of the convolution kernels
                padding:
                dilation: spacing between kernel elements
                groups:
                padding_mode: type of padding, should be either 'zeros' or 'circular' but 'reflect' and 'replicate' accepted
            Returns:
                the result of the convolution as an AdditiveSharingTensor
            """
            assert len(input.shape) == 4
            assert len(weight.shape) == 4

            # Change to tuple if not one
            stride = torch.nn.modules.utils._pair(stride)
            padding = torch.nn.modules.utils._pair(padding)
            dilation = torch.nn.modules.utils._pair(dilation)

            # Extract a few useful values
            batch_size, nb_channels_in, nb_rows_in, nb_cols_in = input.shape
            nb_channels_out, nb_channels_in_, nb_rows_kernel, nb_cols_kernel = weight.shape

            if bias is not None:
                assert len(bias) == nb_channels_out

            # Check if inputs are coherent
            assert nb_channels_in == nb_channels_in_ * groups
            assert nb_channels_in % groups == 0
            assert nb_channels_out % groups == 0

            # Compute output shape
            nb_rows_out = int(((nb_rows_in + 2 * padding[0] - dilation[0] *
                                (nb_rows_kernel - 1) - 1) / stride[0]) + 1)
            nb_cols_out = int(((nb_cols_in + 2 * padding[1] - dilation[1] *
                                (nb_cols_kernel - 1) - 1) / stride[1]) + 1)

            # Apply padding to the input
            if padding != (0, 0):
                padding_mode = "constant" if padding_mode == "zeros" else padding_mode
                input = torch.nn.functional.pad(
                    input, (padding[1], padding[1], padding[0], padding[0]),
                    padding_mode)
                # Update shape after padding
                nb_rows_in += 2 * padding[0]
                nb_cols_in += 2 * padding[1]

            # We want to get relative positions of values in the input tensor that are used by one filter convolution.
            # It basically is the position of the values used for the top left convolution.
            pattern_ind = []
            for ch in range(nb_channels_in):
                for r in range(nb_rows_kernel):
                    for c in range(nb_cols_kernel):
                        pixel = r * nb_cols_in * dilation[0] + c * dilation[1]
                        pattern_ind.append(pixel +
                                           ch * nb_rows_in * nb_cols_in)

            # The image tensor is reshaped for the matrix multiplication:
            # on each row of the new tensor will be the input values used for each filter convolution
            # We will get a matrix [[in values to compute out value 0],
            #                       [in values to compute out value 1],
            #                       ...
            #                       [in values to compute out value nb_rows_out*nb_cols_out]]
            im_flat = input.view(batch_size, -1)
            im_reshaped = []
            for cur_row_out in range(nb_rows_out):
                for cur_col_out in range(nb_cols_out):
                    # For each new output value, we just need to shift the receptive field
                    offset = cur_row_out * stride[
                        0] * nb_cols_in + cur_col_out * stride[1]
                    tmp = [ind + offset for ind in pattern_ind]
                    im_reshaped.append(im_flat[:, tmp].wrap())
            im_reshaped = torch.stack(im_reshaped).permute(1, 0, 2)

            # The convolution kernels are also reshaped for the matrix multiplication
            # We will get a matrix [[weights for out channel 0],
            #                       [weights for out channel 1],
            #                       ...
            #                       [weights for out channel nb_channels_out]].TRANSPOSE()
            weight_reshaped = weight.view(nb_channels_out // groups,
                                          -1).t().wrap()

            # Now that everything is set up, we can compute the result
            if groups > 1:
                res = []
                chunks_im = torch.chunk(im_reshaped, groups, dim=2)
                chunks_weights = torch.chunk(weight_reshaped, groups, dim=0)
                for g in range(groups):
                    tmp = chunks_im[g].matmul(chunks_weights[g])
                    res.append(tmp)
                res = torch.cat(res, dim=2)
            else:
                res = im_reshaped.matmul(weight_reshaped)

            # Add a bias if needed
            if bias is not None:
                res += bias

            # ... And reshape it back to an image
            res = (res.permute(0, 2,
                               1).view(batch_size, nb_channels_out,
                                       nb_rows_out, nb_cols_out).contiguous())
            return res.child
Esempio n. 37
0
                    help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='gloo', type=str, help='distributed backend')
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--rank', default=0, type=int, help='The rank of this process')
args = parser.parse_args()

args.distributed = args.world_size > 1
if args.distributed:
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size, rank=args.rank)

if args.distributed:
    input_data = torch.randn(int(args.num_samples / args.world_size), 1, 161, args.seconds * 100).cuda()
else:
    input_data = torch.randn(args.num_samples, 1, 161, args.seconds * 100).cuda()
input_data = torch.chunk(input_data, int(len(input_data) / args.batch_size))

rnn_type = args.rnn_type.lower()
assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"

with open(args.labels_path) as label_file:
    labels = str(''.join(json.load(label_file)))

audio_conf = dict(sample_rate=args.sample_rate,
                  window_size=args.window_size)

model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                   nb_layers=args.hidden_layers,
                   audio_conf=audio_conf,
                   labels=labels,
                   rnn_type=supported_rnns[rnn_type])
Esempio n. 38
0
 def forward(self, x):
     h = self.encoder(x)
     mu, logvar = torch.chunk(h, 2, dim=1)
     z = self.reparameterize(mu, logvar)
     return self.decoder(z), mu, logvar
Esempio n. 39
0
                    help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='gloo', type=str, help='distributed backend')
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--rank', default=0, type=int, help='The rank of this process')
args = parser.parse_args()

args.distributed = args.world_size > 1
if args.distributed:
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size, rank=args.rank)

if args.distributed:
    input_data = torch.randn(int(args.num_samples / args.world_size), 1, 161, args.seconds * 100).cuda()
else:
    input_data = torch.randn(args.num_samples, 1, 161, args.seconds * 100).cuda()
input_data = torch.chunk(input_data, int(len(input_data) / args.batch_size))

rnn_type = args.rnn_type.lower()
assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"

with open(args.labels_path) as label_file:
    labels = str(''.join(json.load(label_file)))

audio_conf = dict(sample_rate=args.sample_rate,
                  window_size=args.window_size)

model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                   nb_layers=args.hidden_layers,
                   audio_conf=audio_conf,
                   labels=labels,
                   rnn_type=supported_rnns[rnn_type])
Esempio n. 40
0
    def forward(self, input):

        # split left image and right image
        # print(input.size())
        imgs = torch.chunk(input, 2, dim = 1)
        img_left = imgs[0]
        img_right = imgs[1]

        conv1 = self.conv1(input)
        conv2 = self.conv2(conv1)
        conv3a = self.conv3(conv2)
        conv3b = self.conv3_1(conv3a)
        conv4a = self.conv4(conv3b)
        conv4b = self.conv4_1(conv4a)
        conv5a = self.conv5(conv4b)
        conv5b = self.conv5_1(conv5a)
        conv6a = self.conv6(conv5b)
        conv6b = self.conv6_1(conv6a)

        pr6 = self.pred_flow6(conv6b)

        upconv5 = self.upconv5(conv6b)
        upflow6 = self.upflow6to5(pr6)
        concat5 = torch.cat((upconv5, upflow6, conv5b), 1)
        iconv5 = self.iconv5(concat5)
        pr5 = self.pred_flow5(iconv5)

        upconv4 = self.upconv4(iconv5)
        upflow5 = self.upflow5to4(pr5)
        concat4 = torch.cat((upconv4, upflow5, conv4b), 1)
        iconv4 = self.iconv4(concat4)
        pr4 = self.pred_flow4(iconv4)
        
        upconv3 = self.upconv3(iconv4)
        upflow4 = self.upflow4to3(pr4)
        concat3 = torch.cat((upconv3, upflow4, conv3b), 1)
        iconv3 = self.iconv3(concat3)
        pr3 = self.pred_flow3(iconv3)

        upconv2 = self.upconv2(iconv3)
        upflow3 = self.upflow3to2(pr3)
        concat2 = torch.cat((upconv2, upflow3, conv2), 1)
        iconv2 = self.iconv2(concat2)
        pr2 = self.pred_flow2(iconv2)

        upconv1 = self.upconv1(iconv2)
        upflow2 = self.upflow2to1(pr2)
        concat1 = torch.cat((upconv1, upflow2, conv1), 1)
        iconv1 = self.iconv1(concat1)
        pr1 = self.pred_flow1(iconv1)

        upconv0 = self.upconv0(iconv1)
        upflow1 = self.upflow1to0(pr1)
        concat0 = torch.cat((upconv0, upflow1, img_left), 1)
        iconv0 = self.iconv0(concat0)
        pr0 = self.pred_flow0(iconv0)

	# img_right_rec = warp(img_left, pr0)

        # if self.training:
        #     # print("finish forwarding.")
        #     return pr0, pr1, pr2, pr3, pr4, pr5, pr6
        # else:
        #     return pr0

        # can be chosen outside
        return pr0, pr1, pr2, pr3, pr4, pr5, pr6
Esempio n. 41
0
    def forward(self, x):
        """Applies network layers and ops on input image(s) x.

        Args:
            x: input image or batch of images. Shape: [batch,3,300,300].

        Return:
            Depending on phase:
            test:
                Variable(tensor) of output class label predictions,
                confidence score, and corresponding location predictions for
                each object detected. Shape: [batch,topk,7]

            train:
                list of concat outputs from:
                    1: confidence layers, Shape: [batch*num_priors,num_classes]
                    2: localization layers, Shape: [batch,num_priors*4]
                    3: priorbox layers, Shape: [2,num_priors*4]
        """
        image_size = [x.shape[2] , x.shape[3]]
        loc = list()
        conf = list()
        
        if backbone == 'vgg':
            for k in range(16):
                x  = self.vgg[k](x)
            conv3_3_x = x
            for k in range(16 , 23):
                x = self.vgg[k](x)
            conv4_3_x = x
            for k in range(23 , 30):
                x = self.vgg[k](x)
            conv5_3_x = x
            for k in range(30, len(self.vgg)):
                x = self.vgg[k](x)
            fc7_x = x
            for k, v in enumerate(self.extras):
                x = F.relu(v(x), inplace=True)
                if k == 1:
                    conv6_2_x = x
                if k == 3 :
                    conv7_2_x = x

        elif backbone in ['senet','resnet50', 'detnet','resnet101','resnet152' , 'resnext']:
            conv3_3_x = self.layer1(x)
            conv4_3_x = self.layer2(conv3_3_x)
            conv5_3_x = self.layer3(conv4_3_x)
            fc7_x = self.layer4(conv5_3_x)
            conv6_2_x = self.layer5(fc7_x)
            conv7_2_x = self.layer6(conv6_2_x)

        if refine:   
            arm_loc = list()
            arm_conf = list()
            arm_sources = [conv3_3_x, conv4_3_x, conv5_3_x, fc7_x, conv6_2_x, conv7_2_x]
            for (x, l, c) in zip(arm_sources, self.arm_loc, self.arm_conf):
                arm_loc.append( l(x).permute(0, 2, 3, 1).contiguous() )    
                arm_conf.append( c(x).permute(0, 2, 3, 1).contiguous() )
            arm_loc = torch.cat([o.view(o.size(0), -1) for o in arm_loc], 1)
            arm_conf = torch.cat([o.view(o.size(0), -1) for o in arm_conf], 1)
              
        if fpn:
            #lfpn6 = self._upsample_product( self.latlayer6(conv7_2_x) , self.smooth6(conv7_2_x))
            #lfpn5 = self._upsample_product( self.latlayer5(lfpn6) , self.smooth5(conv6_2_x))
            #lfpn4 = self._upsample_product( self.latlayer4(lfpn5) , self.smooth4(fc7_x) )
            #lfpn3 = self._upsample_product( self.latlayer3(lfpn4) , self.smooth3(conv5_3_x) )

            lfpn3 = self._upsample_product( self.latlayer3(fc7_x) , self.smooth3(conv5_3_x) )
            lfpn2 = self._upsample_product( self.latlayer2(lfpn3) , self.smooth2(conv4_3_x) )
            lfpn1 = self._upsample_product( self.latlayer1(lfpn2) , self.smooth1(conv3_3_x) )

            #conv7_2_x = lfpn6
            #conv6_2_x = lfpn5
            #fc7_x     = lfpn4

            conv5_3_x = lfpn3
            conv4_3_x = lfpn2
            conv3_3_x = lfpn1

        if backbone == 'vgg':
            conv3_3_x = self.L2Norm_3_3(conv3_3_x)
            conv4_3_x = self.L2Norm_4_3(conv4_3_x)
            conv5_3_x = self.L2Norm_5_3(conv5_3_x)
 
        if bup:
            #conv4_3_x = F.relu(self.bup1(conv3_3_x))  * conv4_3_x 
            #conv5_3_x = F.relu(self.bup2(conv4_3_x))  * conv5_3_x 
            fc7_x     = F.relu(self.bup3(conv5_3_x))  * fc7_x 
            conv6_2_x = F.relu(self.bup4(fc7_x))      * conv6_2_x 
            conv7_2_x = F.relu( self.bup5(conv6_2_x)) * conv7_2_x 
        
        sources = [conv3_3_x, conv4_3_x, conv5_3_x, fc7_x, conv6_2_x, conv7_2_x]
        if fem:
           sources[0] = self.cpm3_3(sources[0])
           sources[1] = self.cpm4_3(sources[1])
           sources[2] = self.cpm5_3(sources[2])
           sources[3] = self.cpm7(sources[3])
           sources[4] = self.cpm6_2(sources[4])
           sources[5] = self.cpm7_2(sources[5])
        
        # apply multibox head to source layers
        featuremap_size = []
        for  (x, l, c) in zip(sources, self.loc, self.conf):
            featuremap_size.append([ x.shape[2], x.shape[3]])
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            if mo:
                if len(conf)==0:
                    chunk = torch.chunk(c(x) , 4 , 1)
                    bmax  = torch.max(torch.max(chunk[0], chunk[1]) , chunk[2])
                    cls1  = torch.cat([bmax,chunk[3]], dim=1)
                    conf.append( cls1.permute(0, 2, 3, 1).contiguous() )
                else:
                    conf.append(c(x).permute(0, 2, 3, 1).contiguous())
            elif mio:
                len_conf = len(conf)
                if cfg['mbox'][0] ==1 :
                    cls = self.mio_module(c(x),len_conf)
                else:
                    mmbox = torch.chunk(c(x) , cfg['mbox'][0] , 1)
                    cls_0 = self.mio_module(mmbox[0], len_conf)
                    cls_1 = self.mio_module(mmbox[1], len_conf)
                    cls_2 = self.mio_module(mmbox[2], len_conf)
                    cls_3 = self.mio_module(mmbox[3], len_conf)
                    cls = torch.cat([cls_0, cls_1, cls_2, cls_3] , dim=1)
                conf.append(cls.permute(0, 2, 3, 1).contiguous())
            else:
                conf.append(c(x).permute(0, 2, 3, 1).contiguous())
        if pa:
            mbox_num = cfg['mbox'][0]
            face_loc = torch.cat(  [o[:,:,:,:4*mbox_num].contiguous().view(o.size(0),-1) for o in loc],1)
            face_conf = torch.cat( [o[:,:,:,:2*mbox_num].contiguous().view(o.size(0),-1) for o in conf],1)
            head_loc = torch.cat( [o[:,:,:,4*mbox_num:8*mbox_num].contiguous().view(o.size(0),-1) for o in loc[1:]],1)
            head_conf = torch.cat( [o[:,:,:,2*mbox_num:4*mbox_num].contiguous().view(o.size(0),-1) for o in conf[1:]],1)
            body_loc = torch.cat( [o[:,:,:,8*mbox_num:].contiguous().view(o.size(0),-1) for o in loc[2:]],1)
            body_conf = torch.cat( [o[:,:,:,4*mbox_num:].contiguous().view(o.size(0),-1) for o in conf[2:]],1)
        else:
            face_loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
            face_conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
        if self.phase == "test":
            self.cfg['feature_maps'] = featuremap_size
            self.cfg['min_dim'] = image_size
            self.priors = self.init_priors(self.cfg)
            if refine:
                output = self.detect(
                  face_loc.view(face_loc.size(0), -1, 4),         # loc preds
                  self.softmax(face_conf.view(face_conf.size(0), -1, self.num_classes)), # conf preds
                  self.priors.type(type(x.data)),                  # default boxes
                  arm_loc.view(arm_loc.size(0), -1, 4),
                  self.softmax(arm_conf.view(arm_conf.size(0), -1, self.num_classes)),
                )
            else:
                output = self.detect(
                  face_loc.view(face_loc.size(0), -1, 4),         # loc preds
                  self.softmax(face_conf.view(face_conf.size(0), -1, self.num_classes)), # conf preds
                  self.priors.type(type(x.data))                  # default boxes
                )
        else:
            self.cfg['feature_maps'] = featuremap_size
            self.cfg['min_dim'] = image_size
            if pa: 
              self.face_priors = self.init_priors(self.cfg)
              self.head_priors = self.init_priors(self.cfg , min_size=cfg['min_sizes'][:-1], max_size=cfg['max_sizes'][:-1])
              self.body_priors = self.init_priors(self.cfg , min_size=cfg['min_sizes'][:-2], max_size=cfg['max_sizes'][:-2])
              output = (
                face_loc.view(face_loc.size(0), -1, 4),
                face_conf.view(face_conf.size(0), -1, self.num_classes),
                self.face_priors,
 
                head_loc.view(head_loc.size(0), -1, 4),
                head_conf.view(head_conf.size(0), -1, self.num_classes),
                self.head_priors,

                body_loc.view(body_loc.size(0), -1, 4),
                body_conf.view(body_conf.size(0), -1, self.num_classes),
                self.body_priors
              )
            else:
              self.priors = self.init_priors(self.cfg)
              output = (
                face_loc.view(face_loc.size(0), -1, 4),
                face_conf.view(face_conf.size(0), -1, self.num_classes),
                self.priors
              )
            if refine:
                output = output + tuple((arm_loc.view(arm_loc.size(0), -1, 4), arm_conf.view(arm_conf.size(0), -1, self.num_classes) ))
        return output
def gnn_track_finding(
        hid,
        x,
        cell_data,
        embed_ckpt_dir='/global/cfs/cdirs/m3443/data/lightning_models/embedding/checkpoints/epoch=10.ckpt',
        filter_ckpt_dir='/global/cfs/cdirs/m3443/data/lightning_models/filtering/checkpoints/epoch=92.ckpt',
        gnn_ckpt_dir='/global/cfs/cdirs/m3443/data/lightning_models/gnn',
        ckpt_idx=-1,
        dbscan_epsilon=0.25,
        dbscan_minsamples=2):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # ### Setup some hyperparameters and event

    # embed_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/embedding/checkpoints/epoch=10.ckpt'
    # filter_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/filtering/checkpoints/epoch=92.ckpt'
    # gnn_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/gnn'
    # ckpt_idx = -1 # which GNN checkpoint to load
    # dbscan_epsilon, dbscan_minsamples = 0.25, 2 # hyperparameters for DBScan
    # min_hits = 5 # minimum number of hits associated with a particle to define "reconstructable particles"
    # frac_reco_matched, frac_truth_matched = 0.5, 0.5 # parameters for track matching

    data = Data(hid=torch.from_numpy(hid),
                x=torch.from_numpy(x).float(),
                cell_data=torch.from_numpy(cell_data).float()).to(device)

    # ### Evaluating Embedding
    # In[9]:
    e_ckpt = torch.load(embed_ckpt_dir, map_location=device)
    e_config = e_ckpt['hyper_parameters']
    e_config['clustering'] = 'build_edges'
    e_config['knn_val'] = 500
    e_config['r_val'] = 1.7

    e_model = LayerlessEmbedding(e_config).to(device)
    e_model.load_state_dict(e_ckpt["state_dict"])
    e_model.eval()

    # Map each hit to the embedding space, return the embeded parameters for each hit
    with torch.no_grad():
        spatial = e_model(torch.cat([data.cell_data, data.x],
                                    axis=-1))  #.to(device)

    # ### From embeddeding space form doublets

    # `r_val = 1.7` and `knn_val = 500` are the hyperparameters to be studied.
    #
    # * `r_val` defines the radius of the clustering method
    # * `knn_val` defines the number of maximum neighbors in the embedding space

    e_spatial = utils_torch.build_edges(spatial.to(device),
                                        e_model.hparams['r_val'],
                                        e_model.hparams['knn_val'])

    # Removing edges that point from outer region to inner region, which almost removes half of edges.
    # In[16]:
    R_dist = torch.sqrt(data.x[:, 0]**2 +
                        data.x[:, 2]**2)  # distance away from origin...
    e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])]

    f_ckpt = torch.load(filter_ckpt_dir, map_location='cpu')
    f_config = f_ckpt['hyper_parameters']
    f_config['train_split'] = [0, 0, 1]
    f_config['filter_cut'] = 0.18

    f_model = VanillaFilter(f_config).to(device)
    f_model.load_state_dict(f_ckpt['state_dict'])
    f_model.eval()

    emb = None  # embedding information was not used in the filtering stage.
    chunks = 10
    output_list = []
    for j in range(chunks):
        subset_ind = torch.chunk(torch.arange(e_spatial.shape[1]), chunks)[j]
        with torch.no_grad():
            output = f_model(torch.cat([data.cell_data, data.x],
                                       axis=-1), e_spatial[:, subset_ind],
                             emb).squeeze()  #.to(device)
        output_list.append(output)
        del subset_ind
        del output
        gc.collect()
    output = torch.cat(output_list)
    output = torch.sigmoid(output)

    # The filtering network assigns a score to each edge.
    # In the end, edges with socres > `filter_cut` are selected to construct graphs.
    # edge_list = e_spatial[:, output.to('cpu') > f_model.hparams['filter_cut']]
    print(f_model.hparams['filter_cut'])
    edge_list = e_spatial[:, output > f_model.hparams['filter_cut']]
    print(edge_list.shape)

    # ### Form a graph
    # Now moving TensorFlow for GNN inference.

    n_nodes = data.x.shape[0]
    n_edges = edge_list.shape[1]
    nodes = data.x.cpu().numpy().astype(np.float32)
    edges = np.zeros((n_edges, 1), dtype=np.float32)
    senders = edge_list[0].cpu()
    receivers = edge_list[1].cpu()

    input_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": nodes,
        "edges": edges,
        "senders": senders,
        "receivers": receivers,
        "globals": np.array([n_nodes], dtype=np.float32)
    }

    input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])

    num_processing_steps_tr = 8
    optimizer = snt.optimizers.Adam(0.001)
    model = SegmentClassifier()

    output_dir = gnn_ckpt_dir
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    ckpt_manager = tf.train.CheckpointManager(checkpoint,
                                              directory=output_dir,
                                              max_to_keep=10)
    status = checkpoint.restore(
        ckpt_manager.checkpoints[ckpt_idx]).expect_partial()

    # clean up GPU memory
    del e_spatial
    del e_model
    del f_model
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()

    outputs_gnn = model(input_graph, num_processing_steps_tr)
    output_graph = outputs_gnn[-1]

    # ### Track labeling
    input_matrix = prepare_labeling(
        tf.squeeze(output_graph.edges).cpu().numpy(), senders, receivers,
        n_nodes)
    predict_track_df = dbscan_clustering(data.hid.cpu(), input_matrix,
                                         dbscan_epsilon, dbscan_minsamples)
    trkx_groups = predict_track_df.groupby(['track_id'])
    all_trk_ids = np.unique(predict_track_df.track_id)
    n_trkxs = all_trk_ids.shape[0]
    predict_tracks = [
        trkx_groups.get_group(all_trk_ids[idx])['hit_id'].to_numpy().tolist()
        for idx in range(n_trkxs)
    ]
    return predict_tracks
    def forward(self, input):
        internal_state = []

        fcn2_output = input  # 12 *(64*64) * 64 channels
        input = torch.cat(torch.chunk(input, 12, dim=0), dim=1)

        for step in range(3):
            x = input
            if step == 0:
                basize, _, height, width = input.size()
                (h_step, c) = ConvLSTMCell.init_hidden(
                    basize, self.hidden_channels[self.num_layers - 1],
                    (height, width))

            fcn2 = self.conv_fcn2(x)

            h_c = self.conv_h(h_step)

            fcn2_h_cat = fcn2 + h_c
            fcn2_h_cat = self.pool_avg(fcn2_h_cat)
            fcn2_h_cat = self.conv_c(fcn2_h_cat)

            # Attention Module
            fcn2_h_cat = torch.mul(F.softmax(fcn2_h_cat, dim=1), 12)
            Att = fcn2_h_cat

            basize, dime, h, w = fcn2_h_cat.size()
            fcn2_h_cat = fcn2_h_cat.view(1, basize, dime, h,
                                         w).transpose(0, 1).transpose(1, 2)
            fcn2_h_cat = torch.cat(torch.chunk(fcn2_h_cat, basize, dim=0),
                                   dim=1).view(basize * dime, 1, 1, 1)

            fcn2_h_cat = torch.mul(fcn2_output,
                                   fcn2_h_cat).view(1, basize * dime, 64, 64,
                                                    64)
            fcn2_h_cat = torch.cat(torch.chunk(fcn2_h_cat, basize, dim=1),
                                   dim=0)
            fcn2_h_cat = torch.sum(fcn2_h_cat, 1, keepdim=False)  #.squeeze()

            x = fcn2_h_cat
            if step < self.step - 1:
                for i in range(self.num_layers):
                    # all cells are initialized in the first step
                    if step == 0:
                        bsize, _, height, width = x.size()
                        (h,
                         c) = ConvLSTMCell.init_hidden(bsize,
                                                       self.hidden_channels[i],
                                                       (height, width))
                        internal_state.append((h, c))
                    # do forward
                    name = 'cell{}'.format(i)
                    (h, c) = internal_state[i]

                    x, new_c, new_o = getattr(self,
                                              name)(x, h,
                                                    c)  # ConvLSTMCell forward
                    internal_state[i] = (x, new_c)
                h_step = x
                # only record effective steps
                #if step in self.effective_step:

                if step == 0:
                    outputs_o = new_o
                else:
                    outputs_o = torch.cat((outputs_o, new_o), dim=1)

        # outputs_o = torch.cat([outputs_o, new_o], dim=1)
        # outputs_o = torch.cat([outputs_o, outputs_o, outputs_o], dim=1)
        outputs = self.conv_pre(outputs_o)

        output = F.upsample(outputs, scale_factor=4, mode='bilinear')

        return output
Esempio n. 44
0
 def forward(self, x):
     x = self.conv1(x)
     # add_scalar
     x = x + 3
     # mul_scalar
     x = x * 3
     # add_scalar_out
     x += 3
     # mul_scalar_out
     x *= 3
     # add_scalar_relu
     x = x + 3
     x = F.relu(x)
     # add_scalar_relu_out
     x += 3
     x = F.relu(x)
     # mul_scalar_relu
     x = x * 3
     x = F.relu(x)
     # mul_scalar_relu_out
     x *= 3
     x = F.relu(x)
     x = self.maxpool1d(x)
     x = self.maxpool2d(x)
     x = self.maxpool3d(x)
     x = torch.flatten(x)
     x = torch.max(x)
     x = torch.min(x)
     x = x.reshape([-1])
     x = x.resize_(1, 1, x.numel())
     x = x.view(-1)
     # prim::ListConstruct
     xs = [x, x]
     # prim::ListUnpack
     x, y = xs
     # prim::TupleConstruct
     xs = (x, x)
     # prim::TupleUnpack
     x, y = xs
     x = x.transpose(1, 2)
     x = x.contiguous()
     x, y = torch.chunk(x, 2)
     x = F.dropout(x)
     x = self.dropout(x)
     x, _ = torch.sort(x)
     x = x.permute(0, 2, 3, 1)
     x = x.repeat_interleave(3, 1)
     x = torch.repeat_interleave(x, 3, 1)
     x = self.relu(x)
     x = F.relu(x)
     x = F.relu(x, inplace=True)
     x = x.relu()
     x.relu_()
     x = x.squeeze(0)
     x.squeeze_(0)
     x = torch.squeeze(x, 0)
     x = x.unsqueeze(0)
     x.unsqueeze_(0)
     x = torch.unsqueeze(x, 0)
     x = x.detach()
     x.detach_()
     x = x.repeat(4, 2)
     y = []
     y.append(x)
     z = torch.stack(y, 0)
     z = [z, z]
     x, _ = z
     x = self.conv2(x)
     return x
def RGB_to_BGR(batch):
    batch = batch.transpose(0, 1)
    (r, g, b) = torch.chunk(batch, 3)
    batch = torch.cat((b, g, r))
    batch = batch.transpose(0, 1)
    return batch
Esempio n. 46
0
    def calculate_cpes(
        self,
        training_batch,
        states,
        next_states,
        all_next_action_scores,
        logged_action_idxs,
        discount_tensor,
        not_done_mask,
    ):
        if not self.calc_cpe_in_training:
            return None, None, None

        if training_batch.extras.metrics is None:
            metrics_reward_concat_real_vals = training_batch.training_input.reward
        else:
            metrics_reward_concat_real_vals = torch.cat(
                (training_batch.training_input.reward, training_batch.extras.metrics),
                dim=1,
            )

        model_propensities_next_states = masked_softmax(
            all_next_action_scores,
            training_batch.training_input.possible_next_actions_mask
            if self.maxq_learning
            else training_batch.training_input.next_action,
            self.rl_temperature,
        )

        with torch.enable_grad():
            ######### Train separate reward network for CPE evaluation #############
            # FIXME: the reward network should be outputing a tensor, not a q-value object
            reward_estimates = self.reward_network(states).q_values
            reward_estimates_for_logged_actions = reward_estimates.gather(
                1, self.reward_idx_offsets + logged_action_idxs
            )
            reward_loss = F.mse_loss(
                reward_estimates_for_logged_actions, metrics_reward_concat_real_vals
            )
            reward_loss.backward()
            self._maybe_run_optimizer(
                self.reward_network_optimizer, self.minibatches_per_step
            )

            ######### Train separate q-network for CPE evaluation #############
            metric_q_values = self.q_network_cpe(states).q_values.gather(
                1, self.reward_idx_offsets + logged_action_idxs
            )
            all_metrics_target_q_values = torch.chunk(
                self.q_network_cpe_target(next_states).q_values.detach(),
                len(self.metrics_to_score),
                dim=1,
            )
            target_metric_q_values = []
            for i, per_metric_target_q_values in enumerate(all_metrics_target_q_values):
                per_metric_next_q_values = torch.sum(
                    per_metric_target_q_values * model_propensities_next_states,
                    1,
                    keepdim=True,
                )
                per_metric_next_q_values = per_metric_next_q_values * not_done_mask
                per_metric_target_q_values = metrics_reward_concat_real_vals[
                    :, i : i + 1
                ] + (discount_tensor * per_metric_next_q_values)
                target_metric_q_values.append(per_metric_target_q_values)

            target_metric_q_values = torch.cat(target_metric_q_values, dim=1)
            metric_q_value_loss = self.q_network_loss(
                metric_q_values, target_metric_q_values
            )
            metric_q_value_loss.backward()
            self._maybe_run_optimizer(
                self.q_network_cpe_optimizer, self.minibatches_per_step
            )

        # Use the soft update rule to update target network
        self._maybe_soft_update(
            self.q_network_cpe,
            self.q_network_cpe_target,
            self.tau,
            self.minibatches_per_step,
        )

        model_propensities = masked_softmax(
            self.all_action_scores,
            training_batch.training_input.possible_actions_mask
            if self.maxq_learning
            else training_batch.training_input.action,
            self.rl_temperature,
        )
        model_rewards = reward_estimates[
            :,
            torch.arange(
                self.reward_idx_offsets[0],
                self.reward_idx_offsets[0] + self.num_actions,
            ),
        ]
        return reward_loss, model_rewards, model_propensities
Esempio n. 47
0
def vgg_preprocess_caffe(var):
    (r, g, b) = torch.chunk(var, 3, dim=1)
    bgr = torch.cat((b, g, r), 1)
    out = bgr * 255 - torch.autograd.Variable(vgg_mean).type(var.type())
    return out
 def _compute_rotation_matrix_taylor(angle_axis):
     rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
     k_one = torch.ones_like(rx)
     rotation_matrix = torch.cat(
         [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
     return rotation_matrix.view(-1, 3, 3)
 def vgg_preprocess(tensor):
     (r, g, b) = torch.chunk(tensor, 3, dim=0)
     bgr = torch.cat((b, g, r), 0)
     out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr)
     return out
    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head,
                                 self.d_head)  # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head,
                                 self.d_head)  # qlen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head,
                                 self.d_head)  # qlen x bsz x n_head x d_head

        r_head_k = r_head_k.view(rlen, self.n_head,
                                 self.d_head)  # qlen x n_head x d_head

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias  # qlen x bsz x n_head x d_head
        AC = torch.einsum('ibnd,jbnd->ijbn',
                          (rw_head_q, w_head_k))  # qlen x klen x bsz x n_head

        rr_head_q = w_head_q + r_r_bias
        BD = torch.einsum('ibnd,jnd->ijbn',
                          (rr_head_q, r_head_k))  # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[None, :, :, None],
                    -float('inf')).type_as(attn_score)
            elif attn_mask.dim() == 3:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[:, :, :, None],
                    -float('inf')).type_as(attn_score)

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(attn_vec.size(0),
                                              attn_vec.size(1),
                                              self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)

        return output
Esempio n. 51
0
 def forward(self, lefts, rights, tracking=None):
     batch_size = len(lefts)
     ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0))
     return torch.chunk(ret, batch_size, 0)
Esempio n. 52
0
    def forward(
        self,
        x: torch.Tensor,
        states=Tuple[torch.Tensor, torch.Tensor],
        variational_dropout_mask: Optional[torch.BoolTensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        !!! Warning
            DO NOT USE THIS LAYER DIRECTLY, instead use the AugmentedLSTM class

        # Parameters

        x : `torch.Tensor`
            Input tensor of shape (bsize x input_dim).
        states : `Tuple[torch.Tensor, torch.Tensor]`
            Tuple of tensors containing
            the hidden state and the cell state of each element in
            the batch. Each of these tensors have a dimension of
            (bsize x nhid). Defaults to `None`.

        # Returns

        `Tuple[torch.Tensor, torch.Tensor]`
            Returned states. Shape of each state is (bsize x nhid).

        """
        hidden_state, memory_state = states

        # In Pytext this was done as the last step of the cell.
        # But the original AugmentedLSTM from AllenNLP this was done before the processing
        if variational_dropout_mask is not None and self.training:
            hidden_state = hidden_state * variational_dropout_mask

        projected_input = self.input_linearity(x)
        projected_state = self.state_linearity(hidden_state)

        input_gate = forget_gate = memory_init = output_gate = highway_gate = None
        if self.use_highway:
            fused_op = projected_input[:, : 5 * self.lstm_dim] + projected_state
            fused_chunked = torch.chunk(fused_op, 5, 1)
            (input_gate, forget_gate, memory_init, output_gate, highway_gate) = fused_chunked
            highway_gate = torch.sigmoid(highway_gate)
        else:
            fused_op = projected_input + projected_state
            input_gate, forget_gate, memory_init, output_gate = torch.chunk(fused_op, 4, 1)
        input_gate = torch.sigmoid(input_gate)
        forget_gate = torch.sigmoid(forget_gate)
        memory_init = torch.tanh(memory_init)
        output_gate = torch.sigmoid(output_gate)
        memory = input_gate * memory_init + forget_gate * memory_state
        timestep_output: torch.Tensor = output_gate * torch.tanh(memory)

        if self.use_highway:
            highway_input_projection = projected_input[
                :, self._highway_inp_proj_start: self._highway_inp_proj_end
            ]
            timestep_output = (
                highway_gate * timestep_output
                + (1 - highway_gate) * highway_input_projection  # noqa
            )

        return timestep_output, memory