Exemplo n.º 1
0
def train(data_loader, model, criterion, optimizer, epoch, args):
    logger = logutils.Logger()
    model.train()
    start_time = time.time()
    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='Batch Loop Training')):
        logger.data_time.update(time.time() - start_time)
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        batch_num = features_batch.size(1)
        if args.cuda:
            features_batch = features_batch.cuda()
            labels_batch = labels_batch.cuda()
        if args.task == 'affordance':
            obj_num, _ = torch.max(obj_nums, dim=-1)
            features_batch = features_batch[:, :, :obj_num, :]
            labels_batch = labels_batch[:, :, :obj_num]
            features_batch = features_batch.view(
                (features_batch.size(0), -1, features_batch.size(-1)))
            labels_batch = labels_batch.view((labels_batch.size(0), -1))

        output = model(features_batch)
        _, pred_labels = torch.max(output, dim=-1)
        loss = loss_func(criterion, output, labels_batch, total_lengths)
        video_length = torch.sum(total_lengths).item()
        logger.losses.update(loss.item(), video_length)

        for in_batch_idx in range(batch_num):
            detections = pred_labels[:,
                                     in_batch_idx].cpu().data.numpy().flatten(
                                     ).tolist()
            if args.subsample != 1:
                all_total_labels, all_total_lengths = additional
                gt_detections = all_total_labels[:all_total_lengths[
                    in_batch_idx], in_batch_idx].flatten().tolist()
                video_length = len(gt_detections)
                detections = evalutils.upsample(detections,
                                                freq=args.subsample,
                                                length=video_length)
            else:
                gt_detections = labels_batch[:total_lengths[in_batch_idx],
                                             in_batch_idx].cpu().data.numpy(
                                             ).flatten().tolist()
                detections = detections[:total_lengths[in_batch_idx]]
                video_length = len(gt_detections)
            micro_prec = logutils.compute_accuracy(gt_detections, detections)
            macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
                gt_detections, detections, metric='macro')
            task_acc_ratio.update(micro_prec, video_length)
            task_macro_prec.update(macro_prec, video_length)
            task_macro_rec.update(macro_rec, video_length)
            task_macro_f1.update(macro_f1, video_length)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.batch_time.update(time.time() - start_time)
        start_time = time.time()

        if (batch_idx + 1) % args.log_interval == 0:
            tqdm.write('Task {} Epoch: [{}][{}/{}]\t'
                       'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                       'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                       'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                       'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                       'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                           args.task,
                           epoch,
                           batch_idx,
                           len(data_loader),
                           batch_time=logger.batch_time,
                           data_time=logger.data_time,
                           loss=logger.losses,
                           top1=task_acc_ratio,
                           prec=task_macro_prec,
                           recall=task_macro_rec,
                           f1=task_macro_f1))
