Exemplo n.º 1
0
    def training_step(self, batch, batch_idx):
        v, _, label = batch
        v = v.permute(0, 2, 1, 3, 4)

        if v.shape[0] < self.batch_size:
            v = pad_batch(v, self.batch_size)
            label = pad_batch(label, self.batch_size)

        with torch.no_grad():
            features = self.model.encoder.v_encoder(v)
        logits = self.fc(features)

        loss = self.loss(logits, label)
        top_1_accuracy = compute_accuracy(logits, label, top_k=1)
        top_5_accuracy = compute_accuracy(logits, label, top_k=5)

        logs = {
            'loss': loss,
            'train_top_1': top_1_accuracy,
            'train_top_5': top_5_accuracy
        }

        for k in logs:
            prog_bar = True if "top" in k else False
            self.log('train/{}'.format(k), logs[k], prog_bar=prog_bar)

        return {'loss': loss, 'logs': logs}
Exemplo n.º 2
0
    def validation_step(self, batch, batch_idx):
        v, v2, label = batch
        v = v.permute(0, 2, 1, 3, 4)
        v2 = v2.permute(0, 2, 1, 3, 4)

        if v.shape[0] < self.batch_size:
            v = pad_batch(v, self.batch_size)
            label = pad_batch(label, self.batch_size)

        # with torch.no_grad():
        features = self.model.encoder.v_encoder(v)
        logits = self.fc(features)

        f1, f2 = self.model.encoder(v, v2)
        similarity = f1 @ f2.T
        visualize_batch_downstream(similarity,
                                   prefix="val",
                                   dataset="cvrl_ucf")

        loss = self.loss(logits, label)
        top_1_accuracy = compute_accuracy(logits, label, top_k=1)
        top_5_accuracy = compute_accuracy(logits, label, top_k=5)

        logs = {
            'val_loss': loss,
            'val_top_1': top_1_accuracy,
            'val_top_5': top_5_accuracy
        }

        for k in logs:
            prog_bar = True if "top" in k else False
            self.log('val/{}'.format(k), logs[k], prog_bar=prog_bar)

        return logs
Exemplo n.º 3
0
    def validation_step(self, batch, batch_idx):
        a, v, t, label, urls = batch

        if a.shape[0] < self.batch_size:
            a = pad_batch(a, self.batch_size)
            v = pad_batch(v, self.batch_size)
            t = pad_batch(t, self.batch_size)
            label = pad_batch(label, self.batch_size)

        logits, similarity = self.classifier(a, v, t)

        loss = self.loss(logits, label)
        top_1_accuracy = compute_accuracy(logits, label, top_k=1)
        top_5_accuracy = compute_accuracy(logits, label, top_k=5)

        logs = {
            'val_loss': loss,
            'val_top_1': top_1_accuracy,
            'val_top_5': top_5_accuracy}

        if similarity is not None and batch_idx % 100 == 0:
            visualize_batch_downstream(similarity, prefix="val", dataset="k700")

        for k in logs:
            prog_bar = True if "top" in k else False
            self.log('val/{}'.format(k), logs[k], prog_bar=prog_bar)

        return logs
Exemplo n.º 4
0
    def test_step(self, batch, batch_idx):
        a, v, t, label = batch
        logits, similarity = self.classifier(a, v, t)

        loss = self.loss(logits, label)
        top_1_accuracy = compute_accuracy(logits, label, top_k=1)
        top_5_accuracy = compute_accuracy(logits, label, top_k=5)

        logs = {
            'test_loss': loss,
            'test_top_1': top_1_accuracy,
            'test_top_5': top_5_accuracy}
            
        for k in logs:
            self.log('test/{}'.format(k), logs[k], prog_bar=False)
Exemplo n.º 5
0
    def loss(self, a, v, t):

        a = self.norm(self.a_proj(a))
        v = self.norm(self.v_proj(v))
        t = self.norm(self.t_proj(t))

        # approximate centroid vector
        centroid = (a + v + t) / 3
        centroid = self.norm(centroid)

        avt_loss = nce_loss(a, centroid, temp=self.temperature)
        avt_loss += nce_loss(v, centroid, temp=self.temperature)
        avt_loss += nce_loss(t, centroid, temp=self.temperature)

        av_loss = nce_loss(a, v, temp=self.temperature)
        vt_loss = nce_loss(v, t, temp=self.temperature)

        loss = avt_loss + av_loss + vt_loss

        av = a @ v.T
        vt = v @ t.T

        labels = torch.arange(self.batch_size).to(device=v.device)
        av_top1 = 0.5 * (compute_accuracy(av, labels, top_k=1) +
                         compute_accuracy(av.T, labels, top_k=1))
        vt_top1 = 0.5 * (compute_accuracy(vt, labels, top_k=1) +
                         compute_accuracy(vt.T, labels, top_k=1))

        metrics = {
            'loss': loss.item(),
            'av_loss': av_loss.item(),
            'vt_loss': vt_loss.item(),
            'av_top1': av_top1.item(),
            'vt_top1': vt_top1.item(),
            'temp': self.temperature,
            'vt_matrix': vt.T.detach().to(dtype=torch.float32),
        }

        return loss, metrics
