Esempio n. 1
0
    def forward(self,
                x,
                target=None,
                mixup_hidden=False,
                mixup_alpha=0.1,
                layers_mix=None):
        #print x.shape

        if mixup_hidden == True:
            layer_mix = random.randint(0, layers_mix)

            out = x

            if layer_mix == 0:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = self.conv1(x)

            out = self.layer1(out)

            if layer_mix == 1:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = self.layer2(out)

            if layer_mix == 2:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = self.layer3(out)

            if layer_mix == 3:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = act(self.bn1(out))
            out = F.avg_pool2d(out, 8)
            out = out.view(out.size(0), -1)
            out = self.linear(out)

            if layer_mix == 4:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            lam = torch.tensor(lam).cuda()
            lam = lam.repeat(y_a.size())
            return out, y_a, y_b, lam

        else:
            out = x
            out = self.conv1(x)
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = act(self.bn1(out))
            out = F.avg_pool2d(out, 8)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
            return out
Esempio n. 2
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0.0
    correct = 0.0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        mask = random.random()

        if epoch >= 90:
            # threshold = math.cos( math.pi * (epoch - 150) / ((200 - 150) * 2))
            threshold = (100 - epoch) / (100 - 90)
            # threshold = 1.0 - math.cos( math.pi * (200 - epoch) / ((200 - 150) * 2))
            if mask < threshold:
                inputs, targets_a, targets_b, lam = mixup_data(
                    inputs, targets, args.alpha, use_cuda)
            else:
                targets_a, targets_b = targets, targets
                lam = 1.0
        elif epoch >= 60:
            if epoch % 2 == 0:
                inputs, targets_a, targets_b, lam = mixup_data(
                    inputs, targets, args.alpha, use_cuda)
            else:
                targets_a, targets_b = targets, targets
                lam = 1.0
        else:
            inputs, targets_a, targets_b, lam = mixup_data(
                inputs, targets, args.alpha, use_cuda)

        optimizer.zero_grad()
        inputs, targets_a, targets_b = Variable(inputs), Variable(
            targets_a), Variable(targets_b)
        outputs = net(inputs)
        loss_func = mixup_criterion(targets_a, targets_b, lam)
        loss = loss_func(criterion, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += lam * predicted.eq(targets_a.data).cpu().sum().item() + (
            1.0 - lam) * predicted.eq(targets_b.data).cpu().sum().item()

        progress_bar(
            batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss / (batch_idx + 1),
             (100. * correct) / total, correct, total))
    return (train_loss / batch_idx, 100. * correct / total)
Esempio n. 3
0
def train(train_loader, net, criterion, optimizer, epoch, device):
    global writer

    start = time.time()
    net.train()

    train_loss = 0
    correct = 0
    total = 0
    logger.info(" === Epoch: [{}/{}] === ".format(epoch + 1, config.epochs))

    for batch_index, (inputs, targets) in enumerate(train_loader):
        # move tensor to GPU
        inputs, targets = inputs.to(device), targets.to(device)
        if config.mixup:
            inputs, targets_a, targets_b, lam = mixup_data(
                inputs, targets, config.mixup_alpha, device)

            outputs = net(inputs)
            loss = mixup_criterion(
                criterion, outputs, targets_a, targets_b, lam)
        else:
            outputs = net(inputs)
            loss = criterion(outputs, targets)

        # zero the gradient buffers
        optimizer.zero_grad()
        # backward
        loss.backward()
        # update weight
        optimizer.step()

        # count the loss and acc
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        if config.mixup:
            correct += (lam * predicted.eq(targets_a).sum().item()
                        + (1 - lam) * predicted.eq(targets_b).sum().item())
        else:
            correct += predicted.eq(targets).sum().item()

        if (batch_index + 1) % 100 == 0:
            logger.info("   == step: [{:3}/{}], train loss: {:.3f} | train acc: {:6.3f}% | lr: {:.6f}".format(
                batch_index + 1, len(train_loader),
                train_loss / (batch_index + 1), 100.0 * correct / total, get_current_lr(optimizer)))

    logger.info("   == step: [{:3}/{}], train loss: {:.3f} | train acc: {:6.3f}% | lr: {:.6f}".format(
        batch_index + 1, len(train_loader),
        train_loss / (batch_index + 1), 100.0 * correct / total, get_current_lr(optimizer)))

    end = time.time()
    logger.info("   == cost time: {:.4f}s".format(end - start))
    train_loss = train_loss / (batch_index + 1)
    train_acc = correct / total

    writer.add_scalar('train_loss', train_loss, global_step=epoch)
    writer.add_scalar('train_acc', train_acc, global_step=epoch)

    return train_loss, train_acc
Esempio n. 4
0
def train(model,
          device,
          train_loader,
          optimizer,
          scheduler,
          loss_func,
          mixup=False):
    model.train()
    epoch_train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        if mixup:
            data, targets_a, targets_b, lam = mixup_data(data,
                                                         target,
                                                         alpha=1.0)

        optimizer.zero_grad()
        output = model(data)
        if mixup:
            loss = mixup_criterion(loss_func, output, targets_a, targets_b,
                                   lam)
        else:
            loss = loss_func(output, target)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item() * data.size(0)
    scheduler.step()
    loss = epoch_train_loss / len(train_loader.dataset)
    del data
    return loss
Esempio n. 5
0
def train(
    model,
    device,
    train_loader,
    optimizer,
    criterion,
    epoch,
    mixup=False,
    avg_meter=None,
):
    model.train()
    batch_loss = list()
    alpha = 0.2 if mixup else 0
    lam = None  # Required if doing mixup training

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data, target_a, target_b, lam = mixup_data(
            data, target, device, alpha
        )  # Targets here correspond to the pair of examples used to create the mix
        optimizer.zero_grad()
        output = model(data)
        loss = mixup_criterion(criterion, output, target_a, target_b, lam)
        loss.backward()
        optimizer.step()
        batch_loss.append(loss.item())
        if avg_meter is not None:
            avg_meter.update(batch_loss[-1], n=len(data))

    return batch_loss
Esempio n. 6
0
def train():
    params = split_weights(model) if opt.no_wd else model.parameters()
    optimizer = optim.SGD(params, lr=base_lr, momentum=0.9, nesterov=True, weight_decay=0.0001)

    Loss = nn.CrossEntropyLoss()
    metric_loss = mloss()
    alpha = 1. if mixup else 0.
    iterations = 0
    for epoch in range(epochs):
        model.train()
        metric_loss.reset()
        st_time = time.time()
        if mixup and epoch > epochs - 20:
            alpha = 0.
        for i, (trans, labels) in enumerate(train_data):
            trans, targets_a, targets_b, lam = mixup_data(trans.cuda(), labels.cuda(), alpha=alpha)
            trans, targets_a, targets_b = map(Variable, (trans, targets_a, targets_b))

            optimizer.zero_grad()
            outputs = model(trans)
            loss = mixup_criterion(Loss, outputs, targets_a, targets_b, lam)
            loss.backward()
            optimizer.step()

            metric_loss.update(loss)
            iterations += 1
            lr_scheduler.update(optimizer, iterations)
        learning_rate = lr_scheduler.get()
        met_name, metric = metric_loss.get()
        epoch_time = time.time() - st_time
        epoch_str = 'Epoch {}. Train {}: {:.5f}. {} samples/s. lr {:.5}'. \
            format(epoch, met_name, metric, int(num_train_samples // epoch_time), learning_rate)
        logger.info(epoch_str)
        test(epoch, True)
Esempio n. 7
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        # generate mixed inputs, two one-hot label vectors and mixing coefficient
        inputs, targets_a, targets_b, lam = mixup_data(inputs, targets,
                                                       args.alpha, use_cuda)
        optimizer.zero_grad()
        inputs, targets_a, targets_b = Variable(inputs), Variable(
            targets_a), Variable(targets_b)
        outputs = net(inputs)

        loss_func = mixup_criterion(targets_a, targets_b, lam)
        loss = loss_func(criterion, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += lam * predicted.eq(targets_a.data).cpu().sum() + (
            1 - lam) * predicted.eq(targets_b.data).cpu().sum()

        progress_bar(
            batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))
    return (train_loss / batch_idx, 100. * correct / total)
Esempio n. 8
0
    def training_step(self, batch, batch_idx):
        x, y, idx = batch
        x, y_a, y_b, lam = mixup_data(x, y)
        y_hat = self.forward(x)
        loss = mixup_criterion(self.crit, y_hat, y_a.float(), y_b.float(), lam)
        self.log('trn/_loss', loss)

        return loss
Esempio n. 9
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    total_gnorm = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        # generate mixed inputs, two one-hot label vectors and mixing coefficient
        optimizer.zero_grad()

        if args.train_loss == 'mixup':
            inputs, targets_a, targets_b, lam = mixup_data(
                inputs, targets, args.alpha, use_cuda)
            outputs = net(inputs)

            loss_func = mixup_criterion(targets_a, targets_b, lam)
            loss = loss_func(criterion, outputs)
        else:
            outputs = net(inputs)
            loss = cel(outputs, targets)

        loss.backward()

        if args.train_clip > 0:
            gnorm = torch.nn.utils.clip_grad_norm_(net.parameters(),
                                                   args.train_clip)
        else:
            gnorm = -1
        total_gnorm += gnorm

        optimizer.step()
        sgdr.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        _, predicted = torch.max(outputs.data, 1)
        correct += predicted.eq(targets.data).cpu().sum()
        acc = 100. * float(correct) / float(total)

        if batch_idx % 50 == 0 or batch_idx == len(trainloader) - 1:
            wnorms = [
                w.norm().item() for n, w in net.named_parameters()
                if 'weight' in n
            ]
            print(
                batch_idx, len(trainloader),
                'Loss: %.3f | Acc: %.3f%% (%d/%d) | WNorm: %.3e (min: %.3e, max: %.3e) | GNorm: %.3e (%.3e)'
                % (train_loss / (batch_idx + 1), acc, correct, total,
                   sum(wnorms), min(wnorms), max(wnorms), gnorm, total_gnorm /
                   (batch_idx + 1)))

    return train_loss / batch_idx, acc
Esempio n. 10
0
def mixup_train(loader, model, criterion, optimizer, epoch, use_cuda):
    global BEST_ACC, LR_STATE
    # switch to train mode
    if not cfg.CLS.fix_bn:
        model.train()
    else:
        model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for batch_idx, (inputs, targets) in enumerate(loader):
        # adjust learning rate
        adjust_learning_rate(optimizer, epoch, batch=batch_idx, batch_per_epoch=len(loader))

        # mixup
        inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, ALPHA)
        if use_cuda:
            inputs, targets_a, targets_b = inputs.cuda(), targets_a.cuda(), targets_b.cuda()
        inputs, targets_a, targets_b = torch.autograd.Variable(inputs), torch.autograd.Variable(targets_a), \
                                       torch.autograd.Variable(targets_b)

        # measure data loading time
        data_time.update(time.time() - end)

        # forward pass: compute output
        outputs = model(inputs)
        # forward pass: compute gradient and do SGD step
        optimizer.zero_grad()
        loss_func = mixup_criterion(targets_a, targets_b, lam)
        loss = loss_func(criterion, outputs)
        # backward
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        # measure accuracy and record loss
        prec1, prec5 = [0.0], [0.0]
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        if (batch_idx + 1) % cfg.CLS.disp_iter == 0:
            print('Training: [{}/{}][{}/{}] | Best_Acc: {:4.2f}% | Time: {:.2f} | Data: {:.2f} | '
                  'LR: {:.8f} | Top1: {:.4f}% | Top5: {:.4f}% | Loss: {:.4f} | Total: {:.2f}'
                  .format(epoch + 1, cfg.CLS.epochs, batch_idx + 1, len(loader), BEST_ACC, batch_time.average(),
                          data_time.average(), LR_STATE, top1.avg, top5.avg, losses.avg,
                          batch_time.sum + data_time.sum))

    return (losses.avg, top1.avg)
Esempio n. 11
0
def train(epoch, criterion_list, optimizer):
    train_loss = 0.
    train_loss_cls = 0.
    train_loss_div = 0.
    top1_num = 0
    top5_num = 0
    total = 0

    lr = adjust_lr(optimizer, epoch, args)
    start_time = time.time()
    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]

    net.train()
    for batch_idx, (input, target) in enumerate(trainloader):
        batch_start_time = time.time()
        
        input = input.cuda()
        target = target.cuda()
        input, targets_a, targets_b, lam = mixup_data(input, target, 0.4)

        logit = net(input)
        #loss_cls = criterion_cls(logit, target)
        loss_cls = mixup_criterion(CrossEntropyLoss_label_smooth, logit, targets_a, targets_b, lam)
        loss = loss_cls

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() / len(trainloader)
        train_loss_cls += loss_cls.item() / len(trainloader)

        top1, top5 = correct_num(logit, target, topk=(1, 5))
        top1_num += top1
        top5_num += top5
        total += target.size(0)
        
        
        print('Epoch:{},batch_idx:{}/{}'.format(epoch, batch_idx, len(trainloader)),  'acc:', top1_num.item() / total, 'duration:', time.time()-batch_start_time)
    
    print('Epoch:{}\t lr:{:.5f}\t duration:{:.3f}'
                '\n train_loss:{:.5f}\t train_loss_cls:{:.5f}'
                '\n top1_acc: {:.4f} \t top5_acc:{:.4f}'
                .format(epoch, lr, time.time() - start_time,
                        train_loss, train_loss_cls,
                        (top1_num/total).item(), (top5_num/total).item()))

    with open(log_txt, 'a+') as f:
        f.write('Epoch:{}\t lr:{:.5f}\t duration:{:.3f}'
                '\ntrain_loss:{:.5f}\t train_loss_cls:{:.5f}'
                '\ntop1_acc: {:.4f} \t top5_acc:{:.4f} \n'
                .format(epoch, lr, time.time() - start_time,
                        train_loss, train_loss_cls,
                        (top1_num/total).item(), (top5_num/total).item()))
