def train_method(path, lr=1e-3, epochs=50, dropout=.2, batch_size=1024):
    print('TRAINING MODEL')
    dataset = MyDataset(bucket='test', path=path)
    data_loader = DataLoader(dataset)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = MDN(dataset.shape[1] - 1, dropout=dropout).to(device=device)
    optimizer = optim.Adam(net.parameters(), lr=lr)

    for i in range(1, epochs + 1):
        for file_data in data_loader:
            file_data = file_data.squeeze().float()
            x_file, y_file = file_data[:, :-1], file_data[:, -1]
            for j in range(0, x_file.shape[0], batch_size):
                x_batch = x_file[j:j + batch_size].to(device=device)
                y_batch = y_file[j:j + batch_size].to(device=device)
                y_train_pred = net(x_batch).float()

                loss = Normal(y_train_pred[:, 0], y_train_pred[:, 1])
                loss = -loss.log_prob(y_batch).mean()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        print(i, loss.item())
Ejemplo n.º 2
0
def generate_author_graph(config, topic_popularities, country_popularities):

    topic_popularities = torch.tensor(topic_popularities)

    country_popularities = torch.tensor(country_popularities)

    author_country_id = Categorical(
        logits=torch.zeros(NUM_COUNTRIES)).sample().item()

    author_main_topic_id = Categorical(
        logits=torch.zeros(NUM_TOPICS)).sample().item()

    off_topic_logits = torch.zeros(NUM_TOPICS)

    off_topic_logits[author_main_topic_id] = -1000  # sets prob -> 0

    author_off_topic_id = Categorical(logits=off_topic_logits).sample().item()

    author_base_popularity = country_popularities[author_country_id]

    author_popularity = Normal(
        author_base_popularity,
        config['author-popularity-std']).sample().item()

    num_books = Normal(config['num-books-mean'],
                       config['num-books-std']).sample().floor().int()

    is_on_topic = Bernoulli(probs=torch.empty(num_books.item()).fill_(
        1 - config['off-topic-probability'])).sample().byte()

    topics = torch.where(
        is_on_topic,
        torch.empty(num_books.item(),
                    dtype=torch.long).fill_(author_main_topic_id),
        torch.empty(num_books.item(),
                    dtype=torch.long).fill_(author_off_topic_id))

    popularities = (topic_popularities[topics] + author_popularity)

    popularities[~is_on_topic] /= config['off-topic-effect-size']

    sales = torch.max(
        Normal(popularities, config['sales-std']).sample(),
        torch.zeros_like(popularities))

    return author_country_id, topics, sales, popularities