Esempio n. 1
0
def evaluate(netWrapper, loader, history, epoch, args):
    print('Evaluating at {} epochs...'.format(epoch))
    torch.set_grad_enabled(False)

    # remove previous viz results
    makedirs(args.vis, remove=True)

    # switch to eval mode
    netWrapper.eval()

    # initialize meters
    loss_meter = AverageMeter()
    sdr_mix_meter = AverageMeter()
    sdr_meter = AverageMeter()
    sir_meter = AverageMeter()
    sar_meter = AverageMeter()

    # initialize HTML header
    visualizer = HTMLVisualizer(os.path.join(args.vis, 'index.html'))
    header = ['Filename', 'Input Mixed Audio']
    for n in range(1, args.num_mix + 1):
        header += [
            'Video {:d}'.format(n), 'Predicted Audio {:d}'.format(n),
            'GroundTruth Audio {}'.format(n), 'Predicted Mask {}'.format(n),
            'GroundTruth Mask {}'.format(n)
        ]
    header += ['Loss weighting']
    visualizer.add_header(header)
    vis_rows = []

    for i, batch_data in enumerate(loader):
        # forward pass
        err, _, g, outputs = netWrapper.forward(batch_data, args)
        err = err.mean()

        loss_meter.update(err.item())
        print('[Eval] iter {}, loss: {:.4f}'.format(i, err.item()))
        grd_acc = np.sum(
            np.round(g[0][:, 0].detach().cpu().numpy()) +
            (np.round(g[1][:, 1].detach().cpu().numpy()))) / (
                2 * len(np.round(g[0][:, 0].detach().cpu().numpy())))
        grd_mix_acc = (np.sum(
            np.round(g[2][0][:, 0].detach().cpu().numpy()) +
            np.round(g[2][1][:, 1].detach().cpu().numpy()) +
            np.round(g[2][2][:, 0].detach().cpu().numpy()) +
            (np.round(g[2][3][:, 1].detach().cpu().numpy())))) / (
                4 * len(np.round(g[2][0][:, 0].detach().cpu().numpy())))

        grd_solo_acc = (np.sum(
            np.round(g[3][0][:, 0].detach().cpu().numpy()) +
            np.round(g[3][1][:, 1].detach().cpu().numpy()) +
            np.round(g[3][2][:, 0].detach().cpu().numpy()) +
            (np.round(g[3][3][:, 1].detach().cpu().numpy())))) / (
                4 * len(np.round(g[3][0][:, 0].detach().cpu().numpy())))

        print(
            'Grounding acc {:.2f}, Solo Grounding acc: {:.2f}, Sep Grounding acc: {:.2f}'
            .format(grd_acc, grd_solo_acc, grd_mix_acc))

        # calculate metrics
        sdr_mix, sdr, sir, sar = calc_metrics(batch_data, outputs, args)
        #print(sir)

        sdr_mix_meter.update(sdr_mix)
        sdr_meter.update(sdr)
        sir_meter.update(sir)
        sar_meter.update(sar)
        #
        # # output visualization
        # if len(vis_rows) < args.num_vis:
        output_visuals(vis_rows, batch_data, outputs, args)

    print('[Eval Summary] Epoch: {}, Loss: {:.4f}, '
          'SDR_mixture: {:.4f}, SDR: {:.4f}, SIR: {:.4f}, SAR: {:.4f}'.format(
              epoch, loss_meter.average(), sdr_mix_meter.average(),
              sdr_meter.average(), sir_meter.average(), sar_meter.average()))
    history['val']['epoch'].append(epoch)
    history['val']['err'].append(loss_meter.average())
    history['val']['sdr'].append(sdr_meter.average())
    history['val']['sir'].append(sir_meter.average())
    history['val']['sar'].append(sar_meter.average())

    print('Plotting html for visualization...')
    visualizer.add_rows(vis_rows)
    visualizer.write_html()

    # Plot figure
    if epoch > 0:
        print('Plotting figures...')
        plot_loss_metrics(args.ckpt, history)
