def generate_unconditionally(cell_size=400, num_clusters=20, steps=800, random_state=700, \
                                state_dict_file='trained_models/unconditional_epoch_50.pt'):

    model = LSTMRandWriter(cell_size, num_clusters)
    # load trained model weights
    model.load_state_dict(torch.load(state_dict_file)['model'])

    np.random.seed(random_state)
    zero_tensor = torch.zeros((1, 1, 3))
    # initialize null hidden states and memory states
    init_states = [torch.zeros((1, 1, cell_size))] * 4
    if cuda:
        model.cuda()
        zero_tensor = zero_tensor.cuda()
        init_states = [state.cuda() for state in init_states]
    x = Variable(zero_tensor)
    init_states = [
        Variable(state, requires_grad=False) for state in init_states
    ]
    h1_init, c1_init, h2_init, c2_init = init_states
    prev = (h1_init, c1_init)
    prev2 = (h2_init, c2_init)

    record = [np.array([0, 0, 0])]

    for i in range(steps):
        end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p, prev, prev2 = model(
            x, prev, prev2)

        # sample end stroke indicator
        prob_end = end.data[0][0][0]
        sample_end = np.random.binomial(1, prob_end.cpu())
        sample_index = np.random.choice(range(20),
                                        p=weights.data[0][0].cpu().numpy())

        # sample new stroke point
        mu = np.array([
            mu_1.data[0][0][sample_index].item(),
            mu_2.data[0][0][sample_index].item()
        ])
        v1 = log_sigma_1.exp().data[0][0][sample_index].item()**2
        v2 = log_sigma_2.exp().data[0][0][sample_index].item()**2
        c = p.data[0][0][sample_index].item()*log_sigma_1.exp().data[0][0][sample_index].item()\
            *log_sigma_2.exp().data[0][0][sample_index].item()
        cov = np.array([[v1, c], [c, v2]])
        sample_point = np.random.multivariate_normal(mu, cov)

        out = np.insert(sample_point, 0, sample_end)
        record.append(out)
        x = torch.from_numpy(out).type(torch.FloatTensor)
        if cuda:
            x = x.cuda()
        x = Variable(x, requires_grad=False)
        x = x.view((1, 1, 3))

    plot_stroke(np.array(record))
def generate_conditionally(text, cell_size=400, num_clusters=20, K=10, random_state=700, \
                            bias=1., bias2=1., state_dict_file='trained_models/conditional_epoch_60.pt'):

    char_to_code = torch.load('char_to_code.pt')
    np.random.seed(random_state)
    text = text + ' '

    model = LSTMSynthesis(len(text),
                          len(char_to_code) + 1, cell_size, num_clusters, K)
    model.load_state_dict(torch.load(state_dict_file)['model'])

    onehots = np.zeros((len(text), len(char_to_code) + 1))
    for _ in range(len(text)):
        try:
            onehots[_][char_to_code[text[_]]] = 1
        except:
            onehots[_][-1] = 1

    zero_tensor = torch.zeros((1, 1, 3))
    h1_init, c1_init = torch.zeros((1, cell_size)), torch.zeros((1, cell_size))
    h2_init, c2_init = torch.zeros((1, 1, cell_size)), torch.zeros(
        (1, 1, cell_size))
    kappa_old = torch.zeros(1, K)
    onehots = torch.from_numpy(onehots).type(torch.FloatTensor)
    text_len = torch.from_numpy(np.array([[len(text)]
                                          ])).type(torch.FloatTensor)

    if cuda:
        model.cuda()
        zero_tensor = zero_tensor.cuda()
        h1_init, c1_init = h1_init.cuda(), c1_init.cuda()
        h2_init, c2_init = h2_init.cuda(), c2_init.cuda()
        kappa_old = kappa_old.cuda()
        onehots = onehots.cuda()
        text_len = text_len.cuda()

    x = Variable(zero_tensor)
    h1_init, c1_init = Variable(h1_init), Variable(c1_init)
    h2_init, c2_init = Variable(h2_init), Variable(c2_init)
    prev = (h1_init, c1_init)
    prev2 = (h2_init, c2_init)
    kappa_old = Variable(kappa_old)
    onehots = Variable(onehots, requires_grad=False)
    w_old = onehots.narrow(0, 0, 1)  # attention on the first input text char
    text_len = Variable(text_len)

    record = [np.zeros(3)]
    phis = []
    stop = False
    count = 0
    while not stop:
        outputs = model(x, onehots, text_len, w_old, kappa_old, prev, prev2,
                        bias)
        end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, rho, w_old, kappa_old, prev, prev2, old_phi = outputs

        #bernoulli sample
        prob_end = end.data[0][0][0]
        sample_end = np.random.binomial(1, prob_end)

        #mog sample
        sample_index = np.random.choice(range(20),
                                        p=weights.data[0][0].cpu().numpy())
        mu = np.array(
            [mu_1.data[0][0][sample_index], mu_2.data[0][0][sample_index]])
        log_sigma_1 = log_sigma_1 - bias2
        log_sigma_2 = log_sigma_2 - bias2
        v1 = (log_sigma_1).exp().data[0][0][sample_index]**2
        v2 = (log_sigma_2).exp().data[0][0][sample_index]**2
        c = rho.data[0][0][sample_index]*log_sigma_1.exp().data[0][0][sample_index]\
            *log_sigma_2.exp().data[0][0][sample_index]
        cov = np.array([[v1, c], [c, v2]])
        sample_point = np.random.multivariate_normal(mu, cov)

        out = np.insert(sample_point, 0, sample_end)
        record.append(out)
        x = torch.from_numpy(out).type(torch.FloatTensor)
        if cuda:
            x = x.cuda()
        x = Variable(x, requires_grad=False)
        x = x.view(1, 1, 3)

        # attention
        old_phi = old_phi.squeeze(0)
        phis.append(old_phi)
        old_phi = old_phi.data.cpu().numpy()

        # hack to prevent early exit (attention is unstable at the beginning)
        if count >= 20 and np.max(old_phi) == old_phi[-1]:
            stop = True
        count += 1

    phis = torch.stack(phis).data.cpu().numpy().T

    plot_stroke(np.array(record))
    attention_plot(phis)