Example #1
0
    def __init__(self,
                 rep_size=64,
                 n_classes=10,
                 mi_units=128,
                 encoder_name='resnet10',
                 image_channel=1,
                 margin=5,
                 alpha=0.33,
                 beta=0.33,
                 gamma=0.33):
        super().__init__()
        self.rep_size = rep_size
        self.n_classes = n_classes
        # self.input_shape = input_shape
        self.mi_units = mi_units
        self.encoder_name = encoder_name
        self.margin = margin
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        # build encoder
        n = int(encoder_name.strip('resnet'))
        self.encoder = resnet.build_resnet_32x32(
            n, fc_size=rep_size,
            image_channel=image_channel)  # output a representation
        print('==> # encoder parameters {}'.format(cal_parameters(
            self.encoder)))

        self.task_idx = (2, -1)

        local_size = 64

        # 1x1 conv performed on only channel dimension
        self.local_MInet = MI1x1ConvNet(local_size, self.mi_units)
        self.global_MInet = MI1x1ConvNet(self.rep_size, self.mi_units)

        self.class_conditional = ClassConditionalGaussianMixture(
            self.n_classes, self.rep_size)
        model = SDIM(rep_size=hps.rep_size,
                     mi_units=hps.mi_units,
                     encoder_name=hps.encoder_name,
                     image_channel=hps.image_channel).to(hps.device)

        checkpoint_path = os.path.join(
            hps.log_dir,
            'sdim_{}_{}_d{}.pth'.format(hps.encoder_name, hps.problem,
                                        hps.rep_size))
        model.load_state_dict(
            torch.load(checkpoint_path,
                       map_location=lambda storage, loc: storage))
    else:
        n_encoder_layers = int(hps.encoder_name.strip('resnet'))
        model = build_resnet_32x32(n=n_encoder_layers,
                                   fc_size=hps.n_classes,
                                   image_channel=hps.image_channel).to(
                                       hps.device)

        checkpoint_path = os.path.join(
            hps.log_dir, '{}_{}.pth'.format(hps.encoder_name, hps.problem))
        model.load_state_dict(
            torch.load(checkpoint_path,
                       map_location=lambda storage, loc: storage))

    print('Model name: {}'.format(hps.encoder_name))
    print('==>  # Model parameters: {}.'.format(cal_parameters(model)))

    if not os.path.exists(hps.log_dir):
        os.mkdir(hps.log_dir)

    if not os.path.exists(hps.attack_dir):