def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        from model import D_NET64, D_NET128, D_NET256
        netG = G_NET()
        if cfg.TREE.BRANCH_NUM > 0:
            netsD.append(D_NET64())
        if cfg.TREE.BRANCH_NUM > 1:
            netsD.append(D_NET128())
        if cfg.TREE.BRANCH_NUM > 2:
            netsD.append(D_NET256())

        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        epoch = 0

        if self.resume:
            checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')])
            latest_checkpoint = checkpoint_list[-1]
            state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            for i in range(len(netsD)):
                netsD[i].load_state_dict(state_dict["netD"][i])
            epoch = int(latest_checkpoint[-8:-4]) + 1
            print("Resuming training from checkpoint {} at epoch {}.".format(latest_checkpoint, epoch))

        #
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
Example #2
0
    def build_models(self):

        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        # vgg16 network
        style_loss = VGGNet()

        for p in style_loss.parameters():
            p.requires_grad = False

        print("Load the style loss model")
        style_loss.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()

            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)

        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)

            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        # Create a target network.
        target_netG = deepcopy(netG)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            style_loss = style_loss.cuda()

            # The target network is stored on the scondary GPU.---------------------------------
            target_netG.cuda(secondary_device)
            target_netG.ca_net.device = secondary_device
            #-----------------------------------------------------------------------------------

            netG.cuda()
            for i in range(len(netsD)):
                netsD[i] = netsD[i].cuda()

        # Disable training in the target network:
        for p in target_netG.parameters():
            p.requires_grad = False

        return [
            text_encoder, image_encoder, netG, target_netG, netsD, epoch,
            style_loss
        ]
Example #3
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        # MODIFIED
        if cfg.PRETRAINED_G != '':
            state_dict = torch.load(cfg.PRETRAINED_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.PRETRAINED_G)
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.PRETRAINED_G
                s_tmp = Gname[:Gname.rfind('/')]
                for i in range(len(netsD)):
                    Dname = '%s/netD%d.pth' % (
                        s_tmp, i
                    )  # the name of Ds should be consistent and differ from each other in i
                    print('Load D from: ', Dname)
                    state_dict = torch.load(
                        Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
Example #4
0
    def build_models(self):
        # text encoders
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        # self.n_words = 156
        text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # Caption models - cnn_encoder and rnn_decoder
        caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)
        caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage))
        for p in caption_cnn.parameters():
            p.requires_grad = False
        print('Load caption model from:', cfg.CAP.caption_cnn_path)
        caption_cnn.eval()

        # self.n_words = 9
        caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
        # caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
        caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage))
        for p in caption_rnn.parameters():
            p.requires_grad = False
        print('Load caption model from:', cfg.CAP.caption_rnn_path)

        # Generator and Discriminator:
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET

            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))

        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]

            # print(epoch)
            # print(state_dict.keys())
            # print(netG.keys())
            # epoch = state_dict['epoch']
            epoch = int(epoch) + 1
            # epoch = 187
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            caption_cnn = caption_cnn.cuda()
            caption_rnn = caption_rnn.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]
    def build_models(self):
        def count_parameters(model):
            total_param = 0
            for name, param in model.named_parameters():
                if param.requires_grad:
                    num_param = np.prod(param.size())
                    if param.dim() > 1:
                        print(name, ':',
                              'x'.join(str(x) for x in list(param.size())),
                              '=', num_param)
                    else:
                        print(name, ':', num_param)
                    total_param += num_param
            return total_param

        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:

        print('number of trainable parameters =', count_parameters(netG))
        print('number of trainable parameters =', count_parameters(netsD[-1]))

        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
Example #6
0
    def build_models(self):
        ################### Text and Image encoders ########################################
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return
        """
        image_encoder = CNN_dummy()
        image_encoder.cuda()
        image_encoder.eval()
        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()
        """

        VGG = VGG16()
        VGG.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        ####################### Generator and Discriminators ##############
        from model import D_NET256
        netD = D_NET256()
        netG = EncDecNet()

        netD.apply(weights_init)
        netG.apply(weights_init)

        #
        epoch = 0
        """
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        """
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            netG.cuda()
            netD.cuda()
            VGG.cuda()
        return [text_encoder, netG, netD, epoch, VGG]
