예제 #1
0
파일: detect.py 프로젝트: tf369/GEP_PAMI
def train(data_loader, model, criterion, optimizer, epoch, args):
    logger = logutils.Logger()
    model.train()
    start_time = time.time()
    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='Batch Loop Training')):
        logger.data_time.update(time.time() - start_time)
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        batch_num = features_batch.size(1)
        if args.cuda:
            features_batch = features_batch.cuda()
            labels_batch = labels_batch.cuda()
        if args.task == 'affordance':
            obj_num, _ = torch.max(obj_nums, dim=-1)
            features_batch = features_batch[:, :, :obj_num, :]
            labels_batch = labels_batch[:, :, :obj_num]
            features_batch = features_batch.view(
                (features_batch.size(0), -1, features_batch.size(-1)))
            labels_batch = labels_batch.view((labels_batch.size(0), -1))

        output = model(features_batch)
        _, pred_labels = torch.max(output, dim=-1)
        loss = loss_func(criterion, output, labels_batch, total_lengths)
        video_length = torch.sum(total_lengths).item()
        logger.losses.update(loss.item(), video_length)

        for in_batch_idx in range(batch_num):
            detections = pred_labels[:,
                                     in_batch_idx].cpu().data.numpy().flatten(
                                     ).tolist()
            if args.subsample != 1:
                all_total_labels, all_total_lengths = additional
                gt_detections = all_total_labels[:all_total_lengths[
                    in_batch_idx], in_batch_idx].flatten().tolist()
                video_length = len(gt_detections)
                detections = evalutils.upsample(detections,
                                                freq=args.subsample,
                                                length=video_length)
            else:
                gt_detections = labels_batch[:total_lengths[in_batch_idx],
                                             in_batch_idx].cpu().data.numpy(
                                             ).flatten().tolist()
                detections = detections[:total_lengths[in_batch_idx]]
                video_length = len(gt_detections)
            micro_prec = logutils.compute_accuracy(gt_detections, detections)
            macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
                gt_detections, detections, metric='macro')
            task_acc_ratio.update(micro_prec, video_length)
            task_macro_prec.update(macro_prec, video_length)
            task_macro_rec.update(macro_rec, video_length)
            task_macro_f1.update(macro_f1, video_length)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.batch_time.update(time.time() - start_time)
        start_time = time.time()

        if (batch_idx + 1) % args.log_interval == 0:
            tqdm.write('Task {} Epoch: [{}][{}/{}]\t'
                       'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                       'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                       'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                       'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                       'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                           args.task,
                           epoch,
                           batch_idx,
                           len(data_loader),
                           batch_time=logger.batch_time,
                           data_time=logger.data_time,
                           loss=logger.losses,
                           top1=task_acc_ratio,
                           prec=task_macro_prec,
                           recall=task_macro_rec,
                           f1=task_macro_f1))
예제 #2
0
파일: inference.py 프로젝트: tf369/GEP_PAMI
def evaluate(paths):
    try:
        priors = load_prior(paths)
    except IOError:
        sys.exit('Prior information not found.')

    try:
        activity_corpus = pickle.load(open(os.path.join(paths.tmp_root, 'activity_corpus.p'), 'rb'))
    except IOError:
        sys.exit('Ground truth pickle file not found.')

    grammar_dict = grammarutils.read_induced_grammar(paths)
    languages = grammarutils.read_languages(paths)

    # Prediction duration
    duration = 45 + 1

    total_seg_gt_s = list()
    total_seg_pred_s = list()
    total_seg_gt_u = list()
    total_seg_pred_u = list()

    total_gt_s = [list() for _ in range(duration)]
    total_pred_s = [list() for _ in range(duration)]
    total_gt_u = [list() for _ in range(duration)]
    total_pred_u = [list() for _ in range(duration)]

    total_gt_e = list()
    total_pred_e = list()

    for activity, tpgs in activity_corpus.items():
        print(activity)
        for tpg in tpgs:
            print(tpg.id, tpg.terminals[-1].end_frame)
            # if tpg.subject != 'Subject5':
            #     continue
            # if tpg.id != '1204142858':  # Taking medicine, start_frame != 0
            #     continue
            # if tpg.id != '1204144736':
            #     continue
            # if tpg.id == '1204174554' or tpg.id == '1204142616' or tpg.id == '0510142336' or tpg.id == '1204175712' or tpg.id == '1130151154' or tpg.id == '0510172333' or tpg.id == '1130151154':
            #     continue
            infer_start_time = time.time()
            results = infer(paths, tpg, priors, grammar_dict, languages, duration)
            print('Inference time elapsed: {}s'.format(time.time() - infer_start_time))
            # seg_gt_s, seg_pred_s, seg_gt_u, seg_pred_u, gt_s, pred_s, gt_u, pred_u, trace_s, lstm_pred_s, e = results
            seg_gt_s, seg_pred_s, seg_gt_u, seg_pred_u, gt_s, pred_s, gt_u, pred_u, e = results

            # vizutils.plot_segmentation([gt_s[0], pred_s[0], trace_s[:-1], lstm_pred_s[:-1]], len(gt_s[0]))

            total_seg_gt_s.extend(seg_gt_s)
            total_seg_pred_s.extend(seg_pred_s)
            total_seg_gt_u.extend(seg_gt_u)
            total_seg_pred_u.extend(seg_pred_u)

            total_gt_e.append(metadata.activity_index[tpg.activity])
            total_pred_e.append(metadata.activity_index[e])

            for i in range(duration):
                total_gt_s[i].extend(gt_s[i])
                total_pred_s[i].extend(pred_s[i])
                total_gt_u[i].extend(gt_u[i])
                total_pred_u[i].extend(pred_u[i])

    print('===================  Frame wise Action detection  ===================')
    print(utils.compute_accuracy(total_gt_s[0], total_pred_s[0], metric='micro'))
    print(utils.compute_accuracy(total_gt_s[0], total_pred_s[0], metric='macro'))

    print('=================  Frame wise Affordance detection  =================')
    print(utils.compute_accuracy(total_gt_u[0], total_pred_u[0], metric='micro'))
    print(utils.compute_accuracy(total_gt_u[0], total_pred_u[0], metric='macro'))

    print('===================  Segment wise prediction  ===================')
    print(utils.compute_accuracy(total_seg_gt_s, total_seg_pred_s, metric='micro'))
    print(utils.compute_accuracy(total_seg_gt_s, total_seg_pred_s, metric='macro'))

    eval_gt_frame_s = list()
    eval_pred_frame_s = list()
    eval_gt_frame_u = list()
    eval_pred_frame_u = list()
    for i in range(duration):
        eval_gt_frame_s.extend(total_gt_s[i])
        eval_pred_frame_s.extend(total_pred_s[i])
        eval_gt_frame_u.extend(total_gt_u[i])
        eval_pred_frame_u.extend(total_pred_u[i])
    print('===================  Frame wise action prediction  ===================')
    print(utils.compute_accuracy(eval_gt_frame_s, eval_pred_frame_s, metric='micro'))
    print(utils.compute_accuracy(eval_gt_frame_s, eval_pred_frame_s, metric='macro'))

    print('===================  Frame wise affordance prediction  ===================')
    print(utils.compute_accuracy(eval_gt_frame_u, eval_pred_frame_u, metric='micro'))
    print(utils.compute_accuracy(eval_gt_frame_u, eval_pred_frame_u, metric='macro'))
