Exemplo n.º 1
0
Arquivo: loss.py Projeto: johndpope/BA
def D_logistic_r2(G,
                  D,
                  opt,
                  training_set,
                  minibatch_size,
                  reals,
                  labels,
                  gamma=10.0):
    _ = opt, training_set
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = D.get_output_for(reals, labels, is_training=True)
    fake_scores_out = D.get_output_for(fake_images_out,
                                       labels,
                                       is_training=True)
    real_scores_out = autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
    loss = tf.nn.softplus(fake_scores_out)  # -log(1-sigmoid(fake_scores_out))
    loss += tf.nn.softplus(-real_scores_out)  # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type

    with tf.name_scope('GradientPenalty'):
        fake_grads = tf.gradients(tf.reduce_sum(fake_scores_out),
                                  [fake_images_out])[0]
        gradient_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1, 2, 3])
        gradient_penalty = autosummary('Loss/gradient_penalty',
                                       gradient_penalty)
        reg = gradient_penalty * (gamma * 0.5)
    return loss, reg
Exemplo n.º 2
0
Arquivo: loss.py Projeto: johndpope/BA
def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels):
    _ = opt, training_set
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = D.get_output_for(reals, labels, is_training=True)
    fake_scores_out = D.get_output_for(fake_images_out,
                                       labels,
                                       is_training=True)
    real_scores_out = autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
    loss = tf.nn.softplus(fake_scores_out)  # -log(1-sigmoid(fake_scores_out))
    loss += tf.nn.softplus(-real_scores_out)  # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type
    return loss, None
Exemplo n.º 3
0
Arquivo: loss.py Projeto: johndpope/BA
def D_wgan_gp(G,
              D,
              opt,
              training_set,
              minibatch_size,
              reals,
              labels,
              wgan_lambda=10.0,
              wgan_epsilon=0.001,
              wgan_target=1.0):
    _ = opt, training_set
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = D.get_output_for(reals, labels, is_training=True)
    fake_scores_out = D.get_output_for(fake_images_out,
                                       labels,
                                       is_training=True)
    real_scores_out = autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
    loss = fake_scores_out - real_scores_out
    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = autosummary('Loss/epsilon_penalty',
                                      tf.square(real_scores_out))
    loss += epsilon_penalty * wgan_epsilon

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1],
                                           0.0,
                                           1.0,
                                           dtype=fake_images_out.dtype)
        mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype),
                                      fake_images_out, mixing_factors)
        mixed_scores_out = D.get_output_for(mixed_images_out,
                                            labels,
                                            is_training=True)
        mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out)
        mixed_grads = tf.gradients(tf.reduce_sum(mixed_scores_out),
                                   [mixed_images_out])[0]
        mixed_norms = tf.sqrt(
            tf.reduce_sum(tf.square(mixed_grads), axis=[1, 2, 3]))
        mixed_norms = autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
        reg = gradient_penalty * (wgan_lambda / (wgan_target**2))
    return loss, reg