Example #7
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netG = G_NET(len(self.cats_index_dict))
        netsPatD, netsShpD = [], []
        if cfg.TREE.BRANCH_NUM > 0:
            netsPatD.append(PAT_D_NET64())
            netsShpD.append(SHP_D_NET64(len(self.cats_index_dict)))
        if cfg.TREE.BRANCH_NUM > 1:
            netsPatD.append(PAT_D_NET128())
            netsShpD.append(SHP_D_NET128(len(self.cats_index_dict)))
        if cfg.TREE.BRANCH_NUM > 2:
            netsPatD.append(PAT_D_NET256())
            netsShpD.append(SHP_D_NET256(len(self.cats_index_dict)))

        netObjSSD = OBJ_SS_D_NET(len(self.cats_index_dict))
        netObjLSD = OBJ_LS_D_NET(len(self.cats_index_dict))

        netG.apply(weights_init)
        netObjSSD.apply(weights_init)
        netObjLSD.apply(weights_init)
        for i in range(len(netsPatD)):
            netsPatD[i].apply(weights_init)
            netsShpD[i].apply(weights_init)
        print('# of netsPatD', len(netsPatD))
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            netObjSSD.cuda()
            netObjLSD.cuda()
            for i in range(len(netsPatD)):
                netsPatD[i].cuda()
                netsShpD[i].cuda()

            if len(cfg.GPU_IDS) > 1:
                text_encoder = nn.DataParallel(text_encoder)
                text_encoder.to(self.device)
                image_encoder = nn.DataParallel(image_encoder)
                image_encoder.to(self.device)
                netG = nn.DataParallel(netG)
                netG.to(self.device)
                netObjSSD = nn.DataParallel(netObjSSD)
                netObjSSD.to(self.device)
                netObjLSD = nn.DataParallel(netObjLSD)
                netObjLSD.to(self.device)
                for i in range(len(netsPatD)):
                    netsPatD[i] = nn.DataParallel(netsPatD[i])
                    netsPatD[i].to(self.device)

                    netsShpD[i] = nn.DataParallel(netsShpD[i])
                    netsShpD[i].to(self.device)
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1

            Gname = cfg.TRAIN.NET_G
            for i in range(len(netsPatD)):
                s_tmp = Gname[:Gname.rfind('/')]

                Dname = '%s/netPatD%d.pth' % (s_tmp, i)
                print('Load PatD from: ', Dname)
                state_dict = \
                    torch.load(Dname, map_location=lambda storage, loc: storage)
                netsPatD[i].load_state_dict(state_dict)

                Dname = '%s/netShpD%d.pth' % (s_tmp, i)
                print('Load ShpD from: ', Dname)
                state_dict = \
                    torch.load(Dname, map_location=lambda storage, loc: storage)
                netsShpD[i].load_state_dict(state_dict)

            s_tmp = Gname[:Gname.rfind('/')]
            Dname = '%s/netObjSSD.pth' % (s_tmp)
            print('Load ObjSSD from: ', Dname)
            state_dict = \
                torch.load(Dname, map_location=lambda storage, loc: storage)
            netObjSSD.load_state_dict(state_dict)

            s_tmp = Gname[:Gname.rfind('/')]
            Dname = '%s/netObjLSD.pth' % (s_tmp)
            print('Load ObjLSD from: ', Dname)
            state_dict = \
                torch.load(Dname, map_location=lambda storage, loc: storage)
            netObjLSD.load_state_dict(state_dict)

        return [
            text_encoder, image_encoder, netG, netsPatD, netsShpD, netObjSSD,
            netObjLSD, epoch
        ]
Example #8
0
image_encoder = CNN_ENCODER(args.image_size)
img_encoder_path = '../DAMSMencoders/' + args.dataset + '/image_encoder200.pth'
state_dict = torch.load(img_encoder_path,
                        map_location=lambda storage, loc: storage)
image_encoder.load_state_dict(state_dict)
for p in image_encoder.parameters():
    p.requires_grad = False
print('Load image encoder from:', img_encoder_path)
image_encoder.eval()

