def main():
    # Load model from checkpoint
    writer = tf.summary.create_file_writer("%s/log/generated" % args.log_dir)
    optimizer = optimizers.Adam()
    model = GammaCapsuleNetwork(args.num_classes)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    checkpoint.restore("%s/ckpt/ckpt-%d" % (args.log_dir, args.ckpt))

    # Generate images that activate a capsule
    # Note: The batch size defines how many samples we collect i.e. 60
    caps = 0
    while True:
        x = tf.random.uniform((args.num_samples, 28, 28, 1),
                              minval=0.1,
                              maxval=0.9)

        with tf.device("/GPU:1"):
            print("Layer %d | Capsule %d" % (args.layer, caps))
            x, logits, num_capsules = get_img_for_caps(x, model, caps)

        with writer.as_default():
            x = tf.squeeze(x)
            img = utils.plot_generated_image(x, logits)
            x = utils.plot_to_image(img)
            tf.summary.image("Layer %d/Capsule %d" % (args.layer, caps),
                             x,
                             step=args.ckpt)

        caps += 1
        if caps >= num_capsules:
            break
Exemplo n.º 2
0
 def _save_val_predictions_plot(self):
     fig = plot_data(self.val_predictions)
     img = utils.plot_to_image(fig)
     summary = tf.Summary(value=[
         tf.Summary.Value(tag="Val predictions",
                          image=tf.Summary.Image(
                              encoded_image_string=img, height=6, width=6))
     ])
     self.valid_summary_writer.add_summary(summary, self._epochs_training)
Exemplo n.º 3
0
    def make_grid(self, *args):
        samples = self.model.generate_samples(self.latents, training=False)
        if self.show_blurred_samples:
            samples = self.model.blur(samples)

        samples = utils.normalize_images(samples)
        figure = utils.samples_grid(samples)  # TODO: write figure to a file?
        figure.savefig(self.log_dir + f"/samples_grid_{self.samples_seen:06}.png")
        image = utils.plot_to_image(figure)
        with self.model.summary_writer.as_default():
            tf.summary.image("samples_grid", image)
Exemplo n.º 4
0
 def log_redshift_scatter(self, epoch, predictions, test_name):
     log_dir = self.tensorboard_callback.log_dir + '/images'
     # TODO: create just once or close?
     file_writer = tf.summary.create_file_writer(log_dir)
     scatter_plot = redshift_scatter_plot(predictions,
                                          z_pred_col='Z_PHOTO',
                                          color_column='Z_PHOTO_STDDEV',
                                          z_max=4,
                                          return_figure=True)
     scatter_image = plot_to_image(scatter_plot)
     with file_writer.as_default():
         tf.summary.image('redshift scatter - {}'.format(test_name),
                          scatter_image,
                          step=epoch)
Exemplo n.º 5
0
 def log_confusion_matrix(self, epoch, predictions, y_true, test_name):
     log_dir = self.tensorboard_callback.log_dir + '/images'
     file_writer = tf.summary.create_file_writer(log_dir)
     class_names = ['GALAXY', 'QSO', 'STAR']
     y_true_decoded = [class_names[i] for i in np.argmax(y_true, axis=1)]
     cm = confusion_matrix(y_true_decoded, predictions['CLASS_PHOTO'])
     cm_fig = plot_confusion_matrix(cm,
                                    classes=class_names,
                                    normalize=False,
                                    title=None,
                                    return_figure=True)
     cm_image = plot_to_image(cm_fig)
     with file_writer.as_default():
         tf.summary.image('confusion matrix - {}'.format(test_name),
                          cm_image,
                          step=epoch)
Exemplo n.º 6
0
        layers.Conv2D(8, 3, padding="same", activation="relu"),
        layers.Conv2D(16, 3, padding="same", activation="relu"),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(64, activation="relu"),
        layers.Dropout(0.1),
        layers.Dense(10),
    ])

    return model


