def test_avg(dataloader, model, cam, element, evaluator, cluster_a, cluster_v,
             epoch, use_gpu):
    model.eval()
    model.audio.gru.train()

    data_time = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(dataloader))

    for batch_idx, (audio, visual, roi, gtmap, box) in enumerate(dataloader):
        audio = audio.view(args.val_batch * args.mix, *audio.shape[-2:])
        visual = visual.view(args.val_batch * args.mix, *visual.shape[-3:])
        roi = roi.view(args.val_batch, args.mix, args.rois, 4)
        gt = gtmap.numpy()

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()
            roi = roi.cuda()

        data_time.update(time.time() - end)
        pred_a, feat_a, pred_v, feat_v = model(audio, visual, roi, cluster_a,
                                               cluster_v, False)
        feat_v = model.discrim.spa_conv(feat_v)
        feat_v = feat_v.permute([0, 2, 3, 1]).contiguous()

        feat_a = model.discrim.temp_conv(feat_a)
        feat_a = model.discrim.temp_pool(feat_a)

        feat_a = feat_a.view(args.val_batch, 1, 512)
        feat = torch.cat(
            [feat_a.repeat(1, 256, 1),
             feat_v.view(args.val_batch, 256, 512)], -1)
        feat = feat.view(-1, 1024)
        cams = model.discrim.auto_align(feat)
        cams = cams.view(args.val_batch, 256)
        cams = torch.softmax(-cams * 10, -1)
        cams = cams / torch.max(cams, 1)[0].unsqueeze(-1)
        cams = torch.nn.functional.interpolate(cams.view(
            args.val_batch, 1, 16, 16), (256, 256),
                                               mode='bilinear')
        cams = cams.detach().cpu().numpy()
        for idx in range(args.val_batch):
            cam_visualize(batch_idx * args.val_batch + idx,
                          visual[idx:idx + 1], cams[idx], None, box[idx])
            ciou = evaluator.cal_CIOU(cams[idx], gt[idx], 0.1)
            end = time.time()
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s | CIOU: {ciou:.3f}'.format(
            batch=batch_idx + 1,
            size=len(dataloader),
            data=data_time.val,
            ciou=ciou)
        bar.next()

    bar.finish()
    return evaluator.final(), evaluator.cal_AUC()
示例#2
0
def extract(trainloader, model, discrimloss, epoch, use_gpu):
    model.model.eval()
    model.model.audio.gru.train()

    data_time = AverageMeter()
    dis_loss = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(trainloader))

    for batch_idx, (audio, visual, _) in enumerate(trainloader):

        audio = audio.view(args.val_batch * args.mix, *audio.shape[-2:])
        visual = visual.view(args.val_batch * args.mix * args.frame, *visual.shape[-3:])

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()

        data_time.update(time.time() - end)
        discrim, cam = model(audio, visual)
        dloss = discrimloss(discrim[0], discrim[1], None)
        cam_visualize(batch_idx, visual, cam)

        if dloss.item() > 0:
            dis_loss.update(dloss.item(), 1)

        end = time.time()
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s |DLoss: {dloss:.3f}'.format(
            batch=batch_idx + 1,
            size=len(trainloader),
            data=data_time.val,
            dloss=dis_loss.val
        )
        bar.next()

    bar.finish()
def extract(trainloader, model, discrimloss, epoch, use_gpu):
    model.model.eval()
    model.model.discrim.train()
    model.model.audio.gru.train()

    data_time = AverageMeter()
    dis_loss = AverageMeter()
    end = time.time()
    infer = np.zeros((1098, 10, 16, 16))

    bar = Bar('Processing', max=len(trainloader))

    for batch_idx, (audio, visual, _) in enumerate(trainloader):

        audio = audio.view(audio.shape[0] * args.mix, *audio.shape[-2:])
        visual = visual.view(visual.shape[0] * args.mix * args.frame, *visual.shape[-3:])

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()

        data_time.update(time.time() - end)
        discrim, cam = model(audio, visual)
        # cam_visualize(batch_idx, visual, cam)
        cam = cam.detach().cpu().numpy()
        infer[batch_idx] = cam

        end = time.time()
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s'.format(
            batch=batch_idx + 1,
            size=len(trainloader),
            data=data_time.val
        )
        bar.next()

    bar.finish()
    np.save('infer', infer)