Esempio n. 12
0
    def train(self, x_val, y_val):
        """
        Trains the network and backpropagates
        
        Args:
            x_val: input to layer as minibatches
            y_val: labels

        Returns:
            Returns model output and total loss
        """
        '''
        fit -> train -> forward -> returns output -> calculate loss-> train returns loss
        '''
        x = Variable(x_val, requires_grad=False)
        y = Variable(y_val, requires_grad=False)

        # Mixup prepare
        x, y_a, y_b, lam = utils.mixup_data(x, y, self.alpha)

        x = Variable(x, requires_grad=False)
        y_a = Variable(y_a, requires_grad=False)
        y_b = Variable(y_b, requires_grad=False)

        self.optimizer.zero_grad()

        output = self.forward(x)

        # Mixup criterion - supervisor wanted this loss
        loss_mixup = lam * self.loss(output, y_a) + (1 - lam) * self.loss(
            output, y_b)

        # Weight decay
        L2_decay_sum = 0
        for name, param in self.named_parameters():
            if 'weight' in name:
                name_id = str(name.split('.')[0])
                # Get the string name of the layer - layer type
                layer_name = copy.deepcopy(
                    self._modules[name_id].__class__.__name__)
                if layer_name == 'Conv2d' or layer_name == 'Linear' or layer_name == "depthwise_separable_conv":
                    L2_decay_sum += 0.0005 * torch.norm(param.view(-1),
                                                        2)  # Regularization

        # Total loss
        loss_loc = loss_mixup + L2_decay_sum

        # Updates the parameters at the end of the minibatch

        loss_loc.backward(retain_graph=True)

        # Update the optimizer
        self.optimizer.step()

        return output, loss_loc.data
Esempio n. 13
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    global Train_acc
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    if epoch > learning_rate_decay_start and learning_rate_decay_start >= 0:
        frac = (epoch - learning_rate_decay_start) // learning_rate_decay_every
        decay_factor = learning_rate_decay_rate ** frac
        current_lr = opt.lr * decay_factor
        utils.set_lr(optimizer, current_lr)  # set the decayed rate
    else:
        current_lr = opt.lr
    print('learning_rate: %s' % str(current_lr))

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        
        if opt.mixup:
            inputs, targets_a, targets_b, lam = utils.mixup_data(inputs, targets, 0.6, True)
            inputs, targets_a, targets_b = map(Variable, (inputs, targets_a, targets_b))
        else:
            inputs, targets = Variable(inputs), Variable(targets)
        
        outputs = net(inputs)
        
        if opt.mixup:
            loss = utils.mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        else:
            loss = criterion(outputs, targets)
        
        loss.backward()
        utils.clip_gradient(optimizer, 0.1)
        optimizer.step()
        train_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        
        if opt.mixup:
            correct += (lam * predicted.eq(targets_a.data).cpu().sum().float()
                    + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())
        else:
            correct += predicted.eq(targets.data).cpu().sum()
       
        utils.progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*float(correct)/float(total), correct, total))

    Train_acc = 100.*float(correct)/float(total)
    
    return train_loss/(batch_idx+1), Train_acc
Esempio n. 14
0
 def reader():
     batch_data = []
     batch_label = []
     for data, label in read_batch(datasets, args):
         batch_data.append(data)
         batch_label.append(label)
         if len(batch_data) == args.batch_size:
             batch_data = np.array(batch_data, dtype='float32')
             batch_label = np.array(batch_label, dtype='int64')
             if is_training:
                 flatten_label, flatten_non_label = \
                   generate_reshape_label(batch_label, args.batch_size)
                 rad_var = generate_bernoulli_number(args.batch_size)
                 mixed_x, y_a, y_b, lam = utils.mixup_data(
                     batch_data, batch_label, args.batch_size,
                     args.mix_alpha)
                 batch_out = [[mixed_x, y_a, y_b, lam, flatten_label, \
                             flatten_non_label, rad_var]]
                 yield batch_out
             else:
                 batch_out = [[batch_data, batch_label]]
                 yield batch_out
             batch_data = []
             batch_label = []
     if len(batch_data) != 0:
         batch_data = np.array(batch_data, dtype='float32')
         batch_label = np.array(batch_label, dtype='int64')
         if is_training:
             flatten_label, flatten_non_label = \
               generate_reshape_label(batch_label, len(batch_data))
             rad_var = generate_bernoulli_number(len(batch_data))
             mixed_x, y_a, y_b, lam = utils.mixup_data(
                 batch_data, batch_label, len(batch_data), args.mix_alpha)
             batch_out = [[mixed_x, y_a, y_b, lam, flatten_label, \
                         flatten_non_label, rad_var]]
             yield batch_out
         else:
             batch_out = [[batch_data, batch_label]]
             yield batch_out
         batch_data = []
         batch_label = []
Esempio n. 15
0
    def iterate(self, phase):
        self.model.train(phase == "train")

        running_loss = 0.0
        running_acc = np.zeros(4)
        for images, masks in self.dataloaders[phase]:
            labels = (torch.sum(masks, (2, 3)) > 0).type(torch.float32)

            # try mixup
            if phase == "train":
                if self.mixup:
                    images, targets_a, targets_b, lam = mixup_data(
                        images, labels)
                    loss, outputs = self.forward_mixup(images, targets_a,
                                                       targets_b, lam)
                    outputs = (outputs.detach().cpu() > 0.5).type(
                        torch.float32).numpy()
                    targets_a, targets_b = targets_a.numpy(), targets_b.numpy()
                    correct = lam * np.equal(outputs, targets_a).astype(np.float32) + (1 - lam) * \
                              np.equal(outputs, targets_b).astype(np.float32)
                else:
                    loss, outputs = self.forward(images, labels)
                    outputs = (outputs.detach().cpu() > 0.5).type(
                        torch.float32).numpy()
                    labels = labels.numpy()
                    correct = np.equal(outputs, labels).astype(np.float32)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            else:
                loss, outputs = self.forward(images, labels)
                outputs = (outputs.detach().cpu() > 0.5).type(
                    torch.float32).numpy()
                labels = labels.numpy()
                correct = np.equal(outputs, labels).astype(np.float32)

            running_loss += loss.item()
            running_acc += np.sum(correct, axis=0) / labels.shape[0]

        epoch_loss = running_loss / len(self.dataloaders[phase])
        epoch_acc = running_acc / len(self.dataloaders[phase])

        self.loss[phase].append(epoch_loss)
        self.accuracy[phase] = np.concatenate(
            (self.accuracy[phase], np.expand_dims(epoch_acc, axis=0)), axis=0)

        torch.cuda.empty_cache()

        return epoch_loss, epoch_acc
Esempio n. 16
0
def train(train_loader, net, criterion, optimizer, epoch, device):
    global writer

    start = time.time()
    # 设置为tranin模式,仅当有dropout和batchnormal时工作
    net.train()

    train_loss = 0
    correct = 0
    total = 0
    logger.info("====Epoch:[{}/{}]====".format(epoch + 1, config.epochs))
    for batch_index, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        if config.mixup:
            inputs, targets_a, targets_b, lam = utils.mixup_data(
                inputs, targets, config.mixup_alpha, device)
            outputs = net(inputs)
            loss = utils.mixup_criterion(criterion, outputs, targets_a,
                                         targets_b, lam)
        else:
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += inputs.size()[0]
        if config.mixup:
            correct += (lam * predicted.eq * (targets_a)).sum().item() + (
                1 - lam) * predicted.eq(targets_b).sum().item()
        else:
            correct += predicted.eq(targets).sum().item()
        if batch_index % 100 == 99:
            logger.info(
                "   == step: [{:3}/{}], train loss: {:.3f} | train acc: {:6.3f}% | lr: {:.6f}"
                .format(batch_index + 1, len(train_loader),
                        train_loss / (batch_index + 1),
                        100.0 * correct / total,
                        utils.get_current_lr(optimizer)))

    end = time.time()
    logger.info("   == cost time: {:.4f}s".format(end - start))
    train_loss = train_loss / (batch_index + 1)
    train_acc = correct / total
    writer.add_scalar('test_loss', train_loss, global_step=epoch)
    writer.add_scalar('test_acc', train_acc, global_step=epoch)
    return train_loss, train_acc
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        if args.mixup:
            inputs, targets_a, targets_b, lam = mixup_data(
                inputs, targets, 1.0, use_cuda)
            inputs, targets_a, targets_b = map(Variable,
                                               (inputs, targets_a, targets_b))
            outputs = net(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b,
                                   lam)
            _, predicted = torch.max(outputs.data, 1)
            correct += lam * predicted.eq(targets_a.data).cpu().sum().float()
            correct += (1 - lam) * predicted.eq(
                targets_b.data).cpu().sum().float()
        else:
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = torch.max(outputs.data, 1)
            correct += predicted.eq(targets.data).cpu().sum()

        total += targets.size(0)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(
            batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))

        if args.transfer_learning:
            if batch_idx >= len(trainloader) - 2:
                break