Exemplo n.º 4
0
    def register_gradients(self, loss: TfExpression,
                           trainable_vars: Union[List, dict]) -> None:
        """Register the gradients of the given loss function with respect to the given variables.
        Intended to be called once per GPU."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        device = self._get_device(loss.device)

        # Validate trainables.
        if isinstance(trainable_vars, dict):
            trainable_vars = list(trainable_vars.values(
            ))  # allow passing in Network.trainables as vars
        assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
        assert all(
            tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
        assert all(var.device == device.name for var in trainable_vars)

        # Validate shapes.
        if self._gradient_shapes is None:
            self._gradient_shapes = [
                var.shape.as_list() for var in trainable_vars
            ]
        assert len(trainable_vars) == len(self._gradient_shapes)
        assert all(
            var.shape.as_list() == var_shape
            for var, var_shape in zip(trainable_vars, self._gradient_shapes))

        # Report memory usage if requested.
        deps = []
        if self._report_mem_usage:
            self._report_mem_usage = False
            try:
                with tf.name_scope(self.id + '_mem'), tf.device(
                        device.name), tf.control_dependencies([loss]):
                    deps.append(
                        autosummary.autosummary(
                            self.id + "/mem_usage_gb",
                            tf.contrib.memory_stats.BytesInUse() / 2**30))
            except tf.errors.NotFoundError:
                pass

        # Compute gradients.
        with tf.name_scope(self.id + "_grad"), tf.device(
                device.name), tf.control_dependencies(deps):
            loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
            gate = tf.train.Optimizer.GATE_NONE  # disable gating to reduce memory usage
            grad_list = device.optimizer.compute_gradients(
                loss=loss, var_list=trainable_vars, gate_gradients=gate)

        # Register gradients.
        for grad, var in grad_list:
            if var not in device.grad_raw:
                device.grad_raw[var] = []
            device.grad_raw[var].append(grad)
Exemplo n.º 5
0
Arquivo: loss.py Projeto: johndpope/BA
def D_wgan(G,
           D,
           opt,
           training_set,
           minibatch_size,
           reals,
           labels,
           wgan_epsilon=0.001):
    _ = opt, training_set
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = D.get_output_for(reals, labels, is_training=True)
    fake_scores_out = D.get_output_for(fake_images_out,
                                       labels,
                                       is_training=True)
    real_scores_out = autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
    loss = fake_scores_out - real_scores_out
    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = autosummary('Loss/epsilon_penalty',
                                      tf.square(real_scores_out))
        loss += epsilon_penalty * wgan_epsilon
    return loss, None
Exemplo n.º 6
0
    def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
        """Construct training op to update the registered variables based on their gradients."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        self._updates_applied = True
        all_ops = []

        # Check for no-op.
        if allow_no_op and len(self._devices) == 0:
            with tfutil.absolute_name_scope(self.scope):
                return tf.no_op(name='TrainingOp')

        # Clean up gradients.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Clean%d" %
                                            device_idx), tf.device(
                                                device.name):
                for var, grad in device.grad_raw.items():

                    # Filter out disconnected gradients and convert to float32.
                    grad = [g for g in grad if g is not None]
                    grad = [tf.cast(g, tf.float32) for g in grad]

                    # Sum within the device.
                    if len(grad) == 0:
                        grad = tf.zeros(var.shape)  # No gradients => zero.
                    elif len(grad) == 1:
                        grad = grad[0]  # Single gradient => use as is.
                    else:
                        grad = tf.add_n(grad)  # Multiple gradients => sum.

                    # Scale as needed.
                    scale = 1.0 / len(device.grad_raw[var]) / len(
                        self._devices)
                    scale = tf.constant(scale, dtype=tf.float32, name="scale")
                    if self.minibatch_multiplier is not None:
                        scale /= tf.cast(self.minibatch_multiplier, tf.float32)
                    scale = self.undo_loss_scaling(scale)
                    device.grad_clean[var] = grad * scale

        # Sum gradients across devices.
        if len(self._devices) > 1:
            with tfutil.absolute_name_scope(self.scope +
                                            "/Broadcast"), tf.device(None):
                for all_vars in zip(*[
                        device.grad_clean.keys()
                        for device in self._devices.values()
                ]):
                    if len(all_vars) > 0 and all(
                            dim > 0 for dim in all_vars[0].shape.as_list()
                    ):  # NCCL does not support zero-sized tensors.
                        all_grads = [
                            device.grad_clean[var] for device, var in zip(
                                self._devices.values(), all_vars)
                        ]
                        all_grads = nccl_ops.all_sum(all_grads)
                        for device, var, grad in zip(self._devices.values(),
                                                     all_vars, all_grads):
                            device.grad_clean[var] = grad

        # Apply updates separately on each device.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Apply%d" %
                                            device_idx), tf.device(
                                                device.name):
                # pylint: disable=cell-var-from-loop

                # Accumulate gradients over time.
                if self.minibatch_multiplier is None:
                    acc_ok = tf.constant(True, name='acc_ok')
                    device.grad_acc = OrderedDict(device.grad_clean)
                else:
                    # Create variables.
                    with tf.control_dependencies(None):
                        for var in device.grad_clean.keys():
                            device.grad_acc_vars[var] = tf.Variable(
                                tf.zeros(var.shape),
                                trainable=False,
                                name="grad_acc_var")
                        device.grad_acc_count = tf.Variable(
                            tf.zeros([]),
                            trainable=False,
                            name="grad_acc_count")

                    # Track counter.
                    count_cur = device.grad_acc_count + 1.0
                    count_inc_op = lambda: tf.assign(device.grad_acc_count,
                                                     count_cur)
                    count_reset_op = lambda: tf.assign(device.grad_acc_count,
                                                       tf.zeros([]))
                    acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier,
                                                   tf.float32))
                    all_ops.append(
                        tf.cond(acc_ok, count_reset_op, count_inc_op))

                    # Track gradients.
                    for var, grad in device.grad_clean.items():
                        acc_var = device.grad_acc_vars[var]
                        acc_cur = acc_var + grad
                        device.grad_acc[var] = acc_cur
                        with tf.control_dependencies([acc_cur]):
                            acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
                            acc_reset_op = lambda: tf.assign(
                                acc_var, tf.zeros(var.shape))
                            all_ops.append(
                                tf.cond(acc_ok, acc_reset_op, acc_inc_op))

                # No overflow => apply gradients.
                all_ok = tf.reduce_all(
                    tf.stack([acc_ok] + [
                        tf.reduce_all(tf.is_finite(g))
                        for g in device.grad_acc.values()
                    ]))
                apply_op = lambda: device.optimizer.apply_gradients(
                    [(tf.cast(grad, var.dtype), var)
                     for var, grad in device.grad_acc.items()])
                all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))

                # Adjust loss scaling.
                if self.use_loss_scaling:
                    ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var,
                                                      self.loss_scaling_inc)
                    ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var,
                                                      self.loss_scaling_dec)
                    ls_update_op = lambda: tf.group(
                        tf.cond(all_ok, ls_inc_op, ls_dec_op))
                    all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))

                # Last device => report statistics.
                if device_idx == len(self._devices) - 1:
                    all_ops.append(
                        autosummary.autosummary(self.id + "/learning_rate",
                                                self.learning_rate))
                    all_ops.append(
                        autosummary.autosummary(self.id +
                                                "/overflow_frequency",
                                                tf.where(all_ok, 0, 1),
                                                condition=acc_ok))
                    if self.use_loss_scaling:
                        all_ops.append(
                            autosummary.autosummary(
                                self.id + "/loss_scaling_log2",
                                device.loss_scaling_var))

        # Initialize variables.
        self.reset_optimizer_state()
        if self.use_loss_scaling:
            tfutil.init_uninitialized_vars(
                [device.loss_scaling_var for device in self._devices.values()])
        if self.minibatch_multiplier is not None:
            tfutil.run([
                var.initializer for device in self._devices.values()
                for var in list(device.grad_acc_vars.values()) +
                [device.grad_acc_count]
            ])

        # Group everything into a single op.
        with tfutil.absolute_name_scope(self.scope):
            return tf.group(*all_ops, name="TrainingOp")
