예제 #1
0
def model_fn_uncond(hparams, z, x, generator, discriminator, mdevice):

    # Get theta placeholders
    theta_ph = mdevice.get_theta_ph(hparams)
    theta_gen_ph = mdevice.get_theta_ph(hparams)

    # Get generations
    x_gen = generator(hparams, z, 'gen', train=True, reuse=False)
    x_sample = generator(hparams, z, 'gen', train=False, reuse=True)

    # Get lossy versions
    x_lossy, x_gen_lossy = get_lossy(hparams, mdevice, x, theta_ph, x_gen,
                                     theta_gen_ph)

    # Apply discriminator
    _, d_logit = discriminator(hparams,
                               x_lossy,
                               'discrim',
                               train=True,
                               reuse=False)
    _, d_gen_logit = discriminator(hparams,
                                   x_gen_lossy,
                                   'discrim',
                                   train=True,
                                   reuse=True)

    # Get loss
    d_loss, g_loss = get_loss(hparams, d_logit, d_gen_logit, x_lossy,
                              x_gen_lossy, discriminator, None, None, None)

    # Get train ops
    d_update_op, g_update_op, iter_ph = utils.get_train_ops(
        hparams, d_loss, g_loss)

    return x_lossy, x_sample, theta_ph, theta_gen_ph, d_loss, g_loss, d_update_op, g_update_op, iter_ph
예제 #2
0
    def _build_train(self):
        tf.logging.info("-" * 80)
        tf.logging.info("Build train graph")
        logits = self._model(self.x_train,
                             is_training=True,
                             reuse=tf.AUTO_REUSE)
        log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=self.y_train)
        self.loss = tf.reduce_mean(log_probs)

        if self.use_aux_heads:
            log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=self.aux_logits, labels=self.y_train)
            self.aux_loss = tf.reduce_mean(log_probs)
            train_loss = self.loss + 0.4 * self.aux_loss
        else:
            train_loss = self.loss

        self.train_preds = tf.argmax(logits, axis=1)
        self.train_preds = tf.to_int32(self.train_preds)
        self.train_acc = tf.equal(self.train_preds, self.y_train)
        self.train_acc = tf.to_int32(self.train_acc)
        self.train_acc = tf.reduce_sum(self.train_acc)

        tf_variables = [
            var for var in tf.trainable_variables()
            if (var.name.startswith(self.name) and "aux_head" not in var.name)
        ]
        self.num_vars = count_model_params(tf_variables)
        tf.logging.info("Model has {0} params".format(self.num_vars))

        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
                train_loss,
                tf_variables,
                self.global_step,
                self.num_train_steps,
                clip_mode=self.clip_mode,
                grad_bound=self.grad_bound,
                l2_reg=self.l2_reg,
                lr_init=self.lr_init,
                lr_dec_start=self.lr_dec_start,
                lr_dec_every=self.lr_dec_every,
                lr_dec_rate=self.lr_dec_rate,
                lr_cosine=self.lr_cosine,
                lr_max=self.lr_max,
                lr_min=self.lr_min,
                lr_T_0=self.lr_T_0,
                lr_T_mul=self.lr_T_mul,
                num_train_batches=self.num_train_batches,
                optim_algo=self.optim_algo,
                sync_replicas=self.sync_replicas,
                num_aggregate=self.num_aggregate,
                num_replicas=self.num_replicas)
