def __init__(self, opt):
        assert opt.isTrain
        opt = copy.deepcopy(opt)
        if len(opt.gpu_ids) > 0:
            opt.gpu_ids = opt.gpu_ids[:1]
        self.gpu_ids = opt.gpu_ids
        super(SPADEModelModules, self).__init__()
        self.opt = opt
        self.model_names = ['G_student', 'G_teacher', 'D']

        teacher_opt = self.create_option('teacher')
        self.netG_teacher = networks.define_G(opt.teacher_netG,
                                              gpu_ids=self.gpu_ids,
                                              opt=teacher_opt)
        student_opt = self.create_option('student')
        self.netG_student = networks.define_G(opt.student_netG,
                                              init_type=opt.init_type,
                                              init_gain=opt.init_gain,
                                              gpu_ids=self.gpu_ids,
                                              opt=student_opt)
        if hasattr(opt, 'distiller'):
            pretrained_opt = self.create_option('pretrained')
            self.netG_pretrained = networks.define_G(opt.pretrained_netG,
                                                     gpu_ids=self.gpu_ids,
                                                     opt=pretrained_opt)
        self.netD = networks.define_D(opt.netD,
                                      init_type=opt.init_type,
                                      init_gain=opt.init_gain,
                                      gpu_ids=self.gpu_ids,
                                      opt=opt)
        self.mapping_layers = ['head_0', 'G_middle_1', 'up_1']
        self.netAs = nn.ModuleList()
        for i, mapping_layer in enumerate(self.mapping_layers):
            if mapping_layer != 'up_1':
                fs, ft = opt.student_ngf * 16, opt.teacher_ngf * 16
            else:
                fs, ft = opt.student_ngf * 4, opt.teacher_ngf * 4
            if hasattr(opt, 'distiller'):
                netA = nn.Conv2d(in_channels=fs,
                                 out_channels=ft,
                                 kernel_size=1)
            else:
                netA = SuperConv2d(in_channels=fs,
                                   out_channels=ft,
                                   kernel_size=1)
            networks.init_net(netA, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netAs.append(netA)
        self.criterionGAN = GANLoss(opt.gan_mode)
        self.criterionFeat = nn.L1Loss()
        self.criterionVGG = VGGLoss()
        self.optimizers = []
        self.netG_teacher.eval()
        self.config = None
Ejemplo n.º 2
0
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Create a 3D generator

    Parameters:
        input_nc (int) -- the number of channels in input images
        output_nc (int) -- the number of channels in output images
        ngf (int) -- the number of filters in the last conv layer
        netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
        norm (str) -- the name of normalization layers used in the network: batch | instance | none
        use_dropout (bool) -- if use dropout layers.
        init_type (str)    -- the name of our initialization method.
        init_gain (float)  -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Returns a generator

    Our current implementation provides two types of generators:
        U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
        The original U-Net paper: https://arxiv.org/abs/1505.04597

        Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
        Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
        We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).


    The generator has been initialized by <init_net>. It uses RELU for non-linearity.
    """
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netG == 'resnet_9blocks':
        raise NotImplementedError
        # net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
    elif netG == 'resnet_6blocks':
        raise NotImplementedError
        # net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
    elif netG == 'unet_128':
        raise NotImplementedError
        # net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
    elif netG == 'unet_256':
        net = UnetGenerator3d(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
    return init_net(net, init_type, init_gain, gpu_ids)
Ejemplo n.º 3
0
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Create a discriminator

    Parameters:
        input_nc (int)     -- the number of channels in input images
        ndf (int)          -- the number of filters in the first conv layer
        netD (str)         -- the architecture's name: basic | n_layers | pixel
        n_layers_D (int)   -- the number of conv layers in the discriminator; effective when netD=='n_layers'
        norm (str)         -- the type of normalization layers used in the network.
        init_type (str)    -- the name of the initialization method.
        init_gain (float)  -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Returns a discriminator

    Our current implementation provides three types of discriminators:
        [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
        It can classify whether 70×70 overlapping patches are real or fake.
        Such a patch-level discriminator architecture has fewer parameters
        than a full-image discriminator and can work on arbitrarily-sized images
        in a fully convolutional fashion.

        [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
        with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)

        [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
        It encourages greater color diversity but has no effect on spatial statistics.

    The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
    """
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netD == 'basic':  # default PatchGAN classifier
        net = NLayerDiscriminator3d(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
    elif netD == 'n_layers':  # more options
        net = NLayerDiscriminator3d(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
    elif netD == 'pixel':     # classify if each pixel is real or fake
        raise NotImplementedError
        # net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
    return init_net(net, init_type, init_gain, gpu_ids)
Ejemplo n.º 4
0
    def __init__(self, opt):
        assert opt.isTrain
        valid_netGs = [
            'munit', 'super_munit', 'super_mobile_munit',
            'super_mobile_munit2', 'super_mobile_munit3'
        ]
        assert opt.teacher_netG in valid_netGs and opt.student_netG in valid_netGs
        super(BaseMunitDistiller, self).__init__(opt)
        self.loss_names = [
            'G_gan', 'G_rec_x', 'G_rec_c', 'G_rec_s', 'D_fake', 'D_real'
        ]
        if not opt.student_no_style_encoder:
            self.loss_names.append('G_rec_s')
        self.optimizers = []
        self.image_paths = []
        self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B']
        self.model_names = ['netG_student', 'netG_teacher', 'netD']
        opt_teacher = self.create_option('teacher')
        self.netG_teacher = networks.define_G(opt.teacher_netG,
                                              init_type=opt.init_type,
                                              init_gain=opt.init_gain,
                                              gpu_ids=self.gpu_ids,
                                              opt=opt_teacher)
        opt_student = self.create_option('student')
        self.netG_student = networks.define_G(opt.student_netG,
                                              init_type=opt.init_type,
                                              init_gain=opt.init_gain,
                                              gpu_ids=self.gpu_ids,
                                              opt=opt_student)
        self.netD = networks.define_D(opt.netD,
                                      input_nc=opt.output_nc,
                                      init_type='normal',
                                      init_gain=opt.init_gain,
                                      gpu_ids=self.gpu_ids,
                                      opt=opt)
        if hasattr(opt, 'distiller'):
            self.netA = nn.Conv2d(in_channels=4 * opt.student_ngf,
                                  out_channels=4 * opt.teacher_ngf,
                                  kernel_size=1).to(self.device)
        else:
            self.netA = SuperConv2d(in_channels=4 * opt.student_ngf,
                                    out_channels=4 * opt.teacher_ngf,
                                    kernel_size=1).to(self.device)
        networks.init_net(self.netA)
        self.netG_teacher.eval()

        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        self.criterionRec = torch.nn.L1Loss()

        G_params = []
        G_params.append(self.netG_student.parameters())
        G_params.append(self.netA.parameters())
        self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999),
                                            weight_decay=opt.weight_decay)
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999),
                                            weight_decay=opt.weight_decay)
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt,
                                                      direction=opt.direction)
        self.inception_model, _, _ = create_metric_models(opt,
                                                          device=self.device)
        self.npz = np.load(opt.real_stat_path)
        self.is_best = False