示例#4
0
def validate(val_loader, model, criterion, args, writer, epoch):
    batch_time = AverageMeter("Time", ":6.3f", write_val=False)
    losses = AverageMeter("Loss", ":.3f", write_val=False)
    top1 = AverageMeter("Acc@1", ":6.2f", write_val=False)
    top5 = AverageMeter("Acc@5", ":6.2f", write_val=False)
    progress = ProgressMeter(
        len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
    )

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in tqdm.tqdm(
            enumerate(val_loader), ascii=True, total=len(val_loader)
        ):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)

            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)

            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        progress.display(len(val_loader))

        if writer is not None:
            progress.write_to_tensorboard(writer, prefix="test", global_step=epoch)

    return top1.avg, top5.avg
示例#5
0
def train(train_loader, model, criterion, optimizer, epoch, args, writer):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.3f")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix=f"Epoch: [{epoch}]",
    )

    # switch to train mode
    model.train()

    batch_size = train_loader.batch_size
    num_batches = len(train_loader)
    end = time.time()
    for i, (images, target) in tqdm.tqdm(
        enumerate(train_loader), ascii=True, total=len(train_loader)
    ):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)

        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)

        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            t = (num_batches * epoch + i) * batch_size
            progress.display(i)
            progress.write_to_tensorboard(writer, prefix="train", global_step=t)

    return top1.avg, top5.avg
    def validate(self, epoch):
        self._print_log('Validating epoch: {}'.format(epoch + 1), 'validate')

        self.model.eval()
        loader = self.data_loader['test']
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        for batch_idx, (inputs, target) in enumerate(loader, 0):
            target = target.cuda()
            inputs = inputs.cuda()
            with torch.no_grad():
                inputs_var = torch.autograd.Variable(inputs)
                target_var = torch.autograd.Variable(target)
                output = self.model(inputs_var)

            loss = self.criterion(output, target_var.long())

            prec1, prec5 = self.evaluator.accuracy(output.data, target.long(), topk=(1, 5))

            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % self.args.print_freq == 0:
                message = ('Time: [{0}/{1}]\t'
                           'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                           'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                           'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                           'Prec@5 {top5.val:.3f} ({top5.avg:.3f})').format(
                    batch_idx, len(loader), batch_time=batch_time,
                    loss=losses, top1=top1, top5=top5)
                self._print_log(message, 'validate')

        self._print_log(('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.4f}'
                         .format(top1=top1, top5=top5, loss=losses)), 'validate')

        self.writer.add_scalar('Validate/Loss', losses.avg, epoch)
        self.writer.add_scalar('Validate/Prec@1', top1.avg, epoch)
        self.writer.add_scalar('Validate/Prec@5', top5.avg, epoch)

        return top1.avg
    def train(self, epoch):
        self._print_log('Training epoch: {}'.format(epoch + 1), 'train')

        self.model.train()
        loader = self.data_loader['train']
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        for batch_idx, (inputs, target) in enumerate(loader, 0):
            data_time.update(time.time() - end)

            target = target.cuda()
            inputs = inputs.cuda()
            inputs_var = torch.autograd.Variable(inputs)
            target_var = torch.autograd.Variable(target)

            output = self.model(inputs_var)
            loss = self.criterion(output, target_var.long())

            prec1, prec5 = self.evaluator.accuracy(output.data, target.long(), topk=(1, 5))

            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            n_iter = epoch * len(loader) + batch_idx + 1

            batch_time.update(time.time() - end)
            end = time.time()

            self.writer.add_scalar('Train/loss', loss.item(), n_iter)
            self.writer.add_scalar('Train/prec@1', prec1.item(), n_iter)
            self.writer.add_scalar('Train/prec@5', prec5.item(), n_iter)

            if batch_idx % self.args.print_freq == 0:
                message = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
                           'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                           'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                           'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                           'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                           'Prec@5 {top5.val:.3f} ({top5.avg:.3f})').format(
                    epoch, batch_idx, len(loader), batch_time=batch_time,
                    data_time=data_time, loss=losses, top1=top1, top5=top5, lr=self.optimizer.param_groups[-1]['lr'])
                self._print_log(message, 'train')

        self._print_log(('Training Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.4f}'
                         .format(top1=top1, top5=top5, loss=losses)), 'train')

        self.writer.add_scalar('Train/Loss', losses.avg, epoch)
        self.writer.add_scalar('Train/Prec@1', top1.avg, epoch)
        self.writer.add_scalar('Train/Prec@5', top5.avg, epoch)

        for name, param in self.model.named_parameters():
            layer, attr = os.path.splitext(name)
            attr = attr[1:]
            self.writer.add_histogram("{}/{}".format(layer, attr), param, epoch)
