Пример #1
0
    def train(self):

        print_loss, tic = 0, time()

        for i, sample in enumerate(self.train_loader):

            image, refs = [x.to(device) for x in [sample[0], sample[4]]]
            ref_lens, img_path, index = sample[5], sample[7], sample[8]
            batch_loss = self.do_iteration(image, refs, ref_lens, index,
                                           img_path)

            print_loss += batch_loss

            info = {
                'collect/loss': print_loss / self.opt.print_every,
                'collect/time': (time() - tic) /
                self.opt.print_every  # total time so far for this epoch
            }
            util.step_logging(self.logger, info, self.collection_steps)

            if self.collection_steps % self.opt.print_every == 0:
                util.log_avg_grads(self.logger,
                                   self.dmaker,
                                   self.collection_steps,
                                   name="dec")
                steps_per_epoch = len(self.train_loader)
                self.std_logger.info(
                    "Chunk {} Epoch {}, {}/{}| Loss: {} | Time per batch: {} |"
                    " Epoch remaining time (HH:MM:SS) {} | Elapsed time {}".
                    format(
                        self.chunk + 1, self.collection_epoch, i,
                        steps_per_epoch, info['collect/loss'],
                        info['collect/time'],
                        util.time_remaining(steps_per_epoch - i,
                                            info['collect/time']),
                        util.time_elapsed(self.start_time, time())))

                print_loss, tic = 0, time()

            self.collection_steps += 1

        self.trainLLvisualizer.update_html()
        self.data_collector.process_collected_data()