예제 #3
0
파일: detect.py 프로젝트: tf369/GEP_PAMI
def validate(data_loader, model, args, test=False, save=False):
    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    all_labels = list()
    all_gt_labels = list()

    count = 0
    # switch to evaluate mode
    model.eval()
    for batch_idx, data_unit in enumerate(
            tqdm(data_loader,
                 desc='Batch Loop Validation'
                 if not test else 'Batch Loop Testing')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        batch_num = features_batch.size(1)
        count += torch.sum(total_lengths)

        if args.cuda:
            features_batch = features_batch.cuda()
            labels_batch = labels_batch.cuda()
        if args.task == 'affordance':
            obj_num, _ = torch.max(obj_nums, dim=-1)
            features_batch = features_batch[:, :, :obj_num, :]
            labels_batch = labels_batch[:, :, :obj_num]
            features_batch = features_batch.view(
                (features_batch.size(0), -1, features_batch.size(-1)))
            labels_batch = labels_batch.view((labels_batch.size(0), -1))

        output = model(features_batch)
        for batch_i in range(output.size()[1]):
            save_output = output[:int(total_lengths[batch_i]),
                                 batch_i, :].squeeze().cpu().data.numpy()
            if not os.path.exists(args.save_path):
                os.makedirs(args.save_path)
            np.save(
                os.path.join(
                    args.save_path,
                    '{}_out_s{}_b{}_c{}.npy'.format(sequence_ids[batch_i],
                                                    args.subsample,
                                                    args.batch_size,
                                                    args.epochs)), save_output)

        _, pred_labels = torch.max(output, dim=-1)

        for in_batch_idx in range(batch_num):
            detections = pred_labels[:,
                                     in_batch_idx].cpu().data.numpy().flatten(
                                     ).tolist()
            if args.subsample != 1:
                all_total_labels, all_total_lengths = additional
                gt_detections = all_total_labels[:all_total_lengths[
                    in_batch_idx], in_batch_idx].flatten().tolist()
                video_length = len(gt_detections)
                detections = evalutils.upsample(detections,
                                                freq=args.subsample,
                                                length=video_length)
            else:
                gt_detections = labels_batch[:total_lengths[in_batch_idx],
                                             in_batch_idx].cpu().data.numpy(
                                             ).flatten().tolist()
                detections = detections[:total_lengths[in_batch_idx]]
                video_length = len(gt_detections)
            all_labels.extend(detections)
            all_gt_labels.extend(gt_detections)

            micro_prec = logutils.compute_accuracy(gt_detections, detections)
            macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
                gt_detections, detections, metric='macro')
            task_acc_ratio.update(micro_prec, video_length)
            task_macro_prec.update(macro_prec, video_length)
            task_macro_rec.update(macro_rec, video_length)
            task_macro_f1.update(macro_f1, video_length)

        if not test:
            if (batch_idx + 1) % args.log_interval == 0:
                tqdm.write('[Validation] Task {} {} Batch [{}/{}]\t'
                           'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                           'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                           'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                           'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                               args.task,
                               'val' if not test else 'test',
                               batch_idx + 1,
                               len(data_loader),
                               top1=task_acc_ratio,
                               prec=task_macro_prec,
                               recall=task_macro_rec,
                               f1=task_macro_f1))

        if args.task == 'affordance':
            output = output.view(
                (output.size(0), batch_num, -1, output.size(-1)))
        if save:
            for batch_i in range(output.size()[1]):
                if args.task == 'affordance':
                    model_probs = torch.nn.Softmax(dim=-1)(
                        output[:int(total_lengths[batch_i]),
                               batch_i, :, :].squeeze()).cpu().data.numpy()
                else:
                    model_probs = torch.nn.Softmax(dim=-1)(
                        output[:int(total_lengths[batch_i]),
                               batch_i, :].squeeze()).cpu().data.numpy()
                if not os.path.exists(args.save_path):
                    os.makedirs(args.save_path)
                np.save(
                    os.path.join(args.save_path,
                                 '{}.npy'.format(sequence_ids[batch_i])),
                    model_probs)
    print(count)
    micro_prec = logutils.compute_accuracy(all_gt_labels, all_labels)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_labels, all_labels, metric='macro')
    tqdm.write('[Evaluation] Micro Prec: {}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}'.format(micro_prec, macro_prec, macro_recall,
                                          macro_fscore))
    return micro_prec