text_encoder = RNN_ENCODER(dataset_train.n_words, nhidden=args.image_size)
text_encoder_path = '../DAMSMencoders/' + args.dataset + '/text_encoder200.pth'
state_dict = torch.load(text_encoder_path,
                        map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
for p in text_encoder.parameters():
    p.requires_grad = False
print('Load text encoder from:', text_encoder_path)
text_encoder.eval()

if use_cuda:
    text_encoder = text_encoder.cuda()
    image_encoder = image_encoder.cuda()


def prepare_labels():
    batch_size = args.batch_size
    real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
    fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
    match_labels = Variable(torch.LongTensor(range(batch_size)))
    if use_cuda:
Example #9
0
    def build_models(self):
        ################### Text and Image encoders ########################################
        # if cfg.TRAIN.NET_E == '':
        #   print('Error: no pretrained text-image encoders')
        #   return

        VGG = VGGNet()

        for p in VGG.parameters():
            p.requires_grad = False

        print("Load the VGG model")
        VGG.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        text_encoder = RNN_ENCODER(self.n_words,
                                   nhidden=cfg.TEXT.EMBEDDING_DIM)

        # when NET_E = '', train the image_encoder and text_encoder jointly
        if cfg.TRAIN.NET_E != '':
            state_dict = torch.load(
                cfg.TRAIN.NET_E,
                map_location=lambda storage, loc: storage).state_dict()
            text_encoder.load_state_dict(state_dict)
            for p in text_encoder.parameters():
                p.requires_grad = False
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder.eval()

            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = torch.load(
                img_encoder_path,
                map_location=lambda storage, loc: storage).state_dict()
            image_encoder.load_state_dict(state_dict)
            for p in image_encoder.parameters():
                p.requires_grad = False
            print('Load image encoder from:', img_encoder_path)
            image_encoder.eval()

        ####################### Generator and Discriminators ##############
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            netG = G_DCGAN()
            if cfg.TRAIN.W_GAN:
                netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            netG.apply(weights_init)
            if cfg.TRAIN.W_GAN:
                if cfg.TREE.BRANCH_NUM > 0:
                    netsD.append(D_NET64())
                if cfg.TREE.BRANCH_NUM > 1:
                    netsD.append(D_NET128())
                if cfg.TREE.BRANCH_NUM > 2:
                    netsD.append(D_NET256())
                for i in range(len(netsD)):
                    netsD[i].apply(weights_init)

        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = torch.load(cfg.TRAIN.NET_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                      torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            VGG = VGG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch, VGG]
Example #10
0
    def build_models(self):
        # ################### models ######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return
        if cfg.TRAIN.NET_G == '':
            print('Error: no pretrained main module')
            return

        VGG = VGGNet()

        for p in VGG.parameters():
            p.requires_grad = False

        print("Load the VGG model")
        VGG.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        if cfg.GAN.B_DCGAN:
            netG = G_DCGAN()
            from model import D_NET256 as D_NET
            netD = D_NET(b_jcu=False)
        else:
            from model import D_NET256
            netG = G_NET()
            netD = D_NET256()

        netD.apply(weights_init)

        state_dict = \
            torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        netG.eval()
        print('Load G from: ', cfg.TRAIN.NET_G)

        epoch = 0
        netDCM = DCM_Net()
        if cfg.TRAIN.NET_C != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage)
            netDCM.load_state_dict(state_dict)
            print('Load DCM from: ', cfg.TRAIN.NET_C)
            istart = cfg.TRAIN.NET_C.rfind('_') + 1
            iend = cfg.TRAIN.NET_C.rfind('.')
            epoch = cfg.TRAIN.NET_C[istart:iend]
            epoch = int(epoch) + 1

        if cfg.TRAIN.NET_D != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_D, map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load DCM Discriminator from: ', cfg.TRAIN.NET_D)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            netDCM.cuda()
            VGG = VGG.cuda()
            netD.cuda()
        return [text_encoder, image_encoder, netG, netD, epoch, VGG, netDCM]