Esempio n. 18
0
    def train(self, x_val, y_val):
        x = Variable(x_val, requires_grad=False)
        y = Variable(y_val, requires_grad=False)

        # MIX UP PREPARE
        x, y_a, y_b, lam = utils.mixup_data(x, y, self.alpha)

        x = Variable(x, requires_grad=False)
        y_a = Variable(y_a, requires_grad=False)
        y_b = Variable(y_b, requires_grad=False)

        ###################

        self.optimizer.zero_grad()

        output = self.forward(x)

        # MIXUP CRITERION
        loss_mixup = lam * self.loss(output, y_a) + (1 - lam) * self.loss(output, y_b)

        # weight decay
        L2_decay_sum = 0        
        for name, param in self.named_parameters():
            if 'weight' in name:
                name_id = str(name.split('.')[0]) 

                layer_name = copy.deepcopy(self._modules[name_id].__class__.__name__)
                if layer_name == 'Conv2d' or layer_name == 'Linear' or layer_name == "depthwise_separable_conv":
                    L2_decay_sum += 0.0005 * torch.norm(param.view(-1),2)

        # total loss
        loss_loc = loss_mixup + L2_decay_sum

        loss_loc.backward(retain_graph=True)

        self.optimizer.step()

        return output, loss_loc.data
Esempio n. 19
0
    def train(self, dataloader, alpha, h, N):
        self.net.train() 
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            self.optimizer.zero_grad()
            if self.use_cuda:     
                inputs, targets = inputs.cuda(), targets.cuda() 
            # mix-up
            if alpha>0:
                inputs, targets = utils.mixup_data(inputs, targets, alpha, self.use_cuda)
            
            if h == 0.0:
                outputs = self.net(inputs)
            else:
                outputs = self.self_ensemble(inputs, h, N)           
                
            if self.surrogate_loss == 'hinge':
                dummy_input = torch.zeros(targets.shape, dtype=torch.float, device=self.device) 
                loss = self.criterion(outputs.squeeze(), dummy_input.squeeze(), targets)
            else:
                loss = self.criterion(outputs.squeeze(), targets)        
            
                      
            loss.backward()
            self.optimizer.step()

            train_loss += targets.size(0)*loss.data.cpu().numpy()
            y_hat_class = np.where(outputs.data.cpu().numpy()<0, -1, 1)
            
            correct += np.sum(targets.data.cpu().numpy()==y_hat_class.squeeze())
            total += targets.size(0) 
        del batch_idx, inputs, targets, outputs
        
        return (train_loss/total, 100.*correct/total)
Esempio n. 20
0
def train(args, i):
    '''Training. Model will be saved after several iterations. 
    
    Args: 
      dataset_dir: string, directory of dataset
      workspace: string, directory of workspace
      holdout_fold: '1' | 'none', set 1 for development and none for training 
          on all data without validation
      model_type: string, e.g. 'Cnn_9layers_AvgPooling'
      batch_size: int
      cuda: bool
      mini_data: bool, set True for debugging on a small part of data
    '''

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    batch_size = args.batch_size
    cuda = args.cuda and torch.cuda.is_available()
    mini_data = args.mini_data
    filename = args.filename
    audio_num = config.audio_num
    mel_bins = config.mel_bins
    frames_per_second = config.frames_per_second
    max_iteration = None  # Number of mini-batches to evaluate on training data
    reduce_lr = True
    in_domain_classes_num = len(config.labels)

    # Paths
    if mini_data:
        prefix = 'minidata_'
    else:
        prefix = ''

    train_csv = os.path.join(sys.path[0], 'fold' + str(i) + '_train.csv')

    validate_csv = os.path.join(sys.path[0], 'fold' + str(i) + '_test.csv')

    feature_hdf5_path = os.path.join(
        workspace, 'features',
        '{}logmel_{}frames_{}melbins.h5'.format(prefix, frames_per_second,
                                                mel_bins))

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        '{}logmel_{}frames_{}melbins.h5'.format(prefix, frames_per_second,
                                                mel_bins),
        'holdout_fold={}'.format(holdout_fold), model_type)
    create_folder(checkpoints_dir)

    validate_statistics_path = os.path.join(
        workspace, 'statistics', filename,
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        'holdout_fold={}'.format(holdout_fold), model_type,
        'validate_statistics.pickle')

    create_folder(os.path.dirname(validate_statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename, args.mode,
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        'holdout_fold={}'.format(holdout_fold), model_type)
    create_logging(logs_dir, 'w')
    logging.info(args)

    if cuda:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')

    # Model
    Model = eval(model_type)

    model = Model(in_domain_classes_num, activation='logsoftmax')
    loss_func = nll_loss

    if cuda:
        model.cuda()

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)
    #     optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-5)
    # Data generator
    data_generator = DataGenerator(feature_hdf5_path=feature_hdf5_path,
                                   train_csv=train_csv,
                                   validate_csv=validate_csv,
                                   holdout_fold=holdout_fold,
                                   batch_size=batch_size)

    # Evaluator
    evaluator = Evaluator(model=model,
                          data_generator=data_generator,
                          cuda=cuda)

    # Statistics
    validate_statistics_container = StatisticsContainer(
        validate_statistics_path)

    train_bgn_time = time.time()
    iteration = 0

    # Train on mini batches
    for batch_data_dict in data_generator.generate_train():

        # Evaluate
        if iteration % 100 == 0 and iteration >= 1500:
            logging.info('------------------------------------')
            logging.info('Iteration: {}'.format(iteration))

            train_fin_time = time.time()

            train_statistics = evaluator.evaluate(data_type='train',
                                                  iteration=iteration,
                                                  max_iteration=None,
                                                  verbose=False)

            if holdout_fold != 'none':
                validate_statistics = evaluator.evaluate(data_type='validate',
                                                         iteration=iteration,
                                                         max_iteration=None,
                                                         verbose=False)
                validate_statistics_container.append_and_dump(
                    iteration, validate_statistics)

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info('Train time: {:.3f} s, validate time: {:.3f} s'
                         ''.format(train_time, validate_time))

            train_bgn_time = time.time()


#         Save model
        if iteration % 100 == 0 and iteration > 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Reduce learning rate
        if reduce_lr and iteration % 100 == 0 and iteration > 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.9

        # Move data to GPU
        for key in batch_data_dict.keys():
            if key in ['feature', 'target']:
                batch_data_dict[key] = move_data_to_gpu(
                    batch_data_dict[key], cuda)

        # Train
        for i in range(audio_num):
            model.train()
            data, target_a, target_b, lam = mixup_data(
                x=batch_data_dict['feature'][:, i, :, :],
                y=batch_data_dict['target'],
                alpha=0.2)
            batch_output = model(data)
            #         batch_output = model(batch_data_dict['feature'])
            # loss
            loss = loss_func(batch_output, batch_data_dict['target'])
            loss = mixup_criterion(loss_func, batch_output, target_a, target_b,
                                   lam)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Stop learning
        if iteration == 4000:
            break

        iteration += 1
Esempio n. 21
0
def valid(loader, model, criterion_cls, criterion_ranking, optimizer, epoch, history, logger, args):
    batch_time = utils.AverageMeter()
    data_time = utils.AverageMeter()
    total_losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    cls_losses = utils.AverageMeter()          ## cross entropy loss
    ranking_losses = utils.AverageMeter()      ## marginranking loss
    end = time.time()

    print("*** Valid ***")
    model.eval()

    all_idx = []
    all_iscorrect = []
    all_confidence = []
    all_target = []

    ## 원본 이미지, 라벨 저장

    for i, (input, target, idx) in enumerate(loader):   ## batchsize = 128
    # for i, (input, target) in enumerate(loader):   ## batchsize = 128
        with torch.no_grad():
            data_time.update(time.time() - end)
            input, target = input.cuda(), target.cuda()
            confidence = []
            all_idx.extend(idx.tolist())
            all_target.extend(target.tolist())

            ##mixup
            if args.mixup is not None:
                input, target_a, target_b, lam = utils.mixup_data(input, target, args.mixup, True)
                input, target_a, target_b = map(Variable, (input, target_a, target_b))

            output = model(input)

            # NaN alert
            assert torch.all(output == output)

            # compute ranking target value normalize (0 ~ 1) range
            # max(softmax)
            if args.rank_target == 'softmax':
                conf = F.softmax(output, dim=1)
                confidence, prediction = conf.max(dim=1)        ## predictin : 예측 class, confidence : 그때의 confidence

            # entropy
            elif args.rank_target == 'entropy':
                if args.data == 'cifar100':
                    value_for_normalizing = 4.605170
                else:
                    value_for_normalizing = 2.302585
                confidence = crl_utils.negative_entropy(output,
                                                        normalize=True,
                                                        max_value=value_for_normalizing)
            # margin
            elif args.rank_target == 'margin':
                conf, _ = torch.topk(F.softmax(output), 2, dim=1)
                conf[:,0] = conf[:,0] - conf[:,1]
                confidence = conf[:,0]

            # make input pair
            rank_input1 = confidence
            rank_input2 = torch.roll(confidence, -1)
            idx2 = torch.roll(idx, -1)

            # calc target, margin
            rank_target, rank_margin, norm_cor = history.get_target_margin(idx, idx2) ## rank_target : 누가 더 크냐 1, 0, -1 / rank_margin : 옳게 맞춘 횟수의 차이

            rank_target_nonzero = rank_target.clone()
            rank_target_nonzero[rank_target_nonzero == 0] = 1 ## rank_target 에서 0을 다 1로 바꿈
            rank_input2 = rank_input2 + rank_margin / rank_target_nonzero
            ranking_loss = criterion_ranking(rank_input1,
                                             rank_input2,
                                             rank_target)

            # total loss
            if args.mixup is not None:
                cls_loss = utils.mixup_criterion(criterion_cls, output, target_a, target_b, lam)
            else:
                cls_loss = criterion_cls(output, target)

            ranking_loss = args.rank_weight * ranking_loss
            loss = cls_loss + ranking_loss

        # record loss and accuracy
        prec, correct = utils.accuracy(output, target)

        all_iscorrect.extend(map(int, correct))
        all_confidence.extend(confidence.tolist())
        total_losses.update(loss.item(), input.size(0))
        cls_losses.update(cls_loss.item(), input.size(0))
        ranking_losses.update(ranking_loss.item(), input.size(0))
        top1.update(prec.item(), input.size(0))


        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('[{0}][{1}/{2}] '
                  'Time {batch_time.val:.3f}({batch_time.avg:.3f}) '
                  'Data {data_time.val:.3f}({data_time.avg:.3f}) '
                  'Loss {loss.val:.4f}({loss.avg:.4f}) '
                  'CLS Loss {cls_loss.val:.4f}({cls_loss.avg:.4f}) '
                  'Rank Loss {rank_loss.val:.4f}({rank_loss.avg:.4f}) '
                  'Prec {top1.val:.2f}%({top1.avg:.2f}%)'.format(
                   epoch, i, len(loader), batch_time=batch_time,
                   data_time=data_time, loss=total_losses, cls_loss=cls_losses,
                   rank_loss=ranking_losses,top1=top1))

        # history.confidence_update(idx, correct, output)


    # max correctness update
    # history.max_correctness_update(epoch)
    logger.write([epoch, total_losses.avg, cls_losses.avg, ranking_losses.avg, top1.avg])

    return all_idx, all_iscorrect, all_confidence, all_target, total_losses, prec.item()
