コード例 #1
0
ファイル: trainer.py プロジェクト: ivanL96/text2imagesGAN
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = torch.load(cfg.STAGE1_G,
                                    map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD
コード例 #2
0
    def load_network_stageII(self, text_dim, gf_dim, condition_dim, z_dim, df_dim, res_num):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G(text_dim, gf_dim, condition_dim, z_dim, self.cuda)
        netG = STAGE2_G(Stage1_G, text_dim, gf_dim, condition_dim, z_dim, res_num, self.cuda)
        netG.apply(weights_init)
        print(netG)
        if self.net_g != '':
            state_dict = torch.load(self.net_g)
            #torch.loadself.net_g, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', self.net_g)
        elif self.stage1_g != '':
            state_dict = torch.load(self.stage1_g, map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict)
            print('Load from: ', self.stage1_g)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D(df_dim, condition_dim)
        netD.apply(weights_init)
        if self.net_d != '':
            state_dict = torch.load(self.net_d, map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', self.net_d)
        print(netD)

        if self.cuda:
            netG.cuda()
            netD.cuda()
        return netG, netD
コード例 #3
0
def load_model(model_name):
    from model import STAGE1_G, STAGE2_G
    stage1_g = STAGE1_G()
    netG = STAGE2_G(stage1_g)
    state_dict = torch.load(model_name,
                            map_location=lambda storage, loc: storage)
    netG.load_state_dict(state_dict)
    return netG
コード例 #4
0
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = \
                torch.load(cfg.STAGE1_G,
                           map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict)
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()

        start_epoch = 0
        if cfg.NET_G != '':
            istart = cfg.NET_G.rfind('_') + 1
            iend = cfg.NET_G.rfind('.')
            start_epoch = cfg.NET_G[istart:iend]
            start_epoch = int(start_epoch) + 1

        return netG, netD, start_epoch
コード例 #5
0
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        print(
            "---------------------------------------------------------------------------------"
        )
        print("Current Working Directory : ", os.getcwd())
        print("cfg.NET_G : ", cfg.NET_G)
        print(
            "---------------------------------------------------------------------------------"
        )
        if cfg.NET_G != '':
            print("THIS THE CURRENT DIRECTORY : ", os.getcwd())
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = torch.load(cfg.STAGE1_G,
                                    map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict["netD"])
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD