Ejemplo n.º 1
0
def conv_layer(input_,
               filter_size,
               num_filters,
               stride,
               pad,
               nonlinearity=relu,
               W=Normal(0.02),
               **kwargs):
    return dnn.Conv2DDNNLayer(input_,
                              num_filters=num_filters,
                              stride=parse_tuple(stride),
                              filter_size=parse_tuple(filter_size),
                              pad=pad,
                              W=W,
                              nonlinearity=nonlinearity,
                              **kwargs)
Ejemplo n.º 2
0
 def _seperate_time_from_spatial(self, tup):
     if isinstance(tup, tuple):
         time = tup[0]
         space = tup[1:]
     else:
         time = tup
         space = parse_tuple(tup, 2)
     return (time,), space
Ejemplo n.º 3
0
    def __init__(self, filter_size, num_filters, strides=(1,1), padding=None,
                 image_size=None, num_channels=None, **kwargs):
        super(ConvLayer, self).__init__(**kwargs)
        image_size = parse_tuple(image_size, 2)
        self.image_size = image_size

        self.num_filters = num_filters
        self.strides = parse_tuple(strides, 2)
        if isinstance(padding, int):
            padding = parse_tuple(padding, 2)
        self.padding = padding

        self.num_channels = num_channels
        if image_size is None:
            self.input_dims = (num_channels, None, None)
        else:
            self.input_dims = (num_channels, image_size[0], image_size[1])

        self.filter_size = parse_tuple(filter_size, 2)
Ejemplo n.º 4
0
 def __init__(self, filter_size, num_filters, dimshuffle_inp=True,
              pad_time=(0,), **kwargs):
     # a bit of fooling around to use ConvLayer.__init__
     time_filter_size, filter_size = self._seperate_time_from_spatial(filter_size)
     strides = kwargs.pop('strides', (1,1,1))
     time_stride, strides = self._seperate_time_from_spatial(strides)
     super(Conv3DLayer, self).__init__(filter_size, num_filters, strides=strides, **kwargs)
     self.time_filter_size = time_filter_size
     self.time_stride = time_stride
     self.dimshuffle_inp = dimshuffle_inp
     self.pad_time = parse_tuple(pad_time)
     self._gemm = False
Ejemplo n.º 5
0
    def __init__(self,
                 filter_size,
                 num_filters,
                 time_filter_size=None,
                 time_num_filters=None,
                 convupward='conv',
                 convtime='conv',
                 **kwargs):
        if time_filter_size is None:
            time_filter_size = utils.parse_tuple(filter_size, 2)
        else:
            time_filter_size = utils.parse_tuple(time_filter_size, 2)

        # the time application doesnt change the dimensions
        # it is achieved through filter_size of odd shape with half padding
        assert time_filter_size[0] % 2 == 1
        if time_num_filters is None:
            time_num_filters = num_filters

        scan_spatial_input_dims = num_filters * 4

        if convupward == 'conv':
            convupward = ConvLayer(filter_size, scan_spatial_input_dims,
                                   **kwargs)
        elif convupward == 'deconv':
            convupward = DeConvLayer(filter_size, scan_spatial_input_dims,
                                     **kwargs)

        kwargs = self.popkwargs(convupward, kwargs)
        if convtime == 'conv':
            convtime = ScanConvLSTM(time_filter_size,
                                    time_num_filters,
                                    num_channels=num_filters,
                                    spatial_input_dims=scan_spatial_input_dims,
                                    **kwargs)

        super(ConvLSTM, self).__init__(convtime, convupward, **kwargs)
Ejemplo n.º 6
0
    def infer_outputdim(self):
        i_dim = self.image_size[0]
        k_dim = self.filter_size[0]
        s_dim = self.strides[0]
        border_mode = self.padding
        if border_mode == 'valid' :
            border_mode = 0
        elif border_mode == 'half' :
            border_mode = k_dim // 2
        elif border_mode == 'full':
            border_mode = k_dim - 1
        elif isinstance(border_mode, tuple):
            border_mode = border_mode[0]
        else:
            raise ValueError("Does not recognize padding {} in {}".format(
                self.padding, self.prefix))
        self._border_mode = parse_tuple(border_mode, 2)
        o_dim = (i_dim + 2 * border_mode - k_dim) // s_dim + 1

        self.feature_size = (o_dim, o_dim)
Ejemplo n.º 7
0
    def __init__(self,
                 output_dims,
                 input_dims=None,
                 upward='default',
                 time='default',
                 **kwargs):
        output_dims = utils.parse_tuple(output_dims)

        scan_spatial_input_dims = (output_dims[0] * 4, ) + output_dims[1:]
        if upward == 'default':
            upward = FullyConnectedLayer(output_dims=scan_spatial_input_dims,
                                         input_dims=input_dims,
                                         **kwargs)

        # there is no kwargs proper to a fully
        #kwargs = self.popkwargs(upward, kwargs)
        if time == 'default':
            time = ScanLSTM(output_dims=output_dims,
                            input_dims=output_dims,
                            spatial_input_dims=scan_spatial_input_dims,
                            **kwargs)

        super(LSTM, self).__init__(time, upward, **kwargs)
