def generate(
    cfg, 
    model: torch.nn.Module, 
    data_loader: torch.utils.data.DataLoader, 
    device: torch.device, 
    logger=None, 
    *args, 
    **kwargs, 
):
    model.eval()
    total_loss = []
    with utils.log_info(msg="Generate results", level="INFO", state=True, logger=logger):

        pbar = tqdm(total=len(data_loader), dynamic_ncols=True)
        for idx, data in enumerate(data_loader):
            start_time = time.time()
            output, *_ = utils.inference(model=model, data=data, device=device)

            for i in range(output.shape[0]):
                save_dir = os.path.join(cfg.SAVE.DIR, "results")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                path2file = os.path.join(save_dir, data["img_idx"][i]+".png")
                succeed = utils.save_image(output[i].detach().cpu().numpy(), cfg.DATA.MEAN, cfg.DATA.NORM, path2file)
                if not succeed:
                    utils.notify("Cannot save image to {}".format(path2file))

            pbar.update()
        pbar.close()
def generate(
    cfg, 
    model: torch.nn.Module, 
    data_loader: torch.utils.data.DataLoader, 
    device: torch.device, 
    phase, 
    logger=None, 
    *args, 
    **kwargs, 
):
    model.eval()
    # Prepare to log info.
    log_info = print if logger is None else logger.log_info
    total_loss = []
    inference_time = []
    # Read data and evaluate and record info.
    with utils.log_info(msg="Generate results", level="INFO", state=True, logger=logger):
        pbar = tqdm(total=len(data_loader), dynamic_ncols=True)
        for idx, data in enumerate(data_loader):
            start_time = time.time()
            output = utils.inference(model=model, data=data, device=device)
            inference_time.append(time.time()-start_time)

            for i in range(output.shape[0]):
                save_dir = os.path.join(cfg.SAVE.DIR, phase)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                path2file = os.path.join(save_dir, data["img_idx"][i]+"_g.png")
                succeed = utils.save_image(output[i].detach().cpu().numpy(), cfg.DATA.MEAN, cfg.DATA.NORM, path2file)
                if not succeed:
                    log_info("Cannot save image to {}".format(path2file))
            pbar.update()
        pbar.close()
    log_info("Runtime per image: {:<5} seconds.".format(round(sum(inference_time)/len(inference_time), 4)))
