Пример #1
0
def classifier(x, dropout):
    """
	AlexNet fully connected layers definition

	Args:
		x: tensor of shape [batch_size, width, height, channels]
		dropout: probability of non dropping out units

	Returns:
		fc3: 1000 linear tensor taken just before applying the softmax operation
			it is needed to feed it to tf.softmax_cross_entropy_with_logits()
		softmax: 1000 linear tensor representing the output probabilities of the image to classify

	"""
    pool5 = cnn(x)

    dim = pool5.get_shape().as_list()
    flat_dim = dim[1] * dim[2] * dim[3]  # 6 * 6 * 256
    flat = tf.reshape(pool5, [-1, flat_dim])

    with tf.name_scope('alexnet_classifier') as scope:
        with tf.name_scope('alexnet_classifier_fc1') as inner_scope:
            wfc1 = tu.weight([flat_dim, 4096], name='wfc1')
            wfc_1 = tu.weight([flat_dim, 4096], name='wfc_1')
            bfc1 = tu.bias(0.0, [4096], name='bfc1')
            alpha_full_1 = compute_alpha(wfc1)
            wfc_1 = tenary_opration(wfc1)
            flat = tf.multiply(flat, alpha_full_1)
            fc1 = tf.add(tf.matmul(flat, wfc_1), bfc1)
            fc1 = tu.batch_norm(fc1)
            fc1 = selu(fc1)
            fc1 = tf.nn.dropout(fc1, dropout)

        with tf.name_scope('alexnet_classifier_fc2') as inner_scope:
            wfc2 = tu.weight([4096, 4096], name='wfc2')
            wfc_2 = tu.weight([4096, 4096], name='wfc_2')
            bfc2 = tu.bias(0.0, [4096], name='bfc2')
            alpha6 = compute_alpha(wfc2)
            wfc_2 = tenary_opration(wfc2)
            fc1 = tf.multiply(fc1, alpha6)
            fc2 = tf.add(tf.matmul(fc1, wfc_2), bfc2)
            fc2 = tu.batch_norm(fc2)
            fc2 = selu(fc2)
            fc2 = tf.nn.dropout(fc2, dropout)

        with tf.name_scope('alexnet_classifier_output') as inner_scope:
            wfc3 = tu.weight([4096, 1000], name='wfc3')
            bfc3 = tu.bias(0.0, [1000], name='bfc3')
            # wfc3 = tenary_opration(wfc3)
            fc3 = tf.add(tf.matmul(fc2, wfc3), bfc3)
            softmax = tf.nn.softmax(fc3)

    return fc3, softmax
