示例#1
0
    def create_graph(self):
        RSE_network.is_training = True
        """Creates graph for training"""
        self.base_cost = 0.0
        self.accuracy = 0
        num_sizes = len(self.bins)
        self.cost_list = []
        sum_weight = 0
        self.bin_losses = []
        saturation_loss = []
        total_mean_loss = 0

        # Create all bins and calculate losses for them

        with vs.variable_scope("var_lengths"):
            for seqLength, itemCount, ind in zip(self.bins, self.count_list,
                                                 range(num_sizes)):
                x_in = tf.compat.v1.placeholder(cnf.input_type,
                                                [itemCount, seqLength])
                y_in = tf.compat.v1.placeholder("int64",
                                                [itemCount, seqLength])
                self.x_input.append(x_in)
                self.y_input.append(y_in)
                RSE_network.saturation_costs = []
                RSE_network.gate_mem = []
                RSE_network.reset_mem = []
                RSE_network.candidate_mem = []
                RSE_network.prev_mem_list = []
                RSE_network.residual_list = []
                RSE_network.info_alpha = []

                if self.use_two_gpus:
                    device = "/device:GPU:" + (
                        "0" if seqLength >= self.bins[-1] else "1")
                    with tf.device(device):
                        c, a, mem1, logits, per_item_cost, _, _ = self.create_loss(
                            x_in, y_in, seqLength)
                else:
                    c, a, mem1, logits, per_item_cost, _, _ = self.create_loss(
                        x_in, y_in, seqLength)

                weight = 1.0

                sat_cost = (
                    tf.add_n(RSE_network.saturation_costs) /
                    (seqLength * len(RSE_network.saturation_costs) * itemCount)
                    if len(RSE_network.saturation_costs) > 0 else 0)
                saturation_loss.append(sat_cost * weight)
                self.bin_losses.append(per_item_cost)
                self.base_cost += c * weight
                sum_weight += weight
                self.accuracy += a
                self.cost_list.append(c)

                mean_loss = tf.reduce_mean(input_tensor=tf.square(mem1))
                total_mean_loss += mean_loss

                tf.compat.v1.get_variable_scope().reuse_variables()

        # calculate the total loss
        self.base_cost /= sum_weight
        self.accuracy /= num_sizes
        total_mean_loss /= num_sizes
        tf.compat.v1.summary.scalar("base/loss", self.base_cost)
        tf.compat.v1.summary.scalar("base/error", 1 - self.accuracy)
        tf.compat.v1.summary.scalar("base/error_longest", 1 - a)
        tf.compat.v1.summary.histogram("logits", logits)

        if cnf.task is not "musicnet":
            if RSE_network.gate_mem:
                gate_img = tf.stack(RSE_network.gate_mem)
                gate_img = gate_img[:, 0:1, :, :]
                gate_img = tf.cast(gate_img * 255, dtype=tf.uint8)
                tf.compat.v1.summary.image("gate",
                                           tf.transpose(a=gate_img,
                                                        perm=[3, 0, 2, 1]),
                                           max_outputs=16)
            if RSE_network.reset_mem:
                reset_img = tf.stack(RSE_network.reset_mem)
                reset_img = tf.clip_by_value(reset_img, -2, 2)
                tf.compat.v1.summary.histogram("reset", reset_img)
                reset_img = reset_img[:, 0:1, :, :]
                tf.compat.v1.summary.image(
                    "reset",
                    tf.transpose(a=reset_img, perm=[3, 0, 2, 1]),
                    max_outputs=16,
                )
            if RSE_network.prev_mem_list:
                prev_img = tf.stack(RSE_network.prev_mem_list)
                prev_img = prev_img[:, 0:1, :, :]
                prev_img = tf.cast(prev_img * 255, dtype=tf.uint8)
                tf.compat.v1.summary.image(
                    "prev_mem",
                    tf.transpose(a=prev_img, perm=[3, 0, 2, 1]),
                    max_outputs=16,
                )
            if RSE_network.residual_list:
                prev_img = tf.stack(RSE_network.residual_list)
                prev_img = prev_img[:, 0:1, :, :]
                prev_img = tf.cast(prev_img * 255, dtype=tf.uint8)
                tf.compat.v1.summary.image(
                    "residual_mem",
                    tf.transpose(a=prev_img, perm=[3, 0, 2, 1]),
                    max_outputs=16,
                )
            if RSE_network.info_alpha:
                prev_img = tf.stack(RSE_network.info_alpha)
                prev_img = prev_img[:, 0:1, :, :]
                tf.compat.v1.summary.image(
                    "info_alpha",
                    tf.transpose(a=prev_img, perm=[3, 0, 2, 1]),
                    max_outputs=16,
                )

            candidate_img = tf.stack(RSE_network.candidate_mem)
            candidate_img = candidate_img[:, 0:1, :, :]
            candidate_img = tf.cast((candidate_img + 1.0) * 127.5,
                                    dtype=tf.uint8)
            tf.compat.v1.summary.image(
                "candidate",
                tf.transpose(a=candidate_img, perm=[3, 0, 2, 1]),
                max_outputs=16,
            )

            mem1 = mem1[:, 0:1, :, :]
            tf.compat.v1.summary.image("mem",
                                       tf.transpose(a=mem1, perm=[3, 0, 2, 1]),
                                       max_outputs=16)

        saturation = tf.reduce_sum(
            input_tensor=tf.stack(saturation_loss)) / sum_weight
        tf.compat.v1.summary.scalar("base/activation_mean",
                                    tf.sqrt(total_mean_loss))

        self.sat_loss = saturation * self.saturation_weight
        cost = self.base_cost + self.sat_loss

        tvars = [v for v in tf.compat.v1.trainable_variables()]
        for var in tvars:
            name = var.name.replace("var_lengths", "")
            tf.compat.v1.summary.histogram(name + "/histogram", var)

        regvars = [var for var in tvars if "CvK" in var.name]
        print(regvars)
        reg_costlist = [
            tf.reduce_sum(input_tensor=tf.square(var)) for var in regvars
        ]
        reg_cost = tf.add_n(reg_costlist)
        tf.compat.v1.summary.scalar("base/regularize_loss", reg_cost)

        # optimizer

        self.local_lr = self.learning_rate

        optimizer = RAdamOptimizer(
            self.local_lr,
            epsilon=1e-5,
            L2_decay=0.01,
            L1_decay=0.00,
            decay_vars=regvars,
            total_steps=cnf.training_iters,
            warmup_proportion=cnf.num_warmup_steps / cnf.training_iters,
            clip_gradients=True,
        )

        self.optimizer = optimizer.minimize(cost, global_step=self.global_step)

        # some values for printout
        max_vals = []

        for var in tvars:
            var_v = optimizer.get_slot(var, "v")
            max_vals.append(tf.sqrt(var_v))

        self.gnorm = tf.linalg.global_norm(max_vals)
        tf.compat.v1.summary.scalar("base/gnorm", self.gnorm)
        self.cost_list = tf.stack(self.cost_list)
