Beispiel #1
0
def salience_val_one_epoch(net_f,
                           net,
                           optimizer,
                           batch_generator,
                           criterion,
                           SalenceGenerator,
                           AttackMethod,
                           clock,
                           attack_freq=5,
                           use_adv=True,
                           DEVICE=torch.device('cuda:0')):
    '''
    Using net_f to generate salience maps which is used to train a new clsiffier
    :param net_f: feature network
    :param net:
    :param optimizer:
    :param batch_generator:
    :param criterion:
    :param SalenceGenerator:
    :param clock:
    :return:
    '''
    clean_acc = AvgMeter()
    adv_acc = AvgMeter()

    net_f.eval()
    net.eval()

    #clock.tock()

    pbar = tqdm(batch_generator)

    start_time = time.time()
    for (data, label) in pbar:
        #clock.tick()

        data = data.to(DEVICE)
        label = label.to(DEVICE)

        if clock.minibatch % (attack_freq + 1) == 1 and use_adv:
            net.eval()
            adv_inp = AttackMethod.attack(net_f, data, label)
            salience_data = SalenceGenerator(net_f, adv_inp, label)

            pred = net(salience_data)
            loss = criterion(pred, label)
            acc = torch_accuracy(pred, label, (1, ))
            adv_acc.update(acc[0].item())
        else:
            salience_data = SalenceGenerator(net_f, data, label)
            pred = net(salience_data)
            loss = criterion(pred, label)
            acc = torch_accuracy(pred, label, (1, ))
            clean_acc.update(acc[0].item())
        pbar.set_description("Validation Epoch: {}".format(clock.epoch))
        pbar.set_postfix({
            'clean_acc': clean_acc.mean,
            'adv_acc': adv_acc.mean
        })
    return {'clean_acc': clean_acc.mean, 'adv_acc': adv_acc.mean}
Beispiel #2
0
def val_one_epoch(net,
                  batch_generator,
                  CrossEntropyCriterion,
                  clock,
                  tv_loss_weight=1.0):
    '''

    :param net:   network
    :param optimizer:
    :param batch_generator:   pytorch dataloader or other generator
    :param CrossEntropyCriterion:  Used for calculating CrossEntropy loss
    :param clock: TrainClock from utils
    :return:
    '''

    Acc = AvgMeter()
    CrossLoss = AvgMeter()
    GradTvLoss = AvgMeter()

    net.eval()
    pbar = tqdm(batch_generator)

    for (data, label) in pbar:

        data = data.cuda()
        label = label.cuda()

        data.requires_grad = True

        pred = net(data)

        cross_entropy_loss = CrossEntropyCriterion(pred, label)

        grad_map = torch.autograd.grad(cross_entropy_loss,
                                       data,
                                       create_graph=True,
                                       only_inputs=False)[0]

        grad_tv_loss = TvLoss(grad_map)

        #loss = torch.add(cross_entropy_loss, grad_tv_loss * tv_loss_weight)

        #optimizer.zero_grad()

        acc = torch_accuracy(pred, label, topk=(1, ))[0].item()

        Acc.update(acc)
        CrossLoss.update(cross_entropy_loss.item())
        GradTvLoss.update(grad_tv_loss.item())

        pbar.set_description("Validation Epoch: {}".format(clock.epoch))

        pbar.set_postfix({
            'Acc': '{:.3f}'.format(Acc.mean),
            'CrossLoss': "{:.2f}".format(CrossLoss.mean),
            'GradTvLoss': '{:.3f}'.format(GradTvLoss.mean)
        })

    return Acc.mean, CrossLoss.mean, GradTvLoss.mean
Beispiel #3
0
    def get_batch_accuracy(self, net, inp, label):

        adv_inp = self.attack(net, inp, label)

        pred = net(adv_inp)

        accuracy = torch_accuracy(pred, label, (1, ))[0].item()

        return accuracy
Beispiel #4
0
def eval_one_epoch(net,
                   batch_generator,
                   device=torch.device('cuda:0'),
                   val_attack=None,
                   logger=None):
    # logger.info('test start')
    net.eval()
    pbar = tqdm(batch_generator)
    clean_accuracy = AvgMeter()
    adv_accuracy = AvgMeter()

    pbar.set_description('Evaluating')
    for (data, label) in pbar:
        data = data.to(device)
        label = label.to(device)

        with torch.no_grad():
            pred = net(data)
            acc = torch_accuracy(pred, label, (1, ))
            clean_accuracy.update(acc[0].item())

        if val_attack is not None:
            adv_inp = val_attack.forward(data, label)

            with torch.no_grad():
                pred = net(adv_inp)
                acc = torch_accuracy(pred, label, (1, ))
                adv_accuracy.update(acc[0].item())

        pbar_dic = OrderedDict()
        pbar_dic['CleanAcc'] = '{:.2f}'.format(clean_accuracy.mean)
        pbar_dic['AdvAcc'] = '{:.2f}'.format(adv_accuracy.mean)

        pbar.set_postfix(pbar_dic)

        adv_acc = adv_accuracy.mean if val_attack is not None else 0

        if logger is None:
            pass
        else:
            logger.info(
                f'standard acc: {clean_accuracy.mean:.3f}%, robustness acc: {adv_accuracy.mean:.3f}%'
            )
    return clean_accuracy.mean, adv_acc