Exemplo n.º 2
0
def validate(data_loader, model, args):
    all_gt_detections = list()
    all_detections = list()

    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        epsilon = torch.log(torch.tensor(1e-4))
        maximum = torch.log(
            torch.tensor(1 - 1e-4 * (len(args.metadata.actions) - 1)))
        model_outputs = torch.ones(
            (features_batch.size(0), features_batch.size(1),
             len(args.metadata.actions))) * epsilon
        model_outputs = model_outputs.scatter_(
            2,
            labels_batch.type(torch.LongTensor).unsqueeze(1), maximum)
        model_outputs = F.softmax(model_outputs / args.temperature, dim=-1)
        # model_outputs = torch.ones((features_batch.size(0), features_batch.size(1), len(args.metadata.actions))) / len(args.metadata.actions)

        # Inference
        tqdm.write('[{}] Inference'.format(sequence_ids[0]))
        _, nn_pred_labels = torch.max(model_outputs, dim=-1)
        nn_detections = nn_pred_labels.cpu().data.numpy().flatten().tolist()
        pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(
            model_outputs, activities, sequence_ids, args)
        # Evaluation
        # Frame-wise detection
        detections = [
            l for pred_labels in batch_earley_pred_labels
            for l in pred_labels.tolist()
        ]
        if args.subsample != 1:
            all_total_labels, all_total_lengths = additional
            gt_detections = all_total_labels[:all_total_lengths[0]].flatten(
            ).tolist()
            video_length = len(gt_detections)

            detections = evalutils.upsample(detections,
                                            freq=args.subsample,
                                            length=video_length)
            nn_detections = evalutils.upsample(nn_detections,
                                               freq=args.subsample,
                                               length=video_length)
        else:
            gt_detections = labels_batch[:total_lengths[0]].cpu().data.numpy(
            ).flatten().tolist()
            detections = detections[:total_lengths[0]]
        video_length = len(gt_detections)

        # vizutils.plot_segmentation([gt_detections, nn_detections, detections], video_length,
        #                            filename=os.path.join(args.paths.visualize_root, '{}.jpg'.format(sequence_ids[0])), border=False)

        micro_prec = logutils.compute_accuracy(gt_detections, detections)
        micro_prec_nn = logutils.compute_accuracy(gt_detections, nn_detections)
        macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
            gt_detections, detections, metric='macro')
        task_acc_ratio.update(micro_prec, video_length)
        task_acc_ratio_nn.update(micro_prec_nn, video_length)
        task_macro_prec.update(macro_prec, video_length)
        task_macro_rec.update(macro_rec, video_length)
        task_macro_f1.update(macro_f1, video_length)

        all_gt_detections.extend(gt_detections)
        all_detections.extend(detections)

        micro_prec = logutils.compute_accuracy(all_gt_detections,
                                               all_detections)
        macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
            all_gt_detections, all_detections, metric='macro')
        tqdm.write('[Evaluation] Micro Prec: {}\t'
                   'Macro Precision: {}\t'
                   'Macro Recall: {}\t'
                   'Macro F-score: {}'.format(micro_prec, macro_prec,
                                              macro_recall, macro_fscore))

    micro_prec = logutils.compute_accuracy(all_gt_detections, all_detections)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_detections, all_detections, metric='macro')
    tqdm.write('Detection:\n'
               'Micro Prec: {}\t'
               'NN Prec:{}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}\n\n'.format(micro_prec,
                                              task_acc_ratio_nn.avg,
                                              macro_prec, macro_recall,
                                              macro_fscore))
Exemplo n.º 3
0
def validate(data_loader, model, args, test=False, save=False):
    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    all_labels = list()
    all_gt_labels = list()

    count = 0
    # switch to evaluate mode
    model.eval()
    for batch_idx, data_unit in enumerate(
            tqdm(data_loader,
                 desc='Batch Loop Validation'
                 if not test else 'Batch Loop Testing')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        batch_num = features_batch.size(1)
        count += torch.sum(total_lengths)

        if args.cuda:
            features_batch = features_batch.cuda()
            labels_batch = labels_batch.cuda()
        if args.task == 'affordance':
            obj_num, _ = torch.max(obj_nums, dim=-1)
            features_batch = features_batch[:, :, :obj_num, :]
            labels_batch = labels_batch[:, :, :obj_num]
            features_batch = features_batch.view(
                (features_batch.size(0), -1, features_batch.size(-1)))
            labels_batch = labels_batch.view((labels_batch.size(0), -1))

        output = model(features_batch)
        for batch_i in range(output.size()[1]):
            save_output = output[:int(total_lengths[batch_i]),
                                 batch_i, :].squeeze().cpu().data.numpy()
            if not os.path.exists(args.save_path):
                os.makedirs(args.save_path)
            np.save(
                os.path.join(
                    args.save_path,
                    '{}_out_s{}_b{}_c{}.npy'.format(sequence_ids[batch_i],
                                                    args.subsample,
                                                    args.batch_size,
                                                    args.epochs)), save_output)

        _, pred_labels = torch.max(output, dim=-1)

        for in_batch_idx in range(batch_num):
            detections = pred_labels[:,
                                     in_batch_idx].cpu().data.numpy().flatten(
                                     ).tolist()
            if args.subsample != 1:
                all_total_labels, all_total_lengths = additional
                gt_detections = all_total_labels[:all_total_lengths[
                    in_batch_idx], in_batch_idx].flatten().tolist()
                video_length = len(gt_detections)
                detections = evalutils.upsample(detections,
                                                freq=args.subsample,
                                                length=video_length)
            else:
                gt_detections = labels_batch[:total_lengths[in_batch_idx],
                                             in_batch_idx].cpu().data.numpy(
                                             ).flatten().tolist()
                detections = detections[:total_lengths[in_batch_idx]]
                video_length = len(gt_detections)
            all_labels.extend(detections)
            all_gt_labels.extend(gt_detections)

            micro_prec = logutils.compute_accuracy(gt_detections, detections)
            macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
                gt_detections, detections, metric='macro')
            task_acc_ratio.update(micro_prec, video_length)
            task_macro_prec.update(macro_prec, video_length)
            task_macro_rec.update(macro_rec, video_length)
            task_macro_f1.update(macro_f1, video_length)

        if not test:
            if (batch_idx + 1) % args.log_interval == 0:
                tqdm.write('[Validation] Task {} {} Batch [{}/{}]\t'
                           'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                           'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                           'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                           'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                               args.task,
                               'val' if not test else 'test',
                               batch_idx + 1,
                               len(data_loader),
                               top1=task_acc_ratio,
                               prec=task_macro_prec,
                               recall=task_macro_rec,
                               f1=task_macro_f1))

        if args.task == 'affordance':
            output = output.view(
                (output.size(0), batch_num, -1, output.size(-1)))
        if save:
            for batch_i in range(output.size()[1]):
                if args.task == 'affordance':
                    model_probs = torch.nn.Softmax(dim=-1)(
                        output[:int(total_lengths[batch_i]),
                               batch_i, :, :].squeeze()).cpu().data.numpy()
                else:
                    model_probs = torch.nn.Softmax(dim=-1)(
                        output[:int(total_lengths[batch_i]),
                               batch_i, :].squeeze()).cpu().data.numpy()
                if not os.path.exists(args.save_path):
                    os.makedirs(args.save_path)
                np.save(
                    os.path.join(args.save_path,
                                 '{}.npy'.format(sequence_ids[batch_i])),
                    model_probs)
    print(count)
    micro_prec = logutils.compute_accuracy(all_gt_labels, all_labels)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_labels, all_labels, metric='macro')
    tqdm.write('[Evaluation] Micro Prec: {}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}'.format(micro_prec, macro_prec, macro_recall,
                                          macro_fscore))
    return micro_prec
