Exemplo n.º 1
0
    def _retrain_at_task_or_all(self, task_id, train_xs, train_labels, retrain_flags, is_verbose, **kwargs):
        X, Y, excl_X, excl_Y, opt_with_excl, loss_with_excl, keep_prob, is_training = kwargs["model_args"]

        num_batches = self.n_tasks - len(retrain_flags.task_to_forget)
        batch_size_per_task = self.batch_size // num_batches

        xs_queues = [get_batch_iterator(train_xs[t](), batch_size_per_task) for t in range(self.n_tasks)]
        labels_queues = [get_batch_iterator(train_labels[t](), batch_size_per_task) for t in range(self.n_tasks)]

        loss_sum = 0
        for target_task_id in list(set(range(1, self.n_tasks + 1)) - set(retrain_flags.task_to_forget)):
            target_t = target_task_id - 1

            xs_wo_target, labels_wo_target = _get_xs_and_labels_wo_target(
                [target_t] + [t - 1 for t in retrain_flags.task_to_forget],
                xs_queues, labels_queues,
            )

            feed_dict = {
                X: xs_queues[target_t](),
                Y: labels_queues[target_t](),
                excl_X: xs_wo_target, excl_Y: labels_wo_target,
                keep_prob: self.keep_prob, is_training: True,
            }
            _, loss_val = self.sess.run([opt_with_excl, loss_with_excl], feed_dict=feed_dict)
            loss_sum += loss_val

        return loss_sum
    def read_data(self, max_train_size, max_dev_size, read_ahead=10, batch_mode='standard', shuffle=True,
                  crash_test=False, **kwargs):
        utils.debug('reading training data')
        self.batch_iterator, self.train_size = utils.get_batch_iterator(
            self.filenames.train, self.extensions, self.vocabs, self.batch_size,
            max_size=max_train_size, character_level=self.character_level, max_seq_len=self.max_len,
            read_ahead=read_ahead, mode=batch_mode, shuffle=shuffle, binary=self.binary, crash_test=crash_test
        )

        utils.debug('reading development data')

        dev_sets = [
            utils.read_dataset(dev, self.extensions, self.vocabs, max_size=max_dev_size,
                               character_level=self.character_level, binary=self.binary)[0]
            for dev in self.filenames.dev
            ]
        # subset of the dev set whose loss is periodically evaluated
        self.dev_batches = [utils.get_batches(dev_set, batch_size=self.batch_size) for dev_set in dev_sets]
Exemplo n.º 3
0
    def initial_train(self, print_iter=5, *args):

        X, Y, excl_X, excl_Y, keep_prob, is_training = self.build_model()

        # Add L1 & L2 loss regularizer
        l1_l2_regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=self.l1_lambda,
            scale_l2=self.l2_lambda,
        )
        variables = [var for var in tf.trainable_variables() if "conv" in var.name or "fc" in var.name]
        regularization_loss = tf.contrib.layers.apply_regularization(l1_l2_regularizer, variables)
        train_x, train_labels, val_x, val_labels, test_x, test_labels = self._get_data_stream_from_task_as_class_data()
        num_batches = int(math.ceil(len(train_x) / self.batch_size))

        if not self.use_set_based_mask:
            loss = self.loss + regularization_loss
            opt = tf.train.AdamOptimizer(learning_rate=self.init_lr, name="opt").minimize(loss)

            self.sess = tf.Session()
            self.sess.run(tf.global_variables_initializer())

            for epoch in trange(self.max_iter):
                self.initialize_batch()
                loss_sum = 0
                for _ in range(num_batches):
                    batch_x, batch_y = self.get_next_batch(train_x, train_labels)
                    _, loss_val = self.sess.run(
                        [opt, loss],
                        feed_dict={X: batch_x, Y: batch_y, keep_prob: self.keep_prob, is_training: True},
                    )
                    loss_sum += loss_val

                if epoch % print_iter == 0 or epoch == self.max_iter - 1:
                    self.evaluate_overall(epoch, val_x, val_labels, loss_sum)
        else:
            loss_with_excl = self.loss + self.excl_loss + regularization_loss
            opt_with_excl = tf.train.AdamOptimizer(learning_rate=self.init_lr, name="opt").minimize(loss_with_excl)

            self.sess = tf.Session()
            self.sess.run(tf.global_variables_initializer())

            batch_size_per_task = self.batch_size // self.n_tasks

            xs_queues = [get_batch_iterator(self.trainXs[t], batch_size_per_task) for t in range(self.n_tasks)]
            labels_queues = [get_batch_iterator(self.data_labels.get_train_labels(t + 1), batch_size_per_task)
                             for t in range(self.n_tasks)]

            target_t = 0
            for epoch in trange(self.max_iter):
                loss_sum = 0
                for _ in range(num_batches):
                    xs_wo_target, labels_wo_target = _get_xs_and_labels_wo_target(target_t, xs_queues, labels_queues)

                    feed_dict = {
                        X: xs_queues[target_t](), Y: labels_queues[target_t](),
                        excl_X: xs_wo_target, excl_Y: labels_wo_target,
                        keep_prob: self.keep_prob, is_training: True,
                    }
                    _, loss_val = self.sess.run([opt_with_excl, loss_with_excl], feed_dict=feed_dict)
                    loss_sum += loss_val
                    target_t = (target_t + 1) % self.n_tasks

                if epoch % print_iter == 0 or epoch == self.max_iter - 1:
                    self.evaluate_overall(epoch, val_x, val_labels, loss_sum)
                    for i in range(len(self.conv_dims) // 2 - 1):
                        cprint_stats_of_mask_pair(self, 1, 6, batch_size_per_task, X, is_training, mask_id=i)

                if epoch > 0.75 * self.max_iter:
                    max_perf = max(apf for apf, acc in self.validation_results)
                    if self.validation_results[-1][0] == max_perf:
                        cprint("EARLY STOPPED", "green")
                        break