def validate_eval(val_loader, model, criterion, args, epoch=None, fnames=[]):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()

    end = time.time()
    scores = np.zeros((len(val_loader.dataset), args.num_class))
    labels = np.zeros((len(val_loader.dataset), ))
    for i, (frames, objects, target) in enumerate(val_loader):
        with torch.no_grad():
            target = target.cuda(async=True)
            frames = frames.cuda()
            objects = objects.cuda()
            output = model(frames, objects)

            loss = criterion(output, target)
            losses.update(loss.item(), target.size(0))
            prec1 = accuracy(output.data, target)
            top1.update(prec1[0], target.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            # Record scores.
            output_f = F.softmax(output, dim=1)  # To [0, 1]
            output_np = output_f.data.cpu().numpy()
            labels_np = target.data.cpu().numpy()
            b_ind = i * args.batch_size
            e_ind = b_ind + min(args.batch_size, output_np.shape[0])
            scores[b_ind:e_ind, :] = output_np
            labels[b_ind:e_ind] = labels_np

    print(
        'Test [Epoch {0}/{1}]:  '
        '*Time {2:.2f}mins ({batch_time.avg:.2f}s)  '
        '*Loss {loss.avg:.4f}  '
        '*Prec@1 {top1.avg:.3f}'.format(epoch,
                                        args.epoch,
                                        batch_time.sum / 60,
                                        batch_time=batch_time,
                                        top1=top1,
                                        loss=losses))

    model.train()
    res_scores = multi_scores(scores, labels,
                              ['precision', 'recall', 'average_precision'])
    return top1.avg, losses.avg, res_scores['precision'], res_scores[
        'recall'], res_scores['average_precision']
def train_eval(train_loader,
               val_loader,
               model,
               criterion,
               optimizer,
               args,
               epoch,
               fnames=[]):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    # Fix resnet50.
    # model.module.resnet50.eval()

    end = time.time()
    scores = np.zeros((len(train_loader.dataset), args.num_class))
    labels = np.zeros((len(train_loader.dataset), ))

    # Checkpoint begin batch.
    b_batch = 0
    if epoch == cp_recorder.contextual['b_epoch']:
        b_batch = cp_recorder.contextual['b_batch'] + 1

    for i, (frames, target) in enumerate(train_loader):
        # Jump to contextual batch
        if i < b_batch:
            continue
        target = target.cuda(async=True)
        frames = frames.cuda()
        output = model(frames)

        loss = criterion(output, target)
        losses.update(loss.item(), target.size(0))
        prec1 = accuracy(output.data, target)
        top1.update(prec1[0], target.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % args.print_freq == 0:
            """Every 10 batches, print on screen and print train information on tensorboard
			"""
            niter = epoch * len(train_loader) + i
            print(
                'Train [Batch {0}/{1}|Epoch {2}/{3}]:  '
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                'Loss {loss.val:.4f} ({loss.avg:.4f})  '
                'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    i,
                    len(train_loader),
                    epoch,
                    args.epoch,
                    batch_time=batch_time,
                    loss=losses,
                    top1=top1))

            writer.add_scalars('Loss (per batch)', {'train-10b': loss.item()},
                               niter)
            writer.add_scalars('Prec@1 (per batch)', {'train-10b': prec1[0]},
                               niter)

        if (i + 1) % (args.print_freq * 10) == 0:
            # Every 100 batches, print on screen and print validation information on tensorboard

            top1_avg_val, loss_avg_val, prec, recall, ap = validate_eval(
                val_loader, model, criterion, args, epoch)
            writer.add_scalars('Loss (per batch)', {'valid': loss_avg_val},
                               niter)
            writer.add_scalars('Prec@1 (per batch)', {'valid': top1_avg_val},
                               niter)
            writer.add_scalars('mAP (per batch)',
                               {'valid': np.nan_to_num(ap).mean()}, niter)

            # Save checkpoint every 100 batches.
            cp_recorder.record_contextual({
                'b_epoch': epoch,
                'b_batch': i,
                'prec': top1_avg_val,
                'loss': loss_avg_val,
                'class_prec': prec,
                'class_recall': recall,
                'class_ap': ap,
                'mAP': np.nan_to_num(ap).mean()
            })
            cp_recorder.save_checkpoint(model)

        # Record scores.
        output_f = F.softmax(output, dim=1)  # To [0, 1]
        output_np = output_f.data.cpu().numpy()
        labels_np = target.data.cpu().numpy()
        b_ind = i * args.batch_size
        e_ind = b_ind + min(args.batch_size, output_np.shape[0])
        scores[b_ind:e_ind, :] = output_np
        labels[b_ind:e_ind] = labels_np

    res_scores = multi_scores(scores, labels,
                              ['precision', 'recall', 'average_precision'])
    print(
        'Train [Epoch {0}/{1}]:  '
        '*Time {2:.2f}mins ({batch_time.avg:.2f}s)  '
        '*Loss {loss.avg:.4f}  '
        '*Prec@1 {top1.avg:.3f}'.format(epoch,
                                        args.epoch,
                                        batch_time.sum / 60,
                                        batch_time=batch_time,
                                        loss=losses,
                                        top1=top1))

    return top1.avg, losses.avg, res_scores['precision'], res_scores[
        'recall'], res_scores['average_precision']
def validate_eval(val_loader, model, criterion, args, epoch=None, fnames=[]):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()

    end = time.time()
    scores = np.zeros((len(val_loader.dataset), args.num_classes))
    labels = np.zeros((len(val_loader.dataset), ))
    for i, (union, obj1, obj2, bpos, target, full_im, bboxes_14,
            categories) in enumerate(val_loader):
        with torch.no_grad():
            # Create bboxes
            batch_size = bboxes_14.size(0)
            cur_rois_sum = categories[0, 0].item()
            bboxes = bboxes_14[0, 0:categories[0, 0].item(), :]
            for b in range(1, batch_size):
                bboxes = torch.cat(
                    (bboxes, bboxes_14[b, 0:categories[b, 0], :]), 0)
                cur_rois_sum += categories[b, 0].item()
            assert (bboxes.size(0) == cur_rois_sum
                    ), 'Bboxes num must equal to categories num'

            target = target.cuda(async=True)
            union = union.cuda()
            obj1 = obj1.cuda()
            obj2 = obj2.cuda()
            bpos = bpos.cuda()
            full_im = full_im.cuda()
            bboxes = bboxes.cuda()
            categories = categories.cuda()

            output = model(union, obj1, obj2, bpos, full_im, bboxes,
                           categories)

            loss = criterion(output, target)
            losses.update(loss.item(), union.size(0))
            prec1 = accuracy(output.data, target)
            top1.update(prec1[0], union.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            # Record scores.
            output_f = F.softmax(output, dim=1)  # To [0, 1]
            output_np = output_f.data.cpu().numpy()
            labels_np = target.data.cpu().numpy()
            b_ind = i * args.batch_size
            e_ind = b_ind + min(args.batch_size, output_np.shape[0])
            scores[b_ind:e_ind, :] = output_np
            labels[b_ind:e_ind] = labels_np

    print(
        'Test [Epoch {0}/{1}]:  '
        '*Time {2:.2f}mins ({batch_time.avg:.2f}s)  '
        '*Loss {loss.avg:.4f}  '
        '*Prec@1 {top1.avg:.3f}'.format(epoch,
                                        args.epoch,
                                        batch_time.sum / 60,
                                        batch_time=batch_time,
                                        top1=top1,
                                        loss=losses))

    res_scores = multi_scores(scores, labels,
                              ['precision', 'recall', 'average_precision'])
    return top1.avg, losses.avg, res_scores['precision'], res_scores[
        'recall'], res_scores['average_precision']