예제 #1
0
    def save_img_results(self, fake_imgs, attn_maps, bt_attn_maps, captions,
                         cap_lens, gen_iterations):
        font_max = [50, 50]
        font_size = [30, 50]
        batch_size = fake_imgs[0].size(0)
        # Save images
        for i in range(len(attn_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None

            attn_maps = attn_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword, attn_maps, att_sze,
                    lr_imgs=lr_img, font_max=font_max[i], font_size=font_size[i], batch_size=batch_size)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%d_%d.png'\
                    % (self.snapshot_dir, gen_iterations, i)
                im.save(fullpath)

            bt_attn_maps = bt_attn_maps[i]
            att_sze = bt_attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword, bt_attn_maps, att_sze,
                    lr_imgs=lr_img, font_max=font_max[i], font_size=font_size[i], batch_size=batch_size)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/bt_G_%d_%d.png'\
                    % (self.snapshot_dir, gen_iterations, i)
                im.save(fullpath)
예제 #2
0
    def save_img_results(self,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         name='current'):
        # Save images
        fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs,
                                               mask)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                    attn_maps, att_sze, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                myDriveAttnGanImage = '/content/drive/My Drive/cubImageGAN'
                fullpath = '%s/G_%s_%d_%d.png' % (self.image_dir, name,
                                                  gen_iterations, i)
                fullpathDrive = '%s/G_%s_%d_%d.png' % (myDriveAttnGanImage,
                                                       name, gen_iterations, i)
                im.save(fullpath)
                im.save(fullpathDrive)

        # for i in range(len(netsD)):
        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(), None, cap_lens, None,
                                    self.batch_size)
        img_set, _ = \
            build_super_images(fake_imgs[i].detach().cpu(),
                                captions, self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            myDriveAttnGanImage = '/content/drive/My Drive/cubImageGAN'
            fullpath = '%s/D_%s_%d.png' % (self.image_dir, name,
                                           gen_iterations)
            fullpathDrive = '%s/D_%s_%d.png' % (myDriveAttnGanImage, name,
                                                gen_iterations)
            im.save(fullpath)
            im.save(fullpathDrive)
예제 #3
0
    def save_img_results(self,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         transf_matrices_inv,
                         label_one_hot,
                         name='current'):
        # Save images
        # fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
        inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv,
                  label_one_hot)
        fake_imgs, attention_maps, _, _ = nn.parallel.data_parallel(
            netG, inputs, self.gpus)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                   attn_maps, att_sze, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        # for i in range(len(netsD)):
        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(), None, cap_lens, None,
                                    self.batch_size)
        img_set, _ = \
            build_super_images(fake_imgs[i].detach().cpu(),
                               captions, self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
예제 #4
0
    def save_img_results(self,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         name='current'):
        # Save images
        fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs,
                                               mask)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            # print(img.shape, lr_img.shape, attn_maps.shape) # debug
            img_set, _ = \
                build_super_images(img[:, :3], captions, self.ixtoword,
                                   attn_maps, att_sze, lr_imgs=lr_img[:, :3])
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        # for i in range(len(netsD)):
        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(), None, cap_lens, None,
                                    self.batch_size)
        img_set, _ = \
            build_super_images(fake_imgs[i][:, :3].detach().cpu(),
                               captions, self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
            self.writer.add_image(tag="image_attn",
                                  img_tensor=transforms.ToTensor()(im),
                                  global_step=gen_iterations)
예제 #5
0
    def save_img_results(self,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         gen_iterations,
                         name='current'):
        # Save images
        if cfg.CUDA:
            caption = Variable(torch.tensor([])).cuda()
        else:
            caption = Variable(torch.tensor([]))
        fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs,
                                               mask)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, caption, {},
                                   attn_maps, att_sze, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        # for i in range(len(netsD)):
        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(), words_embs, None,
                                    0, None, self.batch_size)
        img_set, _ = \
            build_super_images(fake_imgs[i].detach().cpu(),
                               caption, {}, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
예제 #6
0
    def save_img_results(self, netG, noise, sent_emb, words_embs, mask,
                         image_encoder, captions, cap_lens,
                         gen_iterations, transf_matrices_inv, label_one_hot, local_noise,
                         transf_matrices, max_objects, subset_idx, name='current'):
        # Save images
        inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv,
                  label_one_hot, max_objects)
        fake_imgs, attention_maps, _, _ = netG(*inputs)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img,
                                            batch_size=self.batch_size[0])
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png' % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        # for i in range(len(netsD)):
        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
            _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(),
                                        None, cap_lens, None, self.batch_size[subset_idx])
        else:
            _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(),
                                        None, cap_lens, None, self.batch_size[0])
        img_set, _ = build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png' % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