Exemplo n.º 4
0
def validate(data_loader, detection_model, prediction_model, args):
    all_gt_frame_predictions = list()
    all_frame_predictions = list()
    all_nn_frame_predictions = list()

    task_acc_ratio = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    # switch to evaluate mode
    detection_model.eval()
    prediction_model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        detection_likelihood = torch.nn.Softmax(dim=-1)(
            detection_model(features_batch)).data.cpu().numpy()

        padding = features_batch[0, :, :].repeat(args.using_pred_duration - 1,
                                                 1, 1)
        prediction_features = torch.cat((padding, features_batch), dim=0)
        prediction_output = prediction_model(prediction_features)
        prediction_likelihood = torch.nn.Softmax(
            dim=-1)(prediction_output).data.cpu().numpy()

        for batch_i in range(features_batch.size(1)):
            _, pred_labels = torch.max(
                prediction_output[:total_lengths[batch_i] - 1, batch_i, :],
                dim=-1)
            prediction_likelihood = prediction_likelihood[:total_lengths[
                batch_i] - 1, batch_i, :]

            skip_size = args.using_pred_duration - args.pred_duration

            # for frame in range(0, total_lengths[batch_i]-1, skip_size):
            for frame in range(
                    0, total_lengths[batch_i] - args.using_pred_duration,
                    skip_size):
                det = detection_likelihood[:frame + 1, batch_i, :]
                # det = detection_likelihood[:frame+1+args.using_pred_duration, batch_i, :]
                gt_det = torch.zeros(det.shape)
                gt_det.scatter_(1, labels_batch[:frame + 1,
                                                batch_i].unsqueeze(1), 1)
                gt_det = gt_det * 0.95 + (0.05 / 10) * torch.ones(det.shape)
                gt_det = gt_det.numpy()

                pred = prediction_likelihood[frame:frame +
                                             args.using_pred_duration, :]
                prob_mat = np.concatenate((det, pred), axis=0)
                pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(
                    prob_mat, activities[batch_i], sequence_ids[batch_i], args)

                # Testing
                gep_predictions = batch_earley_pred_labels[
                    frame + 1:frame + args.using_pred_duration + 1]
                all_frame_predictions.extend(gep_predictions)
                nn_frame_predictions = pred_labels[frame + 1:frame +
                                                   args.using_pred_duration +
                                                   1]
                all_nn_frame_predictions.extend(nn_frame_predictions)
                gt_frame_predictions = labels_batch[
                    frame + 1:frame + args.using_pred_duration + 1,
                    batch_i].cpu().numpy().tolist()
                all_gt_frame_predictions.extend(gt_frame_predictions)

                video_length = len(gt_frame_predictions)
                micro_prec_nn = logutils.compute_accuracy(
                    gt_frame_predictions, nn_frame_predictions)
                task_acc_ratio_nn.update(micro_prec_nn, video_length)

                continue
            micro_prec = logutils.compute_accuracy(all_gt_frame_predictions,
                                                   all_frame_predictions)
            nn_mirco_prec = logutils.compute_accuracy(
                all_gt_frame_predictions, all_nn_frame_predictions)
            macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
                all_gt_frame_predictions,
                all_frame_predictions,
                metric='macro')
            tqdm.write('[Evaluation] Micro Prec: {}\t'
                       'NN Precision: {}\t'
                       'Macro Precision: {}\t'
                       'Macro Recall: {}\t'
                       'Macro F-score: {}'.format(micro_prec, nn_mirco_prec,
                                                  macro_prec, macro_recall,
                                                  macro_fscore))
