예제 #1
0
def validate_helper(sample, model, max_c_len):

    image, source, target, caption_len, pos = sample

    # Compute loss

    result = model(image, source, pos)
    logits, pos_logits = result.logits, result.pos_logits

    word_loss = masked_CE(logits, target, caption_len).item()
    pos_loss = masked_CE(pos_logits, pos, caption_len - 1).item()

    # Compute symbol accuracy

    word_preds = seq_max_and_mask(logits, caption_len, max_c_len + 1)
    pos_preds = seq_max_and_mask(pos_logits, caption_len - 1, max_c_len + 1)

    cor = torch.sum((target == word_preds).float() /
                    caption_len.float().unsqueeze(1)).item()
    pos_cor = torch.sum((pos == pos_preds).float() /
                        (caption_len - 1).float().unsqueeze(1)).item()

    return [word_loss, cor, pos_cor, pos_loss]
예제 #2
0
    def train_captioner(self):

        self.captioner.train()

        for epoch in range(self.opt.cap_epochs):
            self.cap_epoch = epoch

            self.update_lr(epoch)
            self.update_ss(epoch)

            print_loss, tic = 0, time()

            print("Training captioner")

            for i, sample in enumerate(self.train_loader):

                image, source, target, caption_len, pos, weight = [
                    x.to(device) for x in sample
                ]

                # Forward pass
                self.c_optimizer.zero_grad()

                r = self.captioner(image, source, pos)
                logits, pos_logits = r.logits, r.pos_logits

                if self.opt.weight_captions:
                    word_loss = masked_CE(logits, target, caption_len,
                                          weight.float())
                    pos_loss = masked_CE(pos_logits, pos, caption_len - 1,
                                         weight.float())
                else:
                    word_loss = masked_CE(logits, target, caption_len)
                    pos_loss = masked_CE(pos_logits, pos, caption_len - 1)

                total_loss = word_loss + self.opt.pos_alpha * pos_loss

                # Backwards pass
                total_loss.backward()

                if self.opt.grad_clip:
                    util.gradient_noise_and_clip(self.captioner.parameters(),
                                                 self.opt.max_clip)

                self.c_optimizer.step()

                # Logging
                print_loss += total_loss.item()

                if self.cap_steps % self.opt.print_every == 0:
                    info = {
                        'cap/loss': print_loss / self.opt.print_every,
                        'cap/time': (time() - tic) / self.opt.
                        print_every  # total time so far for this epoch
                    }
                    util.step_logging(self.logger, info, self.cap_steps)
                    util.log_avg_grads(self.logger,
                                       self.captioner,
                                       self.cap_steps,
                                       name="cap/")
                    steps_per_epoch = len(self.train_loader)
                    self.std_logger.info(
                        "Chunk {} Epoch {}, {}/{}| Loss: {} | Time per batch: {} |"
                        " Epoch remaining time (HH:MM:SS) {} | Elapsed time {}"
                        .format(
                            self.chunk + 1, epoch + 1, i, steps_per_epoch,
                            info['cap/loss'], info['cap/time'],
                            util.time_remaining(steps_per_epoch - i,
                                                info['cap/time']),
                            util.time_elapsed(self.start_time, time())))

                    print_loss, tic = 0, time()

                self.cap_steps += 1

            model_score = self.evaluate_captioner()
            self.save_captioner(epoch, model_score)
    def validate(self):

        loss, correct, correct_answers = 0.0, 0.0, 0
        self.model.eval()

        for step, batch in enumerate(self.val_loader):
            img_path = batch[-1]
            batch = [x.to(device) for x in batch[:-1]]

            image, question_len, source, target, caption, q_idx_vec, pos, att, context, refs, answer = self.compute_cap_features_val(
                batch)

            logits = self.model(image, caption, pos, context, att, source,
                                q_idx_vec)

            batch_loss = masked_CE(logits, target, question_len)
            loss += batch_loss.item()
            predictions = seq_max_and_mask(
                logits, question_len, self.val_loader.dataset.max_q_len + 1)

            correct += torch.sum((target == predictions).float() /
                                 question_len.float().unsqueeze(1)).item()

            # evaluate using VQA expert
            result = self.model.sample(
                image,
                caption,
                pos,
                context,
                att,
                q_idx_vec,
                greedy=True,
                max_seq_len=self.val_loader.dataset.max_q_len + 1)
            sample, log_probs, mask = result.question, result.log_prob, result.mask

            correct_answers += self.query_vqa(image, sample, refs, answer,
                                              mask)

            # write image and Q&A to html file to visualize training progress
            if step % self.opt.visualize_every == 0:
                beam_size = 3
                c_i2w = self.val_loader.dataset.c_i2w
                a_i2w = self.val_loader.dataset.a_i2w
                q_i2w = self.val_loader.dataset.q_i2w

                sample_len = mask.sum(dim=1)

                _, _, beam_predictions, beam_lps = self.model.sample_beam(
                    image,
                    caption,
                    pos,
                    context,
                    att,
                    q_idx_vec,
                    beam_size=beam_size,
                    max_seq_len=15)
                beam_predictions, lps, lens = [
                    util.to_np(x) for x in [
                        beam_predictions,
                        torch.sum(beam_lps, dim=2),
                        torch.sum(beam_lps != 0, dim=2)
                    ]
                ]

                target, question_len, sample, sample_len = [
                    util.to_np(x)
                    for x in [target, question_len, sample, sample_len]
                ]

                for i in range(image.size(0)):

                    entry = {
                        'img':
                        img_path[i],
                        'epoch':
                        self.epoch,
                        'answer':
                        a_i2w[answer[i].item()],
                        'gt_question':
                        ' '.join(
                            util.idx2str(q_i2w,
                                         (target[i])[:question_len[i] - 1])),
                        'greedy_question':
                        ' '.join(
                            util.idx2str(q_i2w, (sample[i])[:sample_len[i]])),
                        'beamsearch': [
                            ' '.join(
                                util.idx2str(
                                    q_i2w,
                                    (beam_predictions[i, j])[:lens[i, j]])) +
                            " logprob: {}".format(lps[i, j])
                            for j in range(beam_size)
                        ]
                    }
                    self.visualizer.add_entry(entry)

        self.visualizer.update_html()

        # Reset
        self.model.train()

        l = len(self.val_loader.dataset)
        return [loss / l, correct / l, 100.0 * correct_answers / l]
    def train(self, epoch):

        print("Training")

        print_loss, tic = 0, time()
        self.model.train()

        # manually iterate over dataset
        word_iter = self.word_match_loader.__iter__()
        pos_iter = self.pos_match_loader.__iter__()
        while True:
            try:
                word_batch = word_iter.next()
                pos_batch = pos_iter.next()
            except StopIteration:
                break

            word_batch = [x.to(device) for x in word_batch[:-1]]
            pos_batch = [x.to(device) for x in pos_batch[:-1]]

            image, question_len, source, target, caption, q_idx_vec, pos, att, context = self.compute_cap_features(
                word_batch, pos_batch)

            # Forward pass
            self.optimizer.zero_grad()
            logits = self.model(image, caption, pos, context, att, source,
                                q_idx_vec)
            loss = masked_CE(logits, target, question_len)

            # Backward pass
            loss.backward()

            if self.opt.grad_clip:
                util.gradient_noise_and_clip(self.model.parameters(),
                                             self.opt.max_clip)

            self.optimizer.step()

            # Logging
            print_loss += loss.item()

            if self.global_step % self.opt.print_every == 0:
                info = {
                    'loss': print_loss / self.opt.print_every,
                    'time':
                    (time() - tic) / self.opt.print_every  # time per step
                }

                util.step_logging(self.logger, info, self.global_step)
                util.log_avg_grads(self.logger, self.model, self.global_step)

                steps_per_epoch = len(self.word_match_loader)
                step = self.global_step - epoch * steps_per_epoch
                remaining_steps = steps_per_epoch * (self.opt.max_epochs -
                                                     epoch) - step
                self.std_logger.info(
                    "{}, {}/{}| Loss: {} | Time per batch: {} | Epoch remaining time (HH:MM:SS) {} | "
                    "Elapsed time {} | Total remaining time {}".format(
                        epoch + 1, step, steps_per_epoch, info['loss'],
                        info['time'],
                        util.time_remaining(steps_per_epoch - step,
                                            info['time']),
                        util.time_elapsed(self.start_time, time()),
                        util.time_remaining(remaining_steps, info['time'])))
                print_loss, tic = 0, time()

            self.global_step = self.global_step + 1

        model_score = self.evaluate(epoch + 1)
        self.save_checkpoint(epoch, model_score)
