Exemplo n.º 1
0
    def myLogger(self,hd,htx,prdct):
# aqui ha um logger customizado que ira imprimir qual o output esperado
# e qual foi alcançado
        decode_prdct = "".join(self.decode(hangul_to_ix,prdct[0])) # decodificamos a saida
        decode_utf = list(map(lambda x : x.decode('utf-8'),prdct[0]))
        decode_inpt = jamotools.join_jamos(decode_utf)
        h = "".join(prdct[1])
        ph = round(SequenceMatcher(None,h,decode_prdct).ratio()*100)
        if ph > 50 :
            ph = Fore.GREEN+str(ph)+"%"+Fore.RESET
        else:
            ph = Fore.RED+str(ph)+"%"+Fore.RESET

        log = """
-----
    -- Prediction Resume --
with data input -> {i} ;
expected output -> {h} ;
and the model predicted -> {d} ;
similarity -> {w}
    -- end --

-----
""".format(i=Fore.GREEN+decode_inpt+Fore.RESET,h=Fore.GREEN+h+Fore.RESET,d=Fore.RED+decode_prdct+Fore.RESET,w=ph)
        print(log)
        txtlog["Prediction Resume"]["with data input"].append(decode_inpt)
        txtlog["Prediction Resume"]["expected output"].append(h)
        txtlog["Prediction Resume"]["the model predicted"].append(decode_prdct)
        txtlog["Prediction Resume"]["similarity"].append(ph)
Exemplo n.º 2
0
def mel_tensor_to_plt_image(tensor, titles, step):

    B, H, L = tensor.shape

    x = 4
    y = int(np.ceil(B / x))

    fig, axes = plt.subplots(y, x, sharey=True, figsize=(36, 12))
    fig.suptitle(f'Mel-spectrogram from Step #{step:07d}', fontsize=24, y=0.95)
    axes = axes.flatten()
    for i in range(B):
        im = axes[i].imshow(tensor[i, :, :], origin='lower', aspect='auto')
        im.set_clim(-1, 1)
        axes[i].axes.xaxis.set_visible(False)
        axes[i].axes.yaxis.set_visible(False)
        title = jamotools.join_jamos(titles[i].replace('<s>',
                                                       '').replace('</s>', ''))
        axes[i].set_title(title)
    fig.colorbar(im, ax=axes, location='right')
    fig.canvas.draw()

    image_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    image_array = image_array.reshape(fig.canvas.get_width_height()[::-1] +
                                      (3, ))

    image_array = np.swapaxes(image_array, 0, 2)
    image_array = np.swapaxes(image_array, 1, 2)

    plt.close()

    return image_array
Exemplo n.º 3
0
def predict_next_character(model, start_string, vocabulary, num_generate):
    input_eval = [vocabulary.char2index(s) for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)
    text_generated_jamo = []
    temperature = 1.0
    model.reset_states()
    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)
        predictions = predictions / temperature
        predicted_id = tf.random.categorical(predictions,
                                             num_samples=1)[-1, 0].numpy()
        input_eval = tf.expand_dims([predicted_id], 0)
        text_generated_jamo.append(vocabulary.index2char(predicted_id))
    text_generated = jamotools.join_jamos(text_generated_jamo)
    return jamotools.join_jamos(start_string) + text_generated
Exemplo n.º 4
0
def wordedits(word):
    """ Splits the Korean characters into letters, and simply calls all the other functions 
	getting all inserts, removals, etc. in one list.  It then joins the characters back together
	and returns all of the edits in a single list """

    editsret = []
    alleditsjoined = []
    splitword = jamotools.split_syllables(word)
    alleditssplit = list(
        inserts(splitword) + removals(splitword) + swaps(splitword) +
        replaces(splitword))
    for w in alleditssplit:
        alleditsjoined.append(jamotools.join_jamos(w))
    print(alleditsjoined)
    realedits = dictionarycomparer(alleditsjoined)
    return (realedits)
