def train(train_iter: ForeverDataIterator, model: Classifier, bss_module, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels = next(train_iter) x = x.to(device) label = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, f = model(x) cls_loss = F.cross_entropy(y, label) bss_loss = bss_module(f) loss = cls_loss + args.trade_off * bss_loss cls_acc = accuracy(y, label)[0] losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_iter: ForeverDataIterator, model: Classifier, backbone_regularization: nn.Module, head_regularization: nn.Module, target_getter: IntermediateLayerGetter, source_getter: IntermediateLayerGetter, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f') losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses, losses_reg_head, losses_reg_backbone, cls_accs ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels = next(train_iter) x = x.to(device) label = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output intermediate_output_s, output_s = source_getter(x) intermediate_output_t, output_t = target_getter(x) y, f = output_t # measure accuracy and record loss cls_acc = accuracy(y, label)[0] cls_loss = F.cross_entropy(y, label) if args.regularization_type == 'feature_map': loss_reg_backbone = backbone_regularization( intermediate_output_s, intermediate_output_t) elif args.regularization_type == 'attention_feature_map': loss_reg_backbone = backbone_regularization( intermediate_output_s, intermediate_output_t) else: loss_reg_backbone = backbone_regularization() loss_reg_head = head_regularization() loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head losses_reg_backbone.update( loss_reg_backbone.item() * args.trade_off_backbone, x.size(0)) losses_reg_head.update(loss_reg_head.item() * args.trade_off_head, x.size(0)) losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def calculate_channel_attention(dataset, return_layers, args): backbone = models.__dict__[args.arch](pretrained=True) classifier = Classifier(backbone, dataset.num_classes).to(device) optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) data_loader = DataLoader(dataset, batch_size=args.attention_batch_size, shuffle=True, num_workers=args.workers, drop_last=False) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=math.exp(math.log(0.1) / args.attention_lr_decay_epochs)) criterion = nn.CrossEntropyLoss() channel_weights = [] for layer_id, name in enumerate(return_layers): layer = get_attribute(classifier, name) layer_channel_weight = [0] * layer.out_channels channel_weights.append(layer_channel_weight) # train the classifier classifier.train() classifier.backbone.requires_grad = False print("Pretrain a classifier to calculate channel attention.") for epoch in range(args.attention_epochs): losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter(len(data_loader), [losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) for i, data in enumerate(data_loader): inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs, _ = classifier(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() cls_acc = accuracy(outputs, labels)[0] losses.update(loss.item(), inputs.size(0)) cls_accs.update(cls_acc.item(), inputs.size(0)) if i % args.print_freq == 0: progress.display(i) lr_scheduler.step() # calculate the channel attention print('Calculating channel attention.') classifier.eval() if args.attention_iteration_limit > 0: total_iteration = min(len(data_loader), args.attention_iteration_limit) else: total_iteration = len(args.data_loader) progress = ProgressMeter(total_iteration, [], prefix="Iteration: ") for i, data in enumerate(data_loader): if i >= total_iteration: break inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs, _ = classifier(inputs) loss_0 = criterion(outputs, labels) progress.display(i) for layer_id, name in enumerate(tqdm(return_layers)): layer = get_attribute(classifier, name) for j in range(layer.out_channels): tmp = classifier.state_dict()[name + '.weight'][j, ].clone() classifier.state_dict()[name + '.weight'][j, ] = 0.0 outputs, _ = classifier(inputs) loss_1 = criterion(outputs, labels) difference = loss_1 - loss_0 difference = difference.detach().cpu().numpy().item() history_value = channel_weights[layer_id][j] channel_weights[layer_id][j] = 1.0 * (i * history_value + difference) / (i + 1) classifier.state_dict()[name + '.weight'][j, ] = tmp channel_attention = [] for weight in channel_weights: weight = np.array(weight) weight = (weight - np.mean(weight)) / np.std(weight) weight = torch.from_numpy(weight).float().to(device) channel_attention.append(F.softmax(weight / 5).detach()) return channel_attention