Exemplo n.º 1
0
def init_models(opt, img_shape) -> "Tuple[models.WDiscriminator, models.WDiscriminator, models.WDiscriminator, models.GeneratorConcatSkip2CleanAdd]":

    #generator initialization:
    netG = models.GeneratorConcatSkip2CleanAdd(opt, img_shape).to(opt.device)
    netG.apply(models.weights_init)
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
    logger.info(netG)

    # general discriminator initialization for both images:
    netD = models.WDiscriminator(opt).to(opt.device)
    netD.apply(models.weights_init)
    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    logger.info(netD)

    # discriminator initialization for identifying the mask of the first image:
    netD_mask1 = models.WDiscriminator(opt).to(opt.device)
    netD_mask1.apply(models.weights_init)
    if opt.netD_mask1 != '':
        netD_mask1.load_state_dict(torch.load(opt.netD_mask1))
    logger.info(netD_mask1)

    # discriminator initialization for identifying the mask of the second image:
    netD_mask2 = models.WDiscriminator(opt).to(opt.device)
    netD_mask2.apply(models.weights_init)
    if opt.netD_mask2 != '':
        netD_mask2.load_state_dict(torch.load(opt.netD_mask2))
    logger.info(netD_mask2)

    return netD, netD_mask1, netD_mask2, netG
Exemplo n.º 2
0
def load_trained_pyramid(opt, mode_='train'):
    mode = opt.mode
    opt.mode = 'train'
    if (mode == 'animation_train') | (mode == 'SR_train') | (mode
                                                             == 'paint_train'):
        opt.mode = mode
    dir = generate_dir2save(opt)
    i = 0
    Gs = []
    opt.out_ = os.path.join(opt.out, opt.out_)
    if (os.path.exists(dir)):
        while i >= 0:
            if os.path.exists('%s/%s/' % (opt.out_, str(i))):
                netG = models.GeneratorConcatSkip2CleanAdd(opt)
                netG.load_weights('%s/%s/netG' % (opt.out_, str(i)))
                Gs.append(netG)
                i += 1
            else:
                break

        with open('%s/Zs.pkl' % (opt.out_), 'rb') as f:
            Zs = pickle.load(f)
        with open('%s/reals.pkl' % (opt.out_), 'rb') as f:
            reals = pickle.load(f)
        with open('%s/NoiseAmp.pkl' % (opt.out_), 'rb') as f:
            NoiseAmp = pickle.load(f)
    else:
        print('no appropriate trained model is exist, please train first')
    opt.mode = mode
    return Gs, Zs, reals, NoiseAmp
Exemplo n.º 3
0
def init_models(opt):
    # 模型初始化
    netG = models.GeneratorConcatSkip2CleanAdd(opt).to(opt.device)
    netG.apply(models.weights_init)
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
    print(netG)

    # discriminator initialization:
    netD = models.WDiscriminator(opt).to(opt.device)
    netD.apply(models.weights_init)
    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    print(netD)

    return netD, netG
Exemplo n.º 4
0
def init_models(opt):
    # opt.min_nfc is the main parameter that controls the width (filter number) for each layer
    #generator initialization:
    netG = models.GeneratorConcatSkip2CleanAdd(opt).to(opt.device)
    netG.apply(
        models.weights_init)  # apply weight initialize function of models.
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
    print(netG)

    #discriminator initialization:
    netD = models.WDiscriminator(opt).to(opt.device)
    netD.apply(models.weights_init)
    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    print(netD)

    return netD, netG
Exemplo n.º 5
0
def init_models(opt):
    netD = models.WDiscriminator(opt)    
    netG = models.GeneratorConcatSkip2CleanAdd(opt)
    return netD, netG