Ejemplo n.º 1
0
def main(config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()

    logger.info(str(model))

    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")

    optimizer = None
    lr_scheduler = None

    assert config.MODEL.RESUME
    loss_scaler = NativeScalerWithGradNormCount()
    load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler,
                    loss_scaler, logger)
    if not args.skip_eval and not args.check_saved_logits:
        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: top-1 acc: {acc1:.1f}%, top-5 acc: {acc5:.1f}%"
        )

    if args.check_saved_logits:
        logger.info("Start checking logits")
    else:
        logger.info("Start saving logits")

    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        dataset_train.set_epoch(epoch)
        data_loader_train.sampler.set_epoch(epoch)

        if args.check_saved_logits:
            check_logits_one_epoch(config,
                                   model,
                                   data_loader_train,
                                   epoch,
                                   mixup_fn=mixup_fn)
        else:
            save_logits_one_epoch(config,
                                  model,
                                  data_loader_train,
                                  epoch,
                                  mixup_fn=mixup_fn)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Saving logits time {}'.format(total_time_str))
Ejemplo n.º 2
0
def main(config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()
    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    if config.AMP_OPT_LEVEL != "O0":
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.AMP_OPT_LEVEL)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))

    if config.AUG.MIXUP > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif config.MODEL.LABEL_SMOOTHING > 0.:
        criterion = LabelSmoothingCrossEntropy(
            smoothing=config.MODEL.LABEL_SMOOTHING)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    max_accuracy = 0.0

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
            )

    if config.MODEL.RESUME:
        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer,
                                       lr_scheduler, logger)
        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        if config.EVAL_MODE:
            return

    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
        load_pretrained(config, model_without_ddp, logger)
        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )

    if config.THROUGHPUT_MODE:
        throughput(data_loader_val, model, logger)
        return

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, model, criterion, data_loader_train, optimizer,
                        epoch, mixup_fn, lr_scheduler)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0
                                     or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, model_without_ddp, max_accuracy,
                            optimizer, lr_scheduler, logger)

        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        max_accuracy = max(max_accuracy, acc1)
        logger.info(f'Max accuracy: {max_accuracy:.2f}%')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
Ejemplo n.º 3
0
def main(config):
    dataset_train, _, data_loader_train, _, _ = build_loader(config)

    config.defrost()
    config.DATA.TRAINING_IMAGES = len(dataset_train)
    config.freeze()

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()
    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    if config.AMP_OPT_LEVEL != "O0":
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.AMP_OPT_LEVEL)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
            )

    if config.MODEL.RESUME:
        _ = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler,
                            logger)

    logger.info("Start self-supervised pre-training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, model, data_loader_train, optimizer, epoch,
                        lr_scheduler)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0
                                     or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, model_without_ddp, 0.0, optimizer,
                            lr_scheduler, logger)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
