def call(self, inputs, training=True, survival_prob=None):
        """Implementation of call().

    Args:
      inputs: the inputs tensor.
      training: boolean, whether the model is constructed for training.
      survival_prob: float, between 0 to 1, drop connect rate.

    Returns:
      A output tensor.
    """
        x = inputs
        if self._block_args.expand_ratio != 1:
            x = self._relu_fn(
                self._bn0(self._expand_conv(x), training=training))
        x = self._relu_fn(self._bn1(self._depthwise_conv(x),
                                    training=training))

        if self._has_se:
            se_tensor = tf.reduce_mean(x, self._spatial_dims, keepdims=True)
            se_tensor = self._se_expand(
                self._relu_fn(self._se_reduce(se_tensor)))
            x = tf.sigmoid(se_tensor) * x

        x = self._bn2(self._project_conv(x), training=training)
        # Add identity so that quantization-aware training can insert quantization
        # ops correctly.
        x = tf.identity(x)
        if self._clip_projection_output:
            x = tf.clip_by_value(x, -6, 6)
        if all(
                s == 1 for s in self._block_args.strides
        ) and self._block_args.input_filters == self._block_args.output_filters:
            if survival_prob:
                x = utils.drop_connect(x, training, survival_prob)
            x = tf.add(x, inputs)
        return x
Exemplo n.º 2
0
    def call(self, inputs, training=True, survival_prob=None):
        """Implementation of call().

        Args:
          inputs: the inputs tensor.
          training: boolean, whether the model is constructed for training.
          survival_prob: float, between 0 to 1, drop connect rate.

        Returns:
          A output tensor.
        """
        logging.info('Block input: %s shape: %s', inputs.name, inputs.shape)
        if self._block_args.expand_ratio != 1:
            x = self._relu_fn(self._bn0(self._expand_conv(inputs), training=training))
        else:
            x = inputs
        logging.info('Expand: %s shape: %s', x.name, x.shape)

        self.endpoints = {'expansion_output': x}

        x = self._bn1(self._project_conv(x), training=training)
        # Add identity so that quantization-aware training can insert quantization
        # ops correctly.
        x = tf.identity(x)
        if self._clip_projection_output:
            x = tf.clip_by_value(x, -6, 6)

        if self._block_args.id_skip:
            if all(
                    s == 1 for s in self._block_args.strides
            ) and self._block_args.input_filters == self._block_args.output_filters:
                # Apply only if skip connection presents.
                if survival_prob:
                    x = utils.drop_connect(x, training, survival_prob)
                x = tf.add(x, inputs)
        logging.info('Project: %s shape: %s', x.name, x.shape)
        return x
Exemplo n.º 3
0
    def forward(self, inputs, drop_connect_rate=None):
        """
        :param inputs: input tensor
        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
        :return: output of block
        """
        # Expansion and Depthwise Convolution
        x = inputs
        if self._block_args.expand_ratio != 1:
            x = relu_fn(self._bn0(self._expand_conv(inputs)))
        x = relu_fn(self._bn1(self._depthwise_conv(x)))

        # Squeeze and Excitation
        if self.has_se:
            x_squeezed = F.adaptive_avg_pool2d(x, 1)
            x_squeezed = self._se_expand(relu_fn(self._se_reduce(x_squeezed)))
            x = torch.sigmoid(x_squeezed) * x

        x_b = x
        x = self._bn2(self._project_conv(x))

        pad_h = x_b.shape[2] - x.shape[2]
        pad_w = x_b.shape[3] - x.shape[3]
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])

        # Skip connection and drop connect
        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
            if drop_connect_rate:
                x = drop_connect(x, p=drop_connect_rate, training=self.training)
            pad_h = inputs.shape[2] - x.shape[2]
            pad_w = inputs.shape[3] - x.shape[3]
            if pad_h > 0 or pad_w > 0:
                x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
            x = x + inputs  # skip connection
        return x