Exemplo n.º 5
0
def get_data_from_hashtag(hashtag: str, data_path='./data'):
    current_datetime = datetime.now().strftime('%Y%m%d_%H%M%S')
    dir_name = '%s/%s_%s' % (data_path, current_datetime, hashtag)
    print(dir_name, 'start')
    os.makedirs(dir_name, exist_ok=True)
    url = 'https://www.instagram.com/explore/tags/%s/?__a=1' % hashtag
    response = requests.get(url)
    raw_json = response.json()
    json.dump(raw_json, open('%s/data.json' % dir_name, 'w'))
    nodes = {
        'popular':
        raw_json['graphql']['hashtag']['edge_hashtag_to_top_posts']['edges'],
        'recent':
        raw_json['graphql']['hashtag']['edge_hashtag_to_media']['edges']
    }
    ret = []
    for k in nodes:
        for node in nodes[k]:
            img = requests.get(node['node']['display_url']).content
            original = join_jamos(node['node']['edge_media_to_caption']
                                  ['edges'][0]['node']['text'] if node['node']
                                  ['edge_media_to_caption']['edges'] else "")
            tags, content = split_content_tag(original)
            caption = node['node']['accessibility_caption']
            date = None
            if caption:
                date = re.findall('[A-Za-z]+ [0-9]{2}, [0-9]{4}', caption)
            if date:
                date = datetime.strptime(date[0],
                                         '%B %d, %Y').strftime('%Y-%m-%d')
            else:
                date = '9999-12-31'
            ret.append([node['node']['id'], original, content, tags, k, date])

            with open('%s/img_%s.jpg' % (dir_name, node['node']['id']),
                      'wb') as f:
                f.write(img)
    with open('%s/metadata.csv' % dir_name,
              'w',
              encoding='utf-8-sig',
              newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            'id', 'original', 'content', 'hashtag', 'category', 'created_date'
        ])
        writer.writerows(ret)
    print(dir_name, 'finish')
Exemplo n.º 6
0
def testmodel2(epoch, logs):
    if epoch % 5 != 0 and epoch != 99:
        return

    test_sentence = train_text[:48]
    test_sentence = jamotools.split_syllables(test_sentence)

    next_chars = 300
    for _ in range(next_chars):
        test_text_X = test_sentence[-seq_length:]
        test_text_X = np.array([
            char2idx[c] if c in char2idx else char2idx['UNK']
            for c in test_text_X
        ])
        test_text_X = pad_sequences([test_text_X],
                                    maxlen=seq_length,
                                    padding='pre',
                                    value=char2idx['UNK'])
        output_idx = model.predict_classes(test_text_X)
        test_sentence += idx2char[output_idx[0]]

    print()
    print(jamotools.join_jamos(test_sentence))
    print()
Exemplo n.º 7
0
 def test_join_jamos(self, input, output):
     input = ''.join([_hex_string_to_str(h) for h in input])
     pred = jamotools.join_jamos(input)
     self.assertEqual(pred, output)
Exemplo n.º 8
0
def jamo2text(jamo):
    return jamotools.join_jamos(jamo)
Exemplo n.º 9
0
def ota_translater(word):
    param = jamotools.split_syllables(word)
    result = correct(param)
    return jamotools.join_jamos(result)
Exemplo n.º 10
0
import numpy as np

path_to_file = tf.keras.utils.get_file(
    'toji.txt',
    'https://raw.githubusercontent.com/pykwon/etc/master/rnn_test_toji.txt')
#path_to_file = 'silrok.txt'
train_text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
s = train_text[:100]
print(s)

# 한글 텍스트를 자모 단위로 분리. 한자 등에는 영향 X
s_split = jamotools.split_syllables(s)  # 100글자의 한글이 자모 단위로 분리됨
print(s_split)

# 자모 결합 테스트
s2 = jamotools.join_jamos(s_split)
print(s2)  # 결합된 결과
print(s == s2)  # True 분리 전후의 문장이 비교 결과 같음

# 자모 토큰화 : 텍스트를 자모 단위로 나눕니다. 지연 시간 필요.
train_text_X = jamotools.split_syllables(train_text)
vocab = sorted(set(train_text_X))
vocab.append('UNK')  # 사전에 정의되지 않은 기호가 있을 수 있으므로 'UNK'도 사전에 넣음
print('{} unique characters'.format(len(vocab)))  # 179 unique characters

