def train(train_iter: ForeverDataIterator, model: Classifier, bss_module, optimizer: SGD,
        epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)

        x = x.to(device)
        label = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y, f = model(x)

        cls_loss = F.cross_entropy(y, label)
        bss_loss = bss_module(f)
        loss = cls_loss + args.trade_off * bss_loss

        cls_acc = accuracy(y, label)[0]

        losses.update(loss.item(), x.size(0))
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #2
0
def train(train_iter: ForeverDataIterator, model: Classifier,
          backbone_regularization: nn.Module, head_regularization: nn.Module,
          target_getter: IntermediateLayerGetter,
          source_getter: IntermediateLayerGetter, optimizer: SGD, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f')
    losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, losses_reg_head, losses_reg_backbone,
        cls_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)
        x = x.to(device)
        label = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        intermediate_output_s, output_s = source_getter(x)
        intermediate_output_t, output_t = target_getter(x)
        y, f = output_t

        # measure accuracy and record loss
        cls_acc = accuracy(y, label)[0]
        cls_loss = F.cross_entropy(y, label)
        if args.regularization_type == 'feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        elif args.regularization_type == 'attention_feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        else:
            loss_reg_backbone = backbone_regularization()
        loss_reg_head = head_regularization()
        loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head

        losses_reg_backbone.update(
            loss_reg_backbone.item() * args.trade_off_backbone, x.size(0))
        losses_reg_head.update(loss_reg_head.item() * args.trade_off_head,
                               x.size(0))
        losses.update(loss.item(), x.size(0))
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #3
0
def calculate_channel_attention(dataset, return_layers, args):
    backbone = models.__dict__[args.arch](pretrained=True)
    classifier = Classifier(backbone, dataset.num_classes).to(device)
    optimizer = SGD(classifier.get_parameters(args.lr),
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    data_loader = DataLoader(dataset,
                             batch_size=args.attention_batch_size,
                             shuffle=True,
                             num_workers=args.workers,
                             drop_last=False)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=math.exp(math.log(0.1) / args.attention_lr_decay_epochs))
    criterion = nn.CrossEntropyLoss()

    channel_weights = []
    for layer_id, name in enumerate(return_layers):
        layer = get_attribute(classifier, name)
        layer_channel_weight = [0] * layer.out_channels
        channel_weights.append(layer_channel_weight)

    # train the classifier
    classifier.train()
    classifier.backbone.requires_grad = False
    print("Pretrain a classifier to calculate channel attention.")
    for epoch in range(args.attention_epochs):
        losses = AverageMeter('Loss', ':3.2f')
        cls_accs = AverageMeter('Cls Acc', ':3.1f')
        progress = ProgressMeter(len(data_loader), [losses, cls_accs],
                                 prefix="Epoch: [{}]".format(epoch))

        for i, data in enumerate(data_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs, _ = classifier(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_acc = accuracy(outputs, labels)[0]

            losses.update(loss.item(), inputs.size(0))
            cls_accs.update(cls_acc.item(), inputs.size(0))

            if i % args.print_freq == 0:
                progress.display(i)
        lr_scheduler.step()

    # calculate the channel attention
    print('Calculating channel attention.')
    classifier.eval()
    if args.attention_iteration_limit > 0:
        total_iteration = min(len(data_loader), args.attention_iteration_limit)
    else:
        total_iteration = len(args.data_loader)

    progress = ProgressMeter(total_iteration, [], prefix="Iteration: ")

    for i, data in enumerate(data_loader):
        if i >= total_iteration:
            break
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs, _ = classifier(inputs)
        loss_0 = criterion(outputs, labels)
        progress.display(i)
        for layer_id, name in enumerate(tqdm(return_layers)):
            layer = get_attribute(classifier, name)
            for j in range(layer.out_channels):
                tmp = classifier.state_dict()[name + '.weight'][j, ].clone()
                classifier.state_dict()[name + '.weight'][j, ] = 0.0
                outputs, _ = classifier(inputs)
                loss_1 = criterion(outputs, labels)
                difference = loss_1 - loss_0
                difference = difference.detach().cpu().numpy().item()
                history_value = channel_weights[layer_id][j]
                channel_weights[layer_id][j] = 1.0 * (i * history_value +
                                                      difference) / (i + 1)
                classifier.state_dict()[name + '.weight'][j, ] = tmp

    channel_attention = []
    for weight in channel_weights:
        weight = np.array(weight)
        weight = (weight - np.mean(weight)) / np.std(weight)
        weight = torch.from_numpy(weight).float().to(device)
        channel_attention.append(F.softmax(weight / 5).detach())
    return channel_attention