Example #1
0
    def forward(self, input_trunk, input_task_specific, output_size=None):

        if len(input_task_specific) != self.n_tasks:
            raise ValueError('Input from task-specific route not same count as tasks.')

        if self.index_intermediate is None:
            if type(input_trunk) is tuple:
                output_trunk = self.attented_layer(*input_trunk, output_size=output_size)
            else:
                output_trunk = self.attented_layer(input_trunk, output_size=output_size)

            output_trunk_ = output_trunk
        else:
            if type(input_trunk) is tuple:
                output_trunk, output_trunk_intermediate = self.attented_layer(*input_trunk,
                                                                              index_intermediate=self.index_intermediate,
                                                                              output_size=output_size)
            else:
                output_trunk, output_trunk_intermediate = self.attented_layer(input_trunk,
                                                                              index_intermediate=self.index_intermediate,
                                                                              output_size=output_size)

            output_trunk_ = output_trunk_intermediate

        output_attentions = []

        for i in range(self.n_tasks):
            if self.upsampling:
                output_attention = F.interpolate(input_task_specific[i],
                                                 scale_factor=2 if output_size is None else None,
                                                 mode='bilinear',
                                                 align_corners=True, size=output_size[2:])
            else:
                output_attention = input_task_specific[i]
            output_attention = self.shared_feature_extractor(output_attention)

            input_attention = torch.cat((output_trunk_, output_attention.type_as(output_trunk_)), dim=1)

            output_attention = self.specific_feature_extractor[i](input_attention)
            if self.save_attention_mask:
                self.attention_mask = output_attention.data.cpu().numpy()
            output_attention = output_attention * output_trunk
            output_attentions.append(output_attention)

        return output_trunk, tuple(output_attentions)
Example #2
0
    def forward(self, input_trunk, input_task_specific=None, index_intermediate=None):

        if not self.first_block and input_task_specific is None:
            raise ValueError('Is not the first attention block, but has no input from task-specific route.')

        if input_task_specific is None:
            input_task_specific = [torch.Tensor([]) for _ in range(self.n_tasks)]

        if len(input_task_specific) != self.n_tasks:
            raise ValueError('Input from task-specific route not same count as tasks.')

        if index_intermediate is None:
            output_trunk = self.attented_layer(input_trunk)
        else:
            output_trunk, output_trunk_intermediate = self.attented_layer(input_trunk,
                                                                          index_intermediate=index_intermediate)
        if type(output_trunk) is tuple:
            output_trunk_ = output_trunk[0]
        else:
            output_trunk_ = output_trunk

        if index_intermediate is None:
            output_trunk_intermediate = output_trunk_

        output_attentions = []

        for i in range(self.n_tasks):
            input_attention = torch.cat((output_trunk_intermediate,
                                         input_task_specific[i].type_as(output_trunk_intermediate)), dim=1)
            output_attention = self.specific_feature_extractor[i](input_attention)
            if self.save_attention_mask:
                self.attention_mask = output_attention.data.cpu().numpy()
            output_attention = output_attention * output_trunk_
            # encoder_block_att are shared
            output_attention = self.shared_feature_extractor(output_attention)
            if self.downsampling:
                output_attention = F.max_pool2d(output_attention, kernel_size=2, stride=2)
            output_attentions.append(output_attention)

        return output_trunk, tuple(output_attentions)
Example #3
0
 def preprocess(self, img, labels):
     n_labels = self.label_embedding(labels)
     n_labels = n_labels.reshape(n_labels.size(0), 1, 28, 28)
     n_img = torch.cat((n_labels, img), 1)
     return n_img
Example #4
0
 def preprocess(self, noise, labels):
     n_labels = self.label_embedding(labels)
     n_labels = n_labels.reshape(n_labels.size(0), n_labels.size(1), 1, 1)
     n_noise = torch.cat((n_labels, noise), 1)
     return n_noise