コード例 #1
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, w_imgs,  captions, cap_lens, class_ids, keys, \
                wrong_caps, wrong_caps_len, wrong_cls_id, _,_  = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0.item() + w_loss1.item())

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0.item() + s_loss1.item())

        if step == 50:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    return s_cur_loss, w_cur_loss
コード例 #2
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss[0] / step
    w_cur_loss = w_total_loss[0] / step

    return s_cur_loss, w_cur_loss
コード例 #3
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size, labels):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, _, _, _, keys, _ = prepare_data(data)
        
        class_ids = None # ok?
        words_features, sent_code = cnn_model(real_imgs[-1])
        
        # num_caps = (cap_lens > 0).sum(1)
        per_qa_embs, avg_qa_embs, num_caps = Level1RNNEncodeMagic(captions, cap_lens, rnn_model)

        w_loss0, w_loss1, attn = words_loss(words_features, per_qa_embs, labels,
                                            num_caps, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, avg_qa_embs, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss.item() / step
    w_cur_loss = w_total_loss.item() / step

    return s_cur_loss, w_cur_loss
コード例 #4
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size, writer, count,
             ixtoword, labels, image_dir):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)
        nef, att_sze = words_features.size(1), words_features.size(2)

        # hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).item()

        if step == 50:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    writer.add_scalars(main_tag="eval_loss",
                       tag_scalar_dict={
                           's_loss': s_cur_loss,
                           'w_loss': w_cur_loss
                       },
                       global_step=count)
    # save a image
    # attention Maps
    img_set, _ = \
        build_super_images(real_imgs[-1][:,:3].cpu(), captions,
                           ixtoword, attn, att_sze)
    if img_set is not None:
        im = Image.fromarray(img_set)
        fullpath = '%s/attention_maps_eval_%d.png' % (image_dir, count)
        im.save(fullpath)
        writer.add_image(tag="image_DAMSM_eval",
                         img_tensor=transforms.ToTensor()(im),
                         global_step=count)
    return s_cur_loss, w_cur_loss
コード例 #5
0
ファイル: trainer_s2.py プロジェクト: mshaikh2/MMRL
    def evaluate(self, dataloader, cnn_model, trx_model, cap_model,
                 batch_size):
        cnn_model.eval()
        trx_model.eval()
        cap_model.eval()  ###
        s_total_loss = 0
        w_total_loss = 0
        c_total_loss = 0  ###
        ### add caption criterion here. #####
        cap_criterion = torch.nn.CrossEntropyLoss(
        )  # add caption criterion here
        if cfg.CUDA:
            cap_criterion = cap_criterion.cuda()  # add caption criterion here
        cap_criterion.eval()
        #####################################
        for step, data in enumerate(dataloader, 0):
            real_imgs, captions, cap_lens, class_ids, keys, cap_imgs, cap_img_masks, sentences, sent_masks = prepare_data(
                data)

            words_features, sent_code = cnn_model(cap_imgs)

            words_emb, sent_emb = trx_model(captions)

            ##### add catr here #####
            cap_preds = cap_model(
                words_features, cap_img_masks, sentences[:, :-1],
                sent_masks[:, :-1])  # caption model feedforward

            cap_loss = caption_loss(cap_criterion, cap_preds, sentences)

            c_total_loss += cap_loss.data
            #########################

            w_loss0, w_loss1, attn = words_loss(words_features, words_emb,
                                                labels, cap_lens, class_ids,
                                                batch_size)
            w_total_loss += (w_loss0 + w_loss1).data

            s_loss0, s_loss1 = \
                sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
            s_total_loss += (s_loss0 + s_loss1).data

#             if step == 50:
#                 break

        s_cur_loss = s_total_loss / step
        w_cur_loss = w_total_loss / step
        c_cur_loss = c_total_loss / step

        return s_cur_loss, w_cur_loss, c_cur_loss
コード例 #6
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0

    count = 0
    img_lst = []
    sent_lst = []
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        count += batch_size

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        img_lst.append(sent_code)
        sent_lst.append(sent_emb)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if count >= 1000:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    img_embs = torch.cat(img_lst)
    sent_embs = torch.cat(sent_lst)

    acc, pred = compute_topk(img_embs, sent_embs)

    logger.info(
        '| end epoch {:3d} | top-5 ({:4d}) {:5.2f} valid loss {:5.2f} {:5.2f} | lr {:.5f}|'
        .format(epoch, count, acc, s_cur_loss, w_cur_loss, lr))

    return s_cur_loss, w_cur_loss
コード例 #7
0
ファイル: train_DAMSM.py プロジェクト: axis-bit/SpeechFab
def evaluate(dataloader, cnn_model, rnn_model, batch_size, exp, epoch):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        shape, cap, cap_len, cls_id, key = data

        sorted_cap_lens, sorted_cap_indices = torch.sort(cap_len, 0, True)

        shapes = shape[sorted_cap_indices].squeeze()
        captions = cap[sorted_cap_indices].squeeze()
        class_ids = cls_id[sorted_cap_indices].squeeze().numpy()

        if torch.cuda.is_available():
            shapes = shapes.cuda()
            captions = captions.cuda()

        words_features, sent_code = cnn_model(shapes)
        nef, att_sze = words_features.size(1), words_features.size(2)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, sorted_cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, sorted_cap_lens,
                                                 class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).item()

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids,
                                     batch_size)
        s_total_loss += (s_loss0 + s_loss1).item()

        if step == 10:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    exp.log_metric('eva_s_cur_loss', s_cur_loss, epoch=epoch)
    exp.log_metric('eva_w_cur_loss', w_cur_loss, epoch=epoch)

    return s_cur_loss, w_cur_loss
コード例 #8
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()

    print("** rnn structure **", rnn_model.rnn)
    print("** embedd structure **", rnn_model.encoder)

    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])

        print("valid captions", captions.size())
        print("valid cap_lens", cap_lens.size())
        print("valid word features", words_features.size())
        print("valid sent_code", sent_code.size())
        print(cap_lens)

        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss.item() / step
    w_cur_loss = w_total_loss.item() / step

    return s_cur_loss, w_cur_loss