예제 #3
0
    def build_trainer(self, child_model):
        child_model.build_valid_rl()
        self.valid_acc = (tf.to_float(child_model.valid_shuffle_acc) /
                          tf.to_float(child_model.batch_size))
        self.current_normal_arc = child_model.current_normal_arc
        self.current_reduce_arc = child_model.current_reduce_arc

        self.reward = self.valid_acc

        if self.entropy_weight is not None:
            self.reward += self.entropy_weight * self.sample_entropy

        self.sample_log_prob = tf.reduce_sum(self.sample_log_prob)
        self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False)
        baseline_update = tf.assign_sub(self.baseline, (1 - self.bl_dec) *
                                        (self.baseline - self.reward))

        with tf.control_dependencies([baseline_update]):
            self.reward = tf.identity(self.reward)

        self.loss = self.sample_log_prob * (self.reward - self.baseline)
        self.train_step = tf.Variable(0,
                                      dtype=tf.int32,
                                      trainable=False,
                                      name="train_step")

        tf_variables = [
            var for var in tf.trainable_variables()
            if var.name.startswith(self.name)
        ]
        print("-" * 80)
        for var in tf_variables:
            print(var)

        self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
            self.loss,
            tf_variables,
            self.train_step,
            clip_mode=self.clip_mode,
            grad_bound=self.grad_bound,
            l2_reg=self.l2_reg,
            lr_init=self.lr_init,
            lr_dec_start=self.lr_dec_start,
            lr_dec_every=self.lr_dec_every,
            lr_dec_rate=self.lr_dec_rate,
            optim_algo=self.optim_algo,
            sync_replicas=self.sync_replicas,
            num_aggregate=self.num_aggregate,
            num_replicas=self.num_replicas)

        self.skip_rate = tf.constant(0.0, dtype=tf.float32)
예제 #4
0
    def _build_train(self):
        print("-" * 80)
        print("Build train graph")
        logits = self._model(self.x_train, is_training=True)
        log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=self.y_train)  ## self.x_train [32,3,32,32] //  batch size is 32!

        self.loss = tf.reduce_mean(log_probs)  ## loss function for training

        if self.use_aux_heads:
            log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=self.aux_logits, labels=self.y_train)
            self.aux_loss = tf.reduce_mean(log_probs)
            train_loss = self.loss + 0.4 * self.aux_loss
        else:
            train_loss = self.loss

        self.train_preds = tf.argmax(logits, axis=1)
        self.train_preds = tf.to_int32(self.train_preds)

        self.train_acc = tf.equal(self.train_preds, self.y_train)
        self.train_acc = tf.to_int32(self.train_acc)
        self.train_acc = tf.reduce_sum(self.train_acc)  # we should divide self.train_acc by batch_size 32

        tf_variables = [
            var for var in tf.trainable_variables() if (
                    var.name.startswith(self.name) and "aux_head" not in var.name)]
        self.num_vars = count_model_params(tf_variables)
        print("Model has {0} params".format(self.num_vars))

        self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
            train_loss,
            tf_variables,
            self.global_step,
            clip_mode=self.clip_mode,
            grad_bound=self.grad_bound,
            l2_reg=self.l2_reg,
            lr_init=self.lr_init,
            lr_dec_start=self.lr_dec_start,
            lr_dec_every=self.lr_dec_every,
            lr_dec_rate=self.lr_dec_rate,
            lr_cosine=self.lr_cosine,
            lr_max=self.lr_max,
            lr_min=self.lr_min,
            lr_T_0=self.lr_T_0,
            lr_T_mul=self.lr_T_mul,
            num_train_batches=self.num_train_batches,
            optim_algo=self.optim_algo,
            sync_replicas=self.sync_replicas,
            num_aggregate=self.num_aggregate,
            num_replicas=self.num_replicas)
예제 #5
0
    def _build_train(self):
        print("Build train graph")
        logits = self._model(self.x_train, True)
        log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=self.y_train)
        self.loss = tf.reduce_mean(log_probs)

        self.train_preds = tf.argmax(logits, axis=1)
        self.train_preds = tf.to_int32(self.train_preds)
        self.train_acc = tf.equal(self.train_preds, self.y_train)
        self.train_acc = tf.to_int32(self.train_acc)
        self.train_acc = tf.reduce_sum(self.train_acc)

        tf_variables = [
            var for var in tf.trainable_variables()
            if var.name.startswith(self.name)
        ]
        self.num_vars = count_model_params(tf_variables)
        print("-" * 80)
        for var in tf_variables:
            print(var)

        self.global_step = tf.Variable(0,
                                       dtype=tf.int32,
                                       trainable=False,
                                       name="global_step")
        self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
            self.loss,
            tf_variables,
            self.global_step,
            clip_mode=self.clip_mode,
            grad_bound=self.grad_bound,
            l2_reg=self.l2_reg,
            lr_init=self.lr_init,
            lr_dec_start=self.lr_dec_start,
            lr_dec_every=self.lr_dec_every,
            lr_dec_rate=self.lr_dec_rate,
            optim_algo=self.optim_algo,
            sync_replicas=self.sync_replicas,
            num_aggregate=self.num_aggregate,
            num_replicas=self.num_replicas)