示例#2
0
    def create_graph(self):
        """Creates graph for training"""
        self.cost = 0.0
        self.accuracy = 0
        num_sizes = len(self.bins)
        self.cost_list = []
        self.bin_losses = []

        # Create all bins and calculate losses for them

        with vs.variable_scope("var_lengths"):
            for seqLength, itemCount, ind in zip(self.bins, self.count_list, range(num_sizes)):
                x_in = tf.placeholder("int64", [itemCount, seqLength])
                y_in = tf.placeholder("int64", [itemCount, seqLength])
                self.x_input.append(x_in)
                self.y_input.append(y_in)
                network.saturation_costs = []
                network.gate_mem = []
                network.reset_mem = []
                network.candidate_mem = []
                network.prev_mem_list = []

                if self.use_two_gpus:
                    device = "/device:GPU:" + ("0" if seqLength >= self.bins[-1] else "1")
                    with tf.device(device):
                        c, a, mem1, _, perItemCost, _ = self.create_loss(x_in, y_in, seqLength)
                else:
                    c, a, mem1, _, perItemCost, _ = self.create_loss(x_in, y_in, seqLength)

                # /seqLength
                self.bin_losses.append(perItemCost)
                self.cost += c
                self.accuracy += a
                self.cost_list.append(c)
                tf.get_variable_scope().reuse_variables()

        # calculate the total loss
        self.cost /= num_sizes
        self.accuracy /= num_sizes

        # tensorboard output
        tf.summary.scalar("base/loss", self.cost)
        tf.summary.scalar("base/accuracy", self.accuracy)
        tf.summary.scalar("base/accuracy_longest", a)

        gate_img = tf.stack(network.gate_mem)
        gate_img = gate_img[:, 0:1, :, :]
        gate_img = tf.cast(gate_img * 255, dtype=tf.uint8)
        tf.summary.image("gate", tf.transpose(gate_img, [3, 0, 2, 1]), max_outputs=16)
        reset_img = tf.stack(network.reset_mem)
        reset_img = reset_img[:, 0:1, :, :]
        reset_img = tf.cast(reset_img * 255, dtype=tf.uint8)
        tf.summary.image("reset", tf.transpose(reset_img, [3, 0, 2, 1]), max_outputs=16)
        if network.prev_mem_list:
            prev_img = tf.stack(network.prev_mem_list)
            prev_img = prev_img[:, 0:1, :, :]
            prev_img = tf.cast(prev_img * 255, dtype=tf.uint8)
            tf.summary.image("prev_mem", tf.transpose(prev_img, [3, 0, 2, 1]), max_outputs=16)

        candidate_img = tf.stack(network.candidate_mem)
        candidate_img = candidate_img[:, 0:1, :, :]
        candidate_img = tf.cast((candidate_img + 1.0) * 127.5, dtype=tf.uint8)
        tf.summary.image("candidate", tf.transpose(candidate_img, [3, 0, 2, 1]), max_outputs=16)

        mem1 = mem1[:, 0:1, :, :]
        tf.summary.image("mem", tf.transpose(mem1, [3, 0, 2, 1]), max_outputs=16)

        tvars = tf.trainable_variables()
        for var in tvars:
            name = var.name.replace("var_lengths", "")
            tf.summary.histogram(name + '/histogram', var)

        # we use a small L2 regularization, although it is questionable if it helps
        regularizable_vars = [var for var in tvars if "CvK" in var.name]
        reg_costlist = [tf.reduce_sum(tf.square(var)) for var in regularizable_vars]
        reg_cost = tf.add_n(reg_costlist)
        tf.summary.scalar("base/regularize_loss", reg_cost)
        optimizer = RAdamOptimizer(self.learning_rate, epsilon=1e-5, L2_decay=0.01, decay_vars=regularizable_vars, total_steps=cnf.training_iters, warmup_proportion=0.0) #Adam optimizer works as well
        self.optimizer = optimizer.minimize(self.cost, global_step=self.global_step)

        # some values for printout
        max_vals = []

        for var in tvars:
            varV = optimizer.get_slot(var, "v")
            max_vals.append(varV)

        self.gnorm = tf.global_norm(max_vals)
        self.cost_list = tf.stack(self.cost_list)