コード例 #9
0
ファイル: trainer.py プロジェクト: axis-bit/SpeechFab
    def _generator_train_step(self, sent_emb, words_emb, cap_lens, step,
                              epoch):

        noise = torch.randn(self.batch_size, self.nz).to(self.device)
        fake_data, mu, logvar = self.netG(noise, sent_emb)
        fake_data = fake_data.to(self.device)

        if False:
            region_features, cnn_code = self.image_encoder(fake_data)

            match_labels = Variable(torch.LongTensor(range(self.batch_size)))
            s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, match_labels,
                                         None, self.batch_size)
            s_loss = (s_loss0 + s_loss1) * 1
            w_loss0, w_loss1, _ = words_loss(region_features, words_emb,
                                             match_labels, cap_lens, None,
                                             self.batch_size)
            w_loss = (w_loss0 + w_loss1) * 1

            g_out = self.netD(fake_data, sent_emb)
            kl_loss = KL_loss(mu, logvar)
            loss = self.gan_loss(g_out, "gen") + kl_loss + s_loss + w_loss

            self.exp.log_metric('kl_loss',
                                kl_loss.item(),
                                step=step,
                                epoch=epoch)
            self.exp.log_metric('s_loss',
                                s_loss.item(),
                                step=step,
                                epoch=epoch)

        else:
            g_out = self.netD(fake_data, sent_emb)
            kl_loss = KL_loss(mu, logvar)
            loss = self.gan_loss(g_out, "gen") + kl_loss

        self.optG.zero_grad()
        loss.backward()
        self.optG.step()
        return loss.item()
コード例 #10
0
def evaluate(dataloader, cnn_model, nlp_model, text_encoder_type, batch_size):
    cnn_model.eval()
    nlp_model.eval()
    text_encoder_type = text_encoder_type.casefold()
    assert text_encoder_type in (
        'rnn',
        'transformer',
    )
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data( data )

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        if text_encoder_type == 'rnn':
            hidden = nlp_model.init_hidden(batch_size)
            words_emb, sent_emb = nlp_model(captions, cap_lens, hidden)
        elif text_encoder_type == 'transformer':
            words_emb = nlp_model(captions)[0].transpose(1, 2).contiguous()
            sent_emb = words_emb[:, :, -1].contiguous()
            # sent_emb = sent_emb.view(batch_size, -1)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss( sent_code, sent_emb, labels, class_ids, batch_size )
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss.item() / step
    w_cur_loss = w_total_loss.item() / step

    return s_cur_loss, w_cur_loss
コード例 #11
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size, labels):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, class_ids = prepare_data(data)
        words_features, sent_code = cnn_model(real_imgs[-1])
        if step == len(dataloader)-1:
          batch_size = len(subset_val)-(len(dataloader)-1)*batch_size
          labels = Variable(torch.LongTensor(range(batch_size)))
          labels = labels.cuda()          
        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data
        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data
    s_cur_loss = s_total_loss.item() / len(dataloader)
    w_cur_loss = w_total_loss.item() / len(dataloader)
    return s_cur_loss, w_cur_loss
コード例 #12
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size, labels):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    t_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        #imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
        imgs, captions, cap_lens, class_ids, keys = prepare_data_bert(data, tokenizer=None)

        #words_features, sent_code, word_logits = cnn_model(imgs[-1], captions)
        words_features, sent_code, word_logits = cnn_model(imgs[-1], captions, cap_lens)
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        t_loss = image_to_text_loss(word_logits, captions)
        t_total_loss += t_loss.data

        if step == 50:
            break

    s_cur_loss = s_total_loss.item() / step
    w_cur_loss = w_total_loss.item() / step
    t_cur_loss = t_total_loss.item() / step

    return s_cur_loss, w_cur_loss, t_cur_loss
コード例 #13
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir, writer, logger, update_interval):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)
        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        # hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.item()
        w_total_loss1 += w_loss1.item()
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.item()
        s_total_loss1 += s_loss1.item()
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(),
                                       cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        global_step = epoch * len(dataloader) + step
        writer.add_scalars(main_tag="batch_loss",
                           tag_scalar_dict={
                               "loss": loss.cpu().item(),
                               "w_loss0": w_loss0.cpu().item(),
                               "w_loss1": w_loss1.cpu().item(),
                               "s_loss0": s_loss0.cpu().item(),
                               "s_loss1": s_loss1.cpu().item()
                           },
                           global_step=global_step)

        if step % update_interval == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / update_interval
            s_cur_loss1 = s_total_loss1 / update_interval

            w_cur_loss0 = w_total_loss0 / update_interval
            w_cur_loss1 = w_total_loss1 / update_interval

            elapsed = time.time() - start_time
            logger.info(
                '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                's_loss {:6.4f} {:6.4f} | '
                'w_loss {:6.4f} {:6.4f}'.format(
                    epoch, step, len(dataloader),
                    elapsed * 1000. / update_interval, s_cur_loss0,
                    s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
        if global_step % (10 * update_interval) == 0:
            img_set, _ = \
                build_super_images(imgs[-1][:,:3].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, count)
                im.save(fullpath)
                writer.add_image(tag="image_DAMSM",
                                 img_tensor=transforms.ToTensor()(im),
                                 global_step=count)
    return count
コード例 #14
0
ファイル: train_DAMSM.py プロジェクト: axis-bit/SpeechFab
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir, exp):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()

    for step, data in enumerate(dataloader):
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        shape, cap, cap_len, cls_id, key = data
        sorted_cap_lens, sorted_cap_indices = torch.sort(cap_len, 0, True)

        #sort
        shapes = shape[sorted_cap_indices].squeeze()
        captions = cap[sorted_cap_indices].squeeze()
        cap_len = cap_len[sorted_cap_indices].squeeze()
        class_ids = cls_id[sorted_cap_indices].squeeze().numpy()

        if torch.cuda.is_available():
            shapes = shapes.cuda()
            captions = captions.cuda()
        #model
        words_features, sent_code = cnn_model(shapes)

        nef, att_sze = words_features.size(1), words_features.size(2)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, sorted_cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, sorted_cap_lens,
                                                 class_ids, batch_size)

        w_total_loss0 += w_loss0.item()
        w_total_loss1 += w_loss1.item()
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids,
                                     batch_size)

        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.item()
        s_total_loss1 += s_loss1.item()

        loss.backward()

        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time

            exp.log_metric('s_cur_loss0', s_cur_loss0, step=step, epoch=epoch)
            exp.log_metric('s_cur_loss1', s_cur_loss1, step=step, epoch=epoch)

            exp.log_metric('w_cur_loss0', w_cur_loss0, step=step, epoch=epoch)
            exp.log_metric('w_cur_loss1', w_cur_loss1, step=step, epoch=epoch)

            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
        if step == 1:
            fullpath = '%s/attention_maps%d' % (image_dir, step)
            build_super_images(shapes.cpu().detach().numpy(), captions,
                               cap_len, ixtoword, attn_maps, att_sze, exp,
                               fullpath, epoch)

    return count
