Пример #1
0
def apn_train_epoch(epoch, data_loader, model, optimizer, device, tb_writer):

    print('train_apn at epoch {}'.format(epoch))

    model.train()
    for i, (inputs, targets) in enumerate(data_loader):

        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()

        t0 = time.time()
        logits, _, _, _ = model(inputs)

        optimizer.zero_grad()
        preds = []
        for j in range(len(targets)):
            pred = [logit[j][targets[j]] for logit in logits]
            preds.append(pred)
        apn_loss = pairwise_ranking_loss(preds)
        apn_loss.backward()
        optimizer.step()
        t1 = time.time()
        itera = (epoch - 1) * int(len(data_loader)) + (i + 1)

        # if (itera % 20) == 0:
        print(
            " [*] apn_epoch[%d], apn_iter %d || apn_loss: %.4f || Timer: %.4fsec"
            % (epoch, i, apn_loss.item(), (t1 - t0)))
        tb_writer.add_scalar('train/rank_loss', apn_loss.item(), itera)
Пример #2
0
        Detailed description is in 'Attention localization and amplification' part.
        Forward function will not changed. backward function will not opearate with autograd, but munually implemented function
    """
    def forward(self, images, locs):
        return AttentionCropFunction.apply(images, locs)


if __name__ == '__main__':
    print(" [*] RACNN forward test...")
    x = torch.randn([2, 3, 448, 448])
    net = RACNN(num_classes=200)
    logits, conv5s, attens = net(x)
    print(" [*] logits[0]:", logits[0].size())

    from Loss import multitask_loss, pairwise_ranking_loss
    target_cls = torch.LongTensor([100, 150])

    preds = []
    for i in range(len(target_cls)):
        pred = [logit[i][target_cls[i]] for logit in logits]
        preds.append(pred)
    loss_cls = multitask_loss(logits, target_cls)
    loss_rank = pairwise_ranking_loss(preds)
    print(" [*] Loss cls:", loss_cls)
    print(" [*] Loss rank:", loss_rank)

    print(" [*] Backward test")
    loss = loss_cls + loss_rank
    loss.backward()
    print(" [*] Backward done")