コード例 #1
0
ファイル: trainer.py プロジェクト: xinshi-chen/l2stop
    def _valid_epoch(self, epoch):
        """
        validation after training an epoch
        :return:
        """
        self.score_net.eval()

        x, y = next(iter(self.data_loader.test_loader))
        x = x.to(self.device)
        y = y.to(self.device)
        batch_size = y.shape[0]
        with torch.no_grad():
            xhs = self.sdn_model(x)
            # internal_fm = self.sdn_model.internal_fm
            # self.sdn_model.internal_fm = [None]*len(internal_fm)
            internal_fm = torch.rand(2, 2)
            scores = self.score_net(x, internal_fm, xhs)
            # scores = self.score_net(x, xhs)

            stop_idx = self.q_posterior(self.args.policy_type,
                                        scores,
                                        stochastic=False,
                                        device=self.device)
            q = self.q_posterior(self.args.policy_type,
                                 scores,
                                 stochastic=True,
                                 device=self.device)

            if epoch == 20 or epoch == 50 or epoch == 70 or epoch == 99:
                stop_idx = self.q_posterior(self.args.policy_type,
                                            scores,
                                            stochastic=False,
                                            device=self.device)
                q_p = self.q_posterior(self.args.policy_type,
                                       scores,
                                       stochastic=True,
                                       device=self.device)
                q_p_idx = torch.argmax(q_p, dim=-1)
                p_true, _ = self.true_posterior(self.args, xhs, y)
                p_true_b = max_onehot(p_true, dim=-1, device=self.device)
                p_true_idx = torch.argmax(p_true_b, dim=-1)
                print('Here is the policy classification validation accuracy:')
                print(data.accuracy(q_p, p_true_idx))
                # pdb.set_trace()

            # validation loss
            if self.args.kl_type == 'forward':
                loss, _ = self.forward_kl_loss(y, xhs, scores, p_det=False)
            else:
                assert self.args.kl_type == 'backward'
                loss, _ = self.backward_kl_loss(y, xhs, scores)

            if self.args.stochastic:
                log = {'val loss': loss, 'sto q': torch.mean(q, dim=0)}
            else:
                log = {'val loss': loss, 'det q': torch.mean(stop_idx, dim=0)}
        return log, log
コード例 #2
0
ファイル: trainer.py プロジェクト: xinshi-chen/l2stop
    def forward_kl_loss(self, y, xhs, scores, p_det=True):
        batch_size = y.shape[0]
        # true posterior
        p_true, _ = self.true_posterior(self.args, xhs, y)
        assert batch_size == p_true.shape[0]
        p = torch.stack([p_true[:, t] for t in self.nz_post.values()], dim=1)
        # mse = torch.stack([mse_all[:, t] for t in self.nz_post.values()], dim=1)

        if p_det:
            p = max_onehot(p, dim=-1, device=self.device)
        # parameteric log posterior
        log_q_pi = self.log_q_posterior(self.args.policy_type, scores)
        assert batch_size == log_q_pi.shape[0]
        # pdb.set_trace()
        return -torch.sum(p * log_q_pi, dim=-1).mean(), -torch.sum(
            p * log_q_pi, dim=-1).mean()
コード例 #3
0
ファイル: trainer.py プロジェクト: xinshi-chen/l2stop
 def q_posterior(type, scores, stochastic=True, device='cuda'):
     if (type == 'multiclass') or (type == 'confidence'):
         if stochastic:
             return F.softmax(scores, dim=-1)
         else:
             return max_onehot(scores, dim=-1, device=device)
     if type == 'sequential':
         batch_size, num_train_post = scores.shape
         q_pi = []
         pi = F.sigmoid(scores)
         if not stochastic:
             pi = (pi > 0.5).float()
         q_cont = torch.ones(batch_size).to(device)
         for i in range(num_train_post):
             # prob of stop at t
             q_pi.append(((1 - pi[:, i]) * q_cont).view(-1, 1))
             # prob of continue to next
             q_cont = q_cont * pi[:, i]
         q_pi.append(q_cont.view(-1, 1))
         return torch.cat(q_pi, dim=-1)
コード例 #4
0
ファイル: trainer.py プロジェクト: xinshi-chen/l2stop
    def _update_policy(self, x, y):
        self.model.eval()
        self.score_net.train()

        # pdb.set_trace()
        xhs = self.model(x)
        internal_fm = torch.rand(2, 2)
        scores = self.score_net(x, internal_fm, xhs)

        self.optimizer2.zero_grad()
        # loss
        # true posterior
        p_true, _ = PolicyKL.true_posterior(self.args, xhs, y)
        p = torch.stack([p_true[:, t] for t in self.nz_post.values()], dim=1)
        p = max_onehot(p, dim=-1)
        # parameteric log posterior
        log_q_pi = PolicyKL.log_q_posterior(self.args.policy_type, scores)
        loss = -torch.sum(p * log_q_pi, dim=-1).mean()

        # backward
        loss.backward()
        self.optimizer2.step()
        return loss
