예제 #1
0
def main():

    # Read sentences
    sentences = readFile("words2.txt")
    print(sentences)

    # Make uniq words list
    words = []
    uniqWords = []
    for sentence in sentences:
        for word in sentence:
            words.append(word)
            if word not in uniqWords:
                uniqWords.append(word)
    print(uniqWords)
    uniqWordSize = len(uniqWords)

    # Make trainPairs
    trainPairs = trainGenerator(sentences, uniqWords)

    dims = 5
    W1 = Variable(torch.randn(dims, uniqWordSize).float(), requires_grad=True)
    W2 = Variable(torch.randn(uniqWordSize, dims).float(), requires_grad=True)

    epo = 1001

    for i in range(epo):
        avg_loss = 0
        samples = 0
        for x, y in trainPairs:
            x = Variable(torch.from_numpy(x)).float()
            y = Variable(torch.from_numpy(np.array([y])).long())

            samples += len(y)

            a1 = torch.matmul(W1, x)
            a2 = torch.matmul(W2, a1)

            logSoftmax = F.log_softmax(a2, dim=0)
            loss = F.nll_loss(logSoftmax.view(1, -1), y)
            loss.backward()

            avg_loss += loss.item()

            W1.data -= 0.002 * W1.grad.data
            W2.data -= 0.002 * W2.grad.data

            W1.grad.data.zero_()
            W2.grad.data.zero_()

            if i != 0 and 100 < i and i % 100 == 0:
                print(avg_loss / samples)

    parisVecter = W1[:, uniqWords.index('paris')].data.numpy()
    context_to_predict = parisVecter
    hidden = Variable(torch.from_numpy(context_to_predict)).float()
    a = torch.matmul(W2, hidden)
    probs = F.softmax(a, dim=0).data.numpy()
    for context, prob in zip(uniqWords, probs):
        print(f'{context}: {prob:.2f}')
예제 #2
0
def train(num_epochs=100, lr=0.001):
    embedding_size = 10
    W1 = Variable(torch.randn(embedding_size, vocab_size).float(),
                  requires_grad=True)
    W2 = Variable(torch.randn(vocab_size, embedding_size).float(),
                  requires_grad=True)

    for epoch in range(num_epochs):
        loss_val = 0
        for data, target in dataset:
            x = Variable(input_layer(data)).float()
            y_true = Variable(torch.from_numpy(np.array([target])).long())

            z1 = torch.matmul(W1, x)
            z2 = torch.matmul(W2, z1)

            log_softmax = F.log_softmax(z2, dim=0)

            loss = F.nll_loss(log_softmax.view(1, -1), y_true)
            loss_val += loss.item()
            loss.backward()
            W1.data -= lr * W1.grad.data
            W2.data -= lr * W2.grad.data

            W1.grad.data.zero_()
            W2.grad.data.zero_()
        if epoch % 10 == 0:
            print(f'Loss at epoch {epoch}: {loss_val/len(dataset)}')
예제 #3
0
    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1),
                               -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(
                -1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        if self.logp:
            logpt = input
        else:
            logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * at

        loss = -1 * (1 - pt)**self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()
예제 #4
0
파일: decoder.py 프로젝트: YUHANYU/YHY-BYLW
    def forward(self, word_input, last_hidden, encoder_outputs):
        # Note: we run this one step at a time
        # TODO: FIX BATCHING

        # Get the embedding of the current input word (last output word)
        word_embedded = self.embedding(word_input).view(1, 1,
                                                        -1)  # S=1 x B x N
        word_embedded = self.dropout(word_embedded)

        # Calculate attention weights and apply to encoder outputs
        attn_weights = self.attn(last_hidden[-1], encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0,
                                                             1))  # B x 1 x N
        context = context.transpose(0, 1)  # 1 x B x N

        # Combine embedded input word and attended context, run through RNN
        rnn_input = torch.cat((word_embedded, context), 2)
        output, hidden = self.gru(rnn_input, last_hidden)

        # Final output layer
        output = output.squeeze(0)  # B x N
        output = F.log_softmax(self.out(torch.cat((output, context), 1)))

        # Return final output, hidden state, and attention weights (for visualization)
        return output, hidden, attn_weights