Пример #2
0
    def forward(self, batch):
        # feature_maps is for visualization with https://github.com/fornaxai/receptivefield
        self.feature_maps = []

        #                    B,  C, H,   W
        # batch:            [B0, 1, 112, 112]
        # => padded_batch   [B0, 1, 112, 112]
        if self.num_modalities > 0:
            B0, C, H, W, MOD = batch.shape
            batch_bm = batch.permute([0, 4, 1, 2, 3])
            batch_bm = batch.view([B0 * MOD, C, H, W])
            # Merge the modality dim into the batch dim.
            # Then go through ordinary preprocessing (splitting of depth, merging depth into batch).
            # After extracting features, merge MOD sets of feature maps into one set.
            batch = batch_bm
        else:
            # MOD = 0 means there's not the modality dimension
            # (i.e. images are in one modality only). Didn't use MOD = 1 to distingush from
            # the case that there's a modality dimension containing only one modality.
            MOD = 0
            B0 = batch.shape[0]

        B, C, H, W = batch.shape

        # nonzero_mask: if '3': [B, 14, 14]; if '2': [B, 36, 36].
        nonzero_mask = self.get_mask(batch)

        if self.backbone_type.startswith('res'):
            batch_base_feats = self.backbone.ext_features(batch)
        elif self.backbone_type.startswith('eff'):
            feats_dict = self.backbone.extract_endpoints(batch)
            #                       [10, 16, 288, 288],        [10, 24, 144, 144]
            batch_base_feats = ( feats_dict['reduction_1'], feats_dict['reduction_2'], \
            #                       [10, 40, 72, 72],          [10, 112, 36, 36],       [10, 1280, 18, 18]

                                 feats_dict['reduction_3'], feats_dict['reduction_4'], feats_dict['reduction_5'] )
            # Corresponding stages in efficient-net paper, Table 1: 2, 3, 4, 6, 9

        # vfeat_fpn: [B (B0*MOD), 1296, 1792]
        vfeat_fpn, vmask, H2, W2 = self.in_fpn_forward(batch_base_feats,
                                                       nonzero_mask, B)
        vfeat_origshape = vfeat_fpn.transpose(1, 2).view(B, -1, H2, W2)
        self.feature_maps.append(vfeat_origshape)

        if self.num_modalities > 0:
            # vfeat_fpn_MOD: [B0, MOD, 1296, 1792]
            vfeat_fpn_MOD = vfeat_fpn.view(B0, MOD, -1, self.trans_in_dim)
            # vfeat_fpn: [B0, 1296, 1792]
            # vfeat_fpn = self.mod_fuse_conv(vfeat_fpn_MOD).squeeze(1)
            vfeat_fpn = vfeat_fpn_MOD.max(dim=1)[0]
            # No need to normalize features here. Each feature vector in vfeat_fpn
            # will be layer-normed in SegtranInputFeatEncoder.

        # if self.in_fpn_layers == '234', xy_shape = (36, 36)
        # if self.in_fpn_layers == '34',  xy_shape = (14, 14)
        xy_shape = torch.Size((H2, W2))
        # xy_indices: [14, 14, 20, 3]
        xy_indices = get_all_indices(xy_shape, device=self.device)
        scale_H = H // H2
        scale_W = W // W2

        # Has to be exactly divided.
        if (scale_H * H2 != H) or (scale_W * W2 != W):
            breakpoint()

        if not self.scales_printed:
            print("\nImage scales: %dx%d. Voxels: %s" %
                  (scale_H, scale_W, list(vfeat_fpn.shape)))
            self.scales_printed = True

        scale = torch.tensor([[scale_H, scale_W]], device=self.device)
        # xy_indices: [1296, 2]
        # Rectify the scales on H, W.
        xy_indices = xy_indices.view([-1, 2]).float() * scale

        # voxels_pos: [B0, 1296, 2], "2" is coordinates.
        voxels_pos = xy_indices.unsqueeze(0).repeat((B0, 1, 1))

        # pos_embed = self.featemb(voxels_pos)
        # vfeat_fused: [2, 784, 1792]
        vfeat_fused = self.voxel_fusion(vfeat_fpn, voxels_pos,
                                        vmask.unsqueeze(2))
        for i in range(self.num_translayers):
            self.feature_maps.append(
                self.voxel_fusion.translayers[i].attention_scores)

        # vfeat_fused: [2, 32, 32, 1792]
        vfeat_fused = vfeat_fused.view([B0, H2, W2, self.trans_out_dim])
        # vfeat_fused: [5, 32, 32, 1792] => [5, 1792, 32, 32]
        vfeat_fused = vfeat_fused.permute([0, 3, 1, 2])

        for i in range(self.num_translayers):
            layer_vfeat = self.voxel_fusion.layers_vfeat[i]
            layer_vfeat = layer_vfeat.view(
                [B0, H2, W2, self.translayer_dims[i + 1]])
            # layer_vfeat: [5, 32, 32, 1792] => [5, 1792, 32, 32]
            layer_vfeat = layer_vfeat.permute([0, 3, 1, 2])
            self.feature_maps.append(layer_vfeat)

        if self.do_out_fpn:
            vfeat_fused_fpn = self.out_fpn_forward(batch_base_feats,
                                                   vfeat_fused, B0)
            if self.posttrans_use_bn:
                vfeat_fused_fpn = batch_norm(vfeat_fused_fpn)
            trans_scores_small = self.out_conv(vfeat_fused_fpn)
        else:
            # scores: [B0, 2, 36, 36]
            # if vfeat_fpn is already 28*28 (in_fpn_layers=='234'),
            # then out_conv does not do upsampling.
            if self.posttrans_use_bn:
                vfeat_fused = batch_norm(vfeat_fused)
            trans_scores_small = self.out_conv(vfeat_fused)

        # full_scores: [B0, 2, 112, 112]
        trans_scores_up = F.interpolate(trans_scores_small,
                                        size=(H, W),
                                        mode='bilinear',
                                        align_corners=False)

        return trans_scores_up
