示例#1
0
def train_epoch(dataloader, model, optimizer):
    model.train()
    # acc_lst, loss_lst = [], []
    stats = collections.defaultdict(list)
    for batch_idx, data in enumerate(dataloader):
        fbank, seq_lens, tokens = data
        fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()

        optimizer.zero_grad()
        if args.ngpu <= 1 or args.dist_train:
            loss = model(fbank, seq_lens, tokens).mean()  # / self.accum_grad
        else:
            # apex does not support torch.nn.DataParallel
            loss = (
                data_parallel(model, (fbank, seq_lens, tokens),
                              range(args.ngpu)).mean()  # / self.accum_grad
            )
        if not hasattr(model, "module"):
            if hasattr(model, "acc") and model.acc is not None:
                stats["acc_lst"].append(model.acc)
                model.acc = None
        else:
            if hasattr(model, "acc") and model.module.acc is not None:
                stats["acc_lst"].append(model.module.acc)
                model.module.acc = None
        loss.backward()
        clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        stats["loss_lst"].append(loss.item())
        logging.warning(f"Training batch: {batch_idx+1}/{len(dataloader)}")
    return dict_average(stats)
示例#2
0
def train_uda_epoch(train_loaders, model, optimizer, epoch):
    src_loader, tgt_loader = train_loaders
    iter_source, iter_target = iter(src_loader), iter(tgt_loader)
    model.train()
    stats = collections.defaultdict(list)
    n_batch = min(len(src_loader), len(tgt_loader))
    for batch_idx in range(n_batch):
        src_data = iter_source.next()
        for i in range(len(src_data)):
            src_data[i] = src_data[i].cuda()
        tgt_data = iter_target.next()
        for i in range(len(tgt_data)):
            tgt_data[i] = tgt_data[i].cuda()
        optimizer.zero_grad()
        if args.ngpu <= 1 or args.dist_train:
            ctc_att_loss, uda_loss = model(*src_data, *tgt_data)
        else:
            # apex does not support torch.nn.DataParallel
            ctc_att_loss, uda_loss = (
                data_parallel(model, (*src_data, *tgt_data), range(args.ngpu))
            )
        ctc_att_loss = ctc_att_loss.mean()
        loss = ctc_att_loss
        
        if args.transfer_loss_weight > 0:
            if args.tranfer_loss_weight_warmup_steps > 0:
                current_iter = float(batch_idx + (epoch - 1) * n_batch)
                frac_done = 1.0 * float(current_iter) / args.tranfer_loss_weight_warmup_steps
                current_weight = args.transfer_loss_weight * min(1.0, frac_done)
                stats["transfer_loss_weight"] = current_weight
            else:
                current_weight = args.transfer_loss_weight
            transfer_loss = uda_loss.mean()
            loss = ctc_att_loss + current_weight * transfer_loss
        if not hasattr(model, "module"):
            if hasattr(model, "acc") and model.acc is not None:
                stats["acc_lst"].append(model.acc)
                model.acc = None
        else:
            if hasattr(model, "acc") and model.module.acc is not None:
                stats["acc_lst"].append(model.module.acc)
                model.module.acc = None
        loss.backward()
        clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        stats["ctc_att_loss_lst"].append(ctc_att_loss.item())
        if args.transfer_loss_weight > 0:
            stats["transfer_loss_lst"].append(transfer_loss.item())
        stats["loss_lst"].append(loss.item())
        logging.warning(f"Training batch: {batch_idx+1}/{n_batch}")
    return dict_average(stats)
示例#3
0
def test(dataloader, model, model_path=None):
    if model_path:
        torch_load(model_path, model)
    model.eval()
    stats = collections.defaultdict(list)
    for batch_idx, data in enumerate(dataloader):
        logging.warning(f"Testing batch: {batch_idx+1}/{len(dataloader)}")
        fbank, seq_lens, tokens = data
        fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()
        with torch.no_grad():
            loss = model(fbank, seq_lens, tokens)
        stats["loss_lst"].append(loss.item())
        if not hasattr(model, "module"):
            if model.acc is not None:
                stats["acc_lst"].append(model.acc)
                model.acc = None
        else:
            if model.module.acc is not None:
                stats["acc_lst"].append(model.module.acc)
                model.module.acc = None
    return dict_average(stats)