示例#8
0
def test_avs(valloader, model, cluster_a, cluster_v, epoch, use_gpu):
    
    model.eval()
    data_time = AverageMeter()
    match_rate = AverageMeter()
    differ_rate = AverageMeter()
    end = time.time()
    bar = Bar('Processing', max=len(valloader))

    for batch_idx, (audio, visual, roi) in enumerate(valloader):

        audio = audio.view(args.val_batch * args.mix, *audio.shape[-2:])
        visual = visual.view(args.val_batch * args.mix * args.frame, *visual.shape[-3:])
        roi = roi.view(args.val_batch, args.mix * args.frame, args.rois, 4)

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()
            roi = roi.cuda()

        data_time.update(time.time() - end)
        discrim, mask, pred_a, pred_v, label_a, label_v, _, _ = model(audio, visual, roi, cluster_a, cluster_v)
        discrim = [discrim[0].view(-1, 1), discrim[1].view(-1, 1)]
        discrim = torch.cat(discrim, 0)
        segments = discrim.shape[0]
        true_match = torch.sum(discrim[: segments//2, 0]<1.0)
        true_match = true_match.item() / float(segments/2)
        true_differ = torch.sum(discrim[segments//2:, 0]>1.0)
        true_differ = true_differ.item() / float(segments/2)
        match_rate.update(true_match, 1)
        differ_rate.update(true_differ, 1)

        end = time.time()
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s |Match: {match:.3f} |Differ: {differ: .3f}'.format(
            batch=batch_idx + 1,
            size=len(valloader),
            data=data_time.val,
            match=match_rate.val,
            differ=differ_rate.val
        )
        bar.next()
        
    bar.finish()
    return match_rate.avg, differ_rate.avg
示例#9
0
def train_avs(trainloader, model, cluster_a, cluster_v, discrimloss, maploss, optimizer, epoch, use_gpu):
    model.eval()
    model.audio.gru.train()
    model.visual.layer5.train()
    model.discrim.train()
    discrimloss.train()
    maploss.train()

    data_time = AverageMeter()
    a_loss = AverageMeter()
    v_loss = AverageMeter()
    r_loss = AverageMeter()
    dis_loss = AverageMeter()
    total_loss = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(trainloader))
    optimizer.zero_grad()

    for batch_idx, (audio, visual, roi) in enumerate(trainloader):

        audio = audio.view(args.train_batch * args.mix, *audio.shape[-2:])
        visual = visual.view(args.train_batch * args.mix * args.frame, *visual.shape[-3:])
        roi = roi.view(args.train_batch, args.mix * args.frame, args.rois, 4)

        if batch_idx == args.wp:
            warm_up_lr(optimizer, False)

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()
            roi = roi.cuda()

        data_time.update(time.time() - end)
        discrim, mask, pred_a, pred_v, label_a, label_v, feat, _ = model(audio, visual, roi, cluster_a, cluster_v)
        dloss = discrimloss(discrim[0], discrim[1], mask)
        aloss, vloss, rloss = maploss(label_a, *label_v, pred_a, *pred_v)
        loss = dloss + vloss / float(7) + rloss + aloss

        if loss.item() > 0:
            total_loss.update(loss.item(), 1)
            a_loss.update(aloss.item() / 7, 1)
            v_loss.update(vloss.item() / 7, 1)
            r_loss.update(rloss.item(), 1)
            dis_loss.update(dloss.item(), 1)

        loss /= args.its
        loss.backward()

        if batch_idx % args.its == 0:
            optimizer.step()
            optimizer.zero_grad()

        end = time.time()
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s |Loss: {loss:.3f} |ALoss: {aloss:.3f} |VLoss: {vloss:.3f} |RLoss: {rloss: .3f} |DLoss: {dloss:.3f}'.format(
            batch=batch_idx + 1,
            size=len(trainloader),
            data=data_time.val,
            loss=total_loss.val,
            aloss=a_loss.val,
            vloss=v_loss.val,
            rloss=r_loss.val,
            dloss=dis_loss.val
        )
        bar.next()

    bar.finish()

    return total_loss.avg
示例#10
0
    def _valid_epoch(self, epoch):
        """Validation after training an epoch"""
        self.discriminator.eval()
        self.generator.eval()

        if self.verbosity > 2:
            print("Validation at epoch {}".format(epoch))

        batch_time = AverageMeter()
        data_time = AverageMeter()
        dlr = AverageMeter()
        dlf = AverageMeter()
        g_loss = AverageMeter()

        end_time = time.time()
        with torch.no_grad():
            for batch_idx, data in enumerate(self.valid_data_loader):
                data_time.update(time.time() - end_time)

                D_loss_real = self.get_reconstruction_loss(data)
                #x_real = data.to(self.device)
                #D_real = self.discriminator(x_real)[0]
                #D_loss_real = self._discriminator_loss(x_real, D_real)

                z = self._sample_z(self.data_loader.batch_size, self.noise_dim)
                z = z.to(self.device)
                x_fake = self.generator(z)

                D_loss_fake = self.get_reconstruction_loss(x_fake.detach())
                #D_fake = self.discriminator(x_fake.detach())[0]
                #D_loss_fake = self._discriminator_loss(x_fake, D_fake)

                D_loss = D_loss_real
                if D_loss_fake.item() < self.margin:
                    D_loss += (self.margin - D_loss_fake)

                z = self._sample_z(self.data_loader.batch_size, self.noise_dim)
                z = z.to(self.device)
                x_fake = self.generator(z)
                D_fake, D_latent = self.discriminator(x_fake)

                G_loss = self._generator_loss(x_fake, D_fake, D_latent)

                batch_time.update(time.time() - end_time)

                dlr.update(D_loss_real.item(), data.size(0))
                dlf.update(D_loss_fake.item(), data.size(0))
                g_loss.update(G_loss.item(), z.size(0))

                if self.verbosity >= 2:
                    print(
                        'Epoch: {} [{}/{} ({:.0f}%)]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                        'Discriminator Loss (Real) {dlr.val:.4f} ({dlr.avg:.4f})\t'
                        'Discriminator Loss (Fake) {dlf.val:.4f} ({dlf.avg:.4f})\t'
                        'Generator Loss {g_loss.val:.4f} ({g_loss.avg:.4f})\t'.
                        format(epoch,
                               batch_idx * self.valid_data_loader.batch_size,
                               self.data_loader.n_valid_samples,
                               100.0 * batch_idx / len(self.valid_data_loader),
                               batch_time=batch_time,
                               data_time=data_time,
                               dlr=dlr,
                               dlf=dlf,
                               g_loss=g_loss))

        log = {
            'dlr': dlr.avg,
            'dlf': dlf.avg,
            'g_loss': g_loss.avg,
        }

        return log
示例#11
0
    def _train_epoch(self, epoch):
        """Training logic for an epoch

        Inputs
        ------
        epoch : int
            The current training epoch
        
        Returns
        -------
        Log with information to save

        """
        if self.verbosity > 2:
            print("Train at epoch {}".format(epoch))

        self.generator.train()
        self.discriminator.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()

        dlr = AverageMeter()
        dlf = AverageMeter()
        g_loss = AverageMeter()

        end_time = time.time()
        for batch_idx, data in enumerate(self.data_loader):
            data_time.update(time.time() - end_time)

            # Train the discriminator
            D_loss_real = self.get_reconstruction_loss(data)
            #x_real = data.to(self.device)
            #D_real = self.discriminator(x_real)[0]
            #D_loss_real = self._discriminator_loss(x_real, D_real)

            z = self._sample_z(self.data_loader.batch_size, self.noise_dim)
            z = z.to(self.device)
            x_fake = self.generator(z)

            D_loss_fake = self.get_reconstruction_loss(x_fake.detach())
            #D_fake = self.discriminator(x_fake.detach())[0]
            #D_loss_fake = self._discriminator_loss(x_fake, D_fake)

            D_loss = D_loss_real
            if D_loss_fake.item() < self.margin:
                D_loss += (self.margin - D_loss_fake)

            self.d_optimizer.zero_grad()
            D_loss.backward()
            self.d_optimizer.step()

            # Train the generator
            z = self._sample_z(self.data_loader.batch_size, self.noise_dim)
            z = z.to(self.device)
            x_fake = self.generator(z)
            D_fake, D_latent = self.discriminator(x_fake)

            G_loss = self._generator_loss(x_fake, D_fake, D_latent)

            self.g_optimizer.zero_grad()
            G_loss.backward()
            self.g_optimizer.step()

            batch_time.update(time.time() - end_time)
            end_time = time.time()

            dlr.update(D_loss_real.item(), data.size(0))
            dlf.update(D_loss_fake.item(), data.size(0))
            g_loss.update(G_loss.item(), z.size(0))

            if self.verbosity >= 2:
                info = (
                    'Epoch: {} [{}/{} ({:.0f}%)]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Discriminator Loss (Real) {dlr.val:.4f} ({dlr.avg:.4f})\t'
                    'Discriminator Loss (Fake) {dlf.val:.4f} ({dlf.avg:.4f})\t'
                    'Generator Loss (Real) {g_loss.val:.4f} ({g_loss.avg:.4f})\t'
                ).format(epoch,
                         batch_idx * self.data_loader.batch_size,
                         self.data_loader.n_samples,
                         100.0 * batch_idx / len(self.data_loader),
                         batch_time=batch_time,
                         data_time=data_time,
                         dlr=dlr,
                         dlf=dlf,
                         g_loss=g_loss)
                if batch_idx % self.log_step == 0:
                    self.logger.info(info)
                    self.writer.add_image(
                        'inp', make_grid(data.cpu(), nrow=8, normalize=True))
                print(info)

        log = {
            'dlr': dlr.avg,
            'dlf': dlf.avg,
            'g_loss': g_loss.avg,
            #'metrics': (total_metrics / len(self.data_loader)).tolist()
        }

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log = {**log, **val_log}

        if self.g_lr_scheduler is not None:
            self.g_lr_scheduler.step()
        if self.d_lr_scheduler is not None:
            self.d_lr_scheduler.step()

        return log
示例#12
0
def main_worker(args):
    args.gpu = None
    train, validate, modifier = get_trainer(args)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)

    if args.pretrained:
        pretrained(args, model)

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        best_acc1 = resume(args, model, optimizer)

    # Data loading code
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)

        return

    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
    args.ckpt_base_dir = ckpt_base_dir

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        modifier(args, epoch, model)

        cur_lr = get_lr(optimizer)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer,
                              epoch)
        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0
        if is_best or save or epoch == args.epochs - 1:
            if is_best:
                print(
                    f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                )

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "best_acc5": best_acc5,
                    "best_train_acc1": best_train_acc1,
                    "best_train_acc5": best_train_acc5,
                    "optimizer": optimizer.state_dict(),
                    "curr_acc1": acc1,
                    "curr_acc5": acc5,
                },
                is_best,
                filename=ckpt_base_dir / f"epoch_{epoch}.state",
                save=save,
            )

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(writer,
                                              prefix="diagnostics",
                                              global_step=epoch)

        if args.conv_type == "SampleSubnetConv":
            count = 0
            sum_pr = 0.0
            for n, m in model.named_modules():
                if isinstance(m, SampleSubnetConv):
                    # avg pr across 10 samples
                    pr = 0.0
                    for _ in range(10):
                        pr += ((torch.rand_like(m.clamped_scores) >=
                                m.clamped_scores).float().mean().item())
                    pr /= 10.0
                    writer.add_scalar("pr/{}".format(n), pr, epoch)
                    sum_pr += pr
                    count += 1

            args.prune_rate = sum_pr / count
            writer.add_scalar("pr/average", args.prune_rate, epoch)

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
    )