Example #11
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            raise FileNotFoundError(
                'No pretrained text encoder found in directory DAMSMencoders/. \n'
                +
                'Please train the DAMSM first before training the GAN (see README for details).'
            )

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        if self.text_encoder_type == 'rnn':
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        elif self.text_encoder_type == 'transformer':
            text_encoder = GPT2Model.from_pretrained(TRANSFORMER_ENCODER)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        elif cfg.GAN.B_STYLEGEN:
            netG = G_NET_STYLED()
            if cfg.GAN.B_STYLEDISC:
                from model import D_NET_STYLED64, D_NET_STYLED128, D_NET_STYLED256
                if cfg.TREE.BRANCH_NUM > 0:
                    netsD.append(D_NET_STYLED64())
                if cfg.TREE.BRANCH_NUM > 1:
                    netsD.append(D_NET_STYLED128())
                if cfg.TREE.BRANCH_NUM > 2:
                    netsD.append(D_NET_STYLED256())
                # TODO: if cfg.TREE.BRANCH_NUM > 3:
            else:
                from model import D_NET64, D_NET128, D_NET256
                if cfg.TREE.BRANCH_NUM > 0:
                    netsD.append(D_NET64())
                if cfg.TREE.BRANCH_NUM > 1:
                    netsD.append(D_NET128())
                if cfg.TREE.BRANCH_NUM > 2:
                    netsD.append(D_NET256())
                # TODO: if cfg.TREE.BRANCH_NUM > 3:
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
            netG.apply(weights_init)
            # print(netG)
            for i in range(len(netsD)):
                netsD[i].apply(weights_init)
                # print(netsD[i])
        print(netG.__class__)
        for i in netsD:
            print(i.__class__)
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = torch.load(cfg.TRAIN.NET_G,
                                    map_location=lambda storage, loc: storage)
            if cfg.GAN.B_STYLEGEN:
                netG.w_ewma = state_dict['w_ewma']
                if cfg.CUDA:
                    netG.w_ewma = netG.w_ewma.to('cuda:' + str(cfg.GPU_ID))
                netG.load_state_dict(state_dict['netG_state_dict'])
            else:
                netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            text_encoder = nn.DataParallel(text_encoder)
            image_encoder = nn.DataParallel(image_encoder)
            netG = nn.DataParallel(netG)
            for i in range(len(netsD)):
                netsD[i] = nn.DataParallel(netsD[i])

        image_encoder.to(self.device)
        text_encoder.to(self.device)
        netG.to(self.device)
        for i in range(len(netsD)):
            netsD[i].to(self.device)

        # if cfg.CUDA and torch.cuda.is_available():
        #     text_encoder = text_encoder.cuda()
        #     image_encoder = image_encoder.cuda()
        #     netG.cuda()
        #     for i in range(len(netsD)):
        #         netsD[i].cuda()

        # if cfg.PARALLEL:
        #     netG = torch.nn.DataParallel(netG, device_ids=[0, 1, 2])
        #     text_encoder = torch.nn.DataParallel(text_encoder, device_ids=[0, 1, 2])
        #     image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[0, 1, 2])
        #     for i in range(len(netsD)):
        #         netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=[0, 1, 2])

        return [text_encoder, image_encoder, netG, netsD, epoch]
Example #13
0
    def build_models(self):
        print('Building models...')
        print('N_words: ', self.n_words)

        #####################
        ##  TEXT ENCODERS  ##
        #####################
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        print('Built image encoder: ', image_encoder)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        print('Built text encoder: ', text_encoder)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        ######################
        ##  CAPTION MODELS  ##
        ######################

        # cnn_encoder and rnn_encoder
        if cfg.CAP.USE_ORIGINAL:
            caption_cnn = CAPTION_CNN(embed_size=cfg.TEXT.EMBEDDING_DIM)
            caption_rnn = CAPTION_RNN(embed_size=cfg.TEXT.EMBEDDING_DIM,
                                      hidden_size=cfg.CAP.HIDDEN_SIZE,
                                      vocab_size=self.n_words,
                                      num_layers=cfg.CAP.NUM_LAYERS)
        else:
            caption_cnn = Encoder()
            caption_rnn = Decoder(idx2word=self.ixtoword)

        caption_cnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_CNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_rnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_RNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_cnn.load_state_dict(caption_cnn_checkpoint['model_state_dict'])
        caption_rnn.load_state_dict(caption_rnn_checkpoint['model_state_dict'])

        for p in caption_cnn.parameters():
            p.requires_grad = False
        print('Load caption model from: ', cfg.CAP.CAPTION_CNN_PATH)
        caption_cnn.eval()

        for p in caption_rnn.parameters():
            p.requires_grad = False
        print('Load caption model from: ', cfg.CAP.CAPTION_RNN_PATH)

        #################################
        ##  GENERATOR & DISCRIMINATOR  ##
        #################################
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET

            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))

        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        text_encoder = text_encoder.to(cfg.DEVICE)
        image_encoder = image_encoder.to(cfg.DEVICE)
        caption_cnn = caption_cnn.to(cfg.DEVICE)
        caption_rnn = caption_rnn.to(cfg.DEVICE)
        netG.to(cfg.DEVICE)
        for i in range(len(netsD)):
            netsD[i].to(cfg.DEVICE)
        return [
            text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD,
            epoch
        ]