Пример #2
0
    def evaluate_with_questions(self):

        self.std_logger.info("Validating decision maker")

        self.dmaker.eval()
        self.captioner.train()
        self.qgen.train()
        # if self.opt.cap_eval:
        #     self.captioner.eval()
        # else:
        #     self.captioner.train()
        #
        # if self.opt.quegen_eval:
        #     self.question_generator.eval()
        # else:
        #     self.question_generator.train()

        c_i2w = self.val_loader.dataset.c_i2w

        correct, pos_correct, eval_scores = [0.0, 0.0], [0.0, 0.0], []
        caption_predictions = [[], []]

        with torch.no_grad():

            for step, sample in enumerate(self.val_loader):

                image, refs, target, caption_len = [
                    x.to(device)
                    for x in [sample[0], sample[4], sample[2], sample[3]]
                ]
                ref_lens, img_path, index, img_id = sample[5], sample[
                    7], sample[8], sample[9]

                batch_size = image.size(0)

                caps, cmasks, poses, pps, qs, qlps, qmasks, aps, cps, atts = [], [], [], [], [], [],[], [], [], []

                # 1. Caption completely
                self.set_seed(self.opt.seed)

                r = self.captioner.sample(
                    image,
                    greedy=True,
                    max_seq_len=self.opt.c_max_sentence_len + 1)

                caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \
                    = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg

                cap_len = cap_mask.long().sum(dim=1)
                caps.append(caption)
                cmasks.append(cap_mask)
                poses.append(pos_probs)

                caption = self.pad_caption(caption, cap_len)

                # get the hidden state context
                source = torch.cat([
                    self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1]
                ],
                                   dim=1)

                r = self.fixed_caption_encoder(image,
                                               source,
                                               gt_pos=None,
                                               ss=False)
                h = r.hidden

                # 2. Identify the best time to ask a question, excluding ended sentences
                logit, valid_pos_mask = self.dmaker(
                    h, attended_img, caption, cap_len, pos_probs, topk_words,
                    self.captioner.caption_embedding.weight.data)
                masked_prob = masked_softmax(
                    logit,
                    cap_mask,
                    valid_pos_mask,
                    max_len=self.opt.c_max_sentence_len)
                dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision(
                    masked_prob, cap_mask, greedy=True)

                # 3. Ask the teacher a question and get the answer
                ans, ans_mask, r = self.ask_question(image,
                                                     caption,
                                                     refs,
                                                     pos_probs,
                                                     h,
                                                     att,
                                                     ask_idx,
                                                     q_greedy=True)

                # logging
                cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
                cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
                pps.append(r.pos_prob[0])
                pps.append(r.pos_prob[0])
                atts.append(r.att[0])
                atts.append(r.att[0])
                qlps.append(r.q_logprob.unsqueeze(1))
                qlps.append(r.q_logprob.unsqueeze(1))
                qmasks.append(r.q_mask)
                qmasks.append(r.q_mask)
                qs.append(r.question)
                qs.append(r.question)
                aps.append(r.ans_prob)
                aps.append(r.ans_prob)

                # 4. Compute new captions based on teacher's answer
                # rollout caption
                r = self.caption_with_teacher_answer(image,
                                                     ask_mask,
                                                     ans_mask,
                                                     greedy=True)

                poses.append(r.pos_prob)
                rollout = r.caption
                rollout_mask = r.cap_mask

                # replace caption
                replace = replace_word_in_caption(caps[0], ans, ask_idx,
                                                  ask_flag)

                base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0],
                                                           dim=1), refs,
                                        ref_lens, self.scorers, self.c_i2w)
                rollout_rwd = mixed_reward(rollout,
                                           torch.sum(rollout_mask,
                                                     dim=1), refs, ref_lens,
                                           self.scorers, self.c_i2w)
                replace_rwd = mixed_reward(replace, torch.sum(cmasks[0],
                                                              dim=1), refs,
                                           ref_lens, self.scorers, self.c_i2w)

                caps.append(rollout)
                caps.append(replace)
                cmasks.append(rollout_mask)
                cmasks.append(cmasks[0])

                stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats(
                    replace_rwd, rollout_rwd, base_rwd)

                best_cap, best_cap_mask = choose_better_caption(
                    replace_rwd, replace, cmasks[0], rollout_rwd, rollout,
                    rollout_mask)

                caps = [caps[0], best_cap]
                cmasks = [cmasks[0], best_cap_mask]

                # Collect captions for coco evaluation
                img_id = util.to_np(img_id)

                for i in range(len(caps)):
                    words, lens, = util.to_np(caps[i]), util.to_np(
                        cmasks[i].sum(dim=1))

                    for j in range(image.size(0)):
                        inds = words[j][:lens[j]]
                        caption = ""
                        for k, ind in enumerate(inds):
                            if k > 0:
                                caption = caption + ' '
                            caption = caption + c_i2w[ind]

                        pred = {'image_id': img_id[j], 'caption': caption}
                        caption_predictions[i].append(pred)

                for i in range(len(correct)):
                    predictions = caps[i] * cmasks[i].long()
                    correct[i] += (
                        (target == predictions).float() /
                        caption_len.float().unsqueeze(1)).sum().item()

                # Logging
                if step % self.opt.val_print_every == 0:

                    c_i2w = self.val_loader.dataset.c_i2w
                    p_i2w = self.val_loader.dataset.p_i2w
                    q_i2w = self.val_loader.dataset.q_i2w
                    a_i2w = self.val_loader.dataset.a_i2w

                    caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len, \
                    flag, raw_idx, refs, ref_lens = \
                        [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1),
                                                 rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0],
                                                 qmasks[0].long().sum(dim=1), ask_flag.long(),
                                                 ask_idx, refs, ref_lens]]
                    pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0]

                    for i in range(image.size(0)):
                        top_pos = torch.topk(pos_probs, 3)[1]
                        top_ans = torch.topk(ans_probs[i], 3)[1]
                        top_cap = torch.topk(cap_probs[i], 5)[1]

                        word = c_i2w[caption[i][
                            raw_idx[i]]] if raw_idx[i] < len(
                                caption[i]) else 'Nan'

                        entry = {
                            'img':
                            img_path[i],
                            'epoch':
                            self.collection_epoch,
                            'caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (caption[i])[:cap_len[i]])) +
                            " | Reward: {:.2f}".format(base_rwd[i]),
                            'qaskprobs':
                            ' '.join([
                                "{:.2f}".format(x) if x > 0.0 else "_"
                                for x in dec_probs[i]
                            ]),
                            'rollout_caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (rollout[i])[:rollout_len[i]])) +
                            " | Reward: {:.2f}".format(rollout_rwd[i]),
                            'replace_caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (replace[i])[:cap_len[i]])) +
                            " | Reward: {:.2f}".format(replace_rwd[i]),
                            'index':
                            raw_idx[i],
                            'flag':
                            bool(flag[i]),
                            'word':
                            word,
                            'pos':
                            ' | '.join([
                                p_i2w[x.item()]
                                if x.item() in p_i2w else p_i2w[18]
                                for x in top_pos
                            ]),
                            'question':
                            ' '.join(
                                util.idx2str(q_i2w,
                                             (question[i])[:q_len[i]])) +
                            " | logprob: {}".format(
                                q_logprob[i, :q_len[i]].sum()),
                            'answers':
                            ' | '.join([a_i2w[x.item()] for x in top_ans]),
                            'words':
                            ' | '.join([c_i2w[x.item()] for x in top_cap]),
                            'refs': [
                                ' '.join(
                                    util.idx2str(c_i2w,
                                                 (refs[i, j])[:ref_lens[i,
                                                                        j]]))
                                for j in range(3)
                            ]
                        }

                        self.valLLvisualizer.add_entry(entry)

                    info = {
                        'eval/replace over rollout':
                        float(stat_rero) / (batch_size),
                        'eval/rollout over replace':
                        float(stat_rore) / (batch_size),
                        'eval/replace over all':
                        float(stat_reall) / (batch_size),
                        'eval/rollout over all':
                        float(stat_roall) / (batch_size),
                        'eval/question asking frequency (percent)':
                        float(ask_flag.sum().item()) / ask_flag.size(0)
                    }
                    util.step_logging(self.logger, info, self.eval_steps)
                    self.eval_steps += 1

        self.valLLvisualizer.update_html()

        acc = [x / len(self.val_loader.dataset) for x in correct]

        for i in range(len(correct)):
            eval_scores.append(
                language_scores(caption_predictions[i],
                                self.opt.run_name,
                                self.result_path,
                                annFile=self.opt.coco_annotation_file))

        self.dmaker.train()

        return acc, eval_scores