Пример #3
0
def cnn(x):
    """
	AlexNet convolutional layers definition

	Args:
		x: tensor of shape [batch_size, width, height, channels]

	Returns:
		pool5: tensor with all convolutions, pooling and lrn operations applied

	"""
    with tf.name_scope('alexnet_cnn') as scope:
        with tf.name_scope('alexnet_cnn_conv1') as inner_scope:
            wcnn1 = tu.weight([11, 11, 3, 96], name='wcnn1')
            bcnn1 = tu.bias(0.0, [96], name='bcnn1')
            conv1 = tf.add(tu.conv2d(x, wcnn1, stride=(4, 4), padding='SAME'),
                           bcnn1)
            #conv1 = tu.batch_norm(conv1)
            conv1 = tu.relu(conv1)
            norm1 = tu.batch_norm(conv1)
            pool1 = tu.max_pool2d(norm1,
                                  kernel=[1, 3, 3, 1],
                                  stride=[1, 2, 2, 1],
                                  padding='VALID')

        with tf.name_scope('alexnet_cnn_conv2') as inner_scope:
            wcnn2 = tu.weight([5, 5, 96, 256], name='wcnn2')
            bcnn2 = tu.bias(1.0, [256], name='bcnn2')
            conv2 = tf.add(
                tu.conv2d(pool1, wcnn2, stride=(1, 1), padding='SAME'), bcnn2)
            #conv2 = tu.batch_norm(conv2)
            conv2 = tu.relu(conv2)
            norm2 = tu.batch_norm(conv2)
            pool2 = tu.max_pool2d(norm2,
                                  kernel=[1, 3, 3, 1],
                                  stride=[1, 2, 2, 1],
                                  padding='VALID')

        with tf.name_scope('alexnet_cnn_conv3') as inner_scope:
            wcnn3 = tu.weight([3, 3, 256, 384], name='wcnn3')
            bcnn3 = tu.bias(0.0, [384], name='bcnn3')
            conv3 = tf.add(
                tu.conv2d(pool2, wcnn3, stride=(1, 1), padding='SAME'), bcnn3)
            conv3 = tu.batch_norm(conv3)
            conv3 = tu.relu(conv3)

        with tf.name_scope('alexnet_cnn_conv4') as inner_scope:
            wcnn4 = tu.weight([3, 3, 384, 384], name='wcnn4')
            bcnn4 = tu.bias(1.0, [384], name='bcnn4')
            conv4 = tf.add(
                tu.conv2d(conv3, wcnn4, stride=(1, 1), padding='SAME'), bcnn4)
            conv4 = tu.batch_norm(conv4)
            conv4 = tu.relu(conv4)

        with tf.name_scope('alexnet_cnn_conv5') as inner_scope:
            wcnn5 = tu.weight([3, 3, 384, 256], name='wcnn5')
            bcnn5 = tu.bias(1.0, [256], name='bcnn5')
            conv5 = tf.add(
                tu.conv2d(conv4, wcnn5, stride=(1, 1), padding='SAME'), bcnn5)
            conv5 = tu.batch_norm(conv5)
            conv5 = tu.relu(conv5)
            pool5 = tu.max_pool2d(conv5,
                                  kernel=[1, 3, 3, 1],
                                  stride=[1, 2, 2, 1],
                                  padding='VALID')

        return pool5
