def _retrain_at_task_or_all(self, task_id, train_xs, train_labels, retrain_flags, is_verbose, **kwargs): X, Y, excl_X, excl_Y, opt_with_excl, loss_with_excl, keep_prob, is_training = kwargs["model_args"] num_batches = self.n_tasks - len(retrain_flags.task_to_forget) batch_size_per_task = self.batch_size // num_batches xs_queues = [get_batch_iterator(train_xs[t](), batch_size_per_task) for t in range(self.n_tasks)] labels_queues = [get_batch_iterator(train_labels[t](), batch_size_per_task) for t in range(self.n_tasks)] loss_sum = 0 for target_task_id in list(set(range(1, self.n_tasks + 1)) - set(retrain_flags.task_to_forget)): target_t = target_task_id - 1 xs_wo_target, labels_wo_target = _get_xs_and_labels_wo_target( [target_t] + [t - 1 for t in retrain_flags.task_to_forget], xs_queues, labels_queues, ) feed_dict = { X: xs_queues[target_t](), Y: labels_queues[target_t](), excl_X: xs_wo_target, excl_Y: labels_wo_target, keep_prob: self.keep_prob, is_training: True, } _, loss_val = self.sess.run([opt_with_excl, loss_with_excl], feed_dict=feed_dict) loss_sum += loss_val return loss_sum
def read_data(self, max_train_size, max_dev_size, read_ahead=10, batch_mode='standard', shuffle=True, crash_test=False, **kwargs): utils.debug('reading training data') self.batch_iterator, self.train_size = utils.get_batch_iterator( self.filenames.train, self.extensions, self.vocabs, self.batch_size, max_size=max_train_size, character_level=self.character_level, max_seq_len=self.max_len, read_ahead=read_ahead, mode=batch_mode, shuffle=shuffle, binary=self.binary, crash_test=crash_test ) utils.debug('reading development data') dev_sets = [ utils.read_dataset(dev, self.extensions, self.vocabs, max_size=max_dev_size, character_level=self.character_level, binary=self.binary)[0] for dev in self.filenames.dev ] # subset of the dev set whose loss is periodically evaluated self.dev_batches = [utils.get_batches(dev_set, batch_size=self.batch_size) for dev_set in dev_sets]
def initial_train(self, print_iter=5, *args): X, Y, excl_X, excl_Y, keep_prob, is_training = self.build_model() # Add L1 & L2 loss regularizer l1_l2_regularizer = tf.contrib.layers.l1_l2_regularizer( scale_l1=self.l1_lambda, scale_l2=self.l2_lambda, ) variables = [var for var in tf.trainable_variables() if "conv" in var.name or "fc" in var.name] regularization_loss = tf.contrib.layers.apply_regularization(l1_l2_regularizer, variables) train_x, train_labels, val_x, val_labels, test_x, test_labels = self._get_data_stream_from_task_as_class_data() num_batches = int(math.ceil(len(train_x) / self.batch_size)) if not self.use_set_based_mask: loss = self.loss + regularization_loss opt = tf.train.AdamOptimizer(learning_rate=self.init_lr, name="opt").minimize(loss) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) for epoch in trange(self.max_iter): self.initialize_batch() loss_sum = 0 for _ in range(num_batches): batch_x, batch_y = self.get_next_batch(train_x, train_labels) _, loss_val = self.sess.run( [opt, loss], feed_dict={X: batch_x, Y: batch_y, keep_prob: self.keep_prob, is_training: True}, ) loss_sum += loss_val if epoch % print_iter == 0 or epoch == self.max_iter - 1: self.evaluate_overall(epoch, val_x, val_labels, loss_sum) else: loss_with_excl = self.loss + self.excl_loss + regularization_loss opt_with_excl = tf.train.AdamOptimizer(learning_rate=self.init_lr, name="opt").minimize(loss_with_excl) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) batch_size_per_task = self.batch_size // self.n_tasks xs_queues = [get_batch_iterator(self.trainXs[t], batch_size_per_task) for t in range(self.n_tasks)] labels_queues = [get_batch_iterator(self.data_labels.get_train_labels(t + 1), batch_size_per_task) for t in range(self.n_tasks)] target_t = 0 for epoch in trange(self.max_iter): loss_sum = 0 for _ in range(num_batches): xs_wo_target, labels_wo_target = _get_xs_and_labels_wo_target(target_t, xs_queues, labels_queues) feed_dict = { X: xs_queues[target_t](), Y: labels_queues[target_t](), excl_X: xs_wo_target, excl_Y: labels_wo_target, keep_prob: self.keep_prob, is_training: True, } _, loss_val = self.sess.run([opt_with_excl, loss_with_excl], feed_dict=feed_dict) loss_sum += loss_val target_t = (target_t + 1) % self.n_tasks if epoch % print_iter == 0 or epoch == self.max_iter - 1: self.evaluate_overall(epoch, val_x, val_labels, loss_sum) for i in range(len(self.conv_dims) // 2 - 1): cprint_stats_of_mask_pair(self, 1, 6, batch_size_per_task, X, is_training, mask_id=i) if epoch > 0.75 * self.max_iter: max_perf = max(apf for apf, acc in self.validation_results) if self.validation_results[-1][0] == max_perf: cprint("EARLY STOPPED", "green") break