Exemplo n.º 4
0
    def call(self, inputs, training=True, survival_prob=None):
        """Implementation of call().

    Args:
      inputs: the inputs tensor.
      training: boolean, whether the model is constructed for training.
      survival_prob: float, between 0 to 1, drop connect rate.

    Returns:
      A output tensor.
    """
        logging.info('Block input: %s shape: %s', inputs.name, inputs.shape)
        logging.info('Block input depth: %s output depth: %s',
                     self._block_args.input_filters,
                     self._block_args.output_filters)

        x = inputs

        fused_conv_fn = self._fused_conv
        expand_conv_fn = self._expand_conv
        depthwise_conv_fn = self._depthwise_conv
        project_conv_fn = self._project_conv

        if self._block_args.condconv:
            pooled_inputs = self._avg_pooling(inputs)
            routing_weights = self._routing_fn(pooled_inputs)
            # Capture routing weights as additional input to CondConv layers
            fused_conv_fn = functools.partial(self._fused_conv,
                                              routing_weights=routing_weights)
            expand_conv_fn = functools.partial(self._expand_conv,
                                               routing_weights=routing_weights)
            depthwise_conv_fn = functools.partial(
                self._depthwise_conv, routing_weights=routing_weights)
            project_conv_fn = functools.partial(
                self._project_conv, routing_weights=routing_weights)

        # creates conv 2x2 kernel
        if self._block_args.super_pixel == 1:
            with tf.variable_scope('super_pixel'):
                x = self._relu_fn(
                    self._bnsp(self._superpixel(x), training=training))
            logging.info('Block start with SuperPixel: %s shape: %s', x.name,
                         x.shape)

        if self._block_args.fused_conv:
            # If use fused mbconv, skip expansion and use regular conv.
            x = self._relu_fn(self._bn1(fused_conv_fn(x), training=training))
            logging.info('Conv2D: %s shape: %s', x.name, x.shape)
        else:
            # Otherwise, first apply expansion and then apply depthwise conv.
            if self._block_args.expand_ratio != 1:
                x = self._relu_fn(
                    self._bn0(expand_conv_fn(x), training=training))
                logging.info('Expand: %s shape: %s', x.name, x.shape)

            x = self._relu_fn(
                self._bn1(depthwise_conv_fn(x), training=training))
            logging.info('DWConv: %s shape: %s', x.name, x.shape)

        if self._has_se:
            with tf.variable_scope('se'):
                x = self._call_se(x)

        self.endpoints = {'expansion_output': x}

        x = self._bn2(project_conv_fn(x), training=training)
        # Add identity so that quantization-aware training can insert quantization
        # ops correctly.
        x = tf.identity(x)
        if self._clip_projection_output:
            x = tf.clip_by_value(x, -6, 6)
        if self._block_args.id_skip:
            if all(
                    s == 1 for s in self._block_args.strides
            ) and self._block_args.input_filters == self._block_args.output_filters:
                # Apply only if skip connection presents.
                if survival_prob:
                    x = utils.drop_connect(x, training, survival_prob)
                x = tf.add(x, inputs)
        logging.info('Project: %s shape: %s', x.name, x.shape)
        return x