Esempio n. 2
0
    def evaluate(self, loader):
        print('Evaluating at {} epochs...'.format(self.epoch))
        torch.set_grad_enabled(False)

        # remove previous viz results
        makedirs(self.args.vis, remove=True)

        self.netwrapper.eval()

        # initialize meters
        loss_meter = AverageMeter()
        sdr_mix_meter = AverageMeter()
        sdr_meter = AverageMeter()
        sir_meter = AverageMeter()
        sar_meter = AverageMeter()

        # initialize HTML header
        visualizer = HTMLVisualizer(os.path.join(self.args.vis, 'index.html'))
        header = ['Filename', 'Input Mixed Audio']
        for n in range(1, self.args.num_mix + 1):
            header += [
                'Video {:d}'.format(n), 'Predicted Audio {:d}'.format(n),
                'GroundTruth Audio {}'.format(n),
                'Predicted Mask {}'.format(n), 'GroundTruth Mask {}'.format(n)
            ]
        header += ['Loss weighting']
        visualizer.add_header(header)
        vis_rows = []
        eval_num = 0
        valid_num = 0

        #for i, batch_data in enumerate(self.loader['eval']):
        for i, batch_data in enumerate(loader):
            # forward pass
            eval_num += batch_data['mag_mix'].shape[0]
            with torch.no_grad():
                err, outputs = self.netwrapper.forward(batch_data, args)
                err = err.mean()

            if self.mode == 'train':
                self.writer.add_scalar('data/val_loss', err,
                                       self.args.epoch_iters * self.epoch + i)

            loss_meter.update(err.item())
            print('[Eval] iter {}, loss: {:.4f}'.format(i, err.item()))

            # calculate metrics
            sdr_mix, sdr, sir, sar, cur_valid_num = calc_metrics(
                batch_data, outputs, self.args)
            print("sdr_mix, sdr, sir, sar: ", sdr_mix, sdr, sir, sar)
            sdr_mix_meter.update(sdr_mix)
            sdr_meter.update(sdr)
            sir_meter.update(sir)
            sar_meter.update(sar)
            valid_num += cur_valid_num
            '''
            # output visualization
            if len(vis_rows) < self.args.num_vis:
                output_visuals(vis_rows, batch_data, outputs, self.args)
            '''
        metric_output = '[Eval Summary] Epoch: {}, Loss: {:.4f}, ' \
            'SDR_mixture: {:.4f}, SDR: {:.4f}, SIR: {:.4f}, SAR: {:.4f}'.format(
                self.epoch, loss_meter.average(),
                sdr_mix_meter.sum_value()/eval_num,
                sdr_meter.sum_value()/eval_num,
                sir_meter.sum_value()/eval_num,
                sar_meter.sum_value()/eval_num
        )
        if valid_num / eval_num < 0.8:
            metric_output += ' ---- Invalid ---- '

        print(metric_output)
        learning_rate = ' lr_sound: {}, lr_frame: {}'.format(
            self.args.lr_sound, self.args.lr_frame)
        with open(self.args.log, 'a') as F:
            F.write(metric_output + learning_rate + '\n')

        self.history['val']['epoch'].append(self.epoch)
        self.history['val']['err'].append(loss_meter.average())
        self.history['val']['sdr'].append(sdr_meter.sum_value() / eval_num)
        self.history['val']['sir'].append(sir_meter.sum_value() / eval_num)
        self.history['val']['sar'].append(sar_meter.sum_value() / eval_num)
        '''
        print('Plotting html for visualization...')
        visualizer.add_rows(vis_rows)
        visualizer.write_html()
        '''
        # Plot figure
        if self.epoch > 0:
            print('Plotting figures...')
            plot_loss_metrics(self.args.ckpt, self.history)
