示例#1
0
    def training_step(self, batch, batch_idx):
        a, v, t, label, urls = batch
        
        if a.shape[0] < self.batch_size:
            a = pad_batch(a, self.batch_size)
            v = pad_batch(v, self.batch_size)
            t = pad_batch(t, self.batch_size)
            label = pad_batch(label, self.batch_size)


        logits, similarity = self.classifier(a, v, t)
        
        loss = self.loss(logits, label)
        top_1_accuracy = compute_accuracy(logits, label, top_k=1)
        top_5_accuracy = compute_accuracy(logits, label, top_k=5)

        logs = {
            'loss': loss,
            'train_top_1': top_1_accuracy,
            'train_top_5': top_5_accuracy}

        for k in logs:
            prog_bar = True if "top" in k else False
            self.log('train/{}'.format(k), logs[k], prog_bar=prog_bar)


        if similarity is not None and batch_idx % 100 == 0:
            visualize_batch_downstream(similarity, prefix="train", dataset="k700")

        return {'loss': loss,
                'logs': logs}
示例#2
0
    def training_step(self, batch, batch_idx):
        if self.multiple_video:
            a, v1, v2, t, urls, t_mask = batch
            if a.shape[0] < self.batch_size:
                a = pad_batch(a, self.batch_size)
            if v1.shape[0] < self.batch_size:
                    v1 = pad_batch(v1, self.batch_size)
            if v2.shape[0] < self.batch_size:
                    v2 = pad_batch(v2, self.batch_size)
            if t.shape[0] < self.batch_size:
                t = pad_batch(t, self.batch_size)
                t_mask = pad_batch(t_mask, self.batch_size)
            a, v1, v2, t = self.encoder( a, v1, v2, t)
            loss, metrics = self.encoder.loss( a, v1, v2, t, t_mask.to(dtype=torch.float32))
        else:

            # a, v, t, urls, t_mask = batch
            a, v, t, i, urls, t_mask = batch
            if a.shape[0] < self.batch_size:
                a = pad_batch(a, self.batch_size)
            if v.shape[0] < self.batch_size:
                v = pad_batch(v, self.batch_size)
            if t.shape[0] < self.batch_size:
                t = pad_batch(t, self.batch_size)
                t_mask = pad_batch(t_mask, self.batch_size)
            if i.shape[0] < self.batch_size:
                i = pad_batch(i, self.batch_size)

            a, v, t, i = self.encoder(a, v, t, i)
            loss, metrics = self.encoder.loss(a, v, t, i, t_mask.to(dtype=torch.float32))

            # a, v, t = self.encoder(a, v, t)
            # loss, metrics = self.encoder.loss(a, v, t, t_mask.to(dtype=torch.float32))
            
        for k in metrics:
            if 'matrix' not in k:
                prog_bar = 'av_top1' in k
                self.log('train/{}'.format(k), metrics[k], prog_bar=prog_bar)

        self.log('lr', self.loggable_lr)

        torch.cuda.empty_cache()

        try:
            if batch_idx % 100 == 0:
                q = 'st' if self.pretrained_text else 'c'
                q += 'vq'
                if 'vt_matrix' in metrics:
                    visualize_batch(urls, metrics['vt_matrix'], prefix="train", qualifier=q, mode='vt')
                else: 
                    visualize_batch(urls, metrics['av_matrix'], prefix="train", qualifier=q, mode='av')
        except: 
            a = batch_idx % 100

        return {'loss': loss,
                'logs': metrics}
示例#3
0
    def test_step(self, batch, batch_idx):
        if self.multiple_video:
            a, v1, v2, t, urls, t_mask = batch
            if a.shape[0] < self.batch_size:
                a = pad_batch(a, self.batch_size)
            if v1.shape[0] < self.batch_size:
                    v1 = pad_batch(v1, self.batch_size)
            if v2.shape[0] < self.batch_size:
                    v2 = pad_batch(v2, self.batch_size)
            if t.shape[0] < self.batch_size:
                t = pad_batch(t, self.batch_size)
                t_mask = pad_batch(t_mask, self.batch_size)
            a, v1, v2, t = self.encoder( a, v1, v2, t)
            loss, metrics = self.encoder.loss( a, v1, v2, t, t_mask.to(dtype=torch.float32))
        else:

            a, v, t, urls, t_mask = batch
            if a.shape[0] < self.batch_size:
                a = pad_batch(a, self.batch_size)
            if v.shape[0] < self.batch_size:
                v = pad_batch(v, self.batch_size)
            if t.shape[0] < self.batch_size:
                t = pad_batch(t, self.batch_size)
                t_mask = pad_batch(t_mask, self.batch_size)

            a, v, t = self.encoder(a, v, t)
            loss, metrics = self.encoder.loss(a, v, t, t_mask.to(dtype=torch.float32))

        return {'test_total_loss': metrics['total_loss'],}