Example #14
0
class RAGAN():
    def name(self):
        return 'ResidualAttentionGAN'

    def initialize(self, opt):
        torch.cuda.set_device(opt.gpu)
        cudnn.benchmark = True
        self.D_len = 3
        self.opt = opt
        self.build_models()
        
        
    def build_models(self):
        ################### encoders #########################################
        self.E_image = None
        self.E_text = None
        use_con = False
        use_sigmoid = True if self.opt.c_gan_mode == 'dcgan' else False
        if self.opt.c_type == 'image':
            if self.opt.model == 'supervised':
                self.E_image = E_ResNet_Local(input_nc=self.opt.output_nc, output_nc=self.opt.nc, nef=self.opt.nef, n_blocks=self.opt.e_blocks,norm_type=self.opt.norm)
            elif self.opt.model == 'unsupervised':
                self.E_image = CE_ResNet_Local(input_nc=self.opt.output_nc, output_nc=self.opt.nc, nef=self.opt.nef, n_blocks=self.opt.e_blocks, c_dim=self.opt.ne, norm_type=self.opt.norm)
        elif self.opt.c_type == 'text':
            use_con = True
            self.E_text =  RNN_ENCODER(self.opt.n_words, nhidden=self.opt.nc)
            state_dict = torch.load(self.opt.E_text_path, map_location=lambda storage, loc: storage)
            self.E_text.load_state_dict(state_dict)
            for p in self.E_text.parameters():
                p.requires_grad = False
            print('Load text encoder successful')
            self.E_text.eval()
        elif self.opt.c_type == 'image_text':
            use_con = True
            self.E_image = E_ResNet_Global(input_nc=self.opt.output_nc, output_nc=self.opt.ne, nef=self.opt.nef, n_blocks=self.opt.e_blocks,norm_type=self.opt.norm)
            self.E_text =  RNN_ENCODER(self.opt.n_words, nhidden=self.opt.nc)
            state_dict = torch.load(self.opt.E_text_path, map_location=lambda storage, loc: storage)
            self.E_text.load_state_dict(state_dict)
            for p in self.E_text.parameters():
                p.requires_grad = False
            print('Load text encoder successful')
            self.E_text.eval()
        elif self.opt.c_type == 'label':
            use_con = True
        elif self.opt.c_type == 'image_label':
            use_con = True
            self.E_image = E_ResNet_Global(input_nc=self.opt.output_nc, output_nc=self.opt.ne, nef=self.opt.nef, n_blocks=self.opt.e_blocks,norm_type=self.opt.norm)
        else:
            raise('Non conditioanl type of {}'.format(self.opt.c_type))
        
        
        ################### generator #########################################
        self.G = RAG_NET(input_nc=self.opt.input_nc, ngf=self.opt.ngf, nc=self.opt.nc, ne=self.opt.ne,norm_type=self.opt.norm)
        
        if self.opt.isTrain:    
            ################### discriminators #####################################
            self.Ds = []
            self.Ds2 = None
            bnf = 3 if self.opt.fineSize <=128 else 4
            self.Ds.append(D_NET(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=bnf, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm))
            self.Ds.append(D_NET(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=4, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm))
            self.Ds.append(D_NET_Multi(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=4, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm))
            if self.opt.model == 'unsupervised' and self.opt.c_type == 'image':
                self.Ds2 = []
                self.Ds2.append(D_NET(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=bnf, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm))
                self.Ds2.append(D_NET(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=4, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm))
                self.Ds2.append(D_NET_Multi(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=4, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm))
            ################### init_weights ########################################
            self.G.apply(weights_init(self.opt.init_type))
            for i in range(self.D_len):
                self.Ds[i].apply(weights_init(self.opt.init_type))
            if self.Ds2 is not None:
                for i in range(self.D_len):
                    self.Ds2[i].apply(weights_init(self.opt.init_type))
            if self.E_image is not None:
                self.E_image.apply(weights_init(self.opt.init_type))
                
            ################### use GPU #############################################
            self.G.cuda()
            for i in range(self.D_len):
                self.Ds[i].cuda()
            if self.Ds2 is not None:
                for i in range(self.D_len):
                    self.Ds2[i].cuda()
            if self.E_image is not None:
                self.E_image.cuda()
            if self.E_text is not None:
                self.E_text.cuda()
                
            ################### set criterion ########################################
            self.criterionGAN = GANLoss(mse_loss=True)
            if use_con:
                self.criterionCGAN = GANLoss(mse_loss = not use_sigmoid)
            self.criterionKL = KL_loss
            
            ################## define optimizers #####################################
            self.define_optimizers()
        
        
    def denorm(self,x):
        x = (x+1)/2
        return x.clamp_(0, 1)
        
    def define_optimizer(self, Net):
        return optim.Adam(Net.parameters(),
                                    lr=self.opt.lr,
                                    betas=(0.5, 0.999))
    def define_optimizers(self):
        self.G_opt = self.define_optimizer(self.G)
        self.Ds_opt = []
        self.Ds2_opt = []
        self.E_opt = None
        for i in range(self.D_len):
            self.Ds_opt.append(self.define_optimizer(self.Ds[i]))
        if self.Ds2 is not None:
            for i in range(self.D_len):
                self.Ds2_opt.append(self.define_optimizer(self.Ds2[i]))
        if self.E_image is not None:
            self.E_opt = self.define_optimizer(self.E_image)
    
    def update_lr(self, lr):
        for param_group in self.G_opt.param_groups:
            param_group['lr'] = lr
        for i in range(self.D_len):
            for param_group in self.Ds_opt[i].param_groups:
                param_group['lr'] = lr
        if self.E_opt is not None:
            for param_group in self.E_opt.param_groups:
                param_group['lr'] = lr
                
    def save(self, name):
        torch.save(self.G.state_dict(), '{}/G_{}.pth'.format(self.opt.model_dir, name))
        for i in range(self.D_len):
            torch.save(self.Ds[i].state_dict(), '{}/D_{}_{}.pth'.format(self.opt.model_dir, i, name))
        if self.E_image is not None:
            torch.save(self.E_image.state_dict(), '{}/E_image_{}.pth'.format(self.opt.model_dir, name))
            
    def get_c_random(self, size):
        c = torch.cuda.FloatTensor(size).normal_()
        return Variable(c)
    

    def prepare_label(self,data):
        imgs, c_global, c_local = data
        c_global = Variable(c_global).cuda()
        c_local = Variable(c_local).cuda()
        imgs_cuda = []
        for img in imgs:
            imgs_cuda.append(Variable(img).cuda())
        return [imgs_cuda, c_global, c_local]
        
    def prepare_image(self,data):
        imgAs, imgBs = data
        if isinstance(imgAs,list):
            imgA = []
            for img in imgAs:
                imgA.append(Variable(img).cuda())
        else:
            imgA = Variable(imgAs).cuda()
        imgB = []
        for img in imgBs:
            imgB.append(Variable(img).cuda())
        return [imgA, imgB]
    
    def prepare_text(self,data):
        imgs, captions, captions_lens, class_ids, keys = data
        sorted_cap_lens, sorted_cap_indices = torch.sort(captions_lens, 0, True)
        real_imgs = []
        for i in range(len(imgs)):
            imgs[i] = imgs[i][sorted_cap_indices]
            real_imgs.append(Variable(imgs[i]).cuda())
        captions = captions[sorted_cap_indices].squeeze()
        class_ids = class_ids[sorted_cap_indices].numpy()
        keys = [keys[i] for i in sorted_cap_indices.numpy()]
        captions = Variable(captions).cuda()
        sorted_cap_lens = Variable(sorted_cap_lens).cuda()

        return [real_imgs, captions, sorted_cap_lens,
                    class_ids, keys]
            
    
    def get_current_errors(self):
        pass
    def get_current_visuals(self):
        pass
    def update_model(self,data):
        pass