def evaluate_adv(netWrapper, loader, history, epoch, args):
    print('Evaluating at {} epochs...'.format(epoch))
    criterion = nn.CrossEntropyLoss()
    # torch.set_grad_enabled(False)

    # switch to eval mode
    netWrapper.eval()
    # initialize meters
    loss_meter = AverageMeter()

    fig = plt.figure()
    epsilons = []
    for i in range(5):
        for j in range(5):
            epsilons.append([i * 0.001, j * 0.001])
    ep = 0.006
    epsilons = [[0, 0], [0, ep], [ep, 0], [ep, ep]]
    for epsilon in epsilons:
        cos_sim = []
        # initialize HTML header
        visualizer = HTMLVisualizer(os.path.join(args.vis, 'index.html'))
        header = ['Filename']
        for n in range(1, args.num_mix + 1):
            header += [
                'Original Image {:d}'.format(n), 'Adv. Image {:d}'.format(n),
                'Original Audio {}'.format(n), 'Adv. Audio {}'.format(n)
            ]
        visualizer.add_header(header)
        vis_rows = []
        correct = 0
        adv_correct = 0
        total = 0

        for i, batch_data in enumerate(loader):
            audios = batch_data['audios']
            frames = batch_data['frames']
            gts = batch_data['labels']

            audio = audios[0].to(args.device)
            frame = frames[0].to(args.device).squeeze(2)
            gt = gts[0].to(args.device)

            if args.attack_type == "fsgm":
                data_viz = []
                frame.requires_grad = True
                audio.requires_grad = True

                # forward pass
                preds, feat_v, feat_a = netWrapper(frame, audio)
                netWrapper.zero_grad()
                err = criterion(preds, gt) + F.cosine_similarity(
                    feat_v, feat_a, 1).mean()  #0.8-ks
                err.backward()

                # original frame and audio
                frame_ori = inv_norm_tensor(frame.clone())
                data_viz.append(frame_ori)
                data_viz.append(audio)

                # Add perturbation
                if args.arch_classifier != "audio":
                    frame_adv = frame + epsilon[0] * torch.sign(
                        frame.grad.data)
                else:
                    frame_adv = frame
                frame_adv = inv_norm_tensor(frame_adv.clone())

                frame_adv = torch.clamp(frame_adv, 0, 1)
                data_viz.append(frame_adv)
                frame_adv = norm_tensor(frame_adv.clone())

                if args.arch_classifier != "visual":
                    audio_adv = audio + epsilon[1] * torch.sign(
                        audio.grad.data)
                    # audio_adv = torch.clamp(audio_adv, -1, 1).detach()
                else:
                    audio_adv = audio
                data_viz.append(audio_adv)

                adv_preds, feat_v, feat_a = netWrapper(frame_adv, audio_adv)
                sim = F.cosine_similarity(feat_v, feat_a, -1)
                cos_sim = np.concatenate((cos_sim, sim.detach().cpu().numpy()),
                                         axis=0)
            elif args.attack_type == "pgd":
                # original frame and audio
                data_viz = []
                frame_ori = inv_norm_tensor(frame.clone())
                data_viz.append(frame_ori)
                data_viz.append(audio)
                preds, _, _ = netWrapper(frame, audio)
                alpha_v = epsilon[0] / 8
                alpha_a = epsilon[1] / 8
                frame_adv = frame.clone().detach()
                audio_adv = audio.clone().detach()
                for t in range(10):
                    frame_adv.requires_grad = True
                    audio_adv.requires_grad = True
                    # forward pass
                    preds_iter, feat_v, feat_a = netWrapper(
                        frame_adv, audio_adv)
                    netWrapper.zero_grad()
                    err = criterion(preds_iter, gt) + F.cosine_similarity(
                        feat_v, feat_a, 1).mean()
                    err.backward()

                    # Add perturbation
                    if args.arch_classifier in ["concat", "visual"]:
                        frame_adv = frame_adv.detach() + alpha_v * torch.sign(
                            frame_adv.grad.data)
                        eta = torch.clamp(frame_adv - frame,
                                          min=-epsilon[0],
                                          max=epsilon[0])
                        frame_adv = (frame + eta).detach_()
                    else:
                        frame_adv = frame.detach()
                    frame_adv = inv_norm_tensor(frame_adv.clone())
                    frame_adv = torch.clamp(frame_adv, 0, 1)
                    frame_adv = norm_tensor(frame_adv.clone())

                    if args.arch_classifier in ["concat", "audio"]:
                        audio_adv = audio_adv.detach() + alpha_a * torch.sign(
                            audio_adv.grad.data)
                        eta = torch.clamp(audio_adv - audio,
                                          min=-epsilon[1],
                                          max=epsilon[1])
                        audio_adv = torch.clamp(audio + eta, min=-1,
                                                max=1).detach_()
                    else:
                        audio_adv = audio.detach()

                data_viz.append(
                    torch.clamp(inv_norm_tensor(frame_adv.clone()), 0, 1))
                data_viz.append(audio_adv)
                adv_preds, _, _ = netWrapper(frame_adv, audio_adv)
            elif args.attack_type == "mim":
                # original frame and audio
                data_viz = []
                frame_ori = inv_norm_tensor(frame.clone())
                data_viz.append(frame_ori)
                data_viz.append(audio)
                preds, _, _ = netWrapper(frame, audio)

                alpha_v = epsilon[0] / 8
                alpha_a = epsilon[1] / 8
                frame_adv = frame.clone().detach()
                audio_adv = audio.clone().detach()
                momentum_v = torch.zeros_like(frame).to(args.device)
                momentum_a = torch.zeros_like(audio).to(args.device)

                for t in range(10):
                    frame_adv.requires_grad = True
                    audio_adv.requires_grad = True
                    # forward pass
                    preds_iter, feat_v, feat_a = netWrapper(
                        frame_adv, audio_adv)
                    netWrapper.zero_grad()
                    err = criterion(preds_iter, gt) + F.cosine_similarity(
                        feat_v, feat_a, 1).mean()
                    err.backward()

                    # Add perturbation
                    if args.arch_classifier in ["concat", "visual"]:
                        grad = frame_adv.grad.data
                        grad_norm = torch.norm(grad, p=1)
                        grad /= grad_norm
                        grad += momentum_v * 1.0
                        momentum_v = grad
                        frame_adv = frame_adv.detach(
                        ) + alpha_v * torch.sign(grad)
                        a = torch.clamp(frame_adv - epsilon[0], min=0)
                        b = (frame_adv >= a).float() * frame_adv + (
                            a > frame_adv).float() * a
                        c = (b > frame_adv + epsilon[0]).float() * (
                            frame_adv + epsilon[0]) + (frame_adv + epsilon[0]
                                                       >= b).float() * b
                        frame_adv = c.detach_()
                    else:
                        frame_adv = frame.detach()
                    frame_adv = inv_norm_tensor(frame_adv.clone())
                    frame_adv = torch.clamp(frame_adv, 0, 1)
                    frame_adv = norm_tensor(frame_adv.clone())

                    if args.arch_classifier in ["concat", "audio"]:
                        grad = audio_adv.grad.data
                        grad_norm = torch.norm(grad, p=1)
                        grad /= grad_norm
                        grad += momentum_a * 1.0
                        momentum_a = grad
                        audio_adv = audio_adv.detach(
                        ) + alpha_a * torch.sign(grad)
                        a = torch.clamp(audio_adv - epsilon[1], min=-1)
                        b = (audio_adv >= a).float() * audio_adv + (
                            a > audio_adv).float() * a
                        c = (b > audio_adv + epsilon[1]).float() * (
                            audio_adv + epsilon[1]) + (audio_adv + epsilon[1]
                                                       >= b).float() * b
                        audio_adv = c.detach_()
                        audio_adv = torch.clamp(audio_adv, min=-1,
                                                max=1).detach_()
                    else:
                        audio_adv = audio.detach()

                data_viz.append(
                    torch.clamp(inv_norm_tensor(frame_adv.clone()), 0, 1))
                data_viz.append(audio_adv)
                adv_preds, _, _ = netWrapper(frame_adv, audio_adv)
            else:
                print("Unknown attack method!")

            _, predicted = torch.max(preds.data, 1)
            total += preds.size(0)
            correct += (predicted == gt).sum().item()

            _, predicted = torch.max(adv_preds.data, 1)
            adv_correct += (predicted == gt).sum().item()

            loss_meter.update(err.item())
            # print('[Eval] iter {}, loss: {:.4f}'.format(i, err.item()))

            # viz
            output_visuals(vis_rows, batch_data, data_viz, args)

        print('[Eval Summary] Epoch: {}, Loss: {:.4f}'.format(
            epoch, loss_meter.average()))
        history['val']['epoch'].append(epoch)
        history['val']['err'].append(loss_meter.average())

        print(
            'Accuracy of the audio-visual event recognition network: %.2f %%' %
            (100 * correct / total))
        print(
            'adv Accuracy of the audio-visual event recognition network: %.2f %%'
            % (100 * adv_correct / total))

        print('Plotting html for visualization...')
        visualizer.add_rows(vis_rows)
        visualizer.write_html()

        # Plot figure
        if epoch > 0:
            print('Plotting figures...')
            plot_loss_metrics(args.ckpt, history)

        plt.plot(cos_sim,
                 label="v: " + str(epsilon[0] * 1e3) + " + a: " +
                 str(epsilon[1] * 1e3) + " " + "acc: " + '%.1f' %
                 (100 * adv_correct / total))
    plt.legend()
    fig.savefig(os.path.join(args.ckpt, 'cos_sim.png'), dpi=200)
