コード例 #1
0
def train(cfg):
    # Set up environment.
    init_distributed_training(cfg)
    local_rank_id = get_local_rank()

    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED + 10 * local_rank_id)
    torch.manual_seed(cfg.RNG_SEED + 10 * local_rank_id)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Setup logging format.
    logging.setup_logging(cfg.OUTPUT_DIR)
    logger.info('init start')
    # 迭代轮数从1开始计数
    arguments = {"cur_epoch": 1}

    device = get_device(local_rank_id)
    model = build_recognizer(cfg, device)
    criterion = build_criterion(cfg, device)
    optimizer = build_optimizer(cfg, model)
    lr_scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = CheckPointer(model,
                                optimizer=optimizer,
                                scheduler=lr_scheduler,
                                save_dir=cfg.OUTPUT_DIR,
                                save_to_disk=True)
    if cfg.TRAIN.RESUME:
        logger.info('resume start')
        extra_checkpoint_data = checkpointer.load(map_location=device)
        if isinstance(extra_checkpoint_data, dict):
            arguments['cur_epoch'] = extra_checkpoint_data['cur_epoch']
            if cfg.LR_SCHEDULER.IS_WARMUP:
                logger.info('warmup start')
                if lr_scheduler.finished:
                    optimizer.load_state_dict(
                        lr_scheduler.after_scheduler.optimizer.state_dict())
                else:
                    optimizer.load_state_dict(
                        lr_scheduler.optimizer.state_dict())
                lr_scheduler.optimizer = optimizer
                lr_scheduler.after_scheduler.optimizer = optimizer
                logger.info('warmup end')
        logger.info('resume end')

    data_loader = build_dataloader(cfg, is_train=True)

    logger.info('init end')
    synchronize()
    do_train(cfg, arguments, data_loader, model, criterion, optimizer,
             lr_scheduler, checkpointer, device)
コード例 #2
0
ファイル: fusion_test.py プロジェクト: ZJCV/TSM
def test(args):
    torch.backends.cudnn.benchmark = True
    logger = logging.setup_logging()
    device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
    map_location = {'cuda:%d' % 0: 'cuda:%d' % 0}

    # # 计算RGB
    rgb_cfg = get_cfg_defaults()
    rgb_cfg.merge_from_file(args.rgb_config_file)
    rgb_cfg.DATALOADER.TEST_BATCH_SIZE = 16
    rgb_cfg.OUTPUT.DIR = args.output
    rgb_cfg.freeze()

    rgb_model = build_model(rgb_cfg, 0)
    rgb_model.eval()
    checkpointer = CheckPointer(rgb_model, logger=logger)
    checkpointer.load(args.rgb_pretrained, map_location=map_location)

    # inference(rgb_cfg, rgb_model, device)

    # 计算RGBDiff
    rgbdiff_cfg = get_cfg_defaults()
    rgbdiff_cfg.merge_from_file(args.rgbdiff_config_file)
    rgbdiff_cfg.DATALOADER.TEST_BATCH_SIZE = 16
    rgbdiff_cfg.OUTPUT.DIR = args.output
    rgbdiff_cfg.freeze()

    rgbdiff_model = build_model(rgbdiff_cfg, 0)
    rgbdiff_model.eval()
    checkpointer = CheckPointer(rgbdiff_model, logger=logger)
    checkpointer.load(args.rgbdiff_pretrained, map_location=map_location)

    inference(rgb_cfg, rgb_model, rgbdiff_cfg, rgbdiff_model, device)
コード例 #3
0
def main():
    args = parse_test_args()
    cfg = load_test_config(args)

    logging.setup_logging(cfg.OUTPUT_DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    launch_job(cfg=cfg, init_method=args.init_method, func=test)
コード例 #4
0
ファイル: util.py プロジェクト: ZJCV/TSM
def create_text_labels(classes, scores, class_names, ground_truth=False):
    """
    Create text labels.
    Args:
        classes (list[int]): a list of class ids for each example.
        scores (list[float] or None): list of scores for each example.
        class_names (list[str]): a list of class names, ordered by their ids.
        ground_truth (bool): whether the labels are ground truth.
    Returns:
        labels (list[str]): formatted text labels.
    """
    try:
        labels = [class_names[i] for i in classes]
    except IndexError:
        logger = logging.setup_logging(__name__)
        logger.error("Class indices get out of range: {}".format(classes))
        return None

    if ground_truth:
        labels = ["[{}] {}".format("GT", label) for label in labels]
    elif scores is not None:
        assert len(classes) == len(scores)
        labels = [
            "[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)
        ]
    return labels
コード例 #5
0
ファイル: demo_net.py プロジェクト: ZJCV/X3D
def run_demo(cfg, frame_provider):
    """
    Run demo visualization.
    Args:
        cfg (CfgNode): configs. Details can be found in
            tsn/config/defaults.py
        frame_provider (iterator): Python iterator that return task objects that are filled
            with necessary information such as `frames`, `id` and `num_buffer_frames` for the
            prediction and visualization pipeline.
    """
    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    # Setup logging format.
    logger = logging.setup_logging(__name__, cfg.OUTPUT_DIR)
    # Print config.
    logger.info("Run demo with config:")
    logger.info(cfg)

    async_vis = AsyncVisualizer(cfg,
                                n_workers=cfg.VISUALIZATION.NUM_VIS_INSTANCES)

    if cfg.NUM_GPUS <= 1:
        model = ActionPredictor(cfg=cfg, async_vis=async_vis)
    else:
        model = AsyncActionPredictor(cfg=cfg, async_vis=async_vis)

    start = time.time()
    num_task = 0
    # Start reading frames.
    frame_provider.start()
    for able_to_read, task in frame_provider:
        if not able_to_read:
            break
        if task is None:
            time.sleep(0.02)
            continue
        num_task += 1

        model.put(task)
        try:
            task = model.get()
            num_task -= 1
            yield task
        except IndexError:
            continue

    while num_task != 0:
        try:
            task = model.get()
            num_task -= 1
            yield task
        except IndexError:
            continue
    logger.info("Finish video in: {}".format(time.time() - start))
コード例 #6
0
def test(cfg):
    # Set up environment.
    init_distributed_training(cfg)
    local_rank_id = get_local_rank()

    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED + 10 * local_rank_id)
    torch.manual_seed(cfg.RNG_SEED + 10 * local_rank_id)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    logging.setup_logging(cfg.OUTPUT_DIR)

    device = get_device(local_rank=local_rank_id)
    model = build_recognizer(cfg, device=device)

    synchronize()
    do_evaluation(cfg, model, device)