model = get_model()
num_epochs = 1
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam(lr=0.001)
acc_metric = keras.metrics.SparseCategoricalAccuracy()
writer = tf.summary.create_file_writer("logs/train/")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (x, y) in enumerate(ds_train):
        figure = image_grid(x, y, class_names)

        with writer.as_default():
            tf.summary.image(
                "Visualize Images",
                plot_to_image(figure),
                step=step,
            )
            step += 1
Exemplo n.º 7
0
        D.append([state, acc_reward, action, td_err_default])
        if (episode % 5) == 0:
            with file_writer_rewards.as_default():
                tf.summary.histogram('action_taken', acc_actions, step=episode)
            print(
                f"Episode {episode}: Reward {acc_reward} with action {action_meanings[action]} which was {'explored' if do_explore else 'greedy'}"
            )
            env.render()

    # Wrap up
    loss = history.history.get("loss", [0])[0]
    time_end = np.round(time.time() - start_time, 2)
    memory_usage = process.memory_info().rss
    tmp = random.choice(experience_batch)
    # print(tmp.shape)
    episode_image = plot_to_image(
        image_grid(tmp, env.unwrapped.get_action_meanings()))
    print(f"Loss of episode {episode} is {loss} and took {time_end} seconds")
    print(f"TOTAL REWARD: {np.sum(episode_rewards)}")
    with file_writer_rewards.as_default():
        tf.summary.scalar('episode_rewards',
                          np.sum(episode_rewards),
                          step=episode)
        tf.summary.scalar('episode_loss', loss, step=episode)
        tf.summary.scalar('episode_time_in_secs', time_end, step=episode)
        tf.summary.scalar('episode_nr_frames', frame_cnt, step=episode)
        tf.summary.scalar('episode_exploration_rate',
                          exploration_rate,
                          step=episode)
        tf.summary.scalar('episode_mem_usage', memory_usage, step=episode)
        tf.summary.scalar('episode_mem_usage_in_GB',
                          np.round(memory_usage / 1024 / 1024 / 1024),
Exemplo n.º 8
0
def main():
    args = parse_args()

    # set random seed
    utils.seed_torch(args.seed)

    # Setup CUDA, GPU
    if not torch.cuda.is_available():
        print("cuda is not available")
        exit(0)
    else:
        args.device = torch.device("cuda")
        args.n_gpus = torch.cuda.device_count()
        print(f"available cuda: {args.n_gpus}")

    # Setup model
    model = MelanomaNet(arch=args.arch)
    if args.n_gpus > 1:
        model = torch.nn.DataParallel(module=model)
    model.to(args.device)
    model_path = f'{configure.MODEL_PATH}/{args.arch}_fold_{args.fold}.pth'

    # Setup data
    total_batch_size = args.per_gpu_batch_size * args.n_gpus
    train_loader, valid_loader = datasets.get_dataloader(
        image_dir=configure.TRAIN_IMAGE_PATH,
        fold=args.fold,
        batch_size=total_batch_size,
        num_workers=args.num_workers)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.BCEWithLogitsLoss()
    # criterion = MarginFocalBCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=args.learning_rate,
                                  weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=5,
                                                gamma=0.5)
    """ Train the model """
    current_time = datetime.now().strftime('%b%d_%H_%M_%S')
    log_dir = f'{configure.TRAINING_LOG_PATH}/{args.arch}_fold_{args.fold}_{current_time}'

    tb_writer = None
    if args.log:
        tb_writer = SummaryWriter(log_dir=log_dir)

    print(f'training started: {current_time}')
    best_score = 0.0
    for epoch in range(args.epochs):
        train_loss = train(dataloader=train_loader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           args=args)

        valid_loss, y_true, y_score = valid(dataloader=valid_loader,
                                            model=model,
                                            criterion=criterion,
                                            args=args)

        valid_score = roc_auc_score(y_true=y_true, y_score=y_score)

        learning_rate = scheduler.get_lr()[0]
        if args.log:
            tb_writer.add_scalar("learning_rate", learning_rate, epoch)
            tb_writer.add_scalar("Loss/train", train_loss, epoch)
            tb_writer.add_scalar("Loss/valid", valid_loss, epoch)
            tb_writer.add_scalar("Score/valid", valid_score, epoch)

            # Log the roc curve as an image summary.
            figure = utils.plot_roc_curve(y_true=y_true, y_score=y_score)
            figure = utils.plot_to_image(figure)
            tb_writer.add_image("ROC curve", figure, epoch)

        if valid_score > best_score:
            best_score = valid_score
            state = {
                'state_dict': model.module.state_dict(),
                'train_loss': train_loss,
                'valid_loss': valid_loss,
                'valid_score': valid_score
            }
            torch.save(state, model_path)

        current_time = datetime.now().strftime('%b%d_%H_%M_%S')
        print(
            f"epoch:{epoch:02d}, "
            f"train:{train_loss:0.3f}, valid:{valid_loss:0.3f}, "
            f"score:{valid_score:0.3f}, best:{best_score:0.3f}, date:{current_time}"
        )

        scheduler.step()

    current_time = datetime.now().strftime('%b%d_%H_%M_%S')
    print(f'training finished: {current_time}')

    if args.log:
        tb_writer.close()