Exemplo n.º 5
0
def validate(data_loader, detection_model, prediction_model, args):
    all_gt_frame_predictions = list()
    all_frame_predictions = list()
    all_nn_frame_predictions = list()

    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    # switch to evaluate mode
    detection_model.eval()
    prediction_model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit

        padding = features_batch[0, :, :].repeat(args.using_pred_duration - 1,
                                                 1, 1)
        prediction_features = torch.cat((padding, features_batch), dim=0)
        prediction_output = prediction_model(prediction_features)
        detection_output = detection_model(features_batch)

        _, detection_labels = torch.max(detection_output, dim=-1)
        detection_labels = detection_labels.cpu().numpy()

        for batch_i in range(detection_output.size(1)):

            gt_all_pred_labels = labels_batch[1:total_lengths[batch_i],
                                              batch_i].cpu().numpy().tolist()
            _, nn_all_pred_labels = torch.max(
                prediction_output[:total_lengths[batch_i] - 1, batch_i, :],
                dim=-1)
            nn_all_pred_labels = nn_all_pred_labels.cpu().numpy().tolist()

            # Initialization of Earley Parser
            class_num = detection_output.shape[2]
            grammar_file = os.path.join(args.paths.grammar_root,
                                        activities[batch_i] + '.pcfg')
            grammar = grammarutils.read_grammar(grammar_file, index=True)
            gen_earley_parser = GEP.GeneralizedEarley(
                grammar, class_num, mapping=args.metadata.action_index)
            with open(
                    os.path.join(args.paths.prior_root,
                                 'duration_prior.json')) as f:
                duration_prior = json.load(f)

            record = dict()

            start_time = time.time()
            for frame in range(total_lengths[batch_i] -
                               args.using_pred_duration):
                nn_pred_labels = nn_all_pred_labels[frame:frame +
                                                    args.using_pred_duration]
                gt_pred_labels = gt_all_pred_labels[frame:frame +
                                                    args.using_pred_duration]
                update_length = len(nn_pred_labels)

                pred_labels = predict(gen_earley_parser,
                                      detection_output[frame, batch_i, :],
                                      duration_prior, record, frame, args)
                # gt = torch.ones(detection_output.size(2)) * 1e-5
                # gt[labels_batch[frame, batch_i]] = 1
                # gt = torch.log(gt / torch.sum(gt))
                # pred_labels = predict(gen_earley_parser, gt,
                #                       duration_prior, record, frame, args)
                # print(frame)
                # print('detection_labels', detection_labels[max(0, frame - 44) : frame + 1, batch_i].tolist())
                # print('gt_detect labels', labels_batch[max(0, frame - 44) :frame+1, batch_i].cpu().numpy().tolist())
                # print('gt_predic_labels', gt_pred_labels)
                # print('nn_predic_labels', nn_pred_labels)
                # print('xx_predic_labels', pred_labels)

                micro_prec = logutils.compute_accuracy(gt_pred_labels,
                                                       pred_labels)
                nn_micro_prec = logutils.compute_accuracy(
                    gt_pred_labels, nn_pred_labels)
                macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
                    gt_pred_labels, nn_pred_labels, metric='macro')
                task_acc_ratio.update(micro_prec, update_length)
                task_acc_ratio_nn.update(nn_micro_prec, update_length)
                task_macro_prec.update(macro_prec, update_length)
                task_macro_rec.update(macro_rec, update_length)
                task_macro_f1.update(macro_f1, update_length)

                all_gt_frame_predictions.extend(gt_pred_labels)
                all_frame_predictions.extend(pred_labels)
                all_nn_frame_predictions.extend(nn_pred_labels)

            print(time.time() - start_time)

        tqdm.write('Task {} {} Batch [{}/{}]\t'
                   'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                   'NN Acc {nn.val:.4f} ({nn.avg:.4f})\t'
                   'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                   'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                   'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                       args.task,
                       'test',
                       batch_idx,
                       len(data_loader),
                       top1=task_acc_ratio,
                       nn=task_acc_ratio_nn,
                       prec=task_macro_prec,
                       recall=task_macro_rec,
                       f1=task_macro_f1))

    micro_prec = logutils.compute_accuracy(all_gt_frame_predictions,
                                           all_frame_predictions)
    nn_micro_prec = logutils.compute_accuracy(all_gt_frame_predictions,
                                              all_nn_frame_predictions)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_frame_predictions, all_nn_frame_predictions, metric='weighted')
    tqdm.write('[Evaluation] Micro Prec: {}\t'
               'NN Micro Prec: {}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}'.format(micro_prec, nn_micro_prec,
                                          macro_prec, macro_recall,
                                          macro_fscore))