Esempio n. 22
0
    def train(self, epoch):
        batch_time = AverageMeter()
        losses = AverageMeter()
        acc = AverageMeter()
        self.scheduler.step()
        self.model.train()

        end = time.time()
        lr = self.scheduler.get_lr()[0]

        # for batch, (softmax_data, triplet_data) in enumerate(itertools.zip_longest(self.softmax_train_loader, self.triplet_train_loader)):
        for batch, (softmax_data, triplet_data) in enumerate(
                zip(self.softmax_train_loader, self.triplet_train_loader)):
            loss = 0
            softmax_inputs, softmax_labels = softmax_data
            # 转cuda
            softmax_inputs = softmax_inputs.to(
                self.device
            ) if torch.cuda.device_count() >= 1 else softmax_inputs
            softmax_labels = softmax_labels.to(
                self.device
            ) if torch.cuda.device_count() >= 1 else softmax_labels

            # softmax_score, softmax_outputs = self.model(softmax_inputs)
            # traditional_loss = self.softmax_loss(softmax_score, softmax_outputs, softmax_labels)
            # loss += traditional_loss

            inputs, targets_a, targets_b, lam = mixup_data(softmax_inputs,
                                                           softmax_labels,
                                                           alpha=opt.alpha)
            # inputs, targets_a, targets_b = Variable(inputs), Variable(targets_a), Variable(targets_b)
            softmax_score, softmax_outputs = self.model(softmax_inputs)
            loss_func = mixup_criterion(targets_a, targets_b, lam)
            mixup_loss = loss_func(criterion, softmax_score)
            loss += mixup_loss

            losses.update(loss.item(), softmax_inputs.size(0))
            prec = (softmax_score.max(1)[1] == softmax_labels).float().mean()
            acc.update(prec, softmax_inputs.size(0))

            triplet_inputs, triplet_labels = triplet_data
            # 转cuda
            triplet_inputs = triplet_inputs.to(
                self.device
            ) if torch.cuda.device_count() >= 1 else triplet_inputs
            triplet_labels = triplet_labels.to(
                self.device
            ) if torch.cuda.device_count() >= 1 else triplet_labels
            triplet_score, triplet_outputs = self.model(triplet_inputs)
            triplet_loss = self.triplet_loss(triplet_score, triplet_outputs,
                                             triplet_labels)
            loss += triplet_loss

            self.optimizer.zero_grad()
            if opt.fp16:  # we use optimier to backward loss
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()

            # 评估训练耗时
            batch_time.update(time.time() - end)
            end = time.time()

            # 打印耗时与结果
            if (batch + 1) % 10 == 0:
                logger.debug(
                    'Epoch: [{}][{}/{}]\t'
                    'Base_lr: [{:.2e}]\t'
                    'Time: ({batch_time.avg:.3f})\t'
                    'Loss_val: {loss.val:.4f}  (Loss_avg: {loss.avg:.4f})\t'
                    'Accuray_val: {acc.val:.4f}  (Accuray_avg: {acc.avg:.4f})'.
                    format(epoch,
                           batch + 1,
                           len(self.softmax_train_loader),
                           lr,
                           batch_time=batch_time,
                           loss=losses,
                           acc=acc))

        # 每个epoch的结果
        log_text = 'Epoch[{}]\tBase_lr {:.2e}\tAccuray {acc.avg:.4f}\tLoss {loss.avg:.4f}'.format(
            epoch, lr, acc=acc, loss=losses)
        logger.info(log_text)
        with open(log_file, 'a') as f:
            f.write(log_text + '\n')
            f.flush()
Esempio n. 23
0
def train(train_loader, net, criterion, optimizer, epoch, device,\
          layer_inputs, layer_outputs, grad_inputs, grad_outputs, layers, crit, groups):
    global writer

    start = time.time()
    net.train()

    train_loss = 0
    correct = 0
    total = 0
    eps = 0.001
    logger.info(" === Epoch: [{}/{}] === ".format(epoch + 1, config.epochs))

    for batch_index, (inputs, targets) in enumerate(train_loader):
        # move tensor to GPU
        inputs, targets = inputs.to(device), targets.to(device)
        inputs.requires_grad = True
        layer_inputs.clear()
        layer_outputs.clear()
        grad_inputs.clear()
        grad_outputs.clear()
        if config.mixup:
            inputs, targets_a, targets_b, lam = mixup_data(
                inputs, targets, config.mixup_alpha, device)

            outputs = net(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b,
                                   lam)
        else:
            outputs = net(inputs)
            loss = criterion(outputs, targets)

        # zero the gradient buffers
        optimizer.zero_grad()
        # backward
        loss.backward()

        #fgsm
        # for p in net.parameters():
        #     p.grad *= args.alpha
        # adv_input = inputs + eps * inputs.grad.sign()
        #
        # outputs = net(adv_input)
        #
        # loss_2 = (1-args.alpha) * criterion(outputs, targets)
        # loss_2.backward()

        # layer_loss = update_grad(net, layer_inputs, layer_outputs, grad_inputs, grad_outputs, layers, crit, args.alpha)
        layer_loss = group_noise(net, groups, crit, args.alpha)
        optimizer.step()

        # count the loss and acc
        train_loss += args.alpha * loss.item() + (1 - args.alpha) * layer_loss
        _, predicted = outputs.max(1)
        total += targets.size(0)
        if config.mixup:
            correct += (lam * predicted.eq(targets_a).sum().item() +
                        (1 - lam) * predicted.eq(targets_b).sum().item())
        else:
            correct += predicted.eq(targets).sum().item()

        if (batch_index + 1) % 100 == 0:
            logger.info(
                "   == step: [{:3}/{}], train loss: {:.3f} | train acc: {:6.3f}% | lr: {:.6f}"
                .format(batch_index + 1, len(train_loader),
                        train_loss / (batch_index + 1),
                        100.0 * correct / total, get_current_lr(optimizer)))

    logger.info(
        "   == step: [{:3}/{}], train loss: {:.3f} | train acc: {:6.3f}% | lr: {:.6f}"
        .format(batch_index + 1, len(train_loader),
                train_loss / (batch_index + 1), 100.0 * correct / total,
                get_current_lr(optimizer)))

    end = time.time()
    logger.info("   == cost time: {:.4f}s".format(end - start))
    train_loss = train_loss / (batch_index + 1)
    train_acc = correct / total

    writer.add_scalar('train_loss', train_loss, global_step=epoch)
    writer.add_scalar('train_acc', train_acc, global_step=epoch)

    return train_loss, train_acc
Esempio n. 24
0
def main_worker(args, logger):
    try:
        writer = SummaryWriter(logdir=args.sub_tensorboard_dir)

        train_set = RSDataset(rootpth=args.data_dir, mode='train')
        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  drop_last=True,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=args.num_workers)

        # 权重list,每个样本被选择的概率,重采样效果不好,不使用,但是留作实例,以后参考
        # sampler_weight = train_set.get_sampler_weight()
        #
        # train_sampler = WeightedRandomSampler(sampler_weight,
        #                                 num_samples=100000,     # 每次循环,使用的样本数量
        #                                 replacement=True)
        #
        # train_loader = DataLoader(train_set,
        #                           batch_size=args.batch_size,
        #                           pin_memory=True,
        #                           num_workers=args.num_workers,
        #                           sampler=train_sampler)

        val_set = RSDataset(rootpth=args.data_dir, mode='val')
        val_loader = DataLoader(val_set,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=args.num_workers)

        net = Dense201()
        logger.info('net name: {}'.format(net.__class__.__name__))
        net.train()
        input_ = torch.randn((1, 3, 224, 224))
        writer.add_graph(net, input_)
        net = net.cuda()
        criterion = nn.CrossEntropyLoss().cuda()

        if args.pre_epoch:
            # 预训练:冻结前面的层,只训练新增加的全连接层
            for name, param in net.named_parameters():
                if 'classifier' not in name:
                    param.requires_grad = False

            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         net.parameters()),
                                  lr=args.base_lr,
                                  momentum=0.9,
                                  nesterov=args.sgdn,
                                  weight_decay=args.weight_decay)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=args.pre_epoch * len(train_loader),
                eta_min=args.min_lr)

        loss_record = []
        iter_step = 0
        running_loss = []
        st = glob_st = time.time()
        total_epoch = args.pre_epoch + args.warmup_epoch + args.normal_epoch
        total_iter_step = len(train_loader) * total_epoch

        logger.info('len(train_set): {}'.format(len(train_set)))
        logger.info('len(train_loader): {}'.format(len(train_loader)))
        logger.info('len(val_set): {}'.format(len(val_set)))
        logger.info('len(val_loader): {}'.format(len(val_loader)))
        logger.info('total_epoch: {}'.format(total_epoch))
        logger.info('total_iter_step: {}'.format(total_iter_step))

        if args.pre_epoch:
            logger.info('----- start pre train ------')
        for epoch in range(total_epoch):

            # 评估
            # if epoch % args.eval_fre == 0 and epoch!=0 :
            if epoch % args.eval_fre == 0:
                evalute(net, val_loader, writer, epoch, logger)

            # 保存
            if epoch % args.save_fre == 0 and epoch > args.save_after:
                model_out_name = osp.join(args.sub_model_out_dir,
                                          'out_{}.pth'.format(epoch))
                # 防止分布式训练保存失败
                state_dict = net.modules.state_dict() if hasattr(
                    net, 'module') else net.state_dict()
                torch.save(state_dict, model_out_name)

            # 预训练结束,训练所有参数,重构optimizer--但是只对全连接和卷积层的乘权重进行衰减
            if epoch == args.pre_epoch:
                for param in net.parameters():
                    param.requires_grad = True

                wd_params, nowd_params = [], []
                for name, module in net.named_modules():
                    if isinstance(module, (nn.Linear, nn.Conv2d)):
                        wd_params.append(module.weight)
                        if not module.bias is None:
                            nowd_params.append(module.bias)
                    # todo 这种paramlist会不会漏掉了一些参数
                    elif isinstance(module, nn.BatchNorm2d):
                        nowd_params += list(module.parameters())
                    # else:
                    #     nowd_params += list(module.parameters())
                param_list = [{
                    'params': wd_params
                }, {
                    'params': nowd_params,
                    'weight_decay': 0
                }]

                optimizer = optim.SGD(param_list,
                                      lr=args.base_lr,
                                      momentum=0.9,
                                      nesterov=args.sgdn,
                                      weight_decay=args.weight_decay)
                # 重构学习率调度器
                if args.warmup_epoch:
                    scheduler = LinearScheduler(optimizer,
                                                start_lr=args.min_lr,
                                                end_lr=args.base_lr,
                                                all_steps=args.warmup_epoch *
                                                len(train_loader))
                    logger.info(
                        '-------- start warmup for {} epochs -------'.format(
                            args.warmup_epoch))

            # 如果到了正式训练,构建新的scheduller
            if epoch == args.pre_epoch + args.warmup_epoch:
                scheduler = optim.lr_scheduler.CosineAnnealingLR(
                    optimizer,
                    T_max=args.normal_epoch * len(train_loader),
                    eta_min=args.min_lr)
                logger.info('---- start normal train for {} epoch ----'.format(
                    args.normal_epoch))

            for img, lb in train_loader:
                iter_step += 1
                img = img.cuda()
                lb = lb.cuda()

                optimizer.zero_grad()

                inputs, targets_a, targets_b, lam = mixup_data(
                    img, lb, args.mixup_alpha)
                outputs = net(inputs)
                loss = mixup_criterion(criterion, outputs, targets_a,
                                       targets_b, lam)
                # outputs = net(img)
                # loss = criterion(outputs, lb)

                loss.backward()
                optimizer.step()
                scheduler.step()

                running_loss.append(loss.item())

                if iter_step % args.msg_fre == 0:
                    ed = time.time()
                    spend = ed - st
                    global_spend = ed - glob_st
                    st = ed

                    eta = int((total_iter_step - iter_step) *
                              (global_spend / iter_step))
                    eta = str(datetime.timedelta(seconds=eta))
                    global_spend = str(
                        datetime.timedelta(seconds=(int(global_spend))))

                    avg_loss = np.mean(running_loss)
                    loss_record.append(avg_loss)
                    running_loss = []

                    lr = optimizer.param_groups[0]['lr']

                    msg = '. '.join([
                        'epoch:{epoch}', 'iter/total_iter:{iter}/{total_iter}',
                        'lr:{lr:.7f}', 'loss:{loss:.4f}',
                        'spend/global_spend:{spend:.4f}/{global_spend}',
                        'eta:{eta}'
                    ]).format(epoch=epoch,
                              iter=iter_step,
                              total_iter=total_iter_step,
                              lr=lr,
                              loss=avg_loss,
                              spend=spend,
                              global_spend=global_spend,
                              eta=eta)
                    logger.info(msg)
                    writer.add_scalar('loss', avg_loss, iter_step)
                    writer.add_scalar('lr', lr, iter_step)

        # 训练完最后评估一次
        evalute(net, val_loader, writer, args.pre_epoch + args.normal_epoch,
                logger)

        out_name = osp.join(args.sub_model_out_dir, args.model_out_name)
        torch.save(net.cpu().state_dict(), out_name)

        logger.info('-----------Done!!!----------')

    except:
        logger.exception('Exception logged')
    finally:
        writer.close()
