示例#1
0
 def _model(self):
     default_args = {
         "nonlinearity": self.nonlinearity,
         "bn": self.bn,
         "kernel_initializer": self.kernel_initializer,
         "kernel_regularizer": self.kernel_regularizer,
         "is_training": self.is_training,
         "counters": self.counters,
     }
     with arg_scope([self.mlp], **default_args):
         y_hat = self.mlp(self.X_c, scope='mlp')
         vars = get_trainable_variables(['mlp'])
         inner_iters = 1
         eval_iters = 10
         y_hat_test_arr = [
             self.mlp(self.X_t,
                      scope='mlp-test-{0}'.format(0),
                      params=vars.copy())
         ]
         for k in range(1, max(inner_iters, eval_iters) + 1):
             loss = tf.losses.mean_squared_error(labels=self.y_c,
                                                 predictions=y_hat)
             grads = tf.gradients(loss,
                                  vars,
                                  colocate_gradients_with_ops=True)
             vars = [v - self.alpha * g for v, g in zip(vars, grads)]
             y_hat = self.mlp(self.X_c,
                              scope='mlp-{0}'.format(k),
                              params=vars.copy())
             y_hat_test = self.mlp(self.X_t,
                                   scope='mlp-test-{0}'.format(k),
                                   params=vars.copy())
             y_hat_test_arr.append(y_hat_test)
         self.eval_ops = y_hat_test_arr
         return y_hat_test_arr[inner_iters]
示例#2
0
    def _model(self):
        default_args = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([self.conditional_decoder], **default_args):
            default_args.update({"bn": False})
            with arg_scope([self.sample_encoder, self.aggregator],
                           **default_args):
                num_c = tf.shape(self.X_c)[0]
                X_ct = tf.concat([self.X_c, self.X_t], axis=0)
                y_ct = tf.concat([self.y_c, self.y_t], axis=0)
                r_ct = self.sample_encoder(X_ct, y_ct, self.r_dim)

                self.z_mu_pr, self.z_log_sigma_sq_pr, self.z_mu_pos, self.z_log_sigma_sq_pos = self.aggregator(
                    r_ct, num_c, self.z_dim)
                if self.user_mode == 'train':
                    z = gaussian_sampler(self.z_mu_pos,
                                         tf.exp(0.5 * self.z_log_sigma_sq_pos))
                elif self.user_mode == 'eval':
                    z = self.z_mu_pos
                else:
                    raise Exception("unknown user_mode")
                z = (1 - self.use_z_ph) * z + self.use_z_ph * self.z_ph

                # add maml ops
                y_hat = self.conditional_decoder(self.X_c, z)
                vars = get_trainable_variables(['conditional_decoder'])
                inner_iters = 1
                eval_iters = 10
                y_hat_test_arr = [
                    self.conditional_decoder(self.X_t, z, params=vars.copy())
                ]
                for k in range(1, max(inner_iters, eval_iters) + 1):
                    loss = sum_squared_error(labels=self.y_c,
                                             predictions=y_hat)
                    grads = tf.gradients(loss,
                                         vars,
                                         colocate_gradients_with_ops=True)
                    vars = [v - self.alpha * g for v, g in zip(vars, grads)]
                    y_hat = self.conditional_decoder(self.X_c,
                                                     z,
                                                     params=vars.copy())
                    y_hat_test = self.conditional_decoder(self.X_t,
                                                          z,
                                                          params=vars.copy())
                    y_hat_test_arr.append(y_hat_test)
                self.eval_ops = y_hat_test_arr
                return y_hat_test_arr[inner_iters]
示例#3
0
    def construct_models(self,
                         model_cls,
                         model_opt,
                         learning_rate,
                         trainable_params=None,
                         eval_keys=['total loss']):
        # models
        self.models = [model_cls(counters={}) for i in range(self.nr_gpu)]
        template = tf.make_template('model', model_cls.build_graph)
        for i in range(self.nr_gpu):
            with tf.device('/gpu:%d' % i):
                template(self.models[i], **model_opt)
        if trainable_params is None:
            self.params = tf.trainable_variables()
        else:
            self.params = get_trainable_variables(trainable_params)
        # gradients
        grads = []
        for i in range(self.nr_gpu):
            with tf.device('/gpu:%d' % i):
                grads.append(
                    tf.gradients(self.models[i].loss,
                                 self.params,
                                 colocate_gradients_with_ops=True))
        with tf.device('/gpu:0'):
            for i in range(1, self.nr_gpu):
                for j in range(len(grads[0])):
                    grads[0][j] += grads[i][j]

        mdict = {}
        if 'total loss' in eval_keys:
            mdict['total loss'] = tf.add_n(
                [model.loss for model in self.models]) / self.nr_gpu
        if 'nll loss' in eval_keys:
            mdict['nll loss'] = tf.add_n(
                [model.loss_nll for model in self.models]) / self.nr_gpu
        if 'reg loss' in eval_keys:
            mdict['reg loss'] = tf.add_n(
                [model.loss_reg for model in self.models]) / self.nr_gpu
        if 'bits per dim' in eval_keys:
            mdict['bits per dim'] = tf.add_n(
                [model.bits_per_dim for model in self.models]) / self.nr_gpu
        if 'mi' in eval_keys:
            mdict['mi'] = tf.add_n([model.mi
                                    for model in self.models]) / self.nr_gpu

        self.monitor = Monitor(dict=mdict,
                               config_str="",
                               log_file_path=self.save_dir + "/logfile")
        self.train_step = adam_updates(self.params, grads[0], lr=learning_rate)
        #
        self.saver = tf.train.Saver()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:

    sess.run(initializer)
    learner.set_session(sess)
    if args.mode == 'train':
        kwargs = {
            "train_mgen": get_generator(args.mask_type, size=args.img_size), ###
            "sample_mgen": get_generator(args.mask_type, size=args.img_size), ###
            "max_num_epoch": 200,
            "save_interval": args.save_interval,
            "restore": args.load_params,
        }
        if args.phase == 'ce':
            if not args.one_stage:
                learner.preload(from_dir=args.pvae_dir, var_list=get_trainable_variables(["forward_pixel_cnn", "conv_encoder", "conv_decoder"]))
        learner.train(**kwargs)
    elif args.mode == 'test':
        learner.eval(which_set='test', mgen=get_generator('bottom half', size=args.img_size), generate_samples=True)
    elif args.mode == 'inpainting':
        layout = (10, 10)
        same_inputs = False
        use_mask_at = "{0}_{1}.npz".format(args.mask_type, args.data_set)
        learner.inpainting(get_generator(args.mask_type, size=args.img_size), layout=layout, same_inputs=same_inputs, use_mask_at=use_mask_at)
    elif args.mode == 'traverse':
        mids = [5,6,7,8,9,10] #[7, 8, 11, 15]
        # mask_descriptions = ['mouth', 'eye', 'nose']
        mask_descriptions = ['mnist top 20']
        for mid in mids:
            for mask in mask_descriptions:
                print("mid {0}, mask {1}".format(mid, mask))
示例#5
0
for i in range(args.nr_gpu):
    with tf.device('/gpu:%d' % i):
        model(pvaes[i],
              xs[i],
              x_bars[i],
              is_trainings[i],
              dropout_ps[i],
              masks=masks[i],
              input_masks=input_masks[i],
              random_indices=random_indices[i],
              **model_opt)

if args.mode == 'train':
    if args.phase == 'ce':
        all_params = get_trainable_variables(
            ["conv_pixel_cnn", "context_encoder"])
    elif args.phase == 'pvae':
        all_params = get_trainable_variables(
            ["conv_encoder", "conv_decoder", "conv_pixel_cnn"])
    grads = []
    for i in range(args.nr_gpu):
        with tf.device('/gpu:%d' % i):
            grads.append(
                tf.gradients(pvaes[i].loss,
                             all_params,
                             colocate_gradients_with_ops=True))
    with tf.device('/gpu:0'):
        for i in range(1, args.nr_gpu):
            for j in range(len(grads[0])):
                grads[0][j] += grads[i][j]
