コード例 #1
0
def apn_train_epoch(epoch, data_loader, model, optimizer, device, tb_writer):

    print('train_apn at epoch {}'.format(epoch))

    model.train()
    for i, (inputs, targets) in enumerate(data_loader):

        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()

        t0 = time.time()
        logits, _, _, _ = model(inputs)

        optimizer.zero_grad()
        preds = []
        for j in range(len(targets)):
            pred = [logit[j][targets[j]] for logit in logits]
            preds.append(pred)
        apn_loss = pairwise_ranking_loss(preds)
        apn_loss.backward()
        optimizer.step()
        t1 = time.time()
        itera = (epoch - 1) * int(len(data_loader)) + (i + 1)

        # if (itera % 20) == 0:
        print(
            " [*] apn_epoch[%d], apn_iter %d || apn_loss: %.4f || Timer: %.4fsec"
            % (epoch, i, apn_loss.item(), (t1 - t0)))
        tb_writer.add_scalar('train/rank_loss', apn_loss.item(), itera)
コード例 #2
0
ファイル: RACNN.py プロジェクト: zfxu/RACNN-pytorch
        Detailed description is in 'Attention localization and amplification' part.
        Forward function will not changed. backward function will not opearate with autograd, but munually implemented function
    """
    def forward(self, images, locs):
        return AttentionCropFunction.apply(images, locs)


if __name__ == '__main__':
    print(" [*] RACNN forward test...")
    x = torch.randn([2, 3, 448, 448])
    net = RACNN(num_classes=200)
    logits, conv5s, attens = net(x)
    print(" [*] logits[0]:", logits[0].size())

    from Loss import multitask_loss, pairwise_ranking_loss
    target_cls = torch.LongTensor([100, 150])

    preds = []
    for i in range(len(target_cls)):
        pred = [logit[i][target_cls[i]] for logit in logits]
        preds.append(pred)
    loss_cls = multitask_loss(logits, target_cls)
    loss_rank = pairwise_ranking_loss(preds)
    print(" [*] Loss cls:", loss_cls)
    print(" [*] Loss rank:", loss_rank)

    print(" [*] Backward test")
    loss = loss_cls + loss_rank
    loss.backward()
    print(" [*] Backward done")