示例#4
0
def test(epoch,
         dataloader,
         model,
         model_path=None,
         language=None,
         visualize_sim_adapter=False):
    if model_path:
        torch_load(model_path, model)
    orig_model = None
    if hasattr(model, "module"):
        orig_model = model
        model = model.module
    model.eval()
    stats = collections.defaultdict(list)
    for batch_idx, data in enumerate(dataloader):
        logging.warning(f"Testing batch: {batch_idx+1}/{len(dataloader)}")
        if len(data) == 4:
            fbank, seq_lens, tokens, language = data
        else:
            assert language is not None
            fbank, seq_lens, tokens = data
        fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()
        with torch.no_grad():
            loss = model(fbank, seq_lens, tokens, language)

        if visualize_sim_adapter:
            atts = model.calculate_sim_adapter_attentions(
                fbank, seq_lens, tokens, language)
            init_mat = lambda: np.zeros((len(model.fusion_languages), ))
            avg_atts = collections.defaultdict(init_mat)
            count = collections.defaultdict(int)
            for key in atts.keys():
                avg_atts[key] = avg_atts[key] + atts[key].sum(axis=(0, 1))
                count[
                    key] = count[key] + atts[key].shape[0] * atts[key].shape[1]
        stats["loss_lst"].append(loss.item())
        if not hasattr(model, "module"):
            if model.acc is not None:
                stats["acc_lst"].append(model.acc)
                model.acc = None
        else:
            if model.module.acc is not None:
                stats["acc_lst"].append(model.module.acc)
                model.module.acc = None
    if visualize_sim_adapter:
        for key in avg_atts.keys():
            avg_atts[key] = avg_atts[key] / count[key]
            logging.warning(f"Attention scores of {key}: {avg_atts[key]}")
        fig = plt.figure(figsize=(16, 8))
        ax = fig.subplots()
        atts, labels = [], []
        for key in avg_atts.keys():
            atts.append(avg_atts[key])
            labels.append(key)
        atts = np.stack(atts)
        tick_marks = np.arange(len(labels))
        ax.set_yticks(tick_marks)
        ax.set_yticklabels(labels)
        x_labels = list(sorted(model.fusion_languages))
        ax.set_xticks(np.arange(len(x_labels)))
        ax.set_xticklabels(x_labels)
        ax.imshow(atts)
        import itertools
        for i, j in itertools.product(range(atts.shape[0]),
                                      range(atts.shape[1])):
            plt.text(j,
                     i,
                     "{:0.2f}".format(atts[i, j]),
                     horizontalalignment="center",
                     color="white")
        fig.tight_layout()
        fig.savefig(f"{args.outdir}/att_{epoch}.png")
        plt.close()
    if orig_model is not None:
        model = orig_model
    return dict_average(stats)
示例#5
0
def train_maml_epoch(dataloader, model, optimizer, epoch=None):
    model.train()
    stats = collections.defaultdict(list)

    for batch_idx, total_batches in enumerate(dataloader):
        i = batch_idx  # current iteration in epoch
        len_dataloader = len(dataloader)  # total iteration in epoch
        meta_iters = args.epochs * len_dataloader
        current_iter = float(i + (epoch - 1) * len_dataloader)
        frac_done = 1.0 * float(current_iter) / meta_iters
        current_outerstepsize = args.meta_lr * (1. - frac_done)

        weights_original = copy.deepcopy(model.state_dict())
        new_weights = []
        for total_batch in total_batches:  # Iter by languages
            in_batch_size = int(total_batch[0].shape[0] /
                                2)  # In-language batch size
            for meta_step in range(2):  # Meta-train & meta-valid
                if meta_step == 1:
                    last_backup = copy.deepcopy(model.state_dict())
                else:
                    last_backup = None
                batch = list(copy.deepcopy(total_batch))
                for i_batch in range(len(batch) - 1):
                    batch[i_batch] = batch[i_batch][meta_step *
                                                    in_batch_size:(1 +
                                                                   meta_step) *
                                                    in_batch_size]
                batch = tuple(batch)

                fbank, seq_lens, tokens, language = batch
                fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(
                ), tokens.cuda()
                optimizer.zero_grad()
                model.zero_grad()
                if args.ngpu <= 1 or args.dist_train:
                    loss = model(fbank, seq_lens, tokens,
                                 language).mean()  # / self.accum_grad
                else:
                    # apex does not support torch.nn.DataParallel
                    loss = (
                        data_parallel(
                            model, (fbank, seq_lens, tokens, language),
                            range(args.ngpu)).mean()  # / self.accum_grad
                    )
                # print(loss.item())
                loss.backward()
                grad_norm = clip_grad_norm_(model.parameters(), args.grad_clip)
                if math.isnan(grad_norm):
                    logging.warning("grad norm is nan. Do not update model.")
                else:
                    optimizer.step()

                if meta_step == 1:  # Record meta valid
                    if not hasattr(model, "module"):
                        if hasattr(model, "acc") and model.acc is not None:
                            stats["acc_lst"].append(model.acc)
                            model.acc = None
                    else:
                        if hasattr(model,
                                   "acc") and model.module.acc is not None:
                            stats["acc_lst"].append(model.module.acc)
                            model.module.acc = None
                    stats["loss_lst"].append(loss.item())
                    stats["meta_lr"] = current_outerstepsize
                    optimizer.zero_grad()

            for name in last_backup:
                # Compute meta-gradient
                last_backup[name] = model.state_dict(
                )[name] - last_backup[name]
            # Change back to the original parameters for the new language
            new_weights.append(
                last_backup
            )  # updates.append(subtract_vars(self._model_state.export_variables(), last_backup))
            model.load_state_dict(
                {name: weights_original[name]
                 for name in weights_original})

        ws = len(new_weights)
        # Compute average meta-gradient
        fweights = {
            name: new_weights[0][name] / float(ws)
            for name in new_weights[0]
        }
        for i in range(1, ws):
            for name in new_weights[i]:
                fweights[
                    name] = fweights[name] + new_weights[i][name] / float(ws)
        model.load_state_dict({
            name:
            weights_original[name] + (fweights[name] * current_outerstepsize)
            for name in weights_original
        })

        logging.warning(f"Training batch: {batch_idx+1}/{len(dataloader)}")
    return dict_average(stats)
