Пример #1
0
def main(args):
    cudnn.benchmark = True
    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    # Create data loaders
    dataset, num_classes, train_loader, query_loader, gallery_loader, camstyle_loader = \
        get_data(args.dataset, args.data_dir, args.height,
                 args.width, args.batch_size, args.camstyle, args.re,
                 0 if args.debug else args.workers,
                 camstyle_path = args.camstyle_path)

    # Create model
    model = models.create(args.arch,
                          num_features=args.features,
                          dropout=args.dropout,
                          num_classes=num_classes)

    # Load from checkpoint
    start_epoch = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']
        print("=> Start epoch {} ".format(start_epoch))
    model = nn.DataParallel(model).cuda()

    # Evaluator
    evaluator = Evaluator(model, args.logs_dir)
    if args.evaluate:
        print("Test:")
        evaluator.evaluate(query_loader, gallery_loader, dataset.query,
                           dataset.gallery, args.output_feature, args.rerank)
        return

    # Criterion
    criterion = nn.CrossEntropyLoss().cuda()

    # Optimizer
    base_param_ids = set(map(id, model.module.base.parameters()))
    new_params = [p for p in model.parameters() if id(p) not in base_param_ids]
    param_groups = [{
        'params': model.module.base.parameters(),
        'lr_mult': 0.1
    }, {
        'params': new_params,
        'lr_mult': 1.0
    }]

    optimizer = torch.optim.SGD(param_groups,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # Trainer
    if args.camstyle == 0:
        trainer = Trainer(model, criterion)
    else:
        trainer = CamStyleTrainer(model, criterion, camstyle_loader)

    # Schedule learning rate
    def adjust_lr(epoch):
        step_size = 40
        lr = args.lr * (0.1**(epoch // step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer)

        save_checkpoint(
            {
                'state_dict': model.module.state_dict(),
                'epoch': epoch + 1,
            },
            fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d} \n'.format(epoch))

    # Final test
    print('Test with best model:')
    evaluator = Evaluator(model, args.logs_dir)
    evaluator.evaluate(query_loader, gallery_loader, dataset.query,
                       dataset.gallery, args.output_feature, args.rerank)
Пример #2
0
def main(args):
    cudnn.benchmark = True
    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    # Create data loaders
    dataset, num_classes, train_loader, query_loader, gallery_loader, camstyle_loader = \
        get_data(args.dataset, args.data_dir, args.height,
                 args.width, args.batch_size, args.camstyle, args.re, args.workers)

    # Create model
    model = models.create(args.arch,
                          num_features=args.features,
                          dropout=args.dropout,
                          num_classes=num_classes)

    # Load from checkpoint
    start_epoch = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']
        print("=> Start epoch {} ".format(start_epoch))
    model = nn.DataParallel(model).cuda()

    # Evaluator
    evaluator = Evaluator(model)
    if args.evaluate:
        print("Test:")
        evaluator.evaluate(query_loader, gallery_loader, dataset.query,
                           dataset.gallery, args.output_feature, args.rerank)
        return

    # Criterion
    #criterion = nn.CrossEntropyLoss().cuda()

    class LSROloss(nn.Module):
        def __init__(self):  # change target to range(0,750)
            super(LSROloss, self).__init__()
            #input means the prediction score(torch Variable) 32*752,target means the corresponding label,
        def forward(
            self, input, target, flg
        ):  # while flg means the flag(=0 for true data and 1 for generated data)  batchsize*1
            # print(type(input))
            if input.dim(
            ) > 2:  # N defines the number of images, C defines channels,  K class in total
                input = input.view(input.size(0), input.size(1),
                                   -1)  # N,C,H,W => N,C,H*W
                input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
                input = input.contiguous().view(
                    -1, input.size(2))  # N,H*W,C => N*H*W,C

        # normalize input
            maxRow, _ = torch.max(
                input.data, 1
            )  # outputs.data  return the index of the biggest value in each row
            maxRow = maxRow.unsqueeze(1)
            input.data = input.data - maxRow

            target = target.view(-1, 1)  # batchsize*1
            flg = flg.view(-1, 1)
            #len=flg.size()[0]
            flos = F.log_softmax(input)  # N*K?      batchsize*751
            flos = torch.sum(flos, 1) / flos.size(
                1)  # N*1  get average      gan loss
            logpt = F.log_softmax(input)  # size: batchsize*751
            #print("logpt",logpt.size())
            #print("taarget", target.size())
            logpt = logpt.gather(1, target)  # here is a problem
            logpt = logpt.view(-1)  # N*1     original loss
            flg = flg.view(-1)
            flg = flg.type(torch.cuda.FloatTensor)
            #print("logpt",logpt.size())
            #print("flg", flg.size())
            #print("flos", flos.size())
            loss = -1 * logpt * (1 - flg) - flos * flg
            return loss.mean()

    criterion = LSROloss()

    # Optimizer
    base_param_ids = set(map(id, model.module.base.parameters()))
    new_params = [p for p in model.parameters() if id(p) not in base_param_ids]
    param_groups = [{
        'params': model.module.base.parameters(),
        'lr_mult': 0.1
    }, {
        'params': new_params,
        'lr_mult': 1.0
    }]

    optimizer = torch.optim.SGD(param_groups,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # Trainer
    if args.camstyle == 0:
        trainer = Trainer(model, criterion)
    else:
        trainer = CamStyleTrainer(model, criterion, camstyle_loader)

    # Schedule learning rate
    def adjust_lr(epoch):
        step_size = 40
        lr = args.lr * (0.1**(epoch // step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer)

        save_checkpoint(
            {
                'state_dict': model.module.state_dict(),
                'epoch': epoch + 1,
            },
            fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d} \n'.format(epoch))

    # Final test
    print('Test with best model:')
    evaluator = Evaluator(model)
    evaluator.evaluate(query_loader, gallery_loader, dataset.query,
                       dataset.gallery, args.output_feature, args.rerank)