class Trainer:
    def __init__(self, model, optimizer, all_loaders, args, resume_epoch):

        self.resume_epoch = resume_epoch
        self.args = args

        self.optimizer = torch.optim.SGD((model.parameters()),
                                         args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        self.layer_list_all = args.layers
        self.layers_dict = {
            'layer2': {
                'name': 'layer2',
                'depth': 512,
                'size': 4
            },
            'layer3': {
                'name': 'layer3',
                'depth': 512,
                'size': 8
            },
            'layer4': {
                'name': 'layer4',
                'depth': 512,
                'size': 8
            },
            'layer5': {
                'name': 'layer5',
                'depth': 256,
                'size': 16
            },
            'layer6': {
                'name': 'layer6',
                'depth': 256,
                'size': 16
            },
        }

        self.generator = gantest.GanTester(args.path_model_gan,
                                           self.layer_list_all,
                                           device=torch.device('cuda'))
        self.z = self.generator.standard_z_sample(200000)

        self.model = model
        self.optimizer = optimizer
        self.loaders = all_loaders
        self.loss_type = args.loss_type

        # Other parameters
        self.margin = args.margin
        self.clustering = args.clustering

        self.epoch = 0
        self.unorm = utils.UnNormalize(mean=(0.485, 0.456, 0.406),
                                       std=(0.229, 0.224, 0.225))

        output_size = 32 if 'large' in args.audio_model else 256

        if args.active_learning:
            active_learning.get_clusterer(self, args, output_size, model)
        else:
            if args.clustering:
                print('Creating cluster from scratch')
                cluster_path = os.path.join(
                    self.args.results, 'clusters',
                    args.name_checkpoint + '_' + str(time.time()))
                self.clusterer = Clusterer(
                    self.loaders['train'],
                    model,
                    path_store=cluster_path,
                    model_dim=args.embedding_dim,
                    save_results=True,
                    output_size=output_size,
                    args=self.args,
                    path_cluster_load=args.path_cluster_load)

        self.epochs_clustering = self.args.epochs_clustering
        self.clusters = self.mean_clust = self.std_clust = self.cluster_counts = self.clusters_unit = None

    def train(self):
        """
        Main training loop. For each epoch train the model and save checkpoint if the results are good.
        Cluster every epochs_clustering epochs
        """
        best_eval = 0

        try:
            for epoch in range(self.resume_epoch, self.args.epochs):
                self.epoch = epoch

                # Clustering
                if self.clustering and \
                        ((epoch % self.epochs_clustering == 0) or (self.args.resume and epoch == self.resume_epoch)):
                    self.clusterer.save_results = True
                    clus, mean_clust, std_clust = self.clusterer.create_clusters(
                        iteration=0)
                    self.clusters = torch.FloatTensor(clus).cuda()
                    self.mean_clust = torch.FloatTensor(mean_clust)
                    self.std_clust = torch.FloatTensor(std_clust)
                    self.cluster_counts = 1 / self.clusters.max(1)[0]
                    self.clusters_unit = self.cluster_counts.view(self.clusters.size(0), 1).expand_as(self.clusters) * \
                                         self.clusters

                    self.clusterer.name_with_images_clusters()
                    self.clusterer.name_clusters()
                    self.optimize_neurons()

                    # This is for visualization:
                    # self.clusterer.segment_images()
                    # self.clusterer.create_web_images()  # segment_images has to be uncommented before
                    self.clusterer.create_web_clusters(with_images=True)

                utils.adjust_learning_rate(self.args, self.optimizer, epoch)

                # Train for one epoch
                print('Starting training epoch ' + str(epoch))
                self.train_epoch(epoch)

                # Evaluate on validation set
                print('Starting evaluation epoch ' + str(epoch))
                eval_score, recalls = self.eval()
                self.args.writer.add_scalar('eval_score', eval_score, epoch)

                # Remember best eval score and save checkpoint
                is_best = eval_score > best_eval
                best_eval = max(eval_score, best_eval)
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'model_state_dict': self.model.state_dict(),
                        'best_eval': best_eval,
                        'recall_now': recalls,
                        'optimizer': self.optimizer.state_dict(),
                    },
                    is_best,
                    self.args,
                    name_checkpoint=self.args.name_checkpoint)

        except KeyboardInterrupt:
            print('You decided to finish the training at epoch ' +
                  str(epoch + 1))

    def train_epoch(self, epoch):
        """
        Train one epoch. It consists of 5 steps
        Step 1: Compute the output of the positive image
        Step 2: Compute the mask for the positive image features
        Step 3: Generate the negative image from this mask
        Step 4: Compute the output of this negative
        Step 5: Compute all the losses
        And after that, do the backpropagation and weight updates
        """
        if not self.args.use_cpu:
            torch.cuda.synchronize()
        batch_time = utils.AverageMeter()
        data_time = utils.AverageMeter()
        losses_meter = utils.AverageMeter()

        # Switch to train mode
        self.model.train()

        end = time.time()
        N_examples = self.loaders['train'].dataset.__len__()

        loss_list_total = {
            'loss_regular': 0,
            'loss_neg': 0,
            'loss_hardneg': 0,
            'loss_total': 0
        }
        for batch_id, (image_input, audio_input, neg_images, nframes, path,
                       image_raw) in enumerate(self.loaders['train']):
            loss_list = {
                'loss_regular': 0,
                'loss_neg': 0,
                'loss_hardneg': 0,
                'loss_total': 0
            }

            # Measure data loading time
            data_time.update(time.time() - end)

            if not self.args.use_cpu:
                audio_input = audio_input.cuda(async=True)

            if not self.args.loading_image:
                path_ints = [p.split('/')[-1] for p in path
                             ]  # in case the audio is inside a subfolder

                v_init = self.z[int(path_ints[0])]
                z_img = torch.FloatTensor(image_input.size(0), v_init.shape[0])

                for k in range(image_input.size(0)):
                    z_img[k, :] = self.z[int(path_ints[k])]

                image_input = self.generator.generate_images(z_img,
                                                             intervention=None)
                image_input = utils.transform(image_input).detach()

            else:
                image_input = image_input.cuda()
                neg_images = neg_images.cuda()

            # STEP 1: Compute output positive
            model_output = self.model(image_input, audio_input, [])
            image_output = model_output[0]
            audio_output = model_output[1]

            neg_images = []

            pooling_ratio = round(audio_input.size(3) / audio_output.size(3))
            nframes.div_(pooling_ratio)

            binary_mask_0 = None

            # Only do steps 2-4 if we want to train with semantic negatives
            if self.loss_type == 'negatives_edited' or self.loss_type == 'negatives_both':
                # STEP 2: Compute mask from image features
                limits = np.zeros((image_input.size(0), 2))

                for i in range(image_input.size(0)):
                    pos_image = image_input[i, :, :, :]

                    nF = nframes[i]

                    matchmap = utils.compute_matchmap(
                        image_output[i], audio_output[i][:, :, :nF])

                    matchmap = matchmap.data.cpu().numpy().copy()

                    matchmap = matchmap.transpose(2, 0, 1)  # l, h, w
                    matchmap = matchmap / (matchmap.max() + 1e-10)
                    matchmap_image = matchmap.max(axis=0)
                    threshold = 0.95

                    # ind_max = np.argmax(matchmap_image)
                    ind_max = np.argmax(matchmap)
                    ind_t = ind_max // (matchmap.shape[2] * matchmap.shape[1])
                    ind_h = (ind_max % (matchmap.shape[2] * matchmap.shape[1])
                             ) // matchmap.shape[1]
                    ind_w = (ind_max % (matchmap.shape[2] * matchmap.shape[1])
                             ) % matchmap.shape[1]

                    limits[i, 0] = ind_t
                    limits[i, 1] = ind_t + 1

                    if self.clustering:
                        if self.args.active_learning and 'active' in path[i]:
                            neg_img = active_learning.get_negatives(
                                self, path_ints[i])

                        else:
                            v = (image_output[i][:, ind_h, ind_w] -
                                 self.mean_clust.cuda()) / (
                                     self.std_clust.cuda() + 1e-8)

                            normalized_clusters = np.matmul(
                                self.clusters.cpu(),
                                v.detach().cpu().numpy().transpose())
                            sorted_val = -np.sort(-normalized_clusters[:])
                            sorted_val = np.clip(sorted_val, 0, 4)
                            if np.sum(sorted_val) <= 0:
                                print(
                                    "None of the clusters was close to the image feature. If this happens regularly, "
                                    "it probably means they were low quality clusters. Did you pretrain with a "
                                    "regular loss before clustering?")
                            prob_samples = sorted_val / np.sum(sorted_val)
                            sorted_id = np.argsort(-normalized_clusters[:])
                            cluster_id = sorted_id[0]

                            norm = 0
                            threshold_random = 0.95

                            # The number of units to be ablated grows if we cannot generate a good (changed) negative
                            # The following numbers are the starting number of units to change
                            num_units_dict = {
                                'layer2': 30,
                                'layer3': 30,
                                'layer4': 140,
                                'layer5': 30,
                                'layer6': 30
                            }
                            thresold_heatmap = threshold

                            count = 0
                            binary_mask_eval = matchmap_image > (
                                thresold_heatmap * matchmap_image.max())
                            binary_mask_eval = utils.geodesic_dilation(
                                binary_mask_eval, (ind_h, ind_w))
                            binary_mask_eval = cv2.resize(
                                binary_mask_eval, (128, 128))
                            bmask = torch.Tensor(binary_mask_eval).cuda()
                            bmask = bmask.view(1, 128, 128).expand(3, 128, 128)

                            while norm < threshold_random:
                                with torch.no_grad():
                                    binary_mask = matchmap_image > (
                                        thresold_heatmap *
                                        matchmap_image.max())
                                    binary_mask = utils.geodesic_dilation(
                                        binary_mask, (ind_h, ind_w))

                                    if binary_mask_0 is None:
                                        binary_mask_0 = cv2.resize(
                                            binary_mask, (224, 224))

                                    # STEP 3: Generate new image
                                    z_img = self.z[int(path_ints[i])]
                                    z_img = z_img[np.newaxis, :]

                                    _ = self.generator.generate_images(z_img)
                                    intervention = {}
                                    for layer_n in self.layer_list_all:
                                        units_ids = self.layers_units[layer_n][
                                            cluster_id][:num_units_dict[
                                                layer_n]]
                                        layer_size = self.layers_dict[layer_n][
                                            'size']
                                        layer_dim = self.layers_dict[layer_n][
                                            'depth']

                                        ablation, replacement = self.get_ablation_replacement(
                                            params=[layer_dim, units_ids],
                                            option='specific')
                                        ablation_final = cv2.resize(
                                            binary_mask,
                                            (layer_size, layer_size))
                                        ablation_final = np.tile(
                                            ablation_final,
                                            (layer_dim, 1, 1)).astype(
                                                np.float32)
                                        ablation_final = torch.cuda.FloatTensor(
                                            ablation_final)
                                        ablation_final = ablation.view(
                                            layer_dim, 1,
                                            1).expand_as(ablation_final
                                                         ) * ablation_final
                                        intervention[layer_n] = (
                                            ablation_final, replacement)

                                    neg_img = self.generator.generate_images(
                                        z_img,
                                        intervention=intervention).detach()
                                    neg_img_t = utils.transform(
                                        neg_img).detach()

                                    norm = (neg_img_t[0, :, :, :] -
                                            pos_image.detach())
                                    norm = norm * bmask
                                    norm = torch.norm(torch.norm(torch.norm(
                                        norm, dim=2),
                                                                 dim=1),
                                                      dim=0)
                                    norm_normalized = norm / torch.norm(
                                        torch.norm(torch.norm(
                                            pos_image.detach() * bmask, dim=2),
                                                   dim=1),
                                        dim=0)
                                    norm = norm_normalized.item()
                                    for layer_n in self.layer_list_all:
                                        num_units_dict[layer_n] = num_units_dict[
                                            layer_n] + 40  # increase units to change
                                    thresold_heatmap = thresold_heatmap - 0.1
                                    threshold_random = threshold_random - 0.05

                                    cluster_id = np.random.choice(
                                        sorted_id, size=1, p=prob_samples)[0]

                                    count = count + 1

                    else:  # random edited negatives
                        binary_mask = matchmap_image > (threshold *
                                                        matchmap_image.max())
                        binary_mask = utils.geodesic_dilation(
                            binary_mask, (ind_h, ind_w))
                        if binary_mask_0 is None:
                            binary_mask_0 = cv2.resize(binary_mask, (224, 224))
                        norm = 0
                        threshold_random = 0.95
                        p = 0.4

                        while norm < threshold_random:
                            with torch.no_grad():
                                intervention = {}

                                for layer_n in self.layer_list_all:
                                    layer_size = self.layers_dict[layer_n][
                                        'size']
                                    layer_dim = self.layers_dict[layer_n][
                                        'depth']

                                    ablation, replacement = self.get_ablation_replacement(
                                        params=[layer_dim, True, 0.5],
                                        option='random')
                                    ablation_final = cv2.resize(
                                        binary_mask, (layer_size, layer_size))
                                    ablation_final = np.tile(
                                        ablation_final,
                                        (layer_dim, 1, 1)).astype(np.float32)
                                    ablation_final = torch.cuda.FloatTensor(
                                        ablation_final)
                                    ablation_final = ablation.view(
                                        layer_dim, 1, 1).expand_as(
                                            ablation_final) * ablation_final
                                    intervention[layer_n] = (ablation_final,
                                                             replacement)

                                # STEP 3: Generate new image
                                z_img = self.z[int(path_ints[i])]
                                z_img = z_img[np.newaxis, :].detach()
                                neg_img = self.generator.generate_images(
                                    z_img, intervention=intervention).detach()
                                neg_img_t = utils.transform(neg_img).detach()

                                binary_mask = cv2.resize(
                                    binary_mask, (128, 128))

                                bmask = torch.Tensor(binary_mask).cuda()

                                bmask = bmask.view(1, 128,
                                                   128).expand(3, 128, 128)
                                norm = (neg_img_t[0, :, :, :] -
                                        pos_image.detach())

                                norm = norm * bmask
                                norm = torch.norm(torch.norm(torch.norm(norm,
                                                                        dim=2),
                                                             dim=1),
                                                  dim=0)
                                norm_normalized = norm / torch.norm(torch.norm(
                                    torch.norm(pos_image.detach() * bmask,
                                               dim=2),
                                    dim=1),
                                                                    dim=0)
                                norm = norm_normalized.item()

                                if random.random() > 0.2:
                                    p = p + 0.05
                                else:
                                    threshold_random = threshold_random - 0.01

                    neg_images.append(neg_img)

                neg_images = torch.cat(neg_images)
                neg_images_t = utils.transform(neg_images)
                # print(neg_images_t.size())

                # STEP 4: Compute output negative
                image_output_neg, _, _ = self.model(neg_images_t, None, [])

            # STEP 5: Compute losses
            if self.args.active_learning:
                image_output, image_output_neg = active_learning.switch_pos_neg(
                    self, image_input, image_output, image_output_neg, path)

            if self.loss_type == 'regular':
                loss = losses.sampled_margin_rank_loss(image_output,
                                                       audio_output, nframes,
                                                       self.margin,
                                                       self.args.symfun)
                loss_list['loss_regular'] = loss.item()
                loss_list['loss_total'] = loss.item()

            elif self.loss_type == 'negatives_edited':  # train with semantic negatives
                loss_regular = losses.sampled_margin_rank_loss(
                    image_output, audio_output, nframes, self.margin,
                    self.args.symfun)
                loss_neg = losses.negatives_loss(image_output, audio_output,
                                                 image_output_neg, nframes,
                                                 self.margin, self.args.symfun)
                loss = loss_regular + loss_neg
                loss_list['loss_regular'] = loss_regular.item()
                loss_list['loss_neg'] = loss_neg.item()
                loss_list['loss_total'] = loss.item()

            elif self.loss_type == 'negatives_hard':  # train with hard negatives
                loss_regular = losses.sampled_margin_rank_loss(
                    image_output, audio_output, nframes, self.margin,
                    self.args.symfun)
                loss_neg = losses.hard_negative_loss(image_output,
                                                     audio_output, nframes,
                                                     self.margin,
                                                     self.args.symfun)
                loss = loss_regular + loss_neg
                loss_list['loss_regular'] = loss_regular.item()
                loss_list['loss_neg'] = loss_neg.item()
                loss_list['loss_total'] = loss.item()

            elif self.loss_type == 'negatives_both':  # combine hard negatives with semantic negatives
                loss_hardneg = losses.combined_random_hard_negative_loss(
                    image_output, audio_output, image_output_neg, nframes,
                    self.margin, self.args.symfun)
                loss_regular = losses.sampled_margin_rank_loss(
                    image_output, audio_output, nframes, self.margin,
                    self.args.symfun)
                loss_regular = torch.clamp(loss_regular, min=0, max=5)
                loss_hardneg = torch.clamp(loss_hardneg, min=0, max=5)
                loss = loss_regular + loss_hardneg
                loss_list['loss_regular'] = loss_regular.item()
                loss_list['loss_hardneg'] = loss_hardneg.item()
                loss_list['loss_total'] = loss.item()

            else:
                raise Exception(
                    f'The loss function {self.loss_type} is not implemented.')

            last_sample = N_examples * epoch + batch_id * self.args.batch_size + image_input.size(
                0)

            # Record loss
            losses_meter.update(loss.item(), image_input.size(0))

            # Backward pass and update
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # Print results
            if (batch_id + 1) % self.args.print_freq == 0:
                for name in loss_list:
                    loss_list_total[name] += loss_list[name]
                for name in loss_list:
                    loss_list_total[
                        name] = loss_list_total[name] / self.args.print_freq

                for loss_name in loss_list:
                    self.args.writer.add_scalar(f'losses/{loss_name}',
                                                loss_list_total[loss_name],
                                                last_sample)

                print(
                    f'Epoch: [{epoch}][{batch_id+1}/{len(self.loaders["train"])}]\t'
                    f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    f'Loss {losses_meter.val:.4f} ({losses_meter.avg:.4f})\t',
                    flush=True)

                image_raw = self.unorm(image_input[0].data.cpu())
                self.args.writer.add_image('positive', image_raw, last_sample)
                if self.loss_type == 'negatives_edited' or self.loss_type == 'negatives_both':
                    image_raw_neg = self.unorm(neg_images[0].data.cpu())
                    image_neg = image_raw_neg / torch.max(image_raw_neg)
                    self.args.writer.add_image('negative', image_neg,
                                               last_sample)
                    self.args.writer.add_image(
                        'Images/region', 255 *
                        np.array([binary_mask_0, binary_mask_0, binary_mask_0
                                  ]).swapaxes(0, 1).swapaxes(1, 2),
                        last_sample)
                loss_list_total = {k: 0 for k, v in loss_list_total.items()}

            else:
                for loss_name in loss_list:
                    loss_list_total[loss_name] += loss_list[loss_name]

    def optimize_neurons(self):

        # Set up console output
        verbose_progress(True)

        gan_model = self.generator.model
        annotate_model_shapes(gan_model, gen=True)

        outdir = os.path.join(
            self.args.results, 'dissect',
            self.args.name_checkpoint + '_' + str(time.time()))
        os.makedirs(outdir, exist_ok=True)

        size = 1000

        sample = z_sample_for_model(gan_model, size)

        train_sample = z_sample_for_model(gan_model, size, seed=2)

        dataset = TensorDataset(sample)
        train_dataset = TensorDataset(train_sample)
        self.cluster_segmenter = ClusterSegmenter(self.model, self.clusters,
                                                  self.mean_clust,
                                                  self.std_clust)

        segrunner = GeneratorSegRunner(self.cluster_segmenter)

        netname = outdir
        # Run dissect
        with torch.no_grad():
            dissect(
                outdir,
                gan_model,
                dataset,
                train_dataset=train_dataset,
                segrunner=segrunner,
                examples_per_unit=20,
                netname=netname,
                quantile_threshold='iqr',
                meta=None,
                make_images=False,  # True,
                make_labels=True,
                make_maxiou=False,
                make_covariance=False,
                make_report=True,
                make_row_images=True,
                make_single_images=True,
                batch_size=8,
                num_workers=8,
                rank_all_labels=True)

            sample_ablate = z_sample_for_model(gan_model, 16)

            dataset_ablate = TensorDataset(sample_ablate)
            data_loader = torch.utils.data.DataLoader(dataset_ablate,
                                                      batch_size=8,
                                                      shuffle=False,
                                                      num_workers=8,
                                                      pin_memory=True,
                                                      sampler=None)

            with open(os.path.join(outdir, 'dissect.json')) as f:
                data = EasyDict(json.load(f))
            dissect_layer = {lrec.layer: lrec for lrec in data.layers}

            self.layers_units = {
                'layer2': [],
                'layer3': [],
                'layer4': [],
                'layer5': [],
                'layer6': [],
            }

            noise_units = np.array([35, 221, 496, 280])

            for i in range(2, len(self.clusters) + 2):
                print('Cluster', i)
                rank_name = 'c_{0}-iou'.format(i)
                for l in range(len(self.layer_list_all)):
                    ranking = next(
                        r
                        for r in dissect_layer[self.layer_list_all[l]].rankings
                        if r.name == rank_name)
                    unit_list = np.array(range(512))
                    unit_list[noise_units] = 0
                    ordering = np.argsort(ranking.score)
                    units_list = unit_list[ordering]
                    self.layers_units[self.layer_list_all[l]].append(
                        units_list)

        # Mark the directory so that it's not done again.
        mark_job_done(outdir)

    def get_ablation_replacement(self, params=(), option='random'):

        if option == 'random':
            import random
            dim_mask = params[0]
            binary = params[1]
            values = np.random.rand(dim_mask)

            if binary:
                prob_ones = params[2]
                ablation = torch.FloatTensor(
                    (np.random.rand(dim_mask) < prob_ones).astype(
                        np.float)).cuda()
            else:
                ablation = torch.FloatTensor(values).cuda()
            replacement = torch.zeros(dim_mask).cuda()

        elif option == 'specific':
            units_ids = params[1]
            dim_mask = params[0]
            ablation, replacement = torch.zeros(dim_mask).cuda(), torch.zeros(
                dim_mask).cuda()
            ablation[units_ids] = 1  # color

        else:
            raise Exception('Please introduce a valid option')

        return ablation, replacement

    def eval(self):
        """
        Collects features for number_recall images and audios and computes the recall @{1, 5, 10} of predicting one from
        the other. It does not involve any hard or edited negative.
        """
        number_recall = 500
        if not self.args.use_cpu:
            torch.cuda.synchronize()
        batch_time = utils.AverageMeter()

        # Switch to evaluate mode
        self.model.eval()

        end = time.time()
        N_examples = self.loaders['val'].dataset.__len__()
        image_embeddings = []  # torch.FloatTensor(N_examples, embedding_dim)
        audio_embeddings = []  # torch.FloatTensor(N_examples, embedding_dim)
        frame_counts = []

        with torch.no_grad():
            for i, (image_input, audio_input, negatives, nframes, path,
                    _) in enumerate(self.loaders['val']):
                if len(image_embeddings) * image_input.size(0) > 500:
                    break

                if not self.args.loading_image:
                    path_ints = [p.split('/')[-1] for p in path
                                 ]  # in case the audio is inside a subfolder

                    v_init = self.z[int(path_ints[0])]
                    z_img = torch.FloatTensor(image_input.size(0),
                                              v_init.shape[0])

                    for k in range(image_input.size(0)):
                        z_img[k, :] = self.z[int(path_ints[k])]

                    image_input = self.generator.generate_images(
                        z_img, intervention=None)
                    image_input = utils.transform(image_input)
                    negatives = []
                else:
                    image_input = image_input.cuda()
                    negatives = [negatives.cuda()]

                # compute output
                model_output = self.model(image_input, audio_input, negatives)
                image_output = model_output[0]
                audio_output = model_output[1]

                image_embeddings.append(image_output.data.cpu())
                audio_embeddings.append(audio_output.data.cpu())

                # find pooling ratio
                # audio_input is (B, D, 40, T)
                # audio_output is (B, D, 1, T/p)
                pooling_ratio = round(
                    audio_input.size(3) / audio_output.size(3))
                nframes.div_(pooling_ratio)
                frame_counts.append(nframes.cpu())

                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.args.print_freq == 0:
                    print('Eval: [{0}/{1}]\t'.format(i + 1,
                                                     len(self.loaders['val'])),
                          flush=True)

            image_outputs = torch.cat(image_embeddings)
            audio_outputs = torch.cat(audio_embeddings)
            frame_counts_tensor = torch.cat(frame_counts)

            N_examples = np.minimum(number_recall, N_examples)

            image_outputs = image_outputs[-N_examples:, :, :, :]
            audio_outputs = audio_outputs[-N_examples:, :, :, :]
            frame_counts_tensor = frame_counts_tensor[-N_examples:]
            # measure accuracy and record loss
            print('Computing recalls...')
            recalls = utils.calc_recalls(image_outputs,
                                         audio_outputs,
                                         frame_counts_tensor,
                                         loss_type=self.loss_type)
            A_r10 = recalls['A_r10']
            I_r10 = recalls['I_r10']
            A_r5 = recalls['A_r5']
            I_r5 = recalls['I_r5']
            A_r1 = recalls['A_r1']
            I_r1 = recalls['I_r1']

            print(
                ' * Audio R@10 {A_r10:.3f} Image R@10 {I_r10:.3f} over {N:d} validation pairs'
                .format(A_r10=A_r10, I_r10=I_r10, N=N_examples),
                flush=True)
            print(
                ' * Audio R@5 {A_r5:.3f} Image R@5 {I_r5:.3f} over {N:d} validation pairs'
                .format(A_r5=A_r5, I_r5=I_r5, N=N_examples),
                flush=True)
            print(
                ' * Audio R@1 {A_r1:.3f} Image R@1 {I_r1:.3f} over {N:d} validation pairs'
                .format(A_r1=A_r1, I_r1=I_r1, N=N_examples),
                flush=True)

            eval_score = (A_r5 + I_r5) / 2

        return eval_score, recalls