def init_model(ixtoword):
    if cfg.CAP.USE_ORIGINAL:
        caption_cnn = CAPTION_CNN(embed_size=cfg.CAP.EMBED_SIZE)
        caption_rnn = CAPTION_RNN(embed_size=cfg.CAP.EMBED_SIZE,
                                  hidden_size=cfg.CAP.HIDDEN_SIZE,
                                  vocab_size=len(ixtoword),
                                  num_layers=cfg.CAP.NUM_LAYERS)
    else:
        caption_cnn = Encoder()
        caption_rnn = Decoder(idx2word=ixtoword)

    decoder_optimizer = torch.optim.Adam(params=caption_rnn.parameters(),
                                         lr=cfg.CAP.LEARNING_RATE)

    if cfg.CAP.CAPTION_CNN_PATH and cfg.CAP.CAPTION_RNN_PATH:
        print('Pre-Trained Caption Model')
        caption_cnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_CNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_rnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_RNN_PATH,
            map_location=lambda storage, loc: storage)

        caption_cnn.load_state_dict(caption_cnn_checkpoint['model_state_dict'])
        caption_rnn.load_state_dict(caption_rnn_checkpoint['model_state_dict'])
        decoder_optimizer.load_state_dict(
            caption_rnn_checkpoint['optimizer_state_dict'])

    caption_cnn = caption_cnn.to(cfg.DEVICE)
    caption_rnn = caption_rnn.to(cfg.DEVICE)
    decoder_optimizer = decoder_optimizer

    return caption_cnn, caption_rnn, decoder_optimizer
Beispiel #2
0
    # transforms.Normalize((0.485, 0.456, 0.406),
    #                      (0.229, 0.224, 0.225))
])
# load text data
dataset = TextDataset(cfg.DATA_DIR,
                      split_dir,
                      base_size=cfg.TREE.BASE_SIZE,
                      transform=transform)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load caption model
caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)
caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2,
                          dataset.n_words, cfg.CAP.num_layers)
caption_cnn.to(device)
caption_rnn.to(device)

caption_cnn.load_state_dict(
    torch.load(cfg.CAP.caption_cnn_path,
               map_location=lambda storage, loc: storage))
caption_rnn.load_state_dict(
    torch.load(cfg.CAP.caption_rnn_path,
               map_location=lambda storage, loc: storage))

## inference
# load image dataset
ROOT_DIR = os.getcwd()
# DATA_DIR = osp.join(ROOT_DIR, 'data', 'output', 'bird', 'Model', 'netG', 'valid', 'single')
DATA_DIR = 'data/birds/CUB_200_2011/images'
OUTPUT_DIR = osp.join(ROOT_DIR, 'output')
Beispiel #3
0
    def build_models(self):
        print('Building models...')
        print('N_words: ', self.n_words)

        #####################
        ##  TEXT ENCODERS  ##
        #####################
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        print('Built image encoder: ', image_encoder)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        print('Built text encoder: ', text_encoder)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        ######################
        ##  CAPTION MODELS  ##
        ######################

        # cnn_encoder and rnn_encoder
        if cfg.CAP.USE_ORIGINAL:
            caption_cnn = CAPTION_CNN(embed_size=cfg.TEXT.EMBEDDING_DIM)
            caption_rnn = CAPTION_RNN(embed_size=cfg.TEXT.EMBEDDING_DIM,
                                      hidden_size=cfg.CAP.HIDDEN_SIZE,
                                      vocab_size=self.n_words,
                                      num_layers=cfg.CAP.NUM_LAYERS)
        else:
            caption_cnn = Encoder()
            caption_rnn = Decoder(idx2word=self.ixtoword)

        caption_cnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_CNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_rnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_RNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_cnn.load_state_dict(caption_cnn_checkpoint['model_state_dict'])
        caption_rnn.load_state_dict(caption_rnn_checkpoint['model_state_dict'])

        for p in caption_cnn.parameters():
            p.requires_grad = False
        print('Load caption model from: ', cfg.CAP.CAPTION_CNN_PATH)
        caption_cnn.eval()

        for p in caption_rnn.parameters():
            p.requires_grad = False
        print('Load caption model from: ', cfg.CAP.CAPTION_RNN_PATH)

        #################################
        ##  GENERATOR & DISCRIMINATOR  ##
        #################################
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET

            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))

        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        text_encoder = text_encoder.to(cfg.DEVICE)
        image_encoder = image_encoder.to(cfg.DEVICE)
        caption_cnn = caption_cnn.to(cfg.DEVICE)
        caption_rnn = caption_rnn.to(cfg.DEVICE)
        netG.to(cfg.DEVICE)
        for i in range(len(netsD)):
            netsD[i].to(cfg.DEVICE)
        return [
            text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD,
            epoch
        ]