コード例 #15
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        """
        Sets gradients of all model parameters to zero.
        Every time a variable is back propogated through, the gradient will be accumulated instead of being replaced.
        (This makes it easier for rnn, because each module will be back propogated through several times.)
        """
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
        class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)
        """Dont understand completely ??"""
        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids,
                                     batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data

        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))

            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = build_super_images(imgs[-1].cpu(), captions, ixtoword,
                                            attn_maps, att_sze)

            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps_%d_epoch_%d_step.png' % (
                    image_dir, epoch, step)
                im.save(fullpath)
    return count
コード例 #16
0
    def evaluate(self, split_dir, hmap_size):
        text_encoder, image_encoder, netG, netShpG = self.build_models()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise_img = Variable(torch.FloatTensor(batch_size, nz))
        if cfg.CUDA:
            noise_img = noise_img.cuda()

        if cfg.TEST.USE_GT_BOX_SEG > 0:
            noise_shp = Variable(
                torch.FloatTensor(batch_size, cfg.ROI.BOXES_NUM,
                                  len(self.cats_index_dict) * 4))
            if cfg.CUDA:
                noise_shp = noise_shp.cuda()

        match_labels = self.prepare_labels()
        clabels_emb = self.prepare_cat_emb()

        predictions, fake_acts_set, acts_set, w_accuracy, s_accuracy = [], [], [], [], []
        region_features_set, cnn_code_set, words_embs_set, sent_emb_set, \
            class_ids_set, cap_lens_set = [], [], [], [], [], []
        gen_iterations = 0
        rp_count = 0

        for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (1) Prepare general test data
                #######################################################
                if cfg.TEST.USE_GT_BOX_SEG < 2:
                    imgs, acts, captions, glove_captions, cap_lens, gt_hmaps, bbox_maps_fwd, \
                        bbox_maps_bwd, bbox_fmaps, rois, fm_rois, num_rois, gt_bt_masks, \
                        gt_fm_bt_masks, class_ids, keys, sent_ids = prepare_data(data)
                else:
                    imgs, acts, captions, glove_captions, cap_lens, bbox_maps_fwd, \
                        bbox_maps_bwd, bbox_fmaps, rois, fm_rois, num_rois, \
                        class_ids, keys, sent_ids = prepare_gen_data(data)
                    gt_hmaps = None

                #######################################################
                # (2) Prepare real shapes or generate fake shapes
                #######################################################
                batch_size = len(num_rois)
                max_num_roi = int(torch.max(num_rois))
                noise_img = noise_img[:batch_size].data.normal_(0, 1)

                if cfg.TEST.USE_GT_BOX_SEG > 0:  # 1 for gt box and gen shape, 2 for gen box and gen shape
                    noise_shp = noise_shp[:batch_size].data.normal_(0, 1)
                    raw_masks = netShpG(noise_shp[:, :max_num_roi],
                                        bbox_maps_fwd, bbox_maps_bwd,
                                        bbox_fmaps)
                    raw_masks = raw_masks.squeeze(2).detach()
                    if gen_iterations % self.display_interval == 0:
                        self.save_shape_results(imgs[0],
                                                raw_masks,
                                                rois[0],
                                                num_rois,
                                                gen_iterations,
                                                model_type='G')
                        if gt_hmaps is not None:
                            self.save_shape_results(imgs[0],
                                                    gt_hmaps[0].squeeze(),
                                                    rois[0],
                                                    num_rois,
                                                    gen_iterations,
                                                    model_type='D')
                    gen_hmaps, gen_bt_masks, gen_fm_bt_masks = form_hmaps(
                        raw_masks, num_rois, rois[0], hmap_size,
                        len(self.cats_index_dict))
                    hmaps = gen_hmaps
                    bt_masks = gen_bt_masks
                    fm_bt_masks = gen_fm_bt_masks
                else:  # 0 for gt box and gt shape
                    hmaps = gt_hmaps
                    bt_masks = gt_bt_masks
                    fm_bt_masks = gt_fm_bt_masks

                #######################################################
                # (3) Prepare or compute text embeddings
                #######################################################
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                max_len = int(torch.max(cap_lens))
                words_embs, sent_emb = text_encoder(captions, cap_lens,
                                                    max_len)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # glove_words_embs: batch_size x 50 (glove dim) x seq_len
                glove_words_embs = self.glove_emb(glove_captions.view(-1))
                glove_words_embs = glove_words_embs.detach().view(
                    glove_captions.size(0), glove_captions.size(1), -1)
                glove_words_embs = glove_words_embs[:, :num_words].transpose(
                    1, 2)
                # clabels_feat: batch x 50 (glove dim) x max_num_roi x 1
                clabels_feat = form_clabels_feat(clabels_emb, rois[0],
                                                 num_rois)

                #######################################################
                # (4) Generate fake images
                #######################################################
                fake_imgs, _, attn_maps, bt_attn_maps, mu, logvar = netG(
                    noise_img, sent_emb, words_embs, glove_words_embs,
                    clabels_feat, mask, hmaps, rois, fm_rois, num_rois,
                    bt_masks, fm_bt_masks, max_num_roi)

                if gen_iterations % self.display_interval == 0:
                    if cfg.TEST.SAVE_OPTIONS == 'SNAPSHOT':
                        self.save_img_results(fake_imgs, attn_maps,
                                              bt_attn_maps, captions, cap_lens,
                                              gen_iterations)
                    elif cfg.TEST.SAVE_OPTIONS == 'IMAGE':
                        self.save_singleimages(fake_imgs[-1], keys, sent_ids)
                    print('%d / %d' % (gen_iterations, self.num_batches))

                #######################################################
                # (5) Prepare intermediate results for evaluation
                #######################################################
                images = fake_imgs[-1].detach()

                region_features, cnn_code = image_encoder(images)
                region_features, cnn_code = region_features.detach(
                ), cnn_code.detach()

                if rp_count >= cfg.TEST.RP_POOL_SIZE:
                    region_features_set = torch.cat(region_features_set, dim=0)
                    region_features_set = region_features_set[:cfg.TEST.
                                                              RP_POOL_SIZE]
                    cnn_code_set = torch.cat(cnn_code_set, dim=0)
                    cnn_code_set = cnn_code_set[:cfg.TEST.RP_POOL_SIZE]

                    sent_emb_set = torch.cat(sent_emb_set, dim=0)
                    sent_emb_set = sent_emb_set[:cfg.TEST.RP_POOL_SIZE]
                    class_ids_set = np.concatenate(class_ids_set, 0)
                    class_ids_set = class_ids_set[:cfg.TEST.RP_POOL_SIZE]
                    cap_lens_set = torch.cat(cap_lens_set, dim=0)
                    cap_lens_set = cap_lens_set[:cfg.TEST.RP_POOL_SIZE]

                    max_len = int(torch.max(cap_lens_set))
                    new_words_embs_set = torch.zeros(rp_count,
                                                     sent_emb_set.size(1),
                                                     max_len)
                    accum = 0
                    for tmp_words_embs in words_embs_set:
                        tmp_bs, tmp_max_len = tmp_words_embs.size(
                            0), tmp_words_embs.size(2)
                        new_words_embs_set[
                            accum:accum +
                            tmp_bs, :, :tmp_max_len] = tmp_words_embs
                        accum += tmp_bs
                    new_words_embs_set = new_words_embs_set[:cfg.TEST.
                                                            RP_POOL_SIZE]

                    _, _, _, w_accu = words_loss(region_features_set,
                                                 new_words_embs_set,
                                                 match_labels,
                                                 cap_lens_set,
                                                 class_ids_set,
                                                 cfg.TEST.RP_POOL_SIZE,
                                                 is_training=False)
                    _, _, s_accu = sent_loss(cnn_code_set,
                                             sent_emb_set,
                                             match_labels,
                                             class_ids_set,
                                             cfg.TEST.RP_POOL_SIZE,
                                             is_training=False)
                    w_accuracy.append(w_accu)
                    s_accuracy.append(s_accu)

                    rp_count = 0
                    region_features_set, cnn_code_set, words_embs_set, sent_emb_set, \
                        class_ids_set, cap_lens_set = [], [], [], [], [], []
                else:
                    region_features_set.append(region_features.cpu())
                    cnn_code_set.append(cnn_code.cpu())
                    words_embs_set.append(words_embs.cpu())
                    sent_emb_set.append(sent_emb.cpu())
                    class_ids_set.append(class_ids)
                    cap_lens_set.append(cap_lens.cpu())
                    rp_count += batch_size

                if cfg.TEST.USE_TF:
                    denorm_images = denorm_imgs(images)
                    pred = self.inception_score.get_inception_pred(
                        denorm_images)
                else:
                    pred = self.inception_model(images)
                    pred = pred.data.cpu().numpy()
                predictions.append(pred)

                if cfg.TEST.USE_TF:
                    fake_acts = self.inception_score.get_fid_pred(
                        denorm_images)
                else:
                    fake_acts = get_activations(images,
                                                self.inception_model_fid,
                                                batch_size)
                acts_set.append(acts)
                fake_acts_set.append(fake_acts)

                gen_iterations += 1
                if gen_iterations >= cfg.TEST.TEST_IMG_NUM:
                    break

        if cfg.TEST.USE_TF:
            self.inception_score.close_sess()

        #######################################################
        # (6) Evaluation
        #######################################################

        predictions = np.concatenate(predictions, 0)
        mean, std = compute_inception_score(predictions,
                                            min(10, self.batch_size))
        mean_conf, std_conf = \
            negative_log_posterior_probability(predictions, min(10, self.batch_size))
        accu_w, std_w, accu_s, std_s = np.mean(w_accuracy), np.std(
            w_accuracy), np.mean(s_accuracy), np.std(s_accuracy)

        acts_set = np.concatenate(acts_set, 0)
        fake_acts_set = np.concatenate(fake_acts_set, 0)
        real_mu, real_sigma = calculate_activation_statistics(acts_set)
        fake_mu, fake_sigma = calculate_activation_statistics(fake_acts_set)
        fid_score = calculate_frechet_distance(real_mu, real_sigma, fake_mu,
                                               fake_sigma)

        fullpath = '%s/scores.txt' % (self.score_dir)
        with open(fullpath, 'w') as fp:
            fp.write(
                'mean, std, mean_conf, std_conf, accu_w, std_w, accu_s, std_s, fid_score \n'
            )
            fp.write('%f, %f, %f, %f, %f, %f, %f, %f, %f' %
                     (mean, std, mean_conf, std_conf, accu_w, std_w, accu_s,
                      std_s, fid_score))

        print(
            'inception_score: mean, std, mean_conf, std_conf, accu_w, std_w, accu_s, std_s, fid_score'
        )
        print('inception_score: %f, %f, %f, %f, %f, %f, %f, %f, %f' %
              (mean, std, mean_conf, std_conf, accu_w, std_w, accu_s, std_s,
               fid_score))