Ejemplo n.º 5
0
print('device: {}'.format(device))

# initial mesh
mesh = Mesh(opts.initial_mesh, device=device, hold_history=True)

# input point cloud
input_xyz, input_normals = utils.read_pts(opts.input_pc)
# normalize point cloud based on initial mesh
input_xyz /= mesh.scale
input_xyz += mesh.translations[None, :]
input_xyz = torch.Tensor(input_xyz).type(options.dtype()).to(device)[None, :, :]
input_normals = torch.Tensor(input_normals).type(options.dtype()).to(device)[None, :, :]

part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
print(f'number of parts {part_mesh.n_submeshes}')
net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)

for i in range(opts.iterations):
    num_samples = options.get_num_samples(i % opts.upsamp)
    if opts.global_step:
        optimizer.zero_grad()
    start_time = time.time()
    for part_i, est_verts in enumerate(net(rand_verts, part_mesh)):
        if not opts.global_step:
            optimizer.zero_grad()
        part_mesh.update_verts(est_verts[0], part_i)
        num_samples = options.get_num_samples(i % opts.upsamp)
        recon_xyz, recon_normals = sample_surface(part_mesh.main_mesh.faces, part_mesh.main_mesh.vs.unsqueeze(0), num_samples)
        # calc chamfer loss w/ normals
        recon_xyz, recon_normals = recon_xyz.type(options.dtype()), recon_normals.type(options.dtype())
        xyz_chamfer_loss, normals_chamfer_loss = chamfer_distance(recon_xyz, input_xyz, x_normals=recon_normals, y_normals=input_normals,
Ejemplo n.º 6
0
    def __init__(self, opt):
        assert opt.isTrain
        super(BaseResnetDistiller, self).__init__(opt)
        self.loss_names = ['G_gan', 'G_distill', 'G_recon', 'D_fake', 'D_real']
        self.optimizers = []
        self.image_paths = []
        self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B']
        self.model_names = ['netG_student', 'netG_teacher', 'netD']
        self.netG_teacher = networks.define_G(opt.input_nc, opt.output_nc, opt.teacher_ngf,
                                              opt.teacher_netG, opt.norm, opt.teacher_dropout_rate,
                                              opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)
        self.netG_student = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
                                              opt.student_netG, opt.norm, opt.student_dropout_rate,
                                              opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)

        if getattr(opt, 'sort_channels', False) and opt.restore_student_G_path is not None:
            self.netG_student_tmp = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
                                                      opt.student_netG.replace('super_', ''), opt.norm,
                                                      opt.student_dropout_rate, opt.init_type, opt.init_gain,
                                                      self.gpu_ids, opt=opt)
        if hasattr(opt, 'distiller'):
            self.netG_pretrained = networks.define_G(opt.input_nc, opt.output_nc, opt.pretrained_ngf,
                                                     opt.pretrained_netG, opt.norm, 0,
                                                     opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)

        if opt.dataset_mode == 'aligned':
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        elif opt.dataset_mode == 'unaligned':
            self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        else:
            raise NotImplementedError('Unknown dataset mode [%s]!!!' % opt.dataset_mode)

        self.netG_teacher.eval()
        self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        elif opt.recon_loss_type == 'vgg':
            self.criterionRecon = models.modules.loss.VGGLoss(self.device)
        else:
            raise NotImplementedError('Unknown reconstruction loss type [%s]!' % opt.loss_type)

        if isinstance(self.netG_teacher, nn.DataParallel):
            self.mapping_layers = ['module.model.%d' % i for i in range(9, 21, 3)]
        else:
            self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)]

        self.netAs = []
        self.Tacts, self.Sacts = {}, {}

        G_params = [self.netG_student.parameters()]
        for i, n in enumerate(self.mapping_layers):
            ft, fs = self.opt.teacher_ngf, self.opt.student_ngf
            if hasattr(opt, 'distiller'):
                netA = nn.Conv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            else:
                netA = SuperConv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            networks.init_net(netA)
            G_params.append(netA.parameters())
            self.netAs.append(netA)
            self.loss_names.append('G_distill%d' % i)

        self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt, direction=opt.direction)

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model.to(self.device)
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.npz = np.load(opt.real_stat_path)
        self.is_best = False