Example #3
0
def validate(args, val_dataloader, nets, iteration=0, iou_thresh=0.5):
    """
    Test the model on validation set
    """

    # write results to files for evaluation
    output_files = []
    fouts = []
    for i in range(args.max_iter):
        output_file = args.save_root + 'val_result-' + str(
            iteration) + '-iter' + str(i + 1) + '.csv'
        output_files.append(output_file)
        f = open(output_file, 'w')
        fouts.append(f)

    gt_file = args.save_root + 'val_gt.csv'
    fout = open(gt_file, 'w')

    with torch.no_grad():  # for evaluation
        for num, (images, targets, tubes, infos) in enumerate(val_dataloader):

            if (num + 1) % 100 == 0:
                print("%d / %d" %
                      (num + 1, len(val_dataloader.dataset) / args.batch_size))

            for b in range(len(infos)):
                for n in range(len(infos[b]['boxes'])):
                    mid = int(len(infos[b]['boxes'][n]) / 2)
                    box = infos[b]['boxes'][n][mid]
                    labels = infos[b]['labels'][n][mid]
                    for label in labels:
                        fout.write(
                            '{0},{1:04},{2:.4},{3:.4},{4:.4},{5:.4},{6}\n'.
                            format(infos[b]['video_name'], infos[b]['fid'],
                                   box[0], box[1], box[2], box[3], label))

            _, _, channels, height, width = images.size()
            images = images.cuda()

            # get conv features
            conv_feat = nets['base_net'](images)
            context_feat = None
            if not args.no_context:
                context_feat = nets['context_net'](conv_feat)

            ############## Inference ##############

            history, _ = inference(args, conv_feat, context_feat, nets,
                                   args.max_iter, tubes)

            #################### Evaluation #################

            # loop for each  iteration
            for i in range(len(history)):
                pred_prob = history[i]['pred_prob'].cpu()
                pred_prob = pred_prob[:, int(pred_prob.shape[1] / 2)]
                pred_tubes = history[i]['pred_loc'].cpu()
                pred_tubes = pred_tubes[:, int(pred_tubes.shape[1] / 2)]
                tubes_nums = history[i]['tubes_nums']

                # loop for each sample in a batch
                tubes_count = 0
                for b in range(len(tubes_nums)):
                    info = infos[b]
                    seq_start = tubes_count
                    tubes_count = tubes_count + tubes_nums[b]

                    cur_pred_prob = pred_prob[seq_start:seq_start +
                                              tubes_nums[b]]
                    cur_pred_tubes = pred_tubes[seq_start:seq_start +
                                                tubes_nums[b]]

                    # do NMS first
                    all_scores = []
                    all_boxes = []
                    all_idx = []
                    for cl_ind in range(args.num_classes):
                        scores = cur_pred_prob[:, cl_ind].squeeze().reshape(-1)
                        c_mask = scores.gt(
                            args.conf_thresh)  # greater than minmum threshold
                        scores = scores[c_mask]
                        idx = np.where(c_mask.numpy())[0]
                        if len(scores) == 0:
                            all_scores.append([])
                            all_boxes.append([])
                            continue
                        boxes = cur_pred_tubes.clone()
                        l_mask = c_mask.unsqueeze(1).expand_as(boxes)
                        boxes = boxes[l_mask].view(-1, 4)

                        boxes = valid_tubes(boxes.view(-1, 1, 4)).view(-1, 4)
                        keep = nms(boxes, scores, args.nms_thresh)
                        boxes = boxes[keep].numpy()
                        scores = scores[keep].numpy()
                        idx = idx[keep]

                        boxes[:, ::2] /= width
                        boxes[:, 1::2] /= height
                        all_scores.append(scores)
                        all_boxes.append(boxes)
                        all_idx.append(idx)

                    # get the top scores
                    scores_list = [(s, cl_ind, j)
                                   for cl_ind, scores in enumerate(all_scores)
                                   for j, s in enumerate(scores)]
                    if args.evaluate_topk > 0:
                        scores_list.sort(key=lambda x: x[0])
                        scores_list = scores_list[::-1]
                        scores_list = scores_list[:args.topk]

                    for s, cl_ind, j in scores_list:
                        # write to files
                        box = all_boxes[cl_ind][j]
                        fouts[i].write(
                            '{0},{1:04},{2:.4},{3:.4},{4:.4},{5:.4},{6},{7:.4}\n'
                            .format(info['video_name'], info['fid'], box[0],
                                    box[1], box[2], box[3], label_dict[cl_ind],
                                    s))
    fout.close()

    all_metrics = []
    for i in range(args.max_iter):
        fouts[i].close()

        metrics = ava_evaluation(os.path.join(args.data_root, 'label/'),
                                 output_files[i], gt_file)
        all_metrics.append(metrics)

    return all_metrics
