def dump_graph(): import tensorflow as tf from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.keras.backend import get_graph tb_path = './tensorboard-new' tb_writer = tf.summary.create_file_writer(tb_path) with tb_writer.as_default(): summary_ops_v2.graph(get_graph(), step=0)
def graph(self, model): from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.keras import backend as K with self.train_writer.as_default(): with summary_ops_v2.always_record_summaries(): if not model.run_eagerly: summary_ops_v2.graph(K.get_graph(), step=0)
def write_epoch_models(self, mode: str) -> None: with self.tf_summary_writers[mode].as_default(), summary_ops_v2.always_record_summaries(): summary_ops_v2.graph(backend.get_graph(), step=0) for model in self.network.epoch_models: summary_writable = (model.__class__.__name__ == 'Sequential' or (hasattr(model, '_is_graph_network') and model._is_graph_network)) if summary_writable: summary_ops_v2.keras_model(model.model_name, model, step=0)
def train(model, tensor_log, manager, init_epoch, train_set, test_set): logdir = os.path.join(params.logdir, model.name) if not os.path.exists(logdir): os.makedirs(logdir) train_writer = tf.summary.create_file_writer(os.path.join(logdir, 'train')) test_writer = tf.summary.create_file_writer(os.path.join(logdir, 'test')) with train_writer.as_default(): summary_ops_v2.graph(K.get_graph(), step=0) train_loss = tf.keras.metrics.Mean(name='train_loss') train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') test_loss = tf.keras.metrics.Mean(name='test_loss') test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') train_step = get_train_step(model, train_loss, train_accuracy) test_step = get_test_step(model, test_loss, test_accuracy) log_step = get_log_step(tensor_log, train_writer) do_callbacks('on_train_begin', model.callbacks) for epoch in range(init_epoch, params.training.epochs): do_callbacks('on_epoch_begin', model.callbacks, epoch=epoch) # Reset the metrics train_loss.reset_states() train_accuracy.reset_states() test_loss.reset_states() test_accuracy.reset_states() tf.keras.backend.set_learning_phase(1) for batch, (images, labels) in enumerate(train_set): do_callbacks('on_batch_begin', model.callbacks, batch=batch) train_step(images, labels) if batch == 0 and params.training.log: log_step(images, labels, epoch) do_callbacks('on_batch_end', model.callbacks, batch=batch) do_callbacks('on_epoch_end', model.callbacks, epoch=epoch) # Get the metric results train_loss_result = float(train_loss.result()) train_accuracy_result = float(train_accuracy.result()) with train_writer.as_default(): tf.summary.scalar('loss', train_loss_result, step=epoch+1) tf.summary.scalar('accuracy', train_accuracy_result, step=epoch+1) # Run a test loop at the end of each epoch. tf.keras.backend.set_learning_phase(0) for images, labels in test_set: test_step(images, labels) # Get the metric results test_loss_result = float(test_loss.result()) test_accuracy_result = float(test_accuracy.result()) with test_writer.as_default(): tf.summary.scalar('loss', test_loss_result, step=epoch+1) tf.summary.scalar('accuracy', test_accuracy_result, step=epoch+1) print('Epoch:{}, train acc:{:f}, test acc:{:f}'.format(epoch + 1, train_accuracy_result, test_accuracy_result)) if (params.training.save_frequency != 0 and epoch % params.training.save_frequency == 0) or epoch == params.training.epochs-1: save_path = manager.save() print("Saved checkpoint for step {}: {}".format(model.optimizer.iterations.numpy(), save_path))
def testGraphSummary(self): training_util.get_or_create_global_step() name = 'hi' graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) with summary_ops.always_record_summaries(): with self.create_db_writer().as_default(): summary_ops.graph(graph) six.assertCountEqual(self, [name], get_all(self.db, 'SELECT node_name FROM Nodes'))
def _save_graph(self): state = tf.zeros((1, self.model.get_state_dim(), 1), dtype=tf.float64) action = tf.zeros((1, self.model.get_action_dim(), 1), dtype=tf.float64) with self.writer.as_default(): graph = tf.function( self.model.build_step_graph).get_concrete_function( "graph", state, action).graph # visualize summary_ops_v2.graph(graph.as_graph_def())
def save_graph(self, function, graphMode=True): state = tf.zeros((self._sDim, 1), dtype=tf.float64) seq = tf.zeros((self._tau, self._aDim, 1), dtype=tf.float64) with self._writer.as_default(): if graphMode: graph = function.get_concrete_function(1, state, seq).graph else: graph = tf.function(function).get_concrete_function( 1, state, seq).graph # visualize summary_ops_v2.graph(graph.as_graph_def())
def _init_writer(self, model): """Sets file writer.""" if context.executing_eagerly(): self.writer = summary_ops_v2.create_file_writer(self.log_dir) if not model.run_eagerly and self.write_graph: with self.writer.as_default(): summary_ops_v2.graph(K.get_graph(), step=0) elif self.write_graph: self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph()) else: self.writer = tf_summary.FileWriter(self.log_dir)
def _init_writer(self, model): """Sets file writer.""" if context.executing_eagerly(): self.writer = summary_ops_v2.create_file_writer(self.log_dir) if not model.run_eagerly and self.write_graph: with self.writer.as_default(): summary_ops_v2.graph(K.get_graph()) elif self.write_graph: self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph()) else: self.writer = tf_summary.FileWriter(self.log_dir)
def write_graph(self, model: k.Model): """Sets Keras model and writes graph if specified.""" if model and self.is_write_graph: with self.writer.as_default(), summary_ops_v2.always_record_summaries(): if not model.run_eagerly: summary_ops_v2.graph(get_graph(), step=0) summary_writable = ( model._is_graph_network or # pylint: disable=protected-access model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access if summary_writable: summary_ops_v2.keras_model('keras', model, step=0)
def on_begin(self, state): if self.write_graph: with self.summary_writers['train'].as_default(): with summary_ops_v2.always_record_summaries(): summary_ops_v2.graph(backend.get_graph(), step=0) for name, model in self.network.model.items(): summary_writable = (model._is_graph_network or model.__class__.__name__ == 'Sequential') if summary_writable: summary_ops_v2.keras_model(name, model, step=0) if self.embeddings_freq: self._configure_embeddings()
def write_graph_summary(graph, summary_dir, **kwargs): """ Writes the summary of a *graph* to a directory *summary_dir* using a ``tf.summary.FileWriter`` (v1) or ``tf.summary.create_file_writer`` (v2). This summary can be used later on to visualize the graph via tensorboard. *graph* can be either a graph object or a path to a protobuf file. In the latter case, :py:func:`load_graph` is used and all *kwargs* are forwarded. .. note:: When used with TensorFlow v1, eager mode must be disabled. """ # prepare the summary dir if not os.path.exists(summary_dir): os.makedirs(summary_dir) # read the graph when a string is passed if isinstance(graph, six.string_types): graph = load_graph(graph, create_session=False, **kwargs) # further handling is version dependent tf, tf1, tf_version = import_tf() if tf_version[0] == "1": # switch to non-eager mode for the FileWriter to work eager = getattr(tf1, "executing_eagerly", lambda: False)() if eager: tf1.disable_eager_execution() # write to file writer = tf1.summary.FileWriter(summary_dir) writer.add_graph(graph) # reset the eager mode if eager: tf1.enable_eager_execution() else: # 2.X from tensorflow.python.ops import summary_ops_v2 as summary_ops # create the writer writer = tf.summary.create_file_writer(summary_dir) # write the graph with writer.as_default(): # the graph summary op requires a step argument prior to 2.5 graph_kwargs = {} if tf_version[1] < "5": graph_kwargs["step"] = 0 summary_ops.graph(graph.as_graph_def(), **graph_kwargs) # close writer.close()
def write_keras_graph(self, model: Union[Model, Sequential], step: int = 0, name: str = "keras"): r""" Writes Keras graph networks to TensorBoard. """ with self.summary_writer.as_default(): with summary_ops_v2.always_record_summaries(): if not model.run_eagerly: summary_ops_v2.graph(keras.backend.get_graph(), step=step) summary_writable = (model._is_graph_network or model.__class__.__name__ == 'Sequential') if summary_writable: summary_ops_v2.keras_model(name=str(name), data=model, step=step) return self
def write_epoch_models(self, mode: str, epoch: int) -> None: with self.tf_summary_writers[mode].as_default( ), summary_ops_v2.always_record_summaries(): # Record the overall execution summary if hasattr(self.network._forward_step_static, '_concrete_stateful_fn'): # noinspection PyProtectedMember summary_ops_v2.graph(self.network._forward_step_static. _concrete_stateful_fn.graph) # Record the individual model summaries for model in self.network.epoch_models: summary_writable = (model.__class__.__name__ == 'Sequential' or (hasattr(model, '_is_graph_network') and model._is_graph_network)) if summary_writable: keras_model_summary(model.model_name, model, step=epoch)
def fit(self, train_set, test_set): logdir = os.path.join(self.params.logdir, self.model.name) if not os.path.exists(logdir): os.makedirs(logdir) train_writer = tf.summary.create_file_writer( os.path.join(logdir, 'train')) test_writer = tf.summary.create_file_writer( os.path.join(logdir, 'test')) with train_writer.as_default(): summary_ops_v2.graph(K.get_graph(), step=0) do_callbacks('on_train_begin', self.model.callbacks) for epoch in range(self.init_epoch, self.params.training.epochs): self.train_one_epoch(train_set, train_writer, epoch) self.evaluate(test_set, test_writer, epoch) self.checkpoint(self.manager, epoch + 1, self.params.training.save_frequency) self.checkpoint(self.manager, self.params.training.epochs, 0)
def set_model(self, model): """Sets Keras model and writes graph if specified.""" self.model = model.model with context.eager_mode(): self._close_writers() if self.write_graph: with self._get_writer(self._train_run_name).as_default(): with summary_ops_v2.always_record_summaries(): if not self.model.run_eagerly: summary_ops_v2.graph(K.get_graph(), step=0) summary_writable = ( self.model._is_graph_network or # pylint: disable=protected-access self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access if summary_writable: summary_ops_v2.keras_model('keras', self.model, step=0) if self.embeddings_freq: self._configure_embeddings()
n_actions = len(CarActionWrapper.ACTIONS) policy = EpsGreedyQPolicy(eps=0.05) memory = SequentialMemory(limit=args.memory_limit, window_length=args.window_length) model = SimpleBiModel().get_model(args.window_length, n_actions) tb_log_dir = 'tensorboard' tb_logs = f'{tb_log_dir}/{datetime.now().strftime("%Y%m%d-%H%M%S")}' graph_dir = f'{tb_logs}/graph' writer = tf.summary.create_file_writer(logdir=graph_dir) # save the graph with writer.as_default(): summary_ops_v2.graph(K.get_graph(), step=0) writer.close() agent = DQNAgent( model=model, nb_actions=n_actions, policy=policy, memory=memory, nb_steps_warmup=args.warmup_steps, gamma=.99, target_model_update=args.target_model_update, train_interval=args.train_interval, delta_clip=1., enable_dueling_network=True) agent.compile(Adam(lr=args.learning_rate), metrics=['mae'])
def main(args): # init environment env = Env(do_render=args.render).reset_sim() # Init tf writer writer = tf.summary.create_file_writer('./logs') # Compile agent # get agent config with open(args.rl_config, "rt") as f: rl_config = json.load(f) # get net config with open(args.net_config, "rt") as f: net_config = json.load(f) with writer.as_default(): state_space = (96, 96) action_space = np.arange(9) # Construct Q Function QFun = getQNetFunc(state_space, action_space, BuildConvNet, **net_config) # Consturct Policy for key, val in rl_config["pi"]["kwargs_eval"].copy().items(): rl_config["pi"]["kwargs_eval"][key] = eval(val) policy = eval( rl_config["pi"]["type"])(action_space, **rl_config["pi"]["kwargs"], **rl_config["pi"]["kwargs_eval"]) # Construct Utility for key, val in rl_config["util"]["kwargs_eval"].copy().items(): rl_config["util"]["kwargs_eval"][key] = eval(val) UFun = eval(rl_config["util"]["type"])( action_space, **rl_config["util"]["kwargs"], **rl_config["util"]["kwargs_eval"]) # Assemble agent a = QAgent( state_space=state_space, action_space=action_space, ).setQ(QFun).setU(UFun).setPolicy(policy) a.setTrain(getSimpleQLOptim(args.gamma, args.alpha, a.QModel, a.UModel)) a.compile() a = RandomReplay(a, args.memory_span) if args.log: summary_ops_v2.graph(a.trainFN.outputs[0].graph, step=0) # Start action loop cumulative_r = 0.0 history_r = np.zeros(args.steps) state = env.start_state[None].astype(np.float32) progress_bar = tqdm(range(args.steps)) for i in progress_bar: # Act in the environment new_state, reward, info = env.get_action(a(state)) new_state = new_state[None].astype(np.float32) # Observe transition a.observe(new_state, reward) a.learn() # Update state state = new_state # Stats cumulative_r += reward history_r[i] = reward progress_bar.set_description("cumulative reward: %.3f" % cumulative_r) progress_bar.refresh() # Compute metrics cumulative_avg = np.cumsum( history_r[:i]) / (np.arange(len(history_r[:i])) + 1) # Dump plot of results plt.plot(np.arange(i), cumulative_avg, label="average") plt.plot(np.arange(i), history_r[:i], label="reward") plt.title("Reward over time") plt.ylabel("Reward") plt.xlabel("Time-Step") plt.legend() plt.savefig(args.save_to + ".png", dpi=900) data = {} data["args"] = args.__dict__ data["results"] = { "tot": cumulative_r, "history": list(history_r), "avg": list(cumulative_avg) } with open(args.save_to + ".json", "wt") as f: json.dump(data, f)
def train(model, model_log, manager, init_epoch, shift_train, shift_test, aff_test): logdir = os.path.join(params.logdir, model.name) if not os.path.exists(logdir): os.makedirs(logdir) train_writer = tf.summary.create_file_writer(os.path.join(logdir, 'train')) test1_writer = tf.summary.create_file_writer(os.path.join(logdir, 'test1')) test2_writer = tf.summary.create_file_writer(os.path.join(logdir, 'test2')) with train_writer.as_default(): summary_ops_v2.graph(K.get_graph(), step=0) loss = tf.keras.metrics.Mean(name='loss') acc = tf.keras.metrics.Mean(name='acc') train_step = get_train_step(model) test1_step = get_test_step(model) test2_step = get_test_step(model) train_log_step = get_log_step(model_log, train_writer) test1_log_step = get_log_step(model_log, test1_writer) test2_log_step = get_log_step(model_log, test2_writer) do_callbacks('on_train_begin', model.callbacks) for epoch in range(init_epoch, params.training.epochs): do_callbacks('on_epoch_begin', model.callbacks, epoch=epoch) # Reset the metrics loss.reset_states() acc.reset_states() step = 0 tf.keras.backend.set_learning_phase(1) for batch, (images, labels) in enumerate(shift_train): do_callbacks('on_batch_begin', model.callbacks, batch=batch) pred_loss, prediction = train_step(images, labels) loss.update_state(pred_loss) acc.update_state(get_difference(labels, prediction)) step = model.optimizer.iterations.numpy() if step > params.training.steps: break if step % params.training.log_steps == 0 and params.training.log: tf.keras.backend.set_learning_phase(0) train_log_step(images, labels, step) # Get the metric results train_loss_result = float(loss.result()) train_acc_result = float(acc.result()) with train_writer.as_default(): tf.summary.scalar('loss', train_loss_result, step=step) tf.summary.scalar('accuracy', train_acc_result, step=step) loss.reset_states() acc.reset_states() if step % (10 * params.training.log_steps) == 0: # shift mnist log_batch = np.random.randint(0, 500) for batch, (images, labels) in enumerate(shift_test): pred_loss, prediction = test1_step(images, labels) loss.update_state(pred_loss) acc.update_state(get_difference(labels, prediction)) if batch == log_batch: test1_log_step(images, labels, step) # Get the metric results test1_loss_result = float(loss.result()) test1_acc_result = float(acc.result()) with test1_writer.as_default(): tf.summary.scalar('loss', test1_loss_result, step=step) tf.summary.scalar('accuracy', test1_acc_result, step=step) loss.reset_states() acc.reset_states() # aff mnist log_batch = np.random.randint(0, 500) for batch, (images, labels) in enumerate(aff_test): pred_loss, prediction = test2_step(images, labels) loss.update_state(pred_loss) acc.update_state(get_difference(labels, prediction)) if batch == log_batch: test2_log_step(images, labels, step) # Get the metric results test2_loss_result = float(loss.result()) test2_acc_result = float(acc.result()) with test2_writer.as_default(): tf.summary.scalar('loss', test2_loss_result, step=step) tf.summary.scalar('accuracy', test2_acc_result, step=step) loss.reset_states() acc.reset_states() tf.keras.backend.set_learning_phase(1) do_callbacks('on_batch_end', model.callbacks, batch=batch) do_callbacks('on_epoch_end', model.callbacks, epoch=epoch) if (params.training.save_frequency != 0 and epoch % params.training.save_frequency == 0) or epoch == params.training.epochs - 1: save_path = manager.save() print("Saved checkpoint for step {}: {}".format( model.optimizer.iterations.numpy(), save_path)) if step > params.training.steps: break
def train(model, model_log, manager, init_epoch, train_set, test_set): logdir = os.path.join(params.logdir, model.name) if not os.path.exists(logdir): os.makedirs(logdir) train_writer = tf.summary.create_file_writer(os.path.join(logdir, 'train')) test_writer = tf.summary.create_file_writer(os.path.join(logdir, 'test')) with train_writer.as_default(): summary_ops_v2.graph(K.get_graph(), step=0) loss = tf.keras.metrics.Mean(name='loss') acc_both = tf.keras.metrics.Mean(name='acc_both') acc_part = tf.keras.metrics.Mean(name='acc_part') train_step = get_train_step(model) test_step = get_test_step(model) train_log_step = get_log_step(model_log, train_writer) test_log_step = get_log_step(model_log, test_writer) do_callbacks('on_train_begin', model.callbacks) for epoch in range(init_epoch, params.training.epochs): do_callbacks('on_epoch_begin', model.callbacks, epoch=epoch) # Reset the metrics loss.reset_states() acc_both.reset_states() acc_part.reset_states() tf.keras.backend.set_learning_phase(1) for batch, features in enumerate(train_set): images, labels, image1, label1, image2, label2 = parse_feature( features) do_callbacks('on_batch_begin', model.callbacks, batch=batch) pred_loss, predictions = train_step(images, labels, image1, image2, label1, label2) # Update the metrics loss.update_state(pred_loss) acc1, acc2 = get_difference(labels, predictions, params.recons.threshold) acc_both.update_state(acc1) acc_part.update_state(acc2) step = model.optimizer.iterations.numpy() if step > params.training.steps: break if step % params.training.log_steps == 0 and params.training.log: train_log_step(images, labels, image1, image2, label1, label2, step) # Get the metric results train_loss_result = float(loss.result()) train_acc_both_result = float(acc_both.result()) train_acc_part_result = float(acc_part.result()) with train_writer.as_default(): tf.summary.scalar('loss', train_loss_result, step=step) tf.summary.scalar('accuracy_both', train_acc_both_result, step=step) tf.summary.scalar('accuracy_part', train_acc_part_result, step=step) loss.reset_states() acc_both.reset_states() acc_part.reset_states() tf.keras.backend.set_learning_phase(0) log_batch = np.random.randint(0, 500) for batch, features in enumerate(test_set): images, labels, image1, label1, image2, label2 = parse_feature( features) pred_loss, predictions = test_step(images, labels, image1, image2, label1, label2) # Update the metrics loss.update_state(pred_loss) acc1, acc2 = get_difference(labels, predictions, params.recons.threshold) acc_both.update_state(acc1) acc_part.update_state(acc2) if batch == log_batch: test_log_step(images, labels, image1, image2, label1, label2, step) # Get the metric results test_loss_result = float(loss.result()) test_acc_both = float(acc_both.result()) test_acc_part = float(acc_part.result()) with test_writer.as_default(): tf.summary.scalar('loss', test_loss_result, step=step) tf.summary.scalar('accuracy_both', test_acc_both, step=step) tf.summary.scalar('accuracy_part', test_acc_part, step=step) loss.reset_states() acc_both.reset_states() acc_part.reset_states() tf.keras.backend.set_learning_phase(1) do_callbacks('on_batch_end', model.callbacks, batch=batch) do_callbacks('on_epoch_end', model.callbacks, epoch=epoch) if (params.training.save_frequency != 0 and epoch % params.training.save_frequency == 0) or epoch == params.training.epochs - 1: save_path = manager.save() print("Saved checkpoint for step {}: {}".format( model.optimizer.iterations.numpy(), save_path)) if step > params.training.steps: break
def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self): with self.assertRaises(TypeError): summary_ops.graph(ops.Graph()) with self.assertRaises(TypeError): summary_ops.graph('')
import warnings import logging import os warnings.filterwarnings("ignore") logging.disable(logging.WARNING) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf from tensorflow.python.ops import summary_ops_v2 # Graph a = tf.Variable(2, name='a') b = tf.Variable(3, name='b') @tf.function # tf.function allows us to take a graph from a function def graph_to_visualize(a, b): c = tf.add(a, b, name='Add') # Visualize writer = tf.summary.create_file_writer('./graphs') with writer.as_default(): graph = graph_to_visualize.get_concrete_function( a, b).graph # get graph from function summary_ops_v2.graph(graph.as_graph_def()) # visualize writer.close()
def on_train_begin(self, logs=None): """ At the begining of training, write the graph to the tensorboard. """ super(MetricReductionCallback, self).on_train_begin(logs) if self._should_summary: summary_ops_v2.graph(K.get_graph(), step=0)
# Visualize a giant graph import tensorflow as tf from tensorflow.python.ops import summary_ops_v2 a = tf.Variable(2.0, name='a') b = tf.Variable(3.0, name='b') c = tf.Variable(7.0, name='c') @tf.function def graph_to_visualize(a, b, c): d = tf.multiply(a, b, name='d-mul') e = tf.add(b, c, name='e-add') f = tf.subtract(e, a, name='f-sub') g = tf.multiply(d, b, name='g-mul') h = tf.divide(g, d, name='h-div') i = tf.add(h, f, name='i-add') writer = tf.summary.create_file_writer('./graphs') with writer.as_default(): # get graph from function graph = graph_to_visualize.get_concrete_function(a, b, c).graph # visualize summary_ops_v2.graph(graph.as_graph_def()) writer.close()
def run(self, batch_size: int, output_dir: str, image_type: CqDataType, scale: int): gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) print(f"Number of GPUS: {len(gpus)}") cq_dataset = CqData(image_type, scale_down=scale) dataset = tf.data.Dataset.from_tensor_slices(cq_dataset.images) \ .shuffle(len(cq_dataset.images)) \ .repeat() \ .batch(batch_size, drop_remainder=True) output_dir = f"{output_dir}_{cq_dataset.get_type_str()}" output_num_x = 8 output_num_y = 8 num_output_image = output_num_x * output_num_y img_width = cq_dataset.get_image_width() img_height = cq_dataset.get_image_height() num_channel = cq_dataset.get_channel_count() lr = 0.0002 z_size = 256 num_cat = cq_dataset.get_count() # num_cat = 0 output_interval = 100 opt_g = keras.optimizers.Adam(lr) opt_d = keras.optimizers.Adam(lr) opt_g = tf.keras.mixed_precision.LossScaleOptimizer(opt_g) opt_d = tf.keras.mixed_precision.LossScaleOptimizer(opt_d) if strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() elif strategy_type == "tpu": resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=f"grpc://{os.environ['COLAB_TPU_ADDR']}") tf.config.experimental_connect_to_host(resolver.master()) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) else: raise Exception(f"Wrong strategy type: {strategy_type}") with strategy.scope(): gen = Generator(4, img_width, img_height, num_channel) disc = Discriminator(4, num_cat) gen.optimizer = opt_g disc.optimizer = opt_d iwgan = IWGanLoss(disc) input_g = Input(batch_size=batch_size, shape=z_size + num_cat, name="z") input_d = Input(batch_size=batch_size, shape=(img_height, img_width, num_channel), name="real_images") input_eps = Input(batch_size=batch_size, shape=(1, 1, 1), name="eps") disc_gen: tf.Tensor disc_real: tf.Tensor cat_output: tf.Tensor gen_images: tf.Tensor = gen(input_g) x_pn = input_eps * input_d + (1 - input_eps) * gen_images iwgan_loss = iwgan(x_pn) disc_gen, cat_output = disc(gen_images) disc_real, _ = disc(input_d) model_g = Model(inputs=[input_g], outputs=[disc_gen, cat_output]) model_d = Model(inputs=[input_g, input_d, input_eps], outputs=[disc_gen, disc_real, iwgan_loss, cat_output]) dataset = strategy.experimental_distribute_dataset(dataset) data_it = iter(dataset) # Model model_dir = os.path.join(output_dir, "model_save") gen_model_dir = os.path.join(model_dir, "gen") disc_model_dir = os.path.join(model_dir, "disc") gen_ckpt = tf.train.Checkpoint(model=gen, optimizer=gen.optimizer) disc_ckpt = tf.train.Checkpoint(model=disc, optimizer=disc.optimizer) gen_ckpt_mgr = tf.train.CheckpointManager(gen_ckpt, gen_model_dir, max_to_keep=5) disc_ckpt_mgr = tf.train.CheckpointManager(disc_ckpt, disc_model_dir, max_to_keep=5) gen_latest = gen_ckpt_mgr.latest_checkpoint disc_latest = disc_ckpt_mgr.latest_checkpoint with strategy.scope(): if gen_latest: gen_ckpt.restore(gen_latest) if disc_latest: disc_ckpt.restore(disc_latest) if mode == "train": # Image output image_dir = os.path.join(output_dir, "images") helper.clean_create_dir(image_dir) test_input_images = cq_dataset.get_ordered_batch( num_output_image, False) test_input_images = tf.convert_to_tensor(test_input_images) test_input_images = CqGAN.convert_to_save_format( test_input_images, output_num_x, output_num_y, img_width, img_height, num_channel) tf.io.write_file(os.path.join(image_dir, "input.png"), test_input_images) # Summary summary_dir = "cq_log" summary_writer = tf.summary.create_file_writer(summary_dir) metrics = { "disc_gen": keras.metrics.Mean(), "disc_real": keras.metrics.Mean(), "loss_gen": keras.metrics.Mean(), "loss_real": keras.metrics.Mean(), "loss_cat": keras.metrics.Mean(), "iwgan_loss": keras.metrics.Mean() } summary_writer.set_as_default() graph(K.get_graph(), step=0) # Test variables z_fixed, _, cat_fixed = self.generate_fixed_z( num_output_image, z_size, num_cat) # Begin training begin = datetime.datetime.now() for step in range(num_iter): train_step(strategy, data_it, disc, gen, model_g, model_d, batch_size, z_size, num_cat, metrics) # Output if step % output_interval == 0 and step != 0: now = datetime.datetime.now() diff = now - begin begin = now output_count = step // output_interval output_filename = f"output{output_count}.png" output_images = gen(z_fixed) output_images = CqGAN.convert_to_save_format( output_images, output_num_x, output_num_y, img_width, img_height, num_channel) tf.io.write_file(os.path.join(image_dir, output_filename), output_images) tf.saved_model.save(disc, disc_model_dir) tf.saved_model.save(gen, gen_model_dir) gen_ckpt_mgr.save() disc_ckpt_mgr.save() print( f"{output_count * output_interval} times done: {diff.total_seconds()}s passed, " f"loss_gen: {metrics['loss_gen'].result()}, " f"loss_real: {metrics['loss_real'].result()}, " f"loss_cat: {metrics['loss_cat'].result()}") # Summary for metric_name, metric in metrics.items(): tf.summary.scalar(f"loss/{metric_name}", metric.result(), step) metric.reset_states() elif mode == "predict": # Image output image_dir = os.path.join(output_dir, "predict_images") helper.clean_create_dir(image_dir) prr = plot_utils.Plot_Reproduce_Performance( image_dir, output_num_x, output_num_y, img_width, img_height, scale) with strategy.scope(): fixed_z, real_z, _ = self.generate_nocat_z( batch_size, z_size, num_cat) fixed_cat_input = self.generate_fixed_cat(batch_size, num_cat) fixed_z2 = tf.concat([real_z, fixed_cat_input], 1) input_images1 = predict_image(strategy, gen, fixed_z).numpy() input_images2 = predict_image(strategy, gen, fixed_z2).numpy() prr.save_pngs(input_images1, num_channel, "input1.png") prr.save_pngs(input_images2, num_channel, "input2.png") for i in range(num_cat): cat_input = self.generate_fixed_cat( batch_size, num_cat, [i]) input_z = tf.concat([real_z, cat_input], 1) gen_images = predict_image(strategy, gen, input_z) gen_images = gen_images.numpy() output_filename = f"output{i + 1}.png" prr.save_pngs(gen_images, num_channel, output_filename)