Exemplo n.º 7
0
Arquivo: loss.py Projeto: johndpope/BA
def G_logistic_ns_pathreg(G,
                          D,
                          opt,
                          training_set,
                          minibatch_size,
                          pl_minibatch_shrink=2,
                          pl_decay=0.01,
                          pl_weight=2.0):
    _ = opt
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    labels = training_set.get_random_labels_tf(minibatch_size)
    fake_images_out, fake_dlatents_out = G.get_output_for(latents,
                                                          labels,
                                                          is_training=True,
                                                          return_dlatents=True)
    fake_scores_out = D.get_output_for(fake_images_out,
                                       labels,
                                       is_training=True)
    loss = tf.nn.softplus(-fake_scores_out)  # -log(sigmoid(fake_scores_out))

    # Path length regularization.
    with tf.name_scope('PathReg'):

        # Evaluate the regularization term using a smaller minibatch to conserve memory.
        if pl_minibatch_shrink > 1:
            pl_minibatch = minibatch_size // pl_minibatch_shrink
            pl_latents = tf.random_normal([pl_minibatch] +
                                          G.input_shapes[0][1:])
            pl_labels = training_set.get_random_labels_tf(pl_minibatch)
            fake_images_out, fake_dlatents_out = G.get_output_for(
                pl_latents, pl_labels, is_training=True, return_dlatents=True)

        # Compute |J*y|.
        pl_noise = tf.random_normal(tf.shape(fake_images_out)) / np.sqrt(
            np.prod(G.output_shape[2:]))
        pl_grads = tf.gradients(tf.reduce_sum(fake_images_out * pl_noise),
                                [fake_dlatents_out])[0]
        pl_lengths = tf.sqrt(
            tf.reduce_mean(tf.reduce_sum(tf.square(pl_grads), axis=2), axis=1))
        pl_lengths = autosummary('Loss/pl_lengths', pl_lengths)

        # Track exponential moving average of |J*y|.
        with tf.control_dependencies(None):
            pl_mean_var = tf.Variable(name='pl_mean',
                                      trainable=False,
                                      initial_value=0.0,
                                      dtype=tf.float32)
        pl_mean = pl_mean_var + pl_decay * (tf.reduce_mean(pl_lengths) -
                                            pl_mean_var)
        pl_update = tf.assign(pl_mean_var, pl_mean)

        # Calculate (|J*y|-a)^2.
        with tf.control_dependencies([pl_update]):
            pl_penalty = tf.square(pl_lengths - pl_mean)
            pl_penalty = autosummary('Loss/pl_penalty', pl_penalty)

        # Apply weight.
        #
        # Note: The division in pl_noise decreases the weight by num_pixels, and the reduce_mean
        # in pl_lengths decreases it by num_affine_layers. The effective weight then becomes:
        #
        # gamma_pl = pl_weight / num_pixels / num_affine_layers
        # = 2 / (r^2) / (log2(r) * 2 - 2)
        # = 1 / (r^2 * (log2(r) - 1))
        # = ln(2) / (r^2 * (ln(r) - ln(2))
        #
        reg = pl_penalty * pl_weight

    return loss, reg
Exemplo n.º 8
0
def training_loop(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False
):  # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                Gs = rGs

    # Print layers and generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    G_loss, G_reg = util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
                if D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(D_reg * D_reg_interval),
                        D_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = RunContext.get().get_time_since_last_update()
            total_time = RunContext.get().get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                %
                (autosummary('Progress/tick', cur_tick),
                 autosummary('Progress/kimg', cur_nimg / 1000.0),
                 autosummary('Progress/lod', sched.lod),
                 autosummary('Progress/minibatch', sched.minibatch_size),
                 util.format_time(autosummary('Timing/total_sec', total_time)),
                 autosummary('Timing/sec_per_tick', tick_time),
                 autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                 autosummary('Timing/maintenance_sec', maintenance_time),
                 autosummary('Resources/peak_gpu_mem_gb',
                             peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     make_run_dir_path('fakes%06d.png' %
                                                       (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = make_run_dir_path('network-snapshot-%06d.pkl' %
                                        (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=make_run_dir_path(),
                            data_dir=convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            RunContext.get().update('%.2f' % sched.lod,
                                    cur_epoch=cur_nimg // 1000,
                                    max_epoch=total_kimg)
            maintenance_time = RunContext.get().get_last_update_interval(
            ) - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()