Ejemplo n.º 1
0
def train(conf):
    """Total training procedure. 
    """ 
    conf.device = torch.device('cuda:0')
    criterion = torch.nn.CrossEntropyLoss().cuda(conf.device)
    backbone_factory = BackboneFactory(conf.backbone_type, conf.backbone_conf_file)
    probe_net = backbone_factory.get_backbone()
    gallery_net = backbone_factory.get_backbone()        
    head_factory = HeadFactory(conf.head_type, conf.head_conf_file)
    prototype = head_factory.get_head().cuda(conf.device)
    probe_net = torch.nn.DataParallel(probe_net).cuda()
    gallery_net = torch.nn.DataParallel(gallery_net).cuda()
    optimizer = optim.SGD(probe_net.parameters(), lr=conf.lr, momentum=conf.momentum, weight_decay=5e-4)
    lr_schedule = optim.lr_scheduler.MultiStepLR(optimizer, milestones=conf.milestones, gamma=0.1)
    if conf.resume:
        probe_net.load_state_dict(torch.load(args.pretrain_model))
    moving_average(probe_net, gallery_net, 0)
    probe_net.train()
    gallery_net.eval().apply(train_BN)    

    exclude_id_set = set()
    loss_meter = AverageMeter()
    for epoch in range(conf.epoches):
        data_loader = DataLoader(
            ImageDataset_SST(conf.data_root, conf.train_file, exclude_id_set), 
            conf.batch_size, True, num_workers = 4, drop_last = True)
        exclude_id_set = train_one_epoch(data_loader, probe_net, gallery_net, 
            prototype, optimizer, criterion, epoch, conf, loss_meter)
        lr_schedule.step()
Ejemplo n.º 2
0
def train(conf):
    """Total training procedure.
    """
    data_loader = DataLoader(ImageDataset_SST(conf.data_root, conf.train_file),
                             conf.batch_size,
                             True,
                             num_workers=4)
    conf.device = torch.device('cuda:0')
    #criterion = OnlineContrastiveLoss(margin=2.5, pair_selector=HardNegativePairSelector(cpu=False)).cuda(torch.device('cuda:0'))

    triplet_selector = FunctionNegativeTripletSelector(
        margin=2.5, negative_selection_fn=random_hard_negative, cpu=False)
    criterion = OnlineTripletLoss(margin=2.5,
                                  triplet_selector=triplet_selector).cuda(
                                      conf.device)
    backbone_factory = BackboneFactory(conf.backbone_type,
                                       conf.backbone_conf_file)
    model = backbone_factory.get_backbone()
    if conf.resume:
        model.load_state_dict(torch.load(args.pretrain_model))
    model = torch.nn.DataParallel(model).cuda()
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(parameters,
                          lr=conf.lr,
                          momentum=conf.momentum,
                          weight_decay=1e-4)
    lr_schedule = optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=conf.milestones,
                                                 gamma=0.1)
    loss_meter = AverageMeter()
    model.train()
    for epoch in range(conf.epoches):
        train_one_epoch(data_loader, model, optimizer, criterion, epoch,
                        loss_meter, conf)
        lr_schedule.step()
        if conf.evaluate:
            conf.evaluator.evaluate(model)
Ejemplo n.º 3
0
"""
@author: Jun Wang 
@date: 20201012 
@contact: [email protected]
"""

import sys
import torch
from thop import profile
from thop import clever_format

sys.path.append('..')
from backbone.backbone_def import BackboneFactory

backbone_type = 'MobileFaceNet'
#backbone_type = 'ResNet'
#backbone_type = 'EfficientNet'
#backbone_type = 'HRNet'
#backbone_type = 'GhostNet'
#backbone_type = 'AttentionNet'

backbone_conf_file = '../training_mode/backbone_conf.yaml'
backbone_factory = BackboneFactory(backbone_type, backbone_conf_file)
backbone = backbone_factory.get_backbone()
input = torch.randn(1, 3, 112, 112)
macs, params = profile(backbone, inputs=(input, ))
macs, params = clever_format([macs, params], "%.2f")
print('backbone type: ', backbone_type)
print('Params: ', params)
print('Macs: ', macs)