Пример #4
0
def cnn(x):
    """
	AlexNet convolutional layers definition

	Args:
		x: tensor of shape [batch_size, width, height, channels]

	Returns:
		pool5: tensor with all convolutions, pooling and lrn operations applied

	"""
    with tf.name_scope('alexnet_cnn') as scope:
        with tf.name_scope('alexnet_cnn_conv1') as inner_scope:
            wcnn1 = tu.weight([11, 11, 3, 96], name='wcnn1')
            bcnn1 = tu.bias(0.0, [96], name='bcnn1')
            # alpha1 = compute_alpha(wcnn1)
            # wcnn1 = tenary_opration(wcnn1)
            # wcnn1_1 =  tf.multiply(alpha1, wcnn1)
            conv1 = tf.add(tu.conv2d(x, wcnn1, stride=(4, 4), padding='SAME'),
                           bcnn1)
            conv1 = tu.batch_norm(conv1)
            conv1 = selu(conv1)
            # norm1 = tu.lrn(conv1, depth_radius=2, bias=1.0, alpha=2e-05, beta=0.75)
            pool1 = tu.max_pool2d(conv1,
                                  kernel=[1, 3, 3, 1],
                                  stride=[1, 2, 2, 1],
                                  padding='VALID')

        with tf.name_scope('alexnet_cnn_conv2') as inner_scope:
            wcnn2 = tu.weight([5, 5, 96, 256], name='wcnn2')
            wcnn_2 = tu.weight([5, 5, 96, 256], name='wcnn_2')
            bcnn2 = tu.bias(1.0, [256], name='bcnn2')
            alpha2 = compute_alpha(wcnn2)
            pool1 = tf.multiply(pool1, alpha2)
            wcnn_2 = tenary_opration(wcnn2)
            # wcnn_2 = tf.multiply(alpha2, wcnn2)
            conv2 = tf.add(
                tu.conv2d(pool1, wcnn_2, stride=(1, 1), padding='SAME'), bcnn2)
            conv2 = tu.batch_norm(conv2)
            conv2 = selu(conv2)
            # norm2 = tu.lrn(conv2, depth_radius=2, bias=1.0, alpha=2e-05, beta=0.75)
            pool2 = tu.max_pool2d(conv2,
                                  kernel=[1, 3, 3, 1],
                                  stride=[1, 2, 2, 1],
                                  padding='VALID')

        with tf.name_scope('alexnet_cnn_conv3') as inner_scope:
            wcnn3 = tu.weight([3, 3, 256, 384], name='wcnn3')
            wcnn_3 = tu.weight([3, 3, 256, 384], name='wcnn_3')
            bcnn3 = tu.bias(0.0, [384], name='bcnn3')
            alpha3 = compute_alpha(wcnn3)
            wcnn_3 = tenary_opration(wcnn3)
            pool2 = tf.multiply(pool2, alpha3)
            conv3 = tf.add(
                tu.conv2d(pool2, wcnn_3, stride=(1, 1), padding='SAME'), bcnn3)
            conv3 = tu.batch_norm(conv3)
            conv3 = selu(conv3)

        with tf.name_scope('alexnet_cnn_conv4') as inner_scope:
            wcnn4 = tu.weight([3, 3, 384, 384], name='wcnn4')
            wcnn_4 = tu.weight([3, 3, 383, 384], name='wcnn_4')
            bcnn4 = tu.bias(1.0, [384], name='bcnn4')
            alpha4 = compute_alpha(wcnn4)
            wcnn_4 = tenary_opration(wcnn4)
            conv3 = tf.multiply(conv3, alpha4)
            conv4 = tf.add(
                tu.conv2d(conv3, wcnn_4, stride=(1, 1), padding='SAME'), bcnn4)
            conv4 = tu.batch_norm(conv4)
            conv4 = selu(conv4)

        with tf.name_scope('alexnet_cnn_conv5') as inner_scope:
            wcnn5 = tu.weight([3, 3, 384, 256], name='wcnn5')
            wcnn_5 = tu.weight([3, 3, 384, 256], name='wcnn_5')
            bcnn5 = tu.bias(1.0, [256], name='bcnn5')
            alpha5 = compute_alpha(wcnn5)
            wcnn_5 = tenary_opration(wcnn5)
            conv4 = tf.multiply(conv4, alpha5)
            conv5 = tf.add(
                tu.conv2d(conv4, wcnn_5, stride=(1, 1), padding='SAME'), bcnn5)
            conv5 = tu.batch_norm(conv5)
            conv5 = selu(conv5)
            pool5 = tu.max_pool2d(conv5,
                                  kernel=[1, 3, 3, 1],
                                  stride=[1, 2, 2, 1],
                                  padding='VALID')

        return pool5