示例#6
0
def train_epoch(dataloader, model, optimizer, epoch=None):
    model.train()
    stats = collections.defaultdict(list)
    for batch_idx, data in enumerate(dataloader):
        fbank, seq_lens, tokens, language = data
        fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()
        if isinstance(optimizer, dict):
            optimizer[language].zero_grad()
        else:
            optimizer.zero_grad()
        model.zero_grad()
        if args.ngpu <= 1 or args.dist_train:
            ctc_att_loss, sim_adapter_guide_loss = model(
                fbank, seq_lens, tokens,
                language)  # .mean() # / self.accum_grad
        else:
            # apex does not support torch.nn.DataParallel
            ctc_att_loss, sim_adapter_guide_loss = (
                data_parallel(model, (fbank, seq_lens, tokens, language),
                              range(args.ngpu))  # .mean() # / self.accum_grad
            )
        loss = ctc_att_loss.mean()
        if args.sim_adapter:
            if hasattr(model, "module"):
                sim_adapter_reg_loss = model.module.get_fusion_regularization_loss(
                )
            else:
                sim_adapter_reg_loss = model.get_fusion_regularization_loss()
            loss = loss + sim_adapter_reg_loss
            stats["sim_adapter_reg_loss_lst"].append(
                sim_adapter_reg_loss.item())
            if args.guide_loss_weight > 0:
                if args.guide_loss_weight_decay_steps > 0:
                    n_batch = len(dataloader)
                    current_iter = float(batch_idx + (epoch - 1) * n_batch)
                    frac_done = 1.0 * float(
                        current_iter) / args.guide_loss_weight_decay_steps
                    current_weight = args.guide_loss_weight * max(
                        0., 1. - frac_done)
                    stats["sim_adapter_guide_loss_weight"] = current_weight
                else:
                    current_weight = args.guide_loss_weight
                sim_adapter_guide_loss = sim_adapter_guide_loss.mean()
                loss = loss + current_weight * sim_adapter_guide_loss
                stats["sim_adapter_guide_loss_lst"].append(
                    sim_adapter_guide_loss.item())

        if not hasattr(model, "module"):
            if hasattr(model, "acc") and model.acc is not None:
                stats["acc_lst"].append(model.acc)
                model.acc = None
        else:
            if hasattr(model, "acc") and model.module.acc is not None:
                stats["acc_lst"].append(model.module.acc)
                model.module.acc = None
        loss.backward()
        grad_norm = clip_grad_norm_(model.parameters(), args.grad_clip)
        if math.isnan(grad_norm):
            logging.warning("grad norm is nan. Do not update model.")
        else:
            if isinstance(optimizer, dict):
                optimizer[language].step()
            else:
                optimizer.step()
            stats["loss_lst"].append(loss.item())
        logging.warning(f"Training batch: {batch_idx+1}/{len(dataloader)}")
    return dict_average(stats)