예제 #5
0
    def _train_embeddings(self, epochs, lr):
        for epoch in range(epochs):
            loss_val = 0
            for data, target in self.idx_pairs:
                x = torch.zeros(self.n_vocab).float()
                x[data] = 1.0
                y_true = Variable(torch.from_numpy(np.array([target])).long())
                z1 = torch.matmul(self.W1, x)
                z2 = torch.matmul(self.W2, z1)
                log_softmax = F.log_softmax(z2, dim=0)
                '''
                print('data')
                print(data)
                print('target')
                print(target)
                print('y_true')
                print(y_true)
                print('log_softmax')
                print(log_softmax)
                '''
                loss = F.nll_loss(log_softmax.view(1, -1), y_true)
                loss_val += loss.data.item()
                loss.backward()
                self.W1.data -= lr * self.W1.grad.data
                self.W2.data -= lr * self.W2.grad.data

                self.W1.grad.data = self.W1.grad.data.zero_()
                self.W2.grad.data = self.W2.grad.data.zero_()
            if epoch % 10 == 0:
                print('Loss at epoch {}: {}'.format(
                    epoch, loss_val / len(self.idx_pairs)))
 def forward(self, x):
     out = self.conv1(x)
     out = self.trans1(self.dense1(out))
     out = self.trans2(self.dense2(out))
     out = self.dense3(out)
     out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))
     out = F.log_softmax(self.fc(out))
     return out
예제 #7
0
 def forward(self, x):
     x = F.relu(F.max_pool2d(self.conv1(x), 2))
     x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
     x = x.view(-1, 320)
     x = F.relu(self.fc1(x))
     feature = F.dropout(x, training=self.training)
     x = self.fc2(feature)
     return F.log_softmax(x, dim=1), feature
예제 #8
0
 def forward(self, x):
     x = self.dropout(x)
     #         print((x.view(1, len(x), -1)).shape)
     print(((x.view(1, len(x), -1)).view(-1, 20)).shape)
     lstm_out, (h_t, c_t) = self.lstm((x.view(1, len(x),
                                              -1)))  ##lstm output
     model_out = self.output(((x.view(1, len(x), -1)).view(
         -1, 20)))  ##linear layer -> 20 values -> 1 hot vector encoding
     output_pred = F.log_softmax(model_out, dim=0)  ##softmax
     return output_pred, (h_t, c_t)
예제 #9
0
    def forward(self, sentence):
        # sentence: [len(sent), emsize]
        embeds = self.word_embeddings(sentence)
        lstm_out, self.hidden = self.lstm(embeds.view(len(sentence), 1, -1),
                                          self.hidden)
        # tag space: [len(sent), tag size]
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))

        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores
예제 #10
0
    def forward(self, logits, labels):
        logits = logits.float()
        labels = labels.float()

        logprobs = F.log_softmax(logits, dim=-1)

        loss = -labels * logprobs
        loss = loss.sum(-1)

        return loss.mean()
예제 #11
0
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)

        # print(x.shape) #(for size)
        x = x.view(-1, 4*4*50) # batch size: -1 (don't know)
        x = F.relu(self.fc1(x))
        x = self.fx2(x)
        return F.log_softmax(x, dim=1)
예제 #12
0
    def forward(self, x):
        # transform the input
        x = self.stn(x)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
예제 #13
0
    def forward(self, inputs):
        left_input, right_input = inputs
        # from (seq_len, batch_size) to (seq_len, batch_size, emb_dim)
        left_embeds = self.embedding(left_input)
        right_embeds = self.embedding(right_input)
        left_out, self.left_hidden = self.left_encoder(left_embeds, self.left_hidden)
        right_out, self.right_hdden = self.right_encoder(right_embeds, self.right_hidden)

        # concate output from each direction
        # (seq_len, batch_size, hidden_dim * directions) => (batch_size, hidden_dim * directions)
        all_out = torch.cat([left_out[-1], right_out[-1]], dim=1)
        tags = self.hidden2tag(all_out)
        return F.log_softmax(tags, dim=1)
예제 #14
0
    def forward(self, src, has_mask=True):
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
            else:
                self.mask = None

            src = self.encoder(src) * math.sqrt(self.ninp)
            src = self.pos_encoder(src)
            output = self.transformer_encoder(src, self.src_mask)
            output = self.decoder(output)
            return F.log_softmax(output, dim=-1)