Пример #3
0
    def do_iteration(self, image, refs, ref_lens, index, img_path):
        self.d_optimizer.zero_grad()

        batch_size = image.size(0)

        caps, cmasks, qs, qlps, qmasks, aps, pps, atts, cps, dps = [], [], [], [], [], [], [], [], [], []

        # 1. Caption completely
        self.set_seed(self.opt.seed)

        r = self.captioner.sample(image,
                                  greedy=self.opt.cap_greedy,
                                  max_seq_len=self.opt.c_max_sentence_len + 1,
                                  temperature=self.opt.temperature)

        caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \
            = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg

        # Don't backprop through captioner
        caption, cap_probs, cap_mask, pos_probs, attended_img \
            = [x.detach() for x in [caption, cap_probs, cap_mask, pos_probs, attended_img]]

        cap_len = cap_mask.long().sum(dim=1)
        caps.append(caption)
        cmasks.append(cap_mask)

        caption = self.pad_caption(caption, cap_len)

        # get the hidden state context
        source = torch.cat(
            [self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1]],
            dim=1)
        r = self.fixed_caption_encoder(image, source, gt_pos=None, ss=False)
        h = r.hidden.detach()
        topk_words = [[y.detach() for y in x] for x in topk_words]

        # 2. Identify the best time to ask a question, excluding ended sentences, baseline against the greedy decision
        logit, valid_pos_mask = self.dmaker(
            h, attended_img, caption, cap_len, pos_probs, topk_words,
            self.captioner.caption_embedding.weight.data)
        masked_prob = masked_softmax(logit,
                                     cap_mask,
                                     valid_pos_mask,
                                     self.opt.dm_temperature,
                                     max_len=self.opt.c_max_sentence_len)

        dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision(
            masked_prob, cap_mask, greedy=False)
        _, ask_idx_greedy, ask_flag_greedy, ask_mask_greedy = self.sample_decision(
            masked_prob, cap_mask, greedy=True)

        dps.append(dm_prob.unsqueeze(1))

        # 3. Ask the teacher a question and get the answer
        ans, ans_mask, r = self.ask_question(image,
                                             caption,
                                             refs,
                                             pos_probs,
                                             h,
                                             att,
                                             ask_idx,
                                             q_greedy=self.opt.q_greedy,
                                             temperature=self.opt.temperature)
        ans_greedy, ans_mask_greedy, rg = self.ask_question(
            image,
            caption,
            refs,
            pos_probs,
            h,
            att,
            ask_idx_greedy,
            q_greedy=self.opt.q_greedy,
            temperature=self.opt.temperature)

        # logging stuff
        cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
        cps.append(cap_probs[self.range_vector[:batch_size], ask_idx_greedy])
        pps.append(r.pos_prob[0])
        pps.append(rg.pos_prob[0])
        atts.append(r.att[0])
        atts.append(rg.att[0])
        qlps.append(r.q_logprob.unsqueeze(1))
        qlps.append(rg.q_logprob.unsqueeze(1))
        qmasks.append(r.q_mask)
        qmasks.append(rg.q_mask)
        qs.append(r.question.detach())
        qs.append(rg.question.detach())
        aps.append(r.ans_prob.detach())
        aps.append(rg.ans_prob.detach())

        # 4. Compute new captions based on teacher's answer
        # rollout caption
        r = self.caption_with_teacher_answer(image,
                                             ask_mask,
                                             ans_mask,
                                             greedy=self.opt.cap_greedy,
                                             temperature=self.opt.temperature)
        rg = self.caption_with_teacher_answer(image,
                                              ask_mask_greedy,
                                              ans_mask_greedy,
                                              greedy=self.opt.cap_greedy,
                                              temperature=self.opt.temperature)

        rollout, rollout_mask, rollout_greedy, rollout_mask_greedy = [
            x.detach()
            for x in [r.caption, r.cap_mask, rg.caption, rg.cap_mask]
        ]

        # replace caption
        replace = replace_word_in_caption(caps[0], ans, ask_idx, ask_flag)
        replace_greedy = replace_word_in_caption(caps[0], ans_greedy,
                                                 ask_idx_greedy,
                                                 ask_flag_greedy)

        # 5. Compute reward for captions
        base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0], dim=1), refs,
                                ref_lens, self.scorers, self.c_i2w)
        rollout_rwd = mixed_reward(rollout, torch.sum(rollout_mask, dim=1),
                                   refs, ref_lens, self.scorers, self.c_i2w)
        rollout_greedy_rwd = mixed_reward(
            rollout_greedy, torch.sum(rollout_mask_greedy, dim=1), refs,
            ref_lens, self.scorers, self.c_i2w)

        replace_rwd = mixed_reward(replace, torch.sum(cmasks[0], dim=1), refs,
                                   ref_lens, self.scorers, self.c_i2w)
        replace_greedy_rwd = mixed_reward(replace_greedy,
                                          torch.sum(cmasks[0], dim=1), refs,
                                          ref_lens, self.scorers, self.c_i2w)

        rwd = np.maximum(replace_rwd, rollout_rwd)
        rwd_greedy = np.maximum(replace_greedy_rwd, rollout_greedy_rwd)

        best_cap, best_cap_mask = choose_better_caption(
            replace_rwd, replace, cmasks[0], rollout_rwd, rollout,
            rollout_mask)
        best_cap_greedy, best_cap_greedy_mask = choose_better_caption(
            replace_greedy_rwd, replace_greedy, cmasks[0], rollout_greedy_rwd,
            rollout_greedy, rollout_mask_greedy)

        # some statistics on whether rollout or single-word-replace is better
        stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats(
            replace_rwd, rollout_rwd, base_rwd)
        caps.append(best_cap)
        cmasks.append(best_cap_mask)
        caps.append(best_cap_greedy)
        cmasks.append(best_cap_greedy_mask)

        # Backwards pass to train decision maker with policy gradient loss
        reward_delta = torch.from_numpy(rwd - rwd_greedy).type(
            torch.float).to(device)
        reward_delta = reward_delta - self.opt.ask_penalty * ask_flag.float()

        loss = masked_PG(reward_delta.detach(),
                         torch.log(dps[0]).squeeze(), ask_flag.detach())
        loss.backward()

        self.d_optimizer.step()

        # also save the question asked and answer, and top-k predictions from captioner
        answers = [torch.max(x, dim=1)[1] for x in aps]
        topwords = [torch.topk(x, 20)[1] for x in cps]
        question_lens = [x.sum(dim=1) for x in qmasks]
        self.data_collector.collect_batch(index.clone(), caps, cmasks,
                                          [base_rwd, rwd, rwd_greedy], answers,
                                          qs, question_lens, topwords)

        # Logging
        if self.collection_steps % self.opt.print_every == 0:

            c_i2w = self.val_loader.dataset.c_i2w
            p_i2w = self.val_loader.dataset.p_i2w
            q_i2w = self.val_loader.dataset.q_i2w
            a_i2w = self.val_loader.dataset.a_i2w

            caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len,\
            flag, raw_idx, refs, ref_lens = \
                [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1),
                                         rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0],
                                         qmasks[0].long().sum(dim=1), ask_flag.long(),
                                         ask_idx, refs, ref_lens]]
            pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0]

            for i in range(image.size(0)):

                top_pos = torch.topk(pos_probs, 3)[1]
                top_ans = torch.topk(ans_probs[i], 3)[1]
                top_cap = torch.topk(cap_probs[i], 5)[1]

                word = c_i2w[caption[i][raw_idx[i]]] if raw_idx[i] < len(
                    caption[i]) else 'Nan'

                entry = {
                    'img':
                    img_path[i],
                    'epoch':
                    self.collection_epoch,
                    'caption':
                    ' '.join(util.idx2str(c_i2w, (caption[i])[:cap_len[i]])) +
                    " | Reward: {:.2f}".format(base_rwd[i]),
                    'qaskprobs':
                    ' '.join([
                        "{:.2f}".format(x) if x > 0.0 else "_"
                        for x in dec_probs[i]
                    ]),
                    'rollout_caption':
                    ' '.join(util.idx2str(c_i2w,
                                          (rollout[i])[:rollout_len[i]])) +
                    " | Reward: {:.2f}".format(rollout_rwd[i]),
                    'replace_caption':
                    ' '.join(util.idx2str(c_i2w, (replace[i])[:cap_len[i]])) +
                    " | Reward: {:.2f}".format(replace_rwd[i]),
                    'index':
                    raw_idx[i],
                    'flag':
                    bool(flag[i]),
                    'word':
                    word,
                    'pos':
                    ' | '.join([
                        p_i2w[x.item()] if x.item() in p_i2w else p_i2w[18]
                        for x in top_pos
                    ]),
                    'question':
                    ' '.join(util.idx2str(q_i2w, (question[i])[:q_len[i]])) +
                    " | logprob: {}".format(q_logprob[i, :q_len[i]].sum()),
                    'answers':
                    ' | '.join([a_i2w[x.item()] for x in top_ans]),
                    'words':
                    ' | '.join([c_i2w[x.item()] for x in top_cap]),
                    'refs': [
                        ' '.join(
                            util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]]))
                        for j in range(3)
                    ]
                }

                self.trainLLvisualizer.add_entry(entry)

                info = {
                    'collect/question logprob':
                    torch.mean(torch.cat(qlps, dim=1)).item(),
                    'collect/replace over rollout':
                    float(stat_rero) / (batch_size),
                    'collect/rollout over replace':
                    float(stat_rore) / (batch_size),
                    'collect/replace over all':
                    float(stat_reall) / (batch_size),
                    'collect/rollout over all':
                    float(stat_roall) / (batch_size),
                    'collect/sampled decision equals greedy decision':
                    float((ask_idx == ask_idx_greedy).sum().item()) /
                    (batch_size),
                    'collect/question asking frequency (percent)':
                    float(ask_flag.sum().item()) / ask_flag.size(0)
                }
                util.step_logging(self.logger, info, self.collection_steps)

        return loss.item()
    def train(self, epoch):

        print("Training")

        print_loss, tic = 0, time()

        for i, sample in enumerate(self.train_loader):

            image, question, question_len, answer, captions = sample[:-3]
            image, question, captions, answer = [
                x.to(device) for x in [image, question, captions, answer]
            ]

            self.optimizer.zero_grad()

            # Forward pass
            result = self.model(image, question, captions)
            logits = result.logits

            # Get loss
            loss = self.loss_function(
                logits,
                answer)  # answer is coming in as double for some reason

            # Backward pass
            loss.backward()

            if self.opt.grad_clip:
                util.gradient_noise_and_clip(self.model.parameters(),
                                             self.opt.max_clip)

            self.optimizer.step()

            # Logging
            print_loss += loss.item()

            if self.global_step % self.opt.print_every == 0:

                info = {
                    'loss': print_loss / self.opt.print_every,
                    'time':
                    (time() - tic) / self.opt.print_every  # time per step
                }
                util.step_logging(self.logger, info, self.global_step)
                util.log_avg_grads(self.logger, self.model, self.global_step)

                steps_per_epoch = len(self.train_loader)
                step = self.global_step - epoch * steps_per_epoch
                remaining_steps = steps_per_epoch * (self.opt.max_epochs -
                                                     epoch) - step
                self.std_logger.info(
                    "{}, {}/{}| Loss: {} | Time per batch: {} | Epoch remaining time (HH:MM:SS) {} | "
                    "Elapsed time {} | Total remaining time {}".format(
                        epoch + 1, step, steps_per_epoch, info['loss'],
                        info['time'],
                        util.time_remaining(steps_per_epoch - step,
                                            info['time']),
                        util.time_elapsed(self.start_time, time()),
                        util.time_remaining(remaining_steps, info['time'])))
                print_loss, tic = 0, time()

            self.global_step = self.global_step + 1

        model_score = self.evaluate(epoch + 1)
        self.save_checkpoint(epoch, model_score)