Exemplo n.º 6
0
    def test_step(self, batch, batch_idx):
        v, label = batch
        v = v.permute(0, 2, 1, 3, 4)

        if v.shape[0] < self.batch_size:
            v = pad_batch(v, self.batch_size)
            label = pad_batch(label, self.batch_size)

        features = self.model.encoder.v_encoder(v)
        logits = self.fc(features)

        loss = self.loss(logits, label)
        top_1_accuracy = compute_accuracy(logits, label, top_k=1)
        top_5_accuracy = compute_accuracy(logits, label, top_k=5)

        logs = {
            'test_loss': loss,
            'test_top_1': top_1_accuracy,
            'test_top_5': top_5_accuracy
        }

        for k in logs:
            self.log('test/{}'.format(k), logs[k], prog_bar=False)
Exemplo n.º 7
0
    def loss(self, a, v, t, t_mask):
        # zero out t vectors from default string

        codes = torch.nn.functional.normalize(self.codebook, p=2, dim=-1)
        softmax = torch.nn.Softmax(dim=-1)
        #cross entropy
        ce_loss = torch.nn.CrossEntropyLoss()
        temp = 1e+3

        a_logits = a @ codes.T
        v_logits = v @ codes.T
        t_logits = t @ codes.T

        a_onehot = softmax(a_logits * temp).view(-1, self.num_embeddings)
        v_onehot = softmax(v_logits * temp).view(-1, self.num_embeddings)
        t_onehot = softmax(t_logits * temp).view(-1, self.num_embeddings)

        #contastive
        # a_latent = (a_onehot.view(-1, self.num_embeddings) * codes.T).sum(-1)
        # v_latent = (v_onehot.view(-1, self.num_embeddings) * codes.T).sum(-1)
        # t_latent = (t_onehot.view(-1, self.num_embeddings) * codes.T).sum(-1)

        # av_loss = 0.5 * (nce_loss(a, t_latent, temp=self.temperature) + nce_loss(v, t_latent, temp=self.temperature))
        # at_loss = 0.5 * (nce_loss(a, v_latent, temp=self.temperature) + nce_loss(t, v_latent, temp=self.temperature))
        # vt_loss = 0.5 * (nce_loss(v, a_latent, temp=self.temperature) + nce_loss(t, a_latent, temp=self.temperature))

        a_logits = a_logits.view(-1, self.num_embeddings)
        v_logits = v_logits.view(-1, self.num_embeddings)
        t_logits = t_logits.view(-1, self.num_embeddings)

        a_label = torch.argmax(a_onehot.view(-1, self.num_embeddings), dim=-1)
        v_label = torch.argmax(v_onehot.view(-1, self.num_embeddings), dim=-1)
        t_label = torch.argmax(t_onehot.view(-1, self.num_embeddings), dim=-1)

        print(v_label)

        av_loss = 0.5 * (ce_loss(a_logits, t_label) +
                         ce_loss(v_logits, t_label))
        at_loss = 0.5 * (ce_loss(a_logits, v_label) +
                         ce_loss(t_logits, v_label))
        vt_loss = 0.5 * (ce_loss(v_logits, a_label) +
                         ce_loss(t_logits, a_label))

        av_top1 = 0.5 * (compute_accuracy(a_onehot, v_label, top_k=1) +
                         compute_accuracy(v_onehot, a_label, top_k=1))
        vt_top1 = 0.5 * (compute_accuracy(v_onehot, t_label, top_k=1) +
                         compute_accuracy(t_onehot, v_label, top_k=1))

        a_prob = softmax(
            a_logits.view(self.batch_size, self.seqlen,
                          -1).transpose(0, 1).sum(dim=1))
        v_prob = softmax(
            v_logits.view(self.batch_size, self.seqlen,
                          -1).transpose(0, 1).sum(dim=1))
        t_prob = softmax(
            t_logits.view(self.batch_size, self.seqlen,
                          -1).transpose(0, 1).sum(dim=1))

        a_entropy = (-a_prob * torch.log(a_prob)).sum(dim=-1).mean()
        v_entropy = (-v_prob * torch.log(v_prob)).sum(dim=-1).mean()
        t_entropy = (-t_prob * torch.log(t_prob)).sum(dim=-1).mean()

        # loss = 10*(av_loss + at_loss + vt_loss) - a_entropy - v_entropy - t_entropy
        loss = av_loss + at_loss + vt_loss + a_entropy + v_entropy + t_entropy
        # loss = av_loss + at_loss + vt_loss

        # metrics = {
        #     'loss': loss.item(),
        #     'av_loss': av_loss.item(),
        #     'at_loss': at_loss.item(),
        #     'vt_loss': vt_loss.item(),
        #     'a_entropy': a_entropy.item(),
        #     'v_entropy': v_entropy.item(),
        #     't_entropy': t_entropy.item(),
        #     'temp': self.temperature,
        #     'vt_matrix': (v[:, 0] @ t[:, 0].T).detach()
        # }

        metrics = {
            'loss': loss.item(),
            'av_top1': av_top1.item(),
            'vt_top': vt_top1.item(),
            'av_loss': av_loss.item(),
            'at_loss': at_loss.item(),
            'vt_loss': vt_loss.item(),
            'a_entropy': a_entropy.item(),
            'v_entropy': v_entropy.item(),
            't_entropy': t_entropy.item(),
            'temp': self.temperature,
            'vt_matrix': (v[:, 0] @ t[:, 0].T).detach()
        }

        return loss, metrics