def test_cam(valloader, model, cam, epoch, use_gpu):
    model.eval()
    model.audio.gru.train()
    model.discrim.train()
    data_time = AverageMeter()
    end = time.time()
    camaudio = cam[0]
    camvisual = cam[1]
    infermap = np.zeros((1098, 10, 15, 16, 16))
    effidx = np.load('audioset_val.npy')
    result = []
    bar = Bar('Processing', max=len(valloader))

    for batch_idx, (audio, visual, label) in enumerate(valloader):

        audio = audio.view(audio.shape[0] * args.mix, *audio.shape[-2:])
        visual = visual.view(visual.shape[0] * args.mix * args.frame, *visual.shape[-3:])
        label = label.view(label.shape[0] * args.mix, label.shape[-1])

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()
            label = label.cuda()

        data_time.update(time.time() - end)
        # discrim, pred_a, pred_v, feat_a, feat_v, cam_v = model(audio, visual)
        pred_a, pred_v, feat_a, feat_v, cam_v = model(audio, visual, False)
        feat_a, _, _ = camaudio(pred_a, feat_a)
        feat_a = model.discrim.temp_conv(feat_a)
        feat_a = model.avalign.temp_pool(feat_a)
        cam_v = camvisual(cam_v, feat_v)
        cam_v = torch.max(cam_v, 1)[0]
        feat_v = model.discrim.spa_conv(feat_v)
        feat_a = feat_a.view(-1, 1, 15, feat_a.shape[1], 1)
        feat_v = feat_v.view(-1, args.frame, 1, feat_v.shape[1],
                             feat_v.shape[-2]*feat_v.shape[-1])
        feat_a = feat_a.repeat([1, args.frame, 1, 1, feat_v.shape[-1]])
        feat_v = feat_v.repeat([1, 1, 15, 1, 1])
        feat = torch.cat([feat_a.permute(0, 1, 2, 4, 3).contiguous(),
                          feat_v.permute(0, 1, 2, 4, 3).contiguous()], -1)