コード例 #17
0
ファイル: trainer.py プロジェクト: axis-bit/t2v.pytorch
    def train(self):
        text_encoder, image_encoder = self.build_models()
        netG, netD = self.netG, self.netD
        optimizerG = self.optimizerG
        optimizerD = self.optimizerD
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        start_epoch = 0

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz)).to(device)
        fixed_noise = Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)).to(device)

        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            epoch = epoch + self.start_epoch
            for step, data in enumerate(self.data_loader):
                start_t = time.time()

                shape, cap, cap_len, cls_id, key = data

                sorted_cap_lens, sorted_cap_indices = torch.sort(cap_len, 0, True)

                #sort
                shapes = shape[sorted_cap_indices].squeeze().to(device)
                captions = cap[sorted_cap_indices].squeeze().to(device)
                class_ids = cls_id[sorted_cap_indices].squeeze().numpy()

                hidden = text_encoder.init_hidden(self.batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, sorted_cap_lens, hidden)
                # words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_shapes, mu, logvar = netG(noise, sent_emb)

                #######################################################
                # (3) Update D network
                ######################################################

                real_labels = torch.FloatTensor(self.batch_size).fill_(1).to(device)
                fake_labels = torch.FloatTensor(self.batch_size).fill_(0).to(device)

                netD.zero_grad()

                real_features = netD(shapes).to(device)
                cond_real_errD = nn.BCELoss()(real_features, real_labels)
                fake_features = netD(fake_shapes).to(device)
                cond_fake_errD = nn.BCELoss()(fake_features, fake_labels)



                errD_total = cond_real_errD + cond_fake_errD / 2.

                d_real_acu = torch.ge(real_features.squeeze(), 0.5).float()
                d_fake_acu = torch.le(fake_features.squeeze(), 0.5).float()
                d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu),0))

                if d_total_acu < 0.85:
                    errD_total.backward(retain_graph=True)
                    optimizerD.step()

                # #######################################################
                # # (4) Update G network: maximize log(D(G(z)))
                # ######################################################
                # # compute total loss for training G
                # step += 1
                # gen_iterations += 1

                # # do not need to compute gradient for Ds
                # # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                # errG_total, G_logs = \
                #     generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                #                    words_embs, sent_emb, match_labels, cap_lens, class_ids)

                labels = Variable(torch.LongTensor(range(self.batch_size)))
                real_labels = torch.FloatTensor(self.batch_size).fill_(1).to(device)

                real_features = netD(fake_shapes)
                cond_real_errG = nn.BCELoss()(real_features, real_labels)

                kl_loss = KL_loss(mu, logvar)
                errG_total = kl_loss + cond_real_errG



                if step % 10 == 0:
                    region_features, cnn_code = image_encoder(fake_shapes)

                    w_loss0, w_loss1, _ = words_loss(region_features, words_embs,
                                                    labels, sorted_cap_lens,
                                                    class_ids, self.batch_size)
                    w_loss = (w_loss0 + w_loss1) * \
                        cfg.TRAIN.SMOOTH.LAMBDA

                    s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb,
                                                labels, class_ids, self.batch_size)
                    s_loss = (s_loss0 + s_loss1) * \
                        cfg.TRAIN.SMOOTH.LAMBDA

                    errG_total += s_loss + w_loss
                    self.exp.metric('s_loss', s_loss.item())
                    self.exp.metric('w_loss', w_loss.item())

                # print('kl: %.2f w s, %.2f %.2f, cond %.2f' % (kl_loss.item(), w_loss.item(), s_loss.item(), cond_real_errG.item()))
                # # backward and update parameters
                errG_total.backward()


                optimizerG.step()

                end_t = time.time()

                self.exp.metric('d_loss', errD_total.item())
                self.exp.metric('g_loss', errG_total.item())
                self.exp.metric('act', d_total_acu.item())

                print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                    % (epoch, self.max_epoch, step, self.num_batches,
                        errD_total.item(), errG_total.item(),
                        end_t - start_t))

                if step % 500 == 0:
                    fullpath = '%s/lean_%d_%d.png' % (self.image_dir,epoch, step)
                    build_images(fake_shapes, captions, self.ixtoword, fullpath)

            torch.save(netG.state_dict(),'%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
            torch.save(netD.state_dict(),'%s/netD_epoch_%d.pth' % (self.model_dir, epoch))
            print('Save G/Ds models.')
コード例 #18
0
ファイル: trainer_s2.py プロジェクト: mshaikh2/MMRL
    def train(self):

        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

        tb_dir = '../tensorboard/{0}_{1}_{2}'.format(cfg.DATASET_NAME,
                                                     cfg.CONFIG_NAME,
                                                     timestamp)
        mkdir_p(tb_dir)
        tbw = SummaryWriter(log_dir=tb_dir)  # Tensorboard logging

        text_encoder, image_encoder, netG, netsD, start_epoch, cap_model = self.build_models(
        )
        labels = Variable(torch.LongTensor(range(
            self.batch_size)))  # used for matching loss

        text_encoder.train()
        image_encoder.train()
        for k, v in image_encoder.named_children(
        ):  # set the input layer1-5 not training and no grads.
            if k in frozen_list_image_encoder:
                v.training = False
                v.requires_grad_(False)
        netG.train()
        for i in range(len(netsD)):
            netsD[i].train()
        cap_model.train()

        avg_param_G = copy_G_params(netG)
        optimizerI, optimizerT, optimizerG , optimizersD , optimizerC , lr_schedulerC \
        , lr_schedulerI , lr_schedulerT = self.define_optimizers(image_encoder
                                                                , text_encoder
                                                                , netG
                                                                , netsD
                                                                , cap_model)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))

        cap_criterion = torch.nn.CrossEntropyLoss(
        )  # add caption criterion here
        if cfg.CUDA:
            labels = labels.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            cap_criterion = cap_criterion.cuda()  # add caption criterion here
        cap_criterion.train()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):

            ##### set everything to trainable ####
            text_encoder.train()
            image_encoder.train()
            netG.train()
            cap_model.train()
            for k, v in image_encoder.named_children():
                if k in frozen_list_image_encoder:
                    v.train(False)
            for i in range(len(netsD)):
                netsD[i].train()
            ##### set everything to trainable ####

            fi_w_total_loss0 = 0
            fi_w_total_loss1 = 0
            fi_s_total_loss0 = 0
            fi_s_total_loss1 = 0
            ft_w_total_loss0 = 0
            ft_w_total_loss1 = 0
            ft_s_total_loss0 = 0
            ft_s_total_loss1 = 0
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            c_total_loss = 0

            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                print('step:{:6d}|{:3d}'.format(step, self.num_batches),
                      end='\r')
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                # add images, image masks, captions, caption masks for catr model
                imgs, captions, cap_lens, class_ids, keys, cap_imgs, cap_img_masks, sentences, sent_masks = prepare_data(
                    data)

                ################## feedforward damsm model ##################
                image_encoder.zero_grad()  # image/text encoders zero_grad here
                text_encoder.zero_grad()

                words_features, sent_code = image_encoder(
                    cap_imgs
                )  # input catr images to image encoder, feedforward, Nx256x17x17
                #                 words_features, sent_code = image_encoder(imgs[-1]) # input image_encoder
                nef, att_sze = words_features.size(1), words_features.size(2)
                # hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(
                    captions)  #, cap_lens, hidden)

                #### damsm losses
                w_loss0, w_loss1, attn_maps = words_loss(
                    words_features, words_embs, labels, cap_lens, class_ids,
                    batch_size)
                w_total_loss0 += w_loss0.data
                w_total_loss1 += w_loss1.data
                damsm_loss = w_loss0 + w_loss1

                s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels,
                                             class_ids, batch_size)
                s_total_loss0 += s_loss0.data
                s_total_loss1 += s_loss1.data
                damsm_loss += s_loss0 + s_loss1

                #                 damsm_loss.backward()

                #                 words_features = words_features.detach()
                # real image real text matching loss graph cleared here
                # grad accumulated -> text_encoder
                #                  -> image_encoder
                #################################################################################

                ################## feedforward image encoder and caption model ##################
                #                 words_features, sent_code = image_encoder(cap_imgs)
                cap_model.zero_grad()  # caption model zero_grad here

                cap_preds = cap_model(
                    words_features, cap_img_masks, sentences[:, :-1],
                    sent_masks[:, :-1])  # caption model feedforward
                cap_loss = caption_loss(cap_criterion, cap_preds, sentences)
                c_total_loss += cap_loss.data
                #                 cap_loss.backward() # caption loss graph cleared,
                # grad accumulated -> cap_model -> image_encoder
                torch.nn.utils.clip_grad_norm_(cap_model.parameters(),
                                               config.clip_max_norm)
                #                 optimizerC.step() # update cap_model params
                #################################################################################

                ############ Prepare the input to Gan from the output of text_encoder ################
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #                 f_img = np.asarray(fake_imgs[-1].permute((0,2,3,1)).detach().cpu())
                #                 print('fake_imgs.size():{0},fake_imgs.min():{1},fake_imgs.max():{2}'.format(fake_imgs[-1].size()
                #                                   ,fake_imgs[-1].min()
                #                                   ,fake_imgs[-1].max()))

                #                 print('f_img.shape:{0}'.format(f_img.shape))

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    print(i)
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # 14 -- 2800 iterations=steps for 1 epoch
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')

                #### temporary check ####