예제 #7
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size, writer, count,
             ixtoword, labels, image_dir):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)
        nef, att_sze = words_features.size(1), words_features.size(2)

        # hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).item()

        if step == 50:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    writer.add_scalars(main_tag="eval_loss",
                       tag_scalar_dict={
                           's_loss': s_cur_loss,
                           'w_loss': w_cur_loss
                       },
                       global_step=count)
    # save a image
    # attention Maps
    img_set, _ = \
        build_super_images(real_imgs[-1][:,:3].cpu(), captions,
                           ixtoword, attn, att_sze)
    if img_set is not None:
        im = Image.fromarray(img_set)
        fullpath = '%s/attention_maps_eval_%d.png' % (image_dir, count)
        im.save(fullpath)
        writer.add_image(tag="image_DAMSM_eval",
                         img_tensor=transforms.ToTensor()(im),
                         global_step=count)
    return s_cur_loss, w_cur_loss
예제 #8
0
    def save_img_results(self, netG, noise, sent_emb, words_embs, mask,
                         image_encoder, captions, cap_lens, gen_iterations,
                         name='current'):
        fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None

            attn_maps = attention_maps[i]
            att_size = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword, attn_maps,
                                   att_size, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/D_%s_%d.png' \
                           % (self.image_dir, name, gen_iterations)
                im.save(fullpath)
