예제 #1
0
class Transformer(nn.Module):
    def __init__(self, train_dataset, cfg):
        super(Transformer, self).__init__()
        self.modality = cfg.modality

        if cfg.modality == 'video':
            self.d_model = cfg.d_model_video
            self.d_feat = cfg.d_vid
            self.d_ff = cfg.d_ff_video
        elif cfg.modality == 'audio':
            self.d_feat = cfg.d_aud
            self.d_model = cfg.d_model_audio
            self.d_ff = cfg.d_ff_audio

        if cfg.use_linear_embedder:
            self.src_emb = FeatureEmbedder(self.d_feat, self.d_model)
        else:
            assert self.d_feat == self.d_model
            self.src_emb = Identity()

        self.trg_emb = VocabularyEmbedder(train_dataset.trg_voc_size,
                                          self.d_model)
        self.pos_emb = PositionalEncoder(self.d_model, cfg.dout_p)
        self.encoder = Encoder(self.d_model, cfg.dout_p, cfg.H, self.d_ff,
                               cfg.N)
        self.decoder = Decoder(self.d_model, cfg.dout_p, cfg.H, self.d_ff,
                               cfg.N)
        self.generator = Generator(self.d_model, train_dataset.trg_voc_size)

        print('initialization: xavier')
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        # initialize embedding after, so it will replace the weights initialized previously
        self.trg_emb.init_word_embeddings(train_dataset.train_vocab.vectors,
                                          cfg.unfreeze_word_emb)

        # load the pretrained encoder from the proposal (used in ablation studies)
        if cfg.pretrained_prop_model_path is not None:
            print(f'Pretrained prop path: \n {cfg.pretrained_prop_model_path}')
            cap_model_cpt = torch.load(cfg.pretrained_prop_model_path,
                                       map_location='cpu')
            encoder_config = cap_model_cpt['config']
            if cfg.modality == 'video':
                self.d_model = encoder_config.d_model_video
                self.d_ff = encoder_config.d_ff_video
            elif cfg.modality == 'audio':
                self.d_model = encoder_config.d_model_audio
                self.d_ff = encoder_config.d_ff_audio
            self.encoder = Encoder(self.d_model, encoder_config.dout_p,
                                   encoder_config.H, self.d_ff,
                                   encoder_config.N)
            encoder_weights = {
                k: v
                for k, v in cap_model_cpt['model_state_dict'].items()
                if 'encoder' in k
            }
            encoder_weights = {
                k.replace('encoder.', ''): v
                for k, v in encoder_weights.items()
            }
            self.encoder.load_state_dict(encoder_weights)
            self.encoder = self.encoder.to(cfg.device)
            for param in self.encoder.parameters():
                param.requires_grad = cfg.finetune_prop_encoder

    def forward(self, src: dict, trg, masks: dict):
        '''
        In: src (B, Ss, d_feat) trg (B, St) src_mask (B, 1, Ss) trg_mask (B, St, St);
        Out: (B, St, voc_size)
        '''
        if self.modality == 'audio':
            src = src['audio']
            src_mask = masks['A_mask']
        elif self.modality == 'video':
            src = src['rgb'] + src['flow']
            src_mask = masks['V_mask']

        trg_mask = masks['C_mask']

        # embed
        src = self.src_emb(src)
        trg = self.trg_emb(trg)
        src = self.pos_emb(src)
        trg = self.pos_emb(trg)

        # encode and decode
        memory = self.encoder(src, src_mask)
        out = self.decoder(trg, memory, src_mask, trg_mask)

        # generate
        out = self.generator(out)

        return out
