Ejemplo n.º 1
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        state = torch.load(cfg.TEST.MODEL_FILE)
        if 'best_state_dict' in state.keys():
            state = state['best_state_dict']
        state = model_key_helper(state)
        model.load_state_dict(state)
    else:
        model_state_file = os.path.join(final_output_dir, 'final_state.pth')
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(model_key_helper(torch.load(model_state_file)))

    # define loss function (criterion) and optimizer
    matcher = build_matcher(cfg.MODEL.NUM_JOINTS)
    weight_dict = {'loss_ce': 1, 'loss_kpts': cfg.MODEL.EXTRA.KPT_LOSS_COEF}
    criterion = SetCriterion(model.num_classes, matcher, weight_dict,
                             cfg.MODEL.EXTRA.EOS_COEF,
                             ['labels', 'kpts', 'cardinality']).cuda()

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=True)

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, model, criterion,
             final_output_dir, tb_log_dir)
Ejemplo n.º 2
0
def build(args):
    num_classes = 20 if args.dataset_file != 'coco' else 91
    if args.dataset_file == "coco_panoptic":
        num_classes = 250
    device = torch.device(args.device)

    backbone = build_backbone(args)

    transformer = build_transformer(args)

    model = DETR(
        args,
        backbone,
        transformer,
        num_classes=num_classes,
        num_queries=args.num_queries,
        aux_loss=args.aux_loss,
    )
    if args.masks:
        model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
    matcher = build_matcher(args)
    weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef
    if args.masks:
        weight_dict["loss_mask"] = args.mask_loss_coef
        weight_dict["loss_dice"] = args.dice_loss_coef
    # TODO this is a hack
    if args.aux_loss:
        aux_weight_dict = {}
        for i in range(args.dec_layers - 1):
            aux_weight_dict.update(
                {k + f'_{i}': v
                 for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)

    losses = ['labels', 'boxes', 'cardinality']
    if args.masks:
        losses += ["masks"]
    criterion = SetCriterion(num_classes,
                             matcher=matcher,
                             weight_dict=weight_dict,
                             eos_coef=args.eos_coef,
                             losses=losses)
    criterion.to(device)
    postprocessors = {'bbox': PostProcess()}
    if args.masks:
        postprocessors['segm'] = PostProcessSegm()
        if args.dataset_file == "coco_panoptic":
            is_thing_map = {i: i <= 90 for i in range(201)}
            postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map,
                                                             threshold=0.85)

    return model, criterion, postprocessors
Ejemplo n.º 3
0
def build(args):
    # the `num_classes` naming here is somewhat misleading.
    # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
    # is the maximum id for a class in your dataset. For example,
    # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
    # As another example, for a dataset that has a single class with id 1,
    # you should pass `num_classes` to be 2 (max_obj_id + 1).
    # For more details on this, check the following discussion
    # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
    num_classes = 20 if args.dataset_file != 'coco' else 91
    if args.dataset_file == "coco_panoptic":
        # for panoptic, we just add a num_classes that is large enough to hold
        # max_obj_id + 1, but the exact value doesn't really matter
        num_classes = 250
    device = torch.device(args.device)

    backbone = build_backbone(args)

    if int(os.environ.get("cross_transformer", 0)):
        transformer = build_cross_transformer(args)
    elif int(os.environ.get("sparse_transformer", 0)):
        transformer = build_sparse_transformer(args)
    elif int(os.environ.get("linear_transformer", 0)):
        transformer = build_linear_transformer(args)
    else:
        transformer = build_transformer(args)

    model = DETR(
        backbone,
        transformer,
        num_classes=num_classes,
        num_queries=args.num_queries,
        aux_loss=args.aux_loss,
    )
    if args.masks:
        model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
    matcher = build_matcher(args)
    weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef
    if args.masks:
        weight_dict["loss_mask"] = args.mask_loss_coef
        weight_dict["loss_dice"] = args.dice_loss_coef
    # TODO this is a hack
    if args.aux_loss:
        aux_weight_dict = {}
        for i in range(args.dec_layers - 1):
            aux_weight_dict.update(
                {k + f'_{i}': v
                 for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)

    losses = ['labels', 'boxes', 'cardinality']
    if args.masks:
        losses += ["masks"]
    criterion = SetCriterion(num_classes,
                             matcher=matcher,
                             weight_dict=weight_dict,
                             eos_coef=args.eos_coef,
                             losses=losses)
    criterion.to(device)
    postprocessors = {'bbox': PostProcess()}
    if args.masks:
        postprocessors['segm'] = PostProcessSegm()
        if args.dataset_file == "coco_panoptic":
            is_thing_map = {i: i <= 90 for i in range(201)}
            postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map,
                                                             threshold=0.85)

    return model, criterion, postprocessors