Exemplo n.º 9
0
def main():
    args = parse_args()

    # set random seed
    utils.seed_torch(args.seed)

    # Setup CUDA, GPU
    if not torch.cuda.is_available():
        print("cuda is not available")
        exit(0)
    else:
        args.device = torch.device("cuda")
        args.n_gpus = torch.cuda.device_count()
        print(f"available cuda: {args.n_gpus}")

    # Setup model
    model = PandaNet(arch=args.arch, num_classes=1)
    model_path = os.path.join(
        configure.MODEL_PATH,
        f'{args.arch}_fold_{args.fold}_{args.tile_size}_{args.num_tiles}.pth')
    if args.resume:
        assert os.path.exists(model_path), "checkpoint does not exist"
        state_dict = torch.load(model_path)
        valid_score = state_dict['valid_score']
        threshold = state_dict['threshold']
        print(
            f"load model from checkpoint, threshold: {threshold}, valid score: {state_dict['valid_score']:0.3f}"
        )
        model.load_state_dict(state_dict['state_dict'])
        best_score = valid_score
        args.learning_rate = 3e-05
    else:
        best_score = 0.0

    if args.n_gpus > 1:
        model = torch.nn.DataParallel(module=model)
    model.to(args.device)

    # Setup data
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    print(f"loading data: {current_time}")
    filename = f"train_images_level_{args.level}_{args.tile_size}_{args.num_tiles}.npy"
    data = np.load(os.path.join(configure.DATA_PATH, filename),
                   allow_pickle=True)
    print(f"data loaded: {datetime.now().strftime('%b%d_%H-%M-%S')}")

    total_batch_size = args.per_gpu_batch_size * args.n_gpus
    train_loader, valid_loader = datasets.get_dataloader(
        data=data,
        fold=args.fold,
        batch_size=total_batch_size,
        num_workers=args.num_workers)

    # define loss function (criterion) and optimizer
    if args.loss == "l1":
        criterion = torch.nn.L1Loss()
    elif args.loss == "mse":
        criterion = torch.nn.MSELoss()
    elif args.loss == "smooth_l1":
        criterion = torch.nn.SmoothL1Loss()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=15,
                                                gamma=0.5)
    """ Train the model """
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_prefix = f'{current_time}_{args.arch}_fold_{args.fold}_{args.tile_size}_{args.num_tiles}'
    log_dir = os.path.join(configure.TRAINING_LOG_PATH, log_prefix)

    tb_writer = None
    if args.log:
        tb_writer = SummaryWriter(log_dir=log_dir)

    print(f'training started: {current_time}')
    for epoch in range(args.epochs):
        train_loss = train(dataloader=train_loader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           args=args)

        valid_loss, valid_score, valid_cm, threshold = valid(
            dataloader=valid_loader,
            model=model,
            criterion=criterion,
            args=args)

        learning_rate = scheduler.get_lr()[0]
        if args.log:
            tb_writer.add_scalar("learning_rate", learning_rate, epoch)
            tb_writer.add_scalar("Loss/train", train_loss, epoch)
            tb_writer.add_scalar("Loss/valid", valid_loss, epoch)
            tb_writer.add_scalar("Score/valid", valid_score, epoch)

            # Log the confusion matrix as an image summary.
            figure = utils.plot_confusion_matrix(
                valid_cm, class_names=[0, 1, 2, 3, 4, 5], score=valid_score)
            cm_image = utils.plot_to_image(figure)
            tb_writer.add_image("Confusion Matrix valid", cm_image, epoch)

        if valid_score > best_score:
            best_score = valid_score
            state = {
                'state_dict': model.module.state_dict(),
                'train_loss': train_loss,
                'valid_loss': valid_loss,
                'valid_score': valid_score,
                'threshold': np.sort(threshold),
                'mean': data.item().get('mean'),
                'std': data.item().get('std')
            }
            torch.save(state, model_path)

        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        print(
            f"epoch:{epoch:02d}, "
            f"train:{train_loss:0.3f}, valid:{valid_loss:0.3f}, "
            f"threshold: {np.sort(threshold)}, "
            f"score:{valid_score:0.3f}, best:{best_score:0.3f}, date:{current_time}"
        )

        scheduler.step()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    print(f'training finished: {current_time}')

    if args.log:
        tb_writer.close()