コード例 #7
0
def simple_group_split(world_size, rank, num_groups):
    # world_size: number of all processes
    # rank: current process ID
    # num_groups: number of groups in total, e.g. world_size=8 and you want to use 4 GPUs in a syncBN group, so num_groups=2
    groups = []
    rank_list = np.split(np.arange(world_size), num_groups)
    rank_list = [list(map(int, x)) for x in rank_list]
    for i in range(num_groups):
        groups.append(dist.new_group(rank_list[i]))
    group_size = world_size // num_groups

    logger = logging.setup_logging(__name__)
    logger.info(
        "Rank no.{} start sync BN on the process group of {}".format(rank, rank_list[rank // group_size]))
    return groups[rank // group_size]
コード例 #8
0
ファイル: fusion_test.py プロジェクト: ZJCV/TSM
def inference(rgb_cfg, rgb_model, rgbdiff_cfg, rgbdiff_model, device):
    dataset_name = rgb_cfg.DATASETS.TEST.NAME
    output_dir = rgb_cfg.OUTPUT.DIR

    rgb_data_loader = build_dataloader(rgb_cfg, is_train=False)
    rgbdiff_data_loader = build_dataloader(rgbdiff_cfg, is_train=False)
    dataset = rgb_data_loader.dataset

    logger = logging.setup_logging()
    logger.info("Evaluating {} dataset({} video clips):".format(
        dataset_name, len(dataset)))

    results_dict, cate_acc_dict, acc_top1, acc_top5 = \
        compute_on_dataset(rgb_model, rgb_data_loader, rgbdiff_model, rgbdiff_data_loader, device)

    top1_acc = np.mean(acc_top1)
    top5_acc = np.mean(acc_top5)
    result_str = '\ntotal - top_1 acc: {:.3f}, top_5 acc: {:.3f}\n'.format(
        top1_acc, top5_acc)

    classes = dataset.classes
    for key in sorted(results_dict.keys(), key=lambda x: int(x)):
        total_num = results_dict[key]
        acc_num = cate_acc_dict[key]

        cate_name = classes[int(key)]

        if total_num != 0:
            result_str += '{:<3} - {:<20} - acc: {:.2f}\n'.format(
                key, cate_name, acc_num / total_num * 100)
        else:
            result_str += '{:<3} - {:<20} - acc: 0.0\n'.format(
                key, cate_name, acc_num / total_num)
    logger.info(result_str)

    result_path = os.path.join(
        output_dir,
        'result_{}.txt'.format(datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))
    with open(result_path, "w") as f:
        f.write(result_str)

    for handler in logger.handlers:
        logger.removeHandler(handler)

    return {'top1': top1_acc, 'top5': top5_acc}
コード例 #9
0
    def _update_video(self, annotation_dir, is_train=True):
        dataset_type = 'rawframes' if self.type == 'RawFrame' else 'videos'
        if is_train:
            annotation_path = os.path.join(
                annotation_dir,
                f'ucf101_train_split_{self.split}_{dataset_type}.txt')
        else:
            annotation_path = os.path.join(
                annotation_dir,
                f'ucf101_val_split_{self.split}_{dataset_type}.txt')

        if not os.path.isfile(annotation_path):
            raise ValueError(f'{annotation_path}不是文件路径')

        if self.type == 'RawFrame':
            self.video_list = [
                VideoRecord(x.strip().split(' '))
                for x in open(annotation_path)
            ]
        elif self.type == 'Video':
            video_list = list()
            for x in open(annotation_path):
                video_path, cate = x.strip().split(' ')
                video_path = os.path.join(self.data_dir, video_path)

                # Try to decode and sample a clip from a video.
                video_container = None
                try:
                    video_container = container.get_video_container(
                        video_path,
                        self.enable_multithread_decode,
                        self.decoding_backend,
                    )
                except Exception as e:
                    logger = logging.setup_logging(__name__)
                    logger.info(
                        "Failed to load video from {} with error {}".format(
                            video_path, e))

                frames_length = decoder.get_video_length(video_container)
                video_list.append(
                    VideoRecord([video_path, frames_length, cate]))
            self.video_list = video_list
        else:
            raise ValueError(f'{self.type} does not exist')
コード例 #10
0
ファイル: fusion_test.py プロジェクト: ZJCV/TSM
def main():
    parser = argparse.ArgumentParser(description='TSN Test With PyTorch')
    parser.add_argument("rgb_config_file",
                        default="",
                        metavar="RGB_CONFIG_FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument('rgb_pretrained',
                        default="",
                        metavar='RGB_PRETRAINED_FILE',
                        help="path to pretrained model",
                        type=str)
    parser.add_argument("rgbdiff_config_file",
                        default="",
                        metavar="RGBDIFF_CONFIG_FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument('rgbdiff_pretrained',
                        default="",
                        metavar='RGBDIFF_PRETRAINED_FILE',
                        help="path to pretrained model",
                        type=str)
    parser.add_argument('--output', default="./outputs/test", type=str)
    args = parser.parse_args()

    if not os.path.isfile(args.rgb_config_file) and not os.path.isfile(
            args.rgb_pretrained):
        raise ValueError('需要输入RGB模态配置文件和预训练模型路径')
    if not os.path.isfile(args.rgbdiff_config_file) or not os.path.isfile(
            args.rgbdiff_pretrained):
        raise ValueError('需要输入RGBDIFF模态配置文件和预训练模型路径')

    if not os.path.exists(args.output):
        os.makedirs(args.output)
    logger = logging.setup_logging(output_dir=args.output)
    logger.info(args)
    logger.info("Environment info:\n" + collect_env_info())

    test(args)
コード例 #11
0
def inference(cfg, model, device, **kwargs):
    iteration = kwargs.get('iteration', None)
    dataset_name = cfg.DATASETS.TEST.NAME
    num_gpus = cfg.NUM_GPUS

    data_loader = build_dataloader(cfg, is_train=False)
    dataset = data_loader.dataset
    evaluator = data_loader.dataset.evaluator
    evaluator.clean()

    logger = logging.setup_logging(__name__)
    logger.info("Evaluating {} dataset({} video clips):".format(
        dataset_name, len(dataset)))

    if is_master_proc():
        for images, targets in tqdm(data_loader):
            compute_on_dataset(images, targets, device, model, num_gpus,
                               evaluator)
    else:
        for images, targets in data_loader:
            compute_on_dataset(images, targets, device, model, num_gpus,
                               evaluator)

    result_str, acc_dict = evaluator.get()
    logger.info(result_str)

    if is_master_proc():
        output_dir = cfg.OUTPUT_DIR
        result_path = os.path.join(output_dir,
                                   'result_{}.txt'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))) \
            if iteration is None else os.path.join(output_dir, 'result_{:07d}.txt'.format(iteration))

        with open(result_path, "w") as f:
            f.write(result_str)

    return acc_dict
コード例 #12
0
def build_recognizer(cfg, device):
    world_size = du.get_world_size()

    model = registry.RECOGNIZER[cfg.MODEL.RECOGNIZER.NAME](cfg).to(
        device=device)

    logger = logging.setup_logging(__name__)
    if cfg.MODEL.SYNC_BN and world_size > 1:
        logger.info("start sync BN on the process group of {}".format(
            du._LOCAL_RANK_GROUP))
        convert_sync_bn(model, du._LOCAL_PROCESS_GROUP, device)
    if cfg.MODEL.PRETRAINED != "":
        logger.info(f'load pretrained: {cfg.MODEL.PRETRAINED}')
        checkpointer = CheckPointer(model, logger=logger)
        checkpointer.load(cfg.MODEL.PRETRAINED, map_location=device)
        logger.info("finish loading model weights")

    if du.get_world_size() > 1:
        model = DDP(model,
                    device_ids=[device],
                    output_device=device,
                    find_unused_parameters=True)

    return model
コード例 #13
0
ファイル: video_visualizer.py プロジェクト: ZJCV/TSM
    def draw_one_frame(self, frame, preds, text_alpha=0.7):
        """
            Draw labels for one image. By default, predicted labels are drawn in
            the top left corner of the image
            Args:
                frame (array-like): a tensor or numpy array of shape (H, W, C), where H and W correspond to
                    the height and width of the image respectively. C is the number of
                    color channels. The image is required to be in RGB format since that
                    is a requirement of the Matplotlib library. The image is also expected
                    to be in the range [0, 255].
                preds (tensor or list): If ground_truth is False, provide a float tensor of shape (num_boxes, num_classes)
                    that contains all of the confidence scores of the model.
                    For recognition task, input shape can be (num_classes,). To plot true label (ground_truth is True),
                    preds is a list contains int32 of the shape (num_boxes, true_class_ids) or (true_class_ids,).
                text_alpha (Optional[float]): transparency level of the box wrapped around text labels.
        """
        if isinstance(preds, torch.Tensor):
            if preds.ndim == 1:
                preds = preds.unsqueeze(0)
            n_instances = preds.shape[0]
        elif isinstance(preds, list):
            n_instances = len(preds)
        else:
            logger = logging.setup_logging(__name__)
            log.getLogger("matplotlib").setLevel(log.ERROR)
            logger.error("Unsupported type of prediction input.")
            return

        if self.mode == "top-k":
            top_scores, top_classes = torch.topk(preds, k=self.top_k)
            top_scores, top_classes = top_scores.tolist(), top_classes.tolist()
        elif self.mode == "thres":
            top_scores, top_classes = [], []
            for pred in preds:
                mask = pred >= self.thres
                top_scores.append(pred[mask].tolist())
                # top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist()
                top_class = torch.where(mask)[0].tolist()
                top_classes.append(top_class)

        # Create labels top k predicted classes with their scores.
        text_labels = []
        for i in range(n_instances):
            text_labels.append(
                create_text_labels(top_classes[i], top_scores[i],
                                   self.class_names))
        frame_visualizer = ImgVisualizer(frame, meta=None)
        font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 35, 5),
                        9)
        top_corner = False

        text = text_labels[0]
        pred_class = top_classes[0]
        colors = [self._get_color(pred) for pred in pred_class]
        frame_visualizer.draw_multiple_text(
            text,
            torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]),
            top_corner=top_corner,
            font_size=font_size,
            box_facecolors=colors,
            alpha=text_alpha,
        )

        return frame_visualizer.output.get_image()