# vocab list를 숫자로 맵핑하고, 반대도 실행.
char2idx = {u: i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

text_as_int = np.array([char2idx[c] for c in train_text_X])
Exemplo n.º 11
0
from shutil import copytree
import glob, os, pdb, jamotools

src_dir = '/mnt/nvme0/snuface'
tgt_dir = '/mnt/nvme0/snuface_release'

splits = ['train', 'val', 'test', 'test_foreign']

assert not os.path.exists(tgt_dir)

# convert all identity names to alphanumeric such that it works on non-unicode systems
for sid, split in enumerate(splits):
    os.makedirs(os.path.join(tgt_dir, split))
    folders = glob.glob('%s/%s/*/' % (src_dir, split))
    folders.sort()
    with open(os.path.join(tgt_dir, split, 'mapping.txt'), 'w') as f:
        for iid, folder in enumerate(folders):
            newname = 'id%05d' % (sid * 10000 + iid)
            oldname = folder.split('/')[-2]
            f.write('%s,%s\n' % (newname, jamotools.join_jamos(oldname)))
            copytree(folder, os.path.join(tgt_dir, split, newname))
            print('Copied %s %s' % (newname, oldname))
Exemplo n.º 12
0
def main():

    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    parser = argparse.ArgumentParser(
        description='Speech hackathon lilililill model')
    parser.add_argument('--max_epochs',
                        type=int,
                        default=100,
                        help='number of max epochs in training (default: 100)')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--save_name',
                        type=str,
                        default='model',
                        help='the name of model in nsml or local')

    parser.add_argument('--dropout',
                        type=float,
                        default=0.2,
                        help='dropout rate in training (default: 0.2)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-03,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--num_mels',
                        type=int,
                        default=80,
                        help='number of the mel bands (default: 80)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='batch size in training (default: 128)')
    parser.add_argument("--num_thread",
                        type=int,
                        default=4,
                        help='number of the loading thread (default: 4)')
    parser.add_argument('--num_hidden_enc',
                        type=int,
                        default=1024,
                        help='hidden size of model (default: 1024)')
    parser.add_argument('--num_hidden_dec',
                        type=int,
                        default=512,
                        help='hidden size of model decoder (default: 512)')
    parser.add_argument(
        '--nsc_in_ms',
        type=int,
        default=50,
        help='Number of sample size per time segment in ms (default: 50)')

    parser.add_argument(
        '--ref_repeat',
        type=int,
        default=1,
        help='Number of repetition of reference seq2seq (default: 1)')
    parser.add_argument('--loss_lim',
                        type=float,
                        default=0.05,
                        help='Minimum loss threshold (default: 0.05)')

    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument("--pause", type=int, default=0)
    parser.add_argument('--memo',
                        type=str,
                        default='',
                        help='Comment you wish to leave')
    parser.add_argument('--debug',
                        type=str,
                        default='False',
                        help='debug mode')

    parser.add_argument('--load', type=str, default=None)

    args = parser.parse_args()

    batch_size = args.batch_size
    num_thread = args.num_thread
    num_mels = args.num_mels

    char2index, index2char = load_label('./hackathon.labels')
    SOS_token = char2index['<s>']  # '<sos>' or '<s>'
    EOS_token = char2index['</s>']  # '<eos>' or '</s>'
    PAD_token = char2index['_']  # '-' or '_'

    unicode_jamo_list = My_Unicode_Jamo_v2()
    # logger.info(''.join(unicode_jamo_list))

    # logger.info('This is a new main2.py')

    tokenizer = Tokenizer(unicode_jamo_list)
    jamo_tokens = tokenizer.word2num(unicode_jamo_list)
    # logger.info('Tokens: {}'.format(jamo_tokens))

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if args.cuda else 'cpu')

    net = Mel2SeqNet_v2(num_mels, args.num_hidden_enc, args.num_hidden_dec,
                        len(unicode_jamo_list), device)
    net_optimizer = optim.Adam(net.parameters(), lr=args.lr)
    ctc_loss = nn.CTCLoss().to(device)

    # net_B = Seq2SeqNet(512, jamo_tokens, char2index, device) #########
    net_B = Seq2SeqNet_v2(1024, jamo_tokens, char2index, device)  #########
    net_B_optimizer = optim.Adam(net_B.parameters(), lr=args.lr)  #########
    net_B_criterion = nn.NLLLoss(reduction='none').to(device)  #########

    bind_model(net, net_B, net_optimizer, net_B_optimizer, index2char,
               tokenizer)

    if args.pause == 1:
        nsml.paused(scope=locals())

    if args.mode != "train":
        return

    if args.load != None:
        nsml.load(checkpoint='best',
                  session='team47/sr-hack-2019-dataset/' + args.load)
        nsml.save('saved')

    for g in net_optimizer.param_groups:
        g['lr'] = 1e-05

    for g in net_B_optimizer.param_groups:
        g['lr'] = 1e-05

    for g in net_optimizer.param_groups:
        logger.info(g['lr'])

    for g in net_B_optimizer.param_groups:
        logger.info(g['lr'])

    wav_paths, script_paths, korean_script_paths = get_paths(DATASET_PATH)
    logger.info('Korean script path 0: {}'.format(korean_script_paths[0]))

    logger.info('wav_paths len: {}'.format(len(wav_paths)))
    logger.info('script_paths len: {}'.format(len(script_paths)))
    logger.info('korean_script_paths len: {}'.format(len(korean_script_paths)))

    # Load Korean Scripts

    korean_script_list, jamo_script_list = get_korean_and_jamo_list_v2(
        korean_script_paths)

    logger.info('Korean script 0: {}'.format(korean_script_list[0]))
    logger.info('Korean script 0 length: {}'.format(len(
        korean_script_list[0])))
    logger.info('Jamo script 0: {}'.format(jamo_script_list[0]))
    logger.info('Jamo script 0 length: {}'.format(len(jamo_script_list[0])))

    script_path_list = get_script_list(script_paths, SOS_token, EOS_token)

    ground_truth_list = [
        (tokenizer.word2num(['<s>'] + list(jamo_script_list[i]) + ['</s>']))
        for i in range(len(jamo_script_list))
    ]

    # 90% of the data will be used as train
    # split_index = int(0.9 * len(wav_paths))
    split_index = int(0.95 * len(wav_paths))

    wav_path_list_train = wav_paths[:split_index]
    ground_truth_list_train = ground_truth_list[:split_index]
    korean_script_list_train = korean_script_list[:split_index]
    script_path_list_train = script_path_list[:split_index]

    wav_path_list_eval = wav_paths[split_index:]
    ground_truth_list_eval = ground_truth_list[split_index:]
    korean_script_list_eval = korean_script_list[split_index:]
    script_path_list_eval = script_path_list[split_index:]

    logger.info('Total:Train:Eval = {}:{}:{}'.format(len(wav_paths),
                                                     len(wav_path_list_train),
                                                     len(wav_path_list_eval)))

    preloader_eval = Threading_Batched_Preloader_v2(wav_path_list_eval,
                                                    ground_truth_list_eval,
                                                    script_path_list_eval,
                                                    korean_script_list_eval,
                                                    batch_size,
                                                    num_mels,
                                                    args.nsc_in_ms,
                                                    is_train=True)
    preloader_train = Threading_Batched_Preloader_v2(wav_path_list_train,
                                                     ground_truth_list_train,
                                                     script_path_list_train,
                                                     korean_script_list_train,
                                                     batch_size,
                                                     num_mels,
                                                     args.nsc_in_ms,
                                                     is_train=False)

    best_loss = 1e10
    best_eval_cer = 1e10

    # load all target scripts for reducing disk i/o
    target_path = os.path.join(DATASET_PATH, 'train_label')
    load_targets(target_path)

    logger.info('start')

    train_begin = time.time()

    for epoch in range(args.max_epochs):

        logger.info((datetime.now().strftime('%m-%d %H:%M:%S')))

        net.train()
        net_B.train()

        preloader_train.initialize_batch(num_thread)
        loss_list_train = list()
        seq2seq_loss_list_train = list()
        seq2seq_loss_list_train_ref = list()

        logger.info("Initialized Training Preloader")
        count = 0

        total_dist = 0
        total_length = 1
        total_dist_ref = 0
        total_length_ref = 1

        while not preloader_train.end_flag:
            batch = preloader_train.get_batch()
            # logger.info(psutil.virtual_memory())
            # logger.info("Got Batch")
            if batch is not None:
                # logger.info("Training Batch is not None")
                tensor_input, ground_truth, loss_mask, length_list, batched_num_script, batched_num_script_loss_mask = batch
                pred_tensor, loss = train(net, net_optimizer, ctc_loss,
                                          tensor_input.to(device),
                                          ground_truth.to(device),
                                          length_list.to(device), device)
                loss_list_train.append(loss)

                ####################################################

                jamo_result = Decode_Prediction_No_Filtering(
                    pred_tensor, tokenizer)

                true_string_list = Decode_Num_Script(
                    batched_num_script.detach().cpu().numpy(), index2char)

                for i in range(args.ref_repeat):
                    lev_input_ref = ground_truth

                    lev_pred_ref, attentions_ref, seq2seq_loss_ref = net_B.net_train(
                        lev_input_ref.to(device),
                        batched_num_script.to(device),
                        batched_num_script_loss_mask.to(device),
                        net_B_optimizer, net_B_criterion)

                pred_string_list_ref = Decode_Lev_Prediction(
                    lev_pred_ref, index2char)
                seq2seq_loss_list_train_ref.append(seq2seq_loss_ref)
                dist_ref, length_ref = char_distance_list(
                    true_string_list, pred_string_list_ref)

                pred_string_list = [None]

                dist = 0
                length = 0

                if (loss < args.loss_lim):
                    lev_input = Decode_CTC_Prediction_And_Batch(pred_tensor)
                    lev_pred, attentions, seq2seq_loss = net_B.net_train(
                        lev_input.to(device), batched_num_script.to(device),
                        batched_num_script_loss_mask.to(device),
                        net_B_optimizer, net_B_criterion)
                    pred_string_list = Decode_Lev_Prediction(
                        lev_pred, index2char)
                    seq2seq_loss_list_train.append(seq2seq_loss)
                    dist, length = char_distance_list(true_string_list,
                                                      pred_string_list)

                total_dist_ref += dist_ref
                total_length_ref += length_ref

                total_dist += dist
                total_length += length

                count += 1

                if count % 25 == 0:
                    logger.info("Train: Count {} | {} => {}".format(
                        count, true_string_list[0], pred_string_list_ref[0]))

                    logger.info("Train: Count {} | {} => {} => {}".format(
                        count, true_string_list[0], jamo_result[0],
                        pred_string_list[0]))

            else:
                logger.info("Training Batch is None")

        # del preloader_train

        # logger.info(loss_list_train)
        train_loss = np.mean(np.asarray(loss_list_train))
        train_cer = np.mean(np.asarray(total_dist / total_length))
        train_cer_ref = np.mean(np.asarray(total_dist_ref / total_length_ref))

        logger.info("Mean Train Loss: {}".format(train_loss))
        logger.info("Total Train CER: {}".format(train_cer))
        logger.info("Total Train Reference CER: {}".format(train_cer_ref))

        preloader_eval.initialize_batch(num_thread)
        loss_list_eval = list()
        seq2seq_loss_list_eval = list()
        seq2seq_loss_list_eval_ref = list()

        logger.info("Initialized Evaluation Preloader")

        count = 0
        total_dist = 0
        total_length = 1
        total_dist_ref = 0
        total_length_ref = 1

        net.eval()
        net_B.eval()

        while not preloader_eval.end_flag:
            batch = preloader_eval.get_batch()
            if batch is not None:
                tensor_input, ground_truth, loss_mask, length_list, batched_num_script, batched_num_script_loss_mask = batch
                pred_tensor, loss = evaluate(net, ctc_loss,
                                             tensor_input.to(device),
                                             ground_truth.to(device),
                                             length_list.to(device), device)
                loss_list_eval.append(loss)

                ####################

                jamo_result = Decode_Prediction_No_Filtering(
                    pred_tensor, tokenizer)

                true_string_list = Decode_Num_Script(
                    batched_num_script.detach().cpu().numpy(), index2char)

                lev_input_ref = ground_truth
                lev_pred_ref, attentions_ref, seq2seq_loss_ref = net_B.net_eval(
                    lev_input_ref.to(device), batched_num_script.to(device),
                    batched_num_script_loss_mask.to(device), net_B_criterion)

                pred_string_list_ref = Decode_Lev_Prediction(
                    lev_pred_ref, index2char)
                seq2seq_loss_list_train_ref.append(seq2seq_loss_ref)
                dist_ref, length_ref = char_distance_list(
                    true_string_list, pred_string_list_ref)

                lev_input = Decode_CTC_Prediction_And_Batch(pred_tensor)
                lev_pred, attentions, seq2seq_loss = net_B.net_eval(
                    lev_input.to(device), batched_num_script.to(device),
                    batched_num_script_loss_mask.to(device), net_B_criterion)
                pred_string_list = Decode_Lev_Prediction(lev_pred, index2char)
                seq2seq_loss_list_train.append(seq2seq_loss)
                dist, length = char_distance_list(true_string_list,
                                                  pred_string_list)

                total_dist_ref += dist_ref
                total_length_ref += length_ref

                total_dist += dist
                total_length += length

                count += 1

                ####################

                if count % 5 == 0:
                    logger.info("Eval: Count {} | {} => {}".format(
                        count, true_string_list[0], pred_string_list_ref[0]))

                    logger.info("Eval: Count {} | {} => {} => {}".format(
                        count, true_string_list[0], jamo_result[0],
                        pred_string_list[0]))

            else:
                logger.info("Training Batch is None")

        eval_cer = total_dist / total_length
        eval_cer_ref = total_dist_ref / total_length_ref
        eval_loss = np.mean(np.asarray(loss_list_eval))

        logger.info("Mean Evaluation Loss: {}".format(eval_loss))
        logger.info("Total Evaluation CER: {}".format(eval_cer))
        logger.info("Total Evaluation Reference CER: {}".format(eval_cer_ref))

        nsml.report(False,
                    step=epoch,
                    train_epoch__loss=train_loss,
                    train_epoch__cer=train_cer,
                    train_epoch__cer_ref=train_cer_ref,
                    eval__loss=eval_loss,
                    eval__cer=eval_cer,
                    eval__cer_ref=eval_cer_ref)

        nsml.save(args.save_name)
        best_model = (eval_cer < best_eval_cer)
        if best_model:
            nsml.save('best')
            best_eval_cer = eval_cer

        logger.info("Inference Check")

        net.eval()
        net_B.eval()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        for wav_path in wav_path_list_eval:
            input = CREATE_MEL(wav_path, 40)
            input = input.type(torch.FloatTensor).to(device)

            pred_tensor = net(input)

            jamo_result = Decode_Prediction_No_Filtering(
                pred_tensor, tokenizer)

            lev_input = Decode_CTC_Prediction_And_Batch(pred_tensor)
            lev_pred = net_B.net_infer(lev_input.to(device))
            pred_string_list = Decode_Lev_Prediction(lev_pred, index2char)

            logger.info(pred_string_list[0])
            logger.info(
                jamotools.join_jamos(jamo_result[0]).replace('<s>', ''))
Exemplo n.º 13
0
# read csv file as a list of lists [csv exported from Excel spreadsheet]
with open('dataset.csv') as read_obj:
    print(read_obj)
    csv_reader = reader(read_obj)
    list_of_rows = list(csv_reader)

# Transpose the list
data = list(map(list, zip(*list_of_rows)))

# Make a dictionary for every student
namesdict = {}
assigndict = {}
for ii, row in enumerate(data[1:]):
    if row[0] != '':
        row_kr = [jamotools.join_jamos(x) for x in row]
        namesdict[row_kr[0]] = row_kr[1:]
        assigndict[row_kr[0]] = []

# Invert the dictionary
invdict = invert_dict(namesdict)

# In increasing order of frequency, assign celebrity names to students with fewest already assigned names
for ii in range(0, 10):
    for key in invdict:
        if len(invdict[key]) == ii:
            lengths = numpy.array([len(assigndict[x]) for x in invdict[key]])
            minidx = numpy.argmin(lengths)
            assigndict[invdict[key][minidx]].append(key)

for key in assigndict: