Beispiel #1
0
 def test_upsample_bilinear_inverted_by_bilinear(self):
     test_input = tf.reshape(
         tf.constant(np.arange(0, 2 * 8 * 8 * 3) / (2 * 8 * 8 * 3),
                     dtype=tf.float32), [2, 8, 8, 3])
     up_x = upsample(test_input, "bilinear")
     down_x = downsample(up_x, "bilinear")
     np.set_printoptions(threshold=np.nan, suppress=True)
     self.assertAllClose(down_x, test_input, atol=.02)
Beispiel #2
0
 def test_upsample_nn_inverted_by_avg_pool(self):
     test_input = tf.constant(np.random.normal(0., 1., size=[2, 4, 4, 3]),
                              dtype=tf.float32)
     up_x = upsample(test_input, "nearest_neighbor")
     down_x = tf.nn.avg_pool(up_x,
                             ksize=[1, 2, 2, 1],
                             strides=[1, 2, 2, 1],
                             padding='SAME')
     self.assertAllEqual(down_x, test_input)
Beispiel #3
0
 def test_upsample_nn(self):
     test_input_spatial = [[0., 1.], [2., 3.]]
     test_input = tf.transpose(
         tf.constant([[test_input_spatial] * 3] * 2, dtype=tf.float32),
         (0, 2, 3, 1))  # b, h, w, c
     x = upsample(test_input, method='nearest_neighbor')
     spatial_target = [[0., 0., 1., 1.], [0., 0., 1., 1.], [2., 2., 3., 3.],
                       [2., 2., 3., 3.]]
     target_array = tf.constant([[spatial_target] * 3] * 2)  # b, c, h, w
     x = tf.transpose(x, (0, 3, 1, 2))  # b, c, h, w
     self.assertAllEqual(x, target_array)
Beispiel #4
0
 def test_upsample_bilinear(self):
     test_input_spatial = [[0., .1], [.2, .3]]
     test_input = tf.transpose(
         tf.constant([[test_input_spatial] * 3] * 2, dtype=tf.float32),
         (0, 2, 3, 1))  # b, h, w, c
     x = upsample(test_input, method='bilinear')
     # skimage.transform.resize (mode='edge') result (a bit different than tf.image.resize_bilinear)
     spatial_target = [[0., 0.025, 0.075, 0.1], [0.05, 0.075, 0.125, 0.15],
                       [0.15, 0.175, 0.225, 0.25], [0.2, 0.225, 0.275, 0.3]]
     target_array = tf.constant([[spatial_target] * 3] * 2)  # b, c, h, w
     x = tf.transpose(x, (0, 3, 1, 2))  # b, c, h, w
     self.assertAllClose(x, target_array, atol=.02)
    def __init__(self,
                 sig,
                 sampling_rate,
                 r,
                 sess,
                 inputs,
                 predictions,
                 interp=True):
        self.rate = sampling_rate
        self.inputs = inputs
        self.predictions = predictions
        self.sess = sess
        patch_dimension = int(inputs[0].shape[1])

        num_to_keep = int(
            np.floor(len(sig) / patch_dimension) * patch_dimension)
        sig = sig[:num_to_keep]
        sig = sig[:len(sig) - (
            len(sig) % r
        )]  # Es: scaling_factor = 2 -> se il numero di campioni (lunghezza di x) è pari, allora non succede nulla. Se è dispari, invece, l'ultimo campione viene rimosso.
        sig_lr = decimate(sig, r)

        if interp:
            sig_lr = upsample(sig_lr, r)
            assert len(sig_lr) == len(sig)

        num_y = int(sig.shape[0] / patch_dimension)
        #generate patches
        self.Y = np.expand_dims(chunkIt(sig, num_y), axis=-1)
        if interp:
            self.X_lr = np.expand_dims(chunkIt(sig_lr, num_y), axis=-1)
        else:
            self.X_lr = np.expand_dims(chunkIt(sig_lr, int(num_y / r)),
                                       axis=-1)

        self.batches = (self.X_lr, self.Y)
Beispiel #6
0
    def call(self, alpha, zs=None, intermediate_ws=None, mapping_network=None, cgan_w=None,
             crossover_list=None, random_crossover=False):
        """
        :param alpha:
        :param zs:
        :param intermediate_ws:
        :param mapping_network:
        :param cgan_w:
        :param crossover_list:
        :param random_crossover:
        :return:
        """
        intermediate_mode = (intermediate_ws is not None)
        mixing_mode = isinstance(zs, list) or isinstance(intermediate_ws, list)
        style_mixing = random_crossover or crossover_list is not None
        if zs is None and intermediate_ws is None:
            raise ValueError("Need z or intermediate")
        if self.use_mapping_network and (mapping_network is None and intermediate_ws is None):
            raise ValueError("No mapping network supplied to generator call")

        if not mixing_mode:
            if intermediate_mode:
                intermediate_ws = [intermediate_ws]
            else:
                zs = [zs]
        if not intermediate_mode:
            intermediate_ws = []
            for z in zs:
                z_shape = z.get_shape().as_list()
                if self.use_pixel_norm:
                    z = pixel_norm(z)  # todo: verify correct
                if len(z_shape) == 2:  # [batch size, z dim]
                    if self.map_cond and cgan_w is not None:
                        z = tf.concat([z, cgan_w], -1)
                    intermediate_latent = mapping_network(z)
                    z = tf.expand_dims(z, 1)
                    z = tf.expand_dims(z, 1)
                    #z = tf.reshape(z, [z_shape[0], 1, 1, -1])
                else:  # [batch size, 1, 1, z dim]
                    z_flat = tf.squeeze(z, axis=[1, 2])
                    if self.map_cond and cgan_w is not None:
                        z_flat = tf.concat([z_flat, cgan_w], -1)
                    intermediate_latent = mapping_network(z_flat)
                intermediate_ws.append(intermediate_latent)

        if len(intermediate_ws) > 1 and not random_crossover and crossover_list is None:
                raise ValueError("Need crossover for mixing mode")



        if cgan_w is not None and not self.map_cond:
           intermediate_latent_cond = tf.concat([intermediate_latent, cgan_w], -1)
        else:
           intermediate_latent_cond = None
        intermediate_latent_cond = None

        batch_size = tf.shape(intermediate_ws[0])[0]
        latent_size = tf.shape(intermediate_ws[0])[1]
        if self.learned_input is not None:
            z = tf.expand_dims(self.learned_input(None), axis=0)
            x = tf.tile(z, [batch_size, 1, 1, 1])
        else:
            x = tf.pad(z, [[0, 0], [3, 3], [3, 3], [0, 0]])
        if self.model_res_w == 1:  # for testing purposes
            return x
        current_res = self.start_shape[1]

        # Inefficient implementation, will have to redo
        with tf.name_scope("style_mixing"):
            intermediate_for_layer_list = []
            if random_crossover:
                intermediate_for_layer_list = []
                intermediate_mixing_schedule = tf.random.uniform([batch_size], 0, len(self.model_layers), dtype=tf.int32)
                intermediate_mixing_schedule = tf.transpose(
                    tf.one_hot(intermediate_mixing_schedule, depth=len(self.model_layers), dtype=tf.int32))
                intermediate_multiplier_for_current_layer = tf.zeros([batch_size], dtype=tf.int32)
                for i in range(0, len(self.model_layers)):
                    intermediate_multiplier_for_current_layer = tf.bitwise.bitwise_or(
                        intermediate_multiplier_for_current_layer,
                        intermediate_mixing_schedule[i])
                    intermediate_multiplier = tf.cast(intermediate_multiplier_for_current_layer,
                                                                        dtype=tf.float32)
                    intermediate_multiplier = tf.expand_dims(intermediate_multiplier, 1)
                    intermediate_for_layer_list.append(
                        (1-intermediate_multiplier)*intermediate_ws[0] +
                        intermediate_multiplier*intermediate_ws[1])
            elif crossover_list:
                for i in range(0, len(self.model_layers)):
                    intermediate_index = 0
                    for c in crossover_list:
                        if i >= c:
                            intermediate_index += 1
                    intermediate_for_layer_list.append(intermediate_ws[intermediate_index])
        to_rgb_lower = 0.
        layer_counter = 0
        # shape: [num_layers, batch_size, len(intermediate_w)]

        # for i in range(0, len(self.model_layers)):
        #     latents_to_swap = tf.random.categorical([batch_size, 2])
        #         ([batch_size, latent_size], minval=0, maxval=1, dtype=tf.int32, )
        #     intermediate_for_layer_list
        # if random_crossover:
        #     crossover_layer = tf.random_uniform([tf.shape(intermediate_ws[0])[0], 1], 0, len(self.model_layers),
        #                                         dtype=tf.int32)
        for conv1, noise1, bias1, tostyle1, conv2, noise2, bias2, tostyle2 in self.model_layers:
            with tf.name_scope("Res%d"%current_res):
                #apply_conditioning = intermediate_latent_cond is not None and \
                #    (self.cond_layers is None or
                #     layer_counter in self.cond_layers)
                apply_conditioning = False

                if (self.include_fmap_add_ops):
                    x += tf.zeros([tf.shape(x)], dtype=tf.float32, name="FmapRes%d")
                if layer_counter != 0 or self.learned_input is None:
                    x = conv1(x)
                if self.add_noise:
                    with tf.name_scope("noise_add1"):
                        noise_inputs = noise1(False)
                        assert(x.get_shape().as_list()[1:] == noise_inputs.get_shape().as_list()[1:])
                        x += noise_inputs
                x = bias1(x)
                x = tf.nn.leaky_relu(x, alpha=.2)
                if self.use_pixel_norm:
                    x = pixel_norm(x)

                if apply_conditioning:
                    ys, yb = tostyle1(intermediate_latent_cond)
                else:
                    if style_mixing:
                        ys, yb = tostyle1(intermediate_for_layer_list[layer_counter])
                    else:
                        ys, yb = tostyle1(intermediate_ws[0])
                x = adaptive_instance_norm(x, ys, yb)

                x = conv2(x)
                if self.use_pixel_norm:
                    x = pixel_norm(x)
                if self.add_noise:
                    with tf.name_scope("noise_add2"):
                        noise_inputs = noise2(False)
                        assert(x.get_shape().as_list()[1:] == noise_inputs.get_shape().as_list()[1:])
                        x += noise_inputs
                x = bias2(x)
                x = tf.nn.leaky_relu(x, alpha=.2)

                if apply_conditioning:
                    ys, yb = tostyle2(intermediate_latent_cond)
                else:
                    if style_mixing:
                        ys, yb = tostyle1(intermediate_for_layer_list[layer_counter])
                    else:
                        ys, yb = tostyle1(intermediate_ws[0])
                x = adaptive_instance_norm(x, ys, yb)

                if current_res == self.model_res_w // 2:
                    to_rgb_lower = upsample(self.toRGB_lower(x), method=self.resize_method)
                if current_res != self.model_res_w:
                    x = upsample(x, method=self.resize_method)
                layer_counter += 1
                current_res *= 2
        to_rgb = self.toRGB(x)
        output = to_rgb_lower + alpha * (to_rgb - to_rgb_lower)
        if self.output_res_w//self.model_res_w >= 2:
            output = upsample(output, method='nearest_neighbor',
                              factor=self.output_res_w//self.model_res_w)
        return output
Beispiel #7
0
def train(hps, files):
    ngpus = hps.ngpus
    config = tf.ConfigProto()
    if ngpus > 1:
        try:
            import horovod.tensorflow as hvd
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(hvd.local_rank())
        except ImportError:
            hvd = None
            print("horovod not available, can only use 1 gpu")
            ngpus = 1

    # todo: organize
    current_res_w = hps.current_res_w
    res_multiplier = current_res_w // hps.start_res_w
    current_res_h = hps.start_res_h * res_multiplier

    tfrecord_input = any('.tfrecords' in fname for fname in files)
    # if using tfrecord, assume dataset is duplicated across multiple resolutions
    if tfrecord_input:
        num_files = 0
        for fname in [fname for fname in files if "res%d" % current_res_w in fname]:
            for record in tf.compat.v1.python_io.tf_record_iterator(fname):
                num_files += 1
    else:
        num_files = len(files)

    label_list = []
    total_classes = 0
    if hps.label_file:
        do_cgan = True
        label_list, total_classes = build_label_list_from_file(hps.label_file)
    else:
        do_cgan = False

    print("dataset has %d files" % num_files)
    try:
        batch_size = int(hps.batch_size)
        try_schedule = False
    except ValueError:
        try_schedule = True
    if try_schedule:
        batch_schedule = ast.literal_eval(hps.batch_size)
    else:
        batch_schedule = None

    #  always generate 32 sample images (should be feasible at high resolutions due to no training)
    #  will probably need to edit for > 128x128
    sample_batch = 32
    sample_latent_numpy = np.random.normal(0., 1., [sample_batch, 512])

    if do_cgan:
        examples_per_class = sample_batch // total_classes
        remainder = sample_batch % total_classes
        sample_cgan_latent_numpy = None
        for i in range(0, total_classes):
            class_vector = [0.] * total_classes
            class_vector[i] = 1.
            if sample_cgan_latent_numpy is None:
                sample_cgan_latent_numpy = [class_vector] * (examples_per_class + remainder)
            else:
                sample_cgan_latent_numpy += [class_vector] * examples_per_class
        sample_cgan_latent_numpy = np.array(sample_cgan_latent_numpy)

    use_beholder = hps.use_beholder
    if use_beholder:
        try:
            from tensorboard.plugins.beholder import Beholder
        except ImportError:
            print("Could not import beholder")
            use_beholder = False
    while current_res_w <= hps.res_w:
        if ngpus > 1:
            hvd.init()
        print("building graph")
        if batch_schedule is not None:
            batch_size = batch_schedule[current_res_w]
            print("res %d batch size is now %d" % (current_res_w, batch_size))
        gen_model, mapping_network, dis_model, sampling_model = \
            build_models(hps,
                         current_res_w,
                         use_ema_sampling=True,
                         num_classes=total_classes,
                         label_list=label_list if hps.conditional_type == "acgan" else None)
        with tf.name_scope("optimizers"):
            optimizer_d, optimizer_g, optimizer_m = build_optimizers(hps)
            if ngpus > 1:
                optimizer_d = hvd.DistributedOptimizer(optimizer_d)
                optimizer_g = hvd.DistributedOptimizer(optimizer_g)
                optimizer_m = hvd.DistributedOptimizer(optimizer_m)
        with tf.name_scope("data"):
            num_shards = None if ngpus == 1 else ngpus
            shard_index = None if ngpus == 1 else hvd.rank()
            it = build_data_iterator(hps, files, current_res_h, current_res_w, batch_size, label_list=label_list,
                                     num_shards=num_shards, shard_index=shard_index)
            next_batch = it.get_next()
            real_image = next_batch['data']

            fake_latent1 = tf.random_normal([batch_size, 512], 0., 1., name="fake_latent")
            fake_latent2 = tf.random_normal([batch_size, 512], 0., 1., name="fake_latent")

            fake_label_dict = None
            real_label_dict = None
            if do_cgan:
                fake_label_dict = {}
                real_label_dict = {}
                for label in label_list:
                    if hps.cond_uniform_fake:
                        distribution = np.ones_like([label.probabilities])
                    else:
                        distribution = np.log([label.probabilities])
                    fake_labels = tf.random.categorical(distribution, batch_size)
                    if label.multi_dim is False:
                        normalized_labels = (fake_labels - tf.reduce_min(fake_labels)) / \
                                            (tf.reduce_max(fake_labels) - tf.reduce_min(fake_labels))
                        fake_labels = tf.reshape(normalized_labels, [batch_size, 1])
                    else:
                        fake_labels = tf.reshape(tf.one_hot(fake_labels, label.num_classes),
                                                 [batch_size, label.num_classes])
                    fake_label_dict[label.name] = fake_labels
                    real_label_dict[label.name] = next_batch[label.name]
                    #fake_label_list.append(fake_labels)
                    # ideally would handle one dimensional labels differently, theory isn't well supported
                    # for that though (example: categorical values of short, medium, tall are on one dimension)
                    # real_labels = tf.reshape(tf.one_hot(tf.cast(next_batch[label.name], tf.int32), num_classes),
                    #                          [batch_size, num_classes])
                    #real_label_list.append(real_labels)
                fake_label_tensor = tf.concat([fake_label_dict[l] for l in fake_label_dict.keys()], axis=-1)
                real_label_tensor = tf.concat([real_label_dict[l] for l in real_label_dict.keys()], axis=-1)
            sample_latent = tf.constant(sample_latent_numpy, dtype=tf.float32, name="sample_latent")
            if do_cgan:
                sample_cgan_w = tf.constant(sample_cgan_latent_numpy, dtype=tf.float32, name="sample_cgan_latent")
            alpha_ph = tf.placeholder(shape=(), dtype=tf.float32, name="alpha")
            #  From Fig 2: "During a resolution transition,
            #  we interpolate between two resolutions of the real images"
            real_image = real_image*alpha_ph + \
                (1-alpha_ph)*upsample(downsample_nv(real_image),
                              method="nearest_neighbor")
            real_image = upsample(real_image, method='nearest_neighbor', factor=hps.res_w//current_res_w)
        if do_cgan:
            with tf.name_scope("gen_synthesis"):
                fake_image = gen_model(alpha_ph, zs=[fake_latent1, fake_latent2], mapping_network=mapping_network,
                                       cgan_w=fake_label_tensor, random_crossover=True)
            real_logit, real_class_logits = dis_model(real_image, alpha_ph,
                                                      real_label_tensor if hps.conditional_type == "proj" else
                                                      None)
            fake_logit, fake_class_logits = dis_model(fake_image, alpha_ph,
                                                      fake_label_tensor if hps.conditional_type == "proj" else
                                                      None)
        else:
            with tf.name_scope("gen_synthesis"):
                fake_image = gen_model(alpha_ph, zs=[fake_latent1, fake_latent2], mapping_network=mapping_network,
                                       random_crossover=True)
            real_logit, real_class_logits = dis_model(real_image, alpha_ph)  # todo: make work with other labels
            fake_logit, fake_class_logits = dis_model(fake_image, alpha_ph)

        with tf.name_scope("gen_sampling"):

            average_latent = tf.constant(np.random.normal(0., 1., [10000, 512]), dtype=tf.float32)
            low_psi = 0.20
            if hps.map_cond:
                class_vector = [0.] * total_classes
                class_vector[0] = 1. # one hot encoding
                average_w = tf.reduce_mean(mapping_network(tf.concat([average_latent,
                                                                      [class_vector]*10000], axis=-1)), axis=0)
                sample_latent_lowpsi = average_w + low_psi * \
                                       (mapping_network(tf.concat([sample_latent,
                                                                   [class_vector]*sample_batch], axis=-1)) - average_w)
            else:
                average_w = tf.reduce_mean(mapping_network(average_latent), axis=0)
                sample_latent_lowpsi = average_w + low_psi * (mapping_network(sample_latent) - average_w)
            average_w_batch = tf.tile(tf.reshape(average_w, [1, 512]), [sample_batch, 1])
            if do_cgan:
                sample_img_lowpsi = sampling_model(alpha_ph, intermediate_ws=sample_latent_lowpsi,
                                                   cgan_w=sample_cgan_w)
                sample_img_base = sampling_model(alpha_ph, zs=sample_latent, mapping_network=mapping_network,
                                                 cgan_w=sample_cgan_w)
                sample_img_mode = sampling_model(alpha_ph, intermediate_ws=average_w_batch,
                                                 cgan_w=sample_cgan_w)
                sample_img_mode = tf.concat([sample_img_mode[0:2] + sample_img_mode[-3:-1]], axis=0)
            else:
                sample_img_lowpsi = sampling_model(alpha_ph, intermediate_ws=sample_latent_lowpsi)
                sample_img_base = sampling_model(alpha_ph, zs=sample_latent, mapping_network=mapping_network)
                sample_img_mode = sampling_model(alpha_ph, intermediate_ws=average_w_batch)[0:4]
            sample_images = tf.concat([sample_img_lowpsi, sample_img_mode, sample_img_base], axis=0)
            sampling_model_init_ops = weight_following_ema_ops(average_model=sampling_model,
                                                               reference_model=gen_model)
            #sample_img_base = gen_model(sample_latent, alpha_ph, mapping_network)

        with tf.name_scope("loss"):
            loss_discriminator, loss_generator = hps.loss_fn(real_logit, fake_logit)
            if real_class_logits is not None:
                for label in label_list:
                    label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=next_batch[label.name],
                                                                         logits=real_class_logits[label.name])
                    loss_discriminator += label_loss * hps.cond_weight * 1./(len(label_list))
                    tf.summary.scalar("label_loss_real", tf.reduce_mean(label_loss))
            if fake_class_logits is not None:
                for label in label_list:
                    label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=fake_label_dict[label.name],
                                                                         logits=fake_class_logits[label.name])
                    loss_discriminator += label_loss * hps.cond_weight * 1./(len(label_list))
                    tf.summary.scalar("label_loss_fake", tf.reduce_mean(label_loss))

                    loss_generator += label_loss * hps.cond_weight * 1./(len(label_list))
            if hps.gp_fn:
                gp = hps.gp_fn(fake_image, real_image, dis_model, alpha_ph, real_label_dict,
                               conditional_type=hps.conditional_type)
                tf.summary.scalar("gradient_penalty", tf.reduce_mean(gp))
                loss_discriminator += hps.lambda_gp*gp
            dp = drift_penalty(real_logit)
            tf.summary.scalar("drift_penalty", tf.reduce_mean(dp))
            if hps.lambda_drift != 0.:
                loss_discriminator = tf.expand_dims(loss_discriminator, -1) + hps.lambda_drift * dp

            loss_discriminator_avg = tf.reduce_mean(loss_discriminator)
            loss_generator_avg = tf.reduce_mean(loss_generator)
        with tf.name_scope("train"):
            train_step_d = optimizer_d.minimize(loss_discriminator_avg, var_list=dis_model.trainable_variables)
            # todo: test this
            with tf.control_dependencies(weight_following_ema_ops(average_model=sampling_model,
                                                                  reference_model=gen_model)):
                train_step_g = [optimizer_g.minimize(loss_generator_avg, var_list=gen_model.trainable_variables)]
            if hps.do_mapping_network:
                train_step_g.append(
                    optimizer_m.minimize(loss_generator_avg, var_list=mapping_network.trainable_variables))
        with tf.name_scope("summary"):
            tf.summary.histogram("real_scores", real_logit)
            tf.summary.scalar("loss_discriminator", loss_discriminator_avg)
            tf.summary.scalar("loss_generator", loss_generator_avg)
            tf.summary.scalar("real_logit", tf.reduce_mean(real_logit))
            tf.summary.scalar("fake_logit", tf.reduce_mean(fake_logit))
            tf.summary.histogram("real_logit", real_logit)
            tf.summary.histogram("fake_logit", fake_logit)
            tf.summary.scalar("alpha", alpha_ph)
            merged = tf.summary.merge_all()
            image_summary_real = generate_image_summary(real_image, "real")
            image_summary_fake_avg = generate_image_summary(sample_images, "fake_avg")
            #image_summary_fake = generate_image_summary(sample_img_base, "fake")
        global_step = tf.train.get_or_create_global_step()
        if hps.profile:
            builder = tf.profiler.ProfileOptionBuilder
            opts = builder(builder.time_and_memory()).order_by('micros').build()

        with tf.contrib.tfprof.ProfileContext(hps.model_dir,
                                              trace_steps=[],
                                              dump_steps=[]) as pctx:
            with tf.Session(config=config) as sess:
                #if hps.tboard_debug:
                #    sess = tf_debug.TensorBoardDebugWrapperSession(sess, "localhost:6064")
                #elif hps.cli_debug:
                #    sess = tf_debug.LocalCLIDebugWrapperSession(sess)
                sess.run(tf.global_variables_initializer())
                sess.run(sampling_model_init_ops)
                alpha = 1.
                step = 0
                if os.path.exists(hps.save_paths.gen_model) and os.path.exists(hps.save_paths.dis_model):
                    if ngpus == 1 or hvd.rank() == 0:
                        print("restoring")
                        restore_models_and_optimizers(sess, gen_model, dis_model, mapping_network,
                                                      sampling_model,
                                                      optimizer_g, optimizer_d, optimizer_m, hps.save_paths)
                if os.path.exists(hps.save_paths.alpha) and os.path.exists(hps.save_paths.step):
                    alpha, step = restore_alpha_and_step(hps.save_paths)
                
                print("alpha")
                print(alpha)

                if alpha != 1.:
                    alpha_inc = 1. / (hps.epochs_per_res * (num_files / batch_size))
                else:
                    alpha_inc = 0.
                writer_path = \
                    os.path.join(hps.model_dir, "summary_%d" % current_res_w, "alpha_start_%d" % alpha)
                if use_beholder:
                    beholder = Beholder(writer_path)
                writer = tf.summary.FileWriter(writer_path, sess.graph)
                writer.add_summary(image_summary_real.eval(feed_dict={alpha_ph: alpha}), step)
                print("Starting res %d training" % current_res_w)
                t = trange(hps.epochs_per_res * num_files // batch_size, desc='Training')


                if ngpus > 1:
                    sess.run(hvd.broadcast_global_variables(0))
                for phase_step in t:
                    try:
                        for i in range(0, hps.ncritic):
                            if hps.profile:
                                pctx.trace_next_step()
                                pctx.dump_next_step()
                            if step % 5 == 0:
                                summary, ld, _ = sess.run([merged,
                                                           loss_discriminator_avg,
                                                           train_step_d if not hps.no_train else tf.no_op()],
                                                          feed_dict={alpha_ph: alpha})
                                writer.add_summary(summary, step)
                            else:

                                ld, _ = sess.run([loss_discriminator_avg,
                                                  train_step_d if not hps.no_train else tf.no_op()],
                                                  feed_dict={alpha_ph: alpha})
                            if hps.profile:
                                pctx.profiler.profile_operations(options=opts)
                        if hps.profile:
                            pctx.trace_next_step()
                            pctx.dump_next_step()
                        lg, _ = sess.run([loss_generator_avg,
                                          train_step_g if not hps.no_train else tf.no_op()],
                                         feed_dict={alpha_ph: alpha})
                        if hps.profile:
                            pctx.profiler.profile_operations(options=opts)
                        alpha = min(alpha+alpha_inc, 1.)

                        #print("step: %d" % step)
                        #print("loss_d: %f" % ld)
                        #print("loss_g: %f\n" % lg)
                        t.set_description('Overall step %d, loss d %f, loss g %f' % (step+1, ld, lg))
                        if use_beholder:
                            try:
                                beholder.update(session=sess)
                            except Exception as e:
                                print("Beholder failed: " + str(e))
                                use_beholder = False

                        if phase_step < 5 or (phase_step < 500 and phase_step % 10 == 0) or (step % 1000 == 0):
                            writer.add_summary(image_summary_fake_avg.eval(
                                feed_dict={alpha_ph: alpha}), step)
                            #writer.add_summary(image_summary_fake.eval(
                            #    feed_dict={alpha_ph: alpha}), step)
                        if hps.steps_per_save is not None and step % hps.steps_per_save == 0 and (ngpus == 1 or hvd.rank() == 0):
                            save_models_and_optimizers(sess,
                                                       gen_model, dis_model, mapping_network,
                                                       sampling_model,
                                                       optimizer_g, optimizer_d, optimizer_m,
                                                       hps.save_paths)
                            save_alpha_and_step(1. if alpha_inc != 0. else 0., step, hps.save_paths)
                        step += 1
                    except tf.errors.OutOfRangeError:
                        break
                assert (abs(alpha - 1.) < .1), "Alpha should be close to 1., not %f" % alpha  # alpha close to 1. (dataset divisible by batch_size for small sets)
                if ngpus == 1 or hvd.rank() == 0:
                    print(1. if alpha_inc != 0. else 0.)
                    save_models_and_optimizers(sess,
                                               gen_model, dis_model, mapping_network, sampling_model,
                                               optimizer_g, optimizer_d, optimizer_m,
                                               hps.save_paths)
                    backup_model_for_this_phase(hps.save_paths, writer_path)
                save_alpha_and_step(1. if alpha_inc != 0. else 0., step, hps.save_paths)
                #  Will generate Out of range errors, see if it's easy to save a tensor so get_next() doesn't need
                #  a new value
                #writer.add_summary(image_summary_real.eval(feed_dict={alpha_ph: 1.}), step)
                #writer.add_summary(image_summary_fake.eval(feed_dict={alpha_ph: 1.}), step)

        tf.reset_default_graph()
        if alpha_inc == 0:
            current_res_h *= 2
            current_res_w *= 2