Example #4
0
def train(args, nets, optimizer, scheduler, train_dataloader, val_dataloader,
          log_file):
    global best_mAP

    for _, net in nets.items():
        net.train()

    # loss counters
    batch_time = AverageMeter(200)
    losses = [AverageMeter(200) for _ in range(args.max_iter)]
    losses_global_cls = AverageMeter(200)
    losses_local_loc = AverageMeter(200)
    losses_neighbor_loc = AverageMeter(200)

    #    writer = SummaryWriter(args.save_root+"summary"+datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))

    ################ Training loop #################

    torch.cuda.synchronize()
    t0 = time.perf_counter()
    epochs = args.start_epochs
    iteration = args.start_iteration
    epoch_size = int(np.ceil(len(train_dataloader.dataset) / args.batch_size))

    while epochs < args.max_epochs:
        for _, (images, targets, tubes, infos) in enumerate(train_dataloader):

            images = images.cuda()

            # adjust learning rate
            scheduler.step()
            lr = optimizer.param_groups[-1]['lr']

            # get conv features
            conv_feat = nets['base_net'](images)
            context_feat = None
            if not args.no_context:
                context_feat = nets['context_net'](conv_feat)

            ############# Inference to get candidates for each iteration ########

            # randomly sample a fixed number of tubes
            if args.NUM_SAMPLE > 0 and args.NUM_SAMPLE < tubes[0].shape[0]:
                sampled_idx = np.random.choice(tubes[0].shape[0],
                                               args.NUM_SAMPLE,
                                               replace=False)
                for i in range(len(tubes)):
                    tubes[i] = tubes[i][sampled_idx]

            for _, net in nets.items():
                net.eval()
            with torch.no_grad():
                history, _ = inference(args, conv_feat, context_feat, nets,
                                       args.max_iter - 1, tubes)
            for _, net in nets.items():
                net.train()

            ########### Forward pass for each iteration ############
            optimizer.zero_grad()
            loss_back = 0.

            # loop for each step
            for i in range(1, args.max_iter + 1):  # index from 1

                # adaptively get the start chunk
                chunks = args.NUM_CHUNKS[i]
                max_chunks = args.NUM_CHUNKS[args.max_iter]
                T_start = int(
                    (args.NUM_CHUNKS[args.max_iter] - chunks) / 2) * args.T
                T_length = chunks * args.T
                T_mid = int(
                    chunks / 2) * args.T  # center chunk within T_length
                chunk_idx = [
                    j * args.T + int(args.T / 2) for j in range(chunks)
                ]  # used to index the middel frame of each chunk

                # select training samples
                selected_tubes, target_tubes = train_select(
                    i, history[i - 2], targets, tubes, args)

                ######### Start training ########

                # flatten list of tubes
                flat_targets, _ = flatten_tubes(target_tubes, batch_idx=False)
                flat_tubes, _ = flatten_tubes(
                    selected_tubes,
                    batch_idx=True)  # add batch_idx for ROI pooling
                flat_targets = torch.FloatTensor(flat_targets).to(conv_feat)
                flat_tubes = torch.FloatTensor(flat_tubes).to(conv_feat)

                # ROI Pooling
                pooled_feat = nets['roi_net'](conv_feat[:, T_start:T_start +
                                                        T_length].contiguous(),
                                              flat_tubes)
                _, C, W, H = pooled_feat.size()
                pooled_feat = pooled_feat.view(-1, T_length, C, W, H)

                temp_context_feat = None
                if not args.no_context:
                    temp_context_feat = torch.zeros(
                        (pooled_feat.size(0), context_feat.size(1), T_length,
                         1, 1)).to(context_feat)
                    for p in range(pooled_feat.size(0)):
                        temp_context_feat[p] = context_feat[
                            int(flat_tubes[p, 0, 0].item() / T_length), :,
                            T_start:T_start + T_length].contiguous().clone()

                _, _, _, _, cur_loss_global_cls, cur_loss_local_loc, cur_loss_neighbor_loc = nets[
                    'det_net%d' % (i - 1)](pooled_feat,
                                           context_feat=temp_context_feat,
                                           tubes=flat_tubes,
                                           targets=flat_targets)
                cur_loss_global_cls = cur_loss_global_cls.mean()
                cur_loss_local_loc = cur_loss_local_loc.mean()
                cur_loss_neighbor_loc = cur_loss_neighbor_loc.mean()

                cur_loss = cur_loss_global_cls + \
                            cur_loss_local_loc * args.lambda_reg + \
                            cur_loss_neighbor_loc * args.lambda_neighbor
                loss_back += cur_loss.to(conv_feat.device)

                losses[i - 1].update(cur_loss.item())
                if cur_loss_neighbor_loc.item() > 0:
                    losses_neighbor_loc.update(cur_loss_neighbor_loc.item())

            ########### Gradient updates ############
            # record last step only
            losses_global_cls.update(cur_loss_global_cls.item())
            losses_local_loc.update(cur_loss_local_loc.item())

            if args.fp16:
                loss_back /= args.max_iter  # prevent gradient overflow
                with amp.scale_loss(loss_back, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_back.backward()
            optimizer.step()

            ############### Print logs and save models ############

            iteration += 1

            if iteration % args.print_step == 0 and iteration > 0:

                gpu_memory = get_gpu_memory()

                torch.cuda.synchronize()
                t1 = time.perf_counter()
                batch_time.update(t1 - t0)

                print_line = 'Epoch {}/{}({}) Iteration {:06d} lr {:.2e} '.format(
                    epochs + 1, args.max_epochs, epoch_size, iteration, lr)
                for i in range(args.max_iter):
                    print_line += 'loss-{} {:.3f} '.format(
                        i + 1, losses[i].avg)
                print_line += 'loss_global_cls {:.3f} loss_local_loc {:.3f} loss_neighbor_loc {:.3f} Timer {:0.3f}({:0.3f}) GPU usage: {}'.format(
                    losses_global_cls.avg, losses_local_loc.avg,
                    losses_neighbor_loc.avg, batch_time.val, batch_time.avg,
                    gpu_memory)

                torch.cuda.synchronize()
                t0 = time.perf_counter()
                log_file.write(print_line + '\n')
                print(print_line)

            if (iteration % args.save_step == 0) and iteration > 0:
                print('Saving state, iter:', iteration)
                save_name = args.save_root + 'checkpoint_' + str(
                    iteration) + '.pth'
                save_dict = {
                    'epochs': epochs + 1,
                    'iteration': iteration,
                    'base_net': nets['base_net'].state_dict(),
                    'context_net': nets['context_net'].state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'val_mAP': best_mAP,
                    'cfg': args
                }
                for i in range(args.max_iter):
                    save_dict['det_net%d' % i] = nets['det_net%d' %
                                                      i].state_dict()
                torch.save(save_dict, save_name)

                # only keep the latest model
                if os.path.isfile(args.save_root + 'checkpoint_' +
                                  str(iteration - args.save_step) + '.pth'):
                    os.remove(args.save_root + 'checkpoint_' +
                              str(iteration - args.save_step) + '.pth')
                    print(args.save_root + 'checkpoint_' +
                          str(iteration - args.save_step) + '.pth  removed!')

            # For consistency when resuming from the middle of an epoch
            if iteration % epoch_size == 0 and iteration > 0:
                break

        ##### Validation at the end of each epoch #####

        validate_epochs = [0, 1, 5, 9, 13, 14]
        if epochs in validate_epochs:
            torch.cuda.synchronize()
            tvs = time.perf_counter()

            for _, net in nets.items():
                net.eval()  # switch net to evaluation mode
            print('Validating at ', iteration)
            all_metrics = validate(args,
                                   val_dataloader,
                                   nets,
                                   iteration,
                                   iou_thresh=args.iou_thresh)

            prt_str = ''
            for i in range(args.max_iter):
                prt_str += 'Iter ' + str(i + 1) + ': MEANAP =>' + str(
                    all_metrics[i]['PascalBoxes_Precision/[email protected]']) + '\n'
            print(prt_str)
            log_file.write(prt_str)

            log_file.write("Best MEANAP so far => {}\n".format(best_mAP))
            for i in class_whitelist:
                log_file.write("({}) {}: {}\n".format(
                    i, id2class[i], all_metrics[-1]
                    ["PascalBoxes_PerformanceByCategory/[email protected]/{}".format(
                        id2class[i])]))

    #        writer.add_scalar('mAP', all_metrics[-1]['PascalBoxes_Precision/[email protected]'], iteration)
    #        for key, ap in all_metrics[-1].items():
    #            writer.add_scalar(key, ap, iteration)

            if all_metrics[-1]['PascalBoxes_Precision/[email protected]'] > best_mAP:
                best_mAP = all_metrics[-1]['PascalBoxes_Precision/[email protected]']
                print('Saving current best model, iter:', iteration)
                save_name = args.save_root + 'checkpoint_best.pth'
                save_dict = {
                    'epochs': epochs + 1,
                    'iteration': iteration,
                    'base_net': nets['base_net'].state_dict(),
                    'context_net': nets['context_net'].state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'val_mAP': best_mAP,
                    'cfg': args
                }
                for i in range(args.max_iter):
                    save_dict['det_net%d' % i] = nets['det_net%d' %
                                                      i].state_dict()

                torch.save(save_dict, save_name)

            for _, net in nets.items():
                net.train()  # switch net to training mode
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prt_str2 = '\nValidation TIME::: {:0.3f}\n\n'.format(t0 - tvs)
            print(prt_str2)
            log_file.write(prt_str2)

        epochs += 1

    log_file.close()
Example #5
0
def main():

    ################## Customize your configuratons here ###################

    checkpoint_path = 'pretrained/ava_step.pth'
    if os.path.isfile(checkpoint_path):
        print("Loading pretrain model from %s" % checkpoint_path)
        map_location = 'cuda:0'
        checkpoint = torch.load(checkpoint_path, map_location=map_location)
        args = checkpoint['cfg']
    else:
        raise ValueError("Pretrain model not found!", checkpoint_path)

    # TODO: Set data_root to the customized input dataset
    args.data_root = '/datasets/demo/frames/'
    args.save_root = os.path.join(os.path.dirname(args.data_root), 'results/')
    if not os.path.isdir(args.save_root):
        os.makedirs(args.save_root)

    # TODO: modify this setting according to the actual frame rate and file name
    source_fps = 30
    im_format = 'frame%04d.jpg'
    conf_thresh = 0.4
    global_thresh = 0.8  # used for cross-class NMS

    ################ Define models #################

    gpu_count = torch.cuda.device_count()
    nets = OrderedDict()
    # backbone network
    nets['base_net'] = BaseNet(args)
    # ROI pooling
    nets['roi_net'] = ROINet(args.pool_mode, args.pool_size)

    # detection network
    for i in range(args.max_iter):
        if args.det_net == "two_branch":
            nets['det_net%d' % i] = TwoBranchNet(args)
        else:
            raise NotImplementedError
    if not args.no_context:
        # context branch
        nets['context_net'] = ContextNet(args)

    for key in nets:
        nets[key] = nets[key].cuda()

    nets['base_net'] = torch.nn.DataParallel(nets['base_net'])
    if not args.no_context:
        nets['context_net'] = torch.nn.DataParallel(nets['context_net'])
    for i in range(args.max_iter):
        nets['det_net%d' % i].to('cuda:%d' % ((i + 1) % gpu_count))
        nets['det_net%d' % i].set_device('cuda:%d' % ((i + 1) % gpu_count))

    # load pretrained model
    nets['base_net'].load_state_dict(checkpoint['base_net'])
    if not args.no_context and 'context_net' in checkpoint:
        nets['context_net'].load_state_dict(checkpoint['context_net'])
    for i in range(args.max_iter):
        pretrained_dict = checkpoint['det_net%d' % i]
        nets['det_net%d' % i].load_state_dict(pretrained_dict)

    ################ DataLoader setup #################

    dataset = CustomizedDataset(args.data_root,
                                args.T,
                                args.NUM_CHUNKS[args.max_iter],
                                source_fps,
                                args.fps,
                                BaseTransform(args.image_size, args.means,
                                              args.stds, args.scale_norm),
                                anchor_mode=args.anchor_mode,
                                im_format=im_format)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             args.batch_size,
                                             num_workers=args.num_workers,
                                             shuffle=False,
                                             collate_fn=detection_collate,
                                             pin_memory=True)

    ################ Inference #################

    for _, net in nets.items():
        net.eval()

    fout = open(os.path.join(args.save_root, 'results.txt'), 'w')
    torch.cuda.synchronize()
    t0 = time.time()
    with torch.no_grad():
        for _, (images, tubes, infos) in enumerate(dataloader):

            _, _, channels, height, width = images.size()
            images = images.cuda()

            # get conv features
            conv_feat = nets['base_net'](images)
            context_feat = None
            if not args.no_context:
                context_feat = nets['context_net'](conv_feat)

            history, _ = inference(args, conv_feat, context_feat, nets,
                                   args.max_iter, tubes)

            # collect result of the last step
            pred_prob = history[-1]['pred_prob'].cpu()
            pred_prob = pred_prob[:, int(pred_prob.shape[1] / 2)]
            pred_tubes = history[-1]['pred_loc'].cpu()
            pred_tubes = pred_tubes[:, int(pred_tubes.shape[1] / 2)]
            tubes_nums = history[-1]['tubes_nums']

            # loop for each batch
            tubes_count = 0
            for b in range(len(tubes_nums)):
                info = infos[b]
                seq_start = tubes_count
                tubes_count = tubes_count + tubes_nums[b]

                cur_pred_prob = pred_prob[seq_start:seq_start + tubes_nums[b]]
                cur_pred_tubes = pred_tubes[seq_start:seq_start +
                                            tubes_nums[b]]

                # do NMS first
                all_scores = []
                all_boxes = []
                all_idx = []
                for cl_ind in range(args.num_classes):
                    scores = cur_pred_prob[:, cl_ind].squeeze()
                    c_mask = scores.gt(conf_thresh)  # greater than a threshold
                    scores = scores[c_mask]
                    idx = np.where(c_mask.numpy())[0]
                    if len(scores) == 0:
                        all_scores.append([])
                        all_boxes.append([])
                        continue
                    boxes = cur_pred_tubes.clone()
                    l_mask = c_mask.unsqueeze(1).expand_as(boxes)
                    boxes = boxes[l_mask].view(-1, 4)

                    boxes = valid_tubes(boxes.view(-1, 1, 4)).view(-1, 4)
                    keep = nms(boxes, scores, args.nms_thresh)
                    boxes = boxes[keep].numpy()
                    scores = scores[keep].numpy()
                    idx = idx[keep]

                    boxes[:, ::2] /= width
                    boxes[:, 1::2] /= height
                    all_scores.append(scores)
                    all_boxes.append(boxes)
                    all_idx.append(idx)

                # get the top scores
                scores_list = [(s, cl_ind, j)
                               for cl_ind, scores in enumerate(all_scores)
                               for j, s in enumerate(scores)]
                if args.evaluate_topk > 0:
                    scores_list.sort(key=lambda x: x[0])
                    scores_list = scores_list[::-1]
                    scores_list = scores_list[:args.topk]

                # merge high overlapping boxes (a simple greedy method)
                merged_result = {}
                flag = [1 for _ in range(len(scores_list))]
                for i in range(len(scores_list)):
                    if flag[i]:
                        s, cl_ind, j = scores_list[i]
                        box = all_boxes[cl_ind][j]
                        temp = ([box], [args.label_dict[cl_ind]], [s])

                        # find all high IoU boxes
                        for ii in range(i + 1, len(scores_list)):
                            if flag[ii]:
                                s2, cl_ind2, j2 = scores_list[ii]
                                box2 = all_boxes[cl_ind2][j2]
                                if compute_box_iou(box, box2) > global_thresh:
                                    flag[ii] = 0
                                    temp[0].append(box2)
                                    temp[1].append(args.label_dict[cl_ind2])
                                    temp[2].append(s2)

                        merged_box = np.mean(np.concatenate(temp[0],
                                                            axis=0).reshape(
                                                                -1, 4),
                                             axis=0)
                        key = ','.join(merged_box.astype(str).tolist())
                        merged_result[key] = [
                            (l, s) for l, s in zip(temp[1], temp[2])
                        ]

                # visualize results
                if not os.path.isdir(
                        os.path.join(args.save_root, info['video_name'])):
                    os.makedirs(
                        os.path.join(args.save_root, info['video_name']))
                print(info)
                overlay_image(os.path.join(args.data_root, info['video_name'],
                                           im_format % info['fid']),
                              os.path.join(args.save_root, info['video_name'],
                                           im_format % info['fid']),
                              pred_boxes=merged_result,
                              id2class=args.id2class)

                # write to files
                for key in merged_result:
                    box = np.asarray(key.split(','), dtype=np.float32)
                    for l, s in merged_result[key]:
                        fout.write(
                            '{0},{1:04},{2:.4},{3:.4},{4:.4},{5:.4},{6},{7:.4}\n'
                            .format(info['video_name'], info['fid'], box[0],
                                    box[1], box[2], box[3], l, s))
            torch.cuda.synchronize()
            t1 = time.time()
            print("Batch time: ", t1 - t0)

            torch.cuda.synchronize()
            t0 = time.time()

    fout.close()
