def train_cycle(use_wandb=True):
    print("%s: Training the model" % (time.strftime("%Y/%m/%d-%H:%M:%S")))

    # n_iters = 100000
    # print_every = 5000
    # plot_every = 1000
    # print_every = 1
    # plot_every = 2
    # embedding_size = 2
    # num_epochs = 300
    print_every = 50
    plot_every = 500
    embedding_size = 32
    num_epochs = 30
    margin = -1.0
    train_size = None
    evaluate_size = 100
    save_path = './unif_model.ckpt'

    # Keep track of losses for plotting
    current_print_loss = 0
    current_plot_loss = 0
    all_losses = []

    start = time.time()
    code_snippets_file = './data/parallel_bodies_n1000'
    descriptions_file = './data/parallel_desc_n1000'
    dataset = CodeDescDataset(code_snippets_file, descriptions_file,
                              train_size)
    num_iters = len(dataset)
    # model = UNIF(dataset.code_vocab_size, dataset.desc_vocab_size, embedding_size)
    model = UNIFNoAttention(dataset.code_vocab_size, dataset.desc_vocab_size,
                            embedding_size)
    cosine_similarity_function = nn.CosineSimilarity()

    loss_function = nn.CosineEmbeddingLoss(margin=margin)
    learning_rate = 0.05  # If you set this too high, it might explode. If too low, it might not learn
    optimiser = torch.optim.SGD(model.parameters(), lr=learning_rate)

    if use_wandb:
        wandb.init(project='code-search', name='unif-cosine-neg', reinit=True)
        config = wandb.config
        config.learning_rate = learning_rate
        config.embedding_size = embedding_size
        config.evaluate_size = evaluate_size
        config.margin = margin
        config.num_epochs = num_epochs
        config.train_size = len(dataset)
        wandb.watch(model, log_freq=plot_every)
        metrics = evaluate_top_n(model, evaluate_size)
        if use_wandb:
            wandb.log(metrics)

    for epoch in range(num_epochs):
        print('Epoch: ', epoch)

        for iter in range(num_iters):
            # print(iter)
            tokenized_code, tokenized_positive_desc, tokenized_negative_desc =\
                dataset[iter]
            code_embedding, desc_embedding, loss = train(
                model, loss_function, optimiser, tokenized_code,
                tokenized_positive_desc)
            current_print_loss += loss
            current_plot_loss += loss

            # Print iter number, loss, name and guess
            if (iter + 1) % print_every == 0:
                print('%d %d%% (%s) %.4f' %
                      (iter + 1, (iter + 1) / num_iters * 100,
                       timeSince(start), current_print_loss / print_every))
                cosine_similarity = cosine_similarity_function(
                    code_embedding, desc_embedding).item()
                print('Cosine similarity:', cosine_similarity)
                # print('Cosine similarity:', cosine_similarity, code_embedding, desc_embedding)
                current_print_loss = 0

            # Add current loss avg to list of losses
            if (iter + 1) % plot_every == 0:
                torch.save(model.state_dict(), save_path)
                metrics = evaluate_top_n(model, evaluate_size)
                metrics.update({'loss': current_plot_loss / plot_every})
                all_losses.append(current_plot_loss / plot_every)
                current_plot_loss = 0
                if use_wandb:
                    wandb.log(metrics)

    return model, current_print_loss, all_losses
def test_unif_model():
    unif_model, current_loss, all_losses = unif_train_triplet.train_cycle()
    plot(all_losses)
    evaluate_top_n(unif_model)
def test_unif_cosine_pos_model():
    unif_model, current_loss, all_losses = unif_train_cosine_pos.train_cycle(
        True)
    plot(all_losses)
    evaluate_top_n(unif_model)
def test_random_model():
    random_model = RandomModel(embedding_size=128)
    evaluate_top_n(random_model)