예제 #4
0
def validate(data_loader, model, args):
    all_gt_detections = list()
    all_detections = list()

    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        epsilon = torch.log(torch.tensor(1e-4))
        maximum = torch.log(
            torch.tensor(1 - 1e-4 * (len(args.metadata.actions) - 1)))
        model_outputs = torch.ones(
            (features_batch.size(0), features_batch.size(1),
             len(args.metadata.actions))) * epsilon
        model_outputs = model_outputs.scatter_(
            2,
            labels_batch.type(torch.LongTensor).unsqueeze(1), maximum)
        model_outputs = F.softmax(model_outputs / args.temperature, dim=-1)
        # model_outputs = torch.ones((features_batch.size(0), features_batch.size(1), len(args.metadata.actions))) / len(args.metadata.actions)

        # Inference
        tqdm.write('[{}] Inference'.format(sequence_ids[0]))
        _, nn_pred_labels = torch.max(model_outputs, dim=-1)
        nn_detections = nn_pred_labels.cpu().data.numpy().flatten().tolist()
        pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(
            model_outputs, activities, sequence_ids, args)
        # Evaluation
        # Frame-wise detection
        detections = [
            l for pred_labels in batch_earley_pred_labels
            for l in pred_labels.tolist()
        ]
        if args.subsample != 1:
            all_total_labels, all_total_lengths = additional
            gt_detections = all_total_labels[:all_total_lengths[0]].flatten(
            ).tolist()
            video_length = len(gt_detections)

            detections = evalutils.upsample(detections,
                                            freq=args.subsample,
                                            length=video_length)
            nn_detections = evalutils.upsample(nn_detections,
                                               freq=args.subsample,
                                               length=video_length)
        else:
            gt_detections = labels_batch[:total_lengths[0]].cpu().data.numpy(
            ).flatten().tolist()
            detections = detections[:total_lengths[0]]
        video_length = len(gt_detections)

        # vizutils.plot_segmentation([gt_detections, nn_detections, detections], video_length,
        #                            filename=os.path.join(args.paths.visualize_root, '{}.jpg'.format(sequence_ids[0])), border=False)

        micro_prec = logutils.compute_accuracy(gt_detections, detections)
        micro_prec_nn = logutils.compute_accuracy(gt_detections, nn_detections)
        macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
            gt_detections, detections, metric='macro')
        task_acc_ratio.update(micro_prec, video_length)
        task_acc_ratio_nn.update(micro_prec_nn, video_length)
        task_macro_prec.update(macro_prec, video_length)
        task_macro_rec.update(macro_rec, video_length)
        task_macro_f1.update(macro_f1, video_length)

        all_gt_detections.extend(gt_detections)
        all_detections.extend(detections)

        micro_prec = logutils.compute_accuracy(all_gt_detections,
                                               all_detections)
        macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
            all_gt_detections, all_detections, metric='macro')
        tqdm.write('[Evaluation] Micro Prec: {}\t'
                   'Macro Precision: {}\t'
                   'Macro Recall: {}\t'
                   'Macro F-score: {}'.format(micro_prec, macro_prec,
                                              macro_recall, macro_fscore))

    micro_prec = logutils.compute_accuracy(all_gt_detections, all_detections)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_detections, all_detections, metric='macro')
    tqdm.write('Detection:\n'
               'Micro Prec: {}\t'
               'NN Prec:{}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}\n\n'.format(micro_prec,
                                              task_acc_ratio_nn.avg,
                                              macro_prec, macro_recall,
                                              macro_fscore))