예제 #5
0
    def train(self, epoch):

        print("Training")

        print_loss, tic = 0, time()
        self.model.train()

        for i, sample in enumerate(self.train_loader):
            image, source, target, caption_len, refs, ref_lens, pos = sample[:
                                                                             -3]
            image, source, target, caption_len, pos = [
                x.to(device)
                for x in [image, source, target, caption_len, pos]
            ]

            # Forward pass
            self.optimizer.zero_grad()
            result = self.model(image, source, pos)
            logits, pos_logits = result.logits, result.pos_logits

            # Get losses
            word_loss = masked_CE(logits, target, caption_len)
            pos_loss = masked_CE(pos_logits, pos, caption_len - 1)

            total_loss = word_loss + self.opt.pos_alpha * pos_loss

            # Backward pass
            total_loss.backward()

            if self.opt.grad_clip:
                util.gradient_noise_and_clip(self.model.parameters(),
                                             self.opt.max_clip)

            self.optimizer.step()

            loss = total_loss.item()

            # Logging
            print_loss += loss

            if self.global_step % self.opt.print_every == 0:
                info = {
                    'loss': print_loss / self.opt.print_every,
                    'time':
                    (time() - tic) / self.opt.print_every  # time per step
                }
                util.step_logging(self.logger, info, self.global_step)
                util.log_avg_grads(self.logger, self.model, self.global_step)

                steps_per_epoch = len(self.train_loader)
                step = self.global_step - epoch * steps_per_epoch
                remaining_steps = steps_per_epoch * (self.opt.max_epochs -
                                                     epoch) - step
                self.std_logger.info(
                    "{}, {}/{}| Loss: {} | Time per batch: {} | Epoch remaining time (HH:MM:SS) {} | "
                    "Elapsed time {} | Total remaining time {}".format(
                        epoch + 1, step, steps_per_epoch, info['loss'],
                        info['time'],
                        util.time_remaining(steps_per_epoch - step,
                                            info['time']),
                        util.time_elapsed(self.start_time, time()),
                        util.time_remaining(remaining_steps, info['time'])))
                print_loss, tic = 0, time()

            self.global_step = self.global_step + 1

        model_score = self.evaluate(epoch + 1)
        self.save_checkpoint(epoch, model_score)