Beispiel #5
0
def adversarial_val(net,
                    batch_generator,
                    criterion,
                    AttackMethod,
                    clock,
                    attack_freq=1,
                    DEVICE=torch.device('cuda:0')):
    """
        val both on clean data and adversarial examples.
        :param net:
        :param batch_generator:
        :param criterion:
        :param AttackMethod:  the attack method
        :param clock: clock object from my_snip.clock import TrainClock
        :param attack_freq:  Frequencies of training with adversarial examples
        :return:
        """

    training_losses = AvgMeter()
    training_accs = AvgMeter()

    clean_losses = AvgMeter()
    clean_accs = AvgMeter()

    defense_losses = AvgMeter()
    defense_accs = AvgMeter()
    names = ['loss', 'acc', 'clean_loss', 'clean_acc', 'adv_loss', 'adv_acc']

    clean_batch_times = AvgMeter()
    ad_batch_times = AvgMeter()
    net.eval()

    pbar = tqdm(batch_generator)

    start_time = time.time()
    i = 0
    for (data, label) in pbar:

        i += 1
        data = data.to(DEVICE)
        label = label.to(DEVICE)

        data_time = time.time() - start_time

        if i % (attack_freq + 1) == 1:

            adv_inp = AttackMethod.attack(net, data, label)

            net.eval()

            with torch.no_grad():
                pred = net(adv_inp)
                loss = criterion(pred, label)

                defense_losses.update(loss.item())

                acc = torch_accuracy(pred, label, (1, ))

                defense_accs.update(acc[0].item())

                batch_time = time.time() - start_time
                ad_batch_times.update(batch_time)
        else:

            with torch.no_grad():
                pred = net(data)
                loss = criterion(pred, label)
                acc = torch_accuracy(pred, label, (1, ))

                clean_losses.update(loss.item())
                clean_accs.update(acc[0].item())

                batch_time = time.time() - start_time
                clean_batch_times.update(batch_time)

        training_losses.update(loss.item())
        training_accs.update(acc[0].item())

        pbar.set_description("Validation Epoch: {}".format(clock.epoch))

        values = [
            training_losses.mean, training_accs.mean, clean_losses.mean,
            clean_accs.mean, defense_losses.mean, defense_accs.mean
        ]
        pbar_dic = OrderedDict()
        for n, v in zip(names, values):
            if n not in ['acc', 'clean_acc', 'adv_acc']:
                continue
            pbar_dic[n] = v
        pbar.set_postfix(pbar_dic)
        pbar_dic['clean_time'] = "{:.2f}".format(clean_batch_times.mean)
        pbar_dic['ad_time'] = "{:.2f}".format(ad_batch_times.mean)
        '''
        pbar.set_postfix(
            loss='{:.2f}'.format(training_losses.mean),
            acc='{:.2f}'.format(training_accs.mean),
            clean_losses='{:.2f}'.format(clean_losses.mean),
            clean_acc="{:.2f}".format(clean_accs.mean),
            defense_losses="{:.2f}".format(defense_losses.mean),
            defense_accs="{:.2f}".format(defense_accs.mean),
            clean_time="{:.2f}".format(clean_batch_times.mean),
            ad_time="{:.2f}".format(ad_batch_times.mean)
        )
        '''
        start_time = time.time()

    values = [
        training_losses.mean, training_accs.mean, clean_losses.mean,
        clean_accs.mean, defense_losses.mean, defense_accs.mean
    ]

    dic = {n: v for n, v in zip(names, values)}
    dic = OrderedDict(dic)
    return dic