#                 if step == 5:
#                     break
#                 print('fake_img shape:',fake_imgs[-1].size())

#                 # this is fine #### exists in GeneratorLoss
#                 fake_imgs[-1] = fake_imgs[-1].detach()
#                 ####### fake imge real text matching loss #################
#                 fi_word_features, fi_sent_code = image_encoder(fake_imgs[-1])
# #                 words_embs, sent_emb = text_encoder(captions) # to update the text

#                 fi_w_loss0, fi_w_loss1, fi_attn_maps = words_loss(fi_word_features, words_embs, labels,
#                                                  cap_lens, class_ids, batch_size)

#                 fi_w_total_loss0 += fi_w_loss0.data
#                 fi_w_total_loss1 += fi_w_loss1.data

#                 fi_damsm_loss = fi_w_loss0 + fi_w_loss1

#                 fi_s_loss0, fi_s_loss1 = sent_loss(fi_sent_code, sent_emb, labels, class_ids, batch_size)

#                 fi_s_total_loss0 += fi_s_loss0.data
#                 fi_s_total_loss1 += fi_s_loss1.data

#                 fi_damsm_loss += fi_s_loss0 + fi_s_loss1

#                 fi_damsm_loss.backward()

###### real image fake text matching loss ##############

                fake_preds = torch.argmax(cap_preds,
                                          axis=-1)  # capation predictions
                fake_captions = tokenizer.batch_decode(
                    fake_preds.tolist(),
                    skip_special_tokens=True)  # list of strings
                fake_outputs = retokenizer.batch_encode_plus(
                    fake_captions,
                    max_length=64,
                    padding='max_length',
                    add_special_tokens=False,
                    return_attention_mask=True,
                    return_token_type_ids=False,
                    truncation=True)
                fake_tokens = fake_outputs['input_ids']
                #                 fake_tkmask = fake_outputs['attention_mask']
                f_tokens = np.zeros((len(fake_tokens), 15), dtype=np.int64)
                f_cap_lens = []
                cnt = 0
                for i in fake_tokens:
                    temp = np.array([x for x in i if x != 27299 and x != 0])
                    num_words = len(temp)
                    if num_words <= 15:
                        f_tokens[cnt][:num_words] = temp
                    else:
                        ix = list(np.arange(num_words))  # 1, 2, 3,..., maxNum
                        np.random.shuffle(ix)
                        ix = ix[:15]
                        ix = np.sort(ix)
                        f_tokens[cnt] = temp[ix]
                        num_words = 15
                    f_cap_lens.append(num_words)
                    cnt += 1

                f_tokens = Variable(torch.tensor(f_tokens))
                f_cap_lens = Variable(torch.tensor(f_cap_lens))
                if cfg.CUDA:
                    f_tokens = f_tokens.cuda()
                    f_cap_lens = f_cap_lens.cuda()

                ft_words_emb, ft_sent_emb = text_encoder(
                    f_tokens)  # input text_encoder

                ft_w_loss0, ft_w_loss1, ft_attn_maps = words_loss(
                    words_features, ft_words_emb, labels, f_cap_lens,
                    class_ids, batch_size)

                ft_w_total_loss0 += ft_w_loss0.data
                ft_w_total_loss1 += ft_w_loss1.data

                ft_damsm_loss = ft_w_loss0 + ft_w_loss1

                ft_s_loss0, ft_s_loss1 = sent_loss(sent_code, ft_sent_emb,
                                                   labels, class_ids,
                                                   batch_size)

                ft_s_total_loss0 += ft_s_loss0.data
                ft_s_total_loss1 += ft_s_loss1.data

                ft_damsm_loss += ft_s_loss0 + ft_s_loss1

                #                 ft_damsm_loss.backward()

                total_multimodal_loss = damsm_loss + ft_damsm_loss + cap_loss
                total_multimodal_loss.backward()
                ## loss = 0.5*loss1 + 0.4*loss2 + ...
                ## loss.backward() -> accumulate grad value in parameters.grad

                ## loss1 = 0.5*loss1
                ## loss1.backward()

                torch.nn.utils.clip_grad_norm_(image_encoder.parameters(),
                                               cfg.TRAIN.RNN_GRAD_CLIP)

                optimizerI.step()

                torch.nn.utils.clip_grad_norm_(text_encoder.parameters(),
                                               cfg.TRAIN.RNN_GRAD_CLIP)
                optimizerT.step()

                optimizerC.step()  # update cap_model params

            lr_schedulerC.step()
            lr_schedulerI.step()
            lr_schedulerT.step()

            end_t = time.time()

            tbw.add_scalar('Loss_D', float(errD_total.item()), epoch)
            tbw.add_scalar('Loss_G', float(errG_total.item()), epoch)
            tbw.add_scalar('train_w_loss0', float(w_total_loss0.item()), epoch)
            tbw.add_scalar('train_s_loss0', float(s_total_loss0.item()), epoch)
            tbw.add_scalar('train_w_loss1', float(w_total_loss1.item()), epoch)
            tbw.add_scalar('train_s_loss1', float(s_total_loss1.item()), epoch)
            tbw.add_scalar('train_c_loss', float(c_total_loss.item()), epoch)

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data,
                   errG_total.data, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:

                self.save_model(netG, avg_param_G, image_encoder, text_encoder,
                                netsD, epoch, cap_model, optimizerC,
                                optimizerI, optimizerT, lr_schedulerC,
                                lr_schedulerI, lr_schedulerT)

            v_s_cur_loss, v_w_cur_loss, v_c_cur_loss = self.evaluate(
                self.dataloader_val, image_encoder, text_encoder, cap_model,
                self.batch_size)
            print(
                'v_s_cur_loss:{:.5f}, v_w_cur_loss:{:.5f}, v_c_cur_loss:{:.5f}'
                .format(v_s_cur_loss, v_w_cur_loss, v_c_cur_loss))
            tbw.add_scalar('val_w_loss', float(v_w_cur_loss), epoch)
            tbw.add_scalar('val_s_loss', float(v_s_cur_loss), epoch)
            tbw.add_scalar('val_c_loss', float(v_c_cur_loss), epoch)

        self.save_model(netG, avg_param_G, image_encoder, text_encoder, netsD,
                        self.max_epoch, cap_model, optimizerC, optimizerI,
                        optimizerT, lr_schedulerC, lr_schedulerI,
                        lr_schedulerT)