Exemplo n.º 5
0
    def forward(self, xyz, rgb, istrain=False):
        hs = []
        #xyz_copy = xyz.clone()
        #rgb_copy = rgb.clone()
        batch_size, n_points, _ = xyz.shape
        part_length = n_points // self.npart
        last_point = -1
        last_feature_dim = -1
        #h = self.proj_in(rgb)
        h = rgb
        s2_count = 0
        for i in range(self.num_layers):
            h_input = h.clone()
            xyz_input = xyz.clone()
            batch_size, n_points, feature_dim = h.shape

            ######## Build Graph #########
            last_point = n_points

            ######### Dynamic Graph Conv #########
            xyz = xyz.transpose(1, 2).contiguous()
            #print(h.shape) # batchsize x point_number x feature_dim
            h = h.transpose(1, 2).contiguous()
            for j in range(self.num_conv):
                index = self.num_conv * i + j
                ####### BN + ReLU #####
                if self.pre_act == True:
                    if self.norm == 'ln':
                        h = h.transpose(1, 2).contiguous()
                        h = self.bn[index](h)
                        h = h.transpose(1, 2).contiguous()
                    else:
                        h = self.bn[index](h)
                    h = F.leaky_relu(h, 0.2)

                ####### Graph Feature ###########
                if self.k == 1 and j == 0:
                    h = h.unsqueeze(-1)
                else:
                    if i == self.num_layers - 1:
                        if self.cluster == 'xyz':
                            h = get_graph_feature(xyz, h, k=self.k)
                        elif self.cluster == 'xyzrgb' or self.cluster == 'allxyzrgb':
                            h = get_graph_feature(torch.cat((xyz, h), 1),
                                                  h,
                                                  k=self.k)
                    else:
                        # Common Layers
                        if self.cluster == 'allxyzrgb':
                            h = get_graph_feature(torch.cat((xyz, h), 1),
                                                  h,
                                                  k=self.k)
                        else:
                            h = get_graph_feature(xyz, h, k=self.k)

                ####### Conv ##########
                if self.light == True and i > 0:
                    #shuffle after the first layer
                    h = channel_shuffle(h, 2)
                    h = self.conv[index](h)
                else:
                    h = self.conv[index](h)
                h = h.max(dim=-1, keepdim=False)[0]
                ####### BN + ReLU #####
                if self.pre_act == False:
                    if self.norm == 'ln':
                        h = h.transpose(1, 2).contiguous()
                        h = self.bn[index](h)
                        h = h.transpose(1, 2).contiguous()
                    else:
                        h = self.bn[index](h)
                    h = F.leaky_relu(h, 0.2)

            h = h.transpose(1, 2).contiguous()
            #print(h.shape) # batchsize x point_number x feature_dim
            batch_size, n_points, feature_dim = h.shape

            ######### Residual Before Downsampling#############
            if self.id_skip == 1:
                if istrain and self.drop_connect_rate > 0:
                    h = drop_connect(h,
                                     p=self.drop_connect_rate,
                                     training=istrain)
                if feature_dim != last_feature_dim:
                    h_input = self.conv_s2[s2_count](h_input)
                h = h_input + self.res_scale * h

            ######### PointNet++ MSG ########
            if feature_dim != last_feature_dim:
                h = h.transpose(1, 2).contiguous()
                xyz, h = self.sa[s2_count](xyz_input, h)
                h = h.transpose(1, 2).contiguous()
                if self.id_skip == 2:
                    h_input = pointnet2_utils.gather_operation(
                        h_input.transpose(1, 2).contiguous(),
                        pointnet2_utils.furthest_point_sample(
                            xyz_input, h.shape[1])).transpose(1,
                                                              2).contiguous()
            else:
                xyz = xyz.transpose(1, 2).contiguous()

            ######### Residual After Downsampling (Paper) #############
            if self.id_skip == 2:
                if istrain and self.drop_connect_rate > 0:
                    h = drop_connect(h,
                                     p=self.drop_connect_rate,
                                     training=istrain)
                if feature_dim != last_feature_dim:
                    h_input = self.conv_s2[s2_count](h_input)
                h = h_input + self.res_scale * h

            if feature_dim != last_feature_dim:
                s2_count += 1
                last_feature_dim = feature_dim

            #print(xyz.shape, h.shape)
        if self.npart == 1:
            # Pooling
            h_max, _ = torch.max(h, 1)
            h_avg = torch.mean(h, 1)
            hs.append(h_max)
            hs.append(h_avg)

            h = torch.cat(hs, 1)
            h = self.embs[0](h)
            h = self.bn_embs[0](h)
            h = self.dropouts[0](h)
            h = self.proj_output(h)
        else:
            # Sort
            #batch_size, n_points, _ = h.shape
            #y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1)
            #h = h.view(batch_size * n_points, -1)
            #h = h[y_index, :].view(batch_size, n_points, -1)
            h = h.transpose(1, 2)
            # Part Pooling
            h = self.partpool(h)
            for i in range(self.npart):
                part_h = h[:, :, i]
                part_h = self.embs[i](part_h)
                part_h = self.bn_embs[i](part_h)
                part_h = self.dropouts[i](part_h)
                part_h = self.proj_outputs[i](part_h)
                hs.append(part_h)
            h = hs
        return h