Пример #5
0
    def train_captioner(self):

        self.captioner.train()

        for epoch in range(self.opt.cap_epochs):
            self.cap_epoch = epoch

            self.update_lr(epoch)
            self.update_ss(epoch)

            print_loss, tic = 0, time()

            print("Training captioner")

            for i, sample in enumerate(self.train_loader):

                image, source, target, caption_len, pos, weight = [
                    x.to(device) for x in sample
                ]

                # Forward pass
                self.c_optimizer.zero_grad()

                r = self.captioner(image, source, pos)
                logits, pos_logits = r.logits, r.pos_logits

                if self.opt.weight_captions:
                    word_loss = masked_CE(logits, target, caption_len,
                                          weight.float())
                    pos_loss = masked_CE(pos_logits, pos, caption_len - 1,
                                         weight.float())
                else:
                    word_loss = masked_CE(logits, target, caption_len)
                    pos_loss = masked_CE(pos_logits, pos, caption_len - 1)

                total_loss = word_loss + self.opt.pos_alpha * pos_loss

                # Backwards pass
                total_loss.backward()

                if self.opt.grad_clip:
                    util.gradient_noise_and_clip(self.captioner.parameters(),
                                                 self.opt.max_clip)

                self.c_optimizer.step()

                # Logging
                print_loss += total_loss.item()

                if self.cap_steps % self.opt.print_every == 0:
                    info = {
                        'cap/loss': print_loss / self.opt.print_every,
                        'cap/time': (time() - tic) / self.opt.
                        print_every  # total time so far for this epoch
                    }
                    util.step_logging(self.logger, info, self.cap_steps)
                    util.log_avg_grads(self.logger,
                                       self.captioner,
                                       self.cap_steps,
                                       name="cap/")
                    steps_per_epoch = len(self.train_loader)
                    self.std_logger.info(
                        "Chunk {} Epoch {}, {}/{}| Loss: {} | Time per batch: {} |"
                        " Epoch remaining time (HH:MM:SS) {} | Elapsed time {}"
                        .format(
                            self.chunk + 1, epoch + 1, i, steps_per_epoch,
                            info['cap/loss'], info['cap/time'],
                            util.time_remaining(steps_per_epoch - i,
                                                info['cap/time']),
                            util.time_elapsed(self.start_time, time())))

                    print_loss, tic = 0, time()

                self.cap_steps += 1

            model_score = self.evaluate_captioner()
            self.save_captioner(epoch, model_score)
    def train(self, epoch):

        print("Training")

        print_loss, tic = 0, time()
        self.model.train()

        # manually iterate over dataset
        word_iter = self.word_match_loader.__iter__()
        pos_iter = self.pos_match_loader.__iter__()
        while True:
            try:
                word_batch = word_iter.next()
                pos_batch = pos_iter.next()
            except StopIteration:
                break

            word_batch = [x.to(device) for x in word_batch[:-1]]
            pos_batch = [x.to(device) for x in pos_batch[:-1]]

            image, question_len, source, target, caption, q_idx_vec, pos, att, context = self.compute_cap_features(
                word_batch, pos_batch)

            # Forward pass
            self.optimizer.zero_grad()
            logits = self.model(image, caption, pos, context, att, source,
                                q_idx_vec)
            loss = masked_CE(logits, target, question_len)

            # Backward pass
            loss.backward()

            if self.opt.grad_clip:
                util.gradient_noise_and_clip(self.model.parameters(),
                                             self.opt.max_clip)

            self.optimizer.step()

            # Logging
            print_loss += loss.item()

            if self.global_step % self.opt.print_every == 0:
                info = {
                    'loss': print_loss / self.opt.print_every,
                    'time':
                    (time() - tic) / self.opt.print_every  # time per step
                }

                util.step_logging(self.logger, info, self.global_step)
                util.log_avg_grads(self.logger, self.model, self.global_step)

                steps_per_epoch = len(self.word_match_loader)
                step = self.global_step - epoch * steps_per_epoch
                remaining_steps = steps_per_epoch * (self.opt.max_epochs -
                                                     epoch) - step
                self.std_logger.info(
                    "{}, {}/{}| Loss: {} | Time per batch: {} | Epoch remaining time (HH:MM:SS) {} | "
                    "Elapsed time {} | Total remaining time {}".format(
                        epoch + 1, step, steps_per_epoch, info['loss'],
                        info['time'],
                        util.time_remaining(steps_per_epoch - step,
                                            info['time']),
                        util.time_elapsed(self.start_time, time()),
                        util.time_remaining(remaining_steps, info['time'])))
                print_loss, tic = 0, time()

            self.global_step = self.global_step + 1

        model_score = self.evaluate(epoch + 1)
        self.save_checkpoint(epoch, model_score)