Example #6
0
    restore_model = tfe.Saver(var_list=variables_to_restore)

    # restore if model saved and show number of params
    restore_state(restore_model, name_best_model)
    get_params(model)

    img = cv2.imread(args.image_path, 0)
    img = cv2.resize(img, (width, height),
                     interpolation=cv2.INTER_AREA).astype(np.float32)
    img = np.expand_dims(img, -1)
    img = np.expand_dims(img, 0)
    print(img.shape)

    prediction = inference(model,
                           img,
                           n_classes,
                           flip_inference=True,
                           scales=[0.75, 1, 1.5],
                           preprocess_mode=None)
    print(prediction.numpy().shape)
    prediction = tf.argmax(prediction, -1)
    print(prediction.numpy().shape)

    img = np.squeeze(img).astype(np.uint8)
    prediction = np.squeeze(prediction.numpy()).astype(np.uint8)
    prediction_6classes = fromIdTrainToId(prediction).astype(np.uint8)

    cv2.imshow('image', img)
    cv2.imshow('pred (cityscapes classes)',
               prediction * 13)  # *13 for visualization
    cv2.imshow('pred (event classes)',
               prediction_6classes * 40)  # *40 for visualization
Example #7
0
def main():

    ################## Load pretrained model and configurations ###################

    checkpoint_path = 'pretrained/ava_step.pth'
    if os.path.isfile(checkpoint_path):
        print ("Loading pretrain model from %s" % checkpoint_path)
        map_location = 'cuda:0'
        checkpoint = torch.load(checkpoint_path, map_location=map_location)
        args = checkpoint['cfg']
    else:
        raise ValueError("Pretrain model not found!", checkpoint_path)

    if not os.path.isdir(args.save_root):
        os.makedirs(args.save_root)
    
    label_dict = {}
    if args.num_classes == 60:
        label_map = os.path.join(args.data_root, 'label/ava_action_list_v2.1_for_activitynet_2018.pbtxt')
        categories, class_whitelist = read_labelmap(open(label_map, 'r'))
        classes = [(val['id'], val['name']) for val in categories]
        id2class = {c[0]: c[1] for c in classes}    # gt class id (1~80) --> class name
        for i, c in enumerate(sorted(list(class_whitelist))):
            label_dict[i] = c
    else:
        for i in range(80):
            label_dict[i] = i+1

    ################ Define models #################

    gpu_count = torch.cuda.device_count()
    nets = OrderedDict()
    # backbone network
    nets['base_net'] = BaseNet(args)
    # ROI pooling
    nets['roi_net'] = ROINet(args.pool_mode, args.pool_size)

    # detection network
    for i in range(args.max_iter):
        if args.det_net == "two_branch":
            nets['det_net%d' % i] = TwoBranchNet(args)
        else:
            raise NotImplementedError
    if not args.no_context:
        # context branch
        nets['context_net'] = ContextNet(args)

    for key in nets:
        nets[key] = nets[key].cuda()

    nets['base_net'] = torch.nn.DataParallel(nets['base_net'])
    if not args.no_context:
        nets['context_net'] = torch.nn.DataParallel(nets['context_net'])
    for i in range(args.max_iter):
        nets['det_net%d' % i].to('cuda:%d' % ((i+1)%gpu_count))
        nets['det_net%d' % i].set_device('cuda:%d' % ((i+1)%gpu_count))

    # load pretrained weights
    nets['base_net'].load_state_dict(checkpoint['base_net'])
    if not args.no_context and 'context_net' in checkpoint:
        nets['context_net'].load_state_dict(checkpoint['context_net'])
    for i in range(args.max_iter):
        pretrained_dict = checkpoint['det_net%d' % i]
        nets['det_net%d' % i].load_state_dict(pretrained_dict)

    
    ################ DataLoader setup #################

    dataset = AVADataset(args.data_root, 'test', args.input_type, args.T, args.NUM_CHUNKS[args.max_iter], args.fps, BaseTransform(args.image_size, args.means, args.stds,args.scale_norm), proposal_path=args.proposal_path_val, stride=1, anchor_mode=args.anchor_mode, num_classes=args.num_classes, foreground_only=False)
    dataloader = torch.utils.data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers,
                                  shuffle=False, collate_fn=detection_collate, pin_memory=True)

    ################ Inference #################

    for _, net in nets.items():
        net.eval()

    # write results to files for evaluation
    output_files = []
    fouts = []
    for i in range(args.max_iter):
        output_file = args.save_root+'testing_result-iter'+str(i+1)+'.csv'
        output_files.append(output_file)
        f = open(output_file, 'w')
        fouts.append(f)

    gt_file = args.save_root+'testing_gt.csv'
    fout = open(gt_file, 'w')

    torch.cuda.synchronize()
    t0 = time.time()
    with torch.no_grad():    # for evaluation
        for num, (images, targets, tubes, infos) in enumerate(dataloader):

            if (num+1) % 100 == 0:
                print ("%d / %d" % (num+1, len(dataloader.dataset)/args.batch_size))

            for b in range(len(infos)):
                for n in range(len(infos[b]['boxes'])):
                    mid = int(len(infos[b]['boxes'][n])/2)
                    box = infos[b]['boxes'][n][mid]
                    labels = infos[b]['labels'][n][mid]
                    for label in labels:
                        fout.write('{0},{1:04},{2:.4},{3:.4},{4:.4},{5:.4},{6}\n'.format(
                                    infos[b]['video_name'],
                                    infos[b]['fid'],
                                    box[0], box[1], box[2], box[3],
                                    label))

            _, _, channels, height, width = images.size()
            images = images.cuda()

            # get conv features
            conv_feat = nets['base_net'](images)
            context_feat = None
            if not args.no_context:
                context_feat = nets['context_net'](conv_feat)

            ############## Inference ##############

            history, _ = inference(args, conv_feat, context_feat, nets, args.max_iter, tubes)

            #################### Evaluation #################

            # loop for each  iteration
            for i in range(len(history)):
                pred_prob = history[i]['pred_prob'].cpu()
                pred_prob = pred_prob[:,int(pred_prob.shape[1]/2)]
                pred_tubes = history[i]['pred_loc'].cpu()
                pred_tubes = pred_tubes[:,int(pred_tubes.shape[1]/2)]
                tubes_nums = history[i]['tubes_nums']

                # loop for each sample in a batch
                tubes_count = 0
                for b in range(len(tubes_nums)):
                    info = infos[b]
                    seq_start = tubes_count
                    tubes_count = tubes_count + tubes_nums[b]
    
                    cur_pred_prob = pred_prob[seq_start:seq_start+tubes_nums[b]]
                    cur_pred_tubes = pred_tubes[seq_start:seq_start+tubes_nums[b]]

                    # do NMS first
                    all_scores = []
                    all_boxes = []
                    all_idx = []
                    for cl_ind in range(args.num_classes):
                        scores = cur_pred_prob[:, cl_ind].squeeze().reshape(-1)
                        c_mask = scores.gt(args.conf_thresh) # greater than minmum threshold
                        scores = scores[c_mask]
                        idx = np.where(c_mask.numpy())[0]
                        if len(scores) == 0:
                            all_scores.append([])
                            all_boxes.append([])
                            continue
                        boxes = cur_pred_tubes.clone()
                        l_mask = c_mask.unsqueeze(1).expand_as(boxes)
                        boxes = boxes[l_mask].view(-1, 4)
    
                        boxes = valid_tubes(boxes.view(-1,1,4)).view(-1,4)
                        keep = nms(boxes, scores, args.nms_thresh)
                        boxes = boxes[keep].numpy()
                        scores = scores[keep].numpy()
                        idx = idx[keep]
    
                        boxes[:, ::2] /= width
                        boxes[:, 1::2] /= height
                        all_scores.append(scores)
                        all_boxes.append(boxes)
                        all_idx.append(idx)

                    # get the top scores
                    scores_list = [(s,cl_ind,j) for cl_ind,scores in enumerate(all_scores) for j,s in enumerate(scores)]
                    if args.evaluate_topk > 0:
                        scores_list.sort(key=lambda x: x[0])
                        scores_list = scores_list[::-1]
                        scores_list = scores_list[:args.topk]

                    for s,cl_ind,j in scores_list:
                        # write to files
                        box = all_boxes[cl_ind][j]
                        fouts[i].write('{0},{1:04},{2:.4},{3:.4},{4:.4},{5:.4},{6},{7:.4}\n'.format(
                                                    info['video_name'],
                                                    info['fid'],
                                                    box[0],box[1],box[2],box[3],
                                                    label_dict[cl_ind],
                                                    s))
    fout.close()

    all_metrics = []
    for i in range(args.max_iter):
        fouts[i].close()

        metrics = ava_evaluation(os.path.join(args.data_root, 'label/'), output_files[i], gt_file)
        all_metrics.append(metrics)

    # Logging
    log_name = args.save_root+"testing_results.log"
    log_file = open(log_name, "w", 1)
    prt_str = ''
    for i in range(args.max_iter):
        prt_str += 'Iter '+str(i+1)+': MEANAP =>'+str(all_metrics[i]['PascalBoxes_Precision/[email protected]'])+'\n'
    log_file.write(prt_str)
    
    for i in class_whitelist:
        log_file.write("({}) {}: {}\n".format(i,id2class[i], 
            all_metrics[-1]["PascalBoxes_PerformanceByCategory/[email protected]/{}".format(id2class[i])]))

    log_file.close()