#        score = model.discrim.discrim(feat)
#        score = torch.softmax(score, -1)[:, :, :, :, 1]
#        score = score.view(-1, 15, 16, 16)
        embed = feat.view(-1, 1024)
        dist = F.mse_loss(model.discrim.transform_a(embed[:, :512]),
                          model.discrim.transform_a(embed[:, 512:]), reduce=False)
        dist = dist.mean(-1)
        dist = dist.view(*feat.shape[:4])
        score = torch.softmax(-dist, -1) * cam_v.view(1, 10, 1, 256)
        score = score / torch.max(score+1e-10, -1)[0].unsqueeze(-1)
        score = score.view(-1, 15, 16, 16)
        score = torch.nn.functional.interpolate(score, (256, 256), mode='bilinear')
        score = score.detach().cpu().numpy()
        infermap[batch_idx] = score
        # generate_bbox(batch_idx, pred_a, pred_v, score, effidx, result)
        # cam_visualize(batch_idx, visual, score)

        end = time.time()
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s'.format(
            batch=batch_idx + 1,
            size=len(valloader),
            data=data_time.val
        )
        bar.next()

    bar.finish()
    np.save('infer', infermap)
def test_cls(valloader, model, epoch, use_gpu):
    model.eval()
    # model.discrim.train()
    data_time = AverageMeter()
    match_rate = AverageMeter()
    differ_rate = AverageMeter()
    end = time.time()
    bar = Bar('Processing', max=len(valloader))
    infermap = np.zeros((1098, 10, 15, 16, 16))
    effidx = np.load('audioset_val.npy')
    result = []

    for batch_idx, (audio, visual, label) in enumerate(valloader):

        audio = audio.view(audio.shape[0] * args.mix, *audio.shape[-2:])
        visual = visual.view(visual.shape[0] * args.mix * args.frame, *visual.shape[-3:])
        label = label.view(label.shape[0] * args.mix, label.shape[-1])

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()
            label = label.cuda()

        data_time.update(time.time() - end)
        pred_a, pred_v = model(audio, visual)
        visual_ = model.visual(visual, 'bottom')
        visual_ = model.visual.layer4(visual_)
        visual_ = visual_.view(10, 512, 256)
        pred = torch.matmul(model.visual.fc.weight.unsqueeze(0), visual_)
        pred = pred.view(10, 15, 16, 16)
        # pred = torch.nn.functional.interpolate(pred, (256, 256), mode='bilinear')
        # pred = torch.relu(pred)
        pred = pred.detach().cpu().numpy()
        infermap[batch_idx] = pred
        # generate_bbox(batch_idx, pred_a, pred_v, pred, effidx, result)
        # cam_visualize(batch_idx, visual, pred)
        match = (pred_a>0.5) * (label>0)
        differ = (pred_a<0.5) * (label==0)
        if torch.sum(label) > 0:
            match_rate.update(torch.sum(match).item() / float(torch.sum(label)))
        differ_rate.update(torch.sum(differ).item() / float(torch.sum(label==0)))

        end = time.time()
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s |Match: {match:.3f} |Differ: {differ: .3f}'.format(
            batch=batch_idx + 1,
            size=len(valloader),
            data=data_time.val,
            match=match_rate.val,
            differ=differ_rate.val
        )
        bar.next()

    bar.finish()
    np.save('infer', infermap)
    # with open('bbox.json', 'w') as f:
    #     json.dump(result, f)

    return match_rate.avg, differ_rate.avg
