コード例 #1
0
ファイル: trainer.py プロジェクト: loveisessential/showcase
def load_network(gpus):
    """hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmx
    netG = G_NET()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)
    """
    #hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxstart
    netG = G_NET_hmx()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)
    #hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxend

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())
    if cfg.TREE.BRANCH_NUM > 3:
        netsD.append(D_NET512())
    if cfg.TREE.BRANCH_NUM > 4:
        netsD.append(D_NET1024())
    # TODO: if cfg.TREE.BRANCH_NUM > 5:

    for i in range(len(netsD)):
        netsD[i].apply(weights_init)
        netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
        # print(netsD[i])
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        count = cfg.TRAIN.NET_G[istart:iend]
        count = int(count) + 1

    if cfg.TRAIN.NET_D != '':
        for i in range(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
            netsD[i].load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    if cfg.CUDA:
        netG.cuda()
        for i in range(len(netsD)):
            netsD[i].cuda()
        inception_model = inception_model.cuda()
    inception_model.eval()

    return netG, netsD, len(netsD), inception_model, count
コード例 #2
0
    def load_Dnet(self, gpus):
        if cfg.TRAIN.NET_D == '':
            print('Error: the path for morels is not found!')
            sys.exit(-1)
        else:
            netsD = []
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            if cfg.TREE.BRANCH_NUM > 3:
                netsD.append(D_NET512())
            if cfg.TREE.BRANCH_NUM > 4:
                netsD.append(D_NET1024())

            self.num_Ds = len(netsD)
            for i in range(len(netsD)):
                netsD[i].apply(weights_init)
                netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)

            for i in range(cfg.TREE.BRANCH_NUM):
                print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
                state_dict = torch.load(
                    '%snetD%d_70000.pth' % (cfg.TRAIN.NET_D, i),
                    map_location=lambda storage, loc: storage)
                netsD[i].load_state_dict(state_dict)

            if cfg.CUDA:
                for i in range(len(netsD)):
                    netsD[i].cuda()
            return netsD
コード例 #3
0
ファイル: trainer.py プロジェクト: synsypa/Recipe2ImageGAN
def load_network(gpus):
    netG = G_NET()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())
    if cfg.TREE.BRANCH_NUM > 3:
        netsD.append(D_NET512())
    if cfg.TREE.BRANCH_NUM > 4:
        netsD.append(D_NET1024())
    # TODO: if cfg.TREE.BRANCH_NUM > 5:

    for i in range(len(netsD)):
        netsD[i].apply(weights_init)
        netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
        # print(netsD[i])
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_G)

        try:
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            count = cfg.TRAIN.NET_G[istart:iend]
            count = int(count)
        except:
            last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Model'
            with open(last_run_dir + '/count.txt', 'r') as f:
                count = int(f.read())

        count = int(count) + 1

    if cfg.TRAIN.NET_D != '':
        for i in range(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
            netsD[i].load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    if cfg.CUDA:
        netG.cuda()
        for i in range(len(netsD)):
            netsD[i].cuda()
        inception_model = inception_model.cuda()
    inception_model.eval()

    return netG, netsD, len(netsD), inception_model, count
コード例 #4
0
def load_network(gpus:list, distributed:bool):
    netG = G_NET()
    netG.apply(weights_init)
    if distributed:
        netG = netG.cuda()
        netG = torch.nn.parallel.DistributedDataParallel(netG, device_ids=gpus, output_device=gpus[0], broadcast_buffers=True)
    else:
        if cfg.CUDA:
            netG = netG.cuda()
        netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())
    if cfg.TREE.BRANCH_NUM > 3:
        netsD.append(D_NET512())
    if cfg.TREE.BRANCH_NUM > 4:
        netsD.append(D_NET1024())
    # TODO: if cfg.TREE.BRANCH_NUM > 5:
    # netsD_module = nn.ModuleList(netsD)
    # netsD_module.apply(weights_init)
    # netsD_module = torch.nn.parallel.DistributedDataParallel(netsD_module.cuda(), device_ids=gpus, output_device=gpus[0])
    for i in range(len(netsD)):
        netsD[i].apply(weights_init)
        if distributed:
            netsD[i] = torch.nn.parallel.DistributedDataParallel(netsD[i].cuda(), device_ids=gpus, output_device=gpus[0], broadcast_buffers=True
                # , process_group=pg_Ds[i]
                )
        else:
            netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
        print(netsD[i])
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        count = cfg.TRAIN.NET_G[istart:iend]
        count = int(count) + 1

    if cfg.TRAIN.NET_D != '':
        for i in range(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i), map_location=lambda storage, loc: storage)
            netsD[i].load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    
    if not distributed:
        if cfg.CUDA:
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
            inception_model = inception_model.cuda()
        inception_model = torch.nn.DataParallel(inception_model, device_ids=gpus)
    else:
        inception_model = torch.nn.parallel.DistributedDataParallel(inception_model.cuda(), device_ids=gpus, output_device=gpus[0])
        pass
    # inception_model = inception_model.cpu() #to(torch.device("cuda:{}".format(gpus[0])))
    inception_model.eval()
    print("model device, G:{}, D:{}, incep:{}".format(netG.device_ids, netsD[0].device_ids, inception_model.device_ids))
    return netG, netsD, len(netsD), inception_model, count