コード例 #19
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    train_function_start_time = time.time()
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0

    # print("keyword |||||||||||||||||||||||||||||||")
    # print("len(dataloader) : " , len(dataloader) )
    # print(" count = " ,  (epoch + 1) * len(dataloader)  )
    # print("keyword |||||||||||||||||||||||||||||||")
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step) !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!MARKER!!!!!!!!!!!!!!!!!!!!!!!!
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids,
                                     batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            # print ("====================================================")
            # print ("s_total_loss0 : " , s_total_loss0)
            # print ("s_total_loss0.item() : " , s_total_loss0.item())
            # print ("UPDATE_INTERVAL : " , UPDATE_INTERVAL)
            print("s_total_loss0.item()/UPDATE_INTERVAL : ",
                  s_total_loss0.item() / UPDATE_INTERVAL)
            print("s_total_loss1.item()/UPDATE_INTERVAL : ",
                  s_total_loss1.item() / UPDATE_INTERVAL)
            print("w_total_loss0.item()/UPDATE_INTERVAL : ",
                  w_total_loss0.item() / UPDATE_INTERVAL)
            print("w_total_loss1.item()/UPDATE_INTERVAL : ",
                  w_total_loss1.item() / UPDATE_INTERVAL)
            # print ("s_total_loss0/UPDATE_INTERVAL : " , s_total_loss0/UPDATE_INTERVAL)
            # print ("=====================================================")
            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            #Save image only every 8 epochs && Save it to The Drive
            if (epoch % 8 == 0):
                print("bulding images")
                img_set, _ = build_super_images(imgs[-1].cpu(), captions,
                                                ixtoword, attn_maps, att_sze)
                if img_set is not None:
                    im = Image.fromarray(img_set)
                    fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                    im.save(fullpath)
                    mydriveimg = '/content/drive/My Drive/cubImage'
                    drivepath = '%s/attention_maps%d.png' % (mydriveimg, epoch)
                    im.save(drivepath)
    print("keyTime |||||||||||||||||||||||||||||||")
    print("train_function_time : ", time.time() - train_function_start_time)
    print("KeyTime |||||||||||||||||||||||||||||||")
    return count