예제 #15
0
파일: metrics.py 프로젝트: ferrine/tutil
def evaluate(net, dataloader, num_ens=1):
    """Calculate ensemble accuracy and NLL"""
    accs = []
    nlls = []
    for i, (inputs, labels) in enumerate(dataloader):
        inputs = torch.autograd.Variable(inputs.cuda(async=True))
        labels = torch.autograd.Variable(labels.cuda(async=True))
        outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).cuda()
        for j in range(num_ens):
            outputs[:, :, j] = F.log_softmax(net(inputs), dim=1).data
        accs.append(logits2acc(logmeanexp(outputs, dim=2), labels))
        nlls.append(
            F.nll_loss(torch.autograd.Variable(logmeanexp(outputs, dim=2)),
                       labels,
                       size_average=False).data.cpu().numpy())
    return np.mean(accs), np.sum(nlls)
예제 #16
0
 def forward(self, img_feat, seq):
     # seq[:,0] = 0
     batch_size = img_feat.shape[0]
     state = self.init_hiddden(batch_size)
     outputs = []
     for i in range(seq.shape[1]):
         if i == 0:
             x = self.img_embed(img_feat)
         else:
             if seq[:, i].sum() == 0:
                 break
             x = self.embed(seq[:, i])
         output, state = self.rnn(x, state)
         output = F.log_softmax(self.logit(output), dim=1)
         outputs.append(output)
     return torch.cat([_.unsqueeze(1) for _ in outputs], dim=1)
예제 #17
0
 def forward(self, x):
     x = self.conv1(x)
     x = F.relu(self.bn1(x))
     x = self.pool1(x)
     x = self.conv2(x)
     x = F.relu(self.bn2(x))
     x = self.pool2(x)
     x = self.conv3(x)
     x = F.relu(self.bn3(x))
     x = self.pool3(x)
     x = self.conv4(x)
     x = F.relu(self.bn4(x))
     x = self.pool4(x)
     x = F.avg_pool1d(x, x.shape[-1])
     x = x.permute(0, 2, 1)
     x = self.fc1(x)
     return F.log_softmax(x, dim=2)
예제 #18
0
def main():
    tokens = tokenize(corpus)
    vocabulary = set(sum(tokens, []))  # sum() flattens the 2d list
    vocab_size = len(vocabulary)
    cc_pair = generate_center_context_pair(tokens, 2)
    # pprint(cc_pair)

    word2idx = word2index(tokens)
    idx2word = {key: val for (val, key) in word2idx.items()}
    print(word2idx)
    print(idx2word)

    idx_pairs = get_idxpairs(cc_pair, word2idx)
    idx_pairs = np.array(idx_pairs)

    embedding_dims = 5
    W1 = Variable(torch.randn(embedding_dims, vocab_size).float(),
                  requires_grad=True)
    W2 = Variable(torch.randn(vocab_size, embedding_dims).float(),
                  requires_grad=True)
    max_iter = int(sys.argv[1])
    learning_rate = 0.001

    for i in range(max_iter):
        loss_val = 0
        for data, target in idx_pairs:
            x = Variable(get_input_layer(data, vocab_size)).float()
            y_true = Variable(torch.from_numpy(np.array([target])).long())

            z1 = torch.matmul(W1, x)
            z2 = torch.matmul(W2, z1)

            log_softmax = F.log_softmax(z2, dim=0)

            loss = F.nll_loss(log_softmax.view(1, -1), y_true)
            loss_val += loss.item()
            loss.backward()
            W1.data -= learning_rate * W1.grad.data
            W2.data -= learning_rate * W2.grad.data

            W1.grad.data.zero_()
            W2.grad.data.zero_()
        if i % 10 == 0:
            print(f"Loss at iter {i}: {loss_val/len(idx_pairs)}")
