def get_model_config(self, config):
     if config.model_type == 'davenet':
         self.audio_model = Davenet(input_dim=self.input_size,
                                    embedding_dim=1024)
     elif config.model_type == 'blstm':
         self.audio_model = BLSTM(512,
                                  input_size=self.input_size,
                                  n_layers=config.num_layers)
     self.image_model = nn.Linear(2048, 1024)
     self.attention_model = DotProductClassAttender(
         input_dim=1024, hidden_dim=1024, n_class=self.n_visual_class)
     if config.mode in ['test', 'align']:
         self.load_checkpoint()
Beispiel #2
0
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    dataset = pickle2dict(input_dir + "features_glove.pkl")
    embeddings = pickle2dict(input_dir + "embeddings_glove.pkl")
    dataset["embeddings"] = embeddings

    emb_np = np.asarray(embeddings, dtype=np.float32)
    emb = torch.from_numpy(emb_np).to(device)

    blstm_model = BLSTM(embeddings=emb,
                        input_dim=embsize,
                        hidden_dim=hidden_size,
                        num_layers=n_layers,
                        output_dim=2,
                        max_len=max_len,
                        dropout=dropout)

    blstm_model = blstm_model.to(device)

    optimizer = optim.SGD(blstm_model.parameters(),
                          lr=l_rate,
                          weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    training_set = dataset["training"]
    training_set = YDataset(training_set["xIndexes"],
                            training_set["yLabels"],
                            to_pad=True,
                            max_len=max_len)

    best_acc_test, best_acc_valid = -np.inf, -np.inf
    batches_per_epoch = int(len(training_set) / batch_size)

    for epoch in range(epochs):
        print("Epoch:{}".format(epoch))
        for n_batch in range(batches_per_epoch):
            training_batch = training_set.next_batch(batch_size)
            train(blstm_model, training_batch, optimizer, criterion)
        acc_val = test(blstm_model, dataset, data_part="validation")
        acc_train = test(blstm_model, dataset, data_part="training")
        training_accuracy.append(acc_train)
        validation_accuracy.append(acc_val)
        print("The Training set prediction accuracy is {}".format(acc_train))
        print("The validation set prediction accuracy is {}".format(acc_val))
        print(" ")
Beispiel #3
0
    def get_model_config(self, config):
        if config.model_type == 'blstm':
            self.audio_net = cuda(
                BLSTM(self.K,
                      n_layers=self.n_layers,
                      n_class=self.n_visual_class,
                      input_size=80,
                      ds_ratio=1,
                      bidirectional=True), self.cuda)

        elif config.model_type == 'mlp':
            self.audio_net = cuda(
                MLP(self.K,
                    n_layers=self.n_layers,
                    n_class=self.n_visual_class,
                    input_size=self.input_size,
                    max_seq_len=self.max_segment_num), self.cuda)
        else:
            raise ValueError(f'Invalid model type {config.model_type}')
Beispiel #4
0
 def __init__(self, idim, hdim, K, n_layers, dropout, lamb):
     super(Model, self).__init__()
     self.net = BLSTM(idim, hdim, n_layers, dropout=dropout)
     self.linear = nn.Linear(hdim * 2, K)
     self.loss_fn = CTC_CRF_LOSS(lamb=lamb)
    def __init__(self, config):
        self.config = config

        self.cuda = torch.cuda.is_available()
        self.beta = 1.  # XXX
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.lr = config.lr
        self.n_layers = config.get('num_layers', 3)
        self.eps = 1e-9
        if config.audio_feature == 'mfcc':
            self.audio_feature_net = None
            self.input_size = 80
            self.hop_len_ms = 10
        elif config.audio_feature == 'wav2vec2':
            self.audio_feature_net = cuda(
                fairseq.checkpoint_utils.load_model_ensemble_and_task(
                    [config.wav2vec_path])[0][0], self.cuda)
            for p in self.audio_feature_net.parameters():
                p.requires_grad = False
            self.input_size = 512
            self.hop_len_ms = 20
        elif config.audio_feature == 'cpc':
            self.audio_feature_net = None
            self.input_size = 256
            self.hop_len_ms = 10
        else:
            raise ValueError(
                f"Feature type {config.audio_feature} not supported")

        self.K = config.K
        self.global_iter = 0
        self.global_epoch = 0
        self.audio_feature = config.audio_feature
        self.image_feature = config.image_feature
        self.debug = config.debug
        self.dataset = config.dataset
        self.max_normalize = config.get('max_normalize', False)
        self.loss_type = config.get('loss_type', 'macro_token_floss')
        self.beta_f_measure = config.get('beta_f_measure', 0.3)
        self.weight_word_loss = config.get('weight_word_loss', 1.0)
        self.weight_phone_loss = config.get('weight_phone_loss', 0.0)
        self.ckpt_dir = Path(config.ckpt_dir)
        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)

        if self.loss_type == 'macro_token_floss':
            self.criterion = MacroTokenFLoss(beta=self.beta_f_measure)
        elif self.loss_type == 'binary_cross_entropy':
            self.criterion = nn.BCELoss()
        else:
            raise ValueError(f'Invalid loss type {self.loss_type}')

        # Dataset
        self.data_loader = return_data(config)
        self.ignore_index = config.get('ignore_index', -100)
        self.n_visual_class = self.data_loader['train']\
                              .dataset.preprocessor.num_visual_words
        self.n_phone_class = self.data_loader[
            'train'].dataset.preprocessor.num_tokens
        self.visual_words = self.data_loader[
            'train'].dataset.preprocessor.visual_words
        self.phone_set = self.data_loader['train'].dataset.preprocessor.tokens
        self.max_feat_len = self.data_loader['train'].dataset.max_feat_len
        self.max_word_len = self.data_loader['train'].dataset.max_word_len
        print(f'Number of visual label classes = {self.n_visual_class}')
        print(f'Number of phone classes = {self.n_phone_class}')
        print(f'Max normalized: {self.max_normalize}')

        self.audio_net = cuda(
            BLSTM(self.K,
                  n_layers=self.n_layers,
                  n_class=self.n_phone_class,
                  input_size=self.input_size,
                  ds_ratio=1,
                  bidirectional=True), self.cuda)

        self.phone_net = cuda(
            HMMPronunciator(self.visual_words,
                            self.phone_set,
                            config=config,
                            ignore_index=self.ignore_index), self.cuda)
        self.phone_net.train_model()
        self.align_net = cuda(LinearPositionAligner(scale=0.),
                              self.cuda)  # XXX

        trainables = [p for p in self.audio_net.parameters()]
        optim_type = config.get('optim', 'adam')
        if optim_type == 'sgd':
            self.optim = optim.SGD(trainables, lr=self.lr)
        else:
            self.optim = optim.Adam(trainables, lr=self.lr, betas=(0.5, 0.999))
        self.scheduler = lr_scheduler.ExponentialLR(self.optim, gamma=0.97)
        self.load_ckpt = config.load_ckpt
        if self.load_ckpt or config.mode in ['test', 'cluster']:
            self.load_checkpoint()

        # History
        self.history = dict()
        self.history['token_f1'] = 0.
        self.history['visual_token_f1'] = 0.
        self.history['loss'] = 0.
        self.history['epoch'] = 0
        self.history['iter'] = 0
