示例#1
0
def upsnet_test():

    pprint.pprint(config)
    logger.info('test config:{}\n'.format(pprint.pformat(config)))

    # create models
    gpus = [int(_) for _ in config.gpus.split(',')]
    test_model = eval(config.symbol)().cuda(device=gpus[0])

    # create data loader
    test_dataset = eval(config.dataset.dataset)(
        image_sets=config.dataset.test_image_set.split('+'),
        flip=False,
        result_path=final_output_path,
        phase='test')
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.test.batch_size,
        shuffle=False,
        num_workers=0,
        drop_last=False,
        pin_memory=False,
        collate_fn=test_dataset.collate)

    if args.eval_only:
        results = pickle.load(
            open(
                os.path.join(final_output_path, 'results', 'results_list.pkl'),
                'rb'))
        if config.test.vis_mask:
            test_dataset.vis_all_mask(
                results['all_boxes'], results['all_masks'],
                os.path.join(final_output_path, 'results', 'vis'))
        if config.network.has_rcnn:
            test_dataset.evaluate_boxes(
                results['all_boxes'], os.path.join(final_output_path,
                                                   'results'))
        if config.network.has_mask_head:
            test_dataset.evaluate_masks(
                results['all_boxes'], results['all_masks'],
                os.path.join(final_output_path, 'results'))
        if config.network.has_fcn_head:
            test_dataset.evaluate_ssegs(
                results['all_ssegs'],
                os.path.join(final_output_path, 'results', 'ssegs'))
            # logging.info('combined pano result:')
            # test_dataset.evaluate_panoptic(test_dataset.get_combined_pan_result(results['all_ssegs'], results['all_boxes'], results['all_masks'], stuff_area_limit=config.test.panoptic_stuff_area_limit), os.path.join(final_output_path, 'results', 'pans_combined'))
        if config.network.has_panoptic_head:
            logging.info('unified pano result:')
            test_dataset.evaluate_panoptic(
                test_dataset.get_unified_pan_result(
                    results['all_ssegs'],
                    results['all_panos'],
                    results['all_pano_cls_inds'],
                    stuff_area_limit=config.test.panoptic_stuff_area_limit),
                os.path.join(final_output_path, 'results', 'pans_unified'))
        sys.exit()

    # preparing
    curr_iter = config.test.test_iteration
    if args.weight_path == '':
        test_model.load_state_dict(torch.load(
            os.path.join(
                os.path.join(
                    os.path.join(config.output_path,
                                 os.path.basename(args.cfg).split('.')[0]),
                    '_'.join(config.dataset.image_set.split('+')),
                    config.model_prefix + str(curr_iter) + '.pth'))),
                                   resume=True)
    else:
        test_model.load_state_dict(torch.load(args.weight_path), resume=True)

    for p in test_model.parameters():
        p.requires_grad = False

    test_model = DataParallel(test_model, device_ids=gpus,
                              gather_output=False).to(gpus[0])

    # start training
    test_model.eval()

    i_iter = 0
    idx = 0
    test_iter = test_loader.__iter__()
    all_boxes = [[] for _ in range(test_dataset.num_classes)]
    if config.network.has_mask_head:
        all_masks = [[] for _ in range(test_dataset.num_classes)]
    if config.network.has_fcn_head:
        all_ssegs = []
    if config.network.has_panoptic_head:
        all_panos = []
        all_pano_cls_inds = []
        panos = []

    data_timer = Timer()
    net_timer = Timer()
    post_timer = Timer()

    #     while i_iter < len(test_loader):
    while i_iter < 5:
        data_timer.tic()
        batch = []
        labels = []
        for gpu_id in gpus:
            try:
                data, label, _ = test_iter.next()
                if label is not None:
                    data['roidb'] = label['roidb']
                for k, v in data.items():
                    data[k] = v.pin_memory().to(
                        gpu_id, non_blocking=True) if torch.is_tensor(v) else v
            except StopIteration:
                data = data.copy()
                for k, v in data.items():
                    data[k] = v.pin_memory().to(
                        gpu_id, non_blocking=True) if torch.is_tensor(v) else v
            batch.append((data, None))
            labels.append(label)
            i_iter += 1

        im_infos = [_[0]['im_info'][0] for _ in batch]

        data_time = data_timer.toc()
        if i_iter > 10:
            net_timer.tic()
        with torch.no_grad():
            output = test_model(*batch)
            torch.cuda.synchronize()
            if i_iter > 10:
                net_time = net_timer.toc()
            else:
                net_time = 0
            output = im_detect(output, batch, im_infos)
        post_timer.tic()
        for score, box, mask, cls_idx, im_info in zip(output['scores'],
                                                      output['boxes'],
                                                      output['masks'],
                                                      output['cls_inds'],
                                                      im_infos):
            im_post(all_boxes, all_masks, score, box, mask, cls_idx,
                    test_dataset.num_classes,
                    np.round(im_info[:2] / im_info[2]).astype(np.int32))
            idx += 1
        if config.network.has_fcn_head:
            for i, sseg in enumerate(output['ssegs']):
                sseg = sseg.squeeze(0).astype(
                    'uint8')[:int(im_infos[i][0]), :int(im_infos[i][1])]
                all_ssegs.append(
                    cv2.resize(sseg,
                               None,
                               None,
                               fx=1 / im_infos[i][2],
                               fy=1 / im_infos[i][2],
                               interpolation=cv2.INTER_NEAREST))
        if config.network.has_panoptic_head:
            pano_cls_inds = []
            for i, (pano, cls_ind) in enumerate(
                    zip(output['panos'], output['pano_cls_inds'])):
                pano = pano.squeeze(0).astype(
                    'uint8')[:int(im_infos[i][0]), :int(im_infos[i][1])]
                panos.append(
                    cv2.resize(pano,
                               None,
                               None,
                               fx=1 / im_infos[i][2],
                               fy=1 / im_infos[i][2],
                               interpolation=cv2.INTER_NEAREST))
                pano_cls_inds.append(cls_ind)

            all_panos.extend(panos)
            panos = []
            all_pano_cls_inds.extend(pano_cls_inds)
        post_time = post_timer.toc()
        s = 'Batch %d/%d, data_time:%.3f, net_time:%.3f, post_time:%.3f' % (
            idx, len(test_dataset), data_time, net_time, post_time)
        logging.info(s)

    results = []

    # trim redundant predictions
    for i in range(1, test_dataset.num_classes):
        all_boxes[i] = all_boxes[i][:len(test_loader)]
        if config.network.has_mask_head:
            all_masks[i] = all_masks[i][:len(test_loader)]
    if config.network.has_fcn_head:
        all_ssegs = all_ssegs[:len(test_loader)]
    if config.network.has_panoptic_head:
        all_panos = all_panos[:len(test_loader)]

    os.makedirs(os.path.join(final_output_path, 'results'), exist_ok=True)

    results = {
        'all_boxes':
        all_boxes,
        'all_masks':
        all_masks if config.network.has_mask_head else None,
        'all_ssegs':
        all_ssegs if config.network.has_fcn_head else None,
        'all_panos':
        all_panos if config.network.has_panoptic_head else None,
        'all_pano_cls_inds':
        all_pano_cls_inds if config.network.has_panoptic_head else None,
    }

    with open(os.path.join(final_output_path, 'results', 'results_list.pkl'),
              'wb') as f:
        pickle.dump(results, f, protocol=2)

    if config.test.vis_mask:
        test_dataset.vis_all_mask(
            all_boxes, all_masks,
            os.path.join(final_output_path, 'results', 'vis'))
    else:
        test_dataset.evaluate_boxes(all_boxes,
                                    os.path.join(final_output_path, 'results'))
        if config.network.has_mask_head:
            test_dataset.evaluate_masks(
                all_boxes, all_masks, os.path.join(final_output_path,
                                                   'results'))
        if config.network.has_panoptic_head:
            logging.info('unified pano result:')
            test_dataset.evaluate_panoptic(
                test_dataset.get_unified_pan_result(
                    all_ssegs,
                    all_panos,
                    all_pano_cls_inds,
                    stuff_area_limit=config.test.panoptic_stuff_area_limit),
                os.path.join(final_output_path, 'results', 'pans_unified'))
        if config.network.has_fcn_head:
            test_dataset.evaluate_ssegs(
                all_ssegs, os.path.join(final_output_path, 'results', 'ssegs'))
