Beispiel #1
0
        t_idx = idx % len(t_train_batch)
        t1_idx = idx % len(t_train1_batch)

        s_data = s_train_batch[s_idx]
        t_data = t_train_batch[t_idx]
        t1_data = t_train1_batch[t1_idx]

        img = torch.cat((s_data[0], t_data[0]), 0)
        img = img.to(device)

        model_output = model(img)

        recons, y0, y0_orig, f, f2 = model_output['recons'], model_output[
            'y0'], model_output['y0_orig'], model_output['f'], model_output[
                'f2']
        img_re_loss, img_re_loss_per_sample = loss.get_reconstruction_loss(
            img, recons, mean=0.5, std=0.5)
        embedding_re_loss, embedding_re_loss_per_sample = loss.get_reconstruction_loss(
            f[:args.batch_size], f2[:args.batch_size], mean=0.5, std=0.5)

        tot_loss = img_re_loss + embedding_re_loss

        pretrain_optimizer.zero_grad()
        tot_loss.backward()
        pretrain_optimizer.step()
    pretrain_scheduler.step()

# forward and visualize the distribution of recon loss(pretrain)
target_anomaly_label_total = list()
target_recon_loss_total = dict()
source_recon_loss_total = dict()
target_recon_loss_total['img'] = list()
Beispiel #2
0
    model.train()
    tr_re_loss, tr_mem_loss, tr_tot = 0.0, 0.0, 0.0
    progress_bar = tqdm(train_batch)

    for batch_idx, frame in enumerate(progress_bar):
        progress_bar.update()
        frame = frame.reshape(
            [args.batch_size, args.t_length, args.c, args.h, args.w])
        frame = frame.permute(0, 2, 1, 3, 4)
        frame = frame.to(device)
        optimizer.zero_grad()

        model_output = model(frame)
        recons, attr = model_output['output'], model_output['att']
        re_loss = loss.get_reconstruction_loss(frame,
                                               recons,
                                               mean=0.5,
                                               std=0.5)
        mem_loss = loss.get_memory_loss(attr)
        tot_loss = re_loss + mem_loss * args.EntropyLossWeight
        tr_re_loss += re_loss.data.item()
        tr_mem_loss += mem_loss.data.item()
        tr_tot += tot_loss.data.item()

        tot_loss.backward()
        optimizer.step()

    train_writer.add_scalar("model/train-recons-loss",
                            tr_re_loss / len(train_batch), epoch)
    train_writer.add_scalar("model/train-memory-sparse",
                            tr_mem_loss / len(train_batch), epoch)
    train_writer.add_scalar("model/train-total-loss",
Beispiel #3
0
                s_data = s_train_batch[s_idx]
                t_data = t_train_batch[t_idx]

                img = torch.cat((s_data[0], t_data[0]), 0)
                img = img.reshape([
                    args.batch_size * 2, args.t_length, args.c, args.h, args.w
                ])
                img = img.permute(0, 2, 1, 3, 4)
                img = img.to(device)

                model_output = model(img)

                recons = model_output['output']
                s_re_loss, s_re_loss_per_sample = loss.get_reconstruction_loss(
                    img[:args.batch_size],
                    recons[:args.batch_size],
                    mean=0.5,
                    std=0.5)
                t_re_loss, t_re_loss_per_sample = loss.get_reconstruction_loss(
                    img[args.batch_size:],
                    recons[args.batch_size:],
                    mean=0.5,
                    std=0.5)
                tot_loss = s_re_loss + args.trade_off * t_re_loss
                pretrain_optimizer.zero_grad()
                tot_loss.backward()
                pretrain_optimizer.step()
                if idx == max(len(s_train_batch), len(t_train_batch)) - 1:
                    writer.add_scalar(
                        's_re_loss/AdversarialLossWeight_' +
                        str(args.AdversarialLossWeight), s_re_loss, epoch)
model.to(device)
model.eval()

# Test

img_recon_error_list = list()
embedding_recon_error_list = list()
anomaly_label = list()

for batch_idz, data in enumerate(t_batch):
    label = data[1]
    img = data[0].to(device)
    model_output = model(img)
    recons, y0, y0_orig, f, f2 = model_output['recons'], model_output[
        'y0'], model_output['y0_orig'], model_output['f'], model_output['f2']
    img_re_loss, img_re_loss_per_sample = loss.get_reconstruction_loss(
        img, recons, mean=0.5, std=0.5)
    embedding_re_loss, embedding_re_loss_per_sample = loss.get_reconstruction_loss(
        f, f2, mean=0.5, std=0.5)

    anomaly_label += label.tolist()
    img_recon_error_list.append(float(img_re_loss.cpu()))
    embedding_recon_error_list.append(float(embedding_re_loss.cpu()))

print('anomaly_label:', np.array(anomaly_label))
print('img_recon_error_list:', np.array(img_recon_error_list))
print('embedding_recon_error_list:', np.array(embedding_recon_error_list))

fpr, tpr, thresholds = metrics.roc_curve(np.array(anomaly_label),
                                         np.array(img_recon_error_list))
calculationAUC_plot(fpr, tpr, log_dir, args, 'img_recon')
fpr, tpr, thresholds = metrics.roc_curve(np.array(anomaly_label),
            s_idx = idx % len(s_train_batch)
            t_idx = idx % len(t_train_batch)

            s_data = s_train_batch[s_idx]
            t_data = t_train_batch[t_idx]

            img = torch.cat((s_data[0], t_data[0]), 0)
            img = img.to(device)

            model_output = model(img)

            recons = model_output['out']
            s_re_loss, s_re_loss_per_sample = loss.get_reconstruction_loss(
                img[:args.batch_size],
                recons[:args.batch_size],
                mean=0.5,
                std=0.5)
            t_re_loss, t_re_loss_per_sample = loss.get_reconstruction_loss(
                img[args.batch_size:],
                recons[args.batch_size:],
                mean=0.5,
                std=0.5)
            tot_loss = s_re_loss + args.trade_off * t_re_loss
            pretrain_optimizer.zero_grad()
            tot_loss.backward()
            pretrain_optimizer.step()
            if idx == max(len(s_train_batch), len(t_train_batch)) - 1:
                writer.add_scalar('s_re_loss/trade_off_' + str(args.trade_off),
                                  s_re_loss, epoch)
                writer.add_scalar('t_re_loss/trade_off_' + str(args.trade_off),
Beispiel #6
0
model.eval()

# Test

recon_error_list = list()
anomaly_label = list()
y_total = list()
if args.ModelName == 'AdversarialAE':
    for batch_idx, data in enumerate(t_batch):
        label = data[1]
        img = data[0].to(device)
        model_output = model(img)
        recons, domain_prediction, domain_prediction0 = model_output[
            'out'], model_output['y'], model_output['y0']
        re_loss, re_loss_per_sample = loss.get_reconstruction_loss(img,
                                                                   recons,
                                                                   mean=0.5,
                                                                   std=0.5)
        print('label:', label)
        print('y:', domain_prediction[0])
        print('re_loss:', re_loss)
        anomaly_label += label.tolist()
        y_total += domain_prediction[0].cpu().tolist()
        recon_error_list.append(float(re_loss.cpu()))
else:
    for batch_idx, data in enumerate(t_batch):
        label = data[1]
        img = data[0].to(device)
        model_output = model(img)
        recons, attr = model_output['output'], model_output['att']
        re_loss = loss.get_reconstruction_loss(img, recons, mean=0.5, std=0.5)
        print('label:', label)
        recon_error = np.mean(r ** 2)  # **0.5
    elif (ModelName == 'MemAE'):
        recon_res = model(frames)
        recon_frames = recon_res['output']
        recon_np = utils.vframes2imgs(unorm_trans(recon_frames.data), step=1, batch_idx=0)
        input_np = utils.vframes2imgs(unorm_trans(frames.data), step=1, batch_idx=0)
        r = utils.crop_image(recon_np, img_crop_size) - utils.crop_image(input_np, img_crop_size)
        sp_error_map = sum(r ** 2) ** 0.5
        recon_error = np.mean(sp_error_map.flatten())
    elif ModelName == 'AdversarialAE':
        recon_res = model(frames)
        recon_frames = recon_res['output']
#         torchvision.utils.save_image((torch.abs(recon_frames[0, 0, 8, :, :]-frames[0, 0, 8, :, :])).cpu(), './'+args.model_name.split('.')[0]+'/%d.png'%idx, normalize=True)
#         torchvision.utils.save_image(recon_frames[0, 0, 8, :, :].cpu(), './'+args.model_name.split('.')[0]+'/recon_%d.png'%idx, normalize=True)
#         torchvision.utils.save_image(frames[0, 0, 8, :, :].cpu(), './'+args.model_name.split('.')[0]+'/original_%d.png'%idx, normalize=True)
        _, re_loss_per_sample = loss.get_reconstruction_loss(frames, recon_frames, mean=0.5, std=0.5)
        recon_error_list.append(re_loss_per_sample)
        label_list.append(label)
#         print('recon_loss_per_sample:')
#         print(re_loss_per_sample)
#         print('label:')
#         print(label)
    else:
        recon_error = -1
        print('Wrong ModelName.')
    idx += 1
end = time.clock()
print(end-start)

recon_error_list = torch.cat(recon_error_list, 0)
label = torch.cat(label_list, 0)