Ejemplo n.º 4
0
Archivo: engine.py Proyecto: L4zyy/detr
def save_img_and_update_conf_acc_list(args,
                                      samples,
                                      outputs,
                                      targets,
                                      conf_acc_list,
                                      b_id,
                                      save_err_only=False):
    tens = samples.tensors.squeeze().cpu()
    img = sample2img(tens)
    w, h = img.size
    drw = ImageDraw.Draw(img)
    labels = targets[0]

    matcher = build_matcher(args)

    outputs_without_aux = {
        k: v
        for k, v in outputs.items() if k != 'aux_outputs'
    }

    # Retrieve the matching between the outputs of the last layer and the targets
    indices = matcher(outputs_without_aux, targets)
    pred_labels = torch.Tensor(
        [outputs['pred_logits'][0][i].argmax() for i in indices[0][0]]).int()
    pred_confs = torch.Tensor([
        torch.nn.Softmax(dim=0)(outputs['pred_logits'][0][i]).max()
        for i in indices[0][0]
    ])
    labels = torch.Tensor([targets[0]['labels'][i]
                           for i in indices[0][1]]).int()
    pboxes = [outputs['pred_boxes'][0][i] for i in indices[0][0]]
    gboxes = [targets[0]['boxes'][i] for i in indices[0][1]]

    conflict = 0
    if save_err_only and conflict > 0:
        for plabel, conf, label, pbox, gbox in zip(pred_labels, pred_confs,
                                                   labels, pboxes, gboxes):
            pred_cls = CLASSES[plabel]
            label_cls = CLASSES[label]
            if pred_cls == label_cls:
                continue

            if pred_cls != 'None':
                p_color = 'red'
                conflict += 1
            else:
                if save_err_only:
                    continue
                p_color = 'white'

            pbox = pbox.cpu() * torch.Tensor([w, h, w, h])
            gbox = gbox.cpu() * torch.Tensor([w, h, w, h])
            # draw pred
            draw_box(drw, pbox, p_color, '{}[{:.2f}]'.format(pred_cls, conf),
                     p_color)
            # draw gt
            draw_box(drw, gbox, 'green', '{}'.format(label_cls), 'green')

        fp = Path(args.output_dir, 'err',
                  '{:04d}_{}.png'.format(b_id, conflict))
        img.save(fp)
    else:
        for plabel, conf, pbox in zip(pred_labels, pred_confs, pboxes):
            pred_cls = CLASSES[plabel]
            if pred_cls == 'None':
                continue

            p_color = 'white'
            pbox = pbox.cpu() * torch.Tensor([w, h, w, h])
            # draw pred
            draw_box(drw, pbox, p_color, '{}[{:.2f}]'.format(pred_cls, conf),
                     p_color)

        for label, gbox in zip(labels, gboxes):
            label_cls = CLASSES[label]
            g_color = 'green'
            gbox = gbox.cpu() * torch.Tensor([w, h, w, h])
            # draw gt
            draw_box(drw, gbox, g_color, '{}'.format(label_cls), g_color)

        fp = Path(args.output_dir,
                  '{:04d}_{}_{}.png'.format(b_id, conflict, len(pboxes)))
        img.save(fp)
Ejemplo n.º 5
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True
    )

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    matcher = build_matcher(cfg.MODEL.NUM_JOINTS)
    weight_dict = {'loss_ce': 1, 'loss_kpts': cfg.MODEL.EXTRA.KPT_LOSS_COEF}
    if cfg.MODEL.EXTRA.AUX_LOSS:
        aux_weight_dict = {}
        for i in range(cfg.MODEL.EXTRA.DEC_LAYERS - 1):
            aux_weight_dict.update(
                {k + f'_{i}': v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)
    criterion = SetCriterion(model.num_classes, matcher, weight_dict, cfg.MODEL.EXTRA.EOS_COEF, [
        'labels', 'kpts', 'cardinality']).cuda()

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(
        final_output_dir, 'checkpoint.pth'
    )

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file, map_location='cpu')
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(model_key_helper(checkpoint['state_dict']))

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

        if 'train_global_steps' in checkpoint.keys():
            writer_dict['train_global_steps'] = checkpoint['train_global_steps']
            writer_dict['valid_global_steps'] = checkpoint['valid_global_steps']

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
        last_epoch=last_epoch
    )

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        lr_scheduler.step()

        # evaluate on validation set
        perf_indicator = validate(
            cfg, valid_loader, valid_dataset, model, criterion,
            final_output_dir, tb_log_dir, writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        ckpt = {
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'state_dict': model.state_dict(),
            'best_state_dict': model.module.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
            'train_global_steps': writer_dict['train_global_steps'],
            'valid_global_steps': writer_dict['valid_global_steps'],
        }

        if epoch % cfg.SAVE_FREQ == 0:
            save_checkpoint(ckpt, best_model, final_output_dir,
                            filename=f'checkpoint_{epoch}.pth')

        save_checkpoint(ckpt, best_model, final_output_dir)

    final_model_state_file = os.path.join(
        final_output_dir, 'final_state.pth'
    )
    logger.info('=> saving final model state to {}'.format(
        final_model_state_file)
    )
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()