Esempio n. 25
0
def main(args):
    if args["model_type"] == "normal":
        load_robust = False
    else:
        load_robust = True
    simple_target_model = args[
        "simple_target_model"]  # if true, target model is simple CIAR10 model (LeNet)
    simple_local_model = True  # if true, local models are simple CIFAR10 models (LeNet)

    # Set TF random seed to improve reproducibility
    tf.set_random_seed(args["seed"])
    data = CIFAR()
    if not hasattr(K, "tf"):
        raise RuntimeError("This tutorial requires keras to be configured"
                           " to use the TensorFlow backend.")

    if keras.backend.image_dim_ordering() != 'tf':
        keras.backend.set_image_dim_ordering('tf')
        print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' to "
              "'th', temporarily setting to 'tf'")

    # Create TF session and set as Keras backend session
    sess = tf.Session()
    keras.backend.set_session(sess)

    x_test, y_test = CIFAR().test_data, CIFAR().test_labels

    all_trans_rate_ls = []  # store transfer rate of all seeds
    remain_trans_rate_ls = [
    ]  # store transfer rate of remaining seeds, used only in local model fine-tuning

    # Define input TF placeholders
    class_num = 10
    image_size = 32
    num_channels = 3
    test_batch_size = 100
    x = tf.placeholder(tf.float32,
                       shape=(None, image_size, image_size, num_channels))
    y = tf.placeholder(tf.float32, shape=(None, class_num))
    # required by the local robust densenet model
    is_training = tf.placeholder(tf.bool, shape=[])
    keep_prob = tf.placeholder(tf.float32)
    ########################### load the target model ##########################################
    if not load_robust:
        if simple_target_model:
            target_model_name = 'modelA'
            target_model = cifar10_models_simple(sess,test_batch_size, 0, use_softmax=True,x = x, y = y,\
            load_existing=True,model_name=target_model_name)
        else:
            target_model_name = 'densenet'
            target_model = cifar10_models(sess,0,test_batch_size = test_batch_size,use_softmax=True,x = x, y = y,\
            load_existing=True,model_name=target_model_name)
        accuracy = target_model.calcu_acc(x_test, y_test)
        print('Test accuracy of target model {}: {:.4f}'.format(
            target_model_name, accuracy))
    else:
        if args["robust_type"] == "madry":
            target_model_name = 'madry_robust'
            model_dir = "CIFAR10_models/Robust_Deep_models/Madry_robust_target_model"  # TODO: pur your own madry robust target model directory here
            target_model = Load_Madry_Model(sess,
                                            model_dir,
                                            bias=0.5,
                                            scale=255)
        elif args["robust_type"] == "zico":
            # Note: add zico cifar10 model will added in future
            target_model_name = 'zico_robust'
            model_dir = ""  # TODO: put your own robust zico target model directory here
            target_model = Load_Zico_Model(model_dir=model_dir,
                                           bias=0.5,
                                           scale=255)
        else:
            raise NotImplementedError
        corr_preds = target_model.correct_prediction(x_test,
                                                     np.argmax(y_test, axis=1))
        print('Test accuracy of target robust model :{:.4f}'.format(
            np.sum(corr_preds) / len(x_test)))
    ##################################### end of load target model ###################################
    local_model_names = args["local_model_names"]
    robust_indx = []
    normal_local_types = []
    for loc_model_name in local_model_names:
        if loc_model_name == "adv_densenet" or loc_model_name == "adv_vgg" or loc_model_name == "adv_resnet":
            # normal_local_types.append(0)
            robust_indx.append(1)
        else:
            robust_indx.append(0)
            if loc_model_name == "modelB":
                normal_local_types.append(1)
            elif loc_model_name == "modelD":
                normal_local_types.append(3)
            elif loc_model_name == "modelE":
                normal_local_types.append(4)
    print("robust index: ", robust_indx)
    print("normal model types:", normal_local_types)

    local_model_folder = ''
    for ii in range(len(local_model_names)):
        if ii != len(local_model_names) - 1:
            local_model_folder += local_model_names[ii] + '_'
        else:
            local_model_folder += local_model_names[ii]

    nb_imgs = args["num_img"]
    # local model attack related params
    clip_min = -0.5
    clip_max = 0.5
    li_eps = args["cost_threshold"]
    alpha = 1.0
    k = 100
    a = 0.01

    load_existing = True  # load pretrained local models, if false, random model will be given
    with_local = args[
        "with_local"]  # if true, hybrid attack, otherwise, only baseline attacks
    if args["no_tune_local"]:
        stop_fine_tune_flag = True
        load_existing = True
    else:
        stop_fine_tune_flag = False

    if with_local:
        if load_existing:
            loc_adv = 'adv_with_tune'
        if args["no_tune_local"]:
            loc_adv = 'adv_no_tune'
    else:
        loc_adv = 'orig'

    # target type
    if args["attack_type"] == "targeted":
        is_targeted = True
    else:
        is_targeted = False

    sub_epochs = args["nb_epochs_sub"]  # epcohs for local model training
    use_loc_adv_thres = args[
        "use_loc_adv_thres"]  # threshold for transfer attack success rate, it is used when we need to start from local adversarial seeds
    use_loc_adv_flag = True  # flag for using local adversarial examples
    fine_tune_freq = args[
        "fine_tune_freq"]  # fine-tune the model every K images to save total model training time

    # store the attack input files (e.g., original image, target class)
    input_file_prefix = os.path.join(args["local_path"], target_model_name,
                                     args["attack_type"])
    os.system("mkdir -p {}".format(input_file_prefix))
    # save locally generated information
    local_info_file_prefix = os.path.join(args["local_path"],
                                          target_model_name,
                                          args["attack_type"],
                                          local_model_folder,
                                          str(args["seed"]))
    os.system("mkdir -p {}".format(local_info_file_prefix))
    # attack_input_file_prefix = os.path.join(args["local_path"],target_model_name,
    # 											args["attack_type"])
    # save bbox attack information
    out_dir_prefix = os.path.join(args["save_path"], args["attack_method"],
                                  target_model_name, args["attack_type"],
                                  local_model_folder, str(args["seed"]))
    os.system("mkdir -p {}".format(out_dir_prefix))

    #### generate the original images and target classes ####
    target_ys_one_hot,orig_images,target_ys,orig_labels,_, trans_test_images = \
    generate_attack_inputs(sess,target_model,x_test,y_test,class_num,nb_imgs,\
     load_imgs=args["load_imgs"],load_robust=load_robust,\
      file_path = input_file_prefix)
    #### end of genarating original images and target classes ####

    start_points = np.copy(
        orig_images)  # either start from orig seed or local advs
    # store attack statistical info
    dist_record = np.zeros(len(orig_labels), dtype=float)
    query_num_vec = np.zeros(len(orig_labels), dtype=int)
    success_vec = np.zeros(len(orig_labels), dtype=bool)
    adv_classes = np.zeros(len(orig_labels), dtype=int)

    # local model related variables
    if simple_target_model:
        local_model_file_name = "cifar10_simple"
    elif load_robust:
        local_model_file_name = "cifar10_robust"
    else:
        local_model_file_name = "cifar10"
    # save_dir = 'model/'+local_model_file_name + '/'
    callbacks_ls = []
    attacked_flag = np.zeros(len(orig_labels), dtype=bool)

    local_model_ls = []
    if with_local:
        ###################### start loading local models ###############################
        local_model_names_all = []  # help to store complete local model names
        sss = 0
        for model_name in local_model_names:
            if model_name == "adv_densenet" or model_name == "adv_vgg" or model_name == "adv_resnet":
                # tensoflow based robust local models
                loc_model = cifar10_tf_robust_models(sess, test_batch_size = test_batch_size, x = x,y = y, is_training=is_training,keep_prob=keep_prob,\
                 load_existing = True, model_name = model_name,loss = args["loss_function"])
                accuracy = loc_model.calcu_acc(x_test, y_test)
                local_model_ls.append(loc_model)
                print('Test accuracy of model {}: {:.4f}'.format(
                    model_name, accuracy))
                sss += 1
            else:
                # keras based local normal models
                if simple_local_model:
                    type_num = normal_local_types[sss]
                if model_name == 'resnet_v1' or model_name == 'resnet_v2':
                    depth_s = [20, 50, 110]
                else:
                    depth_s = [0]
                for depth in depth_s:
                    # model_name used for loading models
                    if model_name == 'resnet_v1' or model_name == 'resnet_v2':
                        model_load_name = model_name + str(depth)
                    else:
                        model_load_name = model_name
                    local_model_names_all.append(model_load_name)
                    if not simple_local_model:
                        loc_model = cifar10_models(sess,depth,test_batch_size = test_batch_size,use_softmax = True, x = x,y = y,\
                        load_existing = load_existing, model_name = model_name,loss = args["loss_function"])
                    else:
                        loc_model = cifar10_models_simple(sess,test_batch_size,type_num,use_softmax = True, x = x,y = y,\
                        is_training=is_training,keep_prob=keep_prob,load_existing = load_existing, model_name = model_name, loss = args["loss_function"])
                    local_model_ls.append(loc_model)

                    opt = keras.optimizers.SGD(lr=0.01,
                                               decay=1e-6,
                                               momentum=0.9,
                                               nesterov=True)
                    loc_model.model.compile(loss='categorical_crossentropy',
                                            optimizer=opt,
                                            metrics=['accuracy'])
                    orig_images_nw = orig_images
                    orig_labels_nw = orig_labels
                    if args["no_save_model"]:
                        if not load_existing:
                            loc_model.model.fit(
                                orig_images_nw,
                                orig_labels_nw,
                                batch_size=args["train_batch_size"],
                                epochs=sub_epochs,
                                verbose=0,
                                validation_data=(x_test, y_test),
                                shuffle=True)
                    else:
                        print(
                            "Saving local model is yet to be implemented, please check back later, system exiting!"
                        )
                        sys.exit(0)
                        # TODO: fix the issue of loading pretrained model first and then finetune the model
                        # if load_existing:
                        # 	filepath = save_dir + model_load_name + '_pretrained.h5'
                        # else:
                        # 	filepath = save_dir + model_load_name + '.h5'
                        # checkpoint = ModelCheckpoint(filepath=filepath,
                        # 							monitor='val_acc',
                        # 							verbose=0,
                        # 							save_best_only=True)
                        # callbacks = [checkpoint]
                        # callbacks_ls.append(callbacks)
                        # if not load_existing:
                        # 	print("Train on %d data and validate on %d data" % (len(orig_labels_nw),len(y_test)))
                        # 	loc_model.model.fit(orig_images_nw, orig_labels_nw,
                        # 		batch_size=args["train_batch_size"],
                        # 		epochs=sub_epochs,
                        # 		verbose=0,
                        # 		validation_data=(x_test, y_test),
                        # 		shuffle = True,
                        # 		callbacks = callbacks)
                    scores = loc_model.model.evaluate(x_test,
                                                      y_test,
                                                      verbose=0)
                    accuracy = scores[1]
                    print('Test accuracy of model {}: {:.4f}'.format(
                        model_load_name, accuracy))
                    sss += 1
        ##################### end of loading local models ######################################

        ##################### Define Attack Graphs of local PGD attack ###############################
        local_attack_graph = LinfPGDAttack(local_model_ls,
                                           epsilon=li_eps,
                                           k=k,
                                           a=a,
                                           random_start=True,
                                           loss_func=args["loss_function"],
                                           targeted=is_targeted,
                                           robust_indx=robust_indx,
                                           x=x,
                                           y=y,
                                           is_training=is_training,
                                           keep_prob=keep_prob)

        ##################### end of definining graphsof PGD attack ##########################

        ##################### generate local adversarial examples and also store the local attack information #####################
        if not args["load_local_AEs"]:
            # first do the transfer check to obtain local adversarial samples
            # generated local info can be used for batch attacks,
            # max_loss, min_loss, max_gap, min_gap etc are other metrics we explored for scheduling seeds based on local information
            if is_targeted:
                all_trans_rate, pred_labs, local_aes,pgd_cnt_mat, max_loss, min_loss, ave_loss, max_gap, min_gap, ave_gap\
                  = local_attack_in_batches(sess,start_points[np.logical_not(attacked_flag)],\
                target_ys_one_hot[np.logical_not(attacked_flag)],eval_batch_size = test_batch_size,\
                attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
            else:
                all_trans_rate, pred_labs, local_aes,pgd_cnt_mat, max_loss, min_loss, ave_loss, max_gap, min_gap, ave_gap\
                  = local_attack_in_batches(sess,start_points[np.logical_not(attacked_flag)],\
                orig_labels[np.logical_not(attacked_flag)],eval_batch_size = test_batch_size,\
                attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
            # calculate local adv loss used for scheduling seeds in batch attack...
            if is_targeted:
                adv_img_loss, free_idx = compute_cw_loss(sess,target_model,local_aes,\
                target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
            else:
                adv_img_loss, free_idx = compute_cw_loss(sess,target_model,local_aes,\
                orig_labels,targeted=is_targeted,load_robust=load_robust)

            # calculate orig img loss for scheduling seeds in baseline attack
            if is_targeted:
                orig_img_loss, free_idx = compute_cw_loss(sess,target_model,orig_images,\
                target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
            else:
                orig_img_loss, free_idx = compute_cw_loss(sess,target_model,orig_images,\
                orig_labels,targeted=is_targeted,load_robust=load_robust)

            pred_labs = np.argmax(target_model.predict_prob(local_aes), axis=1)
            if is_targeted:
                transfer_flag = np.argmax(target_ys_one_hot,
                                          axis=1) == pred_labs
            else:
                transfer_flag = np.argmax(orig_labels, axis=1) != pred_labs
            # save local aes
            np.save(local_info_file_prefix + '/local_aes.npy', local_aes)
            # store local info of local aes and original seeds: used for scheduling seeds in batch attacks
            np.savetxt(local_info_file_prefix + '/pgd_cnt_mat.txt',
                       pgd_cnt_mat)
            np.savetxt(local_info_file_prefix + '/orig_img_loss.txt',
                       orig_img_loss)
            np.savetxt(local_info_file_prefix + '/adv_img_loss.txt',
                       adv_img_loss)
            np.savetxt(local_info_file_prefix + '/ave_gap.txt', ave_gap)
        else:
            local_aes = np.load(local_info_file_prefix + '/local_aes.npy')
            if is_targeted:
                tmp_labels = target_ys_one_hot
            else:
                tmp_labels = orig_labels
            pred_labs = np.argmax(target_model.predict_prob(
                np.array(local_aes)),
                                  axis=1)
            print('correct number',
                  np.sum(pred_labs == np.argmax(tmp_labels, axis=1)))
            all_trans_rate = accuracy_score(np.argmax(tmp_labels, axis=1),
                                            pred_labs)
        ################################ end of generating local AEs and storing related information #######################################

        if not is_targeted:
            all_trans_rate = 1 - all_trans_rate
        print('** Transfer Rate: **' + str(all_trans_rate))

        if all_trans_rate > use_loc_adv_thres:
            print("Updated the starting points to local AEs....")
            start_points[np.logical_not(attacked_flag)] = local_aes
            use_loc_adv_flag = True

        # independent test set for checking transferability: for experiment purpose and does not count for query numbers
        if is_targeted:
            ind_all_trans_rate,_,_,_,_,_,_,_,_,_ = local_attack_in_batches(sess,trans_test_images,target_ys_one_hot,eval_batch_size = test_batch_size,\
            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
        else:
            ind_all_trans_rate,_,_,_,_,_,_,_,_,_ = local_attack_in_batches(sess,trans_test_images,orig_labels,eval_batch_size = test_batch_size,\
            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)

        # record the queries spent by quering the local samples
        query_num_vec[np.logical_not(attacked_flag)] += 1
        if not is_targeted:
            ind_all_trans_rate = 1 - ind_all_trans_rate
        print('** (Independent Set) Transfer Rate: **' +
              str(ind_all_trans_rate))
        all_trans_rate_ls.append(ind_all_trans_rate)

    S = np.copy(start_points)
    S_label = target_model.predict_prob(S)
    S_label_cate = np.argmax(S_label, axis=1)
    S_label_cate = np_utils.to_categorical(S_label_cate, class_num)

    pre_free_idx = []
    candi_idx_ls = []  # store the indices of images in the order attacked

    # these parameters are used to make sure equal number of instances from each class are selected
    # such that diversity of fine-tuning set is improved. However, it is not effective...
    per_cls_cnt = 0
    cls_order = 0
    change_limit = False
    max_lim_num = int(fine_tune_freq / class_num)

    # define the autozoom bbox attack graph
    if args["attack_method"] == "autozoom":
        # setup the autoencoders for autozoom attack
        codec = 0
        args["img_resize"] = 8
        # replace with your directory
        codec_dir = 'CIFAR10_models/cifar10_autoencoder/'  # TODO: replace with your own cifar10 autoencoder directory
        encoder = load_model(codec_dir + 'whole_cifar10_encoder.h5')
        decoder = load_model(codec_dir + 'whole_cifar10_decoder.h5')

        encode_img = encoder.predict(data.test_data[100:101])
        decode_img = decoder.predict(encode_img)
        diff_img = (decode_img - data.test_data[100:101])
        diff_mse = np.mean(diff_img.reshape(-1)**2)

        # diff_mse = np.mean(np.sum(diff_img.reshape(-1,784)**2,axis = 1))
        print("[Info][AE] MSE:{:.4f}".format(diff_mse))
        encode_img = encoder.predict(data.test_data[0:1])
        decode_img = decoder.predict(encode_img)
        diff_img = (decode_img - data.test_data[0:1])
        diff_mse = np.mean(diff_img.reshape(-1)**2)
        print("[Info][AE] MSE:{:.4f}".format(diff_mse))

    if args["attack_method"] == "autozoom":
        # define black-box model graph of autozoom
        autozoom_graph = AutoZOOM(sess, target_model, args, decoder, codec,
                                  num_channels, image_size, class_num)

    # main loop of hybrid attacks
    for itr in range(len(orig_labels)):
        print("#------------ Substitue training round {} ----------------#".
              format(itr))
        # computer loss functions of seeds: no query is needed here because seeds are already queried before...
        if is_targeted:
            img_loss, free_idx = compute_cw_loss(sess,target_model,start_points,\
            target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
        else:
            img_loss, free_idx = compute_cw_loss(sess,target_model,start_points,\
            orig_labels,targeted=is_targeted,load_robust=load_robust)
        free_idx_diff = list(set(free_idx) - set(pre_free_idx))
        print("new free idx found:", free_idx_diff)
        if len(free_idx_diff) > 0:
            candi_idx_ls.extend(free_idx_diff)
        pre_free_idx = free_idx
        if with_local:
            if len(free_idx) > 0:
                # free attacks are found
                attacked_flag[free_idx] = 1
                success_vec[free_idx] = 1
                # update dist and adv class
                if args['dist_metric'] == 'l2':
                    dist = np.sum(
                        (start_points[free_idx] - orig_images[free_idx])**2,
                        axis=(1, 2, 3))**.5
                elif args['dist_metric'] == 'li':
                    dist = np.amax(np.abs(start_points[free_idx] -
                                          orig_images[free_idx]),
                                   axis=(1, 2, 3))
                # print(start_points[free_idx].shape)
                adv_class = target_model.pred_class(start_points[free_idx])
                adv_classes[free_idx] = adv_class
                dist_record[free_idx] = dist
                if np.amax(
                        dist
                ) >= args["cost_threshold"] + args["cost_threshold"] / 10:
                    print(
                        "there are some problems in setting the perturbation distance!"
                    )
                    sys.exit(0)
        print("Number of Unattacked Seeds: ",
              np.sum(np.logical_not(attacked_flag)))
        if attacked_flag.all():
            # early stop when all seeds are sucessfully attacked
            break

        # define the seed generation process as a functon
        if args["sort_metric"] == "min":
            img_loss[attacked_flag] = 1e10
        elif args["sort_metric"] == "max":
            img_loss[attacked_flag] = -1e10
        candi_idx, per_cls_cnt, cls_order,change_limit,max_lim_num = select_next_seed(img_loss,attacked_flag,args["sort_metric"],\
        args["by_class"],fine_tune_freq,class_num,per_cls_cnt,cls_order,change_limit,max_lim_num)

        print(candi_idx)
        candi_idx_ls.append(candi_idx)

        input_img = start_points[candi_idx:candi_idx + 1]
        if args["attack_method"] == "autozoom":
            # encoder decoder performance check
            encode_img = encoder.predict(input_img)
            decode_img = decoder.predict(encode_img)
            diff_img = (decode_img - input_img)
            diff_mse = np.mean(diff_img.reshape(-1)**2)
        else:
            diff_mse = 0.0

        print("[Info][Start]: test_index:{}, true label:{}, target label:{}, MSE:{}".format(candi_idx, np.argmax(orig_labels[candi_idx]),\
         np.argmax(target_ys_one_hot[candi_idx]),diff_mse))

        ################## BEGIN: bbox attacks ############################
        if args["attack_method"] == "autozoom":
            # perform bbox attacks
            if is_targeted:
                x_s, ae, query_num = autozoom_attack(
                    autozoom_graph, input_img,
                    orig_images[candi_idx:candi_idx + 1],
                    target_ys_one_hot[candi_idx])
            else:
                x_s, ae, query_num = autozoom_attack(
                    autozoom_graph, input_img,
                    orig_images[candi_idx:candi_idx + 1],
                    orig_labels[candi_idx])
        else:
            if is_targeted:
                x_s, query_num, ae = nes_attack(args,target_model,input_img,orig_images[candi_idx:candi_idx+1],\
                 np.argmax(target_ys_one_hot[candi_idx]), lower = clip_min, upper = clip_max)
            else:
                x_s, query_num, ae = nes_attack(args,target_model,input_img,orig_images[candi_idx:candi_idx+1],\
                 np.argmax(orig_labels[candi_idx]), lower = clip_min, upper = clip_max)
            x_s = np.squeeze(np.array(x_s), axis=1)
        ################## END: bbox attacks ############################

        attacked_flag[candi_idx] = 1

        # fill the query info, etc
        if len(ae.shape) == 3:
            ae = np.expand_dims(ae, axis=0)
        if args['dist_metric'] == 'l2':
            dist = np.sum((ae - orig_images[candi_idx])**2)**.5
        elif args['dist_metric'] == 'li':
            dist = np.amax(np.abs(ae - orig_images[candi_idx]))
        adv_class = target_model.pred_class(ae)
        adv_classes[candi_idx] = adv_class
        dist_record[candi_idx] = dist

        if args["attack_method"] == "autozoom":
            # autozoom utilizes the query info of attack input, which is already done at the begining.
            added_query = query_num - 1
        else:
            added_query = query_num

        query_num_vec[candi_idx] += added_query
        if dist >= args["cost_threshold"] + args["cost_threshold"] / 10:
            print("the distance is not optimized properly")
            sys.exit(0)

        if is_targeted:
            if adv_class == np.argmax(target_ys_one_hot[candi_idx]):
                success_vec[candi_idx] = 1
        else:
            if adv_class != np.argmax(orig_labels[candi_idx]):
                success_vec[candi_idx] = 1
        if attacked_flag.all():
            print(
                "Early termination because all seeds are successfully attacked!"
            )
            break
        ##############################################################
        ## Starts the section of substitute training and local advs ##
        ##############################################################
        if with_local:
            if not stop_fine_tune_flag:
                # augment the local model training data with target model labels
                print(np.array(x_s).shape)
                print(S.shape)
                S = np.concatenate((S, np.array(x_s)), axis=0)
                S_label_add = target_model.predict_prob(np.array(x_s))
                S_label_add_cate = np.argmax(S_label_add, axis=1)
                S_label_add_cate = np_utils.to_categorical(
                    S_label_add_cate, class_num)
                S_label_cate = np.concatenate((S_label_cate, S_label_add_cate),
                                              axis=0)
                # empirically, tuning with model prediction probabilities given slightly better results.
                # if your bbox attack is decision based, only use the prediction labels
                S_label = np.concatenate((S_label, S_label_add), axis=0)
                # fine-tune the model
                if itr % fine_tune_freq == 0 and itr != 0:
                    if len(S_label) > args["train_inst_lim"]:
                        curr_len = len(S_label)
                        rand_idx = np.random.choice(len(S_label),
                                                    args["train_inst_lim"],
                                                    replace=False)
                        S = S[rand_idx]
                        S_label = S_label[rand_idx]
                        S_label_cate = S_label_cate[rand_idx]
                        print(
                            "current num: %d, max train instance limit %d is reached, performed random sampling to get %d samples!"
                            % (curr_len, len(S_label), len(rand_idx)))
                    sss = 0

                    for loc_model in local_model_ls:
                        model_name = local_model_names_all[sss]
                        if args["use_mixup"]:
                            print(
                                "Updates the training data with mixup strayegy!"
                            )
                            S_nw = np.copy(S)
                            S_label_nw = np.copy(S_label)
                            S_nw, S_label_nw, _ = mixup_data(S_nw,
                                                             S_label_nw,
                                                             alpha=alpha)
                        else:
                            S_nw = S
                            S_label_nw = S_label
                        print("Train on %d data and validate on %d data" %
                              (len(S_label_nw), len(y_test)))
                        if args["no_save_model"]:
                            loc_model.model.fit(
                                S_nw,
                                S_label_nw,
                                batch_size=args["train_batch_size"],
                                epochs=sub_epochs,
                                verbose=0,
                                validation_data=(x_test, y_test),
                                shuffle=True)
                        else:
                            print(
                                "Saving local model is yet to be implemented, please check back later, system exiting!"
                            )
                            sys.exit(0)
                            # callbacks = callbacks_ls[sss]
                            # loc_model.model.fit(S_nw, S_label_nw,
                            # 	batch_size=args["train_batch_size"],
                            # 	epochs=sub_epochs,
                            # 	verbose=0,
                            # 	validation_data=(x_test, y_test),
                            # 	shuffle = True,
                            # 	callbacks = callbacks)
                        scores = loc_model.model.evaluate(x_test,
                                                          y_test,
                                                          verbose=0)
                        print('Test accuracy of model {}: {:.4f}'.format(
                            model_name, scores[1]))
                        sss += 1
                    if not attacked_flag.all():
                        # first check for not attacked seeds
                        if is_targeted:
                            remain_trans_rate, _, remain_local_aes,_, _, _, _, _, _, _\
                              = local_attack_in_batches(sess,orig_images[np.logical_not(attacked_flag)],\
                            target_ys_one_hot[np.logical_not(attacked_flag)],eval_batch_size = test_batch_size,\
                            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
                        else:
                            remain_trans_rate, pred_labs, remain_local_aes,_, _, _, _, _, _, _\
                              = local_attack_in_batches(sess,orig_images[np.logical_not(attacked_flag)],\
                            orig_labels[np.logical_not(attacked_flag)],eval_batch_size = test_batch_size,\
                            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
                        if not is_targeted:
                            remain_trans_rate = 1 - remain_trans_rate
                        print('<<Ramaining Seed Transfer Rate>>:**' +
                              str(remain_trans_rate))
                        # if transfer rate is higher than threshold, use local advs as starting points
                        if remain_trans_rate <= 0 and use_loc_adv_flag:
                            print(
                                "No improvement for substitue training, stop fine-tuning!"
                            )
                            stop_fine_tune_flag = False

                        # transfer rate check with independent test examples
                        if is_targeted:
                            all_trans_rate, _, _, _, _, _, _, _, _, _\
                              = local_attack_in_batches(sess,trans_test_images,target_ys_one_hot,eval_batch_size = test_batch_size,\
                            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
                        else:
                            all_trans_rate, _, _, _, _, _, _, _, _, _\
                              = local_attack_in_batches(sess,trans_test_images,orig_labels,eval_batch_size = test_batch_size,\
                            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
                        if not is_targeted:
                            all_trans_rate = 1 - all_trans_rate
                        print('<<Overall Transfer Rate>>: **' +
                              str(all_trans_rate))

                        # if trans rate is not high enough, still start from orig seed; start from loc adv only
                        # when trans rate is high enough, useful when you start with random model
                        if not use_loc_adv_flag:
                            if remain_trans_rate > use_loc_adv_thres:
                                use_loc_adv_flag = True
                                print("Updated the starting points....")
                                start_points[np.logical_not(
                                    attacked_flag)] = remain_local_aes
                            # record the queries spent on checking newly generated loc advs
                            query_num_vec += 1
                        else:
                            print("Updated the starting points....")
                            start_points[np.logical_not(
                                attacked_flag)] = remain_local_aes
                            # record the queries spent on checking newly generated loc advs
                            query_num_vec[np.logical_not(attacked_flag)] += 1
                        remain_trans_rate_ls.append(remain_trans_rate)
                        all_trans_rate_ls.append(all_trans_rate)
                np.set_printoptions(precision=4)
                print("all_trans_rate:")
                print(all_trans_rate_ls)
                print("remain_trans_rate")
                print(remain_trans_rate_ls)

    # save the query information of all classes
    if not args["no_save_text"]:
        save_name_file = os.path.join(out_dir_prefix,
                                      "{}_num_queries.txt".format(loc_adv))
        np.savetxt(save_name_file, query_num_vec, fmt='%d', delimiter=' ')
        save_name_file = os.path.join(out_dir_prefix,
                                      "{}_success_flags.txt".format(loc_adv))
        np.savetxt(save_name_file, success_vec, fmt='%d', delimiter=' ')
Esempio n. 26
0
        avg_preci = 0.0
        avg_recall = 0.0

        model.train()
        model.clear_gradients()
        t0 = time.time()
        for batch_id, (x, y) in enumerate(train_loader()):
            if step < warm_steps:
                optimizer.set_lr(lrs[step])
            x.stop_gradient = False
            if c['balanced_sampling']:
                x = x.squeeze()
                y = y.squeeze()
            x = x.unsqueeze((1))
            if c['mixup']:
                mixed_x, mixed_y = mixup_data(x, y, c['mixup_alpha'])
                logits = model(mixed_x)
                loss_val = loss_fn(logits, mixed_y)
                loss_val.backward()
            else:
                logits = model(x)
                loss_val = bce_loss(logits, y)
                loss_val.backward()
            optimizer.step()
            model.clear_gradients()
            pred = F.sigmoid(logits)
            preci, recall = get_metrics(y.squeeze().numpy(), pred.numpy())
            avg_loss = (avg_loss * batch_id + loss_val.numpy()[0]) / (1 +
                                                                      batch_id)
            avg_preci = (avg_preci * batch_id + preci) / (1 + batch_id)
            avg_recall = (avg_recall * batch_id + recall) / (1 + batch_id)
Esempio n. 27
0
def train(loader, model, criterion, optimizer, args, scheduler, epoch, lr):
    batch_time = utils.AverageMeter('Time', ':6.3f')
    data_time = utils.AverageMeter('Data', ':6.3f')
    losses = utils.AverageMeter()

    if isinstance(loader, torch.utils.data.dataloader.DataLoader):
        length = len(loader)
    else:
        length = getattr(loader, '_size', 0) / getattr(loader, 'batch_size', 1)
    model.train()
    if 'less_bn' in args.keyword:
        utils.custom_state(model)

    end = time.time()
    for i, data in enumerate(loader):
        if isinstance(data, list) and isinstance(data[0], dict):
            input = data[0]['data']
            target = data[0]['label'].squeeze()
        else:
            input, target = data
        data_time.update(time.time() - end)

        if args.device_ids is not None:
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True).long()

        if args.mixup_enable:
            input, target_a, target_b, lam = utils.mixup_data(
                input,
                target,
                args.mixup_alpha,
                use_cuda=(args.device_ids is not None))

        if 'sgdr' in args.lr_policy and scheduler is not None and torch.__version__ < "1.0.4" and epoch < args.epochs:
            scheduler.step()
            for group in optimizer.param_groups:
                if 'lr_constant' in group:
                    group['lr'] = group['lr_constant']
            lr_list = scheduler.get_lr()
            if isinstance(lr_list, list):
                lr = lr_list[0]

        outputs = model(input)
        if isinstance(outputs, dict) and hasattr(model, '_out_features'):
            outputs = outputs[model._out_features[0]]

        if args.mixup_enable:
            mixup_criterion = lambda pred, target, \
                    lam: (-F.log_softmax(pred, dim=1) * torch.zeros(pred.size()).cuda().scatter_(1, target.data.view(-1, 1), lam.view(-1, 1))) \
                    .sum(dim=1).mean()
            loss = utils.mixup_criterion(target_a, target_b,
                                         lam)(mixup_criterion, outputs)
        else:
            loss = criterion(outputs, target)

        if 'quant_loss' in args.global_buffer:
            loss += args.global_buffer['quant_loss']
            args.global_buffer.pop('quant_loss')

        if i % args.iter_size == 0:
            optimizer.zero_grad()

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if i % args.iter_size == (args.iter_size - 1):
            if args.grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            iterations = epoch * length + i
            if args.wakeup > iterations:
                for param_group in optimizer.param_groups:
                    if param_group.get('lr_constant', None) is not None:
                        continue
                    param_group['lr'] = param_group['lr'] * (
                        1.0 / args.wakeup) * iterations
                logging.info(
                    'train {}/{}, change learning rate to lr * {}'.format(
                        i, length, iterations / args.wakeup))
            if iterations >= args.warmup:
                optimizer.step()

        if 'sgdr' in args.lr_policy and scheduler is not None and torch.__version__ > "1.0.4" and epoch < args.epochs:
            scheduler.step()
            for group in optimizer.param_groups:
                if 'lr_constant' in group:
                    group['lr'] = group['lr_constant']
            lr_list = scheduler.get_lr()
            if isinstance(lr_list, list):
                lr = lr_list[0]

        losses.update(loss.item(), input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.report_freq == 0:
            logging.info(
                'train %d/%d, loss:%.3f(%.3f), batch time:%.2f(%.2f), data load time: %.2f(%.2f)'
                % (i, length, losses.val, losses.avg, batch_time.val,
                   batch_time.avg, data_time.val, data_time.avg))

        if epoch == 0 and i == 10:
            logging.info(utils.gpu_info())
        if args.delay > 0:
            time.sleep(args.delay)

        input = None
        target = None
        data = None

    if 'dali' in args.dataset:
        loader.reset()

    return losses.avg
def train(args,
          model: nn.Module,
          criterion,
          *,
          params,
          train_loader,
          valid_loader,
          init_optimizer,
          use_cuda,
          n_epochs=None,
          patience=2,
          max_lr_changes=3) -> bool:
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    params = list(params)
    optimizer = init_optimizer(params, lr)

    run_root = Path(args.run_root)

    model_path = Path(str(run_root) + '/' + 'model.pt')

    if model_path.exists():
        state = load_model(model, model_path)
        epoch = state['epoch']
        step = state['step']
        best_valid_loss = state['best_valid_loss']
        best_f2 = state['best_f2']
    else:
        epoch = 1
        step = 0
        best_valid_loss = float('inf')
        best_f2 = 0

    lr_changes = 0

    save = lambda ep: torch.save(
        {
            'model': model.state_dict(),
            'epoch': ep,
            'step': step,
            'best_valid_loss': best_valid_loss,
            'best_f2': best_f2
        }, str(model_path))

    report_each = 100
    log = run_root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    valid_f2s = []
    lr_reset_epoch = epoch
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        tq = tqdm.tqdm(
            total=(args.epoch_size or len(train_loader) * args.batch_size))
        tq.set_description(f'Epoch {epoch}, lr {lr}')
        losses = []
        tl = train_loader
        if args.epoch_size:
            tl = islice(tl, args.epoch_size // args.batch_size)
        try:
            mean_loss = 0
            for i, (inputs, targets) in enumerate(tl):
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets_a, targets_b, lam = mixup_data(
                    inputs, targets, 1, use_cuda)
                inputs, targets_a, targets_b = Variable(inputs), Variable(
                    targets_a), Variable(targets_b)
                outputs = model(inputs)
                loss_func = mixup_criterion(targets_a, targets_b, lam)
                loss = loss_func(criterion, outputs)
                loss = _reduce_loss(loss)

                batch_size = inputs.size(0)
                (batch_size * loss).backward()
                if (i + 1) % args.step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss=f'{mean_loss:.3f}')
                # if i and i % report_each == 0:
                #     write_event(log, step, loss=mean_loss)
            write_event(log, step, loss=mean_loss)
            tq.close()
            save(epoch + 1)
            valid_metrics = validation(model, criterion, valid_loader,
                                       use_cuda)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_f2 = valid_metrics['valid_f2_th_0.10']
            valid_f2s.append(valid_f2)
            valid_losses.append(valid_loss)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                #shutil.copy(str(model_path), str(run_root) + '/model_loss_' + f'{valid_loss:.4f}' + '.pt')

            if valid_f2 > best_f2:
                best_f2 = valid_f2
                shutil.copy(
                    str(model_path),
                    str(run_root) + '/model_f2_' + f'{valid_f2:.4f}' + '.pt')


#             if epoch == 7:
#                 lr = 1e-4
#                 print(f'lr updated to {lr}')
#                 optimizer = init_optimizer(params, lr)
#             if epoch == 8:
#                 lr = 1e-5
#                 optimizer = init_optimizer(params, lr)
#                 print(f'lr updated to {lr}')
        except KeyboardInterrupt:
            tq.close()
            #             print('Ctrl+C, saving snapshot')
            #             save(epoch)
            #             print('done.')
            return False
    return True
Esempio n. 29
0
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    l2_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, data in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        input = data[0]
        target = data[-1]
        if args.l2_loss:
            dual_input = data[1]
            dual_input_var = torch.autograd.Variable(dual_input)  
        if CUDA:
            input = input.cuda(async=True)
            target = target.cuda(async=True)
        if args.mixup:
            input, y_a, y_b, lam = utils.mixup_data(input, target, alpha=1.0)
            y_a = torch.autograd.Variable(y_a)
            y_b = torch.autograd.Variable(y_b)


        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output
        if args.l2_loss:
            f1, f2, y1, y2 = model(input_var, dual_input_var)
            l2_loss = l2_loss_w * mse_loss(f1, f2)
            output = torch.cat([y1, y2])
            target = torch.cat([target, target])
            target_var = torch.cat([target_var, target_var])
            loss = criterion(output, target_var)
            loss = loss + l2_loss

            l2_losses.update(l2_loss.data[0], input.size(0))
        else:
            output = model(input_var)
            if args.mixup:
                loss_fun = utils.mixup_criterion(y_a, y_b, lam)
                loss = loss_fun(criterion, output)
            else:
                loss = criterion(output, target_var)

        # measure accuracy and record loss
        if args.mixup:
            _, predicted = torch.max(output.data, 1)
            prec1 = lam*predicted.eq(y_a.data).cpu().sum() + (1-lam)*predicted.eq(y_b.data).cpu().sum()
            top1.update(prec1, input.size(0))
        else:
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            top1.update(prec1[0], input.size(0))
        #top5.update(prec5[0], input.size(0))

        losses.update(loss.data[0], input.size(0))
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 5 == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'L2Loss {l2_loss.val:.4f} ({l2_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, l2_loss=l2_losses, top1=top1))
        
        step = epoch * len(train_loader) + i
        #print(type(step))
        writer.add_scalar('train/acc', prec1[0], step)
        writer.add_scalar('train/loss', loss.data[0], step)
        if args.l2_loss:
            writer.add_scalar('train/l2_loss', l2_loss.data[0], step)
        for name, param in model.named_parameters():
            #print(name, param.data.cpu().numpy().dtype)
            if name.find('batchnorm')==-1:
                writer.add_histogram(name, param.data.cpu().numpy(), step)
Esempio n. 30
0
    def forward(self,
                x,
                target=None,
                mixup_hidden=False,
                mixup_alpha=0.1,
                layers_mix=None,
                ext_feature=False):

        if mixup_hidden == True:
            layer_mix = random.randint(0, layers_mix)

            out = x
            if layer_mix == 0:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = self.conv1(out)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.maxpool(out)
            out = self.layer1(out)

            if layer_mix == 1:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = self.layer2(out)

            if layer_mix == 2:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = self.layer3(out)
            if layer_mix == 3:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = self.layer4(out)
            if layer_mix == 4:
                out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

            out = self.avgpool(out)
            out = out.view(out.size(0), -1)
            out = self.fc(out)

            if ext_feature:
                return out, y_a, y_b, lam
            else:
                # out = self.dropout(out)
                out = self.activation(out)
                out = self.fc2(out)
                return out, y_a, y_b, lam

        else:

            out = x

            out = self.conv1(out)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.maxpool(out)
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)
            out = self.avgpool(out)
            out = out.view(out.size(0), -1)
            out = self.fc(out)

            if ext_feature:
                return out
            else:
                # out = self.dropout(out)
                out = self.activation(out)
                out = self.fc2(out)
                return out
    def train_one_epoch(self, epoch, dataloader):
        config = self.cfg
        self.train_meters = AverageMeterGroup()

        cur_lr = self.optimizer.param_groups[0]["lr"]
        self.logger.info("Epoch %d LR %.6f", epoch, cur_lr)
        if self.enable_writter:
            self.writter.add_scalar("lr", cur_lr, global_step=epoch)

        self.model.train()

        for step, (x, y) in enumerate(dataloader):
            if self.debug and step > 1:
                break
            for callback in self.callbacks:
                callback.on_batch_begin(epoch)
            x, y = x.to(self.device,
                        non_blocking=True), y.to(self.device,
                                                 non_blocking=True)
            bs = x.size(0)
            # mixup data
            if config.mixup.enable:
                x, y_a, y_b, lam = mixup_data(x, y, config.mixup.alpha)
                mixup_y = [y_a, y_b, lam]

            # forward
            logits = self.model(x)

            # loss
            if isinstance(logits, tuple):
                logits, aux_logits = logits
                if config.mixup.enable:
                    aux_loss = mixup_loss_fn(self.loss_fn, aux_logits,
                                             *mixup_y)
                else:
                    aux_loss = self.loss_fn(aux_logits, y)
            else:
                aux_loss = 0.
            if config.mixup.enable:
                loss = mixup_loss_fn(self.loss_fn, logits, *mixup_y)
            else:
                loss = self.loss_fn(logits, y)
            if config.model.aux_weight > 0:
                loss += config.model.aux_weight * aux_loss
            if self.kd_model:
                teacher_output = self.kd_model(x)
                loss += (1 - config.kd.loss.alpha) * loss + loss_fn_kd(
                    logits, teacher_output, self.cfg.kd.loss)

            # backward
            loss.backward()
            # gradient clipping
            # nn.utils.clip_grad_norm_(model.parameters(), 20)

            if (step + 1) % config.trainer.accumulate_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

            # post-processing
            accuracy = metrics(logits, y,
                               topk=(1, 3))  # e.g. {'acc1':0.65, 'acc3':0.86}
            self.train_meters.update(accuracy)
            self.train_meters.update({'train_loss': loss.item()})
            if step % config.logger.log_frequency == 0 or step == len(
                    dataloader) - 1:
                self.logger.info(
                    "Train: [{:3d}/{}] Step {:03d}/{:03d} {}".format(
                        epoch + 1, config.trainer.num_epochs, step,
                        len(dataloader) - 1, self.train_meters))

            for callback in self.callbacks:
                callback.on_batch_end(epoch)

        if self.enable_writter:
            self.writter.add_scalar("loss/train",
                                    self.train_meters['train_loss'].avg,
                                    global_step=epoch)
            self.writter.add_scalar("acc1/train",
                                    self.train_meters['acc1'].avg,
                                    global_step=epoch)
            self.writter.add_scalar("acc3/train",
                                    self.train_meters['acc3'].avg,
                                    global_step=epoch)

        self.logger.info("Train: [{:3d}/{}] Final result {}".format(
            epoch + 1, config.trainer.num_epochs, self.train_meters))

        return self.train_meters