コード例 #5
0
ファイル: trainer.py プロジェクト: xinshi-chen/l2stop
    def _train_epoch(self, epoch):
        # n outputs, n-1 nets
        self.score_net.train()

        total_loss = 0.0
        for i, batch in enumerate(self.train_data_generator):
            # generate path
            x, y = batch
            x = x.to(self.device)
            y = y.to(self.device)
            batch_size = y.shape[0]

            xhs = self.sdn_model(x)
            # internal_fm = self.sdn_model.internal_fm
            # self.sdn_model.internal_fm = [None]*len(internal_fm)
            internal_fm = torch.rand(2, 2)

            # if self.args.policy_type == 'sequential':
            #     pi_all = []
            #     for i, t in self.train_post.items():
            #         pi_all.append(self.policy_nets[i](y, xhs[t]).view(-1))
            # if self.args.policy_type == 'multiclass':
            #     pi_all = self.policy_nets(y, xhs)

            scores = self.score_net(x, internal_fm, xhs)

            # scores = self.score_net(x, xhs)

            self.optimizer.zero_grad()
            # loss
            if self.args.kl_type == 'forward':
                loss, _ = self.forward_kl_loss(y, xhs, scores, p_det=False)
            else:
                assert self.args.kl_type == 'backward'
                loss, _ = self.backward_kl_loss(y, xhs, scores)

            # backward
            loss.backward()
            self.optimizer.step()

            if i % self.args.iters_per_eval == 0:
                print('Epoch: {}, Step: {}, Loss: {}'.format(epoch, i, loss))

            total_loss += loss.item()

        if epoch == 20 or epoch == 50 or epoch == 70 or epoch == 99:
            self.score_net.eval()
            x, y = next(iter(self.train_data_generator))
            x = x.to(self.device)
            y = y.to(self.device)
            batch_size = y.shape[0]

            xhs = self.sdn_model(x)
            internal_fm = torch.rand(2, 2)
            scores = self.score_net(x, internal_fm, xhs)

            stop_idx = self.q_posterior(self.args.policy_type,
                                        scores,
                                        stochastic=False,
                                        device=self.device)
            q_p = self.q_posterior(self.args.policy_type,
                                   scores,
                                   stochastic=True,
                                   device=self.device)
            q_p_idx = torch.argmax(q_p, dim=-1)
            p_true, _ = self.true_posterior(self.args, xhs, y)
            p_true_b = max_onehot(p_true, dim=-1, device=self.device)
            p_true_idx = torch.argmax(p_true_b, dim=-1)
            print('Here is the policy classification training accuracy:')
            print(data.accuracy(q_p, p_true_idx))
            # p_true_max, _ = self.true_posterior_max(self.args, xhs, y)
            # p_true_max_b = max_onehot(p_true_max, dim=-1, device=self.device)
            # p_true_max_b_idx =  torch.argmax(p_true_max_b, dim=-1)

            # pdb.set_trace()
        log = {'epo': epoch, 'train loss': total_loss / i}

        return log
