Ejemplo n.º 1
0
def train(conf):
    """Total training procedure.
    """
    data_loader = DataLoader(ImageDataset(conf.data_root, conf.train_file),
                             conf.batch_size,
                             True,
                             num_workers=4)
    conf.device = torch.device('cuda:0')
    criterion = torch.nn.CrossEntropyLoss().cuda(conf.device)
    backbone_factory = BackboneFactory(conf.backbone_type,
                                       conf.backbone_conf_file)
    head_factory = HeadFactory(conf.head_type, conf.head_conf_file)
    model = FaceModel(backbone_factory, head_factory)
    ori_epoch = 0
    if conf.resume:
        ori_epoch = torch.load(args.pretrain_model)['epoch'] + 1
        state_dict = torch.load(args.pretrain_model)['state_dict']
        model.load_state_dict(state_dict)
    model = model.cuda()
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(parameters,
                          lr=conf.lr,
                          momentum=conf.momentum,
                          weight_decay=1e-4)
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    model = torch.nn.DataParallel(model).cuda()
    lr_schedule = optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=conf.milestones,
                                                 gamma=0.1)
    loss_meter = AverageMeter()
    model.train()
    for epoch in range(ori_epoch, conf.epoches):
        train_one_epoch(data_loader, model, optimizer, criterion, epoch,
                        loss_meter, conf)
        lr_schedule.step()
Ejemplo n.º 2
0
def train(args):
    """Total training procedure.
    """
    print("Use GPU: {} for training".format(args.local_rank))
    if args.local_rank == 0:
        writer = SummaryWriter(log_dir=args.tensorboardx_logdir)
        args.writer = writer
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
    dist.init_process_group(backend='nccl', init_method='env://')
    torch.cuda.set_device(args.local_rank)
    args.rank = dist.get_rank()
    #print('args.rank: ', dist.get_rank())
    #print('args.get_world_size: ', dist.get_world_size())
    #print('is_nccl_available: ', dist.is_nccl_available())
    args.world_size = dist.get_world_size()
    trainset = ImageDataset(args.data_root, args.train_file)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, shuffle=True)
    train_loader = DataLoader(dataset=trainset,
                              batch_size=args.batch_size,
                              sampler=train_sampler,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=False)

    backbone_factory = BackboneFactory(args.backbone_type,
                                       args.backbone_conf_file)
    head_factory = HeadFactory(args.head_type, args.head_conf_file)
    model = FaceModel(backbone_factory, head_factory)
    model = model.to(args.local_rank)
    model.train()
    for ps in model.parameters():
        dist.broadcast(ps, 0)
    # DDP
    model = torch.nn.parallel.DistributedDataParallel(
        module=model, broadcast_buffers=False, device_ids=[args.local_rank])
    criterion = torch.nn.CrossEntropyLoss().to(args.local_rank)
    ori_epoch = 0
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(parameters,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=1e-4)
    lr_schedule = optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=args.milestones,
                                                 gamma=0.1)
    loss_meter = AverageMeter()
    model.train()
    for epoch in range(ori_epoch, args.epoches):
        train_one_epoch(train_loader, model, optimizer, criterion, epoch,
                        loss_meter, args)
        lr_schedule.step()
    dist.destroy_process_group()
Ejemplo n.º 3
0
def train(conf):
    """Total training procedure.
    """
    if conf.virtual_batch:
        assert conf.batch_size % 64 == 0
        update_interval = conf.batch_size // 64
        batch_per_epoch = 64
    else:
        update_interval = 1
        batch_per_epoch = conf.batch_size
    data_loader = DataLoader(ImageDataset(conf.data_root, conf.train_file), 
                             batch_per_epoch, True, num_workers = 6)
    conf.device = torch.device('cuda:0')
    criterion = torch.nn.CrossEntropyLoss().cuda(conf.device)
    backbone_factory = BackboneFactory(conf.backbone_type, conf.backbone_conf_file)    
    head_factory = HeadFactory(conf.head_type, conf.head_conf_file)
    model = FaceModel(backbone_factory, head_factory)
    ori_epoch = 0
    if conf.pretrain_model != '':
        state_dict = torch.load(args.pretrain_model)['state_dict']
        model.load_state_dict(state_dict)
        if conf.resume:
            ori_epoch = torch.load(args.pretrain_model)['epoch'] + 1
        del state_dict
    model = model.cuda()
    # parameters = [p for p in model.parameters() if p.requires_grad]
    backbone_parameters = [p for n, p in model.named_parameters() if ("backbone" in n) and (p.requires_grad)]
    head_parameters = [p for n, p in model.named_parameters() if ("head" in n) and (p.requires_grad)]
    optimizer = optim.AdamW(backbone_parameters + head_parameters, lr = conf.lr, weight_decay = 3e-5)
    if conf.resume:
        for param_group in optimizer.param_groups:
            param_group['initial_lr'] = args.lr
    scaler = torch.cuda.amp.GradScaler()
    model = torch.nn.DataParallel(model).cuda()
    lr_schedule = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones = conf.milestones, gamma = 0.1, last_epoch=ori_epoch-1)
    loss_meter = AverageMeter()
    model.train()
    for epoch in range(ori_epoch, conf.epoches):
        train_one_epoch(data_loader, model, optimizer, 
                        criterion, epoch, loss_meter, backbone_parameters, conf, scaler, update_interval)
        lr_schedule.step()