示例#4
0
    def validation_step(self, batch, batch_idx):
        a, v, t, urls, t_mask = batch
        # a, v, t, i, urls, t_mask = batch
        if a.shape[0] < self.batch_size:
            a = pad_batch(a, self.batch_size)
        if v.shape[0] < self.batch_size:
            v = pad_batch(v, self.batch_size)
        if t.shape[0] < self.batch_size:
            t = pad_batch(t, self.batch_size)
            t_mask = pad_batch(t_mask, self.batch_size)
        # if i.shape[0] < self.batch_size:
        #     i = pad_batch(i, self.batch_size)

        # a, v, t, i = self.encoder(a, v, t, i)
        # loss, metrics = self.encoder.loss(a, v, t, i, t_mask.to(dtype=torch.float32))

        a, v, t = self.encoder(a, v, t)
        loss, metrics = self.encoder.loss(a, v, t,
                                          t_mask.to(dtype=torch.float32))

        for k in metrics:
            if 'matrix' not in k:
                prog_bar = 'av_top1' in k
                self.log('val/{}'.format(k), metrics[k], prog_bar=prog_bar)

        torch.cuda.empty_cache()

        try:
            if batch_idx % 100 == 0:
                q = 'st' if self.pretrained_text else 'c'
                q += 'vq'
                if 'vt_matrix' in metrics:
                    visualize_batch(urls,
                                    metrics['vt_matrix'],
                                    prefix="val",
                                    qualifier=q,
                                    mode='vt')
                else:
                    visualize_batch(urls,
                                    metrics['av_matrix'],
                                    prefix="val",
                                    qualifier=q,
                                    mode='av')
        except:
            a = batch_idx % 100

        return {'val_total_loss': loss}
示例#5
0
    def validation_step(self, batch, batch_idx):
        a, v, t = batch
        if a.shape[0] < self.batch_size:
            a = pad_batch(a, self.batch_size)
        if v.shape[0] < self.batch_size:
            v = pad_batch(v, self.batch_size)
        if t.shape[0] < self.batch_size:
            t = pad_batch(t, self.batch_size)
            t_mask = pad_batch(t_mask, self.batch_size)

        a, v, t = self.encoder(a, v, t)
        loss, metrics = self.encoder.loss(a, v, t)

        for k in metrics:
            if 'matrix' not in k:
                prog_bar = 'av_top1' in k
                self.log('val/{}'.format(k), metrics[k], prog_bar=prog_bar)

        torch.cuda.empty_cache()

        return {'val_total_loss': loss}
示例#6
0
    def test_step(self, batch, batch_idx):
        v, label = batch
        v = v.permute(0, 2, 1, 3, 4)

        if v.shape[0] < self.batch_size:
            v = pad_batch(v, self.batch_size)
            label = pad_batch(label, self.batch_size)

        features = self.model.encoder.v_encoder(v)
        logits = self.fc(features)

        loss = self.loss(logits, label)
        top_1_accuracy = compute_accuracy(logits, label, top_k=1)
        top_5_accuracy = compute_accuracy(logits, label, top_k=5)

        logs = {
            'test_loss': loss,
            'test_top_1': top_1_accuracy,
            'test_top_5': top_5_accuracy
        }

        for k in logs:
            self.log('test/{}'.format(k), logs[k], prog_bar=False)
示例#7
0
    def training_step(self, batch, batch_idx):
        a, v, t = batch
        if a.shape[0] < self.batch_size:
            a = pad_batch(a, self.batch_size)
        if v.shape[0] < self.batch_size:
            v = pad_batch(v, self.batch_size)
        if t.shape[0] < self.batch_size:
            t = pad_batch(t, self.batch_size)
            t_mask = pad_batch(t_mask, self.batch_size)

        a, v, t = self.encoder(a, v, t)
        loss, metrics = self.encoder.loss(a, v, t)

        for k in metrics:
            if 'matrix' not in k:
                prog_bar = 'av_top1' in k
                self.log('train/{}'.format(k), metrics[k], prog_bar=prog_bar)

        self.log('lr', self.loggable_lr)

        torch.cuda.empty_cache()

        return {'loss': loss, 'logs': metrics}