def test_avs(valloader, model, epoch, use_gpu):
    model.eval()
    # model.discrim.train()
    data_time = AverageMeter()
    match_rate = AverageMeter()
    differ_rate = AverageMeter()
    end = time.time()
    bar = Bar('Processing', max=len(valloader))

    for batch_idx, (audio, visual, label) in enumerate(valloader):

        audio = audio.view(audio.shape[0] * args.mix, *audio.shape[-2:])
        visual = visual.view(visual.shape[0] * args.mix * args.frame, *visual.shape[-3:])
        label = label.view(label.shape[0] * args.mix, label.shape[-1])

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()
            label = label.cuda()

        data_time.update(time.time() - end)
        discrim, pred_a, pred_v, feat_a, feat_v, cam_v = model(audio, visual)
        common = discrim[0].view(-1, 2)
        differ = discrim[1].view(-1, 2)
        true_match = torch.sum(common[:, 1] > common[:, 0])
        true_match = true_match.item() / float(common.shape[0])
        true_differ = torch.sum(differ[:, 0] > differ[:, 1])
        true_differ = true_differ.item() / float(differ.shape[0])
        match_rate.update(true_match, 1)
        differ_rate.update(true_differ, 1)

        end = time.time()
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s |Match: {match:.3f} |Differ: {differ: .3f}'.format(
            batch=batch_idx + 1,
            size=len(valloader),
            data=data_time.val,
            match=match_rate.val,
            differ=differ_rate.val
        )
        bar.next()

    bar.finish()
    return match_rate.avg, differ_rate.avg
def train_cam(trainloader, model, cam, discrimloss, maploss, alignloss, optimizer, epoch, use_gpu):
    model.eval()
    model.audio.gru.train()
    model.discrim.train()
    discrimloss.train()
    maploss.train()
    alignloss.train()
    camaudio = cam[0]
    camvisual = cam[1]

    data_time = AverageMeter()
    a_loss = AverageMeter()
    v_loss = AverageMeter()
    e_loss = AverageMeter()
    dis_loss = AverageMeter()
    total_loss = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(trainloader))
    optimizer.zero_grad()

    for batch_idx, (audio, visual, label) in enumerate(trainloader):

        audio = audio.view(audio.shape[0] * args.mix, *audio.shape[-2:])
        visual = visual.view(visual.shape[0] * args.mix * args.frame, *visual.shape[-3:])
        label = label.view(label.shape[0] * args.mix, label.shape[-1])

        if batch_idx == args.wp:
            warm_up_lr(optimizer, False)

        data_time.update(time.time() - end)
        if use_gpu:
            audio = audio.cuda()
            visual = visual.cuda()
            label = label.cuda()

        data_time.update(time.time() - end)
        pred_a, pred_v, feat_a, feat_v, cam_v = model(audio, visual)
        feat_a, _, _ = camaudio(pred_a, feat_a)
        cam_v = camvisual(cam_v, feat_v)
        common, differ = model.avalign(feat_a, feat_v, label, cam_v)
        eloss = alignloss(common, differ)
        aloss, vloss = maploss(label, pred_a, pred_v)
        loss = aloss / float(15) + vloss / float(15) + eloss

        if loss.item() > 0:
            total_loss.update(loss.item(), 1)
            a_loss.update(aloss.item() / 15, 1)
            v_loss.update(vloss.item() / 15, 1)
            e_loss.update(eloss.item(), 1)

        loss /= args.its
        loss.backward()

        if batch_idx % args.its == 0:
            optimizer.step()
            optimizer.zero_grad()

        end = time.time()
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s |Loss: {loss:.3f} |ALoss: {aloss:.3f} ' \
                     '|VLoss: {vloss:.3f} |ELoss: {eloss:.3f}'.format(
            batch=batch_idx + 1,
            size=len(trainloader),
            data=data_time.val,
            loss=total_loss.val,
            aloss=a_loss.val,
            vloss=v_loss.val,
            eloss=e_loss.val
        )
        bar.next()

    bar.finish()

    return total_loss.avg