Пример #7
0
    def train(self, epoch):

        print("Training")

        print_loss, tic = 0, time()
        self.model.train()

        for i, sample in enumerate(self.train_loader):
            image, source, target, caption_len, refs, ref_lens, pos = sample[:
                                                                             -3]
            image, source, target, caption_len, pos = [
                x.to(device)
                for x in [image, source, target, caption_len, pos]
            ]

            # Forward pass
            self.optimizer.zero_grad()
            result = self.model(image, source, pos)
            logits, pos_logits = result.logits, result.pos_logits

            # Get losses
            word_loss = masked_CE(logits, target, caption_len)
            pos_loss = masked_CE(pos_logits, pos, caption_len - 1)

            total_loss = word_loss + self.opt.pos_alpha * pos_loss

            # Backward pass
            total_loss.backward()

            if self.opt.grad_clip:
                util.gradient_noise_and_clip(self.model.parameters(),
                                             self.opt.max_clip)

            self.optimizer.step()

            loss = total_loss.item()

            # Logging
            print_loss += loss

            if self.global_step % self.opt.print_every == 0:
                info = {
                    'loss': print_loss / self.opt.print_every,
                    'time':
                    (time() - tic) / self.opt.print_every  # time per step
                }
                util.step_logging(self.logger, info, self.global_step)
                util.log_avg_grads(self.logger, self.model, self.global_step)

                steps_per_epoch = len(self.train_loader)
                step = self.global_step - epoch * steps_per_epoch
                remaining_steps = steps_per_epoch * (self.opt.max_epochs -
                                                     epoch) - step
                self.std_logger.info(
                    "{}, {}/{}| Loss: {} | Time per batch: {} | Epoch remaining time (HH:MM:SS) {} | "
                    "Elapsed time {} | Total remaining time {}".format(
                        epoch + 1, step, steps_per_epoch, info['loss'],
                        info['time'],
                        util.time_remaining(steps_per_epoch - step,
                                            info['time']),
                        util.time_elapsed(self.start_time, time()),
                        util.time_remaining(remaining_steps, info['time'])))
                print_loss, tic = 0, time()

            self.global_step = self.global_step + 1

        model_score = self.evaluate(epoch + 1)
        self.save_checkpoint(epoch, model_score)