Ejemplo n.º 1
0
def main(args):
    train_prefix = 'train'
    train_loader = get_loader(
        args.aug,
        args,
        two_crop=args.model in ['PixPro'],
        prefix=train_prefix,
        return_coord=True,
    )

    args.num_instances = len(train_loader.dataset)
    logger.info(f"length of training dataset: {args.num_instances}")

    model, optimizer = build_model(args)
    scheduler = get_scheduler(optimizer, len(train_loader), args)

    # optionally resume from a checkpoint
    if args.pretrained_model:
        assert os.path.isfile(args.pretrained_model)
        load_pretrained(model, args.pretrained_model)
    if args.auto_resume:
        resume_file = os.path.join(args.output_dir, "current.pth")
        if os.path.exists(resume_file):
            logger.info(f'auto resume from {resume_file}')
            args.resume = resume_file
        else:
            logger.info(
                f'no checkpoint found in {args.output_dir}, ignoring auto resume'
            )
    if args.resume:
        assert os.path.isfile(args.resume)
        load_checkpoint(args,
                        model,
                        optimizer,
                        scheduler,
                        sampler=train_loader.sampler)

    # tensorboard
    if dist.get_rank() == 0:
        summary_writer = SummaryWriter(log_dir=args.output_dir)
    else:
        summary_writer = None

    for epoch in range(args.start_epoch, args.epochs + 1):
        if isinstance(train_loader.sampler, DistributedSampler):
            train_loader.sampler.set_epoch(epoch)

        train(epoch, train_loader, model, optimizer, scheduler, args,
              summary_writer)

        if dist.get_rank() == 0 and (epoch % args.save_freq == 0
                                     or epoch == args.epochs):
            save_checkpoint(args,
                            epoch,
                            model,
                            optimizer,
                            scheduler,
                            sampler=train_loader.sampler)
Ejemplo n.º 2
0
def main(args):
    global best_acc1

    args.batch_size = args.total_batch_size // dist.get_world_size()
    train_loader = get_loader(args.aug, args, prefix='train')
    val_loader = get_loader('val', args, prefix='val')
    logger.info(f"length of training dataset: {len(train_loader.dataset)}")

    model, optimizer = build_model(args, num_class=len(train_loader.dataset.classes))
    scheduler = get_scheduler(optimizer, len(train_loader), args)

    # load pre-trained model
    load_pretrained(model, args.pretrained_model)

    # optionally resume from a checkpoint
    if args.auto_resume:
        resume_file = os.path.join(args.output_dir, "current.pth")
        if os.path.exists(resume_file):
            logger.info(f'auto resume from {resume_file}')
            args.resume = resume_file
        else:
            logger.info(f'no checkpoint found in {args.output_dir}, ignoring auto resume')
    if args.resume:
        assert os.path.isfile(args.resume), f"no checkpoint found at '{args.resume}'"
        load_checkpoint(args, model, optimizer, scheduler)

    if args.eval:
        logger.info("==> testing...")
        validate(val_loader, model, args)
        return

    # tensorboard
    if dist.get_rank() == 0:
        summary_writer = SummaryWriter(log_dir=args.output_dir)
    else:
        summary_writer = None

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):
        if isinstance(train_loader.sampler, DistributedSampler):
            train_loader.sampler.set_epoch(epoch)

        tic = time.time()
        train(epoch, train_loader, model, optimizer, scheduler, args)
        logger.info(f'epoch {epoch}, total time {time.time() - tic:.2f}')

        logger.info("==> testing...")
        test_acc, test_acc5, test_loss = validate(val_loader, model, args)
        if summary_writer is not None:
            summary_writer.add_scalar('test_acc', test_acc, epoch)
            summary_writer.add_scalar('test_acc5', test_acc5, epoch)
            summary_writer.add_scalar('test_loss', test_loss, epoch)

        # save model
        if dist.get_rank() == 0 and epoch % args.save_freq == 0:
            logger.info('==> Saving...')
            save_checkpoint(args, epoch, model, test_acc, optimizer, scheduler)