Exemplo n.º 1
0
    def __init__(self, args):
        self.reconstruction_path = args.reconstruction_path
        if not os.path.exists(self.reconstruction_path):
            os.makedirs(self.reconstruction_path)

        self.beta = args.beta
        self.train_batch_size = args.recon_train_batch_size
        self.test_batch_size = args.test_batch_size
        self.epochs = args.epochs
        self.early_stop = args.early_stop
        self.early_stop_observation_period = args.early_stop_observation_period
        self.use_scheduler = False
        self.print_training = args.print_training
        self.class_num = args.class_num
        self.disentangle_with_reparameterization = args.disentangle_with_reparameterization
        self.share_encoder = args.share_encoder
        self.train_flag = False

        # self.small_recon_weight = 0.001
        self.small_recon_weight = 0.01
        # self.small_recon_weight = 0.1

        self.z_dim = args.z_dim

        # class (pos/neg) mem (pos/neg)
        self.encoder_name_list = ['pn', 'pp', 'np', 'nn']

        self.z_idx = dict()
        self.base_z_dim = int(self.z_dim / 4)
        for encoder_idx, encoder_name in enumerate(self.encoder_name_list):
            self.z_idx[encoder_name] = range(encoder_idx * self.base_z_dim,
                                             (encoder_idx + 1) *
                                             self.base_z_dim)

        # Enc/Dec
        self.encoders = dict()
        if args.dataset in ['MNIST', 'Fashion-MNIST', 'CIFAR-10', 'SVHN']:
            if args.dataset in ['MNIST', 'Fashion-MNIST']:
                self.num_channels = 1
            elif args.dataset in ['CIFAR-10', 'SVHN']:
                self.num_channels = 3

            for encoder_name in self.encoder_name_list:
                self.encoders[encoder_name] = module.VAEConvEncoder(
                    self.z_dim, self.num_channels)

            self.decoder = module.VAEConvDecoder(self.z_dim, self.num_channels)

        # Discriminators
        self.class_discs = dict()
        self.membership_discs = dict()
        for encoder_name in self.encoder_name_list:
            self.class_discs[encoder_name] = module.ClassDiscriminator(
                self.base_z_dim, args.class_num)
            self.membership_discs[
                encoder_name] = module.MembershipDiscriminator(
                    self.base_z_dim + args.class_num, 1)

        # Optimizer
        self.encoders_opt = dict()
        self.class_discs_opt = dict()
        self.membership_discs_opt = dict()
        for encoder_name in self.encoder_name_list:
            self.encoders_opt[encoder_name] = optim.Adam(
                self.encoders[encoder_name].parameters(),
                lr=args.recon_lr,
                betas=(0.5, 0.999))
            self.class_discs_opt[encoder_name] = optim.Adam(
                self.class_discs[encoder_name].parameters(),
                lr=args.disc_lr,
                betas=(0.5, 0.999))
            self.membership_discs_opt[encoder_name] = optim.Adam(
                self.membership_discs[encoder_name].parameters(),
                lr=args.disc_lr,
                betas=(0.5, 0.999))

        self.decoder_opt = optim.Adam(self.decoder.parameters(),
                                      lr=args.recon_lr,
                                      betas=(0.5, 0.999))

        # Loss
        self.recon_loss = self.get_loss_function()
        self.class_loss = nn.CrossEntropyLoss(reduction='sum')
        self.membership_loss = nn.BCEWithLogitsLoss(reduction='sum')

        self.weights = {
            'recon': args.recon_weight,
            'class_pos': args.class_pos_weight,
            'class_neg': args.class_neg_weight,
            'membership_pos': args.membership_pos_weight,
            'membership_neg': args.membership_neg_weight,
        }

        # To device
        self.device = torch.device("cuda:{}".format(args.gpu_id))
        for encoder_name in self.encoder_name_list:
            self.encoders[encoder_name] = self.encoders[encoder_name].to(
                self.device)
            self.class_discs[encoder_name] = self.class_discs[encoder_name].to(
                self.device)
            self.membership_discs[encoder_name] = self.membership_discs[
                encoder_name].to(self.device)
        self.decoder = self.decoder.to(self.device)

        self.disentangle = (
            self.weights['class_pos'] + self.weights['class_neg'] +
            self.weights['membership_pos'] + self.weights['membership_neg'] >
            0)

        self.start_epoch = 0
        self.best_valid_loss = float("inf")
        # self.train_loss = 0
        self.early_stop_count = 0

        self.class_acc_dict = {
            'pn': 0.,
            'pp': 0.,
            'np': 0.,
            'nn': 0.,
        }
        self.membership_acc_dict = {
            'pn': 0.,
            'pp': 0.,
            'np': 0.,
            'nn': 0.,
        }
        self.best_class_acc_dict = {}
        self.best_membership_acc_dict = {}

        if 'cuda' in str(self.device):
            cudnn.benchmark = True

        if args.resume:
            print('==> Resuming from checkpoint..')
            try:
                self.load()
            except FileNotFoundError:
                print(
                    'There is no pre-trained model; Train model from scratch')
    def __init__(self, args):
        self.reconstruction_path = args.reconstruction_path
        if not os.path.exists(self.reconstruction_path):
            os.makedirs(self.reconstruction_path)

        self.beta = args.beta
        self.train_batch_size = args.recon_train_batch_size
        self.test_batch_size = args.test_batch_size
        self.epochs = args.epochs
        self.early_stop = args.early_stop
        self.early_stop_observation_period = args.early_stop_observation_period
        self.use_scheduler = False
        self.print_training = args.print_training
        self.class_num = args.class_num
        self.disentangle_with_reparameterization = args.disentangle_with_reparameterization

        self.z_dim = args.z_dim
        self.disc_input_dim = int(self.z_dim / 2)

        self.class_idx = range(0, self.disc_input_dim)
        self.membership_idx = range(self.disc_input_dim, self.z_dim)

        self.nets = dict()

        if args.dataset in ['MNIST', 'Fashion-MNIST', 'CIFAR-10', 'SVHN']:
            if args.dataset in ['MNIST', 'Fashion-MNIST']:
                self.num_channels = 1
            elif args.dataset in ['CIFAR-10', 'SVHN']:
                self.num_channels = 3

            self.nets['encoder'] = module.VAEConvEncoder(
                self.z_dim, self.num_channels)
            self.nets['decoder'] = module.VAEConvDecoder(
                self.z_dim, self.num_channels)

        elif args.dataset in ['adult', 'location']:
            self.nets['encoder'] = module.VAEFCEncoder(args.encoder_input_dim,
                                                       self.z_dim)
            self.nets['decoder'] = module.FCDecoder(args.encoder_input_dim,
                                                    self.z_dim)

        self.discs = {
            'class_fz':
            module.ClassDiscriminator(self.z_dim, args.class_num),
            'class_cz':
            module.ClassDiscriminator(self.disc_input_dim, args.class_num),
            'class_mz':
            module.ClassDiscriminator(self.disc_input_dim, args.class_num),
            'membership_fz':
            module.MembershipDiscriminator(self.z_dim + args.class_num, 1),
            'membership_cz':
            module.MembershipDiscriminator(
                self.disc_input_dim + args.class_num, 1),
            'membership_mz':
            module.MembershipDiscriminator(
                self.disc_input_dim + args.class_num, 1),
        }

        self.recon_loss = self.get_loss_function()
        self.class_loss = nn.CrossEntropyLoss(reduction='sum')
        self.membership_loss = nn.BCEWithLogitsLoss(reduction='sum')

        # optimizer
        self.optimizer = dict()
        for net_type in self.nets:
            self.optimizer[net_type] = optim.Adam(
                self.nets[net_type].parameters(),
                lr=args.recon_lr,
                betas=(0.5, 0.999))
        self.discriminator_lr = args.disc_lr
        for disc_type in self.discs:
            self.optimizer[disc_type] = optim.Adam(
                self.discs[disc_type].parameters(),
                lr=self.discriminator_lr,
                betas=(0.5, 0.999))

        self.weights = {
            'recon': args.recon_weight,
            'class_fz': args.class_fz_weight,
            'class_cz': args.class_cz_weight,
            'class_mz': args.class_mz_weight,
            'membership_fz': args.membership_fz_weight,
            'membership_cz': args.membership_cz_weight,
            'membership_mz': args.membership_mz_weight,
        }

        self.scheduler_enc = StepLR(self.optimizer['encoder'],
                                    step_size=50,
                                    gamma=0.1)
        self.scheduler_dec = StepLR(self.optimizer['decoder'],
                                    step_size=50,
                                    gamma=0.1)

        # to device
        self.device = torch.device("cuda:{}".format(args.gpu_id))
        for net_type in self.nets:
            self.nets[net_type] = self.nets[net_type].to(self.device)
        for disc_type in self.discs:
            self.discs[disc_type] = self.discs[disc_type].to(self.device)

        self.disentangle = (
            self.weights['class_fz'] + self.weights['class_cz'] +
            self.weights['class_mz'] + self.weights['membership_fz'] +
            self.weights['membership_cz'] + self.weights['membership_mz'] > 0)

        self.start_epoch = 0
        self.best_valid_loss = float("inf")
        # self.train_loss = 0
        self.early_stop_count = 0

        self.acc_dict = {
            'class_fz': 0,
            'class_cz': 0,
            'class_mz': 0,
            'membership_fz': 0,
            'membership_cz': 0,
            'membership_mz': 0,
        }
        self.best_acc_dict = {}

        if 'cuda' in str(self.device):
            cudnn.benchmark = True

        if args.resume:
            print('==> Resuming from checkpoint..')
            try:
                self.load()
            except FileNotFoundError:
                print(
                    'There is no pre-trained model; Train model from scratch')
