def train(args): """Total training procedure. """ print("Use GPU: {} for training".format(args.local_rank)) if args.local_rank == 0: writer = SummaryWriter(log_dir=args.tensorboardx_logdir) args.writer = writer if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') args.rank = dist.get_rank() #print('args.rank: ', dist.get_rank()) #print('args.get_world_size: ', dist.get_world_size()) #print('is_nccl_available: ', dist.is_nccl_available()) args.world_size = dist.get_world_size() trainset = ImageDataset(args.data_root, args.train_file) train_sampler = torch.utils.data.distributed.DistributedSampler( trainset, shuffle=True) train_loader = DataLoader(dataset=trainset, batch_size=args.batch_size, sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=False) backbone_factory = BackboneFactory(args.backbone_type, args.backbone_conf_file) head_factory = HeadFactory(args.head_type, args.head_conf_file) model = FaceModel(backbone_factory, head_factory) model = model.to(args.local_rank) model.train() for ps in model.parameters(): dist.broadcast(ps, 0) optimizer = build_optimizer(model, args.lr) lr_schedule = build_scheduler(optimizer, len(train_loader), args.epoches, args.warm_up_epoches) model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # DDP model = torch.nn.parallel.DistributedDataParallel( module=model, broadcast_buffers=False, device_ids=[args.local_rank]) criterion = torch.nn.CrossEntropyLoss().to(args.local_rank) loss_meter = AverageMeter() model.train() ori_epoch = 0 for epoch in range(ori_epoch, args.epoches): train_one_epoch(train_loader, model, optimizer, lr_schedule, criterion, epoch, loss_meter, args) dist.destroy_process_group()
def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model) if config.AMP_OPT_LEVEL != "O0": model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion = LabelSmoothingCrossEntropy( smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning( f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info( f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' ) if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))
def main(config): dataset_train, _, data_loader_train, _, _ = build_loader(config) config.defrost() config.DATA.TRAINING_IMAGES = len(dataset_train) config.freeze() logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model) if config.AMP_OPT_LEVEL != "O0": model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning( f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info( f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' ) if config.MODEL.RESUME: _ = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) logger.info("Start self-supervised pre-training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, 0.0, optimizer, lr_scheduler, logger) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))
def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( config) if config.DISTILL.DO_DISTILL: logger.info( f"Loading teacher model:{config.MODEL.TYPE}/{config.DISTILL.TEACHER}" ) model_checkpoint_name = os.path.basename(config.DISTILL.TEACHER) if 'regnety_160' in model_checkpoint_name: model_teacher = create_model( 'regnety_160', pretrained=False, num_classes=config.MODEL.NUM_CLASSES, global_pool='avg', ) if config.DISTILL.TEACHER.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( config.DISTILL.TEACHER, map_location='cpu', check_hash=True) else: checkpoint = torch.load(config.DISTILL.TEACHER, map_location='cpu') model_teacher.load_state_dict(checkpoint['model']) model_teacher.cuda() model_teacher.eval() del checkpoint torch.cuda.empty_cache() else: if 'base' in model_checkpoint_name: teacher_type = 'base' elif 'large' in model_checkpoint_name: teacher_type = 'large' else: teacher_type = None model_teacher = load_teacher_model(type=teacher_type) model_teacher.cuda() model_teacher = torch.nn.parallel.DistributedDataParallel( model_teacher, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) checkpoint = torch.load(config.DISTILL.TEACHER, map_location='cpu') msg = model_teacher.module.load_state_dict(checkpoint['model'], strict=False) logger.info(msg) del checkpoint torch.cuda.empty_cache() logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model) if config.AMP_OPT_LEVEL != "O0": model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False, find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) criterion_soft = soft_cross_entropy criterion_attn = cal_relation_loss criterion_hidden = cal_hidden_relation_loss if config.DISTILL.HIDDEN_RELATION else cal_hidden_loss if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion_truth = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion_truth = LabelSmoothingCrossEntropy( smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion_truth = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning( f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" ) config.defrost() config.MODEL.RESUME = resume_file config.DISTILL.RESUME_WEIGHT_ONLY = False config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info( f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' ) if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) acc1, acc5, loss = validate(config, data_loader_val, model, logger) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.EVAL_MODE: return if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) if config.DISTILL.DO_DISTILL: train_one_epoch_distill(config, model, model_teacher, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, criterion_soft=criterion_soft, criterion_truth=criterion_truth, criterion_attn=criterion_attn, criterion_hidden=criterion_hidden) else: train_one_epoch(config, model, criterion_truth, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) if epoch % config.EVAL_FREQ == 0 or epoch == config.TRAIN.EPOCHS - 1: acc1, acc5, loss = validate(config, data_loader_val, model, logger) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))
def main(args, config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() if args.use_sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) logger.info(str(model)) optimizer = build_optimizer(config, model) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) loss_scaler = NativeScalerWithGradNormCount() model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler( config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) if config.DISTILL.ENABLED: # we disable MIXUP and CUTMIX when knowledge distillation assert len(config.DISTILL.TEACHER_LOGITS_PATH ) > 0, "Please fill in DISTILL.TEACHER_LOGITS_PATH" criterion = torch.nn.CrossEntropyLoss(reduction='mean') else: if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion = LabelSmoothingCrossEntropy( smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning( f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info( f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' ) if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) acc1, acc5, loss = validate(args, config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(args, config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): # set_epoch for dataset_train when distillation if hasattr(dataset_train, 'set_epoch'): dataset_train.set_epoch(epoch) data_loader_train.sampler.set_epoch(epoch) if config.DISTILL.ENABLED: train_one_epoch_distill_using_saved_logits( args, config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler) else: train_one_epoch(args, config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger) acc1, acc5, loss = validate(args, config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') if is_main_process() and args.use_wandb: wandb.log({ f"val/acc@1": acc1, f"val/acc@5": acc5, f"val/loss": loss, "epoch": epoch, }) wandb.run.summary['epoch'] = epoch wandb.run.summary['best_acc@1'] = max_accuracy total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))