예제 #19
0
def train():
    W1 = torch.randn(EMBEDDING_DIMENSION,
                     VOCAB_SIZE,
                     dtype=torch.float,
                     device=DEVICE,
                     requires_grad=True)
    W2 = torch.randn(VOCAB_SIZE,
                     EMBEDDING_DIMENSION,
                     dtype=torch.float,
                     device=DEVICE,
                     requires_grad=True)
    dataloader = DataLoader(MSMARCO('data/pairs.txt'), MB_SIZE, shuffle=True)
    epoch = 0
    for center, context in dataloader:
        if epoch > EPOCHS:
            break
        total_loss = 0
        for i in tqdm(range(0, MB_SIZE)):
            x = Variable(get_input_layer(center[i])).float().to(DEVICE)
            y = Variable(torch.from_numpy(np.array([context[i]
                                                    ])).long()).to(DEVICE)
            z1 = torch.matmul(W1, x).to(DEVICE)
            z2 = torch.matmul(W2, z1).to(DEVICE)
            log_softmax = F.log_softmax(z2, dim=0).to(DEVICE)
            loss = F.nll_loss(log_softmax.view(1, -1), y)
            total_loss += loss.item()
            loss.backward()
            W1.data -= learning_rate * W1.grad.data
            W2.data -= learning_rate * W2.grad.data
            tmp = W1.grad.data.zero_()
            tmp = W2.grad.data.zero_()
            del x, y, z1, z2, log_softmax, loss, tmp
            torch.cuda.empty_cache()
        epoch += 1
        print_message("Epoch {}: loss {}".format(epoch, total_loss / MB_SIZE))
    idx2vec = W2.data.cpu().numpy()
    pickle.dump(idx2vec, open('data/idx2vec.txt', 'wb'))
    print_message("Word2Vec Finished Training")
예제 #20
0
 def forward(self, word_input, last_hidden, encoder_outputs):
     '''
     :param word_input:
         word input for current time step, in shape (B)
     :param last_hidden:
         last hidden stat of the decoder, in shape (layers*direction*B*H)
     :param encoder_outputs:
         encoder outputs in shape (T*B*H)
     :return
         decoder output
     Note: we run this one step at a time i.e. you should use a outer loop 
         to process the whole sequence
     Tip(update):
     EncoderRNN may be bidirectional or have multiple layers, so the shape of hidden states can be 
     different from that of DecoderRNN
     You may have to manually guarantee that they have the same dimension outside this function,
     e.g, select the encoder hidden state of the foward/backward pass.
     '''
     # Get the embedding of the current input word (last output word)
     word_embedded = self.embedding(word_input.long()).view(1, word_input.size(0), -1) # (1,B,V)
    # self.embedding(word_input.long()).view(batch_size,1,-1)
     word_embedded = self.dropout(word_embedded)
     # Calculate attention weights and apply to encoder outputs
     attn_weights = self.attn(last_hidden[-1], encoder_outputs)
     context = attn_weights.bmm(encoder_outputs.transpose(0, 1))  # (B,1,V)
     context = context.transpose(0, 1)  # (1,B,V)
     # Combine embedded input word and attended context, run through RNN
     rnn_input = torch.cat((word_embedded, context), 2)
     rnn_input = self.attn_combine(rnn_input) # use it in case your size of rnn_input is different
     output, hidden = self.gru(rnn_input, last_hidden)
     output = output.squeeze(0)  # (1,B,V)->(B,V)
     # context = context.squeeze(0)
     # update: "context" input before final layer can be problematic.
     # output = F.log_softmax(self.out(torch.cat((output, context), 1)))
     output = F.log_softmax(self.out(output))
     # Return final output, hidden state
     return output, hidden,attn_weights
예제 #21
0
파일: w2v.py 프로젝트: dailydaniel/univer
def run_model(vocabulary_size: int,
              documents: list,
              word2idx: dict,
              embedding_dims: int = 128,
              num_epochs: int = 101,
              learning_rate: float = 0.001):
    W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(),
                  requires_grad=True)
    W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(),
                  requires_grad=True)

    for epo in range(num_epochs):
        start_time = time.time()
        loss_val = 0
        idx_pairs = create_idx_pairs(documents, word2idx)
        for data, target in idx_pairs:
            x = Variable(get_input_layer(data)).float()
            y_true = Variable(torch.from_numpy(np.array([target])).long())

            z1 = torch.matmul(W1, x)
            z2 = torch.matmul(W2, z1)
            log_softmax = F.log_softmax(z2, dim=0)

            loss = F.nll_loss(log_softmax.view(1, -1), y_true)
            loss_val += loss.data.item()
            loss.backward()

            with torch.no_grad():
                W1 -= learning_rate * W1.grad
                W2 -= learning_rate * W2.grad

                W1.grad.zero_()
                W2.grad.zero_()

        print('Loss at epo {0}: {1}; {2} seconds'.format(
            epo, loss_val / len(idx_pairs), int(time.time() - start_time)))
    return W1, W2