コード例 #14
0
ファイル: trainer.py プロジェクト: ZJCV/TSM
def do_train(cfg, arguments, data_loader, model, criterion, optimizer,
             lr_scheduler, checkpointer, device):
    logger = logging.setup_logging(__name__)
    meters = MetricLogger()
    summary_writer = None

    use_tensorboard = cfg.TRAIN.USE_TENSORBOARD
    log_step = cfg.TRAIN.LOG_STEP
    save_step = cfg.TRAIN.SAVE_STEP
    eval_step = cfg.TRAIN.EVAL_STEP
    max_iter = cfg.TRAIN.MAX_ITER
    start_iter = arguments['iteration']

    if is_master_proc() and use_tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        summary_writer = SummaryWriter(
            log_dir=os.path.join(cfg.OUTPUT_DIR, 'tf_logs'))
    evaluator = data_loader.dataset.evaluator

    synchronize()
    start_training_time = time.time()
    end = time.time()
    logger.info("Start training ...")
    model.train()
    for iteration, (images, targets) in enumerate(data_loader, start_iter):
        iteration = iteration + 1
        arguments["iteration"] = iteration

        images = images.to(device=device, non_blocking=True)
        targets = targets.to(device=device, non_blocking=True)

        output_dict = model(images)
        loss_dict = criterion(output_dict, targets)
        loss = loss_dict['loss']

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        acc_list = evaluator.evaluate_train(output_dict, targets)
        update_meters(cfg.NUM_GPUS, meters, loss_dict, acc_list)

        if iteration % len(data_loader) == 0 and hasattr(
                data_loader.batch_sampler, "set_epoch"):
            data_loader.batch_sampler.set_epoch(iteration)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time)
        if iteration % log_step == 0:
            eta_seconds = meters.time.global_avg * (max_iter - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            logger.info(
                meters.delimiter.join([
                    "iter: {iter:06d}",
                    "lr: {lr:.5f}",
                    '{meters}',
                    "eta: {eta}",
                    'mem: {mem}M',
                ]).format(
                    iter=iteration,
                    lr=optimizer.param_groups[0]['lr'],
                    meters=str(meters),
                    eta=eta_string,
                    mem=round(torch.cuda.max_memory_allocated() / 1024.0 /
                              1024.0),
                ))
        if is_master_proc():
            if summary_writer:
                global_step = iteration
                for name, meter in meters.meters.items():
                    summary_writer.add_scalar('{}/avg'.format(name),
                                              float(meter.avg),
                                              global_step=global_step)
                    summary_writer.add_scalar('{}/global_avg'.format(name),
                                              meter.global_avg,
                                              global_step=global_step)
                summary_writer.add_scalar('lr',
                                          optimizer.param_groups[0]['lr'],
                                          global_step=global_step)

            if save_step > 0 and iteration % save_step == 0:
                checkpointer.save("model_{:06d}".format(iteration),
                                  **arguments)
        if eval_step > 0 and iteration % eval_step == 0 and not iteration == max_iter:
            eval_results = do_evaluation(cfg,
                                         model,
                                         device,
                                         iteration=iteration)
            model.train()
            if is_master_proc() and summary_writer:
                for key, value in eval_results.items():
                    summary_writer.add_scalar(f'eval/{key}',
                                              value,
                                              global_step=iteration)

    if eval_step > 0:
        logger.info('Start final evaluating...')
        torch.cuda.empty_cache()  # speed up evaluating after training finished
        eval_results = do_evaluation(cfg, model, device)

        if is_master_proc() and summary_writer:
            for key, value in eval_results.items():
                summary_writer.add_scalar(f'eval/{key}',
                                          value,
                                          global_step=arguments["iteration"])
            summary_writer.close()
    checkpointer.save("model_final", **arguments)
    # compute training time
    total_training_time = int(time.time() - start_training_time)
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / max_iter))
    return model