Ejemplo n.º 7
0
    def __init__(self, opts, input_dim, output_dim, lambda_ms=None):
        super(BicycleGANAdaIN, self).__init__()
        self.isTrain = (opts.phase == 'train')
        self.gpu_ids = opts.gpu_ids
        self.device = torch.device('cuda:{}'.format(
            self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.nz = opts.nz

        # generator
        self.netG = networks.init_net(Gen(input_dim,
                                          output_dim,
                                          style_dim=opts.nz),
                                      init_type='xavier',
                                      init_gain=0.02,
                                      gpu_ids=self.gpu_ids)

        # discriminator
        if self.isTrain:
            self.netD = networks.define_D(output_dim,
                                          64,
                                          netD='basic_256_multi',
                                          norm='instance',
                                          num_Ds=2,
                                          gpu_ids=self.gpu_ids)
            self.netD2 = networks.define_D(output_dim,
                                           64,
                                           netD='basic_256_multi',
                                           norm='instance',
                                           num_Ds=2,
                                           gpu_ids=self.gpu_ids)

        # encoder
        self.netE = networks.define_E(output_dim,
                                      opts.nz,
                                      64,
                                      netE=opts.bicycleE,
                                      norm='instance',
                                      vaeLike=True,
                                      gpu_ids=self.gpu_ids)

        # loss and optimizer and scheduler
        if self.isTrain:
            self.criterionGAN = networks.GANLoss(gan_mode=opts.gan_mode).to(
                self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionZ = torch.nn.L1Loss()
            self.lambda_ms = 0. if lambda_ms is None else lambda_ms
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opts.lr,
                                                betas=(0.5, 0.999))
            self.optimizer_E = torch.optim.Adam(self.netE.parameters(),
                                                lr=opts.lr,
                                                betas=(0.5, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opts.lr,
                                                betas=(0.5, 0.999))
            self.optimizer_D2 = torch.optim.Adam(self.netD2.parameters(),
                                                 lr=opts.lr,
                                                 betas=(0.5, 0.999))
            self.optimizers = [
                self.optimizer_G, self.optimizer_E, self.optimizer_D,
                self.optimizer_D2
            ]
Ejemplo n.º 8
0
    def __init__(self, opts):

        BaseModel.__init__(self, opts)

        lr = self.opt.lr

        self.model_names = ['G_A', 'G_B']
        self.loss_names = [
            'd_total', 'g_total', 'g_rec_x_a', 'g_rec_x_b', 'g_rec_s_a',
            'g_rec_s_b', 'g_rec_c_a', 'g_rec_c_b', 'g_adv_a', 'g_adv_b'
        ]
        self.visual_names = []
        # Initiate the networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
            self.netD_A = init_net(MsImageDis(self.opt.input_dim_a,
                                              self.opt.dis),
                                   init_type=self.opt.init,
                                   gpu_ids=self.gpu_ids)
            self.netD_B = init_net(MsImageDis(self.opt.input_dim_b,
                                              self.opt.dis),
                                   init_type=self.opt.init,
                                   gpu_ids=self.gpu_ids)
        self.netG_A = init_net(AdaINGen(self.opt.input_dim_a, self.opt.gen),
                               init_type=self.opt.init,
                               gpu_ids=self.gpu_ids)
        self.netG_B = init_net(AdaINGen(self.opt.input_dim_b, self.opt.gen),
                               init_type=self.opt.init,
                               gpu_ids=self.gpu_ids)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = self.opt.gen['style_dim']

        # fix the noise used in sampling
        display_size = self.opt.display_size
        self.s_a_fixed = torch.randn(display_size, self.style_dim, 1,
                                     1).to(self.device)
        self.s_b_fixed = torch.randn(display_size, self.style_dim, 1,
                                     1).to(self.device)

        if self.isTrain:
            # Setup the optimizers

            d_params = list(self.netD_A.parameters()) + list(
                self.netD_B.parameters())
            g_params = list(self.netG_A.parameters()) + list(
                self.netG_B.parameters())

            self.optimizer_D = torch.optim.Adam(
                [p for p in d_params if p.requires_grad],
                lr=lr,
                betas=(self.opt.beta1, self.opt.beta2),
                weight_decay=self.opt.weight_decay)
            self.optimizer_G = torch.optim.Adam(
                [p for p in g_params if p.requires_grad],
                lr=lr,
                betas=(self.opt.beta1, self.opt.beta2),
                weight_decay=self.opt.weight_decay)
            self.optimizer_names = ['optimizer_D', 'optimizer_G']
            self.optimizers.append(self.optimizer_D)
            self.optimizers.append(self.optimizer_G)

            self.scheduler_D = get_scheduler(self.optimizer_D, self.opt)
            self.scheduler_G = get_scheduler(self.optimizer_G, self.opt)
            self.schedulers = [self.scheduler_D, self.scheduler_G]

        # Load VGG model if needed
        if (self.opt.vgg_w is not None) and self.opt.vgg_w > 0:
            self.vgg = load_vgg16(self.opt.vgg_model_path + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Ejemplo n.º 9
0
def main():
    print(
        '############################### train.py ###############################'
    )

    # Set random seed for reproducibility
    manual_seed = 999
    # manualSeed = random.randint(1, 10000) # use if you want new results
    print("Random Seed: ", manual_seed)
    print()
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)

    # Hyper parameters
    workers = 2
    batch_size = 128
    image_size = 128
    nc = 3
    in_ngc = 3
    out_ngc = 3
    in_ndc = in_ngc + out_ngc
    out_ndc = 1
    ngf = 64
    ndf = 32
    sf = 100  # style factor for generator
    learning_rate = 0.0005
    beta1 = 0.5
    epochs = 100
    gpu = True
    load_saved_model = False

    # print hyper parameters
    print(f'number of workers : {workers}')
    print(f'batch size : {batch_size}')
    print(f'image size : {image_size}')
    print(f'number of channels : {nc}')
    print(f'generator feature map size : {ngf}')
    print(f'discriminator feature map size : {ndf}')
    print(f'style factor : {sf}')
    print(f'learning rate : {learning_rate}')
    print(f'beta1 : {beta1}')
    print(f'epochs: {epochs}')
    print(f'GPU: {gpu}')
    print(f'load saved model: {load_saved_model}')
    print()

    # set up GPU device
    cuda = True if gpu and torch.cuda.is_available() else False

    # load CelebA dataset
    download_path = '/home/pbuddare/EEE_598/data/CelebA'
    # download_path = '/Users/prasanth/Academics/ASU/FALL_2019/EEE_598_CIU/data/Project/CelebA'
    data_loader_src = prepare_celeba_data(download_path, batch_size,
                                          image_size, workers)

    # load respective cartoon dataset
    download_path = '/home/pbuddare/EEE_598/data/Cartoon'
    # download_path = '/Users/prasanth/Academics/ASU/FALL_2019/EEE_598_CIU/data/Project/Cartoon'
    data_loader_tgt = prepare_cartoon_data(download_path, batch_size,
                                           image_size, workers)

    # show sample images
    show_images(
        next(iter(data_loader_src))[0], (8, 8), 16,
        'Training images (Natural)', 'human_real')
    show_images(
        next(iter(data_loader_tgt))[0], (8, 8), 16,
        'Training images (Cartoon)', 'cartoon_real')

    # create generator and discriminator networks
    generator = Generator(in_ngc, out_ngc, ngf)
    discriminator = Discriminator(in_ndc, ndf)
    if cuda:
        generator.cuda()
        discriminator.cuda()
    init_net(generator, gpu_ids=[0])
    init_net(discriminator, gpu_ids=[0])

    # loss function and optimizers
    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()
    optimizer_g = optim.Adam(generator.parameters(),
                             lr=learning_rate,
                             betas=(beta1, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr=learning_rate,
                             betas=(beta1, 0.999))

    # Train GAN
    loss_G, loss_D = train_gan(data_loader_src, data_loader_tgt, generator,
                               discriminator, criterion_GAN, criterion_L1,
                               optimizer_d, optimizer_g, sf, batch_size,
                               epochs, cuda)

    # save parameters
    current_time = str(datetime.datetime.now().time()).replace(
        ':', '').replace('.', '') + '.pth'
    g_path = './project_G_' + current_time
    d_path = './project_D_' + current_time
    torch.save(generator.state_dict(), g_path)
    torch.save(discriminator.state_dict(), d_path)

    # generate and display fake images
    test_imgs = next(iter(data_loader_src))[0]
    show_images(test_imgs, (8, 8), 16, 'Testing images (Natural)',
                'human_real_test')
    test_imgs = test_imgs.cuda() if cuda else test_imgs
    fake_imgs = generator(test_imgs).detach()
    show_images(fake_imgs.cpu(), (8, 8), 16, 'Fake images (Cartoon)',
                'cartoon_fake')
Ejemplo n.º 10
0
Archivo: trainer.py Proyecto: deJQK/CAT
    def start(self):
        opt = self.opt
        dataloader = self.dataloader
        model = self.model
        modules_on_one_gpu = getattr(model, 'modules_on_one_gpu', model)
        logger = self.logger

        if self.task == 'distill':
            shrink(model, opt)
            modules_on_one_gpu.netG_student = init_net(
                modules_on_one_gpu.netG_student, opt.init_type, opt.init_gain,
                []).to(model.device)
            if getattr(opt, 'prune_continue', False):
                model.load_networks(restore_pretrain=False)
                logger.print_info('All networks loaded.')
            model.print_networks()
            if 'spade' in self.opt.distiller:
                logger.print_info(
                    f'netG student FLOPs: {mc.unwrap_model(modules_on_one_gpu.netG_student).n_macs}.'
                )
            else:
                logger.print_info(
                    f'netG student FLOPs: {mc.unwrap_model(modules_on_one_gpu.netG_student).n_macs}; down sampling: {mc.unwrap_model(modules_on_one_gpu.netG_student).down_sampling.n_macs}; features: {mc.unwrap_model(modules_on_one_gpu.netG_student).features.n_macs}; up sampling: {mc.unwrap_model(modules_on_one_gpu.netG_student).up_sampling.n_macs}.'
                )
            if getattr(opt, 'prune_only', False):
                return

        start_epoch = opt.epoch_base
        end_epoch = opt.epoch_base + opt.nepochs + opt.nepochs_decay - 1
        total_iter = opt.iter_base
        for epoch in range(start_epoch, end_epoch + 1):
            epoch_start_time = time.time()
            for i, data_i in enumerate(dataloader):
                iter_start_time = time.time()
                model.set_input(data_i)
                model.optimize_parameters(total_iter)

                if total_iter % opt.print_freq == 0:
                    losses = model.get_current_losses()
                    logger.print_current_errors(epoch, total_iter, losses,
                                                time.time() - iter_start_time)
                    logger.plot(losses, total_iter)

                if total_iter % opt.save_latest_freq == 0 or total_iter == opt.iter_base:
                    self.evaluate(
                        epoch, total_iter,
                        'Saving the latest model (epoch %d, total_steps %d)' %
                        (epoch, total_iter))
                    if getattr(model, 'is_best', False):
                        model.save_networks('iter%d' % total_iter)
                        model.save_networks('best')
                    if getattr(model, 'is_best_A', False):
                        model.save_networks('iter%d' % total_iter)
                        model.save_networks('best_A')
                    if getattr(model, 'is_best_B', False):
                        model.save_networks('iter%d' % total_iter)
                        model.save_networks('best_B')

                total_iter += 1
            logger.print_info(
                'End of epoch %d / %d \t Time Taken: %.2f sec' %
                (epoch, end_epoch, time.time() - epoch_start_time))
            if epoch % opt.save_epoch_freq == 0 or epoch == end_epoch:
                self.evaluate(
                    epoch, total_iter,
                    'Saving the model at the end of epoch %d, iters %d' %
                    (epoch, total_iter))
                model.save_networks(epoch)
                if getattr(model, 'is_best', False):
                    model.save_networks('iter%d' % total_iter)
                    model.save_networks('best')
                if getattr(model, 'is_best_A', False):
                    model.save_networks('iter%d' % total_iter)
                    model.save_networks('best_A')
                if getattr(model, 'is_best_B', False):
                    model.save_networks('iter%d' % total_iter)
                    model.save_networks('best_B')
            model.update_learning_rate(logger)
Ejemplo n.º 11
0
    def run(self):
        # mesh = Mesh(opts.initial_mesh, device=device, hold_history=True)
        mesh = vtkMesh(self.initPoly, device=device, hold_history=True)

        # input point cloud
        input_xyz, input_normals = self.MakeInputData(self.targetPoly)
        # normalize point cloud based on initial mesh
        input_xyz /= mesh.scale
        input_xyz += mesh.translations[None, :]
        input_xyz = torch.Tensor(input_xyz).type(
            options.dtype()).to(device)[None, :, :]
        input_normals = torch.Tensor(input_normals).type(
            options.dtype()).to(device)[None, :, :]

        part_mesh = PartMesh(mesh,
                             num_parts=options.get_num_parts(len(mesh.faces)),
                             bfs_depth=opts.overlap)
        print(f'number of parts {part_mesh.n_submeshes}')
        net, optimizer, rand_verts, scheduler = init_net(
            mesh, part_mesh, device, opts)

        beamgap_loss = BeamGapLoss(device)

        if opts.beamgap_iterations > 0:
            print('beamgap on')
            beamgap_loss.update_pm(
                part_mesh, torch.cat([input_xyz, input_normals], dim=-1))

        for i in range(opts.iterations):
            num_samples = options.get_num_samples(i % opts.upsamp)
            if opts.global_step:
                optimizer.zero_grad()
            start_time = time.time()
            for part_i, est_verts in enumerate(net(rand_verts, part_mesh)):
                if not opts.global_step:
                    optimizer.zero_grad()
                part_mesh.update_verts(est_verts[0], part_i)
                num_samples = options.get_num_samples(i % opts.upsamp)
                recon_xyz, recon_normals = sample_surface(
                    part_mesh.main_mesh.faces,
                    part_mesh.main_mesh.vs.unsqueeze(0), num_samples)
                # calc chamfer loss w/ normals
                recon_xyz, recon_normals = recon_xyz.type(
                    options.dtype()), recon_normals.type(options.dtype())
                xyz_chamfer_loss, normals_chamfer_loss = chamfer_distance(
                    recon_xyz,
                    input_xyz,
                    x_normals=recon_normals,
                    y_normals=input_normals,
                    unoriented=opts.unoriented)

                if (i < opts.beamgap_iterations) and (i % opts.beamgap_modulo
                                                      == 0):
                    loss = beamgap_loss(part_mesh, part_i)
                else:
                    loss = (xyz_chamfer_loss +
                            (opts.ang_wt * normals_chamfer_loss))
                if opts.local_non_uniform > 0:
                    loss += opts.local_non_uniform * local_nonuniform_penalty(
                        part_mesh.main_mesh).float()
                loss.backward()
                if not opts.global_step:
                    optimizer.step()
                    scheduler.step()
                part_mesh.main_mesh.vs.detach_()
            if opts.global_step:
                optimizer.step()
                scheduler.step()
            end_time = time.time()

            if i % 1 == 0:
                print(
                    f'{os.path.basename(opts.input_pc)}; iter: {i} out of: {opts.iterations}; loss: {loss.item():.4f};'
                    f' sample count: {num_samples}; time: {end_time - start_time:.2f}'
                )

            # mesh.export(os.path.join("./", f'recon_iter_{i}.obj'))
            self.backwarded.emit(mesh.vs)

            # if (i > 0 and (i + 1) % opts.upsamp == 0):
            #     mesh = part_mesh.main_mesh
            #     num_faces = int(np.clip(len(mesh.faces) * 1.5, len(mesh.faces), opts.max_faces))

            #     if num_faces > len(mesh.faces) or opts.manifold_always:
            #         # up-sample mesh
            #         mesh = utils.manifold_upsample(mesh, opts.save_path, Mesh,
            #                                     num_faces=min(num_faces, opts.max_faces),
            #                                     res=opts.manifold_res, simplify=True)

            #         part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
            #         print(f'upsampled to {len(mesh.faces)} faces; number of parts {part_mesh.n_submeshes}')
            #         net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)
            #         if i < opts.beamgap_iterations:
            #             print('beamgap updated')
            #             beamgap_loss.update_pm(part_mesh, input_xyz)

        with torch.no_grad():
            mesh.export(os.path.join(opts.save_path, 'last_recon.obj'))
Ejemplo n.º 12
0
    def __init__(self, opt):
        assert opt.isTrain
        valid_netGs = [
            'resnet_9blocks', 'mobile_resnet_9blocks',
            'super_mobile_resnet_9blocks', 'sub_mobile_resnet_9blocks'
        ]
        assert opt.teacher_netG in valid_netGs and opt.student_netG in valid_netGs
        super(BaseResnetDistiller, self).__init__(opt)
        self.loss_names = ['G_gan', 'G_distill', 'G_recon', 'D_fake', 'D_real']
        self.optimizers = []
        self.image_paths = []
        self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B']
        self.model_names = ['netG_student', 'netG_teacher', 'netD']
        self.netG_teacher = networks.define_G(
            opt.teacher_netG,
            input_nc=opt.input_nc,
            output_nc=opt.output_nc,
            ngf=opt.teacher_ngf,
            norm=opt.norm,
            dropout_rate=opt.teacher_dropout_rate,
            gpu_ids=self.gpu_ids,
            opt=opt)
        self.netG_student = networks.define_G(
            opt.student_netG,
            input_nc=opt.input_nc,
            output_nc=opt.output_nc,
            ngf=opt.student_ngf,
            norm=opt.norm,
            dropout_rate=opt.student_dropout_rate,
            init_type=opt.init_type,
            init_gain=opt.init_gain,
            gpu_ids=self.gpu_ids,
            opt=opt)
        if hasattr(opt, 'distiller'):
            self.netG_pretrained = networks.define_G(opt.pretrained_netG,
                                                     input_nc=opt.input_nc,
                                                     output_nc=opt.output_nc,
                                                     ngf=opt.pretrained_ngf,
                                                     norm=opt.norm,
                                                     gpu_ids=self.gpu_ids,
                                                     opt=opt)
        if opt.dataset_mode == 'aligned':
            self.netD = networks.define_D(opt.netD,
                                          input_nc=opt.input_nc +
                                          opt.output_nc,
                                          ndf=opt.ndf,
                                          n_layers_D=opt.n_layers_D,
                                          norm=opt.norm,
                                          init_type=opt.init_type,
                                          init_gain=opt.init_gain,
                                          gpu_ids=self.gpu_ids,
                                          opt=opt)
        elif opt.dataset_mode == 'unaligned':
            self.netD = networks.define_D(opt.netD,
                                          input_nc=opt.output_nc,
                                          ndf=opt.ndf,
                                          n_layers_D=opt.n_layers_D,
                                          norm=opt.norm,
                                          init_type=opt.init_type,
                                          init_gain=opt.init_gain,
                                          gpu_ids=self.gpu_ids,
                                          opt=opt)
        else:
            raise NotImplementedError('Unknown dataset mode [%s]!!!' %
                                      opt.dataset_mode)

        self.netG_teacher.eval()
        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        elif opt.recon_loss_type == 'vgg':
            self.criterionRecon = models.modules.loss.VGGLoss(self.device)
        else:
            raise NotImplementedError(
                'Unknown reconstruction loss type [%s]!' % opt.loss_type)

        if isinstance(self.netG_teacher, nn.DataParallel):
            self.mapping_layers = [
                'module.model.%d' % i for i in range(9, 21, 3)
            ]
        else:
            self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)]

        self.netAs = []
        self.Tacts, self.Sacts = {}, {}

        G_params = [self.netG_student.parameters()]
        for i, n in enumerate(self.mapping_layers):
            ft, fs = self.opt.teacher_ngf, self.opt.student_ngf
            if hasattr(opt, 'distiller'):
                netA = nn.Conv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            else:
                netA = SuperConv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            networks.init_net(netA)
            G_params.append(netA.parameters())
            self.netAs.append(netA)
            self.loss_names.append('G_distill%d' % i)

        self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt,
                                                      direction=opt.direction)
        self.inception_model, self.drn_model, _ = create_metric_models(
            opt, device=self.device)
        self.npz = np.load(opt.real_stat_path)
        self.is_best = False