Exemplo n.º 1
0
    def evaluate(self, cnn_model, trx_model, batch_size):
        cnn_model.eval()
        trx_model.eval()
        #         cap_model.eval() ###
        s_total_loss = 0
        w_total_loss = 0
        s_t_total_loss = 0
        w_t_total_loss = 0
        ### add caption criterion here. #####
        labels = Variable(torch.LongTensor(
            range(batch_size)))  # used for matching loss
        if cfg.CUDA:
            labels = labels.cuda()
        #####################################

        val_data_iter = iter(self.dataloader_val)
        for step in tqdm(range(len(val_data_iter)), leave=False):
            real_imgs, captions, masks, class_ids, cap_lens = val_data_iter.next(
            )
            class_ids = class_ids.numpy()

            ids = np.array(list(range(batch_size)))
            neg_ids = Variable(
                torch.LongTensor([
                    np.random.choice(ids[ids != x]) for x in ids
                ]))  # used for matching loss

            if cfg.CUDA:
                real_imgs, captions, masks, cap_lens = real_imgs.cuda(
                ), captions.cuda(), masks.cuda(), cap_lens.cuda()
                neg_ids = neg_ids.cuda()
            words_features, sent_code = cnn_model(real_imgs)
            words_emb, sent_emb = trx_model(captions, masks)

            w_loss0, w_loss1, attn = words_loss(words_features, words_emb[:, :,
                                                                          1:],
                                                labels, cap_lens - 1,
                                                class_ids, batch_size)
            w_total_loss += (w_loss0 + w_loss1).data

            s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels,
                                         class_ids, batch_size)
            s_total_loss += (s_loss0 + s_loss1).data

            w_t_loss0, w_t_loss1, _ = words_triplet_loss(
                words_features, words_emb[:, :, 1:], labels, neg_ids,
                cap_lens - 1, batch_size)
            w_t_total_loss += (w_t_loss0 + w_t_loss1).data

            s_t_loss0, s_t_loss1 = sent_triplet_loss(sent_code, sent_emb,
                                                     labels, neg_ids,
                                                     batch_size)
            s_t_total_loss += (s_t_loss0 + s_t_loss1).data

        s_cur_loss = s_total_loss / (step + 1)
        w_cur_loss = w_total_loss / (step + 1)

        s_t_cur_loss = s_t_total_loss / (step + 1)
        w_t_cur_loss = w_t_total_loss / (step + 1)
        return s_cur_loss, w_cur_loss, s_t_cur_loss, w_t_cur_loss
