Ejemplo n.º 1
0
def get_logger(work_dir, cfg):
    logger = DistSummaryWriter(work_dir)
    config_txt = os.path.join(work_dir, 'cfg.txt')
    if is_main_process():
        with open(config_txt, 'w') as fp:
            fp.write(str(cfg))

    return logger
Ejemplo n.º 2
0
def main():
    time_stamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    global best_prec1, args
    best_prec1 = 0
    args = parse()

    if not len(args.data):
        raise Exception("error: No data set provided")

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    # make apex optional
    if args.opt_level is not None or args.sync_bn:
        try:
            global DDP, amp, optimizers, parallel
            from apex.parallel import DistributedDataParallel as DDP
            from apex import amp, optimizers, parallel
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to run this example."
            )
    if args.opt_level is None and args.distributed:
        from torch.nn.parallel import DistributedDataParallel as DDP

    dist_print("opt_level = {}".format(args.opt_level))
    dist_print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32),
               type(args.keep_batchnorm_fp32))
    dist_print("loss_scale = {}".format(args.loss_scale),
               type(args.loss_scale))
    dist_print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))

    torch.backends.cudnn.benchmark = True
    best_prec1 = 0
    if args.deterministic:
        # cudnn.benchmark = False
        # cudnn.deterministic = True
        # torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)
        setup_seed(0)

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size
    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
    args.work_dir = os.path.join(args.work_dir,
                                 time_stamp + args.arch + args.note)
    if not args.evaluate:
        if args.local_rank == 0:
            os.makedirs(args.work_dir)
        logger = DistSummaryWriter(args.work_dir)

    # create model
    if args.pretrained:
        dist_print("=> using pre-trained model '{}'".format(args.arch))
        if args.arch == 'fcanet34':
            model = fcanet34(pretrained=True)
        elif args.arch == 'fcanet50':
            model = fcanet50(pretrained=True)
        elif args.arch == 'fcanet101':
            model = fcanet101(pretrained=True)
        elif args.arch == 'fcanet152':
            model = fcanet152(pretrained=True)
        else:
            model = models.__dict__[args.arch](pretrained=True)
    else:
        dist_print("=> creating model '{}'".format(args.arch))
        if args.arch == 'fcanet34':
            model = fcanet34()
        elif args.arch == 'fcanet50':
            model = fcanet50()
        elif args.arch == 'fcanet101':
            model = fcanet101()
        elif args.arch == 'fcanet152':
            model = fcanet152()
        else:
            model = models.__dict__[args.arch]()

    if args.sync_bn:
        dist_print("using apex synced BN")
        model = parallel.convert_syncbn_model(model)

    if hasattr(torch, 'channels_last') and hasattr(torch, 'contiguous_format'):
        if args.channels_last:
            memory_format = torch.channels_last
        else:
            memory_format = torch.contiguous_format
        model = model.cuda().to(memory_format=memory_format)
    else:
        model = model.cuda()

    # Scale learning rate based on global batch size
    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    if args.opt_level is not None:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=args.keep_batchnorm_fp32,
            loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        if args.opt_level is not None:
            model = DDP(model, delay_allreduce=True)
        else:
            model = DDP(model,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank)

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                dist_print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))

                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                dist_print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                dist_print("=> no checkpoint found at '{}'".format(
                    args.resume))

        resume()
    if args.evaluate:
        assert args.evaluate_model is not None
        dist_print("=> loading checkpoint '{}' for eval".format(
            args.evaluate_model))
        checkpoint = torch.load(
            args.evaluate_model,
            map_location=lambda storage, loc: storage.cuda(args.gpu))
        if 'state_dict' in checkpoint.keys():
            model.load_state_dict(checkpoint['state_dict'])
        else:
            state_dict_with_module = {}
            for k, v in checkpoint.items():
                state_dict_with_module['module.' + k] = v
            model.load_state_dict(state_dict_with_module)

    # Data loading code
    if len(args.data) == 1:
        traindir = os.path.join(args.data[0], 'train')
        valdir = os.path.join(args.data[0], 'val')
    else:
        traindir = args.data[0]
        valdir = args.data[1]

    if (args.arch == "inception_v3"):
        raise RuntimeError(
            "Currently, inception_v3 is not supported by this example.")
        # crop_size = 299
        # val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    pipe = HybridTrainPipe(batch_size=args.batch_size,
                           num_threads=args.workers,
                           device_id=args.local_rank,
                           data_dir=traindir,
                           crop=crop_size,
                           dali_cpu=args.dali_cpu,
                           shard_id=args.local_rank,
                           num_shards=args.world_size)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe,
                                              reader_name="Reader",
                                              fill_last_batch=False)

    pipe = HybridValPipe(batch_size=args.batch_size,
                         num_threads=args.workers,
                         device_id=args.local_rank,
                         data_dir=valdir,
                         crop=crop_size,
                         size=val_size,
                         shard_id=args.local_rank,
                         num_shards=args.world_size)
    pipe.build()
    val_loader = DALIClassificationIterator(pipe,
                                            reader_name="Reader",
                                            fill_last_batch=False)

    # criterion = nn.CrossEntropyLoss().cuda()
    criterion = CrossEntropyLabelSmooth().cuda()

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    len_epoch = int(math.ceil(train_loader._size / args.batch_size))
    T_max = 95 * len_epoch
    warmup_iters = 5 * len_epoch
    scheduler = CosineAnnealingLR(optimizer,
                                  T_max,
                                  warmup='linear',
                                  warmup_iters=warmup_iters)

    total_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        avg_train_time = train(train_loader, model, criterion, optimizer,
                               epoch, logger, scheduler)
        total_time.update(avg_train_time)
        torch.cuda.empty_cache()
        # evaluate on validation set
        [prec1, prec5] = validate(val_loader, model, criterion)
        logger.add_scalar('Val/prec1', prec1, global_step=epoch)
        logger.add_scalar('Val/prec5', prec5, global_step=epoch)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                work_dir=args.work_dir)
            if epoch == args.epochs - 1:
                dist_print('##Best Top-1 {0}\n'
                           '##Perf  {2}'.format(
                               best_prec1,
                               args.total_batch_size / total_time.avg))
                with open(os.path.join(args.work_dir, 'res.txt'), 'w') as f:
                    f.write('arhc: {0} \n best_prec1 {1}'.format(
                        args.arch + args.note, best_prec1))

        train_loader.reset()
        val_loader.reset()