コード例 #1
0
def init_model(ixtoword):
    if cfg.CAP.USE_ORIGINAL:
        caption_cnn = CAPTION_CNN(embed_size=cfg.CAP.EMBED_SIZE)
        caption_rnn = CAPTION_RNN(embed_size=cfg.CAP.EMBED_SIZE,
                                  hidden_size=cfg.CAP.HIDDEN_SIZE,
                                  vocab_size=len(ixtoword),
                                  num_layers=cfg.CAP.NUM_LAYERS)
    else:
        caption_cnn = Encoder()
        caption_rnn = Decoder(idx2word=ixtoword)

    decoder_optimizer = torch.optim.Adam(params=caption_rnn.parameters(),
                                         lr=cfg.CAP.LEARNING_RATE)

    if cfg.CAP.CAPTION_CNN_PATH and cfg.CAP.CAPTION_RNN_PATH:
        print('Pre-Trained Caption Model')
        caption_cnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_CNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_rnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_RNN_PATH,
            map_location=lambda storage, loc: storage)

        caption_cnn.load_state_dict(caption_cnn_checkpoint['model_state_dict'])
        caption_rnn.load_state_dict(caption_rnn_checkpoint['model_state_dict'])
        decoder_optimizer.load_state_dict(
            caption_rnn_checkpoint['optimizer_state_dict'])

    caption_cnn = caption_cnn.to(cfg.DEVICE)
    caption_rnn = caption_rnn.to(cfg.DEVICE)
    decoder_optimizer = decoder_optimizer

    return caption_cnn, caption_rnn, decoder_optimizer
コード例 #2
0
    def build_models(self):
        # text encoders
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        # self.n_words = 156
        text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # Caption models - cnn_encoder and rnn_decoder
        caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)
        caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage))
        for p in caption_cnn.parameters():
            p.requires_grad = False
        print('Load caption model from:', cfg.CAP.caption_cnn_path)
        caption_cnn.eval()

        # self.n_words = 9
        caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
        # caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
        caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage))
        for p in caption_rnn.parameters():
            p.requires_grad = False
        print('Load caption model from:', cfg.CAP.caption_rnn_path)

        # Generator and Discriminator:
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET

            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))

        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]

            # print(epoch)
            # print(state_dict.keys())
            # print(netG.keys())
            # epoch = state_dict['epoch']
            epoch = int(epoch) + 1
            # epoch = 187
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            caption_cnn = caption_cnn.cuda()
            caption_rnn = caption_rnn.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]
コード例 #3
0
# image preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.485, 0.456, 0.406),
    #                      (0.229, 0.224, 0.225))
])
# load text data
dataset = TextDataset(cfg.DATA_DIR,
                      split_dir,
                      base_size=cfg.TREE.BASE_SIZE,
                      transform=transform)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load caption model
caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)
caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2,
                          dataset.n_words, cfg.CAP.num_layers)
caption_cnn.to(device)
caption_rnn.to(device)

caption_cnn.load_state_dict(
    torch.load(cfg.CAP.caption_cnn_path,
               map_location=lambda storage, loc: storage))
caption_rnn.load_state_dict(
    torch.load(cfg.CAP.caption_rnn_path,
               map_location=lambda storage, loc: storage))

## inference
# load image dataset
ROOT_DIR = os.getcwd()
コード例 #4
0
    dataset_val = TextDatasetCOCO(
        cfg.DATA_DIR, 
        'test', 
        base_size=cfg.TREE.BASE_SIZE,
        transform=image_transform,
        norm=norm
    )
    """
    dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                 batch_size=batch_size,
                                                 drop_last=True,
                                                 shuffle=True,
                                                 num_workers=int(cfg.WORKERS))

    # Train ##############################################################
    encoder = CAPTION_CNN(cfg.CAP.embed_size).cuda()
    decoder = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2,
                          dataset.n_words, cfg.CAP.num_layers).cuda()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params, lr=cfg.CAP.learning_rate)

    log_step = 10
    save_step = 10
    num_epochs = 50

    for epoch in range(num_epochs):
        total_step = len(dataloader)
        for i, data in enumerate(dataloader):
            imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
コード例 #5
0
ファイル: trainer.py プロジェクト: cammiida/masters-project
    def build_models(self):
        print('Building models...')
        print('N_words: ', self.n_words)

        #####################
        ##  TEXT ENCODERS  ##
        #####################
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        print('Built image encoder: ', image_encoder)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        print('Built text encoder: ', text_encoder)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        ######################
        ##  CAPTION MODELS  ##
        ######################

        # cnn_encoder and rnn_encoder
        if cfg.CAP.USE_ORIGINAL:
            caption_cnn = CAPTION_CNN(embed_size=cfg.TEXT.EMBEDDING_DIM)
            caption_rnn = CAPTION_RNN(embed_size=cfg.TEXT.EMBEDDING_DIM,
                                      hidden_size=cfg.CAP.HIDDEN_SIZE,
                                      vocab_size=self.n_words,
                                      num_layers=cfg.CAP.NUM_LAYERS)
        else:
            caption_cnn = Encoder()
            caption_rnn = Decoder(idx2word=self.ixtoword)

        caption_cnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_CNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_rnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_RNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_cnn.load_state_dict(caption_cnn_checkpoint['model_state_dict'])
        caption_rnn.load_state_dict(caption_rnn_checkpoint['model_state_dict'])

        for p in caption_cnn.parameters():
            p.requires_grad = False
        print('Load caption model from: ', cfg.CAP.CAPTION_CNN_PATH)
        caption_cnn.eval()

        for p in caption_rnn.parameters():
            p.requires_grad = False
        print('Load caption model from: ', cfg.CAP.CAPTION_RNN_PATH)

        #################################
        ##  GENERATOR & DISCRIMINATOR  ##
        #################################
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET

            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))

        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        text_encoder = text_encoder.to(cfg.DEVICE)
        image_encoder = image_encoder.to(cfg.DEVICE)
        caption_cnn = caption_cnn.to(cfg.DEVICE)
        caption_rnn = caption_rnn.to(cfg.DEVICE)
        netG.to(cfg.DEVICE)
        for i in range(len(netsD)):
            netsD[i].to(cfg.DEVICE)
        return [
            text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD,
            epoch
        ]
コード例 #6
0
ファイル: sample_STREAM.py プロジェクト: Seth-Park/MirrorGAN
    unnorm = transforms.Normalize(
        (-0.485/0.229, -0.456/0.229, -0.406/0.229),
        (1/0.229, 1/0.229, 1/0.229)
    )

    dataset = TextDataset(cfg.DATA_DIR, 'test',
                          base_size=cfg.TREE.BASE_SIZE,
                          transform=image_transform,
                          norm=norm)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, drop_last=True,
        shuffle=False, num_workers=int(cfg.WORKERS))

    encoder_path = os.path.join(args.model_dir, 'encoder-9.ckpt')
    decoder_path = os.path.join(args.model_dir, 'decoder-9.ckpt')
    encoder = CAPTION_CNN(cfg.CAP.embed_size).cuda()
    decoder = CAPTION_RNN(
        cfg.CAP.embed_size,
        cfg.CAP.hidden_size * 2,
        dataset.n_words,
        cfg.CAP.num_layers
    ).cuda()
    encoder.load_state_dict(torch.load(encoder_path))
    decoder.load_state_dict(torch.load(decoder_path))
    encoder.eval()
    decoder.eval()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    for i, data in enumerate(dataloader):