Exemplo n.º 10
0
def train(train_ds, test_ds, class_names):
    """ Train capsule networks mirrored on multiple gpu's
  """

    # Run training for multiple epochs mirrored on multiple gpus
    strategy = tf.distribute.MirroredStrategy()
    num_replicas = strategy.num_replicas_in_sync

    train_ds = strategy.experimental_distribute_dataset(train_ds)
    test_ds = strategy.experimental_distribute_dataset(test_ds)

    # Create a checkpoint directory to store the checkpoints.
    ckpt_dir = os.path.join(args.log_dir, "ckpt/", "ckpt")

    train_writer = tf.summary.create_file_writer("%s/log/train" % args.log_dir)
    test_writer = tf.summary.create_file_writer("%s/log/test" % args.log_dir)

    with strategy.scope():
        model = CapsNet(args)
        optimizer = tf.optimizers.Adam(learning_rate=args.learning_rate)
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

        # Define metrics
        test_loss = tf.keras.metrics.Mean(name='test_loss')

        # Function for a single training step
        def train_step(inputs):
            x, y = inputs
            with tf.GradientTape() as tape:
                logits, reconstruction, layers = model(x, y)
                loss, _ = compute_loss(logits, y, reconstruction, x)

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            acc = compute_accuracy(logits, y)

            return loss, acc, (x, reconstruction)

        # Function for a single test step
        def test_step(inputs):
            x, y = inputs
            logits, reconstruction, _ = model(x, y)
            loss, _ = compute_loss(logits, y, reconstruction, x)

            test_loss.update_state(loss)
            acc = compute_accuracy(logits, y)

            pred = tf.math.argmax(logits, axis=1)
            cm = tf.math.confusion_matrix(y, pred, num_classes=10)
            return acc, cm

        # Define functions for distributed training
        def distributed_train_step(dataset_inputs):
            return strategy.run(train_step, args=(dataset_inputs, ))

        def distributed_test_step(dataset_inputs):
            return strategy.run(test_step, args=(dataset_inputs, ))

        if args.enable_tf_function:
            distributed_train_step = tf.function(distributed_train_step)
            distributed_test_step = tf.function(distributed_test_step)

        # Loop for multiple epochs
        step = 0
        max_acc = 0.0
        for epoch in range(args.epochs):
            ########################################
            # Test
            ########################################
            if args.test:
                cm = np.zeros((10, 10))
                test_acc = []
                for data in test_ds:
                    distr_acc, distr_cm = distributed_test_step(data)
                    for r in range(num_replicas):
                        if num_replicas > 1:
                            cm += distr_cm.values[r]
                            test_acc.append(distr_acc.values[r].numpy())
                        else:
                            cm += distr_cm
                            test_acc.append(distr_acc)

                # Log test results (for replica 0 only for activation map and reconstruction)
                test_acc = np.mean(test_acc)
                max_acc = test_acc if test_acc > max_acc else max_acc
                figure = utils.plot_confusion_matrix(cm.numpy(), class_names)
                cm_image = utils.plot_to_image(figure)
                print("TEST | epoch %d (%d): acc=%.4f, loss=%.4f" %
                      (epoch, step, test_acc, test_loss.result()),
                      flush=True)

                with test_writer.as_default():
                    tf.summary.image("Confusion Matrix", cm_image, step=step)
                    tf.summary.scalar("General/Accuracy", test_acc, step=step)
                    tf.summary.scalar("General/Loss",
                                      test_loss.result(),
                                      step=step)
                test_loss.reset_states()
                test_writer.flush()

            ########################################
            # Train
            ########################################
            for data in train_ds:
                start = time.time()
                distr_loss, distr_acc, distr_imgs = distributed_train_step(
                    data)
                train_loss = tf.reduce_mean(
                    distr_loss.values) if num_replicas > 1 else distr_loss
                acc = tf.reduce_mean(
                    distr_acc.values) if num_replicas > 1 else distr_acc

                # Logging
                if step % 100 == 0:
                    time_per_step = (time.time() - start) * 1000 / 100
                    print(
                        "TRAIN | epoch %d (%d): acc=%.4f, loss=%.4f | Time per step[ms]: %.2f"
                        %
                        (epoch, step, acc, train_loss.numpy(), time_per_step),
                        flush=True)

                    # Create some recon tensorboard images (only GPU 0)
                    if args.use_reconstruction:
                        x = distr_imgs[0].values[
                            0] if num_replicas > 1 else distr_imgs[0]
                        recon_x = distr_imgs[1].values[
                            0] if num_replicas > 1 else distr_imgs[1]
                        recon_x = tf.reshape(recon_x, [
                            -1,
                            tf.shape(x)[1],
                            tf.shape(x)[2], args.img_depth
                        ])
                        x = tf.reshape(x, [
                            -1,
                            tf.shape(x)[1],
                            tf.shape(x)[2], args.img_depth
                        ])
                        img = tf.concat([x, recon_x], axis=1)
                        with train_writer.as_default():
                            tf.summary.image(
                                "X & Recon",
                                img,
                                step=step,
                                max_outputs=3,
                            )

                    with train_writer.as_default():
                        # Write scalars
                        tf.summary.scalar("General/Accuracy", acc, step=step)
                        tf.summary.scalar("General/Loss",
                                          train_loss.numpy(),
                                          step=step)

                    start = time.time()
                    train_writer.flush()

                step += 1

            ####################
            # Checkpointing
            if epoch % 15 == 0:
                checkpoint.save(ckpt_dir)

        return max_acc
