Пример #1
0
    def __init__(self, d_in, d_outs, gen_func):
        super().__init__()
        aeq(len(d_outs), 2)

        self.generators = nn.ModuleDict()
        self.generators['base'] = Generator(d_in, d_outs[0], gen_func)
        self.generators['crosslingual'] = Generator(d_in, d_outs[1], gen_func)

        self.add_switch('generator', self.generators, 'crosslingual', 'base')
Пример #2
0
    def __init__(self, d_in, d_hid, d_out, mode='residual'):
        super().__init__()
        self.event_hidden_layer = nn.Linear(d_in, d_hid)
        self.agent_hidden_layer = nn.Linear(d_in, d_hid)
        self.theme_hidden_layer = nn.Linear(d_in, d_hid)
        self.hidden = nn.Linear(d_hid * 3, d_out)

        assert mode in ['nonlinear', 'sum', 'residual']
        self.mode = mode
        if self.mode == 'residual':
            aeq(d_in, d_out)
Пример #3
0
    def _example_dict_iter(self, line, index):
        line = line.split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        words, feats, n_feats = TextDataset.extract_text_features(line)
        example_dict = {self.side: words, "indices": index}
        if feats:
            # All examples must have same number of features.
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update((prefix + str(j), f)
                                for j, f in enumerate(feats))

        return example_dict