Exemplo n.º 6
0
def validate(data_loader, model, args):
    all_gt_detections = list()
    all_detections = list()

    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    # switch to evaluate mode
    model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        print(
            os.path.join(
                args.save_path,
                '{}_out_s{}_b{}_c{}.npy'.format(sequence_ids[0],
                                                args.subsample,
                                                args.using_batch_size,
                                                args.trained_epochs)))
        # exit()
        model_outputs = torch.tensor(
            np.load(
                os.path.join(
                    args.save_path, '{}_out_s{}_b{}_c{}.npy'.format(
                        sequence_ids[0], args.subsample, args.using_batch_size,
                        args.trained_epochs)))).unsqueeze(1)

        # Inference
        tqdm.write('[{}] Inference'.format(sequence_ids[0]))

        seg_path = os.path.join(args.paths.inter_root, 'segmentation')
        if not os.path.exists(seg_path):
            os.makedirs(seg_path)

        # # If no prior model outputs are provided
        # if not os.path.isfile(os.path.join(seg_path, '{}.npy'.format(sequence_ids[0]))):
        #     _, nn_pred_labels = torch.max(model_outputs, dim=-1)
        #     nn_detections = nn_pred_labels.cpu().data.numpy().flatten().tolist()
        #     pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(model_outputs, activities, sequence_ids, args)
        #
        #     # Evaluation
        #     # Frame-wise detection
        #     detections = [l for pred_labels in batch_earley_pred_labels for l in pred_labels.tolist()]
        #     if args.subsample != 1:
        #         all_total_labels, all_total_lengths = additional
        #         gt_detections = all_total_labels[:all_total_lengths[0]].flatten().tolist()
        #         video_length = len(gt_detections)
        #
        #         detections = evalutils.upsample(detections, freq=args.subsample, length=video_length)
        #         nn_detections = evalutils.upsample(nn_detections, freq=args.subsample, length=video_length)
        #     else:
        #         gt_detections = labels_batch[:total_lengths[0]].cpu().data.numpy().flatten().tolist()
        #         detections = detections[:total_lengths[0]]
        #     np.save(os.path.join(args.paths.inter_root, 'segmentation', '{}.npy'.format(sequence_ids[0])),
        #             [gt_detections, nn_detections, detections])
        # else:
        #     results = np.load(os.path.join(seg_path, '{}.npy'.format(sequence_ids[0])))
        #     gt_detections, nn_detections, detections = results[0], results[1], results[2]

        _, nn_pred_labels = torch.max(model_outputs, dim=-1)
        nn_detections = nn_pred_labels.cpu().data.numpy().flatten().tolist()
        pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(
            model_outputs, activities, sequence_ids, args)

        # Evaluation
        # Frame-wise detection
        detections = [
            l for pred_labels in batch_earley_pred_labels
            for l in pred_labels.tolist()
        ]
        if args.subsample != 1:
            all_total_labels, all_total_lengths = additional
            gt_detections = all_total_labels[:all_total_lengths[0]].flatten(
            ).tolist()
            video_length = len(gt_detections)

            detections = evalutils.upsample(detections,
                                            freq=args.subsample,
                                            length=video_length)
            nn_detections = evalutils.upsample(nn_detections,
                                               freq=args.subsample,
                                               length=video_length)
        else:
            gt_detections = labels_batch[:total_lengths[0]].cpu().data.numpy(
            ).flatten().tolist()
            detections = detections[:total_lengths[0]]
        video_length = len(gt_detections)

        # # Visualization code for figures
        # vizutils.plot_segmentation([gt_detections, nn_detections, detections], video_length,
        #                            filename=os.path.join(args.paths.visualize_root, '{}.jpg'.format(sequence_ids[0])), border=False)

        micro_prec = logutils.compute_accuracy(gt_detections, detections)
        micro_prec_nn = logutils.compute_accuracy(gt_detections, nn_detections)
        macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
            gt_detections, detections, metric='macro')
        task_acc_ratio.update(micro_prec, video_length)
        task_acc_ratio_nn.update(micro_prec_nn, video_length)
        task_macro_prec.update(macro_prec, video_length)
        task_macro_rec.update(macro_rec, video_length)
        task_macro_f1.update(macro_f1, video_length)

        all_gt_detections.extend(gt_detections)
        all_detections.extend(detections)

        micro_prec = logutils.compute_accuracy(all_gt_detections,
                                               all_detections)
        macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
            all_gt_detections, all_detections, metric='macro')
        tqdm.write('[Evaluation] Micro Prec: {}\t'
                   'Macro Precision: {}\t'
                   'Macro Recall: {}\t'
                   'Macro F-score: {}'.format(micro_prec, macro_prec,
                                              macro_recall, macro_fscore))

    micro_prec = logutils.compute_accuracy(all_gt_detections, all_detections)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_detections, all_detections, metric='macro')
    tqdm.write('Detection:\n'
               'Micro Prec: {}\t'
               'NN Prec:{}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}\n\n'.format(micro_prec,
                                              task_acc_ratio_nn.avg,
                                              macro_prec, macro_recall,
                                              macro_fscore))
