Exemplo n.º 1
0
def main(args):
    fix(args.seed)
    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    
    print(args)
    # Create data loaders
    dataset, test_dataset, num_classes, source_train_loader, grid_query_loader, grid_gallery_loader,prid_query_loader, prid_gallery_loader,viper_query_loader, viper_gallery_loader, ilid_query_loader, ilid_gallery_loader = \
        get_data(args.data_dir, args.height,
                 args.width, args.batch_size, args.num_instance, args.re, args.workers)

    # Create model
    Encoder, Transfer, CamDis = models.create(args.arch, num_features=args.features,
                          dropout=args.dropout, num_classes=num_classes)

    invNet = InvNet(args.features, num_classes, args.batch_size, beta=args.beta, knn=args.knn, alpha=args.alpha).cuda()

    # Load from checkpoint
    start_epoch = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        Encoder.load_state_dict(checkpoint['Encoder'])
        Transfer.load_state_dict(checkpoint['Transfer'])
        CamDis.load_state_dict(checkpoint['CamDis'])
        invNet.load_state_dict(checkpoint['InvNet'])
        start_epoch = checkpoint['epoch']

    Encoder = Encoder.cuda()
    Transfer = Transfer.cuda()
    CamDis = CamDis.cuda()

    model = [Encoder, Transfer, CamDis]
    # Evaluator
    evaluator = Evaluator(model)
    if args.evaluate:
        # -----------------------------
        v = evaluator.eval_viper(viper_query_loader, viper_gallery_loader, test_dataset.viper_query, test_dataset.viper_gallery, args.output_feature, seed=97)
        p = evaluator.eval_prid(prid_query_loader, prid_gallery_loader, test_dataset.prid_query, test_dataset.prid_gallery, args.output_feature, seed=40)
        g = evaluator.eval_grid(grid_query_loader, grid_gallery_loader, test_dataset.grid_query, test_dataset.grid_gallery, args.output_feature, seed=28)
        l = evaluator.eval_ilids(ilid_query_loader, test_dataset.ilid_query, args.output_feature, seed=24)
        # -----------------------------

    criterion = []
    criterion.append(nn.CrossEntropyLoss().cuda())
    criterion.append(TripletLoss(margin=args.margin))


    # Optimizer
    base_param_ids = set(map(id, Encoder.base.parameters()))
    new_params = [p for p in Encoder.parameters() if
                    id(p) not in base_param_ids]
    param_groups = [
        {'params': Encoder.base.parameters(), 'lr_mult': 0.1},
        {'params': new_params, 'lr_mult': 1.0}]

    optimizer_Encoder = torch.optim.SGD(param_groups, lr=args.lr,
                                momentum=0.9, weight_decay=5e-4, nesterov=True)
    # ====
    base_param_ids = set(map(id, Transfer.base.parameters()))
    new_params = [p for p in Transfer.parameters() if
                    id(p) not in base_param_ids]
    param_groups = [
        {'params': Transfer.base.parameters(), 'lr_mult': 0.1},
        {'params': new_params, 'lr_mult': 1.0}]

    optimizer_Transfer = torch.optim.SGD(param_groups, lr=args.lr,
                                momentum=0.9, weight_decay=5e-4, nesterov=True)
    # ====
    param_groups = [
        {'params':CamDis.parameters(), 'lr_mult':1.0},
    ]
    optimizer_Cam = torch.optim.SGD(param_groups, lr=args.lr,momentum=0.9, weight_decay=5e-4, nesterov=True)

    optimizer = [optimizer_Encoder, optimizer_Transfer, optimizer_Cam]

    # Trainer
    trainer = Trainer(model, criterion, InvNet=invNet)

    # Schedule learning rate
    def adjust_lr(epoch):
        step_size = 40
        lr = args.lr * (0.1 ** ((epoch) // step_size))
        for g in optimizer_Encoder.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)
        for g in optimizer_Transfer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)
        for g in optimizer_Cam.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, source_train_loader, optimizer, args.tri_weight, args.adv_weight, args.mem_weight)

        save_checkpoint({
            'Encoder': Encoder.state_dict(),
            'Transfer': Transfer.state_dict(),
            'CamDis': CamDis.state_dict(),
            'InvNet': invNet.state_dict(),
            'epoch': epoch + 1,
        }, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        evaluator = Evaluator(model)
        print('\n * Finished epoch {:3d} \n'.
              format(epoch))

    # Final test
    print('Test with best model:')
    evaluator = Evaluator(model)
    evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery, args.output_feature, args.rerank)