예제 #9
0
    def save_img_results(self,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         cnn_code,
                         region_features,
                         real_imgs,
                         netDCM,
                         real_features,
                         name='current'):
        # Save images
        fake_imgs, attention_maps, _, _, h_code, c_code = netG(
            noise, sent_emb, words_embs, mask, cnn_code, region_features)

        fake_img = netDCM(h_code, real_features, sent_emb, words_embs, mask,
                          c_code)

        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                   attn_maps, att_sze, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(), None, cap_lens, None,
                                    self.batch_size)
        img_set, _ = \
            build_super_images(fake_imgs[i].detach().cpu(),
                               captions, self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)

        img_set, _ = \
            build_super_images(fake_img.detach().cpu(),
                               captions, self.ixtoword, att_maps, att_sze)

        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/C_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
        '''
예제 #10
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
예제 #11
0
def train(dataloader, cnn_model, rnn_model, d_model, batch_size,
          labels, generator_optimizer, discriminator_optimizer, epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    d_total_loss = 0
    g_total_loss = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        target_classes = torch.LongTensor(class_ids)
        if cfg.CUDA:
            target_classes = target_classes.cuda()
        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size)
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1

        is_fake_0, pred_class_0 = d_model(sent_code)
        is_fake_1, pred_class_1 = d_model(sent_emb)
        g_loss = (F.binary_cross_entropy_with_logits(is_fake_0, torch.zeros_like(is_fake_0))
                 + F.binary_cross_entropy_with_logits(is_fake_1, torch.zeros_like(is_fake_1))
                 + F.cross_entropy(pred_class_0, target_classes)
                 + F.cross_entropy(pred_class_1, target_classes))
        loss += g_loss * cfg.TRAIN.SMOOTH.SUPERVISED_COEF

        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        s_total_loss0 += s_loss0.item()
        s_total_loss1 += s_loss1.item()
        w_total_loss0 += w_loss0.item()
        w_total_loss1 += w_loss1.item()
        g_total_loss += g_loss.item()
        generator_optimizer.step()
        d_model.zero_grad()

        _, sent_code = cnn_model(imgs[-1])
        hidden = rnn_model.init_hidden(batch_size)
        _, sent_emb = rnn_model(captions, cap_lens, hidden)

        is_fake_0, pred_class_0 = d_model(sent_code)
        is_fake_1, pred_class_1 = d_model(sent_emb)

        d_loss = (F.binary_cross_entropy_with_logits(is_fake_0, torch.zeros_like(is_fake_0))
                + F.binary_cross_entropy_with_logits(is_fake_1, torch.ones_like(is_fake_1))
                + F.cross_entropy(pred_class_0, target_classes)
                + F.cross_entropy(pred_class_1, target_classes))
        loss = d_loss
        loss.backward()
        discriminator_optimizer.step()
        d_total_loss += d_loss.item()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            d_cur_loss = d_total_loss / UPDATE_INTERVAL
            g_cur_loss = g_total_loss / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f} | d_loss {:5.2f} | g_loss {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1, d_cur_loss, g_cur_loss))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            d_total_loss = 0
            g_total_loss = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
예제 #12
0
def train(dataloader, cnn_model, trx_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):

    cnn_model.train()
    trx_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        print('step:{:6d}|{:3d}'.format(step, len(dataloader)), end='\r')
        trx_model.zero_grad()
        cnn_model.zero_grad()
        imgs, captions, cap_lens, class_ids, keys, _, _, _, _ = prepare_data(
            data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        #         print(words_features.shape,sent_code.shape)
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)
        #         print('nef:{0},att_sze:{1}'.format(nef,att_sze))

        #         hidden = trx_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        #         print('captions:',captions, captions.size())

        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        #         words_emb, sent_emb = trx_model(captions, cap_lens, hidden)

        words_emb, sent_emb = trx_model(captions)
        #         print('words_emb:',words_emb.size(),', sent_emb:', sent_emb.size())
        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)

        #         print(w_loss0.data)
        #         print('--------------------------')
        #         print(w_loss1.data)
        #         print('--------------------------')
        #         print(attn_maps[0].shape)

        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1
        #         print(loss)
        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1

        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data

        #         print(s_total_loss0[0],s_total_loss1[0])
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(trx_model.parameters(),
                                       cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if (step % UPDATE_INTERVAL == 0
                or step == (len(dataloader) - 1)) and step > 0:
            count = epoch * len(dataloader) + step

            #             print(count)
            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            tbw.add_scalar('Birds_Train/train_w_loss0',
                           float(w_cur_loss0.item()), epoch)
            tbw.add_scalar('Birds_Train/train_s_loss0',
                           float(s_cur_loss0.item()), epoch)
            tbw.add_scalar('Birds_Train/train_w_loss1',
                           float(w_cur_loss1.item()), epoch)
            tbw.add_scalar('Birds_Train/train_s_loss1',
                           float(s_cur_loss1.item()), epoch)
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps

            #             print(imgs[-1].cpu().shape, captions.shape, len(attn_maps),attn_maps[-1].shape, att_sze)
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '{0}/attention_maps_e{1}_s{2}.png'.format(
                    image_dir, epoch, step)
                im.save(fullpath)
    return count
예제 #13
0
def train(dataloader, cnn_model, nlp_model, text_encoder_type, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    cnn_model.train()
    nlp_model.train()
    text_encoder_type = text_encoder_type.casefold()
    assert text_encoder_type in (
        'rnn',
        'transformer',
    )
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        nlp_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data( data )

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # print( words_features.shape, sent_code.shape )
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        # Forward Prop:
        # inputs:
        #   captions: torch.LongTensor of ids of size batch x n_steps
        # outputs:
        #   words_emb: batch_size x nef x seq_len
        #   sent_emb: batch_size x nef
        if text_encoder_type == 'rnn':
            hidden = nlp_model.init_hidden(batch_size)
            words_emb, sent_emb = nlp_model(captions, cap_lens, hidden)
        elif text_encoder_type == 'transformer':
            words_emb = nlp_model(captions)[0].transpose(1, 2).contiguous()
            sent_emb = words_emb[:, :, -1].contiguous()
            # sent_emb = sent_emb.view(batch_size, -1)
        # print( words_emb.shape, sent_emb.shape )

        # Compute Loss:
        # NOTE: the ideal loss for Transformer may be different than that for bi-directional LSTM
        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss( sent_code, sent_emb, labels, class_ids, batch_size )
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        # Backprop:
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        if text_encoder_type == 'rnn':
            torch.nn.utils.clip_grad_norm(nlp_model.parameters(),
                                          cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            # print(  s_total_loss0, s_total_loss1 )
            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            # print(  w_total_loss0, w_total_loss1 )
            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.5f} {:5.5f} | '
                  'w_loss {:5.5f} {:5.5f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()

            # Attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
예제 #14
0
    def save_img_results(self,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         glove_words_embs,
                         clabels_feat,
                         mask,
                         hmaps,
                         rois,
                         fm_rois,
                         num_rois,
                         bt_masks,
                         fm_bt_masks,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         name='current'):
        # Save images
        glb_max_num_roi = int(torch.max(num_rois))
        fake_imgs, _, attention_maps, bt_attention_maps, _, _ = netG(
            noise, sent_emb, words_embs, glove_words_embs, clabels_feat, mask,
            hmaps, rois, fm_rois, num_rois, bt_masks, fm_bt_masks,
            glb_max_num_roi)

        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None

            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                   attn_maps, att_sze, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

            bt_attn_maps = bt_attention_maps[i]
            att_sze = bt_attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                   bt_attn_maps, att_sze, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/bt_G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps, _ = words_loss(region_features.detach(),
                                       words_embs.detach(), None, cap_lens,
                                       None, self.batch_size)
        img_set, _ = \
            build_super_images(fake_imgs[i].detach().cpu(),
                               captions, self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
예제 #15
0
파일: trainer.py 프로젝트: zxs789/Obj-GAN
    def save_img_results(self,
                         netG,
                         noise,
                         imgs,
                         bbox_maps_fwd,
                         bbox_maps_bwd,
                         bbox_fmaps,
                         hmaps,
                         rois,
                         num_rois,
                         gen_iterations,
                         name='current'):
        # Save images
        font_max = 20
        font_size = 12

        imgs = imgs.cpu()
        fake_hmaps = netG(noise, bbox_maps_fwd, bbox_maps_bwd, bbox_fmaps)

        fake_hmaps = fake_hmaps.squeeze().detach().cpu()
        hmaps = hmaps.squeeze().cpu()

        # prepare captions
        batch_size = fake_hmaps.size(0)
        captions = Variable(torch.zeros(batch_size, cfg.ROI.BOXES_NUM)).cuda()
        for batch_index in range(self.batch_size):
            for roi_index in range(num_rois[batch_index]):
                rela_cat_id = int(rois[batch_index, roi_index, 4])
                captions[batch_index,
                         roi_index] = self.cats_dict[rela_cat_id][0]

        att_sze = fake_hmaps.size(2)
        img_set, _ = build_super_images(imgs,
                                        captions,
                                        self.ixtoword,
                                        fake_hmaps,
                                        att_sze,
                                        lr_imgs=None,
                                        font_max=font_max,
                                        font_size=font_size,
                                        max_word_num=cfg.ROI.BOXES_NUM)

        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/G_%s_%d.png' % (self.image_dir, name,
                                           gen_iterations)
            im.save(fullpath)

        img_set, _ = build_super_images(imgs,
                                        captions,
                                        self.ixtoword,
                                        hmaps,
                                        att_sze,
                                        lr_imgs=None,
                                        font_max=font_max,
                                        font_size=font_size,
                                        max_word_num=cfg.ROI.BOXES_NUM)

        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png' % (self.image_dir, name,
                                           gen_iterations)
            im.save(fullpath)

        #
        img_set, _ = build_super_images2(imgs,
                                         captions,
                                         self.ixtoword,
                                         fake_hmaps,
                                         att_sze,
                                         lr_imgs=None,
                                         font_max=font_max,
                                         font_size=font_size,
                                         max_word_num=cfg.ROI.BOXES_NUM)

        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/G2_%s_%d.png' % (self.image_dir, name,
                                            gen_iterations)
            im.save(fullpath)

        img_set, _ = build_super_images2(imgs,
                                         captions,
                                         self.ixtoword,
                                         hmaps,
                                         att_sze,
                                         lr_imgs=None,
                                         font_max=font_max,
                                         font_size=font_size,
                                         max_word_num=cfg.ROI.BOXES_NUM)

        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D2_%s_%d.png' % (self.image_dir, name,
                                            gen_iterations)
            im.save(fullpath)
예제 #16
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir, exp):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()

    for step, data in enumerate(dataloader):
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        shape, cap, cap_len, cls_id, key = data
        sorted_cap_lens, sorted_cap_indices = torch.sort(cap_len, 0, True)

        #sort
        shapes = shape[sorted_cap_indices].squeeze()
        captions = cap[sorted_cap_indices].squeeze()
        cap_len = cap_len[sorted_cap_indices].squeeze()
        class_ids = cls_id[sorted_cap_indices].squeeze().numpy()

        if torch.cuda.is_available():
            shapes = shapes.cuda()
            captions = captions.cuda()
        #model
        words_features, sent_code = cnn_model(shapes)

        nef, att_sze = words_features.size(1), words_features.size(2)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, sorted_cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, sorted_cap_lens,
                                                 class_ids, batch_size)

        w_total_loss0 += w_loss0.item()
        w_total_loss1 += w_loss1.item()
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids,
                                     batch_size)

        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.item()
        s_total_loss1 += s_loss1.item()

        loss.backward()

        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time

            exp.log_metric('s_cur_loss0', s_cur_loss0, step=step, epoch=epoch)
            exp.log_metric('s_cur_loss1', s_cur_loss1, step=step, epoch=epoch)

            exp.log_metric('w_cur_loss0', w_cur_loss0, step=step, epoch=epoch)
            exp.log_metric('w_cur_loss1', w_cur_loss1, step=step, epoch=epoch)

            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
        if step == 1:
            fullpath = '%s/attention_maps%d' % (image_dir, step)
            build_super_images(shapes.cpu().detach().numpy(), captions,
                               cap_len, ixtoword, attn_maps, att_sze, exp,
                               fullpath, epoch)

    return count
예제 #17
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    train_function_start_time = time.time()
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0

    # print("keyword |||||||||||||||||||||||||||||||")
    # print("len(dataloader) : " , len(dataloader) )
    # print(" count = " ,  (epoch + 1) * len(dataloader)  )
    # print("keyword |||||||||||||||||||||||||||||||")
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step) !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!MARKER!!!!!!!!!!!!!!!!!!!!!!!!
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids,
                                     batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            # print ("====================================================")
            # print ("s_total_loss0 : " , s_total_loss0)
            # print ("s_total_loss0.item() : " , s_total_loss0.item())
            # print ("UPDATE_INTERVAL : " , UPDATE_INTERVAL)
            print("s_total_loss0.item()/UPDATE_INTERVAL : ",
                  s_total_loss0.item() / UPDATE_INTERVAL)
            print("s_total_loss1.item()/UPDATE_INTERVAL : ",
                  s_total_loss1.item() / UPDATE_INTERVAL)
            print("w_total_loss0.item()/UPDATE_INTERVAL : ",
                  w_total_loss0.item() / UPDATE_INTERVAL)
            print("w_total_loss1.item()/UPDATE_INTERVAL : ",
                  w_total_loss1.item() / UPDATE_INTERVAL)
            # print ("s_total_loss0/UPDATE_INTERVAL : " , s_total_loss0/UPDATE_INTERVAL)
            # print ("=====================================================")
            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            #Save image only every 8 epochs && Save it to The Drive
            if (epoch % 8 == 0):
                print("bulding images")
                img_set, _ = build_super_images(imgs[-1].cpu(), captions,
                                                ixtoword, attn_maps, att_sze)
                if img_set is not None:
                    im = Image.fromarray(img_set)
                    fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                    im.save(fullpath)
                    mydriveimg = '/content/drive/My Drive/cubImage'
                    drivepath = '%s/attention_maps%d.png' % (mydriveimg, epoch)
                    im.save(drivepath)
    print("keyTime |||||||||||||||||||||||||||||||")
    print("train_function_time : ", time.time() - train_function_start_time)
    print("KeyTime |||||||||||||||||||||||||||||||")
    return count
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        """
        Sets gradients of all model parameters to zero.
        Every time a variable is back propogated through, the gradient will be accumulated instead of being replaced.
        (This makes it easier for rnn, because each module will be back propogated through several times.)
        """
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
        class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)
        """Dont understand completely ??"""
        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids,
                                     batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data

        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))

            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = build_super_images(imgs[-1].cpu(), captions, ixtoword,
                                            attn_maps, att_sze)

            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps_%d_epoch_%d_step.png' % (
                    image_dir, epoch, step)
                im.save(fullpath)
    return count
예제 #19
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir, writer, logger, update_interval):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)
        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        # hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.item()
        w_total_loss1 += w_loss1.item()
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.item()
        s_total_loss1 += s_loss1.item()
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(),
                                       cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        global_step = epoch * len(dataloader) + step
        writer.add_scalars(main_tag="batch_loss",
                           tag_scalar_dict={
                               "loss": loss.cpu().item(),
                               "w_loss0": w_loss0.cpu().item(),
                               "w_loss1": w_loss1.cpu().item(),
                               "s_loss0": s_loss0.cpu().item(),
                               "s_loss1": s_loss1.cpu().item()
                           },
                           global_step=global_step)

        if step % update_interval == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / update_interval
            s_cur_loss1 = s_total_loss1 / update_interval

            w_cur_loss0 = w_total_loss0 / update_interval
            w_cur_loss1 = w_total_loss1 / update_interval

            elapsed = time.time() - start_time
            logger.info(
                '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                's_loss {:6.4f} {:6.4f} | '
                'w_loss {:6.4f} {:6.4f}'.format(
                    epoch, step, len(dataloader),
                    elapsed * 1000. / update_interval, s_cur_loss0,
                    s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
        if global_step % (10 * update_interval) == 0:
            img_set, _ = \
                build_super_images(imgs[-1][:,:3].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, count)
                im.save(fullpath)
                writer.add_image(tag="image_DAMSM",
                                 img_tensor=transforms.ToTensor()(im),
                                 global_step=count)
    return count
예제 #20
0
    def save_img_results(self,
                         real_img,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         transf_matrices_inv,
                         label_one_hot,
                         name='current',
                         num_visualize=8):

        qa_nums = (cap_lens > 0).sum(1)
        real_captions = captions
        captions, _ = make_fake_captions(qa_nums)  # fake caption.

        # Save images
        # fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
        inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv,
                  label_one_hot)
        fake_imgs, attention_maps, _, _ = nn.parallel.data_parallel(
            netG, inputs, self.gpus)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                   attn_maps, att_sze, lr_imgs=lr_img, nvis = num_visualize)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        for i in range(cfg.TREE.BRANCH_NUM):
            save_pure_img_results(real_img[i].detach().cpu(),
                                  fake_imgs[i].detach().cpu(),
                                  gen_iterations,
                                  self.image_dir,
                                  token='level%d' % i)

        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(), None, qa_nums, None,
                                    self.batch_size)
        img_set, _ = build_super_images(fake_imgs[i].detach().cpu(),
                                        captions,
                                        self.ixtoword,
                                        att_maps,
                                        att_sze,
                                        nvis=num_visualize)
        # FIXME currently the `render_attn_to_html` supports only the last level.
        # please implement multiple level rendering.
        html_doc = render_attn_to_html([
            real_img[i].detach().cpu(),
            fake_imgs[i].detach().cpu(),
        ],
                                       real_captions,
                                       self.ixtoword,
                                       att_maps,
                                       att_sze,
                                       None,
                                       info=['Real Images', 'Fake Images'])
        with open('%s/damsm_attn_%d.html' % (self.image_dir, gen_iterations),
                  'w') as html_f:
            html_f.write(str(html_doc))

        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
예제 #21
0
def train(dataloader, cnn_model, rnn_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    global global_step
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        global_step += 1
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        # imgs: b x 3 x nbasesize x nbasesize
        imgs, captions, cap_lens, \
            class_ids, _, _, _, keys, _ = prepare_data(data)
        
        class_ids = None # Oh. is this ok? FIXME

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)

        # num_caps = (cap_lens > 0).sum(1)
        per_qa_embs, avg_qa_embs, num_caps = Level1RNNEncodeMagic(captions, cap_lens, rnn_model)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, per_qa_embs, labels,
                                                 num_caps, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, avg_qa_embs, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            
            def make_fake_captions(num_caps):
                caps = torch.zeros(batch_size, cfg.TEXT.MAX_QA_NUM, dtype = torch.int64)
                ref = torch.arange(0, cfg.TEXT.MAX_QA_NUM).view(1, -1).repeat(batch_size, 1).cuda()
                targ = num_caps.view(-1, 1).repeat(1, cfg.TEXT.MAX_QA_NUM)
                caps[ref < targ] = 1
                return caps, {1: 'DUMMY'}

            _captions, _ixtoword = make_fake_captions(num_caps)
            
            html_doc = render_attn_to_html(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)

            with open('%s/attn_step%d.html' % (image_dir, global_step), 'w') as html_f:
                html_f.write(str(html_doc))
                
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), _captions,
                                   _ixtoword, attn_maps, att_sze)

            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, global_step)
                im.save(fullpath)
    return count
def train(dataloader, cnn_model, rnn_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    t_total_loss = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()

    if(cfg.LOCAL_PRETRAINED):
        tokenizer = tokenization.FullTokenizer(vocab_file=cfg.BERT_ENCODER.VOCAB, do_lower_case=True)
    else:
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    debug_flag = False
    #debug_flag = True

    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()
        if(debug_flag):
            with open('./debug0.pkl', 'wb') as f:
                pickle.dump({'data':data, 'cnn_model':cnn_model, 'rnn_model':rnn_model, 'labels':labels}, f)  

        #imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
        imgs, captions, cap_lens, class_ids, keys = prepare_data_bert(data, tokenizer)
        #imgs, captions, cap_lens, class_ids, keys, \
        #        input_ids, segment_ids, input_mask = prepare_data_bert(data, tokenizer)

        # sent_code: batch_size x nef
        #words_features, sent_code, word_logits = cnn_model(imgs[-1], captions)
        words_features, sent_code, word_logits = cnn_model(imgs[-1], captions, cap_lens)
        #words_features, sent_code, word_logits = cnn_model(imgs[-1], captions, input_ids, segment_ids, input_mask)
        # bs x T x vocab_size
        if(debug_flag):
            with open('./debug1.pkl', 'wb') as f:
                pickle.dump({'words_features':words_features, 'sent_code':sent_code, 'word_logits':word_logits}, f)  

        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
        #words_emb, sent_emb = rnn_model(captions, cap_lens, hidden, input_ids, segment_ids, input_mask)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,
                                                 cap_lens, class_ids, batch_size)
        if(debug_flag):
            with open('./debug2.pkl', 'wb') as f:
                pickle.dump({'words_features':words_features, 'words_emb':words_emb, 'labels':labels, 'cap_lens':cap_lens, 'class_ids':class_ids, 'batch_size':batch_size}, f)  

        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        if(debug_flag):
            with open('./debug3.pkl', 'wb') as f:
                pickle.dump({'sent_code':sent_code, 'sent_emb':sent_emb, 'labels':labels, 'class_ids':class_ids, 'batch_size':batch_size}, f)  

        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data

        # added code
        #print(word_logits.shape, captions.shape)
        t_loss = image_to_text_loss(word_logits, captions)
        if(debug_flag):
            with open('./debug4.pkl', 'wb') as f:
                pickle.dump({'word_logits':word_logits, 'captions':captions}, f)  

        loss += t_loss
        t_total_loss += t_loss.data

        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            t_curr_loss = t_total_loss.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f} | '
                  't_loss {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1,
                          t_curr_loss))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            t_total_loss = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count