Exemplo n.º 7
0
def train(data_loader, model, criterion, optimizer, epoch, args):
    logger = utils.Logger()
    model.train()
    start_time = time.time()
    task_acc_ratio = utils.AverageMeter()
    task_macro_prec = utils.AverageMeter()
    task_macro_rec = utils.AverageMeter()
    task_macro_f1 = utils.AverageMeter()
    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='Batch Loop Training')):
        logger.data_time.update(time.time() - start_time)
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        if args.cuda:
            features_batch = features_batch.cuda()
            labels_batch = labels_batch.cuda()
        if args.task == 'affordance':
            obj_num, _ = torch.max(obj_nums, dim=-1)
            features_batch = features_batch[:, :, :obj_num, :]
            labels_batch = labels_batch[:, :, :obj_num]
            features_batch = features_batch.view(
                (features_batch.size(0), -1, features_batch.size(-1)))
            labels_batch = labels_batch.view((labels_batch.size(0), -1))

        padding = features_batch[0, :, :].repeat(args.pred_duration - 1, 1, 1)
        features = torch.cat((padding, features_batch), dim=0)
        output = model(features)
        loss = 0
        for batch_i in range(features.size(1)):
            gt_pred_labels = labels_batch[1:total_lengths[batch_i], batch_i]
            _, pred_labels = torch.max(output[:total_lengths[batch_i] - 1,
                                              batch_i],
                                       dim=-1)
            loss += criterion(output[:total_lengths[batch_i] - 1, batch_i],
                              gt_pred_labels)
            gt_pred_labels = gt_pred_labels.cpu().numpy().tolist()
            pred_labels = pred_labels.cpu().numpy().tolist()
            video_length = len(gt_pred_labels)

            logger.losses.update(loss.item(), video_length)
            micro_prec = utils.compute_accuracy(gt_pred_labels, pred_labels)
            macro_prec, macro_rec, macro_f1 = utils.compute_accuracy(
                gt_pred_labels, pred_labels, metric='weighted')
            task_acc_ratio.update(micro_prec, video_length)
            task_macro_prec.update(macro_prec, video_length)
            task_macro_rec.update(macro_rec, video_length)
            task_macro_f1.update(macro_f1, video_length)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.batch_time.update(time.time() - start_time)
        start_time = time.time()

        if (batch_idx + 1) % args.log_interval == 0:
            tqdm.write('Task {} Epoch: [{}][{}/{}]\t'
                       'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                       'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                       'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                       'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                       'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                           args.task,
                           epoch,
                           batch_idx,
                           len(data_loader),
                           batch_time=logger.batch_time,
                           data_time=logger.data_time,
                           loss=logger.losses,
                           top1=task_acc_ratio,
                           prec=task_macro_prec,
                           recall=task_macro_rec,
                           f1=task_macro_f1))