Exemplo n.º 2
0
def main(args):
    # For fast training.
    cudnn.benchmark = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    print('log_dir=', args.logs_dir)

    # Print logs
    print(args)

    # Create data loaders
    dataset, num_classes, source_train_loader, target_train_loader, \
    query_loader, gallery_loader = get_data(args.data_dir, args.source,
                                            args.target, args.height,
                                            args.width, args.batch_size,
                                            args.re, args.workers)

    # Create model
    model = models.create(args.arch, num_features=args.features,
                          dropout=args.dropout, num_classes=num_classes)

    # Invariance learning model
    num_tgt = len(dataset.target_train)
    model_inv = InvNet(args.features, num_tgt,
                        beta=args.inv_beta, knn=args.knn,
                        alpha=args.inv_alpha)

    # Load from checkpoint
    start_epoch = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        model_inv.load_state_dict(checkpoint['state_dict_inv'])
        start_epoch = checkpoint['epoch']
        print("=> Start epoch {} "
              .format(start_epoch))

    # Set model
    model = nn.DataParallel(model).to(device)
    model_inv = model_inv.to(device)

    # Evaluator
    evaluator = Evaluator(model)
    if args.evaluate:
        print("Test:")
        evaluator.evaluate(query_loader, gallery_loader, dataset.query,
                           dataset.gallery, args.output_feature)
        return

    # Optimizer
    base_param_ids = set(map(id, model.module.base.parameters()))

    base_params_need_for_grad = filter(lambda p: p.requires_grad, model.module.base.parameters())

    new_params = [p for p in model.parameters() if
                    id(p) not in base_param_ids]
    param_groups = [
        {'params': base_params_need_for_grad, 'lr_mult': 0.1},
        {'params': new_params, 'lr_mult': 1.0}]

    optimizer = torch.optim.SGD(param_groups, lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # Trainer
    trainer = Trainer(model, model_inv, lmd=args.lmd, include_mmd=args.include_mmd)

    # Schedule learning rate
    def adjust_lr(epoch):
        step_size = args.epochs_decay
        lr = args.lr * (0.1 ** (epoch // step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, source_train_loader, target_train_loader, optimizer)

        save_checkpoint({
            'state_dict': model.module.state_dict(),
            'state_dict_inv': model_inv.state_dict(),
            'epoch': epoch + 1,
        }, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d} \n'.
              format(epoch))

    # Final test
    print('Test with best model:')
    evaluator = Evaluator(model)
    evaluator.evaluate(query_loader, gallery_loader, dataset.query,
                       dataset.gallery, args.output_feature)
Exemplo n.º 3
0
def main(args):
    # For fast training.
    cudnn.benchmark = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    print('log_dir=', args.logs_dir)

    # Print logs
    print(args)

    # Create data loaders
    dataset,dataset_2, num_classes, source_train_loader, target_train_loader, \
    query_loader, gallery_loader,query_loader_2, gallery_loader_2 = get_data(args.data_dir, args.source,
                                            args.target, args.height,
                                            args.width, args.batch_size,
                                            args.re, args.workers)

    # Create model
    model = models.create(args.arch, num_features=args.features,
                          dropout=args.dropout, num_classes=num_classes)

    model_ema = models.create(args.arch, num_features=args.features,
                          dropout=args.dropout, num_classes=num_classes)  #####new add

    # Invariance learning model
    num_tgt = len(dataset.target_train)
    model_inv = InvNet(args.features, num_tgt,
                        beta=args.inv_beta, knn=args.knn,
                        alpha=args.inv_alpha)

    # Load from checkpoint
    start_epoch = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
      #  model_inv.load_state_dict(checkpoint['state_dict_inv'])
        start_epoch = checkpoint['epoch']
        print("=> Start epoch {} "
              .format(start_epoch))

    
    checkpoint = load_checkpoint(args.init_1)
    model.load_state_dict(checkpoint['state_dict'])
    model_ema.load_state_dict(checkpoint['state_dict'])
    
    '''    
    checkpoint = load_checkpoint(args.init_1)
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    model_ema.load_state_dict(model_dict)
    '''
    # Set model
    model = nn.DataParallel(model).to(device)
    model_ema = nn.DataParallel(model_ema).to(device)#####new add
    model_inv = model_inv.to(device)

    
    
    '''
    initial_weights = load_checkpoint(args.init_1)
    copy_state_dict(initial_weights['state_dict'], model)
    copy_state_dict(initial_weights['state_dict'], model_ema)
    model_ema.module.classifier.weight.data.copy_(model.module.classifier.weight.data)
    '''
   # model.load_state_dict(checkpoint['state_dict'])
   # model_ema.load_state_dict(checkpoint['state_dict'])
   # copy_state_dict(initial_weights['state_dict'], model)
    #copy_state_dict(initial_weights['state_dict'], model_ema)
   # model_ema.module.classifier.weight.data.copy_(model.module.classifier.weight.data)
    
    # Evaluator
    evaluator = Evaluator(model)


    # Final test
    print('Test with best model:')
    evaluator = Evaluator(model)
    evaluator.evaluate(query_loader, gallery_loader, dataset.query,
                       dataset.gallery, args.output_feature)