예제 #22
0
def masked_cross_entropy(logits, target, length):
    length = torch.LongTensor(length).to(device)
    """
    Code paraphrased from 
    https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/masked_cross_entropy.py
    """
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.

    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = functional.log_softmax(logits_flat, dim=1)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss
예제 #23
0
    def word2vec(self, words):
        vocabulary = []
        for token in words:
            if token not in vocabulary:
                vocabulary.append(token)

        word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
        idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

        vocabulary_size = len(vocabulary)

        window_size = 2
        idx_pairs = []

        # for sentence in words:
        indices = [word2idx[word] for word in words]

        for center_word_pos in range(len(indices)):
            for w in range(-window_size, window_size + 1):
                context_word_pos = center_word_pos + w
                if context_word_pos < 0 or context_word_pos >= len(
                        indices) or center_word_pos == context_word_pos:
                    continue
                context_word_idx = indices[context_word_pos]
                idx_pairs.append((indices[center_word_pos], context_word_idx))

        idx_pairs = np.array(idx_pairs)

        embedding_dims = 5
        W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(),
                      requires_grad=True)
        W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(),
                      requires_grad=True)
        num_epochs = 1
        learning_rate = 0.01

        for epo in range(num_epochs):
            loss_val = 0
            for data, target in idx_pairs:
                x = Variable(self.get_input_layer(data,
                                                  vocabulary_size)).float()
                y_true = Variable(torch.from_numpy(np.array([target])).long())

                z1 = torch.matmul(W1, x)
                z2 = torch.matmul(W2, z1)

                log_softmax = F.log_softmax(z2, dim=0)

                loss = F.nll_loss(log_softmax.view(1, -1), y_true)
                loss_val += loss.item()
                loss.backward()
                W1.data -= learning_rate * W1.grad.data
                W2.data -= learning_rate * W2.grad.data

                W1.grad.data.zero_()
                W2.grad.data.zero_()
            if epo % 10 == 0:
                print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')

        tmp = []
        for data, target in idx_pairs:
            x = Variable(self.get_input_layer(data, vocabulary_size)).float()
            z1 = torch.matmul(W1, x)
            tmp.append(z1)
        return tmp
예제 #24
0
 def forward(self, x):
     x = nn.relu(self.net1(x))
     x = nn.relu(self.net2(x))
     x = nn.relu(self.net3(x))
     x = self.net4(x)
     return F.log_softmax(x, dim=1)
예제 #25
0
 def forward(self, x):
     # noinspection PyUnresolvedReferences
     return F.log_softmax(self.projection(x), dim=-1)
예제 #26
0
파일: loss.py 프로젝트: eiriksfa/d4dl
 def forward(self, inputs, targets):
     return self.nll_loss(F.log_softmax(inputs), targets)
예제 #27
0
 def forward(self, x):
     x, trans = self.feat(x)
     x = F.relu(self.fc1(x))
     x = F.relu(self.fc2(x))
     x = self.fc3(x)
     return F.log_softmax(x, dim=-1), trans
예제 #28
0
 def forward(self, inputs):
     embeds = self.embeddings(inputs).view((1, -1))
     out = F.relu(self.linear1(embeds))
     out = self.linear2(out)
     log_probs = F.log_softmax(out, dim=1)
     return log_probs
예제 #29
0
 def forward(self, sentence):
     embeds = self.word_embeddings(sentence)
     lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
     tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
     tag_scores = F.log_softmax(tag_space, dim=1)
     return tag_scores
예제 #30
0
파일: w2vectest.py 프로젝트: Tsarpf/log2vec
embedding_dims = 5
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)
num_epochs = 101
learning_rate = 0.001

for epo in range(num_epochs):
    loss_val = 0
    for data, target in idx_pairs:
        x = Variable(get_input_layer(data)).float()
        y_true = Variable(torch.from_numpy(np.array([target])).long())

        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)
    
        log_softmax = F.log_softmax(z2, dim=0)

        loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        loss_val += loss.item()
        loss.backward()
        W1.data -= learning_rate * W1.grad.data
        W2.data -= learning_rate * W2.grad.data

        W1.grad.data.zero_()
        W2.grad.data.zero_()
    #if epo % 10 == 0:    
    print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')


#%%
W1numpy = torch.Tensor.cpu(W1).detach().numpy()