Esempio n. 4
0
def evaluate(netWrapper, loader, history, epoch, args):
    print('Evaluating at {} epochs...'.format(epoch))
    with torch.no_grad():
        tic = time.perf_counter()

        # remove previous viz results
        makedirs(args.vis, remove=True)

        # switch to eval mode
        netWrapper.eval()

        # initialize meters
        loss_meter = AverageMeter()
        sdr_mix_meter = AverageMeter()
        sdr_meter = AverageMeter()
        sir_meter = AverageMeter()
        sar_meter = AverageMeter()

        # initialize HTML header
        visualizer = HTMLVisualizer(os.path.join(args.vis, 'index.html'))
        header = ['Filename', 'Input Mixed Audio']
        for n in range(1, args.num_mix + 1):
            header += [
                'Video {:d}'.format(n), 'Predicted Audio {:d}'.format(n),
                'GroundTruth Audio {}'.format(n),
                'Predicted Mask {}'.format(n), 'GroundTruth Mask {}'.format(n)
            ]
        header += ['Loss weighting']
        visualizer.add_header(header)
        vis_rows = []

        for batch_data in tqdm(loader):
            # forward pass
            err, outputs = netWrapper.forward(batch_data, args)
            err = err.mean()

            loss_meter.update(err.item())
            outputs = loader.dataset._dump_stft(outputs, batch_data,
                                                args)  # compute mag, mask

            # calculate metrics --> speed-up? BUG
            sdr_mix, sdr, sir, sar = calc_metrics(batch_data, outputs, args)
            sdr_mix_meter.update(sdr_mix)
            sdr_meter.update(sdr)
            sir_meter.update(sir)
            sar_meter.update(sar)

            # output visualization
            if len(vis_rows) < args.num_vis:
                output_visuals(vis_rows, batch_data, outputs, args)

        print('[Eval Summary] Epoch: {}, Time: {:.2f} Loss: {:.4f}, '
              'SDR_mixture: {:.4f}, SDR: {:.4f}, SIR: {:.4f}, SAR: {:.4f}'.
              format(epoch,
                     time.perf_counter() - tic, loss_meter.average(),
                     sdr_mix_meter.average(), sdr_meter.average(),
                     sir_meter.average(), sar_meter.average()))
        history['val']['epoch'].append(epoch)
        history['val']['err'].append(loss_meter.average())
        history['val']['sdr'].append(sdr_meter.average())
        history['val']['sir'].append(sir_meter.average())
        history['val']['sar'].append(sar_meter.average())

        print('Plotting html for visualization...')
        visualizer.add_rows(vis_rows)
        visualizer.write_html()

        # Plot figure
        if epoch > 0:
            print('Plotting figures...')
            plot_loss_metrics(args.ckpt, history)