Exemplo n.º 6
0
    def forward(self, xyz, rgb, istrain=False):
        hs = []
        #xyz_copy = xyz.clone()
        #rgb_copy = rgb.clone()
        batch_size, n_points, _ = xyz.shape
        part_length = n_points//self.npart
        last_point = -1
        #h = self.proj_in(rgb)
        h = rgb
        s2_count = 0
        for i in range(self.num_layers):
            h_input = h.clone()
            xyz_input = xyz.clone()
            batch_size, n_points, _ = h.shape
            if self.k>1:
                if i == self.num_layers-1:
                    if self.cluster == 'xyz':
                        g = self.nng(xyz, istrain = istrain and self.graph_jitter)
                    elif self.cluster == 'rgb':
                        g = self.nng(h, istrain=istrain and self.graph_jitter)
                    elif self.cluster == 'xyzrgb':
                        g = self.nng( torch.cat((xyz,h), 2), istrain=istrain and self.graph_jitter)
                elif i==0 or  n_points !=  last_point:
                    g = self.nng(xyz, istrain=istrain and self.graph_jitter)
            last_point = n_points
            h = h.view(batch_size * n_points, -1)

            if self.k==1:
                h = self.conv[i](h)
            elif self.conv_type == 'GatedGCN':
                h = self.conv[i](g, h, g.edata['feat'], snorm_n = 1/g.number_of_nodes() , snorm_e = 1/g.number_of_edges())
            else:
                h = self.conv[i](g, h)
            h = F.leaky_relu(h, 0.2)
            h = h.view(batch_size, n_points, -1)
            h = h.transpose(1, 2).contiguous()
            xyz, h  = self.sa[i](xyz_input, h)
            h = h.transpose(1, 2).contiguous()
            #h = self.conv_s1[i](h)
            if self.id_skip and  h.shape[1] <= self.init_points//4:
            # We could use identity mapping Here or add connect drop
                if istrain and self.drop_connect_rate>0:
                    h = drop_connect(h, p=self.drop_connect_rate, training=istrain)

                if h.shape[1] == n_points:
                    h = h_input + self.res_scale * h  # Here I borrow the idea from Inception-ResNet-v2
                elif h.shape[1] == n_points//2:
                    h_input_s2 = pointnet2_utils.gather_operation(
                        h_input.transpose(1, 2).contiguous(), 
                        pointnet2_utils.furthest_point_sample(xyz_input, h.shape[1] )
                    ).transpose(1, 2).contiguous()
                    h = self.conv_s2[s2_count](h_input_s2) + self.res_scale * h
                    s2_count +=1
        if self.npart==1:
            # Pooling
            h_max, _ = torch.max(h, 1)
            h_avg = torch.mean(h, 1)
            hs.append(h_max)
            hs.append(h_avg)

            h = torch.cat(hs, 1)
            h = self.embs[0](h)
            h = self.bn_embs[0](h)
            h = self.dropouts[0](h)
            h = self.proj_output(h)
        else:
            # Sort 
            batch_size, n_points, _ = h.shape
            y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1)
            h = h.view(batch_size * n_points, -1)
            h = h[y_index, :].view(batch_size, n_points, -1)
            h = h.transpose(1, 2)
            # Part Pooling            
            h = self.partpool(h)
            for i in range(self.npart):
                part_h = h[:,:,i]
                part_h = self.embs[i](part_h)
                part_h = self.bn_embs[i](part_h)
                part_h = self.dropouts[i](part_h)
                part_h = self.proj_outputs[i](part_h)
                hs.append(part_h)
            h = hs
        return h