def upsnet_train():

    if is_master:
        logger.info('training config:{}\n'.format(pprint.pformat(config)))
    gpus = [torch.device('cuda', int(_)) for _ in config.gpus.split(',')]
    num_replica = hvd.size() if config.train.use_horovod else len(gpus)
    num_gpus = 1 if config.train.use_horovod else len(gpus)

    # create models
    train_model = eval(config.symbol)().cuda()
        
    # create optimizer
    params_lr = train_model.get_params_lr()
    # we use custom optimizer and pass lr=1 to support different lr for different weights
    optimizer = SGD(params_lr, lr=1, momentum=config.train.momentum, weight_decay=config.train.wd)
    if config.train.use_horovod:
        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=train_model.named_parameters())
    optimizer.zero_grad()

    # create data loader
    train_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.image_set.split('+'), flip=config.train.flip, result_path=final_output_path)
    val_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.test_image_set.split('+'), flip=False, result_path=final_output_path, phase='val')
    if config.train.use_horovod:
        train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        val_sampler = distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, sampler=train_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, sampler=val_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate)
    else:
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate)

    # preparing
    curr_iter = config.train.begin_iteration
    batch_end_callback = [Speedometer(num_replica * config.train.batch_size, config.train.display_iter)]
    metrics = []
    metrics_name = []
    if config.network.has_rpn:
        metrics.extend([AvgMetric(name='rpn_cls_loss'), AvgMetric(name='rpn_bbox_loss'),])
        metrics_name.extend(['rpn_cls_loss', 'rpn_bbox_loss'])
    if config.network.has_rcnn:
        metrics.extend([AvgMetric(name='rcnn_accuracy'), AvgMetric(name='cls_loss'), AvgMetric(name='bbox_loss'),])
        metrics_name.extend(['rcnn_accuracy', 'cls_loss', 'bbox_loss'])
    if config.network.has_mask_head:
        metrics.extend([AvgMetric(name='mask_loss'), ])
        metrics_name.extend(['mask_loss'])
    if config.network.has_fcn_head:
        metrics.extend([AvgMetric(name='fcn_loss'), ])
        metrics_name.extend(['fcn_loss'])
        if config.train.fcn_with_roi_loss:
            metrics.extend([AvgMetric(name='fcn_roi_loss'), ])
            metrics_name.extend(['fcn_roi_loss'])
    if config.network.has_panoptic_head:
        metrics.extend([AvgMetric(name='panoptic_accuracy'), AvgMetric(name='panoptic_loss'), ])
        metrics_name.extend(['panoptic_accuracy', 'panoptic_loss'])

    if config.train.resume:
        train_model.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')), resume=True)
        optimizer.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth')))
        if config.train.use_horovod:
            hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)
    else:
        if is_master:
            train_model.load_state_dict(torch.load(config.network.pretrained))

        if config.train.use_horovod:
            hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)

    if not config.train.use_horovod:
        train_model = DataParallel(train_model, device_ids=[int(_) for _ in config.gpus.split(',')]).to(gpus[0])

    if is_master:
        batch_end_callback[0](0, 0)

    train_model.eval()

    # start training
    while curr_iter < config.train.max_iteration:
        if config.train.use_horovod:
            train_sampler.set_epoch(curr_iter)

            if config.network.use_syncbn:
                train_model.train()
                if config.network.backbone_freeze_at > 0:
                    train_model.freeze_backbone(config.network.backbone_freeze_at)
                if config.network.backbone_fix_bn:
                    train_model.resnet_backbone.eval()


            for inner_iter, batch in enumerate(train_loader):
                data, label, _ = batch
                for k, v in data.items():
                    data[k] = v if not torch.is_tensor(v) else v.cuda()
                for k, v in label.items():
                    label[k] = v if not torch.is_tensor(v) else v.cuda()

                lr = adjust_learning_rate(optimizer, curr_iter, config)
                optimizer.zero_grad()
                output = train_model(data, label)
                loss = 0
                if config.network.has_rpn:
                    loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean()
                if config.network.has_rcnn:
                    loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() * config.train.bbox_loss_weight
                if config.network.has_mask_head:
                    loss = loss + output['mask_loss'].mean()
                if config.network.has_fcn_head:
                    loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight
                    if config.train.fcn_with_roi_loss:
                        loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2
                if config.network.has_panoptic_head:
                    loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight
                loss.backward()
                optimizer.step(lr)

                losses = []
                losses.append(allreduce_async(loss, name='train_total_loss'))
                for l in metrics_name:
                    losses.append(allreduce_async(output[l].mean(), name=l))

                loss = hvd.synchronize(losses[0]).item()
                if is_master:
                    writer.add_scalar('train_total_loss', loss, curr_iter)
                for i, (metric, l) in enumerate(zip(metrics, metrics_name)):
                    loss = hvd.synchronize(losses[i + 1]).item()
                    if is_master:
                        writer.add_scalar('train_' + l, loss, curr_iter)
                        metric.update(_, _, loss)
                curr_iter += 1


                if curr_iter in config.train.decay_iteration:
                    if is_master:
                        logger.info('decay momentum buffer')
                    for k in optimizer.state_dict()['state'].keys():
                        if 'momentum_buffer' in optimizer.state_dict()['state'][k]:
                            optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10)

                if is_master:
                    if curr_iter % config.train.display_iter == 0:
                        for callback in batch_end_callback:
                            callback(curr_iter, metrics)

                    if curr_iter % config.train.snapshot_step == 0:
                        logger.info('taking snapshot ...')
                        torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth'))
                        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth'))
        else:
            inner_iter = 0
            train_iterator = train_loader.__iter__()
            while inner_iter + num_gpus <= len(train_loader):
                batch = []
                for gpu_id in gpus:
                    data, label, _ = train_iterator.next()
                    for k, v in data.items():
                        data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                    for k, v in label.items():
                        label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                    batch.append((data, label))
                    inner_iter += 1
                lr = adjust_learning_rate(optimizer, curr_iter, config)
                optimizer.zero_grad()
                if config.train.use_horovod:
                    output = train_model(data, label)
                else:
                    output = train_model(*batch)

                loss = 0
                if config.network.has_rpn:
                    loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean()
                if config.network.has_rcnn:
                    loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean()
                if config.network.has_mask_head:
                    loss = loss + output['mask_loss'].mean()
                if config.network.has_fcn_head:
                    loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight
                    if config.train.fcn_with_roi_loss:
                        loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2
                if config.network.has_panoptic_head:
                    loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight
                loss.backward()
                optimizer.step(lr)
                
                losses = []
                losses.append(loss.item())
                for l in metrics_name:
                    losses.append(output[l].mean().item())

                loss = losses[0]
                if is_master:
                    writer.add_scalar('train_total_loss', loss, curr_iter)
                for i, (metric, l) in enumerate(zip(metrics, metrics_name)):
                    loss = losses[i + 1]
                    if is_master:
                        writer.add_scalar('train_' + l, loss, curr_iter)
                        metric.update(_, _, loss)
                curr_iter += 1

                if curr_iter in config.train.decay_iteration:
                    if is_master:
                        logger.info('decay momentum buffer')
                    for k in optimizer.state_dict()['state'].keys():
                        optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10)

                if is_master:
                    if curr_iter % config.train.display_iter == 0:
                        for callback in batch_end_callback:
                            callback(curr_iter, metrics)


                    if curr_iter % config.train.snapshot_step == 0:
                        logger.info('taking snapshot ...')
                        torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth'))
                        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth'))

            while True:
                try:
                    train_iterator.next()
                except:
                    break

        for metric in metrics:
            metric.reset()

        if config.train.eval_data:
            train_model.eval()

            if config.train.use_horovod:
                for inner_iter, batch in enumerate(val_loader):
                    data, label, _ = batch
                    for k, v in data.items():
                        data[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True)
                    for k, v in label.items():
                        label[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True)

                    with torch.no_grad():
                        output = train_model(data, label)

                    for metric, l in zip(metrics, metrics_name):
                        loss = hvd.allreduce(output[l].mean()).item()
                        if is_master:
                            metric.update(_, _, loss)

            else:
                inner_iter = 0
                val_iterator = val_loader.__iter__()
                while inner_iter + len(gpus) <= len(val_loader):
                    batch = []
                    for gpu_id in gpus:
                        data, label, _ = val_iterator.next()
                        for k, v in data.items():
                            data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                        for k, v in label.items():
                            label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                        batch.append((data, label))
                        inner_iter += 1

                    with torch.no_grad():
                        if config.train.use_horovod:
                            output = train_model(data, label)
                        else:
                            output = train_model(*batch)

                    losses = []
                    for l in metrics_name:
                        losses.append(allreduce_async(output[l].mean(), name=l) if config.train.use_horovod else output[l].mean().item())

                    for metric, loss in zip(metrics, losses):
                        loss = hvd.synchronize(loss).item() if config.train.use_horovod else loss
                        if is_master:
                            metric.update(_, _, loss)

                while True:
                    try:
                        val_iterator.next()
                    except Exception:
                        break

            s = 'Batch [%d]\t Epoch[%d]\t' % (curr_iter, curr_iter // len(train_loader))

            for metric in metrics:
                m, v = metric.get()
                s += 'Val-%s=%f,\t' % (m, v)
                if is_master:
                    writer.add_scalar('val_' + m, v, curr_iter)
                    metric.reset()
            if is_master:
                logger.info(s)

    if is_master and config.train.use_horovod:
        logger.info('taking snapshot ...')
        torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
    elif not config.train.use_horovod:
        logger.info('taking snapshot ...')
        torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))