LOG_DIR = args.log_dir + '/run-test_{}-n{}-lr{}-wd{}-m{}-embeddings{}-msceleb-alpha10' \
    .format(args.optimizer, args.n_triplets, args.lr, args.wd,
            args.margin, args.embedding_size)

# create logger
logger = Logger(LOG_DIR)
# Define visulaize SummaryWriter instance
writer = SummaryWriter('Log/amsoftmax_res10', comment='margin0.3')

kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
if args.cos_sim:
    l2_dist = nn.CosineSimilarity(dim=1, eps=1e-6)
else:
    l2_dist = PairwiseDistance(2)

voxceleb, voxceleb_dev = wav_list_reader(args.dataroot)
if args.makemfb:
    # pbar = tqdm(voxceleb)
    for datum in voxceleb:
        mk_MFB(
            (args.dataroot + '/voxceleb1_wav/' + datum['filename'] + '.wav'))
    print("Complete convert")

if args.mfb:
    transform = transforms.Compose([
        concateinputfromMFB(),
        # truncatedinputfromMFB(),
        totensor()
    ])
    transform_T = transforms.Compose([
        concateinputfromMFB(input_per_file=args.test_input_per_file),
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    voxceleb, train_set, valid_set = wav_list_reader(args.dataroot, split=True)

    class_to_idx = np.load('Data/dataset/voxceleb1/Fbank64_Norm/class2idx.npy').item()
    print('Number of Speakers: {}.\n'.format(len(class_to_idx)))

    test_dir = VoxcelebTestset(dir=args.dataroot, pairs_path=args.test_pairs_path, loader=file_loader,
                               transform=transform_T)

    indices = list(range(len(test_dir)))
    random.shuffle(indices)
    indices = indices[:4000]
    test_part = torch.utils.data.Subset(test_dir, indices)

    valid_dir = ValidationDataset(voxceleb=valid_set, dir=args.dataroot, loader=file_loader, class_to_idx=class_to_idx,
                                  transform=transform)

    del voxceleb
    del train_set
    del valid_set

    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))

    # instantiate model and initialize weights
    model = SuperficialResCNN(layers=[1, 1, 1, 0], embedding_size=args.embedding_size, n_classes=len(class_to_idx), m=3)


    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model.load_state_dict(filtered)

        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    # train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True,
    #                                            # collate_fn=PadCollate(dim=2),
    #                                            **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.test_batch_size, shuffle=False,
                                               # collate_fn=PadCollate(dim=2),
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_part, batch_size=args.test_batch_size, shuffle=False, **kwargs)



    # train(train_loader, model, ce, optimizer, scheduler, epoch)
    # train(train_loader, model, ce)
    # test(test_loader, valid_loader, model)

    model.eval()

    valid_pbar = tqdm(enumerate(valid_loader))
    softmax = nn.Softmax(dim=1)

    correct = 0.
    total_datasize = 0.

    for batch_idx, (data, label) in valid_pbar:
        # data = Variable(data.cuda())
        # true_labels = Variable(label.cuda())
        pdb.set_trace()
        explainer = lime_image.LimeImageExplainer()
        explanation = explainer.explain_instance(data,
                                                 batch_predict,  # classification function
                                                 top_labels=5,
                                                 hide_color=0,
                                                 num_samples=len(data))

        temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=10,
                                                    hide_rest=False)
        img_boundry2 = mark_boundaries(temp / 255.0, mask)
        plt.imshow(img_boundry2)

        break

    writer.close()
Ejemplo n.º 3
0
kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {}
opt_kwargs = {
    'lr': args.lr,
    'lr_decay': args.lr_decay,
    'weight_decay': args.weight_decay,
    'dampening': args.dampening,
    'momentum': args.momentum
}

if args.cos_sim:
    l2_dist = nn.CosineSimilarity(dim=1, eps=1e-6)
else:
    l2_dist = PairwiseDistance(2)

voxceleb, train_set, valid_set = wav_list_reader(args.dataroot, split=True)
# voxceleb2, voxceleb2_dev = voxceleb2_list_reader(args.dataroot)

# if args.makemfb:
#     #pbar = tqdm(voxceleb)
#     for datum in voxceleb:
#         mk_MFB((args.dataroot +'/voxceleb1_wav/' + datum['filename']+'.wav'))
#     print("Complete convert")
#
# if args.makespec:
#     num_pro = 1.
#     for datum in voxceleb:
#         # Data/voxceleb1/
#         # /data/voxceleb/voxceleb1_wav/
#         GenerateSpect(wav_path='/data/voxceleb/voxceleb1_wav/' + datum['filename']+'.wav',
#                       write_path=args.dataroot +'/spectrogram/voxceleb1_wav/' + datum['filename']+'.npy')
Ejemplo n.º 4
0
             vox2['duration'] = len(item)
             vox_duration.append(vox2)

             share_lock.release()
             # print('')
         except Exception:
             error_queue.put(vox2)
         #share_lock.release()
         print('\rProcess {}: There are {:6d} features left.'.format(cpid, queue.qsize()), end='\r')
    pass

if __name__ == '__main__':
    queue = Queue()
    que_queue = Queue()
    # voxceleb2, voxceleb2_dev = voxceleb2_list_reader(dataroot)
    vox1, vox1_dev = wav_list_reader(dataroot)
    vox_duration = multiprocessing.Manager().list()
    # spk_utt_duration = multiprocessing.Manager().dict()

    share_lock = multiprocessing.Manager().Lock()

    for i in range(len(vox1)):
        queue.put(vox1[i])

    #check_from_queue(queue, que_queue, 1)

    process_list = []
    for i in range(15):
        # pro = Process(target=check_from_queue, args=(queue, que_queue, spk_utt_duration, i, share_lock))
        pro = Process(target=add_duration_vox, args=(queue, que_queue, vox_duration, i, share_lock))
        process_list.append(pro)