Beispiel #6
0
 def __init__(self, idim, hdim, K, n_layers, dropout):
     super(Model, self).__init__()
     self.net = BLSTM(idim, hdim, n_layers, dropout)
     self.linear = nn.Linear(hdim * 2, K)
Beispiel #7
0
  def __init__(self, config):
    self.config = config

    self.cuda = torch.cuda.is_available()
    self.epoch = config.epoch
    self.batch_size = config.batch_size
    self.beta = config.beta
    self.lr = config.lr
    self.n_layers = config.get('num_layers', 1)
    self.weight_phone_loss = config.get('weight_phone_loss', 1.)
    self.weight_word_loss = config.get('weight_word_loss', 1.)
    self.anneal_rate = config.get('anneal_rate', 3e-6)
    self.num_sample = config.get('num_sample', 1)
    self.eps = 1e-9
    self.max_grad_norm = config.get('max_grad_norm', None)
    if config.audio_feature == 'mfcc':
      self.audio_feature_net = None
      self.input_size = 80
      self.hop_len_ms = 10

    elif config.audio_feature == 'wav2vec2':
      self.audio_feature_net = cuda(fairseq.checkpoint_utils.load_model_ensemble_and_task([config.wav2vec_path])[0][0],
                                    self.cuda)
      for p in self.audio_feature_net.parameters():
        p.requires_grad = False
      self.input_size = 512
      self.hop_len_ms = 20 
    else:
      raise ValueError(f"Feature type {config.audio_feature} not supported")
   
    self.K = config.K
    self.global_iter = 0
    self.global_epoch = 0
    self.audio_feature = config.audio_feature
    self.image_feature = config.image_feature
    self.debug = config.debug
    self.dataset = config.dataset

    # Dataset
    self.data_loader = return_data(config)
    self.n_visual_class = self.data_loader['train']\
                          .dataset.preprocessor.num_visual_words
    self.n_phone_class = self.data_loader['train']\
                         .dataset.preprocessor.num_tokens
    self.visual_words = self.data_loader['train']\
                        .dataset.preprocessor.visual_words 
    print(f'Number of visual label classes = {self.n_visual_class}')
    print(f'Number of phone classes = {self.n_phone_class}')
  
    self.model_type = config.model_type 
    if config.model_type == 'gumbel_blstm':
      self.audio_net = cuda(GumbelBLSTM(
                              self.K,
                              input_size=self.input_size,
                              n_layers=self.n_layers,
                              n_class=self.n_visual_class,
                              n_gumbel_units=self.n_phone_class,
                              ds_ratio=1,
                              bidirectional=True), self.cuda)
      self.K = 2 * self.K
    elif config.model_type == 'blstm':
      self.audio_net = cuda(BLSTM(
        self.K,
        input_size=self.input_size,
        n_layers=self.n_layers,
        n_class=self.n_visual_class+self.n_phone_class,
        bidirectional=True), self.cuda)
      self.K = 2 * self.K
    elif config.model_type == 'mlp':
      self.audio_net = cuda(GumbelMLP(
                                self.K,
                                input_size=self.input_size,
                                n_class=self.n_visual_class,
                                n_gumbel_units=self.n_phone_class,
                            ), self.cuda)
    elif config.model_type == 'tds':
      self.audio_net = cuda(GumbelTDS(
                              input_size=self.input_size,
                              n_class=self.n_visual_class,
                              n_gumbel_units=self.n_phone_class,
                            ), self.cuda)
    elif config.model_type == 'vq-mlp':
      self.audio_net = cuda(VQMLP(
                              self.K,
                              input_size=self.input_size,
                              n_class=self.n_visual_class,
                              n_embeddings=self.n_phone_class
                            ), self.cuda) 
  
    trainables = [p for p in self.audio_net.parameters()]
    optim_type = config.get('optim', 'adam')
    if optim_type == 'sgd':
      self.optim = optim.SGD(trainables, lr=self.lr)
    else:
      self.optim = optim.Adam(trainables,
                              lr=self.lr, betas=(0.5,0.999))
    self.scheduler = lr_scheduler.ExponentialLR(self.optim, gamma=0.97)
    self.ckpt_dir = Path(config.ckpt_dir)
    if not self.ckpt_dir.exists(): 
      self.ckpt_dir.mkdir(parents=True, exist_ok=True)
    self.load_ckpt = config.load_ckpt
    if self.load_ckpt or config.mode in ['test', 'cluster']: 
      self.load_checkpoint()
    
    # History
    self.history = dict()
    self.history['acc']=0. 
    self.history['token_f1']=0.
    self.history['loss']=0.
    self.history['epoch']=0
    self.history['iter']=0
    self.history['temp']=1.