Beispiel #6
0
def adversairal_train_one_epoch(net,
                                optimizer,
                                batch_generator,
                                criterion,
                                AttackMethod,
                                clock,
                                attack_freq=1,
                                use_adv=True,
                                DEVICE=torch.device('cuda:0'),
                                act_loss_coef=0):
    """
    adversarial training.
    :param net:
    :param optimizer:
    :param batch_generator:
    :param criterion:
    :param AttackMethod:  the attack method
    :param clock: clock object from my_snip.clock import TrainClock
    :param attack_freq:  Frequencies of training with adversarial examples
    :return:
    """

    training_losses = AvgMeter()
    training_accs = AvgMeter()

    clean_losses = AvgMeter()
    clean_accs = AvgMeter()

    defense_losses = AvgMeter()
    defense_accs = AvgMeter()
    names = ['loss', 'acc', 'clean_loss', 'clean_acc', 'adv_loss', 'adv_acc']

    clean_batch_times = AvgMeter()
    ad_batch_times = AvgMeter()
    net.train()
    clock.tock()

    pbar = tqdm(batch_generator)

    start_time = time.time()
    for (data, label) in pbar:
        clock.tick()

        data = data.to(DEVICE)
        label = label.to(DEVICE)

        data_time = time.time() - start_time

        if clock.minibatch % (attack_freq + 1) == 1 and use_adv:

            adv_inp = AttackMethod.attack(net, data, label)
            #print("ADV ", torch.mean(adv_inp).item())
            net.train()
            optimizer.zero_grad()

            pred = net(adv_inp)
            loss = criterion(pred, label)

            loss.backward()

            optimizer.step()

            defense_losses.update(loss.item())

            acc = torch_accuracy(pred, label, (1, ))

            defense_accs.update(acc[0].item())

            batch_time = time.time() - start_time
            ad_batch_times.update(batch_time)
        else:
            optimizer.zero_grad()
            #print("Clean ", torch.mean(data).item())

            pred, act_loss = net(data, if_return_activation=True)
            loss = criterion(pred, label) + act_loss_coef * act_loss
            print(loss, act_loss)
            loss.backward()
            optimizer.step()
            acc = torch_accuracy(pred, label, (1, ))

            clean_losses.update(loss.item())
            clean_accs.update(acc[0].item())

            batch_time = time.time() - start_time
            clean_batch_times.update(batch_time)

        training_losses.update(loss.item())
        training_accs.update(acc[0].item())

        pbar.set_description("Training Epoch: {}".format(clock.epoch))

        values = [
            training_losses.mean, training_accs.mean, clean_losses.mean,
            clean_accs.mean, defense_losses.mean, defense_accs.mean
        ]
        pbar_dic = OrderedDict()
        for n, v in zip(names, values):
            if n not in ['acc', 'clean_acc', 'adv_acc']:
                continue
            pbar_dic[n] = v
        pbar_dic['clean_time'] = "{:.1f}".format(clean_batch_times.mean)
        pbar_dic['ad_time'] = "{:.1f}".format(ad_batch_times.mean)
        pbar.set_postfix(pbar_dic)
        '''
        pbar.set_postfix(
            loss = '{:.2f}'.format(training_losses.mean),
            acc = '{:.2f}'.format(training_accs.mean),
            clean_losses = '{:.2f}'.format(clean_losses.mean),
            clean_acc = "{:.2f}".format(clean_accs.mean),
            defense_losses = "{:.2f}".format(defense_losses.mean),
            defense_accs = "{:.2f}".format(defense_accs.mean),
            clean_time = "{:.2f}".format(clean_batch_times.mean),
            ad_time = "{:.2f}".format(ad_batch_times.mean)
        )'''

        start_time = time.time()

    #names = ['loss', 'acc', 'clean_loss', 'clean_acc', 'adv_loss', 'adv_acc']
    values = [
        training_losses.mean, training_accs.mean, clean_losses.mean,
        clean_accs.mean, defense_losses.mean, defense_accs.mean
    ]

    dic = {n: v for n, v in zip(names, values)}
    dic = OrderedDict(dic)
    return dic
Beispiel #7
0
def train_one_epoch(net,
                    batch_generator,
                    optimizer,
                    criterion,
                    device=torch.device('cuda:0'),
                    descrip_str='Training',
                    attack_method=None,
                    adv_coef=1.0,
                    logger=None):
    # 设置为训练模式
    net.train()

    # 初始化进度条,以及acc、loss
    pbar = tqdm(batch_generator)
    adv_acc = -1
    adv_loss = -1
    clean_acc = -1
    clean_loss = -1
    pbar.set_description(descrip_str)

    # 分批次取数据进行训练
    for i, (data, label) in enumerate(pbar):
        data = data.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        pbar_dic = OrderedDict()

        # 对当前batch生成对抗样本
        if attack_method is not None:
            adv_inp = attack_method.forward(data, label)
            optimizer.zero_grad()
            net.train()

            # 前传
            pred = net(adv_inp)
            loss = criterion(pred, label)
            acc = torch_accuracy(pred, label, (1, ))
            adv_acc = acc[0].item()
            adv_loss = loss.item()

            # 反传
            (loss * adv_coef).backward()

        pred = net(data)

        loss = criterion(pred, label)
        # TotalLoss = TotalLoss + loss
        loss.backward()
        # TotalLoss.backward()
        # param = next(net.parameters())
        # grad_mean = torch.mean(param.grad)

        optimizer.step()
        acc = torch_accuracy(pred, label, (1, ))
        clean_acc = acc[0].item()
        clean_loss = loss.item()
        pbar_dic['Acc'] = '{:.2f}'.format(clean_acc)
        pbar_dic['loss'] = '{:.2f}'.format(clean_loss)
        pbar_dic['Adv Acc'] = '{:.2f}'.format(adv_acc)
        pbar_dic['Adv loss'] = '{:.2f}'.format(adv_loss)
        pbar.set_postfix(pbar_dic)

        if logger is None:
            pass
        else:
            logger.info(
                f'standard loss: {clean_loss:.3f}, Adv loss: {adv_loss:.3f}')
            logger.info(
                f'standard acc: {clean_acc:.3f}%, robustness acc: {adv_acc:.3f}%'
            )