コード例 #6
0
def policy_training(models_path, device='cpu'):
    #sdn_name = 'cifar10_vgg16bn_bd_sdn_converted'; add_trigger = True  # for the backdoored network

    add_trigger = False

    #task = 'cifar10'
    # task = 'cifar100'
    task = 'tinyimagenet'

    network = 'vgg16bn'
    #network = 'resnet56'
    #network = 'wideresnet32_4'
    #network = 'mobilenet'

    # sdn_name = task + '_' + network + '_sdn_ic_only'
    sdn_name = task + '_' + network + '_sdn_ic_only_ic1'
    # sdn_name = task + '_' + network + '_sdn_ic_only_ic1_ds'
    # sdn_name = task + '_' + network + '_sdn_sdn_training'
    # sdn_name = task + '_' + network + '_sdn_sdn_training_ds'
    # sdn_name = task + '_' + network + '_sdn_sdn_training_ic14_ds'

    sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1)
    sdn_model.to(device)
    dataset = af.get_dataset(sdn_params['task'], add_trigger=add_trigger)

    # need to construct the policy network and train the policy net.
    # the architecture of the policy network need to be designed.

    ######################################
    # need to think about the model of policynet
    ######################################
    sdn_model.eval()
    p_true_all = list()
    xhs_all = list()
    y_all = list()
    for batch in dataset.val_loader:
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        batch_size = y.shape[0]
        with torch.no_grad():
            xhs = sdn_model(x)
            categories = xhs[-1].shape[-1]
            # pdb.set_trace()
            # internal_fm = sdn_model.internal_fm
            # sdn_model.internal_fm = [None]*len(internal_fm)
            p_true, _ = PolicyKL.true_posterior(cmd_args, xhs, y)

        xhs_all.append(xhs)
        y_all.append(y)
        p_true_all.append(p_true)

    p_true = torch.cat(p_true_all, dim=0)
    p_det = max_onehot(p_true, dim=-1, device=device)
    p_true = torch.mean(p_true, dim=0)
    # find positions with nonzero posterior
    train_post = {}
    nz_post = {}
    i = 0
    for t in range(cmd_args.num_output):
        if p_true[t] > 0.001:
            train_post[i] = t
            nz_post[i] = t
            i += 1
    del train_post[i - 1]

    p_str = 'val p true:['
    p_str += ','.join(['%0.3f' % p_true[t] for t in nz_post.values()])
    print(p_str + ']')

    p_det = torch.mean(p_det, dim=0)
    p_str = 'val p true det:['
    p_str += ','.join(['%0.3f' % p_det[t] for t in nz_post.values()])
    print(p_str + ']')
    ######################################

    ####
    #check the performance based on confidence score
    ####
    y_all = torch.cat(y_all, dim=-1)
    xhs_all = list(zip(*xhs_all))
    for i in range(len(xhs_all)):
        xhs_all[i] = torch.cat(xhs_all[i], dim=0)
        print('The {}th classifier performance:'.format(i))
        prec1, prec5 = data.accuracy(xhs_all[i], y_all, topk=(1, 5))
        print('Top1 Test accuracy: {}'.format(prec1))
        print('Top5 Test accuracy: {}'.format(prec5))

    xhs_all = list(map(lambda x: F.softmax(x, dim=-1), xhs_all))
    max_confidences = list(map(lambda x: torch.max(x, dim=-1)[0], xhs_all))
    max_confidences = torch.stack(max_confidences, dim=-1)
    xhs_all_stack = torch.stack(xhs_all, dim=1)
    predictions = list(map(lambda x: torch.argmax(x, dim=-1), xhs_all))
    predictions = torch.stack(predictions, dim=-1)

    thresholds = [
        0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999,
        -1
    ]
    # thresholds = [0.8, 0.9, 0.95, 0.99, 0.999, -1]
    for threshold in thresholds:
        if threshold == -1:
            index = torch.argmax(max_confidences, dim=-1).cpu().numpy()
            # pdb.set_trace()
        else:
            mask = (max_confidences > threshold).to(int).cpu().numpy()
            mask[:, -1] = 1
            index = np.array(list(map(lambda x: list(x).index(1), list(mask))))
        results = xhs_all_stack.gather(
            1,
            torch.Tensor([index] * 200).t().view(
                -1, 1, 200).long().to(device)).squeeze()

        prec1, prec5 = data.accuracy(results, y_all, topk=(1, 5))
        print('htreshold: ', threshold)
        print('Top1 Test accuracy: {}'.format(prec1))
        print('Top5 Test accuracy: {}'.format(prec5))
    ####
    #confidence score check finish
    ####

    # pdb.set_trace()
    internal_fm = [torch.rand(2, 2) for i in range(cmd_args.num_output)]
    # initialize nets with nonzero posterior
    if cmd_args.model_type == 'sequential':
        score_net = MNIconfidence(cmd_args,
                                  x,
                                  internal_fm,
                                  train_post,
                                  category=categories,
                                  share=cmd_args.share,
                                  net_size=cmd_args.net_size)
        score_net.to(device)
        # print('Sequential model to be implemented')
    if cmd_args.model_type == 'multiclass':
        score_net = MulticlassNetImage(cmd_args,
                                       x,
                                       internal_fm,
                                       train_post,
                                       category=categories)
        score_net.to(device)
    if cmd_args.model_type == 'confidence':
        score_net = MNIconfidence(cmd_args,
                                  x,
                                  internal_fm,
                                  train_post,
                                  category=categories,
                                  share=cmd_args.share,
                                  net_size=cmd_args.net_size)
        score_net.to(device)
    if cmd_args.model_type == 'imiconfidence':
        score_net = Imiconfidence(cmd_args,
                                  x,
                                  internal_fm,
                                  train_post,
                                  category=categories,
                                  share=cmd_args.share,
                                  net_size=cmd_args.net_size)
        score_net.to(device)

    # train
    if cmd_args.phase == 'train':

        # start training
        optimizer = optim.Adam(list(score_net.parameters()),
                               lr=cmd_args.learning_rate,
                               weight_decay=cmd_args.weight_decay)
        milestones = [10, 20, 40, 60, 80]
        gammas = [0.4, 0.2, 0.2, 0.2, 0.2]
        scheduler = MultiStepMultiLR(optimizer,
                                     milestones=milestones,
                                     gammas=gammas)
        trainer = PolicyKL(args=cmd_args,
                           sdn_model=sdn_model,
                           score_net=score_net,
                           train_post=train_post,
                           nz_post=nz_post,
                           optimizer=optimizer,
                           data_loader=dataset,
                           device=device,
                           scheduler=scheduler,
                           sdn_name=sdn_name)
        trainer.train()
        #pdb.set_trace()
    # test
    dump = cmd_args.save_dir + '/{}_best_val_policy.dump'.format(sdn_name)
    print('Loading model...')
    score_net.load_state_dict(torch.load(dump))

    PolicyKL.test(args=cmd_args,
                  score_net=score_net,
                  sdn_model=sdn_model,
                  data_loader=dataset.test_loader,
                  nz_post=nz_post,
                  device=device)
    print(cmd_args.save_dir)