def init_validation_set(self, D_gpus, training_set): assert self._valid_images is None images, labels = training_set.load_validation_set_np() if images.shape[0] == 0: return self._valid_images = images self._valid_labels = labels # Build validation graph. with tflib.absolute_name_scope('Validation'), tf.control_dependencies( None): with tf.device('/cpu:0'): self._valid_images_in = tf.placeholder(training_set.dtype, name='valid_images_in', shape=[None] + training_set.shape) self._valid_labels_in = tf.placeholder( training_set.label_dtype, name='valid_labels_in', shape=[None, training_set.label_size]) images_in_gpus = tf.split(self._valid_images_in, len(D_gpus)) labels_in_gpus = tf.split(self._valid_labels_in, len(D_gpus)) ops = [] for gpu, (D_gpu, images_in_gpu, labels_in_gpu) in enumerate( zip(D_gpus, images_in_gpus, labels_in_gpus)): with tf.device(f'/gpu:{gpu}'): images_expr = tf.cast(images_in_gpu, tf.float32) * (2 / 255) - 1 D_valid = loss.eval_D(D_gpu, self, images_expr, labels_in_gpu, report='valid') ops += [D_valid.scores] self._valid_op = tf.group(*ops)
def get_strength_var(self): if self._strength_var is None: with tflib.absolute_name_scope('Augment'), tf.control_dependencies( None): self._strength_var = tf.Variable(np.float32(self.strength), name='strength', trainable=False) return self._strength_var
def _read_and_decay_acc(self, name, nimg_delta): acc_vars = self._acc_vars[name] acc_num, acc_sum = tuple(np.sum(tflib.run(acc_vars), axis=0)) if nimg_delta > 0: with tflib.absolute_name_scope('Augment'), tf.control_dependencies(None): if self._acc_decay_in is None: self._acc_decay_in = tf.placeholder(tf.float32, name='acc_decay_in', shape=[]) if name not in self._acc_decay_ops: with tf.name_scope('acc_' + name): ops = [tf.assign(var, var * self._acc_decay_in) for var in acc_vars] self._acc_decay_ops[name] = tf.group(*ops) acc_decay = 0.5 ** (nimg_delta / (self.stat_decay_kimg * 1000)) if self.stat_decay_kimg > 0 else 0 tflib.run(self._acc_decay_ops[name], {self._acc_decay_in: acc_decay}) return acc_sum / acc_num if acc_num > 0 else 0