def __init__(
            self,
            encoder_type,
            shape_decoder_type,
            texture_decoder_type,
            discriminator_type,
            vertex_scaling,
            texture_scaling,
            silhouette_loss_levels,
            lambda_silhouettes,
            lambda_textures,
            lambda_perceptual,
            lambda_inflation,
            lambda_discriminator,
            lambda_discriminator2,
            lambda_graph_laplacian,
            single_view_training,
            class_conditional,
            no_texture,
            num_views,
            dim_hidden=512,
            anti_aliasing=False,
    ):
        super(ShapeNetModel, self).__init__()
        self.trainer = None
        self.dataset_name = 'shapenet'

        # model size
        self.dim_hidden = dim_hidden

        # loss weights
        self.silhouette_loss_levels = silhouette_loss_levels
        self.lambda_silhouettes = lambda_silhouettes
        self.lambda_textures = lambda_textures
        self.lambda_perceptual = lambda_perceptual
        self.lambda_discriminator = lambda_discriminator
        self.lambda_discriminator2 = lambda_discriminator2
        self.lambda_inflation = lambda_inflation
        self.lambda_graph_laplacian = lambda_graph_laplacian

        # others
        self.single_view_training = single_view_training
        self.class_conditional = class_conditional
        self.no_texture = no_texture
        self.num_views = num_views
        self.use_depth = False

        # setup renderer
        self.renderer = neural_renderer.Renderer()
        self.renderer.image_size = 224
        self.renderer.anti_aliasing = anti_aliasing
        self.renderer.perspective = True
        self.renderer.viewing_angle = self.xp.degrees(self.xp.arctan(16. / 60.))
        self.renderer.camera_mode = 'look_at'
        self.renderer.blur_size = 0

        with self.init_scope():
            # setup links
            dim_in_encoder = 3
            if no_texture:
                texture_decoder_type = 'dummy'
                dim_in_discriminator = 1
            else:
                dim_in_discriminator = 4

            self.encoder = encoders.get_encoder(encoder_type, dim_in_encoder, self.dim_hidden)
            self.shape_decoder = decoders.get_shape_decoder(shape_decoder_type, self.dim_hidden, vertex_scaling)
            self.texture_decoder = decoders.get_texture_decoder(texture_decoder_type, self.dim_hidden, texture_scaling)
            self.discriminator = discriminators.get_discriminator(discriminator_type, dim_in_discriminator)
            self.shape_encoder = self.encoder