Exemplo n.º 2
0
    def train(self):

        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
        #     LAMBDA_FT,LAMBDA_FI,LAMBDA_DAMSM=01,50,10
        tb_dir = '../tensorboard/{0}_{1}_{2}'.format(cfg.DATASET_NAME,
                                                     cfg.CONFIG_NAME,
                                                     timestamp)
        mkdir_p(tb_dir)
        tbw = SummaryWriter(log_dir=tb_dir)  # Tensorboard logging

        ####### init models ########
        text_encoder, image_encoder, start_epoch, = self.build_models()
        labels = Variable(torch.LongTensor(range(
            self.batch_size)))  # used for matching loss

        text_encoder.train()
        image_encoder.train()

        ###############################################################

        ###### init optimizers #####
        optimizerI, optimizerT, lr_schedulerI, lr_schedulerT = self.define_optimizers(
            image_encoder, text_encoder)
        ############################################

        ##### init data #############################

        match_labels = self.prepare_labels()

        batch_size = self.batch_size
        ##################################################################

        ###### init caption model criterion ############
        if cfg.CUDA:
            labels = labels.cuda()
        #################################################

        tensorboard_step = 0
        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches

        #### print lambdas ###
        #         print('LAMBDA_GEN:{0},LAMBDA_CAP:{1},LAMBDA_FT:{2},LAMBDA_FI:{3},LAMBDA_DAMSM:{4}'.format(cfg.TRAIN.SMOOTH.LAMBDA_GEN
        #                                                                                                   ,cfg.TRAIN.SMOOTH.LAMBDA_CAP
        #                                                                                                   ,cfg.TRAIN.SMOOTH.LAMBDA_FT
        #                                                                                                   ,cfg.TRAIN.SMOOTH.LAMBDA_FI
        #                                                                                                   ,cfg.TRAIN.SMOOTH.LAMBDA_DAMSM))

        best_val_loss = 100.0
        for epoch in range(start_epoch, self.max_epoch):

            ##### set everything to trainable ####
            text_encoder.train()
            image_encoder.train()
            ####################################

            ####### init loss variables ############
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0

            s_t_total_loss0 = 0
            s_t_total_loss1 = 0
            w_t_total_loss0 = 0
            w_t_total_loss1 = 0

            total_damsm_loss = 0
            total_t_loss = 0
            total_combo_loss = 0

            ####### print out lr of each optimizer before training starts, make sure lrs are correct #########
            print('Learning rates: lr_i %.7f, lr_t %.7f' %
                  (optimizerI.param_groups[0]['lr'],
                   optimizerT.param_groups[0]['lr']))

            #########################################################################################

            start_t = time.time()

            data_iter = iter(self.data_loader)
            #             step = 0
            pbar = tqdm(range(self.num_batches))
            for step in pbar:
                #             while step < self.num_batches:
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                imgs, captions, masks, class_ids, cap_lens = data_iter.next()
                class_ids = class_ids.numpy()

                ids = np.array(list(range(batch_size)))
                neg_ids = Variable(
                    torch.LongTensor([
                        np.random.choice(ids[ids != x]) for x in ids
                    ]))  # used for matching loss

                if cfg.CUDA:
                    imgs, captions, masks, cap_lens = imgs.cuda(
                    ), captions.cuda(), masks.cuda(), cap_lens.cuda()
                    neg_ids = neg_ids.cuda()
                # add images, image masks, captions, caption masks for catr model

                ################## feedforward damsm model ##################
                image_encoder.zero_grad()  # image/text encoders zero_grad here
                text_encoder.zero_grad()

                words_features, sent_code = image_encoder(
                    imgs)  # input images to image encoder, feedforward
                nef, att_sze = words_features.size(1), words_features.size(2)
                # hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, masks)

                #### damsm losses
                s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels,
                                             class_ids, batch_size)
                s_total_loss0 += s_loss0.item()
                s_total_loss1 += s_loss1.item()
                damsm_loss = s_loss0 + s_loss1

                w_loss0, w_loss1, attn_maps = words_loss(
                    words_features, words_embs[:, :, 1:], labels, cap_lens - 1,
                    class_ids, batch_size)
                w_total_loss0 += w_loss0.item()
                w_total_loss1 += w_loss1.item()
                damsm_loss += w_loss0 + w_loss1

                total_damsm_loss += damsm_loss.item()

                #### triplet loss
                s_t_loss0, s_t_loss1 = sent_triplet_loss(
                    sent_code, sent_emb, labels, neg_ids, batch_size)
                s_t_total_loss0 += s_t_loss0.item()
                s_t_total_loss1 += s_t_loss1.item()
                t_loss = s_t_loss0 + s_t_loss1

                w_t_loss0, w_t_loss1, attn_maps = words_triplet_loss(
                    words_features, words_embs[:, :, 1:], labels, neg_ids,
                    cap_lens - 1, batch_size)
                w_t_total_loss0 += w_t_loss0.item()
                w_t_total_loss1 += w_t_loss1.item()
                t_loss += w_t_loss0 + w_t_loss1

                total_t_loss += t_loss.item()
                ############################################################################

                damsm_triplet_combo_loss = cfg.LAMBDA_DAMSM * damsm_loss + cfg.LAMBDA_TRIPLET * t_loss
                total_combo_loss += damsm_triplet_combo_loss.item()
                #                 damsm_loss.backward()
                #                 t_loss.backward()
                damsm_triplet_combo_loss.backward()

                torch.nn.utils.clip_grad_norm_(image_encoder.parameters(),
                                               cfg.clip_max_norm)
                optimizerI.step()

                torch.nn.utils.clip_grad_norm_(text_encoder.parameters(),
                                               cfg.clip_max_norm)
                optimizerT.step()
                ##################### loss values for each step #########################################
                ## damsm ##
                tbw.add_scalar('Train_step/train_w_step_loss0',
                               float(w_loss0.item()),
                               step + epoch * self.num_batches)
                tbw.add_scalar('Train_step/train_s_step_loss0',
                               float(s_loss0.item()),
                               step + epoch * self.num_batches)
                tbw.add_scalar('Train_step/train_w_step_loss1',
                               float(w_loss1.item()),
                               step + epoch * self.num_batches)
                tbw.add_scalar('Train_step/train_s_step_loss1',
                               float(s_loss1.item()),
                               step + epoch * self.num_batches)
                tbw.add_scalar('Train_step/train_damsm_step_loss',
                               float(damsm_loss.item()),
                               step + epoch * self.num_batches)

                ## triplet ##
                tbw.add_scalar('Train_step/train_w_t_step_loss0',
                               float(w_t_loss0.item()),
                               step + epoch * self.num_batches)
                tbw.add_scalar('Train_step/train_s_t_step_loss0',
                               float(s_t_loss0.item()),
                               step + epoch * self.num_batches)
                tbw.add_scalar('Train_step/train_w_t_step_loss1',
                               float(w_t_loss1.item()),
                               step + epoch * self.num_batches)
                tbw.add_scalar('Train_step/train_s_t_step_loss1',
                               float(s_t_loss1.item()),
                               step + epoch * self.num_batches)
                tbw.add_scalar('Train_step/train_t_step_loss',
                               float(t_loss.item()),
                               step + epoch * self.num_batches)

                ################################################################################################

                ############ tqdm descriptions showing running average loss in terminal ##############################
                #                 pbar.set_description('damsm %.5f' % ( float(total_damsm_loss) / (step+1)))
                pbar.set_description('combo_loss %.5f' %
                                     (float(total_combo_loss) / (step + 1)))
                ######################################################################################################
                ##########################################################
            v_s_cur_loss, v_w_cur_loss, v_s_t_cur_loss, v_w_t_cur_loss = self.evaluate(
                image_encoder, text_encoder, self.val_batch_size)
            print(
                '[epoch: %d] val_w_loss: %.4f, val_s_loss: %.4f, val_w_t_loss: %.4f, val_s_t_loss: %.4f'
                % (epoch, v_w_cur_loss, v_s_cur_loss, v_w_t_cur_loss,
                   v_s_t_cur_loss))
            print('-' * 80)
            ### val losses ###
            tbw.add_scalar('Val_step/val_w_loss', float(v_w_cur_loss), epoch)
            tbw.add_scalar('Val_step/val_s_loss', float(v_s_cur_loss), epoch)
            tbw.add_scalar('Val_step/val_w_t_loss', float(v_w_t_cur_loss),
                           epoch)
            tbw.add_scalar('Val_step/val_s_t_loss', float(v_s_t_cur_loss),
                           epoch)

            lr_schedulerI.step()
            lr_schedulerT.step()

            end_t = time.time()
            total_val_loss = (float(v_w_cur_loss) + float(v_s_cur_loss) +
                              float(v_w_t_cur_loss) +
                              float(v_s_t_cur_loss)) / 4.0
            if total_val_loss < best_val_loss:
                best_val_loss = total_val_loss
                self.save_model(image_encoder, text_encoder, optimizerI,
                                optimizerT, lr_schedulerI, lr_schedulerT,
                                epoch)