def create_MyDNN(opt):
    # Initialize the network
    MyDNN = network.MyDNN(opt)
    # Init the network
    network.weights_init(MyDNN, init_type = opt.init_type, init_gain = opt.init_gain)
    print('MyDNN is created!')
    return MyDNN
예제 #2
0
def create_discriminator(opt):
    # Initialize the networks
    discriminator = network.PatchDiscriminator(opt)
    print('Discriminator is created!')
    network.weights_init(discriminator, init_type = opt.init_type, init_gain = opt.init_gain)
    print('Initialize discriminator with %s type' % opt.init_type)
    return discriminator
예제 #3
0
def create_generator(opt):
    # Initialize the networks
    generator = network.GatedGenerator(opt)
    print('Generator is created!')
    network.weights_init(generator, init_type = opt.init_type, init_gain = opt.init_gain)
    print('Initialize generator with %s type' % opt.init_type)
    return generator
예제 #4
0
def create_discriminator(opt):
    # Initialize the network
    discriminator = network.AdversarialDiscriminator(opt)
    # Init the network
    network.weights_init(discriminator, init_type = opt.init_type, init_gain = opt.init_gain)
    print('Discriminators is created!')
    return discriminator
def create_CBD(opt):
    # Initialize the network
    CBD = network.CBD(opt)
    # Init the network
    network.weights_init(CBD, init_type = opt.init_type, init_gain = opt.init_gain)
    print('CBD is created!')
    return CBD
def create_DnCNN(opt):
    # Initialize the network
    DnCNN = network.DnCNN(opt)
    # Init the network
    network.weights_init(DnCNN, init_type = opt.init_type, init_gain = opt.init_gain)
    print('DnCNN is created!')
    return DnCNN
def create_UresNet(opt):
    # Initialize the network
    UresNet = network.UresNet(opt)
    # Init the network
    network.weights_init(UresNet, init_type = opt.init_type, init_gain = opt.init_gain)
    print('UresNet is created!')
    return UresNet
def create_CBD_generator(opt):
    # Initialize the network
    CBD_generator = network.CBD_Generator(opt)
    # Init the network
    network.weights_init(CBD_generator, init_type = opt.init_type, init_gain = opt.init_gain)
    print('CBD_generator is created!')
    return CBD_generator
def create_FaceDNN(opt):
    # Initialize the network
    FaceDNN = network.FaceDNN(opt)
    # Init the network
    network.weights_init(FaceDNN, init_type = opt.init_type, init_gain = opt.init_gain)
    print('FaceDNN is created!')
    return FaceDNN
예제 #10
0
def create_discriminator(opt):
    # Initialize the networks
    discriminator = network.PatchDiscriminator70(opt)
    # Init the networks
    network.weights_init(discriminator,
                         init_type=opt.init_type,
                         init_gain=opt.init_gain)
    return discriminator
예제 #11
0
def create_discriminator(opt):
    # Initialize the network
    discriminator_a = network.PatchDiscriminator70(opt)
    discriminator_b = network.PatchDiscriminator70(opt)
    # Init the network
    network.weights_init(discriminator_a, init_type = opt.init_type, init_gain = opt.init_gain)
    network.weights_init(discriminator_b, init_type = opt.init_type, init_gain = opt.init_gain)
    print('Discriminators is created!')
    return discriminator_a, discriminator_b
예제 #12
0
def create_encoder(opt):
    # Initialize the network
    encoder = network.Encoder(opt)
    # Init the network
    network.weights_init(encoder,
                         init_type=opt.init_type,
                         init_gain=opt.init_gain)
    print('Encoder is created!')
    return encoder
예제 #13
0
파일: utils.py 프로젝트: liujikun/ESRGAN
def create_ESRGAN_discriminator(opt):
    # Initialize the network
    discriminator = network.SR_VGG128_Discriminator(in_nc=3, base_nf=64)
    # Init the network
    network.weights_init(discriminator,
                         init_type=opt.init_type,
                         init_gain=opt.init_gain)
    print('ESRGAN discriminator is created!')
    return discriminator
예제 #14
0
def create_generator(opt):
    # Initialize the network
    generator = network.Generator(opt)
    # Init or Load value for the network
    network.weights_init(generator,
                         init_type=opt.init_type,
                         init_gain=opt.init_gain)
    print('Generator is created!')
    if opt.finetune_path != "":
        pretrained_net = torch.load(opt.finetune_path)
        generator = load_dict(generator, pretrained_net)
        print('Generator is loaded!')
    return generator
예제 #15
0
def create_generator(opt):
    # Initialize the networks
    generator = network.GatedGenerator(opt)
    print('Generator is created!')
    if opt.load_name:
        generator = load_dict(generator, opt.load_name)
    else:
        # Init the networks
        network.weights_init(generator,
                             init_type=opt.init_type,
                             init_gain=opt.init_gain)
        print('Initialize generator with %s type' % opt.init_type)
    return generator