示例#6
0
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:

    sess.run(initializer)
    learner.set_session(sess)
    if args.mode == 'train':
        kwargs = {
            "train_mgen": get_generator(args.mask_type, size=args.img_size), ###
            "sample_mgen": get_generator(args.mask_type, size=args.img_size), ###
            "max_num_epoch": 200,
            "save_interval": None,
            "restore": args.load_params,
        }
        if 'mnist' in args.data_set:
            learner.preload(from_dir="/data/ziz/jxu/save_dirs/model_mnist_controlled_pixel_vae_mmd_{0}_rec".format(args.phase), var_list=get_trainable_variables(["reverse_pixel_cnn", "forward_pixel_cnn", "conv_encoder", "conv_decoder"]))
        else:
            learner.preload(from_dir="/data/ziz/jxu/save_dirs/model_celeba_controlled_pixel_vae_mmd_{0}_rec".format(args.phase), var_list=get_trainable_variables(["reverse_pixel_cnn", "forward_pixel_cnn", "conv_encoder", "conv_decoder"]))
        learner.train(**kwargs)

    elif args.mode == 'test':
        if args.mask_type == 'random blobs':
            mtype = 'blob'
        elif args.mask_type == 'random rec':
            mtype = 'rec'
        elif args.mask_type == 'pepper':
            mtype = 'pepper'

        if 'mnist' in args.data_set:
            learner.preload(from_dir="/data/ziz/jxu/save_dirs/model_mnist_controlled_pixel_vae_mmd_{0}_{1}".format(args.phase, mtype), var_list=None)
        elif 'church' in args.data_set:
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large_mixture_logistic
                encoder_q = conv_encoder_32_q
            else:
                raise Exception("unknown network type")
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
                encoder_q = conv_encoder_28_binary_q
            else:
                raise Exception("unknown network type")
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([encoder, decoder, encoder_q], **kwargs):
            inputs = self.x

            self.num_particles = 16

            inputs = inputs * broadcast_masks_tf(
                self.masks, num_channels=self.num_channels)
            inputs = tf.concat(
                [inputs,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            inputs_pos = tf.concat(
                [self.x,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            inputs = tf.concat([inputs, inputs_pos], axis=0)

            z_mu, z_log_sigma_sq = encoder(inputs, self.z_dim)
            self.z_mu_pr, self.z_mu = z_mu[:self.batch_size], z_mu[self.
                                                                   batch_size:]
            self.z_log_sigma_sq_pr, self.z_log_sigma_sq = z_log_sigma_sq[:self.batch_size], z_log_sigma_sq[
                self.batch_size:]

            self.z_mu, self.z_log_sigma_sq = self.z_mu_pr, self.z_log_sigma_sq_pr

            x = tf.tile(self.x, [self.num_particles, 1, 1, 1])
            masks = tf.tile(self.masks, [self.num_particles, 1, 1])
            self.z_mu = tf.tile(self.z_mu, [self.num_particles, 1])
            self.z_mu_pr = tf.tile(self.z_mu_pr, [self.num_particles, 1])
            self.z_log_sigma_sq = tf.tile(self.z_log_sigma_sq,
                                          [self.num_particles, 1])
            self.z_log_sigma_sq_pr = tf.tile(self.z_log_sigma_sq_pr,
                                             [self.num_particles, 1])
            sigma = tf.exp(self.z_log_sigma_sq / 2.)

            self.params = get_trainable_variables(["inference"])

            dist = tf.distributions.Normal(loc=0., scale=1.)
            epsilon = dist.sample(sample_shape=[
                self.batch_size * self.num_particles, self.z_dim
            ],
                                  seed=None)
            z = self.z_mu + tf.multiply(epsilon, sigma)

            if self.network_type == 'binary':
                self.pixel_params = decoder(z)
            else:
                self.pixel_params = decoder(
                    z, nr_logistic_mix=self.nr_logistic_mix)
            if self.network_type == 'binary':
                nll = bernoulli_loss(x,
                                     self.pixel_params,
                                     masks=masks,
                                     output_mean=False)
            else:
                nll = mix_logistic_loss(x,
                                        self.pixel_params,
                                        masks=masks,
                                        output_mean=False)

            log_prob_pos = dist.log_prob(epsilon)
            epsilon_pr = (z - self.z_mu_pr) / tf.exp(
                self.z_log_sigma_sq_pr / 2.)
            log_prob_pr = dist.log_prob(epsilon_pr)
            # convert back
            log_prob_pr = tf.stack([
                log_prob_pr[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                                   axis=0)
            log_prob_pos = tf.stack([
                log_prob_pos[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                                    axis=0)
            log_prob_pr = tf.reduce_sum(log_prob_pr, axis=2)
            log_prob_pos = tf.reduce_sum(log_prob_pos, axis=2)
            nll = tf.stack([
                nll[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                           axis=0)
            log_likelihood = -nll

            # log_weights = log_prob_pr + log_likelihood - log_prob_pos
            log_weights = log_likelihood
            log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0)
            log_avg_weight = log_sum_weight - tf.log(
                tf.to_float(self.num_particles))
            self.log_avg_weight = log_avg_weight

            normalized_weights = tf.stop_gradient(
                tf.nn.softmax(log_weights, axis=0))
            sq_normalized_weights = tf.square(normalized_weights)

            self.gradients = tf.gradients(
                -tf.reduce_sum(sq_normalized_weights * log_weights, axis=0),
                self.params,
                colocate_gradients_with_ops=True)
示例#8
0
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large
                encoder_q = conv_encoder_32_q
            else:
                encoder = conv_encoder_32
                decoder = conv_decoder_32
            forward_pixelcnn = forward_pixel_cnn_32_small
            reverse_pixelcnn = reverse_pixel_cnn_32_small
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
                forward_pixelcnn = forward_pixel_cnn_28_binary
                reverse_pixelcnn = reverse_pixel_cnn_28_binary
                encoder_q = conv_encoder_28_binary_q
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope(
            [forward_pixelcnn, reverse_pixelcnn, encoder, decoder, encoder_q],
                **kwargs):
            kwargs_pixelcnn = {
                "nr_resnet": self.nr_resnet,
                "nr_filters": self.nr_filters,
                "nr_logistic_mix": self.nr_logistic_mix,
                "dropout_p": self.dropout_p,
                "bn": False,
            }
            with arg_scope([forward_pixelcnn, reverse_pixelcnn],
                           **kwargs_pixelcnn):
                self.num_particles = 16

                inp = self.x * broadcast_masks_tf(
                    self.input_masks, num_channels=self.num_channels)
                inp += tf.random_uniform(
                    int_shape(inp), -1, 1) * (1 - broadcast_masks_tf(
                        self.input_masks, num_channels=self.num_channels))
                inp = tf.concat([
                    inp,
                    broadcast_masks_tf(self.input_masks, num_channels=1)
                ],
                                axis=-1)

                inputs_pos = tf.concat([
                    self.x,
                    broadcast_masks_tf(tf.ones_like(self.input_masks),
                                       num_channels=1)
                ],
                                       axis=-1)
                inp = tf.concat([inp, inputs_pos], axis=0)

                z_mu, z_log_sigma_sq = encoder(inp, self.z_dim)
                self.z_mu_pr, self.z_mu = z_mu[:self.batch_size], z_mu[
                    self.batch_size:]
                self.z_log_sigma_sq_pr, self.z_log_sigma_sq = z_log_sigma_sq[:self.batch_size], z_log_sigma_sq[
                    self.batch_size:]

                x = tf.tile(self.x, [self.num_particles, 1, 1, 1])
                x_bar = tf.tile(self.x_bar, [self.num_particles, 1, 1, 1])
                input_masks = tf.tile(self.input_masks,
                                      [self.num_particles, 1, 1])
                masks = tf.tile(self.masks, [self.num_particles, 1, 1])

                self.z_mu_pr = tf.tile(self.z_mu_pr, [self.num_particles, 1])
                self.z_log_sigma_sq_pr = tf.tile(self.z_log_sigma_sq_pr,
                                                 [self.num_particles, 1])
                self.z_mu = tf.tile(self.z_mu, [self.num_particles, 1])
                self.z_log_sigma_sq = tf.tile(self.z_log_sigma_sq,
                                              [self.num_particles, 1])

                self.z_mu, self.z_log_sigma_sq = self.z_mu_pr, self.z_log_sigma_sq_pr

                sigma = tf.exp(self.z_log_sigma_sq / 2.)

                self.params = get_trainable_variables(["inference"])

                dist = tf.distributions.Normal(loc=0., scale=1.)
                epsilon = dist.sample(sample_shape=[
                    self.batch_size * self.num_particles, self.z_dim
                ],
                                      seed=None)
                z = self.z_mu + tf.multiply(epsilon, sigma)

                decoded_features = decoder(z, output_features=True)
                r_outputs = reverse_pixelcnn(x, masks, context=None, bn=False)
                cond_features = tf.concat([r_outputs, decoded_features],
                                          axis=-1)
                cond_features = tf.concat([
                    broadcast_masks_tf(input_masks, num_channels=1),
                    cond_features
                ],
                                          axis=-1)

                self.pixel_params = forward_pixelcnn(x_bar,
                                                     cond_features,
                                                     bn=False)

                if self.network_type == 'binary':
                    nll = bernoulli_loss(x,
                                         self.pixel_params,
                                         masks=masks,
                                         output_mean=False)
                else:
                    nll = mix_logistic_loss(x,
                                            self.pixel_params,
                                            masks=masks,
                                            output_mean=False)

                log_prob_pos = dist.log_prob(epsilon)
                epsilon_pr = (z - self.z_mu_pr) / tf.exp(
                    self.z_log_sigma_sq_pr / 2.)
                log_prob_pr = dist.log_prob(epsilon_pr)
                # convert back
                log_prob_pr = tf.stack([
                    log_prob_pr[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                                       axis=0)
                log_prob_pos = tf.stack([
                    log_prob_pos[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                                        axis=0)
                log_prob_pr = tf.reduce_sum(log_prob_pr, axis=2)
                log_prob_pos = tf.reduce_sum(log_prob_pos, axis=2)
                nll = tf.stack([
                    nll[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                               axis=0)
                log_likelihood = -nll

                # log_weights = log_prob_pr + log_likelihood - log_prob_pos
                log_weights = log_likelihood
                log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0)
                log_avg_weight = log_sum_weight - tf.log(
                    tf.to_float(self.num_particles))
                self.log_avg_weight = log_avg_weight

                normalized_weights = tf.stop_gradient(
                    tf.nn.softmax(log_weights, axis=0))
                sq_normalized_weights = tf.square(normalized_weights)
                self.gradients = tf.gradients(-tf.reduce_sum(
                    sq_normalized_weights * log_weights, axis=0),
                                              self.params,
                                              colocate_gradients_with_ops=True)