示例#1
0
    def train(self, train_iter, epoch, report_func=None):
        """ Train next epoch.
        Args:
            train_iter: training data iterator
            epoch(int): the epoch number
            report_func(fn): function for logging

        Returns:
            stats (:obj:`onmt.Statistics`): epoch loss statistics
        """
        total_stats = Statistics()
        report_stats = Statistics()
        idx = 0
        true_batchs = []
        accum = 0
        normalization = 0
        try:
            add_on = 0
            if len(train_iter) % self.grad_accum_count > 0:
                add_on += 1
            num_batches = len(train_iter) / self.grad_accum_count + add_on
        except NotImplementedError:
            # Dynamic batching
            num_batches = -1

        for i, batch in enumerate(train_iter):
            cur_dataset = train_iter.get_cur_dataset()
            self.train_loss.cur_dataset = cur_dataset

            true_batchs.append(batch)
            accum += 1
            if self.norm_method == "tokens":
                num_tokens = batch.tgt[1:].data.view(-1) \
                    .ne(self.train_loss.padding_idx).sum()
                normalization += num_tokens
            else:
                normalization += batch.batch_size

            if accum == self.grad_accum_count:
                self._gradient_accumulation(true_batchs, total_stats,
                                            report_stats, normalization)

                if report_func is not None:
                    report_stats = report_func(epoch, idx, num_batches,
                                               self.progress_step,
                                               total_stats.start_time,
                                               self.optim.lr, report_stats)
                    self.progress_step += 1

                true_batchs = []
                accum = 0
                normalization = 0
                idx += 1

        if len(true_batchs) > 0:
            self._gradient_accumulation(true_batchs, total_stats, report_stats,
                                        normalization)
            true_batchs = []

        return total_stats
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`onmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = Statistics()

        for batch in valid_iter:
            cur_dataset = valid_iter.get_cur_dataset()
            self.valid_loss.cur_dataset = cur_dataset

            src = onmt.io.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            else:
                src_lengths = None

            tgt = onmt.io.make_features(batch, 'tgt')

            # extract indices for all entries in the mini-batch
            idxs = batch.indices.cpu().data.numpy()
            # load image features for this minibatch into a pytorch Variable
            img_feats = torch.from_numpy(self.valid_img_feats[idxs])
            img_feats = torch.autograd.Variable(img_feats, requires_grad=False)
            if next(self.model.parameters()).is_cuda:
                img_feats = img_feats.cuda()
            else:
                img_feats = img_feats.cpu()

            # F-prop through the model.
            if self.multimodal_model_type == 'src+img':
                outputs, outputs_img, attns, _ = self.model(
                    src, tgt, src_lengths, img_feats)
            elif self.multimodal_model_type in ['imgw', 'imge', 'imgd']:
                outputs, attns, _ = self.model(src, tgt, src_lengths,
                                               img_feats)
            else:
                raise Exception("Multimodal model type not yet supported: %s" %
                                (self.multimodal_model_type))

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
示例#3
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`onmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = Statistics()

        for batch in valid_iter:
            cur_dataset = valid_iter.get_cur_dataset()
            self.valid_loss.cur_dataset = cur_dataset

            src = onmt.io.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            else:
                src_lengths = None

            tgt = onmt.io.make_features(batch, 'tgt')

            # extract indices for all entries in the mini-batch
            idxs = batch.indices.cpu().data.numpy()
            # load image features for this minibatch into a pytorch Variable
            img_feats = torch.tensor(self.valid_img_feats[idxs].transpose(
                (0, 2, 1)),
                                     dtype=torch.float32)
            img_mask = torch.tensor(self.valid_img_mask[idxs],
                                    dtype=torch.float32)
            img_attr = torch.tensor(self.valid_attr[idxs], dtype=torch.float32)
            if next(self.model.parameters()).is_cuda:
                img_feats = img_feats.cuda()
                img_mask = img_mask.cuda()
                img_attr = img_attr.cuda()
            else:
                img_feats = img_feats.cpu()
                img_mask = img_mask.cpu()
                img_attr = img_attr.cpu()
            # F-prop through the model.
            if 'bank' in self.multimodal_model_type \
                    or 'dcap' in self.multimodal_model_type \
                    or 'imgw' in self.multimodal_model_type:
                outputs, attns, _ = self.model(src,
                                               tgt,
                                               src_lengths,
                                               img_attr=img_attr,
                                               img_feats=img_feats,
                                               img_mask=img_mask)
            else:
                outputs, attns, _ = self.model(src, tgt, src_lengths)

            # Compute loss.
            if 'generator' in self.multimodal_model_type:
                batch_stats = self.valid_loss.monolithic_compute_loss(
                    batch, outputs, attns, img_feats=img_feats)
            else:
                batch_stats = self.valid_loss.monolithic_compute_loss(
                    batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
示例#4
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`onmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        if not self.multimodal_model_type in MODEL_TYPES:
            stats = Statistics()
        else:
            stats = VIStatistics(self.multimodal_model_type)

        for batch in valid_iter:
            cur_dataset = valid_iter.get_cur_dataset()
            self.valid_loss.cur_dataset = cur_dataset

            src = onmt.io.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            else:
                src_lengths = None

            tgt = onmt.io.make_features(batch, 'tgt')
            # we are now interested in capturing the target sequences lengths
            if self.data_type == 'text':
                _, tgt_lengths = batch.tgt
                padding_token = self.train_loss.padding_idx
            else:
                tgt_lengths = None
            # set batch.tgt back to target tokens only (Loss object expects it like that)
            batch.tgt = batch.tgt[0]

            # extract indices for all entries in the mini-batch
            idxs = batch.indices.cpu().data.numpy()
            # load image features for this minibatch into a pytorch Variable
            img_feats = torch.from_numpy( self.valid_img_feats[idxs] )
            img_feats = torch.autograd.Variable(img_feats, requires_grad=False)
            if next(self.model.parameters()).is_cuda:
                img_feats = img_feats.cuda()
            else:
                img_feats = img_feats.cpu()

            if self.model_opt.image_loss == 'categorical':
                # load image vectors for this minibatch into a pytorch Variable
                img_vecs = torch.from_numpy( self.valid_img_vecs[idxs] )
                img_vecs = torch.autograd.Variable(img_vecs, requires_grad=False)
                if next(self.model.parameters()).is_cuda:
                    img_vecs = img_vecs.cuda()
                else:
                    img_vecs = img_vecs.cpu()
            else:
                img_vecs = None

            # F-prop through the model.
            if self.multimodal_model_type in MODEL_TYPES:
                outputs, attns, _ = self.model(src, tgt, src_lengths, tgt_lengths, img_feats,
                        img_vecs=img_vecs, padding_token=padding_token)
            else:
                raise Exception("Multimodal model type not yet supported: %s"%(
                        self.multimodal_model_type))

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                    batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        if isinstance(batch_stats, VIStatistics):
            stats.save_progress(self.optim.lr, self.model_updates, self._epoch, 'valid')

        # Set model back to training mode.
        self.model.train()

        return stats
示例#5
0
    def train(self,
              epoch,
              report_func=None,
              batch_override=-1,
              text=None,
              attnMask=False,
              startMask=0,
              endMask=0):
        """ Called for each epoch to train. """
        src_total_stats = Statistics()
        tgt_total_stats = Statistics()
        report_stats = Statistics()

        src_batches = [s for s in self.src_train_iter]
        nBatches = len(src_batches)

        if hasattr(self, 'tgt_model'):
            if self.big_text:
                tgt_batches = [t for t in text]
            else:
                tgt_batches = [t for t in self.tgt_train_iter]

            nBatches = min(len(src_batches), len(tgt_batches))
        if batch_override > 0:
            nBatches = batch_override

        src_batches = src_batches[:nBatches]
        if hasattr(self, 'tgt_model'):
            tgt_batches = tgt_batches[:nBatches]

        for i in range(nBatches):
            # SRC
            batch = src_batches[i]

            src = onmt.io.make_features(batch, 'src', 'audio')
            src_labels = src.squeeze().sum(1)[:, 0:-1:8].data.cpu().numpy()
            #print src.size(), src_labels.shape

            self.src_model.zero_grad()
            outputs, _ = self.src_model(src, None)
            l = [self.src_label] * outputs.size()[0]
            labels = Variable(torch.cuda.FloatTensor(l).view(-1, 1))
            w = np.zeros(src_labels.shape)
            w[src_labels != 0.] = 1.
            if startMask > 0:
                w[:, :startMask] = 0
            weights = torch.cuda.FloatTensor(w)
            #print src_labels.shape, w.shape, weights.size()
            self.criterion.weight = weights.view(-1, 1)[:outputs.size()[0], :]

            #print outputs.size(), labels.size()
            loss = self.criterion(outputs, labels)
            loss.backward()
            #if i % 10 == 0:
            #    print "discriminator", i, self.src_label
            #    print outputs.data[0:5], loss.data[0]

            src_total_stats.update_loss(loss.data[0])
            report_stats.update_loss(loss.data[0])

            # 4. Update the parameters and statistics.
            self.src_optim.step()

            if not hasattr(self, 'tgt_model'):
                continue

            # TGT
            if self.tgt_optim is None:
                continue

            batch = tgt_batches[i]
            _, src_lengths = batch.src

            src = onmt.io.make_features(batch, 'src')
            #src_lengths, src = self.add_noise(src_lengths, src)

            report_stats.n_src_words += src_lengths.sum()

            self.tgt_model.zero_grad()
            try:
                outputs, _ = self.tgt_model(src, src_lengths)
            except:
                print src_lengths
                raise

            l = [self.tgt_label] * outputs.size()[0]
            labels = Variable(torch.cuda.FloatTensor(l).view(-1, 1))
            weights = torch.cuda.FloatTensor(src.size()[0],
                                             src.size()[1]).zero_()
            for j in range(len(src_lengths)):
                weights[:src_lengths[j], j] = 1.
            self.criterion.weight = weights.view(-1, 1)[:outputs.size()[0], :]

            #print outputs.size(), labels.size()
            loss = self.criterion(outputs, labels)
            loss.backward()
            #if i % 10 == 0:
            #    print "discriminator", i, self.tgt_label
            #    print outputs.data[0:5], loss.data[0]

            tgt_total_stats.update_loss(loss.data[0])
            report_stats.update_loss(loss.data[0])

            # 4. Update the parameters and statistics.
            self.tgt_optim.step()

            if report_func is not None:
                report_stats = report_func(epoch, i, nBatches,
                                           src_total_stats.start_time,
                                           self.src_optim.lr, report_stats)

        return src_total_stats, tgt_total_stats
示例#6
0
    def train(self, epoch, report_func=None):
        """ Called for each epoch to train. """
        total_stats = Statistics()
        report_stats = Statistics()

        for i, batch in enumerate(self.train_iter):

            _, src_lengths = batch.src

            src = onmt.io.make_features(batch, 'src')

            #src_lengths, src = self.add_noise(src_lengths, src)

            report_stats.n_src_words += src_lengths.sum()

            # compute outputs
            self.model.zero_grad()
            outputs, _ = self.model(src, src_lengths)

            # loss re: true_label, backprop through discrim
            true_l = [self.true_label] * outputs.size()[0]
            labels = Variable(torch.cuda.LongTensor(true_l))

            loss = self.criterion(outputs, labels)
            loss.backward(retain_graph=True)
            if i % 10 == 0:
                print "discriminator:", i, outputs.data[0:5]

            total_stats.update_loss(loss.data[0])
            report_stats.update_loss(loss.data[0])

            self.discrim_optim.step()

            # loss re: false_label, backprop through generator
            self.model.zero_grad()
            #outputs = self.model(src, src_lengths)
            fake_l = [1 - self.true_label] * outputs.size()[0]
            labels = Variable(torch.cuda.LongTensor(fake_l))

            loss = self.criterion(outputs, labels)
            loss.backward()
            if i % 10 == 0:
                print "generator:", i, outputs.data[0:5]

            total_stats.update_loss(loss.data[0])
            report_stats.update_loss(loss.data[0])

            self.gener_optim.step()

            if report_func is not None:
                report_stats = report_func(epoch, i, len(self.train_iter),
                                           total_stats.start_time,
                                           self.discrim_optim.lr, report_stats)

        return total_stats
示例#7
0
    def validate(self):
        """ Validate model.

        Returns:
            :obj:`onmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.src_model.eval()
        self.tgt_model.eval()

        tgt_stats = Statistics()
        src_stats = Statistics()
        src_tgt_stats = Statistics()

        for batch in self.src_valid_iter:

            src = onmt.io.make_features(batch, 'src', 'audio')
            src_labels = src.squeeze().sum(1)[:, 0:-1:8].data.cpu().numpy()
            #print src.size(), src_labels.shape

            outputs, _ = self.src_model(src, None)
            l = [self.src_label] * outputs.size()[0]
            labels = Variable(torch.cuda.FloatTensor(l).view(-1, 1))
            w = np.zeros(src_labels.shape)
            w[src_labels != 0.] = 1.
            weights = torch.cuda.FloatTensor(w)
            self.criterion.weight = weights.view(-1, 1)[:outputs.size()[0], :]

            # Compute loss.
            loss = self.criterion(outputs, labels)

            # Update statistics.
            src_stats.update_loss(loss.data[0])

            l = [self.tgt_label] * outputs.size()[0]
            labels = Variable(torch.cuda.FloatTensor(l).view(-1, 1))
            loss = self.criterion(outputs, labels)
            src_tgt_stats.update_loss(loss.data[0])

        for batch in self.tgt_valid_iter:

            _, tgt_lengths = batch.src

            tgt = onmt.io.make_features(batch, 'src')
            #print src.size(), src_labels.shape

            outputs, _ = self.tgt_model(tgt, tgt_lengths)
            l = [self.tgt_label] * outputs.size()[0]
            labels = Variable(torch.cuda.FloatTensor(l).view(-1, 1))
            weights = torch.cuda.FloatTensor(tgt.size()[0],
                                             tgt.size()[1]).zero_()
            for j in range(len(tgt_lengths)):
                weights[:tgt_lengths[j], j] = 1.

            self.criterion.weight = weights.view(-1, 1)[:outputs.size()[0], :]

            # Compute loss.
            loss = self.criterion(outputs, labels)

            # Update statistics.
            tgt_stats.update_loss(loss.data[0])

        # Set model back to training mode.
        self.src_model.train()
        self.tgt_model.train()

        return src_stats, tgt_stats, src_tgt_stats