コード例 #20
0
def train(dataloader, cnn_model, nlp_model, text_encoder_type, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    cnn_model.train()
    nlp_model.train()
    text_encoder_type = text_encoder_type.casefold()
    assert text_encoder_type in (
        'rnn',
        'transformer',
    )
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        nlp_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data( data )

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # print( words_features.shape, sent_code.shape )
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        # Forward Prop:
        # inputs:
        #   captions: torch.LongTensor of ids of size batch x n_steps
        # outputs:
        #   words_emb: batch_size x nef x seq_len
        #   sent_emb: batch_size x nef
        if text_encoder_type == 'rnn':
            hidden = nlp_model.init_hidden(batch_size)
            words_emb, sent_emb = nlp_model(captions, cap_lens, hidden)
        elif text_encoder_type == 'transformer':
            words_emb = nlp_model(captions)[0].transpose(1, 2).contiguous()
            sent_emb = words_emb[:, :, -1].contiguous()
            # sent_emb = sent_emb.view(batch_size, -1)
        # print( words_emb.shape, sent_emb.shape )

        # Compute Loss:
        # NOTE: the ideal loss for Transformer may be different than that for bi-directional LSTM
        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss( sent_code, sent_emb, labels, class_ids, batch_size )
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        # Backprop:
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        if text_encoder_type == 'rnn':
            torch.nn.utils.clip_grad_norm(nlp_model.parameters(),
                                          cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            # print(  s_total_loss0, s_total_loss1 )
            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            # print(  w_total_loss0, w_total_loss1 )
            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.5f} {:5.5f} | '
                  'w_loss {:5.5f} {:5.5f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()

            # Attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
コード例 #21
0
def train(dataloader, cnn_model, rnn_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    t_total_loss = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()

    if(cfg.LOCAL_PRETRAINED):
        tokenizer = tokenization.FullTokenizer(vocab_file=cfg.BERT_ENCODER.VOCAB, do_lower_case=True)
    else:
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    debug_flag = False
    #debug_flag = True

    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()
        if(debug_flag):
            with open('./debug0.pkl', 'wb') as f:
                pickle.dump({'data':data, 'cnn_model':cnn_model, 'rnn_model':rnn_model, 'labels':labels}, f)  

        #imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
        imgs, captions, cap_lens, class_ids, keys = prepare_data_bert(data, tokenizer)
        #imgs, captions, cap_lens, class_ids, keys, \
        #        input_ids, segment_ids, input_mask = prepare_data_bert(data, tokenizer)

        # sent_code: batch_size x nef
        #words_features, sent_code, word_logits = cnn_model(imgs[-1], captions)
        words_features, sent_code, word_logits = cnn_model(imgs[-1], captions, cap_lens)
        #words_features, sent_code, word_logits = cnn_model(imgs[-1], captions, input_ids, segment_ids, input_mask)
        # bs x T x vocab_size
        if(debug_flag):
            with open('./debug1.pkl', 'wb') as f:
                pickle.dump({'words_features':words_features, 'sent_code':sent_code, 'word_logits':word_logits}, f)  

        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
        #words_emb, sent_emb = rnn_model(captions, cap_lens, hidden, input_ids, segment_ids, input_mask)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,
                                                 cap_lens, class_ids, batch_size)
        if(debug_flag):
            with open('./debug2.pkl', 'wb') as f:
                pickle.dump({'words_features':words_features, 'words_emb':words_emb, 'labels':labels, 'cap_lens':cap_lens, 'class_ids':class_ids, 'batch_size':batch_size}, f)  

        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        if(debug_flag):
            with open('./debug3.pkl', 'wb') as f:
                pickle.dump({'sent_code':sent_code, 'sent_emb':sent_emb, 'labels':labels, 'class_ids':class_ids, 'batch_size':batch_size}, f)  

        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data

        # added code
        #print(word_logits.shape, captions.shape)
        t_loss = image_to_text_loss(word_logits, captions)
        if(debug_flag):
            with open('./debug4.pkl', 'wb') as f:
                pickle.dump({'word_logits':word_logits, 'captions':captions}, f)  

        loss += t_loss
        t_total_loss += t_loss.data

        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            t_curr_loss = t_total_loss.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f} | '
                  't_loss {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1,
                          t_curr_loss))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            t_total_loss = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
