def build_train_loader(cls, cfg):
     mapper = PanopticDeeplabDatasetMapper(
         cfg, augmentations=build_sem_seg_train_aug(cfg))
     return build_detection_train_loader(cfg, mapper=mapper)
Esempio n. 2
0
def main():
    args, args_text = _parse_args()

    # detectron2 data loader ###########################
    # det2_args = default_argument_parser().parse_args()
    det2_args = args
    det2_args.config_file = args.det2_cfg
    cfg = setup(det2_args)
    mapper = PanopticDeeplabDatasetMapper(
        cfg, augmentations=build_sem_seg_train_aug(cfg))
    det2_dataset = iter(build_detection_train_loader(cfg, mapper=mapper))

    # dist init
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    torch.cuda.set_device(args.local_rank)
    args.world_size = torch.distributed.get_world_size()
    args.local_rank = torch.distributed.get_rank()

    args.save = args.save + args.exp_name

    if args.local_rank == 0:
        create_exp_dir(args.save,
                       scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
        logger = SummaryWriter(args.save)
        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", str(args))
    else:
        logger = None

    # preparation ################
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # config network and criterion ################
    gt_down_sampling = 1
    min_kept = int(args.batch_size * args.image_height * args.image_width //
                   (16 * gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)

    # data loader ###########################

    kwargs = {
        'num_workers': args.workers,
        'pin_memory': True,
        'drop_last': True
    }
    train_loader, train_sampler, val_loader, val_sampler, num_classes = dataloaders.make_data_loader(
        args, **kwargs)

    with open(args.json_file, 'r') as f:
        # dict_a = json.loads(f, cls=NpEncoder)
        model_dict = json.loads(f.read())

    width_mult_list = [
        4. / 12,
        6. / 12,
        8. / 12,
        10. / 12,
        1.,
    ]
    model = Network(Fch=args.Fch,
                    num_classes=num_classes,
                    stem_head_width=(args.stem_head_width,
                                     args.stem_head_width))

    last = model_dict["lasts"]

    if args.local_rank == 0:
        logging.info("net: " + str(model))
        with torch.cuda.device(0):
            macs, params = get_model_complexity_info(model, (3, 1024, 2048),
                                                     as_strings=True,
                                                     print_per_layer_stat=True,
                                                     verbose=True)
            logging.info('{:<30}  {:<8}'.format('Computational complexity: ',
                                                macs))
            logging.info('{:<30}  {:<8}'.format('Number of parameters: ',
                                                params))

        with open(os.path.join(args.save, 'args.yaml'), 'w') as f:
            f.write(args_text)

    init_weight(model,
                nn.init.kaiming_normal_,
                torch.nn.BatchNorm2d,
                args.bn_eps,
                args.bn_momentum,
                mode='fan_in',
                nonlinearity='relu')

    if args.pretrain:
        model.backbone = load_pretrain(model.backbone, args.pretrain)
    model = model.cuda()

    # if args.sync_bn:
    #     if has_apex:
    #         model = apex.parallel.convert_syncbn_model(model)
    #     else:
    #         model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # Optimizer ###################################
    base_lr = args.base_lr

    if args.opt == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=base_lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=base_lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08)
    elif args.opt == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=base_lr,
                                      betas=(0.9, 0.999),
                                      eps=1e-08,
                                      weight_decay=args.weight_decay)
    else:
        optimizer = create_optimizer(args, model)

    if args.sched == "raw":
        lr_scheduler = None
    else:
        max_iteration = len(train_loader) * args.epochs
        lr_scheduler = Iter_LR_Scheduler(args, max_iteration,
                                         len(train_loader))

    start_epoch = 0
    if os.path.exists(os.path.join(args.save, 'last.pth.tar')):
        args.resume = os.path.join(args.save, 'last.pth.tar')

    if args.resume:
        model_state_file = args.resume
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['start_epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logging.info('Loaded checkpoint (starting from iter {})'.format(
                checkpoint['start_epoch']))

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=None)

    if model_ema:
        eval_model = model_ema.ema
    else:
        eval_model = model

    if has_apex:
        model = DDP(model, delay_allreduce=True)
    else:
        model = DDP(model, device_ids=[args.local_rank])

    best_valid_iou = 0.
    best_epoch = 0

    logging.info("rank: {} world_size: {}".format(args.local_rank,
                                                  args.world_size))
    for epoch in range(start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)
        val_sampler.set_epoch(epoch)
        if args.local_rank == 0:
            logging.info(args.load_path)
            logging.info(args.save)
            logging.info("lr: " + str(optimizer.param_groups[0]['lr']))

        # training
        drop_prob = args.drop_path_prob * epoch / args.epochs
        # model.module.drop_path_prob(drop_prob)

        train_mIoU = train(train_loader, det2_dataset, model, model_ema,
                           ohem_criterion, num_classes, lr_scheduler,
                           optimizer, logger, epoch, args, cfg)

        torch.cuda.empty_cache()

        if epoch > args.epochs // 3:
            # if epoch >= 10:
            temp_iou, avg_loss = validation(val_loader,
                                            eval_model,
                                            ohem_criterion,
                                            num_classes,
                                            args,
                                            cal_miou=True)
        else:
            temp_iou = 0.
            avg_loss = -1

        torch.cuda.empty_cache()
        if args.local_rank == 0:
            logging.info("Epoch: {} train miou: {:.2f}".format(
                epoch + 1, 100 * train_mIoU))
            if temp_iou > best_valid_iou:
                best_valid_iou = temp_iou
                best_epoch = epoch

                if model_ema is not None:
                    torch.save(
                        {
                            'start_epoch': epoch + 1,
                            'state_dict': model_ema.ema.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            # 'lr_scheduler': lr_scheduler.state_dict(),
                        },
                        os.path.join(args.save, 'best_checkpoint.pth.tar'))
                else:
                    torch.save(
                        {
                            'start_epoch': epoch + 1,
                            'state_dict': model.module.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            # 'lr_scheduler': lr_scheduler.state_dict(),
                        },
                        os.path.join(args.save, 'best_checkpoint.pth.tar'))

            logger.add_scalar("mIoU/val", temp_iou, epoch)
            logging.info("[Epoch %d/%d] valid mIoU %.4f eval loss %.4f" %
                         (epoch + 1, args.epochs, temp_iou, avg_loss))
            logging.info("Best valid mIoU %.4f Epoch %d" %
                         (best_valid_iou, best_epoch))

            if model_ema is not None:
                torch.save(
                    {
                        'start_epoch': epoch + 1,
                        'state_dict': model_ema.ema.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        # 'lr_scheduler': lr_scheduler.state_dict(),
                    },
                    os.path.join(args.save, 'last.pth.tar'))
            else:
                torch.save(
                    {
                        'start_epoch': epoch + 1,
                        'state_dict': model.module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        # 'lr_scheduler': lr_scheduler.state_dict(),
                    },
                    os.path.join(args.save, 'last.pth.tar'))