Esempio n. 1
0
def trainSingle(dataloader, cnn_model, rnn_model, batch_size, labels,
                optimizer, epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()

    data = next(iter(dataloader))

    # print('step', step)
    rnn_model.zero_grad()
    cnn_model.zero_grad()

    imgs, captions, cap_lens, \
        class_ids, keys = prepare_data(data)

    # words_features: batch_size x nef x 17 x 17
    # sent_code: batch_size x nef
    words_features, sent_code = cnn_model(imgs[-1])
    #print("words features shape",words_features.shape)
    #print("sent code shape", sent_code.shape)
    # --> batch_size x nef x 17*17
    nef, att_sze = words_features.size(1), words_features.size(2)
    #print("nef att_sze", nef, att_sze)
    # words_features = words_features.view(batch_size, nef, -1)

    hidden = rnn_model.init_hidden(batch_size)
    # words_emb: batch_size x nef x seq_len
    # sent_emb: batch_size x nef

    #print("train captions", captions.size() )
    #print("train cap_lens", cap_lens.size() )
    #print("train word features", words_features.size() )
    #print("train sent_code", sent_code.size() )

    words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
    #print("words_emb shape", words_emb.size() )
    #print("sent_emb shape", sent_emb.size() )

    words_loss(words_features, words_emb, labels, cap_lens, class_ids,
               batch_size)
Esempio n. 2
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()

    print("** rnn structure **", rnn_model.rnn)
    print("** embedd structure **", rnn_model.encoder)

    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])

        print("valid captions", captions.size())
        print("valid cap_lens", cap_lens.size())
        print("valid word features", words_features.size())
        print("valid sent_code", sent_code.size())
        print(cap_lens)

        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss.item() / step
    w_cur_loss = w_total_loss.item() / step

    return s_cur_loss, w_cur_loss
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        #print("words features shape",words_features.shape)
        #print("sent code shape", sent_code.shape)
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        #print("nef att_sze", nef, att_sze)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef

        #print("train captions", captions.size() )
        #print("train cap_lens", cap_lens.size() )
        #print("train word features", words_features.size() )
        #print("train sent_code", sent_code.size() )

        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
        #print("words_emb shape", words_emb.size() )
        #print("sent_emb shape", sent_emb.size() )

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)

        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(),
                                       cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