Exemplo n.º 8
0
def validate(data_loader, model, args, test=False):
    task_acc_ratio = utils.AverageMeter()
    task_macro_prec = utils.AverageMeter()
    task_macro_rec = utils.AverageMeter()
    task_macro_f1 = utils.AverageMeter()
    all_labels = list()
    all_gt_labels = list()

    # switch to evaluate mode
    model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader,
                 desc='Batch Loop Validation'
                 if not test else 'Batch Loop Testing')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit
        if args.cuda:
            features_batch = features_batch.cuda()
            labels_batch = labels_batch.cuda()
        if args.task == 'affordance':
            obj_num, _ = torch.max(obj_nums, dim=-1)
            features_batch = features_batch[:, :, :obj_num, :]
            labels_batch = labels_batch[:, :, :obj_num]
            features_batch = features_batch.view(
                (features_batch.size(0), -1, features_batch.size(-1)))
            labels_batch = labels_batch.view((labels_batch.size(0), -1))

        padding = features_batch[0, :, :].repeat(args.pred_duration - 1, 1, 1)
        features = torch.cat((padding, features_batch), dim=0)
        output = model(features)
        for batch_i in range(features.size(1)):
            gt_pred_labels = labels_batch[1:total_lengths[batch_i], batch_i]
            _, pred_labels = torch.max(output[:total_lengths[batch_i] - 1,
                                              batch_i, :],
                                       dim=-1)
            gt_pred_labels = gt_pred_labels.cpu().numpy().tolist()
            pred_labels = pred_labels.cpu().numpy().tolist()

            for frame in range(total_lengths[batch_i] - 1):
                video_length = len(gt_pred_labels)
                all_gt_labels.extend(gt_pred_labels[frame:frame +
                                                    args.pred_duration])
                all_labels.extend(pred_labels[frame:frame +
                                              args.pred_duration])
                micro_prec = utils.compute_accuracy(gt_pred_labels,
                                                    pred_labels)
                macro_prec, macro_rec, macro_f1 = utils.compute_accuracy(
                    gt_pred_labels, pred_labels, metric='weighted')
                task_acc_ratio.update(micro_prec, video_length)
                task_macro_prec.update(macro_prec, video_length)
                task_macro_rec.update(macro_rec, video_length)

        if (batch_idx + 1) % args.log_interval == 0:
            tqdm.write('[Validataion] Task {} {} Batch [{}/{}]\t'
                       'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                       'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                       'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                       'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                           args.task,
                           'val' if not test else 'test',
                           batch_idx,
                           len(data_loader),
                           top1=task_acc_ratio,
                           prec=task_macro_prec,
                           recall=task_macro_rec,
                           f1=task_macro_f1))

    micro_prec = utils.compute_accuracy(all_gt_labels, all_labels)
    macro_prec, macro_recall, macro_fscore = utils.compute_accuracy(
        all_gt_labels, all_labels, metric='macro')
    tqdm.write('[Evaluation] Micro Prec: {}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}'.format(micro_prec, macro_prec, macro_recall,
                                          macro_fscore))
    return micro_prec