Exemplo n.º 3
0
    def __init__(self, args):
        self.reconstruction_path = args.reconstruction_path
        if not os.path.exists(self.reconstruction_path):
            os.makedirs(self.reconstruction_path)

        self.beta = args.beta
        self.train_batch_size = args.recon_train_batch_size
        self.test_batch_size = args.test_batch_size
        self.epochs = args.epochs
        # self.early_stop = args.early_stop_recon
        self.early_stop = False
        self.early_stop_observation_period = args.early_stop_observation_period
        self.print_training = args.print_training
        self.class_num = args.class_num
        self.disentangle_with_reparameterization = args.disentangle_with_reparameterization
        self.share_encoder = args.share_encoder
        self.train_flag = False
        self.resume = args.resume
        self.adversarial_loss_mode = args.adversarial_loss_mode
        self.gradient_penalty_weight = args.gradient_penalty_weight
        self.reduction = 'sum'
        # self.reduction = 'mean'
        self.scheduler_type = args.scheduler_type 

        self.disentanglement_start_epoch = 0
        self.save_step_size = 100 
        self.scheduler_step_size = 100 

        self.small_recon_weight = args.small_recon_weight
        self.z_dim = args.z_dim

        # class (pos/neg) mem (pos/neg)
        self.encoder_name_list = ['pn', 'pp', 'np', 'nn']

        self.z_idx = dict()
        self.base_z_dim = int(self.z_dim / 4)
        for encoder_idx, encoder_name in enumerate(self.encoder_name_list):
            self.z_idx[encoder_name] = range(encoder_idx * self.base_z_dim, (encoder_idx + 1) * self.base_z_dim)

        # Enc/Dec
        self.encoders = dict()
        if args.dataset in ['MNIST', 'Fashion-MNIST', 'CIFAR-10', 'SVHN']:
            if args.dataset in ['MNIST', 'Fashion-MNIST']:
                self.num_channels = 1
            elif args.dataset in ['CIFAR-10', 'SVHN']:
                self.num_channels = 3

            for encoder_name in self.encoder_name_list:
                self.encoders[encoder_name] = module.VAEConvEncoder(self.z_dim, self.num_channels)

            self.decoder = module.VAEConvDecoder(self.z_dim, self.num_channels)

        # Discriminators
        self.class_discs = dict()
        self.membership_discs = dict()
        for encoder_name in self.encoder_name_list:
            # self.class_discs[encoder_name] = module.ClassDiscriminator(self.base_z_dim, args.class_num)
            # self.membership_discs[encoder_name] = module.MembershipDiscriminator(self.base_z_dim + args.class_num, 1)
            self.class_discs[encoder_name] = module.ClassDiscriminatorImproved(self.base_z_dim, args.class_num)
            self.membership_discs[encoder_name] = module.MembershipDiscriminatorImproved(self.base_z_dim, args.class_num)
        self.rf_disc = module.Discriminator()

        # Optimizer
        self.encoders_opt = dict()
        self.class_discs_opt = dict()
        self.membership_discs_opt = dict()
        for encoder_name in self.encoder_name_list:
            self.encoders_opt[encoder_name] = optim.Adam(self.encoders[encoder_name].parameters(), lr=args.recon_lr, betas=(0.5, 0.999))
            self.class_discs_opt[encoder_name] = optim.Adam(self.class_discs[encoder_name].parameters(), lr=args.recon_lr, betas=(0.5, 0.999))
            self.membership_discs_opt[encoder_name] = optim.Adam(self.membership_discs[encoder_name].parameters(), lr=args.recon_lr, betas=(0.5, 0.999))

        self.decoder_opt = optim.Adam(self.decoder.parameters(), lr=args.recon_lr, betas=(0.5, 0.999))
        self.rf_disc_opt = optim.Adam(self.rf_disc.parameters(), lr=args.recon_lr, betas=(0.5, 0.999))

        # Scheduler 
        self.encoders_opt_scheduler = dict()
        self.class_discs_opt_scheduler = dict()
        self.membership_discs_opt_scheduler = dict()

        if self.early_stop:
            for encoder_name in self.encoder_name_list:
                self.encoders_opt_scheduler[encoder_name] = ReduceLROnPlateau(self.encoders_opt[encoder_name], 'min', patience=self.early_stop_observation_period, threshold=0)
                self.class_discs_opt_scheduler[encoder_name] = ReduceLROnPlateau(self.class_discs_opt[encoder_name], 'min', patience=self.early_stop_observation_period, threshold=0)
                self.membership_discs_opt_scheduler[encoder_name] = ReduceLROnPlateau(self.membership_discs_opt[encoder_name], 'min', patience=self.early_stop_observation_period, threshold=0)

            self.decoder_opt_scheduler = ReduceLROnPlateau(self.decoder_opt, 'min', patience=self.early_stop_observation_period, threshold=0)
            self.rf_disc_opt_scheduler = ReduceLROnPlateau(self.rf_disc_opt, 'min', patience=self.early_stop_observation_period, threshold=0)
        else:
            for encoder_name in self.encoder_name_list:
                # self.encoders_opt_scheduler[encoder_name] = StepLR(self.encoders_opt[encoder_name], self.scheduler_step_size)
                # self.class_discs_opt_scheduler[encoder_name] = StepLR(self.class_discs_opt[encoder_name], self.scheduler_step_size)
                # self.membership_discs_opt_scheduler[encoder_name] = StepLR(self.membership_discs_opt[encoder_name], self.scheduler_step_size)
                self.encoders_opt_scheduler[encoder_name] = self.get_scheduler(self.encoders_opt[encoder_name])
                self.class_discs_opt_scheduler[encoder_name] = self.get_scheduler(self.class_discs_opt[encoder_name])
                self.membership_discs_opt_scheduler[encoder_name] = self.get_scheduler(self.membership_discs_opt[encoder_name])

            # self.decoder_opt_scheduler = StepLR(self.decoder_opt, self.scheduler_step_size)
            # self.rf_disc_opt_scheduler = StepLR(self.rf_disc_opt, self.scheduler_step_size)
            self.decoder_opt_scheduler = self.get_scheduler(self.decoder_opt)
            self.rf_disc_opt_scheduler = self.get_scheduler(self.rf_disc_opt)


        # Loss
        self.vae_loss = self.get_loss_function()
        self.class_loss = nn.CrossEntropyLoss(reduction=self.reduction)
        self.membership_loss = nn.BCEWithLogitsLoss(reduction=self.reduction)
        self.bce_loss = nn.BCEWithLogitsLoss(reduction=self.reduction)

        self.weights = {
            'recon': args.recon_weight,
            'real_fake': args.real_fake_weight,
            'class_pos': args.class_pos_weight,
            'class_neg': args.class_neg_weight,
            'membership_pos': args.membership_pos_weight,
            'membership_neg': args.membership_neg_weight,
        }

        # To device
        self.device = torch.device("cuda:{}".format(args.gpu_id))
        for encoder_name in self.encoder_name_list:
            self.encoders[encoder_name] = self.encoders[encoder_name].to(self.device)
            self.class_discs[encoder_name] = self.class_discs[encoder_name].to(self.device)
            self.membership_discs[encoder_name] = self.membership_discs[encoder_name].to(self.device)
        self.decoder = self.decoder.to(self.device)
        self.rf_disc = self.rf_disc.to(self.device)

        self.start_epoch = 0
        self.best_valid_loss = float("inf")
        self.early_stop_count = 0
        self.early_stop_count_total = 0
        self.EARLY_STOP_COUNT_TOTAL_MAX = 4 

        self.class_acc_dict = {
            'pn': 0., 'pp': 0., 'np': 0., 'nn': 0.,
        }
        self.membership_acc_dict = {
            'pn': 0., 'pp': 0., 'np': 0., 'nn': 0.,
        }
        # self.class_loss_dict = {
        #     'pn': 0., 'pp': 0., 'np': 0., 'nn': 0.,
        # }
        # self.membership_loss_dict = {
        #     'pn': 0., 'pp': 0., 'np': 0., 'nn': 0.,
        # }
        self.best_class_acc_dict = {}
        self.best_membership_acc_dict = {}

        if 'cuda' in str(self.device):
            cudnn.benchmark = True