Exemplo n.º 11
0
def train(train_ds, all_test_ds, class_names):
  """ Train gamma-capsule networks mirrored on multiple gpu's
  """
  test_ds, test_ds_0 = all_test_ds

  # Run training for multiple epochs mirrored on multiple gpus
  strategy = tf.distribute.MirroredStrategy()
  num_replicas = strategy.num_replicas_in_sync
  train_ds = strategy.experimental_distribute_dataset(train_ds)
  test_ds = strategy.experimental_distribute_dataset(test_ds)
  test_ds_0 = strategy.experimental_distribute_dataset(test_ds_0)

  # Create a checkpoint directory to store the checkpoints.
  ckpt_dir = os.path.join(args.log_dir, "ckpt/", "ckpt")

  train_writer = tf.summary.create_file_writer("%s/log/train" % args.log_dir)
  test_writer = tf.summary.create_file_writer("%s/log/test [0-9]" % args.log_dir)
  test_writer_0 = tf.summary.create_file_writer("%s/log/test [0]" % args.log_dir)

  with strategy.scope():
    model = GammaCapsuleNetwork(args.num_classes)
    optimizer = optimizers.Adam(learning_rate=args.learning_rate)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

    # Define metrics 
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    train_t_score = tf.keras.metrics.Mean(name='train_t_score')
    train_d_score = tf.keras.metrics.Mean(name='train_d_score')
    
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_t_score = tf.keras.metrics.Mean(name='test_t_score')
    test_d_score = tf.keras.metrics.Mean(name='test_d_score')
    
    # Function for a single training step
    def train_step(inputs):
      # Note: Here we do emperical risk minimization under attack
      x, y = inputs
      x_adv = attack.pgd(x, y, model, eps=0.1, a=0.01, k=40) if args.gamma_robust else x      
      with tf.GradientTape() as tape:
        logits, reconstruction, _, T, D = model(x_adv, y)

        # We want to reconstruct the original x rather than x_adv , because 
        # small pert. should represent the same image
        loss, _ = compute_loss(logits, y, reconstruction, x)
      
      grads = tape.gradient(loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      train_accuracy.update_state(y, logits)
      train_t_score.update_state(T)
      train_d_score.update_state(D)
      return loss, (x, x_adv, reconstruction)

    # Function for a single test step
    def test_step(inputs):
      x, y = inputs
      logits, reconstruction, layers, T, D = model(x, y)
      loss, _ = compute_loss(logits, y, reconstruction, x)
      
      test_accuracy.update_state(y, logits)
      test_loss.update_state(loss)
      test_t_score.update_state(T)
      test_d_score.update_state(D)

      pred = tf.math.argmax(logits, axis=1)
      cm = tf.math.confusion_matrix(y, pred, num_classes=args.num_classes)

      return cm, layers

    # Define functions for distributed training
    def distributed_train_step(dataset_inputs):
      return strategy.experimental_run_v2(train_step,
                                                        args=(dataset_inputs,))
      #return strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)

    def distributed_test_step(dataset_inputs):
      return strategy.experimental_run_v2(test_step, args=(dataset_inputs, ))
    
    if args.enable_tf_function:
      distributed_train_step = tf.function(distributed_train_step)
      distributed_test_step = tf.function(distributed_test_step)

    # Loop for multiple epochs
    step = 0
    for epoch in range(args.epochs):
      ########################################
      # Test [0-9]
      ########################################
      cm = np.zeros((args.num_classes, args.num_classes))
      for data in test_ds:
        distr_cm, distr_layers = distributed_test_step(data)
        for r in range(num_replicas):
          cm += distr_cm.values[r]

      # Log test results (for replica 0 only for activation map and reconstruction)
      figure = utils.plot_confusion_matrix(cm.numpy(), class_names)
      cm_image = utils.plot_to_image(figure)
      with test_writer.as_default(): 
        tf.summary.image("Confusion Matrix", cm_image, step=step)

      with test_writer.as_default(): 
        tf.summary.image(
          "Activation Map",
          v_map(distr_layers[0].values[r]),
          step=step,
          max_outputs=1,)

      print("TEST [0-9] | epoch %d (%d): acc=%.2f, loss=%.3f, T=%.2f, D=%.2f" % 
            (epoch, step, test_accuracy.result(), test_loss.result(), test_t_score.result(), 
            test_d_score.result()), flush=True)  

      with test_writer.as_default(): 
        tf.summary.scalar("General/Accuracy", test_accuracy.result(), step=step)
        tf.summary.scalar("General/Loss", test_loss.result(), step=step)
        tf.summary.scalar("Gamma-Metrics/T-Score", test_t_score.result(), step=step)
        tf.summary.scalar("Gamma-Metrics/D-Score", test_d_score.result(), step=step)
      test_accuracy.reset_states()
      test_loss.reset_states()
      test_t_score.reset_states()
      test_d_score.reset_states()
      test_writer.flush()

      ########################################
      # Test [0]
      ########################################
      for data in test_ds_0:
        _, distr_layers = distributed_test_step(data)

      # Log test results
      with test_writer_0.as_default(): 
        tf.summary.image(
          "Activation Map",
          v_map(distr_layers[0].values[0]),
          step=step,
          max_outputs=1,)

      print("TEST [0] | epoch %d (%d): acc=%.2f, loss=%.3f, T=%.2f, D=%.2f" % 
            (epoch, step, test_accuracy.result(), test_loss.result(), test_t_score.result(), 
            test_d_score.result()), flush=True)  
      with test_writer_0.as_default(): 
        tf.summary.scalar("General/Accuracy", test_accuracy.result(), step=step)
        tf.summary.scalar("General/Loss", test_loss.result(), step=step)
        tf.summary.scalar("Gamma-Metrics/T-Score", test_t_score.result(), step=step)
        tf.summary.scalar("Gamma-Metrics/D-Score", test_d_score.result(), step=step)

      test_accuracy.reset_states()
      test_loss.reset_states()
      test_t_score.reset_states()
      test_d_score.reset_states()
      test_writer.flush()

      ########################################
      # Train
      ########################################
      for data in train_ds:
        start = time.time()
        distr_loss, distr_imgs = distributed_train_step(data)
        train_loss = 0
        for r in range(num_replicas):
          train_loss += distr_loss.values[r]        

        if step % 100 == 0:
          # Show some inputs, adversarial inputs and reconstructions
          time_per_step = (time.time()-start) * 1000 / 100
          print("TRAIN | epoch %d (%d): acc=%.2f, loss=%.3f, T=%.2f, D=%.2f | Time per step[ms]: %.2f" % 
              (epoch, step, train_accuracy.result(), train_loss.numpy(), 
                train_t_score.result(), train_d_score.result(), time_per_step), flush=True)     

          # Create recon tensorboard images
          x, x_adv, recon_x = distr_imgs[0].values[0], distr_imgs[1].values[0], distr_imgs[2].values[0]
          recon_x = tf.reshape(recon_x, [-1, tf.shape(x)[1], tf.shape(x)[2]])  
          img = tf.concat([x, x_adv, recon_x], axis=1)
          img = tf.expand_dims(img, -1)

          with train_writer.as_default(): 
            tf.summary.scalar("General/Accuracy", train_accuracy.result(), step=step)
            tf.summary.scalar("General/Loss", train_loss.numpy(), step=step)
            tf.summary.scalar("Gamma-Metrics/T-Score", train_t_score.result(), step=step)
            tf.summary.scalar("Gamma-Metrics/D-Score", train_d_score.result(), step=step)
            tf.summary.image(
              "X & XAdv & Recon",
              img,
              step=step,
              max_outputs=3,)

          train_accuracy.reset_states()
          train_t_score.reset_states()
          train_d_score.reset_states()
          start = time.time()

          train_writer.flush()

        step += 1
      
      ####################
      # Checkpointing
      if epoch % 1 == 0:
        checkpoint.save(ckpt_dir)
    next_q = set_of_batch_rewards + (discount_rate *
                                     tf.reduce_max(next_q_values, axis=1))
    history = approximator_model.fit(
        [set_of_batch_initial_states, set_of_batch_actions],
        next_q,
        verbose=1,
        callbacks=[tensorflow_callback])

    # Wrap up
    loss = history.history.get("loss", [0])[0]
    time_end = np.round(time.time() - start_time, 2)
    memory_usage = process.memory_info().rss
    tmp = random.choice(experience_batch)
    # print(tmp.shape)
    episode_image = plot_to_image(image_grid(tmp, action_meanings))

    print(f"Current memory consumption is {memory_usage}")
    print(f"Loss of episode {episode} is {loss} and took {time_end} seconds")
    print(f"TOTAL REWARD: {np.sum(episode_rewards)}")
    with file_writer_rewards.as_default():
        tf.summary.scalar('episode_rewards',
                          np.sum(episode_rewards),
                          step=episode)
        tf.summary.scalar('episode_loss', loss, step=episode)
        tf.summary.scalar('episode_time_in_secs', time_end, step=episode)
        tf.summary.scalar('episode_nr_frames', frame_cnt, step=episode)
        tf.summary.scalar('episode_exploration_rate',
                          exploration_rate,
                          step=episode)
        tf.summary.scalar('episode_mem_usage', memory_usage, step=episode)