Пример #5
0
def eval_robustness(args, net, refnet, dataloader, mask_prepred_mapping_func=None):
    AUG_DEG = args.robust_aug_degrees
    if not isinstance(AUG_DEG, collections.abc.Iterable):
        AUG_DEG = (AUG_DEG, AUG_DEG)

    if args.robustness_augs:
        augs = [ args.robustness_augs ]
        is_resize = [ False ]
    else:
        augs = [
            transforms.ColorJitter(brightness=AUG_DEG),
            transforms.ColorJitter(contrast=AUG_DEG),
            transforms.ColorJitter(saturation=AUG_DEG),
            transforms.Resize((192, 192)),
            transforms.Resize((432, 432)),
            transforms.Pad(0)   # Placeholder. Replace input images with random noises.
        ]
        is_resize = [ False, False, False, True, True, False ]

    num_augs = len(augs)
    num_iters = args.robust_sample_num // args.batch_size
    # on_pearsons: pearsons between old and new feature maps
    on_pearsons = np.zeros((num_augs, net.num_vis_layers))
    # lr_old_pearsons/lr_new_pearsons: pearsons between left-half and right-half of the feature maps
    lr_old_pearsons = np.zeros((net.num_vis_layers))
    old_stds        = np.zeros((net.num_vis_layers))
    lr_new_pearsons = np.zeros((num_augs, net.num_vis_layers))
    new_stds        = np.zeros((num_augs, net.num_vis_layers))
    aug_counts      = np.zeros(num_augs) + 0.0001
    print("Evaluating %d augs on %d layers of feature maps, %d samples" %(num_augs, net.num_vis_layers, args.robust_sample_num))
    do_BN = True
    orig_allcls_dice_sum    = np.zeros(args.num_classes - 1)
    aug_allcls_dice_sum     = np.zeros((num_augs, args.num_classes - 1))
    orig_sample_count       = 0
    aug_sample_counts       = np.zeros(num_augs) + 0.0001

    # Compare the feature maps from the same network.
    if refnet is None:
        refnet = net
        
    for it in tqdm(range(num_iters)):
        aug_idx = it % num_augs
        aug_counts[aug_idx] += 1
        aug = augs[aug_idx]
        dataloader.dataset.image_trans_func2 = transforms.Compose( [ aug ] + \
                                                                   dataloader.dataset.image_trans_func.transforms )

        batch = next(iter(dataloader))
        image_batch, image2_batch, mask_batch = batch['image'].cuda(), batch['image2'].cuda(), batch['mask'].cuda()
        image_batch = F.interpolate(image_batch, size=args.patch_size,
                                   mode='bilinear', align_corners=False)
        image2_batch = F.interpolate(image2_batch, size=args.patch_size,
                                   mode='bilinear', align_corners=False)
        if mask_prepred_mapping_func:
            mask_batch = mask_prepred_mapping_func(mask_batch)

        orig_input_size = mask_batch.shape[2:]
        if it == 0:
            print("Input size: {}, orig image size: {}".format(image_batch.shape[2:], orig_input_size))

        if aug_idx == 5:
            image2_batch.normal_()

        with torch.no_grad():
            scores_raw = refnet(image_batch)
            scores_raw = F.interpolate(scores_raw, size=orig_input_size,
                                       mode='bilinear', align_corners=False)

            batch_allcls_dice = calc_batch_metric(scores_raw, mask_batch, args.num_classes, 0.5)
            orig_allcls_dice_sum    += batch_allcls_dice.sum(axis=0)
            orig_sample_count       += len(batch_allcls_dice)

            orig_features = copy.copy(refnet.feature_maps)
            orig_bn_features = list(orig_features)
            net.feature_maps = []
            scores_raw2 = net(image2_batch)

            batch_allcls_dice = calc_batch_metric(scores_raw2, mask_batch, args.num_classes, 0.5)
            aug_allcls_dice_sum[aug_idx]    += batch_allcls_dice.sum(axis=0)
            aug_sample_counts[aug_idx]      += len(batch_allcls_dice)

            aug_features  = copy.copy(net.feature_maps)
            aug_bn_features  = list(aug_features)
            net.feature_maps = []
            for layer_idx in range(net.num_vis_layers):
                if is_resize[aug_idx] and orig_features[layer_idx].shape != aug_features[layer_idx].shape:
                    try:
                        aug_features[layer_idx] = F.interpolate(aug_features[layer_idx], size=orig_features[layer_idx].shape[-2:],
                                                                mode='bilinear', align_corners=False)
                    except:
                        breakpoint()

                if do_BN and orig_features[layer_idx].dim() == 4:
                    orig_bn_features[layer_idx] = batch_norm(orig_features[layer_idx])
                    aug_bn_features[layer_idx]  = batch_norm(aug_features[layer_idx])

                pear = pearson(orig_bn_features[layer_idx], aug_bn_features[layer_idx])
                on_pearsons[aug_idx, layer_idx]     += pear
                lr_old_pearsons[layer_idx] += lr_pearson(orig_bn_features[layer_idx])
                lr_new_pearsons[aug_idx, layer_idx] += lr_pearson(aug_bn_features[layer_idx])

                # 4D feature maps. Assume dim 1 is the channel dim.
                if orig_features[layer_idx].dim() == 4:
                    chan_num = orig_features[layer_idx].shape[1]
                    old_std  = orig_features[layer_idx].transpose(0, 1).reshape(chan_num, -1).std(dim=1).mean()
                    new_std  = aug_features[layer_idx].transpose(0, 1).reshape(chan_num, -1).std(dim=1).mean()
                else:
                    old_std  = orig_features[layer_idx].std()
                    new_std  = aug_features[layer_idx].std()
                old_stds[layer_idx] += old_std
                new_stds[aug_idx, layer_idx] += new_std

    aug_counts = np.expand_dims(aug_counts, 1)
    aug_sample_counts = np.expand_dims(aug_sample_counts, 1)
    on_pearsons /= aug_counts
    lr_old_pearsons /= num_iters
    lr_new_pearsons /= aug_counts
    old_stds /= num_iters
    new_stds /= aug_counts

    orig_allcls_avg_metric = orig_allcls_dice_sum / orig_sample_count
    aug_allcls_avg_metric  = aug_allcls_dice_sum  / aug_sample_counts

    print('Orig dices: ', end='')
    for cls in range(1, args.num_classes):
        orig_dice = orig_allcls_avg_metric[cls-1]
        print('%.3f ' %(orig_dice), end='')
    print()

    for aug_idx in range(num_augs):
        print('Aug %d dices: ' %aug_idx, end='')
        for cls in range(1, args.num_classes):
            aug_dice = aug_allcls_avg_metric[aug_idx, cls-1]
            print('%.3f ' %(aug_dice), end='')
        print()

    for layer_idx in range(net.num_vis_layers):
        print("%d: Orig LR P %.3f, Std %.3f" %(layer_idx, lr_old_pearsons[layer_idx], old_stds[layer_idx]))

    for aug_idx in range(num_augs):
        print(augs[aug_idx])
        for layer_idx in range(net.num_vis_layers):
            print("%d: ON P %.3f, LR P %.3f, Std %.3f" %(layer_idx,
                            on_pearsons[aug_idx, layer_idx],        # old/new    pearson
                            lr_new_pearsons[aug_idx, layer_idx],    # left/right pearson
                            new_stds[aug_idx, layer_idx]))