예제 #2
0
class ProposalGenerator(nn.Module):
    def __init__(self, cfg, anchors):
        super(ProposalGenerator, self).__init__()
        self.cfg = cfg
        self.EPS = 1e-16
        self.num_logits = 3  # 3: c, w, obj
        self.anchors = anchors
        self.anchors_list = anchors[cfg.modality]
        self.anchors_num = len(self.anchors_list)

        if cfg.modality == 'video':
            self.d_feat = cfg.d_vid
            self.d_model_modality = cfg.d_model_video
            self.d_ff = cfg.d_ff_video
            layer_dims = [
                self.d_model_modality, *cfg.conv_layers_video,
                self.num_logits * self.anchors_num
            ]
        elif cfg.modality == 'audio':
            self.d_feat = cfg.d_aud
            self.d_model_modality = cfg.d_model_audio
            self.d_ff = cfg.d_ff_audio
            layer_dims = [
                self.d_model_modality, *cfg.conv_layers_audio,
                self.num_logits * self.anchors_num
            ]
        else:
            raise NotImplementedError

        if cfg.use_linear_embedder:
            self.emb = FeatureEmbedder(self.d_feat, self.d_model_modality)
        else:
            self.emb = Identity()
        self.pos_enc = PositionalEncoder(self.d_model_modality, cfg.dout_p)

        # load the pre-trained encoder from captioning module
        if cfg.pretrained_cap_model_path is not None:
            print(f'Caption path: \n {cfg.pretrained_cap_model_path}')
            cap_model_cpt = torch.load(cfg.pretrained_cap_model_path,
                                       map_location='cpu')
            encoder_config = cap_model_cpt['config']
            if cfg.modality == 'video':
                self.d_model_modality = encoder_config.d_model_video
                self.d_ff = encoder_config.d_ff_video
            elif cfg.modality == 'audio':
                self.d_model_modality = encoder_config.d_model_audio
                self.d_ff = encoder_config.d_ff_audio
            else:
                raise NotImplementedError
            self.encoder = Encoder(self.d_model_modality,
                                   encoder_config.dout_p, encoder_config.H,
                                   self.d_ff, encoder_config.N)
            encoder_weights = {
                k: v
                for k, v in cap_model_cpt['model_state_dict'].items()
                if 'encoder' in k
            }
            encoder_weights = {
                k.replace('module.encoder.', ''): v
                for k, v in encoder_weights.items()
            }
            self.encoder.load_state_dict(encoder_weights)
            self.encoder = self.encoder.to(cfg.device)
            for param in self.encoder.parameters():
                param.requires_grad = cfg.finetune_cap_encoder
        else:
            self.encoder = Encoder(self.d_model_modality, cfg.dout_p, cfg.H,
                                   self.d_ff, cfg.N)
            # encoder initialization
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)

        self.detection_layers = torch.nn.ModuleList([
            ProposalGenerationHead(layer_dims, k, cfg.dout_p, cfg.layer_norm)
            for k in cfg.kernel_sizes[cfg.modality]
        ])

        print(self.detection_layers)
        self.bce_loss = nn.BCELoss()
        self.mse_loss = nn.MSELoss()

    def kernel_size_forward(self, x, layer, stride, targets):
        # in case targets is None
        loss = 0
        losses = {}
        x = layer(x)

        B, S, D = x.shape
        x = x.view(B, S, self.anchors_num, self.num_logits)

        x = x.permute(0, 2, 1, 3).contiguous()
        grid_cell = torch.arange(S).view(1, 1, S).float().to(self.cfg.device)
        # After dividing anchors by the stride, they represent the size size of
        # how many grid celts they are overlapping: 1.2 = 1 and 20% of a grid cell.
        # After multiplying them by the stride, the pixel values are going to be
        # obtained.
        anchors_list = [[anchor / stride] for anchor in self.anchors_list]
        anchors_tensor = torch.tensor(anchors_list, device=self.cfg.device)
        # (A, 2) -> (1, A, 1) for broadcasting
        prior_length = anchors_tensor.view(1, self.anchors_num, 1)

        # prediction values for the *loss* calculation (training)
        sigma_c = torch.sigmoid(x[:, :, :, 0])  # center shift
        l = x[:, :, :, 1]  # log coefficient
        sigma_o = torch.sigmoid(x[:, :, :, 2])  # objectness

        # prediction values that are going to be used for the original image
        # we need to detach them from the graph as we don't need to backproparate on them
        predictions = x.clone().detach()
        # broadcasting (B, A, S) + (1, 1, S)
        predictions[:, :, :, 0] = sigma_c + grid_cell
        # broadcasting (1, A, 1) * (B, A, S)
        predictions[:, :, :, 1] = prior_length * torch.exp(l)
        predictions[:, :, :, 2] = sigma_o

        if targets is not None:
            obj_mask, noobj_mask, gt_x, gt_w, gt_obj = make_targets(
                predictions, targets, anchors_tensor, stride)
            ## Loss
            # Localization
            loss_x = self.mse_loss(sigma_c[obj_mask], gt_x[obj_mask])
            loss_w = self.mse_loss(l[obj_mask], gt_w[obj_mask])
            loss_loc = loss_x + loss_w
            # Confidence
            loss_obj = self.bce_loss(sigma_o[obj_mask], gt_obj[obj_mask])
            loss_noobj = self.bce_loss(sigma_o[noobj_mask], gt_obj[noobj_mask])
            loss_conf = self.cfg.obj_coeff * loss_obj + self.cfg.noobj_coeff * loss_noobj
            # Total loss
            loss = loss_loc + loss_conf

            losses = {
                'loss_x': loss_x,
                'loss_w': loss_w,
                'loss_conf_obj': loss_obj,
                'loss_conf_noobj': loss_noobj
            }

        # for NMS: (B, A, S, 3) -> (B, A*S, 3)
        predictions = predictions.view(B, S * self.anchors_num,
                                       self.num_logits)
        predictions[:, :, :2] *= stride

        return predictions, loss, losses

    def forward(self, x, targets, masks):

        if self.cfg.modality == 'video':
            x = x['rgb'] + x['flow']
            stride = self.cfg.strides['video']
            x = self.emb(x)
            x = self.pos_enc(x)
            x = self.encoder(x, masks['V_mask'])
        elif self.cfg.modality == 'audio':
            x = x['audio']
            stride = self.cfg.strides['audio']
            x = self.emb(x)
            x = self.pos_enc(x)
            x = self.encoder(x, masks['A_mask'])

        all_predictions = []
        # total_loss should have backward
        sum_losses_dict = {}
        total_loss = 0

        for layer in self.detection_layers:
            predictions, loss, loss_dict = self.kernel_size_forward(
                x, layer, stride, targets)
            total_loss += loss
            all_predictions.append(predictions)
            sum_losses_dict = add_dict_to_another_dict(loss_dict,
                                                       sum_losses_dict)

        all_predictions = torch.cat(all_predictions, dim=1)

        return all_predictions, total_loss, sum_losses_dict