예제 #5
0
파일: inference.py 프로젝트: tf369/GEP_PAMI
def infer(paths, gt_tpg, priors, grammar_dict, languages, duration):
    gt_subactivity, gt_objects, gt_affordance = get_ground_truth_label(gt_tpg)
    likelihoods = get_intermediate_results(paths, gt_tpg)
    obj_num = gt_objects.shape[0]

    # Segmentation
    # dp_start_time = time.time()
    trace_begin, trace_a, trace_o, trace_u, trace_s = dp_segmentation(priors, likelihoods)
    ttrace_begin, ttrace_a, ttrace_s = dp_segmentation_s(priors, likelihoods)
    # print('DP segmentation time elapsed: {}'.format(time.time() - dp_start_time))

    # Labels for evaluation
    seg_gt_s = list()
    seg_pred_s = list()
    seg_gt_u = list()
    seg_pred_u = list()

    gt_s = [list() for _ in range(duration)]
    pred_s = [list() for _ in range(duration)]
    gt_u = [list() for _ in range(duration)]
    pred_u = [list() for _ in range(duration)]

    for end_frame in range(1, int(trace_begin.shape[0])):
    # for end_frame in range(10, 350, 10):
    # for end_frame in [350]:
        # Gibbs sampling to refine the parsing
        tpg = generate_parse_graph(trace_begin, trace_a, trace_o, trace_u, trace_s, gt_tpg.terminals[0].start_frame, end_frame)
        # print str(gt_tpg), tpg_to_tokens(tpg, np.inf)
        # vizutil.visualize_tpg_labeling(gt_subactivity, gt_affordance, tpg, obj_num, end_frame)
        tpg.activity = gt_tpg.activity
        tpg = gibbs_sampling(tpg, grammar_dict, languages, priors, likelihoods)
        # vizutil.visualize_tpg_labeling(gt_subactivity, gt_affordance, tpg, obj_num, end_frame)

        # Prediction
        predicted_tpg = predict(grammar_dict, languages, tpg, end_frame, duration, priors, likelihoods)
        # vizutil.visualize_tpg_labeling(gt_subactivity, gt_affordance, predicted_tpg, obj_num, end_frame+duration)

        # Labels for evaluation
        get_next_subactivity_label(gt_tpg, predicted_tpg, seg_gt_s, seg_pred_s, end_frame)
        get_next_affordance_label(gt_tpg, predicted_tpg, seg_gt_u, seg_pred_u, obj_num, end_frame)

        subactivities, actions, objects, affordance = get_label(predicted_tpg, obj_num)
        pred_end_frame = np.min([subactivities.shape[0], gt_subactivity.shape[0], end_frame-1+duration])
        # print subactivities.shape, actions.shape, objects.shape, affordance.shape
        # print gt_subactivity.shape, gt_objects.shape, gt_affordance.shape
        for f in range(end_frame-1, pred_end_frame):
            gt_s[f-end_frame+1].append(gt_subactivity[f])
            pred_s[f-end_frame+1].append(subactivities[f])
            for io in range(obj_num):
                gt_u[f-end_frame+1].append(gt_affordance[io, f])
                pred_u[f-end_frame+1].append(affordance[io, f])

    print(gt_tpg.activity, tpg.activity, predicted_tpg.activity)
    print(str(gt_tpg))
    print(tpg_to_tokens(tpg, np.inf))

    final_pred_seg = np.zeros_like(seg_gt_s)
    for spg in predicted_tpg.terminals:
        final_pred_seg[spg.start_frame: spg.end_frame + 1] = spg.action

    # action_log_likelihood = np.load(os.path.join(paths.inter_root, 'likelihood', 'activity',
    #                                              gt_tpg.activity + '$' + gt_tpg.id + '.npy')).T
    # lstm_pred_s = np.argmax(action_log_likelihood, axis=0).tolist()
    print('Action detection micro evaluation:', utils.compute_accuracy(gt_s[0], pred_s[0], metric='macro'))
    print('Affordance detection micro evaluation:', utils.compute_accuracy(gt_u[0], pred_u[0], metric='macro'))
    return seg_gt_s, seg_pred_s, seg_gt_u, seg_pred_u, gt_s, pred_s, gt_u, pred_u, predicted_tpg.activity
예제 #6
0
def validate(data_loader, detection_model, prediction_model, args):
    all_gt_frame_predictions = list()
    all_frame_predictions = list()
    all_nn_frame_predictions = list()

    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    # switch to evaluate mode
    detection_model.eval()
    prediction_model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit

        padding = features_batch[0, :, :].repeat(args.using_pred_duration - 1,
                                                 1, 1)
        prediction_features = torch.cat((padding, features_batch), dim=0)
        prediction_output = prediction_model(prediction_features)
        detection_output = detection_model(features_batch)

        _, detection_labels = torch.max(detection_output, dim=-1)
        detection_labels = detection_labels.cpu().numpy()

        for batch_i in range(detection_output.size(1)):

            gt_all_pred_labels = labels_batch[1:total_lengths[batch_i],
                                              batch_i].cpu().numpy().tolist()
            _, nn_all_pred_labels = torch.max(
                prediction_output[:total_lengths[batch_i] - 1, batch_i, :],
                dim=-1)
            nn_all_pred_labels = nn_all_pred_labels.cpu().numpy().tolist()

            # Initialization of Earley Parser
            class_num = detection_output.shape[2]
            grammar_file = os.path.join(args.paths.grammar_root,
                                        activities[batch_i] + '.pcfg')
            grammar = grammarutils.read_grammar(grammar_file, index=True)
            gen_earley_parser = GEP.GeneralizedEarley(
                grammar, class_num, mapping=args.metadata.action_index)
            with open(
                    os.path.join(args.paths.prior_root,
                                 'duration_prior.json')) as f:
                duration_prior = json.load(f)

            record = dict()

            start_time = time.time()
            for frame in range(total_lengths[batch_i] -
                               args.using_pred_duration):
                nn_pred_labels = nn_all_pred_labels[frame:frame +
                                                    args.using_pred_duration]
                gt_pred_labels = gt_all_pred_labels[frame:frame +
                                                    args.using_pred_duration]
                update_length = len(nn_pred_labels)

                pred_labels = predict(gen_earley_parser,
                                      detection_output[frame, batch_i, :],
                                      duration_prior, record, frame, args)
                # gt = torch.ones(detection_output.size(2)) * 1e-5
                # gt[labels_batch[frame, batch_i]] = 1
                # gt = torch.log(gt / torch.sum(gt))
                # pred_labels = predict(gen_earley_parser, gt,
                #                       duration_prior, record, frame, args)
                # print(frame)
                # print('detection_labels', detection_labels[max(0, frame - 44) : frame + 1, batch_i].tolist())
                # print('gt_detect labels', labels_batch[max(0, frame - 44) :frame+1, batch_i].cpu().numpy().tolist())
                # print('gt_predic_labels', gt_pred_labels)
                # print('nn_predic_labels', nn_pred_labels)
                # print('xx_predic_labels', pred_labels)

                micro_prec = logutils.compute_accuracy(gt_pred_labels,
                                                       pred_labels)
                nn_micro_prec = logutils.compute_accuracy(
                    gt_pred_labels, nn_pred_labels)
                macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
                    gt_pred_labels, nn_pred_labels, metric='macro')
                task_acc_ratio.update(micro_prec, update_length)
                task_acc_ratio_nn.update(nn_micro_prec, update_length)
                task_macro_prec.update(macro_prec, update_length)
                task_macro_rec.update(macro_rec, update_length)
                task_macro_f1.update(macro_f1, update_length)

                all_gt_frame_predictions.extend(gt_pred_labels)
                all_frame_predictions.extend(pred_labels)
                all_nn_frame_predictions.extend(nn_pred_labels)

            print(time.time() - start_time)

        tqdm.write('Task {} {} Batch [{}/{}]\t'
                   'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                   'NN Acc {nn.val:.4f} ({nn.avg:.4f})\t'
                   'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                   'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                   'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                       args.task,
                       'test',
                       batch_idx,
                       len(data_loader),
                       top1=task_acc_ratio,
                       nn=task_acc_ratio_nn,
                       prec=task_macro_prec,
                       recall=task_macro_rec,
                       f1=task_macro_f1))

    micro_prec = logutils.compute_accuracy(all_gt_frame_predictions,
                                           all_frame_predictions)
    nn_micro_prec = logutils.compute_accuracy(all_gt_frame_predictions,
                                              all_nn_frame_predictions)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_frame_predictions, all_nn_frame_predictions, metric='weighted')
    tqdm.write('[Evaluation] Micro Prec: {}\t'
               'NN Micro Prec: {}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}'.format(micro_prec, nn_micro_prec,
                                          macro_prec, macro_recall,
                                          macro_fscore))
예제 #7
0
def validate(data_loader, detection_model, prediction_model, args):
    all_gt_frame_predictions = list()
    all_frame_predictions = list()
    all_nn_frame_predictions = list()

    task_acc_ratio = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    # switch to evaluate mode
    detection_model.eval()
    prediction_model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        detection_likelihood = torch.nn.Softmax(dim=-1)(
            detection_model(features_batch)).data.cpu().numpy()

        padding = features_batch[0, :, :].repeat(args.using_pred_duration - 1,
                                                 1, 1)
        prediction_features = torch.cat((padding, features_batch), dim=0)
        prediction_output = prediction_model(prediction_features)
        prediction_likelihood = torch.nn.Softmax(
            dim=-1)(prediction_output).data.cpu().numpy()

        for batch_i in range(features_batch.size(1)):
            _, pred_labels = torch.max(
                prediction_output[:total_lengths[batch_i] - 1, batch_i, :],
                dim=-1)
            prediction_likelihood = prediction_likelihood[:total_lengths[
                batch_i] - 1, batch_i, :]

            skip_size = args.using_pred_duration - args.pred_duration

            # for frame in range(0, total_lengths[batch_i]-1, skip_size):
            for frame in range(
                    0, total_lengths[batch_i] - args.using_pred_duration,
                    skip_size):
                det = detection_likelihood[:frame + 1, batch_i, :]
                # det = detection_likelihood[:frame+1+args.using_pred_duration, batch_i, :]
                gt_det = torch.zeros(det.shape)
                gt_det.scatter_(1, labels_batch[:frame + 1,
                                                batch_i].unsqueeze(1), 1)
                gt_det = gt_det * 0.95 + (0.05 / 10) * torch.ones(det.shape)
                gt_det = gt_det.numpy()

                pred = prediction_likelihood[frame:frame +
                                             args.using_pred_duration, :]
                prob_mat = np.concatenate((det, pred), axis=0)
                pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(
                    prob_mat, activities[batch_i], sequence_ids[batch_i], args)

                # Testing
                gep_predictions = batch_earley_pred_labels[
                    frame + 1:frame + args.using_pred_duration + 1]
                all_frame_predictions.extend(gep_predictions)
                nn_frame_predictions = pred_labels[frame + 1:frame +
                                                   args.using_pred_duration +
                                                   1]
                all_nn_frame_predictions.extend(nn_frame_predictions)
                gt_frame_predictions = labels_batch[
                    frame + 1:frame + args.using_pred_duration + 1,
                    batch_i].cpu().numpy().tolist()
                all_gt_frame_predictions.extend(gt_frame_predictions)

                video_length = len(gt_frame_predictions)
                micro_prec_nn = logutils.compute_accuracy(
                    gt_frame_predictions, nn_frame_predictions)
                task_acc_ratio_nn.update(micro_prec_nn, video_length)

                continue
            micro_prec = logutils.compute_accuracy(all_gt_frame_predictions,
                                                   all_frame_predictions)
            nn_mirco_prec = logutils.compute_accuracy(
                all_gt_frame_predictions, all_nn_frame_predictions)
            macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
                all_gt_frame_predictions,
                all_frame_predictions,
                metric='macro')
            tqdm.write('[Evaluation] Micro Prec: {}\t'
                       'NN Precision: {}\t'
                       'Macro Precision: {}\t'
                       'Macro Recall: {}\t'
                       'Macro F-score: {}'.format(micro_prec, nn_mirco_prec,
                                                  macro_prec, macro_recall,
                                                  macro_fscore))
예제 #8
0
파일: gep.py 프로젝트: tf369/GEP_PAMI
def validate(data_loader, model, args):
    all_gt_detections = list()
    all_detections = list()

    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    # switch to evaluate mode
    model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        print(
            os.path.join(
                args.save_path,
                '{}_out_s{}_b{}_c{}.npy'.format(sequence_ids[0],
                                                args.subsample,
                                                args.using_batch_size,
                                                args.trained_epochs)))
        # exit()
        model_outputs = torch.tensor(
            np.load(
                os.path.join(
                    args.save_path, '{}_out_s{}_b{}_c{}.npy'.format(
                        sequence_ids[0], args.subsample, args.using_batch_size,
                        args.trained_epochs)))).unsqueeze(1)

        # Inference
        tqdm.write('[{}] Inference'.format(sequence_ids[0]))

        seg_path = os.path.join(args.paths.inter_root, 'segmentation')
        if not os.path.exists(seg_path):
            os.makedirs(seg_path)

        # # If no prior model outputs are provided
        # if not os.path.isfile(os.path.join(seg_path, '{}.npy'.format(sequence_ids[0]))):
        #     _, nn_pred_labels = torch.max(model_outputs, dim=-1)
        #     nn_detections = nn_pred_labels.cpu().data.numpy().flatten().tolist()
        #     pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(model_outputs, activities, sequence_ids, args)
        #
        #     # Evaluation
        #     # Frame-wise detection
        #     detections = [l for pred_labels in batch_earley_pred_labels for l in pred_labels.tolist()]
        #     if args.subsample != 1:
        #         all_total_labels, all_total_lengths = additional
        #         gt_detections = all_total_labels[:all_total_lengths[0]].flatten().tolist()
        #         video_length = len(gt_detections)
        #
        #         detections = evalutils.upsample(detections, freq=args.subsample, length=video_length)
        #         nn_detections = evalutils.upsample(nn_detections, freq=args.subsample, length=video_length)
        #     else:
        #         gt_detections = labels_batch[:total_lengths[0]].cpu().data.numpy().flatten().tolist()
        #         detections = detections[:total_lengths[0]]
        #     np.save(os.path.join(args.paths.inter_root, 'segmentation', '{}.npy'.format(sequence_ids[0])),
        #             [gt_detections, nn_detections, detections])
        # else:
        #     results = np.load(os.path.join(seg_path, '{}.npy'.format(sequence_ids[0])))
        #     gt_detections, nn_detections, detections = results[0], results[1], results[2]

        _, nn_pred_labels = torch.max(model_outputs, dim=-1)
        nn_detections = nn_pred_labels.cpu().data.numpy().flatten().tolist()
        pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(
            model_outputs, activities, sequence_ids, args)

        # Evaluation
        # Frame-wise detection
        detections = [
            l for pred_labels in batch_earley_pred_labels
            for l in pred_labels.tolist()
        ]
        if args.subsample != 1:
            all_total_labels, all_total_lengths = additional
            gt_detections = all_total_labels[:all_total_lengths[0]].flatten(
            ).tolist()
            video_length = len(gt_detections)

            detections = evalutils.upsample(detections,
                                            freq=args.subsample,
                                            length=video_length)
            nn_detections = evalutils.upsample(nn_detections,
                                               freq=args.subsample,
                                               length=video_length)
        else:
            gt_detections = labels_batch[:total_lengths[0]].cpu().data.numpy(
            ).flatten().tolist()
            detections = detections[:total_lengths[0]]
        video_length = len(gt_detections)

        # # Visualization code for figures
        # vizutils.plot_segmentation([gt_detections, nn_detections, detections], video_length,
        #                            filename=os.path.join(args.paths.visualize_root, '{}.jpg'.format(sequence_ids[0])), border=False)

        micro_prec = logutils.compute_accuracy(gt_detections, detections)
        micro_prec_nn = logutils.compute_accuracy(gt_detections, nn_detections)
        macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
            gt_detections, detections, metric='macro')
        task_acc_ratio.update(micro_prec, video_length)
        task_acc_ratio_nn.update(micro_prec_nn, video_length)
        task_macro_prec.update(macro_prec, video_length)
        task_macro_rec.update(macro_rec, video_length)
        task_macro_f1.update(macro_f1, video_length)

        all_gt_detections.extend(gt_detections)
        all_detections.extend(detections)

        micro_prec = logutils.compute_accuracy(all_gt_detections,
                                               all_detections)
        macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
            all_gt_detections, all_detections, metric='macro')
        tqdm.write('[Evaluation] Micro Prec: {}\t'
                   'Macro Precision: {}\t'
                   'Macro Recall: {}\t'
                   'Macro F-score: {}'.format(micro_prec, macro_prec,
                                              macro_recall, macro_fscore))

    micro_prec = logutils.compute_accuracy(all_gt_detections, all_detections)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_detections, all_detections, metric='macro')
    tqdm.write('Detection:\n'
               'Micro Prec: {}\t'
               'NN Prec:{}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}\n\n'.format(micro_prec,
                                              task_acc_ratio_nn.avg,
                                              macro_prec, macro_recall,
                                              macro_fscore))