Пример #4
0
    def score(self, h_t, h_s):
        """
        Args:
          h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)``
          h_s (FloatTensor): sequence of sources ``(batch, src_len, dim``

        Returns:
          FloatTensor: raw attention scores (unnormalized) for each src index
            ``(batch, tgt_len, src_len)``
        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = torch.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
Пример #5
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank = memory_bank + self.linear_cover(cover).view_as(
                memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        # if memory_lengths is not None:
        #     mask = sequence_mask(memory_lengths, max_len=align.size(-1))
        #     mask = mask.unsqueeze(1)  # Make it broadcastable.
        #     align.masked_fill_(1 - mask, -float('inf'))

        if self.mask is not None:
            align.data.masked_fill_(self.mask.view(-1, self.mask.size(-1)),
                                    -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        align_vectors = F.softmax(align.view(batch * target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)
            if coverage is not None:
                covered = F.softmax(
                    align.view(batch * target_l, source_l) / self.temp, -1)
                covered = covered.view(batch, target_l, source_l).squeeze(1)
                coverage = coverage - covered
            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors, coverage
Пример #6
0
    def _gradient_accumulation(self, true_batches, normalization, total_stats,
                               report_stats):
        if self.accum_count > 1:
            self.optim.zero_grad()

        for k, batch in enumerate(true_batches):
            task = self._get_task(batch)
            if task.category == 'lm':
                target_size = batch.tgt_agent.size(0)
            else:
                target_size = batch.tgt.size(0)
            # Truncated BPTT: reminder not compatible with accum > 1
            if self.trunc_size:
                trunc_size = self.trunc_size
            else:
                trunc_size = target_size

            if task.name == 'lm' or (task.name == 'crosslingual'
                                     and self.crosslingual == 'lm'):
                src_attr = 'src_event'
            elif task.name == 'crosslingual':
                src_attr = 'src_old' if batch.eat_format == 'combined' else 'src'
            elif batch.eat_format in ['old', 'new']:
                src_attr = 'src'
            else:
                src_attr = 'src_old'
            src, src_lengths = getattr(batch, src_attr) if isinstance(getattr(batch, src_attr), tuple) \
                else (getattr(batch, src_attr), None)

            r_stats = report_stats[task.name]
            t_stats = total_stats[task.name]
            if src_lengths is not None:
                r_stats.n_src_words += src_lengths.sum().item()

            tgt_outer = batch.tgt

            bptt = False
            for j in range(0, target_size - 1, trunc_size):
                # 1. Create truncated target.
                tgt = tgt_outer[j:j + trunc_size]

                # 2. F-prop all but generator.
                if self.accum_count == 1:
                    self.optim.zero_grad()

                # 2.5 Prepare for crosslingual mode if needed.
                outputs, attns = self.model(src,
                                            tgt,
                                            src_lengths,
                                            bptt=bptt,
                                            task=task)
                bptt = True

                # 2.9 Get appropriate loss function.
                if task.name == 'crosslingual':
                    loss_func = self.crosslingual_train_loss
                elif task.name == 'lm':
                    loss_func = self.lm_train_loss
                elif task.name == 'aux':
                    loss_func = self.aux_train_loss
                else:
                    loss_func = self.train_loss

                # 3. Compute loss.
                try:

                    def loss_call(batch, outputs):
                        loss, batch_stats = loss_func(
                            batch,
                            outputs,
                            attns,
                            normalization=normalization,
                            shard_size=self.shard_size,
                            trunc_start=j,
                            trunc_size=trunc_size)
                        return loss, batch_stats

                    if task.category == 'lm':
                        agent_preds = outputs['agent']
                        theme_preds = outputs['theme']

                        batch.agent_preds = agent_preds
                        batch.theme_preds = theme_preds
                        batch.tgt_backup = batch.tgt

                        def update_loss_and_stats(tgt_attr, preds, loss,
                                                  batch_stats):
                            batch.tgt = getattr(batch, tgt_attr)
                            this_loss, this_batch_stats = loss_call(
                                batch, preds)
                            if this_loss is not None:
                                raise RuntimeError(
                                    'loss is not properly updated from within this function.'
                                )
                            batch_stats.update(this_batch_stats)

                        loss = None
                        batch_stats = onmt.utils.Statistics()
                        update_loss_and_stats('tgt_agent', agent_preds, loss,
                                              batch_stats)
                        # update_loss_and_stats('tgt_agent_mod', agent_preds, loss, batch_stats)
                        # update_loss_and_stats('tgt_theme', theme_preds, loss, batch_stats)
                        # update_loss_and_stats('tgt_theme_mod', theme_preds, loss, batch_stats)
                    else:
                        loss, batch_stats = loss_call(batch, outputs)

                    if task.name == 'crosslingual' and self.almt_reg_hyper > 0.0:
                        weight = self.model.encoder.embeddings.almt_layers[
                            'mapping'].weight
                        reg_loss = weight
                        d1, d2 = weight.shape
                        aeq(d1, d2)
                        eye = torch.eye(d1).to(weight.device)
                        reg_loss = ((weight @ weight.T - eye)**2).sum()
                        reg_loss.backward()

                    if loss is not None:
                        self.optim.backward(loss)

                    t_stats.update(batch_stats)
                    r_stats.update(batch_stats)

                except Exception:
                    traceback.print_exc()
                    logger.info("At step %d, we removed a batch - accum %d",
                                self.optim.training_step, k)
                    raise

                # 4. Update the parameters and statistics.
                if self.accum_count == 1:
                    # Multi GPU gradient gather
                    if self.n_gpu > 1:
                        grads = [
                            p.grad.data for p in self.model.parameters()
                            if p.requires_grad and p.grad is not None
                        ]
                        onmt.utils.distributed.all_reduce_and_rescale_tensors(
                            grads, float(1))
                    self.optim.step()

                # If truncated, don't backprop fully.
                # TO CHECK
                # if dec_state is not None:
                #    dec_state.detach()
                if self.model.decoder.state is not None:
                    self.model.decoder.detach_state()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.accum_count > 1:
            if self.n_gpu > 1:
                grads = [
                    p.grad.data for p in self.model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                onmt.utils.distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()
Пример #7
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    aux_vocab = None
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
        if opt.crosslingual:
            aux_vocab = checkpoint['aux_vocab']
    elif opt.crosslingual:
        assert opt.crosslingual in ['old', 'lm']
        vocab = torch.load(opt.data + '.vocab.pt')
        aux_vocab = torch.load(opt.aux_train_data + '.vocab.pt')
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    def get_fields(vocab):
        if old_style_vocab(vocab):
            return load_old_vocab(vocab,
                                  opt.model_type,
                                  dynamic_dict=opt.copy_attn)
        else:
            return vocab

    fields = get_fields(vocab)
    aux_fields = None
    if opt.crosslingual:
        aux_fields = get_fields(aux_vocab)

    if opt.crosslingual:
        if opt.crosslingual == 'old':
            aeq(len(opt.eat_formats), 3)
            fields_info = [
                ('train', fields, 'data', Eat2PlainMonoTask, 'base',
                 opt.eat_formats[0]),
                ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask,
                 'aux', opt.eat_formats[1]),
                ('train', aux_fields, 'aux_train_data',
                 Eat2PlainCrosslingualTask, 'crosslingual', opt.eat_format[2])
            ]
        else:
            aeq(len(opt.eat_formats), 4)
            fields_info = [
                ('train', fields, 'data', Eat2PlainMonoTask, 'base',
                 opt.eat_formats[0]),
                ('train', fields, 'data', EatLMMonoTask, 'lm',
                 opt.eat_formats[1]),
                ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask,
                 'aux', opt.eat_formats[2]),
                ('train', aux_fields, 'aux_train_data', EatLMCrosslingualTask,
                 'crosslingual', opt.eat_formats[3])
            ]
        train_iter = build_crosslingual_dataset_iter(fields_info, opt)
    elif len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    else:
        device_id = 0 if nb_gpu == 1 else -1
        # NOTE Only pass train_iter in my crosslingual mode.
        train_iter = train_iter if opt.crosslingual else None
        passed_fields = {
            'main': fields,
            'crosslingual': aux_fields
        } if opt.crosslingual else None
        single_main(opt,
                    device_id,
                    train_iter=train_iter,
                    passed_fields=passed_fields)