Exemple #2
0
def main(argv=None):
    tf.set_random_seed(1237)
    np.random.seed(1237)

    # Load data
    x_train, sorted_x_train = \
            utils.load_image_data(FLAGS.dataset, n_xl, n_channels, FLAGS.mbs)
    xshape = (-1, n_xl, n_xl, n_channels)
    print('Data shape = {}'.format(x_train.shape))

    x_train = x_train * 2 - 1
    sorted_x_train = sorted_x_train * 2 - 1

    # Make some data
    is_training = tf.placeholder_with_default(False,
                                              shape=[],
                                              name='is_training')
    generator = get_generator(FLAGS.dataset, FLAGS.arch,
                              n_code if FLAGS.arch == 'ae' else n_x, n_xl,
                              n_channels, n_z, ngf, is_training,
                              'transformation')
    if FLAGS.arch == 'adv':
        discriminator = get_discriminator(FLAGS.dataset, FLAGS.arch, n_x, n_xl,
                                          n_channels, n_f, ngf // 2,
                                          is_training)
        decoder = get_generator(FLAGS.dataset, FLAGS.arch, n_x, n_xl,
                                n_channels, n_f, ngf, is_training, 'decoder')

    # Define training/evaluation parameters
    run_name = 'results/{}_{}_{}_{}_c{}_mbs{}_bs{}_lr{}_t0{}'.format(
        FLAGS.dataset, FLAGS.arch, FLAGS.dist, FLAGS.match, n_code, FLAGS.mbs,
        FLAGS.bs, FLAGS.lr0, FLAGS.t0)

    if not os.path.exists(run_name):
        os.makedirs(run_name)

    # Build the computation graph
    if FLAGS.arch == 'ae':
        ae = ConvAE(x_train, (None, n_xl, n_xl, n_channels), ngf)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ae.train(sess)
            x_code = ae.encode(x_train, sess)
            sorted_x_code = ae.encode(sorted_x_train, sess)

        model = MyPMD(x_code, sorted_x_code, xshape, generator, run_name, ae)
    elif FLAGS.arch == 'adv':
        model = MyPMD(x_train,
                      sorted_x_train,
                      xshape,
                      generator,
                      run_name,
                      F=discriminator,
                      D=decoder)
    else:
        model = MyPMD(x_train, sorted_x_train, xshape, generator, run_name)

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.arch == 'ae':
            ae.train(sess)

        print('Training...')
        model.train(sess,
                    gen_dict={
                        model.batch_size_ph: FLAGS.mbs,
                        is_training: False
                    },
                    opt_dict={
                        model.batch_size_ph: FLAGS.bs,
                        is_training: True
                    },
                    iters=((x_train.shape[0] - 1) // FLAGS.mbs) + 1)
    def __init__(self, source_dataset):
        ###########################
        # Initialize Info Holders #
        ###########################
        self.args = get_params(source_dataset, experiment='adaptation')
        self.source_best_pred = 0.0
        self.target_best_pred = 0.0
        self.best_source_net_state = None
        self.best_target_net_state = None
        self.source_test_losses = []
        self.target_test_losses = []
        self.source_test_acc = []
        self.target_test_acc = []
        self.iters = 0

        #######################################
        # Initialize Source and target labels #
        #######################################
        self.source_disc_labels = torch.zeros(size=(self.args.batch_size_train,
                                                    1)).requires_grad_(False)
        self.target_disc_labels = torch.ones(size=(self.args.batch_size_train,
                                                   1)).requires_grad_(False)
        if self.args.cuda:
            self.source_disc_labels = self.source_disc_labels.cuda()
            self.target_disc_labels = self.target_disc_labels.cuda()

        ######################
        # Define DataLoaders #
        ######################
        kwargs = {'num_workers': 8, 'pin_memory': True}
        self.source_train_loader, self.source_test_loader = get_dataloaders(
            self.args.source_dataset, **kwargs)
        self.target_train_loader, self.target_test_loader = get_dataloaders(
            self.args.target_dataset, **kwargs)
        self.n_batch = min(len(self.target_train_loader),
                           len(self.source_train_loader))

        ##################
        # Define network #
        ##################
        self.net = get_classifier(source_dataset)

        if self.args.cuda:
            self.net = torch.nn.DataParallel(self.net, device_ids=[0])
            self.net = self.net.cuda()

        ###############
        # Set Encoder #
        ###############
        if self.args.cuda:
            self.encoder = self.net.module.encode
        else:
            self.encoder = self.net.encode

        ###################################################
        # Set Domain Classifier (Encoder + Discriminator) #
        ###################################################
        self.discriminator = get_discriminator(source_dataset)
        self.domain_classifier = DomainClassifier(self.encoder,
                                                  self.discriminator)
        if self.args.cuda:
            self.domain_classifier = torch.nn.DataParallel(
                self.domain_classifier, device_ids=[0])
            self.domain_classifier = self.domain_classifier.cuda()

        #####################
        # Define Optimizers #
        #####################
        self.net_optimizer = torch.optim.SGD(self.net.parameters(),
                                             lr=self.args.learning_rate,
                                             momentum=self.args.momentum)
        self.encoder_optimizer = torch.optim.SGD(self.net.parameters(),
                                                 self.args.learning_rate,
                                                 momentum=self.args.momentum)
        self.discriminator_optimizer = torch.optim.SGD(
            self.discriminator.parameters(),
            lr=self.args.learning_rate,
            momentum=self.args.momentum)
Exemple #4
0
    def __init__(
        self,
        encoder_type,
        shape_decoder_type,
        texture_decoder_type,
        discriminator_type,
        silhouette_loss_type,
        vertex_scaling,
        texture_scaling,
        silhouette_loss_levels,
        lambda_silhouettes,
        lambda_perceptual,
        lambda_inflation,
        lambda_graph_laplacian,
        lambda_discriminator,
        no_texture,
        class_conditional,
        symmetric,
        dim_hidden=512,
        image_size=224,
        anti_aliasing=False,
    ):
        super(PascalModel, self).__init__()
        self.trainer = None
        self.dataset_name = 'pascal'

        # model size
        self.dim_hidden = dim_hidden

        # loss type
        self.silhouette_loss_type = silhouette_loss_type
        self.silhouette_loss_levels = silhouette_loss_levels

        # loss weights
        self.lambda_silhouettes = lambda_silhouettes
        self.lambda_perceptual = lambda_perceptual
        self.lambda_discriminator = lambda_discriminator
        self.lambda_inflation = lambda_inflation
        self.lambda_graph_laplacian = lambda_graph_laplacian

        # others
        self.no_texture = no_texture
        self.class_conditional = class_conditional
        self.symmetric = symmetric
        self.use_depth = False

        # setup renderer
        self.renderer = neural_renderer.Renderer()
        self.renderer.image_size = image_size
        self.renderer.anti_aliasing = anti_aliasing
        self.renderer.perspective = False
        self.renderer.camera_mode = 'none'

        with self.init_scope():
            # setup links
            dim_in_encoder = 3
            if no_texture:
                texture_decoder_type = 'dummy'
                dim_in_discriminator = 1
            else:
                dim_in_discriminator = 4

            self.encoder = encoders.get_encoder(encoder_type, dim_in_encoder,
                                                self.dim_hidden)
            self.shape_decoder = decoders.get_shape_decoder(
                shape_decoder_type, self.dim_hidden, vertex_scaling, symmetric)
            self.texture_decoder = decoders.get_texture_decoder(
                texture_decoder_type, self.dim_hidden, texture_scaling,
                self.symmetric)
            self.discriminator = discriminators.get_discriminator(
                discriminator_type, dim_in_discriminator)