Esempio n. 4
0
def testproc2():
    args = parse_args()
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    if args.gpu_id == -1:
        cfg.CUDA = False
    else:
        cfg.GPU_ID = args.gpu_id

    if args.data_dir != '':
        cfg.DATA_DIR = args.data_dir
    print('Using config:')
    pprint.pprint(cfg)

    if not cfg.TRAIN.FLAG:
        args.manualSeed = 100
    elif args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    random.seed(args.manualSeed)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    if cfg.CUDA:
        torch.cuda.manual_seed_all(args.manualSeed)

    ##########################################################################
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
        (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    model_dir = os.path.join(output_dir, 'Model')
    image_dir = os.path.join(output_dir, 'Image')
    mkdir_p(model_dir)
    mkdir_p(image_dir)

    torch.cuda.set_device(cfg.GPU_ID)
    cudnn.benchmark = True

    # Get data loader ##################################################
    imsize = cfg.TREE.BASE_SIZE * (2**(cfg.TREE.BRANCH_NUM - 1))
    batch_size = cfg.TRAIN.BATCH_SIZE
    image_transform = transforms.Compose([
        transforms.Resize(int(imsize * 76 / 64)),
        transforms.RandomCrop(imsize),
        transforms.RandomHorizontalFlip()
    ])

    dataset = FashionTextDataset(cfg.DATA_DIR,
                                 'train',
                                 base_size=cfg.TREE.BASE_SIZE,
                                 transform=image_transform)

    print(dataset.n_words, dataset.embeddings_num)
    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             drop_last=True,
                                             shuffle=True,
                                             num_workers=int(cfg.WORKERS))

    # Train ##############################################################
    rnn_model, cnn_model, labels, start_epoch = build_models()
    para = list(rnn_model.parameters())
    for v in cnn_model.parameters():
        if v.requires_grad:
            para.append(v)

    data = next(iter(dataloader))
    imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

    # check last item from images list.
    print(imgs[-1].shape)
    print("labels", labels)
    #print("class ids",class_ids)

    words_features, sent_code = cnn_model(imgs[-1])
    print("words features shape", words_features.shape)
    print("sent code shape", sent_code.shape)
    # --> batch_size x nef x 17*17
    nef, att_sze = words_features.size(1), words_features.size(2)
    print("nef att_sze", nef, att_sze)
    # words_features = words_features.view(batch_size, nef, -1)

    hidden = rnn_model.init_hidden(batch_size)
    for i, h in enumerate(hidden):
        print("hidden size", i + 1, h.size())

    # 2 x batch_size x hidden_size

    # words_emb: batch_size x nef x seq_len
    # sent_emb: batch_size x nef

    print("train captions", captions.size())
    print("train cap_lens", cap_lens.size())
    #print("train word features", words_features.size() )
    #print("train sent_code", sent_code.size() )

    words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
    print("words_emb shape", words_emb.size())
    print("sent_emb shape", sent_emb.size())

    i = 10
    masks = []
    if class_ids is not None:
        mask = (class_ids == class_ids[i]).astype(np.uint8)
        mask[i] = 0
        masks.append(mask.reshape((1, -1)))

    print("no masks, if class ids are sequential.", masks)

    #data_dir = "/home/donchan/Documents/DATA/CULTECH_BIRDS/CUB_200_2011/train"
    #if os.path.isfile(data_dir + '/class_info.pickle'):
    #    with open(data_dir + '/class_info.pickle', 'rb') as f:
    #        class_id = pickle.load(f, encoding="latin1")

    # Get the i-th text description
    words_num = cap_lens[i]
    print(words_num)
    # -> 1 x nef x words_num
    word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()
    print(word.size())
    # -> batch_size x nef x words_num
    word = word.repeat(batch_size, 1, 1)
    print(word.size())
    #print(word)

    context = words_features.clone()
    query = word.clone()

    batch_size, queryL = query.size(0), query.size(2)
    ih, iw = context.size(2), context.size(3)
    sourceL = ih * iw

    # --> batch x sourceL x ndf
    context = context.view(batch_size, -1, sourceL)
    contextT = torch.transpose(context, 1, 2).contiguous()

    # Get attention
    # (batch x sourceL x ndf)(batch x ndf x queryL)
    # -->batch x sourceL x queryL
    attn = torch.bmm(contextT, query)  # Eq. (7) in AttnGAN paper
    # --> batch*sourceL x queryL
    attn = attn.view(batch_size * sourceL, queryL)

    #print("attn on Eq.8 on GlobalAttention", attn.size()  , attn.data.cpu().sum() ) # 13872, 6   / 13872, 7 ??
    attn = nn.Softmax(dim=0)(attn)  # Eq. (8)
    print("attn size", attn.size())

    # --> batch x sourceL x queryL
    attn = attn.view(batch_size, sourceL, queryL)
    # --> batch*queryL x sourceL
    attn = torch.transpose(attn, 1, 2).contiguous()
    attn = attn.view(batch_size * queryL, sourceL)
    print("attn size", attn.size())

    #print("attn on Eq.9 on GlobalAttention", attn.size() , attn.data.cpu().sum() ) # 288, 289 / 336 , 289 ?

    #  Eq. (9)

    attn = attn * cfg.TRAIN.SMOOTH.GAMMA1
    attn = nn.Softmax(dim=0)(attn)
    attn = attn.view(batch_size, queryL, sourceL)
    # --> batch x sourceL x queryL
    attnT = torch.transpose(attn, 1, 2).contiguous()

    # (batch x ndf x sourceL)(batch x sourceL x queryL)
    # --> batch x ndf x queryL
    weightedContext = torch.bmm(context, attnT)
    print("weight size", weightedContext.size())

    attn = attn.view(batch_size, -1, ih, iw)
    print("attn size after Eq9", attn.size())

    att_maps = []
    #weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1)
    att_maps.append(attn[i].unsqueeze(0).contiguous())
    # --> batch_size x words_num x nef
    word = word.transpose(1, 2).contiguous()
    weightedContext = weightedContext.transpose(1, 2).contiguous()
    # --> batch_size*words_num x nef
    word = word.view(batch_size * words_num, -1)
    weightedContext = weightedContext.view(batch_size * words_num, -1)
    print("weight size after Eq.10", weightedContext.size())

    #
    # -->batch_size*words_num
    row_sim = cosine_similarity(word, weightedContext)
    print("row similarities", row_sim.size())
    # --> batch_size x words_num
    row_sim = row_sim.view(batch_size, words_num)

    # Eq. (10)
    row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_()
    row_sim = row_sim.sum(dim=1, keepdim=True)
    row_sim = torch.log(row_sim)

    print(row_sim)