예제 #6
0
    def __init__(
        self,
        model,
        inputs_batch_size=16,
        embedding_dim=128,
        lstm_hidden_size=128,
        lstm_num_layers=2,
        useAttention=True,
        lr_init=0.0035,
        lr_dec_start=0,
        lr_dec_every=10000000,
        lr_dec_rate=0.9,
        l2_reg=0,
        clip_mode=None,
        grad_bound=None,
        optim_algo="adam",
        sync_replicas=False,
        num_aggregate=None,
        num_replicas=None,
        train_player_steps=200,
    ):
        self.train_player_steps = train_player_steps
        self.model = model
        self.inputs_batch_size = inputs_batch_size

        num_players = 0
        self.dict_player_id = {}

        self.pos_player_list = []
        for n_or_r in range(2):
            for node in range(2, self.model.num_cells + 2):
                for x_or_y in [0, 2]:
                    tmp_player_list = []
                    for prev_node in range(0, node):

                        self.dict_player_id[num_players] = (n_or_r, node,
                                                            prev_node, x_or_y)
                        tmp_player_list.append(num_players)
                        num_players += 1
                    self.pos_player_list.append(tmp_player_list)
        self.num_players = num_players

        self.num_arc_class = 3 + 1

        self.START_TOKEN = -1
        self.END_TOKEN = self.num_arc_class

        self.seq_inputs = tf.placeholder(
            shape=(inputs_batch_size, self.num_players,
                   self.num_arc_class * self.model.num_ops),
            dtype=tf.int32,
            name='seq_inputs')

        self.seq_inputs_length = tf.ones([inputs_batch_size],
                                         dtype=tf.int32) * self.num_players

        self.reward_value_list = tf.placeholder(shape=(inputs_batch_size,
                                                       self.model.num_ops,
                                                       self.num_arc_class),
                                                dtype=tf.float32,
                                                name="reward_value_list")

        self.select_player_id = tf.placeholder(shape=(),
                                               dtype=tf.int32,
                                               name="select_player_id")
        with tf.variable_scope("environment_encoder", reuse=tf.AUTO_REUSE):
            encoder_embedding = tf.Variable(tf.random_uniform(
                [self.num_arc_class * self.model.num_ops, embedding_dim]),
                                            dtype=tf.float32,
                                            name='encoder_embedding')

            encoder_inputs_embedded = tf.einsum(
                'ibn,nd->ibd', tf.cast(self.seq_inputs, dtype=tf.float32),
                encoder_embedding)

            encoder_cell = tf.nn.rnn_cell.MultiRNNCell([
                tf.nn.rnn_cell.BasicLSTMCell(lstm_hidden_size)
                for _ in range(lstm_num_layers)
            ])
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                cell=encoder_cell,
                inputs=encoder_inputs_embedded,
                sequence_length=self.seq_inputs_length,
                dtype=tf.float32,
                time_major=False)
            encoder_state = encoder_state[-1]
        self.player_out_list = []
        self.player_log_prob_list = []
        self.player_loss_list = []
        self.train_step = tf.Variable(0,
                                      dtype=tf.int32,
                                      trainable=False,
                                      name="train_step")
        self.run_op_list = []
        tf_variables_list = []
        self.att_matrix_list = []
        for player_id in range(num_players):
            with tf.variable_scope("player_decoder_{}".format(player_id),
                                   reuse=tf.AUTO_REUSE):
                tokens_go = tf.ones([inputs_batch_size],
                                    dtype=tf.int32) * self.START_TOKEN
                decoder_embedding = tf.Variable(tf.random_uniform(
                    [self.num_arc_class, embedding_dim]),
                                                dtype=tf.float32,
                                                name='decoder_embedding')
                decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(lstm_hidden_size)
                if useAttention:
                    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                        num_units=lstm_hidden_size,
                        memory=encoder_outputs,
                        memory_sequence_length=self.seq_inputs_length)

                    decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                        decoder_cell,
                        attention_mechanism,
                        alignment_history=True,
                        output_attention=True)
                    decoder_initial_state = decoder_cell.zero_state(
                        batch_size=inputs_batch_size, dtype=tf.float32)
                    decoder_initial_state = decoder_initial_state.clone(
                        cell_state=encoder_state)

                helper = tf.contrib.seq2seq.SampleEmbeddingHelper(
                    decoder_embedding, tokens_go, self.END_TOKEN)
                decoder = tf.contrib.seq2seq.BasicDecoder(
                    decoder_cell,
                    helper,
                    decoder_initial_state,
                    output_layer=tf.layers.Dense(self.num_arc_class))
                decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
                    decoder, maximum_iterations=self.model.num_ops)

                attention_matrices = decoder_state.alignment_history.stack(
                    name="train_attention_matrix.{}".format(player_id))

                att_matrix = tf.reduce_mean(attention_matrices, axis=1)

                self.att_matrix_list.append(att_matrix)

                decoder_logits = decoder_outputs.rnn_output

                tmp_decoder_logits = tf.reshape(decoder_logits,
                                                shape=[-1, self.num_arc_class])

                selected_classes = tf.multinomial(tmp_decoder_logits, 1)

                selected_classes = tf.reshape(
                    selected_classes,
                    shape=[self.inputs_batch_size, self.model.num_ops])

                log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=decoder_logits, labels=selected_classes)

                reshape_selected_classes = tf.reshape(
                    selected_classes,
                    shape=[self.inputs_batch_size, self.model.num_ops, 1])

                sum_log_prob = tf.reduce_sum(log_prob,
                                             axis=-1)  # shape=(batch_size,1)
                self.player_log_prob_list.append(sum_log_prob)
                self.player_out_list.append(
                    selected_classes)  # shape = [batch_size, num_copy]
                reward = tf.gather_nd(
                    self.reward_value_list,
                    reshape_selected_classes,
                    batch_dims=2)  # the same shape as selected_classes

                reshape_reward = tf.reshape(
                    reward, shape=[self.inputs_batch_size, self.model.num_ops])
                loss = tf.reduce_sum(reshape_reward * log_prob, axis=-1)
                loss = tf.reduce_mean(loss)
                self.player_loss_list.append(loss)

                tf_variables = [
                    var for var in tf.trainable_variables() if
                    var.name.startswith("player_decoder_{}".format(player_id))
                ]
                num_shapes = len(tf_variables)

                tf_variables_list.append(tf_variables)

        tf_shape_variables_list = [
            tf.stack([
                tf_variables_list[p_id][shape_id]
                for p_id in range(num_players)
            ],
                     axis=0) for shape_id in range(num_shapes)
        ]
        tf_encoder_variable = [
            var for var in tf.trainable_variables()
            if var.name.startswith("environment_encoder")
        ]

        self.player_loss_list = tf.stack(self.player_loss_list, axis=0)

        train_op, lr, grad_norm, optimizer = get_train_ops(
            self.player_loss_list[self.select_player_id],
            tf_encoder_variable + [
                tf_shape_variable[self.select_player_id]
                for tf_shape_variable in tf_shape_variables_list
            ],
            self.train_step,
            clip_mode=clip_mode,
            grad_bound=grad_bound,
            l2_reg=l2_reg,
            lr_init=lr_init,
            lr_dec_start=lr_dec_start,
            lr_dec_every=lr_dec_every,
            lr_dec_rate=lr_dec_rate,
            optim_algo=optim_algo,
            sync_replicas=sync_replicas,
            num_aggregate=num_aggregate,
            num_replicas=num_replicas)
        self.run_op = [
            train_op, self.player_loss_list[self.select_player_id],
            self.train_step, lr, grad_norm, self.att_matrix_list
        ]