예제 #16
0
파일: utils.py 프로젝트: liujikun/ESRGAN
def create_ESRGAN_generator(opt):
    # Initialize the network
    esrgan = network.RRDBNet(3, 3, 64, 23, gc=32)
    # Init the network
    if opt.ESRGAN_name:
        pretrained_net = torch.load(opt.ESRGAN_name)
        load_dict(esrgan, pretrained_net)
    else:
        network.weights_init(esrgan,
                             init_type=opt.init_type,
                             init_gain=opt.init_gain)
    print('ESRGAN generator is created!')
    return esrgan
예제 #17
0
def create_generator(opt):

    generator = network.Generator(opt)
    if opt.load_pre_train:
        pretrained_net = torch.load(opt.load_name + '.pth')
        load_dict(generator, pretrained_net)
        print('Generator is loaded!')
    else:
        # Init the network
        network.weights_init(generator,
                             init_type=opt.init_type,
                             init_gain=opt.init_gain)
        print('Generator is created!')

    return generator
예제 #18
0
def create_generator(opt):
    # Initialize the network
    generator = network.Generator(opt)
    if opt.pre_train:
        # Init the network
        network.weights_init(generator,
                             init_type=opt.init_type,
                             init_gain=opt.init_gain)
        print('Generator is created!')
    else:
        # Load a pre-trained network
        pretrained_net = torch.load(opt.load_name)
        load_dict(generator, pretrained_net)
        print('Generator is loaded!')
    return generator
예제 #19
0
def create_generator(opt):
    # Initialize the networks
    colorizationnet = network.SCGAN(opt)
    if opt.load_name == '':
        print('Generator is created!')
        # Init the networks
        network.weights_init(colorizationnet, init_type = opt.init_type, init_gain = opt.init_gain)
        pretrained_dict = torch.load(opt.global_feature_network_path)
        load_dict(colorizationnet.global_feature_network, pretrained_dict)
        print('Generator is loaded with %s!' % (opt.global_feature_network_path))
    else:
        pretrained_dict = torch.load(opt.load_name)
        load_dict(colorizationnet, pretrained_dict)
        print('Generator is loaded!')
    return colorizationnet
예제 #20
0
def create_generator(opt):
    if opt.pre_train == True:
        # Initialize the network
        generator_a = network.Generator(opt)
        generator_b = network.Generator(opt)
        # Init the network
        network.weights_init(generator_a, init_type = opt.init_type, init_gain = opt.init_gain)
        network.weights_init(generator_b, init_type = opt.init_type, init_gain = opt.init_gain)
        print('Generator is created!')
    else:
        # Load the weights
        generator_a = torch.load(opt.load_name + '_a.pth')
        generator_b = torch.load(opt.load_name + '_b.pth')
        print('Generator is loaded!')
    return generator_a, generator_b
예제 #21
0
def create_generator(opt):
    # Initialize the networks
    generator = network.GrayInpaintingNet(opt)
    print('Generator is created!')
    # Init the networks
    if opt.finetune_path:
        pretrained_net = torch.load(opt.finetune_path)
        generator = load_dict(generator, pretrained_net)
        print('Load generator with %s' % opt.finetune_path)
    else:
        network.weights_init(generator,
                             init_type=opt.init_type,
                             init_gain=opt.init_gain)
        print('Initialize generator with %s type' % opt.init_type)
    return generator
예제 #22
0
def create_generator(opt):
    # Initialize the network
    generator = network.KPN(opt.color, opt.burst_length, opt.blind_est, opt.kernel_size, opt.sep_conv, \
        opt.channel_att, opt.spatial_att, opt.upMode, opt.core_bias)
    if opt.load_name == '':
        # Init the network
        network.weights_init(generator,
                             init_type=opt.init_type,
                             init_gain=opt.init_gain)
        print('Generator is created!')
    else:
        # Load a pre-trained network
        pretrained_net = torch.load(opt.load_name)
        load_dict(generator, pretrained_net)
        print('Generator is loaded!')
    return generator
예제 #23
0
def create_generator(opt):
    generator = network.Net(num_channels=opt.num_channels, scale_factor=opt.scale_factor, d=32, s=5, m=1)
    if opt.load_pre_train:

        if '.pkl' in opt.load_name:
            generator.load_state_dict(torch.load(opt.load_name))
        else:
            pretrained_net = torch.load(opt.load_name)
            load_dict(generator, pretrained_net)
        print('Generator is loaded!')
    else:
        # Init the network
        network.weights_init(generator, init_type = opt.init_type, init_gain = opt.init_gain)
        print('Generator is created!')
    
    return generator