Ejemplo n.º 8
0
def main(args, config):

    if args.horovod:
        verbose = hvd.rank() == 0
        global_size = hvd.size()
        # global_rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        verbose = True
        global_size = 1
        # global_rank = 0
        local_rank = 0

    timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.gmtime())
    logdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'runs',
                          args.architecture, timestamp)

    if verbose:
        writer = tf.summary.FileWriter(logdir=logdir)
        print("Arguments passed:")
        print(args)
        print(f"Saving files to {logdir}")

    else:
        writer = None

    final_shape = parse_tuple(args.final_shape)
    image_channels = final_shape[0]
    final_resolution = final_shape[-1]
    num_phases = int(np.log2(final_resolution) - 1)
    base_dim = num_filters(1, num_phases, size=args.network_size)

    var_list = list()
    global_step = 0

    for phase in range(1, num_phases + 1):

        tf.reset_default_graph()

        # ------------------------------------------------------------------------------------------#
        # DATASET

        size = 2 * 2**phase
        if args.dataset == 'imagenet':
            dataset = imagenet_dataset(
                args.dataset_path,
                args.scratch_path,
                size,
                copy_files=local_rank == 0,
                is_correct_phase=phase >= args.starting_phase,
                gpu=args.gpu,
                num_labels=1 if args.num_labels is None else args.num_labels)
        else:
            raise ValueError(f"Unknown dataset {args.dataset_path}")

        # Get DataLoader
        batch_size = max(1, args.base_batch_size // (2**(phase - 1)))

        if phase >= args.starting_phase:
            assert batch_size * global_size <= args.max_global_batch_size
            if verbose:
                print(
                    f"Using local batch size of {batch_size} and global batch size of {batch_size * global_size}"
                )

        if args.horovod:
            dataset.shard(hvd.size(), hvd.rank())

        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.repeat()
        dataset = dataset.prefetch(AUTOTUNE)
        dataset = dataset.make_one_shot_iterator()
        data = dataset.get_next()
        if len(data) == 1:
            real_image_input = data
            real_label = None
        elif len(data) == 2:
            real_image_input, real_label = data
        else:
            raise NotImplementedError()

        real_image_input = tf.ensure_shape(
            real_image_input, [batch_size, image_channels, size, size])
        real_image_input = real_image_input + tf.random.normal(
            tf.shape(real_image_input)) * .01

        if real_label is not None:
            real_label = tf.one_hot(real_label, depth=args.num_labels)

        # ------------------------------------------------------------------------------------------#
        # OPTIMIZERS

        g_lr = args.g_lr
        d_lr = args.d_lr

        if args.horovod:
            if args.g_scaling == 'sqrt':
                g_lr = g_lr * np.sqrt(hvd.size())
            elif args.g_scaling == 'linear':
                g_lr = g_lr * hvd.size()
            elif args.g_scaling == 'none':
                pass
            else:
                raise ValueError(args.g_scaling)

            if args.d_scaling == 'sqrt':
                d_lr = d_lr * np.sqrt(hvd.size())
            elif args.d_scaling == 'linear':
                d_lr = d_lr * hvd.size()
            elif args.d_scaling == 'none':
                pass
            else:
                raise ValueError(args.d_scaling)

        # d_lr = tf.Variable(d_lr, name='d_lr', dtype=tf.float32)
        # g_lr = tf.Variable(g_lr, name='g_lr', dtype=tf.float32)

        # # optimizer_gen = tf.train.AdamOptimizer(learning_rate=g_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_disc = tf.train.AdamOptimizer(learning_rate=d_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_gen = LAMB(learning_rate=g_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_disc = LAMB(learning_rate=d_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_gen = LARSOptimizer(learning_rate=g_lr, momentum=0, weight_decay=0)
        # # optimizer_disc = LARSOptimizer(learning_rate=d_lr, momentum=0, weight_decay=0)

        # # optimizer_gen = tf.train.RMSPropOptimizer(learning_rate=1e-3)
        # # optimizer_disc = tf.train.RMSPropOptimizer(learning_rate=1e-3)
        # # optimizer_gen = tf.train.GradientDescentOptimizer(learning_rate=1e-3)
        # # optimizer_disc = tf.train.GradientDescentOptimizer(learning_rate=1e-3)
        # # optimizer_gen = RAdamOptimizer(learning_rate=g_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_disc = RAdamOptimizer(learning_rate=d_lr, beta1=args.beta1, beta2=args.beta2)

        # lr_step = tf.Variable(0, name='step', dtype=tf.float32)
        # update_step = lr_step.assign_add(1.0)

        # with tf.control_dependencies([update_step]):
        #     update_g_lr = g_lr.assign(g_lr * args.g_annealing)
        #     update_d_lr = d_lr.assign(d_lr * args.d_annealing)

        # if args.horovod:
        #     if args.use_adasum:
        #         # optimizer_gen = hvd.DistributedOptimizer(optimizer_gen, op=hvd.Adasum)
        #         optimizer_gen = hvd.DistributedOptimizer(optimizer_gen)
        #         optimizer_disc = hvd.DistributedOptimizer(optimizer_disc, op=hvd.Adasum)
        #     else:
        #         optimizer_gen = hvd.DistributedOptimizer(optimizer_gen)
        #         optimizer_disc = hvd.DistributedOptimizer(optimizer_disc)

        # ------------------------------------------------------------------------------------------#
        # NETWORKS

        with tf.variable_scope('alpha'):
            alpha = tf.Variable(1, name='alpha', dtype=tf.float32)
            # Alpha init
            init_alpha = alpha.assign(1)

            # Specify alpha update op for mixing phase.
            num_steps = args.mixing_nimg // (batch_size * global_size)
            alpha_update = 1 / num_steps
            # noinspection PyTypeChecker
            update_alpha = alpha.assign(tf.maximum(alpha - alpha_update, 0))

        base_shape = [image_channels, 4, 4]

        if args.optim_strategy == 'simultaneous':
            gen_loss, disc_loss, gp_loss, gen_sample = forward_simultaneous(
                generator,
                discriminator,
                real_image_input,
                args.latent_dim,
                alpha,
                phase,
                num_phases,
                base_dim,
                base_shape,
                args.activation,
                args.leakiness,
                args.network_size,
                args.loss_fn,
                args.gp_weight,
                conditioning=real_label,
            )

            gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         scope='generator')
            disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='discriminator')

            with tf.variable_scope('optimizer_gen'):
                # disc_loss = tf.Print(gen_loss, [gen_loss], 'g_loss')
                optimizer_gen = create_optimizer(
                    gen_loss,
                    gen_vars,
                    1e-8, (args.mixing_nimg + args.stabilizing_nimg) /
                    (batch_size * global_size),
                    8,
                    hvd=hvd,
                    optimizer_type='adam')

            with tf.variable_scope('optimizer_disc'):
                # disc_loss = tf.Print(disc_loss, [disc_loss], 'd_loss')
                optimizer_disc = create_optimizer(
                    disc_loss,
                    disc_vars,
                    1e-8, (args.mixing_nimg + args.stabilizing_nimg) /
                    (batch_size * global_size),
                    8,
                    hvd=hvd,
                    optimizer_type='lamb')

            # if args.horovod:
            #     if args.use_adasum:
            #         # optimizer_gen = hvd.DistributedOptimizer(optimizer_gen, op=hvd.Adasum)
            #         optimizer_gen = hvd.DistributedOptimizer(optimizer_gen, sparse_as_dense=True)
            #         optimizer_disc = hvd.DistributedOptimizer(optimizer_disc, op=hvd.Adasum, sparse_as_dense=True)
            #     else:
            #         optimizer_gen = hvd.DistributedOptimizer(optimizer_gen, sparse_as_dense=True)
            #         optimizer_disc = hvd.DistributedOptimizer(optimizer_disc, sparse_as_dense=True)

            # g_gradients = optimizer_gen.compute_gradients(gen_loss, var_list=gen_vars)
            # d_gradients = optimizer_disc.compute_gradients(disc_loss, var_list=disc_vars)

            # g_norms = tf.stack([tf.norm(grad) for grad, var in g_gradients if grad is not None])
            # max_g_norm = tf.reduce_max(g_norms)
            # d_norms = tf.stack([tf.norm(grad) for grad, var in d_gradients if grad is not None])
            # max_d_norm = tf.reduce_max(d_norms)

            # # g_clipped_grads = [(tf.clip_by_norm(grad, clip_norm=128), var) for grad, var in g_gradients]
            # # train_gen = optimizer_gen.apply_gradients(g_clipped_grads)
            # gs = t
            # train_gen = optimizer_gen.apply_gradients(g_gradients)
            # train_disc = optimizer_disc.apply_gradients(d_gradients)

        # elif args.optim_strategy == 'alternate':

        #     disc_loss, gp_loss = forward_discriminator(
        #         generator,
        #         discriminator,
        #         real_image_input,
        #         args.latent_dim,
        #         alpha,
        #         phase,
        #         num_phases,
        #         base_dim,
        #         base_shape,
        #         args.activation,
        #         args.leakiness,
        #         args.network_size,
        #         args.loss_fn,
        #         args.gp_weight,
        #         conditioning=real_label
        #     )

        #     # disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
        #     # d_gradients = optimizer_disc.compute_gradients(disc_loss, var_list=disc_vars)
        #     # d_norms = tf.stack([tf.norm(grad) for grad, var in d_gradients if grad is not None])
        #     # max_d_norm = tf.reduce_max(d_norms)

        #     # train_disc = optimizer_disc.apply_gradients(d_gradients)

        #     with tf.control_dependencies([train_disc]):
        #         gen_sample, gen_loss = forward_generator(
        #             generator,
        #             discriminator,
        #             real_image_input,
        #             args.latent_dim,
        #             alpha,
        #             phase,
        #             num_phases,
        #             base_dim,
        #             base_shape,
        #             args.activation,
        #             args.leakiness,
        #             args.network_size,
        #             args.loss_fn,
        #             is_reuse=True
        #         )

        #         gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
        #         g_gradients = optimizer_gen.compute_gradients(gen_loss, var_list=gen_vars)
        #         g_norms = tf.stack([tf.norm(grad) for grad, var in g_gradients if grad is not None])
        #         max_g_norm = tf.reduce_max(g_norms)
        #         train_gen = optimizer_gen.apply_gradients(g_gradients)

        else:
            raise ValueError("Unknown optim strategy ", args.optim_strategy)

        if verbose:
            print(f"Generator parameters: {count_parameters('generator')}")
            print(
                f"Discriminator parameters:: {count_parameters('discriminator')}"
            )

        # train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
        # train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)

        ema = tf.train.ExponentialMovingAverage(decay=args.ema_beta)
        ema_op = ema.apply(gen_vars)
        # Transfer EMA values to original variables
        ema_update_weights = tf.group(
            [tf.assign(var, ema.average(var)) for var in gen_vars])

        with tf.name_scope('summaries'):
            # Summaries
            tf.summary.scalar('d_loss', disc_loss)
            tf.summary.scalar('g_loss', gen_loss)
            tf.summary.scalar('gp', tf.reduce_mean(gp_loss))

            # for g in g_gradients:
            #     tf.summary.histogram(f'grad_{g[1].name}', g[0])

            # for g in d_gradients:
            #     tf.summary.histogram(f'grad_{g[1].name}', g[0])

            # tf.summary.scalar('convergence', tf.reduce_mean(disc_real) - tf.reduce_mean(tf.reduce_mean(disc_fake_d)))

            # tf.summary.scalar('max_g_grad_norm', max_g_norm)
            # tf.summary.scalar('max_d_grad_norm', max_d_norm)

            real_image_grid = tf.transpose(real_image_input,
                                           (0, 2, 3, 1))  # D H W C  -> B H W C
            shape = real_image_grid.get_shape().as_list()
            grid_cols = int(2**np.floor(np.log(np.sqrt(shape[0])) / np.log(2)))
            grid_rows = shape[0] // grid_cols
            grid_shape = [grid_rows, grid_cols]
            real_image_grid = image_grid(real_image_grid,
                                         grid_shape,
                                         image_shape=shape[1:3],
                                         num_channels=shape[-1])

            fake_image_grid = tf.transpose(gen_sample, (0, 2, 3, 1))
            fake_image_grid = image_grid(fake_image_grid,
                                         grid_shape,
                                         image_shape=shape[1:3],
                                         num_channels=shape[-1])

            fake_image_grid = tf.clip_by_value(fake_image_grid, -1, 1)

            tf.summary.image('real_image', real_image_grid)
            tf.summary.image('fake_image', fake_image_grid)

            tf.summary.scalar('fake_image_min', tf.math.reduce_min(gen_sample))
            tf.summary.scalar('fake_image_max', tf.math.reduce_max(gen_sample))

            tf.summary.scalar('real_image_min',
                              tf.math.reduce_min(real_image_input[0]))
            tf.summary.scalar('real_image_max',
                              tf.math.reduce_max(real_image_input[0]))
            tf.summary.scalar('alpha', alpha)

            tf.summary.scalar('g_lr', g_lr)
            tf.summary.scalar('d_lr', d_lr)

            merged_summaries = tf.summary.merge_all()

        # Other ops
        init_op = tf.global_variables_initializer()
        assign_starting_alpha = alpha.assign(args.starting_alpha)
        assign_zero = alpha.assign(0)
        broadcast = hvd.broadcast_global_variables(0)

        with tf.Session(config=config) as sess:
            sess.run(init_op)

            trainable_variable_names = [
                v.name for v in tf.trainable_variables()
            ]

            if var_list is not None and phase > args.starting_phase:
                print("Restoring variables from:",
                      os.path.join(logdir, f'model_{phase - 1}'))
                var_names = [v.name for v in var_list]
                load_vars = [
                    sess.graph.get_tensor_by_name(n) for n in var_names
                    if n in trainable_variable_names
                ]
                saver = tf.train.Saver(load_vars)
                saver.restore(sess, os.path.join(logdir, f'model_{phase - 1}'))
            elif var_list is not None and args.continue_path and phase == args.starting_phase:
                print("Restoring variables from:", args.continue_path)
                var_names = [v.name for v in var_list]
                load_vars = [
                    sess.graph.get_tensor_by_name(n) for n in var_names
                    if n in trainable_variable_names
                ]
                saver = tf.train.Saver(load_vars)
                saver.restore(sess, os.path.join(args.continue_path))
            else:
                if verbose:
                    print("Not restoring variables.")
                    print("Variable List Length:", len(var_list))

            var_list = gen_vars + disc_vars

            if phase < args.starting_phase:
                continue

            if phase == args.starting_phase:
                sess.run(assign_starting_alpha)
            else:
                sess.run(init_alpha)

            if verbose:
                print(f"Begin mixing epochs in phase {phase}")
            if args.horovod:
                sess.run(broadcast)

            local_step = 0
            # take_first_snapshot = True

            while True:
                start = time.time()
                if local_step % 128 == 0 and local_step > 1:
                    if args.horovod:
                        sess.run(broadcast)
                    saver = tf.train.Saver(var_list)
                    if verbose:
                        saver.save(
                            sess,
                            os.path.join(logdir,
                                         f'model_{phase}_ckpt_{global_step}'))

                # _, _, summary, d_loss, g_loss = sess.run(
                #      [train_gen, train_disc, merged_summaries,
                #       disc_loss, gen_loss])

                _, _, summary, d_loss, g_loss = sess.run([
                    optimizer_gen, optimizer_disc, merged_summaries, disc_loss,
                    gen_loss
                ])

                global_step += batch_size * global_size
                local_step += 1

                end = time.time()
                img_s = global_size * batch_size / (end - start)
                if verbose:

                    writer.add_summary(summary, global_step)
                    writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag='img_s', simple_value=img_s)
                        ]), global_step)
                    memory_percentage = psutil.Process(
                        os.getpid()).memory_percent()
                    writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag='memory_percentage',
                                             simple_value=memory_percentage)
                        ]), global_step)

                    print(f"Step {global_step:09} \t"
                          f"img/s {img_s:.2f} \t "
                          f"d_loss {d_loss:.4f} \t "
                          f"g_loss {g_loss:.4f} \t "
                          f"memory {memory_percentage:.4f} % \t"
                          f"alpha {alpha.eval():.2f}")

                    # if take_first_snapshot:
                    #     import tracemalloc
                    #     tracemalloc.start()
                    #     snapshot_first = tracemalloc.take_snapshot()
                    #     take_first_snapshot = False

                    # snapshot = tracemalloc.take_snapshot()
                    # top_stats = snapshot.compare_to(snapshot_first, 'lineno')
                    # print("[ Top 10 differences ]")
                    # for stat in top_stats[:10]:
                    #     print(stat)
                    # snapshot_prev = snapshot

                if global_step >= ((phase - args.starting_phase) *
                                   (args.mixing_nimg + args.stabilizing_nimg) +
                                   args.mixing_nimg):
                    break

                sess.run(update_alpha)
                sess.run(ema_op)
                # sess.run(update_d_lr)
                # sess.run(update_g_lr)

                assert alpha.eval() >= 0

                if verbose:
                    writer.flush()

            if verbose:
                print(f"Begin stabilizing epochs in phase {phase}")

            sess.run(assign_zero)

            while True:
                start = time.time()
                assert alpha.eval() == 0
                if local_step % 128 == 0 and local_step > 0:

                    if args.horovod:
                        sess.run(broadcast)
                    saver = tf.train.Saver(var_list)
                    if verbose:
                        saver.save(
                            sess,
                            os.path.join(logdir,
                                         f'model_{phase}_ckpt_{global_step}'))

                # _, _, summary, d_loss, g_loss = sess.run(
                #      [train_gen, train_disc, merged_summaries,
                #       disc_loss, gen_loss])

                _, _, summary, d_loss, g_loss = sess.run([
                    optimizer_gen, optimizer_disc, merged_summaries, disc_loss,
                    gen_loss
                ])

                global_step += batch_size * global_size
                local_step += 1

                end = time.time()
                img_s = global_size * batch_size / (end - start)
                if verbose:
                    writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag='img_s', simple_value=img_s)
                        ]), global_step)
                    writer.add_summary(summary, global_step)
                    memory_percentage = psutil.Process(
                        os.getpid()).memory_percent()
                    writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag='memory_percentage',
                                             simple_value=memory_percentage)
                        ]), global_step)

                    print(f"Step {global_step:09} \t"
                          f"img/s {img_s:.2f} \t "
                          f"d_loss {d_loss:.4f} \t "
                          f"g_loss {g_loss:.4f} \t "
                          f"memory {memory_percentage:.4f} % \t"
                          f"alpha {alpha.eval():.2f}")

                sess.run(ema_op)

                if verbose:
                    writer.flush()

                if global_step >= (phase - args.starting_phase + 1) * (
                        args.stabilizing_nimg + args.mixing_nimg):
                    # if verbose:
                    #     run_metadata = tf.RunMetadata()
                    #     opts = tf.profiler.ProfileOptionBuilder.float_operation()
                    #     g = tf.get_default_graph()
                    #     flops = tf.profiler.profile(g, run_meta=run_metadata, cmd='op', options=opts)
                    #     writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='graph_flops',
                    #                                                           simple_value=flops.total_float_ops)]),
                    #                        global_step)
                    #
                    #     # Print memory info.
                    #     try:
                    #         print(nvgpu.gpu_info())
                    #     except subprocess.CalledProcessError:
                    #         pid = os.getpid()
                    #         py = psutil.Process(pid)
                    #         print(f"CPU Percent: {py.cpu_percent()}")
                    #         print(f"Memory info: {py.memory_info()}")

                    break

            # # Calculate metrics.
            # calc_swds: bool = size >= 16
            # calc_ssims: bool = min(npy_data.shape[1:]) >= 16
            #
            # if args.calc_metrics:
            #     fids_local = []
            #     swds_local = []
            #     psnrs_local = []
            #     mses_local = []
            #     nrmses_local = []
            #     ssims_local = []
            #
            #     counter = 0
            #     while True:
            #         if args.horovod:
            #             start_loc = counter + hvd.rank() * batch_size
            #         else:
            #             start_loc = 0
            #         real_batch = np.stack([npy_data[i] for i in range(start_loc, start_loc + batch_size)])
            #         real_batch = real_batch.astype(np.int16) - 1024
            #         fake_batch = sess.run(gen_sample).astype(np.float32)
            #
            #         # Turn fake batch into HUs and clip to training range.
            #         fake_batch = (np.clip(fake_batch, -1, 2) * 1024).astype(np.int16)
            #
            #         if verbose:
            #             print('real min, max', real_batch.min(), real_batch.max())
            #             print('fake min, max', fake_batch.min(), fake_batch.max())
            #
            #         fids_local.append(calculate_fid_given_batch_volumes(real_batch, fake_batch, sess))
            #
            #         if calc_swds:
            #             swds = get_swd_for_volumes(real_batch, fake_batch)
            #             swds_local.append(swds)
            #
            #         psnr = get_psnr(real_batch, fake_batch)
            #         if calc_ssims:
            #             ssim = get_ssim(real_batch, fake_batch)
            #             ssims_local.append(ssim)
            #         mse = get_mean_squared_error(real_batch, fake_batch)
            #         nrmse = get_normalized_root_mse(real_batch, fake_batch)
            #
            #         psnrs_local.append(psnr)
            #         mses_local.append(mse)
            #         nrmses_local.append(nrmse)
            #
            #         if args.horovod:
            #             counter = counter + global_size * batch_size
            #         else:
            #             counter += batch_size
            #
            #         if counter >= args.num_metric_samples:
            #             break
            #
            #     fid_local = np.mean(fids_local)
            #     psnr_local = np.mean(psnrs_local)
            #     ssim_local = np.mean(ssims_local)
            #     mse_local = np.mean(mses_local)
            #     nrmse_local = np.mean(nrmses_local)
            #
            #     if args.horovod:
            #         fid = MPI.COMM_WORLD.allreduce(fid_local, op=MPI.SUM) / hvd.size()
            #         psnr = MPI.COMM_WORLD.allreduce(psnr_local, op=MPI.SUM) / hvd.size()
            #         mse = MPI.COMM_WORLD.allreduce(mse_local, op=MPI.SUM) / hvd.size()
            #         nrmse = MPI.COMM_WORLD.allreduce(nrmse_local, op=MPI.SUM) / hvd.size()
            #         if calc_ssims:
            #             ssim = MPI.COMM_WORLD.allreduce(ssim_local, op=MPI.SUM) / hvd.size()
            #     else:
            #         fid = fid_local
            #         psnr = psnr_local
            #         ssim = ssim_local
            #         mse = mse_local
            #         nrmse = nrmse_local
            #
            #     if calc_swds:
            #         swds_local = np.array(swds_local)
            #         # Average over batches
            #         swds_local = swds_local.mean(axis=0)
            #         if args.horovod:
            #             swds = MPI.COMM_WORLD.allreduce(swds_local, op=MPI.SUM) / hvd.size()
            #         else:
            #             swds = swds_local
            #
            #     if verbose:
            #         print(f"FID: {fid:.4f}")
            #         writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='fid',
            #                                                               simple_value=fid)]),
            #                            global_step)
            #
            #         print(f"PSNR: {psnr:.4f}")
            #         writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='psnr',
            #                                                               simple_value=psnr)]),
            #                            global_step)
            #
            #         print(f"MSE: {mse:.4f}")
            #         writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='mse',
            #                                                               simple_value=mse)]),
            #                            global_step)
            #
            #         print(f"Normalized Root MSE: {nrmse:.4f}")
            #         writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='nrmse',
            #                                                               simple_value=nrmse)]),
            #                            global_step)
            #
            #         if calc_swds:
            #             print(f"SWDS: {swds}")
            #             for i in range(len(swds))[:-1]:
            #                 lod = 16 * 2 ** i
            #                 writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=f'swd_{lod}',
            #                                                                       simple_value=swds[
            #                                                                           i])]),
            #                                    global_step)
            #             writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=f'swd_mean',
            #                                                                   simple_value=swds[
            #                                                                       -1])]), global_step)
            #         if calc_ssims:
            #             print(f"SSIM: {ssim}")
            #             writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=f'ssim',
            #                                                                   simple_value=ssim)]), global_step)

            if verbose:
                print("\n\n\n End of phase.")

                # Save Session.
                sess.run(ema_update_weights)
                saver = tf.train.Saver(var_list)
                saver.save(sess, os.path.join(logdir, f'model_{phase}'))

            if args.ending_phase:
                if phase == args.ending_phase:
                    print("Reached final phase, breaking.")
                    break
Ejemplo n.º 9
0
 def __init__(self, spatial_input_dims=None):
     self.spatial_input_dims = parse_tuple(spatial_input_dims)
Ejemplo n.º 10
0
 def __init__(self, input_dims=None, output_dims=None, prefix=''):
     self.input_dims = utils.parse_tuple(input_dims)
     self.output_dims = utils.parse_tuple(output_dims)
     self.prefix = prefix
Ejemplo n.º 11
0
def main(args, config):
    phase = args.phase
    if args.horovod:
        verbose = hvd.rank() == 0
        global_size = hvd.size()
        global_rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        verbose = True
        global_size = 1
        global_rank = 0
        local_rank = 0

    if verbose:
        timestamp = time.strftime("%Y-%m-%d_%H:%M", time.gmtime())
        logdir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                              'generated_samples', timestamp)
        os.makedirs(logdir)
    else:
        logdir = None

    if args.horovod:
        logdir = MPI.COMM_WORLD.bcast(logdir, root=0)

    if verbose:
        print("Arguments passed:")
        print(args)
        print(f"Saving files to {logdir}")

    tf.reset_default_graph()
    # Get Dataset.

    final_shape = parse_tuple(args.final_shape)
    final_resolution = final_shape[-1]
    num_phases = int(np.log2(final_resolution) - 1)
    size = 2 * 2**phase
    data_path = os.path.join(args.dataset_path, f'{size}x{size}/')
    npy_data = NumpyPathDataset(data_path,
                                None,
                                copy_files=False,
                                is_correct_phase=False)
    dataset = tf.data.Dataset.from_tensor_slices(npy_data.scratch_files)

    batch_size = 1

    if args.horovod:
        dataset.shard(hvd.size(), hvd.rank())

    def load(x):
        x = np.load(x.numpy().decode('utf-8'))[np.newaxis, ...]
        return x

    # Lay out the graph.
    dataset = dataset.shuffle(len(npy_data))
    dataset = dataset.map(
        lambda x: tf.py_function(func=load, inp=[x], Tout=tf.uint16),
        num_parallel_calls=AUTOTUNE)
    dataset = dataset.map(lambda x: tf.cast(x, tf.float32) / 1024 - 1,
                          num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(AUTOTUNE)
    dataset = dataset.make_one_shot_iterator()
    real_image_input = dataset.get_next()
    real_image_input = tf.ensure_shape(real_image_input,
                                       [batch_size] + list(npy_data.shape))
    # real_image_input = real_image_input + tf.random.normal(tf.shape(real_image_input)) * .01

    with tf.variable_scope('alpha'):
        alpha = tf.Variable(0, name='alpha', dtype=tf.float32)

    zdim_base = max(1, final_shape[1] // (2**(num_phases - 1)))
    base_shape = (1, zdim_base, 4, 4)

    noise_input_d = tf.random.normal(
        shape=[tf.shape(real_image_input)[0], args.latent_dim])
    gen_sample_d = generator(noise_input_d,
                             alpha,
                             phase,
                             num_phases,
                             args.base_dim,
                             base_shape,
                             activation=args.activation,
                             param=args.leakiness)

    if verbose:
        print(f"Generator parameters: {count_parameters('generator')}")

    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 scope='generator')
    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='discriminator')

    real_image_grid = tf.transpose(real_image_input[0], (1, 2, 3, 0))
    shape = real_image_grid.get_shape().as_list()
    grid_cols = int(2**np.floor(np.log(np.sqrt(shape[0])) / np.log(2)))
    grid_rows = shape[0] // grid_cols
    grid_shape = [grid_rows, grid_cols]
    real_image_grid = image_grid(real_image_grid,
                                 grid_shape,
                                 image_shape=shape[1:3],
                                 num_channels=shape[-1])

    fake_image_grid = tf.transpose(gen_sample_d[0], (1, 2, 3, 0))
    fake_image_grid = image_grid(fake_image_grid,
                                 grid_shape,
                                 image_shape=shape[1:3],
                                 num_channels=shape[-1])

    with tf.Session(config=config) as sess:

        sess.run(tf.global_variables_initializer())

        trainable_variable_names = [v.name for v in tf.trainable_variables()]
        var_list = gen_vars + disc_vars
        var_names = [v.name for v in var_list]
        load_vars = [
            sess.graph.get_tensor_by_name(n) for n in var_names
            if n in trainable_variable_names
        ]
        saver = tf.train.Saver(load_vars)
        saver.restore(sess, os.path.join(args.model_path))

        if args.horovod:
            sess.run(hvd.broadcast_global_variables(0))

        num_samples = args.num_samples // global_size

        calc_swds: bool = size >= 16
        calc_ssims: bool = min(npy_data.shape[1:]) >= 16

        fids_local = []
        swds_local = []
        psnrs_local = []
        mses_local = []
        nrmses_local = []
        ssims_local = []

        for i in tqdm(range(num_samples)):

            ix = (global_rank + i * global_size)
            real_batch, fake_batch, grid_real, grid_fake = sess.run([
                real_image_input, gen_sample_d, real_image_grid,
                fake_image_grid
            ])

            # fake_batch = sess.run(real_image_input)

            grid_real = np.squeeze(grid_real)
            grid_fake = np.squeeze(grid_fake)

            imageio.imwrite(os.path.join(logdir, f'grid_real_{ix}.png'),
                            grid_real)
            imageio.imwrite(os.path.join(logdir, f'grid_fake_{ix}.png'),
                            grid_fake)

            fake_batch = (np.clip(fake_batch, -1, 2) * 1024).astype(np.int16)
            real_batch = (np.clip(real_batch, -1, 2) * 1024).astype(np.int16)

            # fake_batch = real_batch

            assert real_batch.min() < -512
            assert fake_batch.min() < -512

            fids_local.append(
                calculate_fid_given_batch_volumes(real_batch, fake_batch,
                                                  sess))
            if calc_swds:
                swds = get_swd_for_volumes(real_batch, fake_batch)
                swds_local.append(swds)

            psnr = get_psnr(real_batch, fake_batch)
            if calc_ssims:
                ssim = get_ssim(real_batch, fake_batch)
                ssims_local.append(ssim)

            mse = get_mean_squared_error(real_batch, fake_batch)
            nrmse = get_normalized_root_mse(real_batch, fake_batch)

            psnrs_local.append(psnr)
            mses_local.append(mse)
            nrmses_local.append(nrmse)

            save_path = os.path.join(logdir, f'{ix}.npy')
            np.save(save_path, fake_batch)

        fid_local = np.stack(fids_local).mean(0)
        psnr_local = np.mean(psnrs_local)
        ssim_local = np.mean(ssims_local)
        mse_local = np.mean(mses_local)
        nrmse_local = np.mean(nrmses_local)

        if args.horovod:
            fid = MPI.COMM_WORLD.allreduce(fid_local, op=MPI.SUM) / hvd.size()
            psnr = MPI.COMM_WORLD.allreduce(psnr_local,
                                            op=MPI.SUM) / hvd.size()
            mse = MPI.COMM_WORLD.allreduce(mse_local, op=MPI.SUM) / hvd.size()
            nrmse = MPI.COMM_WORLD.allreduce(nrmse_local,
                                             op=MPI.SUM) / hvd.size()
            if calc_ssims:
                ssim = MPI.COMM_WORLD.allreduce(ssim_local,
                                                op=MPI.SUM) / hvd.size()
        else:
            fid = fid_local
            psnr = psnr_local
            ssim = ssim_local
            mse = mse_local
            nrmse = nrmse_local

        if calc_swds:
            swds_local = np.array(swds_local)
            # Average over batches
            swds_local = swds_local.mean(axis=0)
            if args.horovod:
                swds = MPI.COMM_WORLD.allreduce(swds_local,
                                                op=MPI.SUM) / hvd.size()
            else:
                swds = swds_local

        summary_str = ""
        if verbose:
            summary_str += f"FIDS: {fid.tolist()} \n\n"
            summary_str += f"FID: {fid.mean():.4f} \n"
            summary_str += f"PSNR: {psnr:.4f} \n"
            summary_str += f"MSE: {mse:.4f} \n"
            summary_str += f"Normalized Root MSE: {nrmse:.4f} \n"
            if calc_swds:
                summary_str += f"SWDS: {swds} \n"
            if calc_ssims:
                summary_str += f"SSIM: {ssim} \n"

        if verbose:
            with open(os.path.join(logdir, 'summary.txt'), 'w') as f:
                f.write(summary_str)