예제 #9
0
파일: finetune.py 프로젝트: tf369/GEP_PAMI
def train(data_loader, model, criterion, optimizer, epoch, args):
    logger = utils.Logger()
    model.train()
    start_time = time.time()
    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='Batch Loop Training')):
        logger.data_time.update(time.time() - start_time)
        sequence_id, rgb_image, depth_image, aligned_image, activity, object_labels, object_images, affordance, skeleton, affordance_features = data_unit

        if args.task == 'affordance':
            for object_id in range(object_labels.size(1)):
                # Batch_size * object_num * 3 * 224 * 224
                object_name = metadata.objects[np.argmax(
                    object_labels[0, object_id].numpy())]
                object_image = object_images[:, object_id, :, :, :].squeeze(
                    1).cuda()
                # affordance (batch_size, )
                affordance_label = affordance[:, object_id].cuda()
                feature, output = model(object_image)
                loss = criterion(output, affordance_label)
                _, pred_labels = torch.max(torch.nn.Softmax(dim=-1)(output),
                                           dim=-1)
                pred_labels = pred_labels.cpu().data.numpy().flatten().tolist()
                gt_labels = affordance_label.cpu().data.numpy().flatten(
                ).tolist()
                prec = utils.compute_accuracy(gt_labels,
                                              pred_labels,
                                              metric='micro')
                logger.multi_losses.update(object_name, loss.item(),
                                           len(sequence_id))
                logger.top1.update(object_name, prec, len(sequence_id))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        else:
            rgb_image = rgb_image.cuda()
            affordance_features = affordance_features.cuda()
            activity = activity.squeeze(1).cuda()
            feature, output = model(rgb_image, affordance_features)
            loss = criterion(output, activity)
            _, pred_labels = torch.max(torch.nn.Softmax(dim=-1)(output),
                                       dim=-1)
            pred_labels = pred_labels.cpu().data.numpy().flatten().tolist()
            gt_labels = activity.cpu().data.numpy().flatten().tolist()
            prec = utils.compute_accuracy(gt_labels,
                                          pred_labels,
                                          metric='micro')
            logger.multi_losses.update('Activity', loss.item(),
                                       len(sequence_id))
            logger.top1.update('Activity', prec, len(sequence_id))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        logger.batch_time.update(time.time() - start_time)
        start_time = time.time()

        if (batch_idx + 1) % args.log_interval == 0:
            tqdm.write('Task {} Epoch: [{}][{}/{}]\t'
                       'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                       'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                       'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                       'Prec@1 {top1.val} ({top1.avg})'.format(
                           args.task,
                           epoch,
                           batch_idx,
                           len(data_loader),
                           batch_time=logger.batch_time,
                           data_time=logger.data_time,
                           loss=logger.multi_losses,
                           top1=logger.top1))
예제 #10
0
파일: finetune.py 프로젝트: tf369/GEP_PAMI
def validate(data_loader, model, args, test=False, save=False):
    logger = utils.Logger()
    model.eval()
    start_time = time.time()
    if save and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    all_detections = list()
    all_gt_detections = list()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader,
                 desc='Batch Loop Validating'
                 if not test else 'Batch Loop Testing')):
        logger.data_time.update(time.time() - start_time)
        sequence_id, rgb_image, depth_image, aligned_image, activity, object_labels, object_images, affordance, skeleton, affordance_features = data_unit
        features = None
        if args.task == 'affordance':
            for object_id in range(object_labels.size(1)):
                # Batch_size * object_num * 3 * 224 * 224
                object_name = metadata.objects[np.argmax(
                    object_labels[0, object_id].numpy())]
                object_image = object_images[:, object_id, :, :, :].squeeze(
                    1).cuda()
                # affordance (batch_size, )
                affordance_label = affordance[:, object_id].cuda()
                feature, output = model(object_image)
                _, pred_labels = torch.max(torch.nn.Softmax(dim=-1)(output),
                                           dim=-1)
                pred_labels = pred_labels.cpu().data.numpy().flatten().tolist()
                gt_labels = affordance_label.cpu().data.numpy().flatten(
                ).tolist()
                all_detections.extend(pred_labels)
                all_gt_detections.extend(gt_labels)
                prec = utils.compute_accuracy(gt_labels,
                                              pred_labels,
                                              metric='micro')
                logger.top1.update(object_name, prec, len(sequence_id))
                if save:
                    feature = feature.detach().cpu().numpy()
                    if features is None:
                        features = feature
                    else:
                        features = np.vstack((features, feature))
        else:
            rgb_image = rgb_image.cuda()
            affordance_features = affordance_features.cuda()
            feature, output = model(rgb_image, affordance_features)
            _, pred_labels = torch.max(torch.nn.Softmax(dim=-1)(output),
                                       dim=-1)
            pred_labels = pred_labels.cpu().data.numpy().flatten().tolist()
            gt_labels = activity.cpu().data.numpy().flatten().tolist()
            all_detections.extend(pred_labels)
            all_gt_detections.extend(gt_labels)
            prec = utils.compute_accuracy(gt_labels,
                                          pred_labels,
                                          metric='micro')
            logger.top1.update('Activity', prec, len(sequence_id))
            if save:
                features = feature.detach().cpu().numpy()

        if save:
            assert (len(sequence_id) == 1)
            np.save(os.path.join(args.save_path, sequence_id[0] + '.npy'),
                    features)

        logger.batch_time.update(time.time() - start_time)
        start_time = time.time()
        if not test:
            if (batch_idx + 1) % args.log_interval == 0:
                tqdm.write('Task {} Test: [{}/{}]\t'
                           'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                           'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                               args.task,
                               batch_idx,
                               len(data_loader),
                               batch_time=logger.batch_time,
                               data_time=logger.data_time,
                               top1=logger.top1))
    if test:
        micro_prec = utils.compute_accuracy(all_gt_detections, all_detections)
        macro_prec, macro_recall, macro_fscore = utils.compute_accuracy(
            all_gt_detections, all_detections, metric='weighted')
        tqdm.write('Micro Prec: {}\t'
                   'Macro Precision: {}\t'
                   'Macro Recall: {}\t'
                   'Macro F-score: {}'.format(micro_prec, macro_prec,
                                              macro_recall, macro_fscore))
    return logger.top1.avg