Exemplo n.º 7
0
    def forward(self, xyz, rgb, istrain=False):
        hs = []
        #xyz_copy = xyz.clone()
        #rgb_copy = rgb.clone()
        batch_size, n_points, _ = xyz.shape
        part_length = n_points//self.npart
        last_point = -1
        last_feature_dim = -1
        #h = self.proj_in(rgb)
        h = rgb
        s2_count = 0
        for i in range(self.num_layers):
            h_input = h.clone()
            xyz_input = xyz.clone()
            batch_size, n_points, feature_dim  = h.shape

            ######## Build Graph #########
            if self.k>1:
                if i == self.num_layers-1:
                    if self.cluster == 'xyz':
                        g = self.nng(xyz, istrain = istrain, jitter= self.graph_jitter)
                    elif self.cluster == 'rgb':
                        g = self.nng(h, istrain=istrain, jitter= self.graph_jitter)
                    elif self.cluster == 'xyzrgb':
                        g = self.nng( torch.cat((xyz,h), 2), istrain=istrain, jitter= self.graph_jitter)
                elif n_points !=  last_point:
                    g = self.nng(xyz, istrain=istrain, jitter= self.graph_jitter)
            last_point = n_points

            ######### Dynamic Graph Conv #########
            if self.pre_act == True:
                if self.norm == 'ln':
                    h = self.bn[i](h)
                else:
                    h = h.transpose(1, 2).contiguous()
                    h = self.bn[i](h)
                    h = F.leaky_relu(h, 0.2)
                    h = h.transpose(1, 2).contiguous()

            h = h.view(batch_size * n_points, -1)
            if self.k==1:
                h = self.conv[i](h)
            elif self.conv_type == 'GatedGCN':
                h = self.conv[i](g, h, g.edata['feat'], snorm_n = 1/g.number_of_nodes() , snorm_e = 1/g.number_of_edges())
            else:
                h = self.conv[i](g, h)

            if self.pre_act == False:
                if self.norm == 'ln':
                    h = self.bn[i](h)
                h = F.leaky_relu(h, 0.2)

            h = h.view(batch_size, n_points, -1)
            batch_size, n_points, feature_dim  = h.shape

            if self.id_skip: 
                if istrain and self.drop_connect_rate>0:
                    h = drop_connect(h, p=self.drop_connect_rate, training=istrain)
                if feature_dim != last_feature_dim:
                    h_input = self.conv_s2[s2_count](h_input)
                h = h_input + self.res_scale * h

            ######### PointNet++ MSG ########
            if feature_dim != last_feature_dim:
                h = h.transpose(1, 2).contiguous()
                xyz, h  = self.sa[s2_count](xyz_input, h)
                h = h.transpose(1, 2).contiguous()
                s2_count +=1
                last_feature_dim = feature_dim

        if self.npart==1:
            # Pooling
            h_max, _ = torch.max(h, 1)
            h_avg = torch.mean(h, 1)
            hs.append(h_max)
            hs.append(h_avg)

            h = torch.cat(hs, 1)
            h = self.embs[0](h)
            h = self.bn_embs[0](h)
            h = self.dropouts[0](h)
            h = self.proj_output(h)
        else:
            # Sort 
            #batch_size, n_points, _ = h.shape
            #y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1)
            #h = h.view(batch_size * n_points, -1)
            #h = h[y_index, :].view(batch_size, n_points, -1)
            h = h.transpose(1, 2)
            # Part Pooling            
            h = self.partpool(h)
            for i in range(self.npart):
                part_h = h[:,:,i]
                part_h = self.embs[i](part_h)
                part_h = self.bn_embs[i](part_h)
                part_h = self.dropouts[i](part_h)
                part_h = self.proj_outputs[i](part_h)
                hs.append(part_h)
            h = hs
        return h