コード例 #1
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    return s_cur_loss, w_cur_loss
コード例 #2
0
    def save_img_results(self,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         cnn_code,
                         region_features,
                         real_imgs,
                         name='current'):
        # Save images
        fake_imgs, attention_maps, _, _, _, _ = netG(noise, sent_emb,
                                                     words_embs, mask,
                                                     cnn_code, region_features)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                   attn_maps, att_sze, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(), None, cap_lens, None,
                                    self.batch_size)
        img_set, _ = \
            build_super_images(fake_imgs[i].detach().cpu(),
                               captions, self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
        '''
コード例 #3
0
def train(dataloader, cnn_model, rnn_model, d_model, batch_size, labels,
          optimizer, d_optimizer, epoch, ixtoword, image_dir):

    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    g_total_loss = 0
    d_total_loss = 0

    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
        class_ids = [c_id - 1 for c_id in class_ids]

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        _, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.item()
        w_total_loss1 += w_loss1.item()
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids,
                                     batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.item()
        s_total_loss1 += s_loss1.item()

        pred_by_sent, pred_class_by_sent = discriminator(sent_code)
        pred_by_emb, pred_class_by_emb = discriminator(sent_emb)

        targets = torch.LongTensor(class_ids).to(pred_by_emb.device)

        # ZSL G Losses
        g_loss_adv_sent = F.binary_cross_entropy_with_logits(
            pred_by_sent, torch.zeros_like(pred_by_sent))
        g_loss_adv_emb = F.binary_cross_entropy_with_logits(
            pred_by_emb, torch.zeros_like(pred_by_emb))
        g_loss_cls_sent = F.cross_entropy(pred_class_by_sent, targets)
        g_loss_cls_emb = F.cross_entropy(pred_class_by_emb, targets)
        g_loss = (g_loss_adv_sent + g_loss_adv_emb + g_loss_cls_sent +
                  g_loss_cls_emb)

        loss += g_loss * cfg.ZSL.LAMBDA
        g_total_loss = g_loss.item()

        loss.backward()

        # TODO: SummaryWriter

        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(),
                                       cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        # Discriminator step
        discriminator.zero_grad()

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
        words_features, sent_code = cnn_model(imgs[-1])

        pred_by_sent, pred_class_by_sent = discriminator(sent_code)
        pred_by_emb, pred_class_by_emb = discriminator(sent_emb)

        # ZSL D Losses
        d_loss_adv_sent = F.binary_cross_entropy_with_logits(
            pred_by_sent, torch.zeros_like(pred_by_sent))
        d_loss_adv_emb = F.binary_cross_entropy_with_logits(
            pred_by_emb, torch.ones_like(pred_by_emb))
        d_loss_cls_sent = F.cross_entropy(pred_class_by_sent, targets)
        d_loss_cls_emb = F.cross_entropy(pred_class_by_emb, targets)

        d_loss = (d_loss_adv_sent + d_loss_adv_emb + d_loss_cls_sent +
                  d_loss_cls_emb)
        d_loss.backward()
        d_optimizer.step()
        d_total_loss += d_loss.item()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            info = ('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                    's_loss {:5.2f} {:5.2f} | '
                    'w_loss {:5.2f} {:5.2f} | g_loss {:5.2f} d_loss {:5.2f}'.
                    format(epoch, step, len(dataloader),
                           elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                           s_cur_loss1, w_cur_loss0, w_cur_loss1, g_total_loss,
                           d_total_loss))

            LOG.info(info)
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            g_total_loss = 0
            d_total_loss = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
コード例 #4
0
ファイル: pretrain_DAMSM.py プロジェクト: sondn141/ManiGAN
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)

        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0
        w_total_loss1 += w_loss1
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0
        s_total_loss1 += s_loss1
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step > 0 and step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count