class Solver(object):
    def __init__(self, config):
        self.config = config
        self.debug = config.debug
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.global_epoch = 0
        self.get_feature_config(config)
        self.get_dataset_config(config)
        self.get_model_config(config)
        self.global_epoch = 0
        self.best_threshold = None
        self.history = dict()
        self.history['f1'] = 0.
        self.history['loss'] = 0.
        self.history['epoch'] = 0.
        if not os.path.exists(config.exp_dir):
            os.makedirs(config.exp_dir)

    def get_feature_config(self, config):
        self.audio_feature = config.audio_feature
        if config.audio_feature == 'mfcc':
            self.audio_feature_net = None
            self.input_size = 80
            self.hop_len_ms = 10
        elif config.audio_feature == 'wav2vec2':
            self.audio_feature_net = fairseq.checkpoint_utils.load_model_ensemble_and_task(
                [config.wav2vec_path])[0][0]
            self.audio_feature_net = self.audio_feature_net.to(self.device)

            for p in self.audio_feature_net.parameters():
                p.requires_grad = False
            self.input_size = 512
            self.hop_len_ms = 20
        elif config.audio_feature == 'cpc':
            self.audio_feature_net = None
            self.input_size = 256
            self.hop_len_ms = 10
        else:
            raise ValueError(
                f"Feature type {config.audio_feature} not supported")

    def get_dataset_config(self, config):
        self.data_loader = return_data(config)
        self.ignore_index = config.get('ignore_index', -100)

        self.n_visual_class = self.data_loader['train']\
                              .dataset.preprocessor.num_visual_words
        self.n_phone_class = self.data_loader[
            'train'].dataset.preprocessor.num_tokens
        self.visual_words = self.data_loader[
            'train'].dataset.preprocessor.visual_words
        self.phone_set = self.data_loader['train'].dataset.preprocessor.tokens
        self.max_feat_len = self.data_loader['train'].dataset.max_feat_len
        self.max_word_len = self.data_loader['train'].dataset.max_word_len
        self.max_normalize = config.get('max_normalize', False)

        print(f'Number of visual label classes = {self.n_visual_class}')
        print(f'Number of phone classes = {self.n_phone_class}')
        print(f'Max normalized: {self.max_normalize}')

    def get_model_config(self, config):
        if config.model_type == 'davenet':
            self.audio_model = Davenet(input_dim=self.input_size,
                                       embedding_dim=1024)
        elif config.model_type == 'blstm':
            self.audio_model = BLSTM(512,
                                     input_size=self.input_size,
                                     n_layers=config.num_layers)
        self.image_model = nn.Linear(2048, 1024)
        self.attention_model = DotProductClassAttender(
            input_dim=1024, hidden_dim=1024, n_class=self.n_visual_class)
        if config.mode in ['test', 'align']:
            self.load_checkpoint()

    def train(self):
        device = self.device
        torch.set_grad_enabled(True)

        args = self.config
        audio_model = self.audio_model
        image_model = self.image_model
        attention_model = self.attention_model
        train_loader = self.data_loader['train']
        n_visual_class = self.n_visual_class

        # Initialize all of the statistics we want to keep track of
        batch_time = AverageMeter()
        data_time = AverageMeter()
        loss_meter = AverageMeter()
        progress = []
        best_epoch, best_acc = 0, -np.inf
        global_step, epoch = 0, 0
        start_time = time.time()
        exp_dir = args.exp_dir

        def _save_progress():
            progress.append([
                epoch, global_step, best_epoch, best_acc,
                time.time() - start_time
            ])
            with open("%s/progress.pkl" % exp_dir, "wb") as f:
                pickle.dump(progress, f)

        # create/load exp
        if args.resume:
            progress_pkl = "%s/progress.pkl" % exp_dir
            progress, epoch, global_step, best_epoch, best_acc = load_progress(
                progress_pkl)
            print("\nResume training from:")
            print("  epoch = %s" % epoch)
            print("  global_step = %s" % global_step)
            print("  best_epoch = %s" % best_epoch)
            print("  best_acc = %.4f" % best_acc)

        if not isinstance(audio_model, torch.nn.DataParallel):
            audio_model = nn.DataParallel(audio_model)

        if not isinstance(image_model, torch.nn.DataParallel):
            image_model = nn.DataParallel(image_model)

        if not isinstance(attention_model, torch.nn.DataParallel):
            attention_model = nn.DataParallel(attention_model)

        if epoch != 0:
            audio_model.load_state_dict(
                torch.load("%s/audio_model.pth" % (exp_dir)))
            image_model.load_state_dict(
                torch.load("%s/image_model.pth" % (exp_dir)))
            print("loaded parameters from epoch %d" % epoch)

        audio_model = audio_model.to(device)
        image_model = image_model.to(device)
        attention_model = attention_model.to(device)
        if args.loss_type == 'macro_token_floss':
            criterion = MacroTokenFLoss()
        elif args.loss_type == 'micro_token_floss':
            criterion = MicroTokenFLoss()
        elif args.loss_type == 'binary_cross_entropy':
            criterion = nn.BCEWithLogitsLoss()
        else:
            raise ValueError(f'Invalid loss type: {args.loss_type}')

        # Set up the optimizer
        audio_trainables = [
            p for p in audio_model.parameters() if p.requires_grad
        ]
        image_trainables = [
            p for p in image_model.parameters() if p.requires_grad
        ]
        attention_trainables = [
            p for p in attention_model.parameters() if p.requires_grad
        ]
        trainables = audio_trainables + image_trainables + attention_trainables
        if args.optim == 'sgd':
            optimizer = torch.optim.SGD(trainables,
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        elif args.optim == 'adam':
            optimizer = torch.optim.Adam(trainables,
                                         args.lr,
                                         weight_decay=args.weight_decay,
                                         betas=(0.95, 0.999))
        else:
            raise ValueError('Optimizer %s is not supported' % args.optim)

        if epoch != 0:
            optimizer.load_state_dict(
                torch.load("%s/optim_state.pth" % (exp_dir, epoch)))
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)
            print("loaded state dict from epoch %d" % epoch)

        epoch += 1

        print("current #steps=%s, #epochs=%s" % (global_step, epoch))
        print("start training...")

        audio_model.train()
        image_model.train()
        attention_model.train()
        while epoch < args.epoch:
            self.global_epoch += 1
            adjust_learning_rate(args.lr, args.lr_decay, optimizer, epoch)
            end_time = time.time()
            audio_model.train()
            image_model.train()
            for i, batch in enumerate(train_loader):
                if self.debug and i > 2:
                    break
                audio_input = batch[0]
                word_label = batch[2]
                input_mask = batch[3]
                word_mask = batch[5]

                # measure data loading time
                data_time.update(time.time() - end_time)
                B = audio_input.size(0)
                audio_input = audio_input.to(device)
                if self.audio_feature == 'wav2vec2':
                    audio_input = self.audio_feature_net.feature_extractor(
                        audio_input)

                word_label = word_label.to(device)
                input_mask = input_mask.to(device)
                word_mask = word_mask.to(device)
                nframes = input_mask.sum(-1)
                word_mask = torch.where(
                    word_mask.sum(dim=(-2, -1)) > 0,
                    torch.tensor(1, device=device),
                    torch.tensor(0, device=device))
                nwords = word_mask.sum(-1)

                # (batch size, n word class)
                word_label_onehot = (F.one_hot(word_label, n_visual_class) *
                                     word_mask.unsqueeze(-1)).sum(-2)
                word_label_onehot = torch.where(word_label_onehot > 0,
                                                torch.tensor(1, device=device),
                                                torch.tensor(0, device=device))
                optimizer.zero_grad()

                audio_output = audio_model(audio_input, masks=input_mask)
                pooling_ratio = round(
                    audio_input.size(-1) / audio_output.size(-2))
                nframes = nframes // pooling_ratio
                input_mask_ds = input_mask[:, ::pooling_ratio]
                word_logit, attn_weights = attention_model(
                    audio_output, input_mask_ds)

                if args.loss_type == 'binary_cross_entropy':
                    loss = criterion(word_logit, word_label_onehot.float())
                else:
                    word_prob = torch.sigmoid(word_logit)
                    if self.max_normalize:
                        word_prob = word_prob / word_prob.max(-1,
                                                              keepdim=True)[0]
                    loss = criterion(word_prob, word_label_onehot,
                                     torch.ones(B, device=device))
                loss.backward()
                optimizer.step()

                # record loss
                loss_meter.update(loss.item(), B)
                batch_time.update(time.time() - end_time)
                global_step += 1
                if i % 500 == 0:
                    info = 'Itr {} {loss_meter.val:.4f} ({loss_meter.avg:.4f})'.format(
                        i, loss_meter=loss_meter)
                    print(info)
                i += 1
            info = ('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss total {loss_meter.val:.4f} ({loss_meter.avg:.4f})'
                    ).format(epoch,
                             i,
                             len(train_loader),
                             batch_time=batch_time,
                             data_time=data_time,
                             loss_meter=loss_meter)
            print(info)

            end_time = time.time()

            if np.isnan(loss_meter.avg):
                print("training diverged...")
                return
            if epoch % 1 == 0:
                precision, recall, f1 = self.validate()
                self.align_finetune()
                avg_acc = f1

                torch.save(audio_model.state_dict(),
                           "%s/audio_model.pth" % (exp_dir))
                torch.save(image_model.state_dict(),
                           "%s/image_model.pth" % (exp_dir))
                torch.save(attention_model.state_dict(),
                           "%s/attention_model.pth" % (exp_dir))
                torch.save(optimizer.state_dict(),
                           "%s/optim_state.pth" % (exp_dir))

                info = f' Epoch: [{epoch}] Loss: {loss_meter.val:.4f}  Token Precision: {precision:.4f} Recall: {recall:.4f}  F1: {f1:.4f}\n'
                save_path = os.path.join(exp_dir, 'result_file.txt')
                with open(save_path, "a") as file:
                    file.write(info)

                if avg_acc > best_acc:
                    self.history['f1'] = f1
                    self.history['loss'] = loss_meter.avg
                    self.history['epoch'] = self.global_epoch
                    best_epoch = epoch
                    best_acc = avg_acc
                    shutil.copyfile("%s/audio_model.pth" % (exp_dir),
                                    "%s/best_audio_model.pth" % (exp_dir))
                    shutil.copyfile("%s/attention_model.pth" % (exp_dir),
                                    "%s/best_attention_model.pth" % (exp_dir))
                    shutil.copyfile("%s/image_model.pth" % (exp_dir),
                                    "%s/best_image_model.pth" % (exp_dir))
                _save_progress()
            epoch += 1

    def validate(self):
        device = self.device
        args = self.config
        audio_model = self.audio_model
        image_model = self.image_model
        attention_model = self.attention_model
        val_loader = self.data_loader['test']
        n_visual_class = self.n_visual_class
        epoch = self.global_epoch
        batch_time = AverageMeter()

        if not isinstance(audio_model, torch.nn.DataParallel):
            audio_model = nn.DataParallel(audio_model)
        if not isinstance(image_model, torch.nn.DataParallel):
            image_model = nn.DataParallel(image_model)
        if not isinstance(attention_model, torch.nn.DataParallel):
            attention_model = nn.DataParallel(attention_model)

        audio_model = audio_model.to(device)
        image_model = image_model.to(device)
        attention_model = attention_model.to(device)

        # switch to evaluate mode
        image_model.eval()
        audio_model.eval()
        attention_model.eval()

        end = time.time()
        N_examples = val_loader.dataset.__len__()
        gold_labels = []
        pred_labels = []
        readable_f = open(
            os.path.join(args.exp_dir, f'keyword_predictions_{epoch}.txt'),
            'w')
        readable_f.write('ID\tGold\tPred\n')
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                if self.debug and i > 2:
                    break
                audio_input = batch[0]
                word_label = batch[2]
                input_mask = batch[3]
                word_mask = batch[5]
                B = audio_input.size(0)
                audio_input = audio_input.to(device)
                if self.audio_feature == 'wav2vec2':
                    audio_input = self.audio_feature_net.feature_extractor(
                        audio_input)

                word_label = word_label.to(device)
                input_mask = input_mask.to(device)
                word_mask = word_mask.to(device)
                nframes = input_mask.sum(-1)
                word_mask = torch.where(
                    word_mask.sum(dim=(-2, -1)) > 0,
                    torch.tensor(1, device=device),
                    torch.tensor(0, device=device))
                nwords = word_mask.sum(-1)

                # (batch size, n word class)
                word_label_onehot = (F.one_hot(word_label, n_visual_class) *
                                     word_mask.unsqueeze(-1)).sum(-2)
                word_label_onehot = torch.where(word_label_onehot > 0,
                                                torch.tensor(1, device=device),
                                                torch.tensor(0, device=device))

                # compute output
                audio_output = audio_model(audio_input, masks=input_mask)
                pooling_ratio = round(
                    audio_input.size(-1) / audio_output.size(-2))
                input_mask_ds = input_mask[:, ::pooling_ratio]
                word_logit, attn_weights = attention_model(
                    audio_output, input_mask_ds)
                pred_label_onehot = (word_logit > 0).long()
                gold_labels.append(
                    word_label_onehot.flatten().detach().cpu().numpy())
                pred_labels.append(
                    pred_label_onehot.flatten().detach().cpu().numpy())

                batch_time.update(time.time() - end)
                end = time.time()

                for ex in range(B):
                    global_idx = i * val_loader.batch_size + ex
                    audio_id = os.path.splitext(
                        os.path.split(
                            val_loader.dataset.dataset[global_idx][0])[1])[0]
                    pred_idxs = torch.nonzero(
                        pred_label_onehot[ex],
                        as_tuple=True)[0].detach().cpu().numpy().tolist()
                    gold_idxs = torch.nonzero(
                        word_label_onehot[ex],
                        as_tuple=True)[0].detach().cpu().numpy().tolist()

                    pred_word_names = '|'.join(
                        val_loader.dataset.preprocessor.to_word_text(
                            pred_idxs))
                    gold_word_names = '|'.join(
                        val_loader.dataset.preprocessor.to_word_text(
                            gold_idxs))
                    readable_f.write(
                        f'{audio_id}\t{gold_word_names}\t{pred_word_names}\n')

            gold_labels = np.concatenate(gold_labels)
            pred_labels = np.concatenate(pred_labels)
        readable_f.close()
        macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
            gold_labels, pred_labels, average='macro')
        print(
            f'Macro Precision: {macro_precision: .3f}, Recall: {macro_recall:.3f}, F1: {macro_f1:.3f}'
        )
        precision, recall, f1, _ = precision_recall_fscore_support(
            gold_labels, pred_labels)
        precision = precision[1]
        recall = recall[1]
        f1 = f1[1]
        print(
            f'Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}')
        return precision, recall, f1

    def align_finetune(self):
        """ Fine-tune the localization threshold on validation set """
        device = self.device
        args = self.config
        audio_model = self.audio_model
        image_model = self.image_model
        attention_model = self.attention_model
        val_loader = self.data_loader['test']
        n_visual_class = self.n_visual_class
        epoch = self.global_epoch
        batch_time = AverageMeter()

        if not isinstance(audio_model, torch.nn.DataParallel):
            audio_model = nn.DataParallel(audio_model)
        if not isinstance(image_model, torch.nn.DataParallel):
            image_model = nn.DataParallel(image_model)
        if not isinstance(attention_model, torch.nn.DataParallel):
            attention_model = nn.DataParallel(attention_model)

        audio_model = audio_model.to(device)
        image_model = image_model.to(device)
        attention_model = attention_model.to(device)

        # switch to evaluate mode
        image_model.eval()
        audio_model.eval()
        attention_model.eval()

        end = time.time()
        gold_masks = []
        pred_masks = []
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                if self.debug and i > 2:
                    break
                audio_input = batch[0]
                word_label = batch[2]
                input_mask = batch[3]
                word_mask = batch[5]
                B = audio_input.size(0)
                audio_input = audio_input.to(device)
                if self.audio_feature == 'wav2vec2':
                    audio_input = self.audio_feature_net.feature_extractor(
                        audio_input)

                word_label = word_label.to(device)
                input_mask = input_mask.to(device)
                word_mask = word_mask.to(device)
                nframes = input_mask.sum(-1).long()
                word_mask = word_mask.sum(-2)
                word_mask_2d = torch.where(
                    word_mask.sum(-1) > 0, torch.tensor(1, device=device),
                    torch.tensor(0, device=device))
                nwords = word_mask_2d.sum(-1).long()
                # (batch size, n word class)
                word_label_onehot = (F.one_hot(word_label, n_visual_class) *
                                     word_mask_2d.unsqueeze(-1)).sum(-2)
                word_label_onehot = torch.where(word_label_onehot > 0,
                                                torch.tensor(1, device=device),
                                                torch.tensor(0, device=device))

                audio_output = audio_model(audio_input, masks=input_mask)
                pooling_ratio = round(
                    audio_input.size(-1) / audio_output.size(-2))
                input_mask_ds = input_mask[:, ::pooling_ratio]
                word_logit, attn_weights = attention_model(
                    audio_output, input_mask_ds)
                attn_weights = attn_weights.unsqueeze(-1).expand(
                    -1, -1, -1, pooling_ratio).contiguous()
                attn_weights = attn_weights.view(B, self.n_visual_class, -1)

                batch_time.update(time.time() - end)
                end = time.time()

                for ex in range(B):
                    global_idx = i * val_loader.batch_size + ex
                    audio_id = os.path.splitext(
                        os.path.split(
                            val_loader.dataset.dataset[global_idx][0])[1])[0]
                    gold_mask = torch.zeros(
                        (self.n_visual_class, nframes[ex].item()),
                        device=device)
                    for w in range(
                            nwords[ex]
                    ):  # Aggregate masks for the same word class
                        gold_mask[word_label[ex, w]] = gold_mask[word_label[ex, w]]\
                                                       + word_mask[ex, w, :nframes[ex]]
                    gold_mask = (gold_mask > 0).long()

                    for v in range(self.n_visual_class):
                        if word_label_onehot[ex, v]:
                            pred_mask = attn_weights[ex, v, :nframes[ex]]
                            pred_masks.append(
                                pred_mask.detach().cpu().numpy().flatten())
                            gold_masks.append(
                                gold_mask[v].detach().cpu().numpy().flatten())

            pred_masks_1d = np.concatenate(pred_masks)
            gold_masks_1d = np.concatenate(gold_masks)
            precision, recall, thresholds = precision_recall_curve(
                gold_masks_1d, pred_masks_1d)
            EPS = 1e-10
            f1 = 2 * precision * recall / (precision + recall + EPS)
            best_idx = np.argmax(f1)
            self.best_threshold = thresholds[best_idx]
            best_precision = precision[best_idx]
            best_recall = recall[best_idx]
            best_f1 = f1[best_idx]
            print(
                f'Best Localization Precision: {best_precision:.3f}\tRecall: {best_recall:.3f}\tF1: {best_f1:.3f}\tmAP: {np.mean(precision)}'
            )

    def align(self):
        if not self.best_threshold:
            self.best_threshold = 0.5
        print(f'Best Threshold: {self.best_threshold}')

        device = self.device
        args = self.config
        audio_model = self.audio_model
        image_model = self.image_model
        attention_model = self.attention_model
        train_loader = self.data_loader['train']
        val_loader = self.data_loader['test']
        n_visual_class = self.n_visual_class
        epoch = self.global_epoch
        batch_time = AverageMeter()

        if not isinstance(audio_model, torch.nn.DataParallel):
            audio_model = nn.DataParallel(audio_model)
        if not isinstance(image_model, torch.nn.DataParallel):
            image_model = nn.DataParallel(image_model)
        if not isinstance(attention_model, torch.nn.DataParallel):
            attention_model = nn.DataParallel(attention_model)

        audio_model = audio_model.to(device)
        image_model = image_model.to(device)
        attention_model = attention_model.to(device)

        # switch to evaluate mode
        image_model.eval()
        audio_model.eval()
        attention_model.eval()

        end = time.time()
        gold_masks = []
        pred_masks = []
        with torch.no_grad():
            # TODO Extract alignments for training set
            pred_word_dict = dict()
            for i, batch in enumerate(val_loader):
                if self.debug and i > 2:
                    break
                audio_input = batch[0]
                word_label = batch[2]
                input_mask = batch[3]
                word_mask = batch[5]
                B = audio_input.size(0)
                audio_input = audio_input.to(device)
                if self.audio_feature == 'wav2vec2':
                    audio_input = self.audio_feature_net.feature_extractor(
                        audio_input)

                word_label = word_label.to(device)
                input_mask = input_mask.to(device)
                word_mask = word_mask.to(device)
                nframes = input_mask.sum(-1)
                word_mask = torch.where(
                    word_mask.sum(dim=(-2, -1)) > 0,
                    torch.tensor(1, device=device),
                    torch.tensor(0, device=device))
                nwords = word_mask.sum(-1)
                # (batch size, n word class)
                word_label_onehot = (F.one_hot(word_label, n_visual_class) *
                                     word_mask.unsqueeze(-1)).sum(-2)
                word_label_onehot = torch.where(word_label_onehot > 0,
                                                torch.tensor(1, device=device),
                                                torch.tensor(0, device=device))

                audio_output = audio_model(audio_input, masks=input_mask)
                pooling_ratio = round(
                    audio_input.size(-1) / audio_output.size(-2))
                word_logit, attn_weights = attention_model(
                    audio_output, input_mask_ds)

                batch_time.update(time.time() - end)
                end = time.time()

                for ex in range(B):
                    global_idx = i * val_loader.batch_size + ex
                    audio_id = os.path.splitext(
                        os.path.split(
                            val_loader.dataset.dataset[global_idx][0])[1])[0]
                    pred_word_dict[audio_id] = {'pred': [], 'gold': []}

                    pred_word_dict[audio_id]['gold'] = self.mask_to_interval(
                        word_mask[ex, :nwords[ex], :nframes[ex]],
                        word_label[ex, :nwords[ex]])
                    for v in range(self.n_visual_class):
                        if word_label_onehot[ex, v]:
                            pred_mask = (attn_weights[ex, v, :nframes[ex]] >=
                                         self.best_threshold).long()
                            pred_word_dict[audio_id]['pred'].extend(
                                self.mask_to_interval(
                                    pred_mask.unsqueeze(-1),
                                    torch.tensor([v], device=device)))
        json.dump(pred_word_dict,
                  open(os.path.join(self.ckpt_dir, 'pred_words.json'), 'w'),
                  indent=2)

    def mask_to_interval(m, y):
        intervals = []
        y = y.detach().cpu().numpy().tolist()
        for ex in m.size(0):
            is_inside = False
            begin = 0
            for t, is_mask in enumerate(m[ex]):
                if is_mask and not is_inside:
                    begin = t
                    is_inside = True
                elif not is_mask and is_inside:
                    intervals.append({
                        'begin': begin * self.hop_len_ms / 1000,
                        'end': (t - 1) * self.hop_len_ms / 1000,
                        'text': y[ex]
                    })
                    is_inside = False
            if is_inside:
                intervals.append({
                    'begin': begin * self.hop_len_ms / 1000,
                    'end': t * self.hop_len_ms / 1000,
                    'text': y[ex]
                })
        return intervals

    def load_checkpoint(self):
        audio_model_file = os.path.join(self.config.exp_dir,
                                        'best_audio_model.pth')
        image_model_file = os.path.join(self.config.exp_dir,
                                        'best_image_model.pth')
        attention_model_file = os.path.join(self.config.exp_dir,
                                            'best_image_model.pth')
        self.audio_model.load_state_dict(torch.load(audio_model_file))
        self.image_model.load_state_dict(torch.load(image_model_file))
        self.attention_model.load_state_dict(torch.load(attention_model_file))