예제 #11
0
def train(data_loader, model, criterion, optimizer, epoch, args):
    logger = utils.Logger()
    model.train()
    start_time = time.time()
    task_acc_ratio = utils.AverageMeter()
    task_macro_prec = utils.AverageMeter()
    task_macro_rec = utils.AverageMeter()
    task_macro_f1 = utils.AverageMeter()
    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='Batch Loop Training')):
        logger.data_time.update(time.time() - start_time)
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        if args.cuda:
            features_batch = features_batch.cuda()
            labels_batch = labels_batch.cuda()
        if args.task == 'affordance':
            obj_num, _ = torch.max(obj_nums, dim=-1)
            features_batch = features_batch[:, :, :obj_num, :]
            labels_batch = labels_batch[:, :, :obj_num]
            features_batch = features_batch.view(
                (features_batch.size(0), -1, features_batch.size(-1)))
            labels_batch = labels_batch.view((labels_batch.size(0), -1))

        padding = features_batch[0, :, :].repeat(args.pred_duration - 1, 1, 1)
        features = torch.cat((padding, features_batch), dim=0)
        output = model(features)
        loss = 0
        for batch_i in range(features.size(1)):
            gt_pred_labels = labels_batch[1:total_lengths[batch_i], batch_i]
            _, pred_labels = torch.max(output[:total_lengths[batch_i] - 1,
                                              batch_i],
                                       dim=-1)
            loss += criterion(output[:total_lengths[batch_i] - 1, batch_i],
                              gt_pred_labels)
            gt_pred_labels = gt_pred_labels.cpu().numpy().tolist()
            pred_labels = pred_labels.cpu().numpy().tolist()
            video_length = len(gt_pred_labels)

            logger.losses.update(loss.item(), video_length)
            micro_prec = utils.compute_accuracy(gt_pred_labels, pred_labels)
            macro_prec, macro_rec, macro_f1 = utils.compute_accuracy(
                gt_pred_labels, pred_labels, metric='weighted')
            task_acc_ratio.update(micro_prec, video_length)
            task_macro_prec.update(macro_prec, video_length)
            task_macro_rec.update(macro_rec, video_length)
            task_macro_f1.update(macro_f1, video_length)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.batch_time.update(time.time() - start_time)
        start_time = time.time()

        if (batch_idx + 1) % args.log_interval == 0:
            tqdm.write('Task {} Epoch: [{}][{}/{}]\t'
                       'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                       'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                       'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                       'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                       'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                           args.task,
                           epoch,
                           batch_idx,
                           len(data_loader),
                           batch_time=logger.batch_time,
                           data_time=logger.data_time,
                           loss=logger.losses,
                           top1=task_acc_ratio,
                           prec=task_macro_prec,
                           recall=task_macro_rec,
                           f1=task_macro_f1))
예제 #12
0
def validate(data_loader, model, args, test=False):
    task_acc_ratio = utils.AverageMeter()
    task_macro_prec = utils.AverageMeter()
    task_macro_rec = utils.AverageMeter()
    task_macro_f1 = utils.AverageMeter()
    all_labels = list()
    all_gt_labels = list()

    # switch to evaluate mode
    model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader,
                 desc='Batch Loop Validation'
                 if not test else 'Batch Loop Testing')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        if args.cuda:
            features_batch = features_batch.cuda()
            labels_batch = labels_batch.cuda()
        if args.task == 'affordance':
            obj_num, _ = torch.max(obj_nums, dim=-1)
            features_batch = features_batch[:, :, :obj_num, :]
            labels_batch = labels_batch[:, :, :obj_num]
            features_batch = features_batch.view(
                (features_batch.size(0), -1, features_batch.size(-1)))
            labels_batch = labels_batch.view((labels_batch.size(0), -1))

        padding = features_batch[0, :, :].repeat(args.pred_duration - 1, 1, 1)
        features = torch.cat((padding, features_batch), dim=0)
        output = model(features)
        for batch_i in range(features.size(1)):
            gt_pred_labels = labels_batch[1:total_lengths[batch_i], batch_i]
            _, pred_labels = torch.max(output[:total_lengths[batch_i] - 1,
                                              batch_i, :],
                                       dim=-1)
            gt_pred_labels = gt_pred_labels.cpu().numpy().tolist()
            pred_labels = pred_labels.cpu().numpy().tolist()

            for frame in range(total_lengths[batch_i] - 1):
                video_length = len(gt_pred_labels)
                all_gt_labels.extend(gt_pred_labels[frame:frame +
                                                    args.pred_duration])
                all_labels.extend(pred_labels[frame:frame +
                                              args.pred_duration])
                micro_prec = utils.compute_accuracy(gt_pred_labels,
                                                    pred_labels)
                macro_prec, macro_rec, macro_f1 = utils.compute_accuracy(
                    gt_pred_labels, pred_labels, metric='weighted')
                task_acc_ratio.update(micro_prec, video_length)
                task_macro_prec.update(macro_prec, video_length)
                task_macro_rec.update(macro_rec, video_length)

        if (batch_idx + 1) % args.log_interval == 0:
            tqdm.write('[Validataion] Task {} {} Batch [{}/{}]\t'
                       'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                       'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                       'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                       'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                           args.task,
                           'val' if not test else 'test',
                           batch_idx,
                           len(data_loader),
                           top1=task_acc_ratio,
                           prec=task_macro_prec,
                           recall=task_macro_rec,
                           f1=task_macro_f1))

    micro_prec = utils.compute_accuracy(all_gt_labels, all_labels)
    macro_prec, macro_recall, macro_fscore = utils.compute_accuracy(
        all_gt_labels, all_labels, metric='macro')
    tqdm.write('[Evaluation] Micro Prec: {}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}'.format(micro_prec, macro_prec, macro_recall,
                                          macro_fscore))
    return micro_prec