コード例 #22
0
def train(dataloader, cnn_model, trx_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):

    cnn_model.train()
    trx_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        print('step:{:6d}|{:3d}'.format(step, len(dataloader)), end='\r')
        trx_model.zero_grad()
        cnn_model.zero_grad()
        imgs, captions, cap_lens, class_ids, keys, _, _, _, _ = prepare_data(
            data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        #         print(words_features.shape,sent_code.shape)
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)
        #         print('nef:{0},att_sze:{1}'.format(nef,att_sze))

        #         hidden = trx_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        #         print('captions:',captions, captions.size())

        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        #         words_emb, sent_emb = trx_model(captions, cap_lens, hidden)

        words_emb, sent_emb = trx_model(captions)
        #         print('words_emb:',words_emb.size(),', sent_emb:', sent_emb.size())
        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)

        #         print(w_loss0.data)
        #         print('--------------------------')
        #         print(w_loss1.data)
        #         print('--------------------------')
        #         print(attn_maps[0].shape)

        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1
        #         print(loss)
        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1

        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data

        #         print(s_total_loss0[0],s_total_loss1[0])
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(trx_model.parameters(),
                                       cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if (step % UPDATE_INTERVAL == 0
                or step == (len(dataloader) - 1)) and step > 0:
            count = epoch * len(dataloader) + step

            #             print(count)
            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            tbw.add_scalar('Birds_Train/train_w_loss0',
                           float(w_cur_loss0.item()), epoch)
            tbw.add_scalar('Birds_Train/train_s_loss0',
                           float(s_cur_loss0.item()), epoch)
            tbw.add_scalar('Birds_Train/train_w_loss1',
                           float(w_cur_loss1.item()), epoch)
            tbw.add_scalar('Birds_Train/train_s_loss1',
                           float(s_cur_loss1.item()), epoch)
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps

            #             print(imgs[-1].cpu().shape, captions.shape, len(attn_maps),attn_maps[-1].shape, att_sze)
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '{0}/attention_maps_e{1}_s{2}.png'.format(
                    image_dir, epoch, step)
                im.save(fullpath)
    return count
コード例 #23
0
def train(dataloader, cnn_model, rnn_model, d_model, batch_size,
          labels, generator_optimizer, discriminator_optimizer, epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    d_total_loss = 0
    g_total_loss = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        target_classes = torch.LongTensor(class_ids)
        if cfg.CUDA:
            target_classes = target_classes.cuda()
        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size)
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1

        is_fake_0, pred_class_0 = d_model(sent_code)
        is_fake_1, pred_class_1 = d_model(sent_emb)
        g_loss = (F.binary_cross_entropy_with_logits(is_fake_0, torch.zeros_like(is_fake_0))
                 + F.binary_cross_entropy_with_logits(is_fake_1, torch.zeros_like(is_fake_1))
                 + F.cross_entropy(pred_class_0, target_classes)
                 + F.cross_entropy(pred_class_1, target_classes))
        loss += g_loss * cfg.TRAIN.SMOOTH.SUPERVISED_COEF

        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        s_total_loss0 += s_loss0.item()
        s_total_loss1 += s_loss1.item()
        w_total_loss0 += w_loss0.item()
        w_total_loss1 += w_loss1.item()
        g_total_loss += g_loss.item()
        generator_optimizer.step()
        d_model.zero_grad()

        _, sent_code = cnn_model(imgs[-1])
        hidden = rnn_model.init_hidden(batch_size)
        _, sent_emb = rnn_model(captions, cap_lens, hidden)

        is_fake_0, pred_class_0 = d_model(sent_code)
        is_fake_1, pred_class_1 = d_model(sent_emb)

        d_loss = (F.binary_cross_entropy_with_logits(is_fake_0, torch.zeros_like(is_fake_0))
                + F.binary_cross_entropy_with_logits(is_fake_1, torch.ones_like(is_fake_1))
                + F.cross_entropy(pred_class_0, target_classes)
                + F.cross_entropy(pred_class_1, target_classes))
        loss = d_loss
        loss.backward()
        discriminator_optimizer.step()
        d_total_loss += d_loss.item()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            d_cur_loss = d_total_loss / UPDATE_INTERVAL
            g_cur_loss = g_total_loss / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f} | d_loss {:5.2f} | g_loss {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1, d_cur_loss, g_cur_loss))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            d_total_loss = 0
            g_total_loss = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
コード例 #24
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
コード例 #25
0
def train(dataloader, cnn_model, rnn_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    global global_step
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        global_step += 1
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        # imgs: b x 3 x nbasesize x nbasesize
        imgs, captions, cap_lens, \
            class_ids, _, _, _, keys, _ = prepare_data(data)
        
        class_ids = None # Oh. is this ok? FIXME

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)

        # num_caps = (cap_lens > 0).sum(1)
        per_qa_embs, avg_qa_embs, num_caps = Level1RNNEncodeMagic(captions, cap_lens, rnn_model)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, per_qa_embs, labels,
                                                 num_caps, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, avg_qa_embs, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            
            def make_fake_captions(num_caps):
                caps = torch.zeros(batch_size, cfg.TEXT.MAX_QA_NUM, dtype = torch.int64)
                ref = torch.arange(0, cfg.TEXT.MAX_QA_NUM).view(1, -1).repeat(batch_size, 1).cuda()
                targ = num_caps.view(-1, 1).repeat(1, cfg.TEXT.MAX_QA_NUM)
                caps[ref < targ] = 1
                return caps, {1: 'DUMMY'}

            _captions, _ixtoword = make_fake_captions(num_caps)
            
            html_doc = render_attn_to_html(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)

            with open('%s/attn_step%d.html' % (image_dir, global_step), 'w') as html_f:
                html_f.write(str(html_doc))
                
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), _captions,
                                   _ixtoword, attn_maps, att_sze)

            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, global_step)
                im.save(fullpath)
    return count