Ejemplo n.º 4
0
def main(config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)

    if config.DISTILL.DO_DISTILL:
        logger.info(
            f"Loading teacher model:{config.MODEL.TYPE}/{config.DISTILL.TEACHER}"
        )
        model_checkpoint_name = os.path.basename(config.DISTILL.TEACHER)
        if 'regnety_160' in model_checkpoint_name:
            model_teacher = create_model(
                'regnety_160',
                pretrained=False,
                num_classes=config.MODEL.NUM_CLASSES,
                global_pool='avg',
            )
            if config.DISTILL.TEACHER.startswith('https'):
                checkpoint = torch.hub.load_state_dict_from_url(
                    config.DISTILL.TEACHER,
                    map_location='cpu',
                    check_hash=True)
            else:
                checkpoint = torch.load(config.DISTILL.TEACHER,
                                        map_location='cpu')
            model_teacher.load_state_dict(checkpoint['model'])
            model_teacher.cuda()
            model_teacher.eval()
            del checkpoint
            torch.cuda.empty_cache()
        else:
            if 'base' in model_checkpoint_name:
                teacher_type = 'base'
            elif 'large' in model_checkpoint_name:
                teacher_type = 'large'
            else:
                teacher_type = None
            model_teacher = load_teacher_model(type=teacher_type)
            model_teacher.cuda()
            model_teacher = torch.nn.parallel.DistributedDataParallel(
                model_teacher,
                device_ids=[config.LOCAL_RANK],
                broadcast_buffers=False)
            checkpoint = torch.load(config.DISTILL.TEACHER, map_location='cpu')
            msg = model_teacher.module.load_state_dict(checkpoint['model'],
                                                       strict=False)
            logger.info(msg)
            del checkpoint
            torch.cuda.empty_cache()

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()
    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    if config.AMP_OPT_LEVEL != "O0":
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.AMP_OPT_LEVEL)
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[config.LOCAL_RANK],
        broadcast_buffers=False,
        find_unused_parameters=True)

    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
    criterion_soft = soft_cross_entropy
    criterion_attn = cal_relation_loss
    criterion_hidden = cal_hidden_relation_loss if config.DISTILL.HIDDEN_RELATION else cal_hidden_loss

    if config.AUG.MIXUP > 0.:
        # smoothing is handled with mixup label transform
        criterion_truth = SoftTargetCrossEntropy()
    elif config.MODEL.LABEL_SMOOTHING > 0.:
        criterion_truth = LabelSmoothingCrossEntropy(
            smoothing=config.MODEL.LABEL_SMOOTHING)
    else:
        criterion_truth = torch.nn.CrossEntropyLoss()

    max_accuracy = 0.0

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.DISTILL.RESUME_WEIGHT_ONLY = False
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
            )

    if config.MODEL.RESUME:
        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer,
                                       lr_scheduler, logger)
        acc1, acc5, loss = validate(config, data_loader_val, model, logger)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        if config.EVAL_MODE:
            return

    if config.THROUGHPUT_MODE:
        throughput(data_loader_val, model, logger)
        return

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        if config.DISTILL.DO_DISTILL:
            train_one_epoch_distill(config,
                                    model,
                                    model_teacher,
                                    data_loader_train,
                                    optimizer,
                                    epoch,
                                    mixup_fn,
                                    lr_scheduler,
                                    criterion_soft=criterion_soft,
                                    criterion_truth=criterion_truth,
                                    criterion_attn=criterion_attn,
                                    criterion_hidden=criterion_hidden)
        else:
            train_one_epoch(config, model, criterion_truth, data_loader_train,
                            optimizer, epoch, mixup_fn, lr_scheduler)

        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0
                                     or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, model_without_ddp, max_accuracy,
                            optimizer, lr_scheduler, logger)

        if epoch % config.EVAL_FREQ == 0 or epoch == config.TRAIN.EPOCHS - 1:
            acc1, acc5, loss = validate(config, data_loader_val, model, logger)
            logger.info(
                f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
            )
            max_accuracy = max(max_accuracy, acc1)
            logger.info(f'Max accuracy: {max_accuracy:.2f}%')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
Ejemplo n.º 5
0
def main(args, config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()

    if args.use_sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
    loss_scaler = NativeScalerWithGradNormCount()
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(
        config, optimizer,
        len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)

    if config.DISTILL.ENABLED:
        # we disable MIXUP and CUTMIX when knowledge distillation
        assert len(config.DISTILL.TEACHER_LOGITS_PATH
                   ) > 0, "Please fill in DISTILL.TEACHER_LOGITS_PATH"
        criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    else:
        if config.AUG.MIXUP > 0.:
            # smoothing is handled with mixup label transform
            criterion = SoftTargetCrossEntropy()
        elif config.MODEL.LABEL_SMOOTHING > 0.:
            criterion = LabelSmoothingCrossEntropy(
                smoothing=config.MODEL.LABEL_SMOOTHING)
        else:
            criterion = torch.nn.CrossEntropyLoss()

    max_accuracy = 0.0

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
            )

    if config.MODEL.RESUME:
        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer,
                                       lr_scheduler, loss_scaler, logger)
        acc1, acc5, loss = validate(args, config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        if config.EVAL_MODE:
            return

    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
        load_pretrained(config, model_without_ddp, logger)
        acc1, acc5, loss = validate(args, config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )

    if config.THROUGHPUT_MODE:
        throughput(data_loader_val, model, logger)
        return

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        # set_epoch for dataset_train when distillation
        if hasattr(dataset_train, 'set_epoch'):
            dataset_train.set_epoch(epoch)
        data_loader_train.sampler.set_epoch(epoch)

        if config.DISTILL.ENABLED:
            train_one_epoch_distill_using_saved_logits(
                args, config, model, criterion, data_loader_train, optimizer,
                epoch, mixup_fn, lr_scheduler, loss_scaler)
        else:
            train_one_epoch(args, config, model, criterion, data_loader_train,
                            optimizer, epoch, mixup_fn, lr_scheduler,
                            loss_scaler)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0
                                     or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, model_without_ddp, max_accuracy,
                            optimizer, lr_scheduler, loss_scaler, logger)

        acc1, acc5, loss = validate(args, config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        max_accuracy = max(max_accuracy, acc1)
        logger.info(f'Max accuracy: {max_accuracy:.2f}%')

        if is_main_process() and args.use_wandb:
            wandb.log({
                f"val/acc@1": acc1,
                f"val/acc@5": acc5,
                f"val/loss": loss,
                "epoch": epoch,
            })
            wandb.run.summary['epoch'] = epoch
            wandb.run.summary['best_acc@1'] = max_accuracy

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))