Exemplo n.º 8
0
    def loss(self, a, v, t, t_mask):
        # zero out t vectors from default string
        v = nan_filter(v)

        codes = torch.nn.functional.normalize(self.codebook, p=2,
                                              dim=-1).permute(0, 2, 1)
        softmax = torch.nn.Softmax(dim=-1)
        #cross entropy
        ce_loss = torch.nn.CrossEntropyLoss()
        temp = 1e+3

        a_logits = (a.unsqueeze(2) @ codes.T).squeeze()
        v_logits = (v.unsqueeze(2) @ codes.T).squeeze()
        t_logits = (t.unsqueeze(2) @ codes.T).squeeze()

        a_onehot = softmax(a_logits).view(-1, self.num_embeddings)
        v_onehot = softmax(v_logits).view(-1, self.num_embeddings)
        t_onehot = softmax(t_logits).view(-1, self.num_embeddings)

        a_label = torch.argmax(a_onehot, dim=-1)
        v_label = torch.argmax(v_onehot, dim=-1)
        t_label = torch.argmax(t_onehot, dim=-1)

        print(v.mean())

        # print('raw', is_nan(v))
        # print('raw', is_nan(a), is_nan(v), is_nan(t))
        # print('logits', is_nan(a_logits), is_nan(v_logits), is_nan(t_logits))
        # print('onehot', is_nan(a_onehot), is_nan(v_onehot), is_nan(t_onehot))
        # print('label', is_nan(a_label), is_nan(v_label), is_nan(t_label))

        av_loss = 0.5 * (ce_loss(a_onehot, v_label) +
                         ce_loss(v_onehot, a_label))
        vt_loss = 0.5 * (ce_loss(v_onehot, t_label) +
                         ce_loss(t_onehot, v_label))

        av_top1 = 0.5 * (compute_accuracy(a_onehot, v_label, top_k=1) +
                         compute_accuracy(v_onehot, a_label, top_k=1))
        vt_top1 = 0.5 * (compute_accuracy(v_onehot, t_label, top_k=1) +
                         compute_accuracy(t_onehot, v_label, top_k=1))

        a_prob = (a_onehot.view(self.batch_size, self.seqlen,
                                -1)).transpose(0, 1).sum(dim=1)
        v_prob = (v_onehot.view(self.batch_size, self.seqlen,
                                -1)).transpose(0, 1).sum(dim=1)
        t_prob = (t_onehot.view(self.batch_size, self.seqlen,
                                -1)).transpose(0, 1).sum(dim=1)

        a_entropy = (-a_prob * torch.log(a_prob)).sum(dim=-1).mean()
        v_entropy = (-v_prob * torch.log(v_prob)).sum(dim=-1).mean()
        t_entropy = (-t_prob * torch.log(t_prob)).sum(dim=-1).mean()

        # loss = av_loss + vt_loss + a_entropy + v_entropy + t_entropy
        loss = av_loss + vt_loss

        metrics = {
            'loss': loss.item(),
            'av_top1': av_top1.item(),
            'vt_top': vt_top1.item(),
            'av_loss': av_loss.item(),
            'vt_loss': vt_loss.item(),
            'a_entropy': a_entropy.item(),
            'v_entropy': v_entropy.item(),
            't_entropy': t_entropy.item(),
            'temp': self.temperature,
            'vt_matrix': (v[:, 0] @ t[:, 0].T).detach()
        }

        return loss, metrics