示例#1
0
                                       running_batch_elapsed_time) / 60.0

        print(
            "===== TRAINING STEP {} | ~{:.0f} MINUTES REMAINING =====".format(
                training_step, estimated_minutes_remaining))
        print("CRITIC LOSS:     {0}".format(running_critic_loss))
        print("GENERATOR LOSS:  {0}\n".format(running_generator_loss))

        # Loss histories
        critic_losses_per_vis_interval.append(running_critic_loss)
        generator_losses_per_vis_interval.append(running_generator_loss)
        running_critic_loss = 0.0
        running_generator_loss = 0.0

        Plot.plot_histories(
            [critic_losses_per_vis_interval], ["Critic"],
            "{0}critic_loss_history.png".format(MODEL_OUTPUT_DIR))
        Plot.plot_histories(
            [generator_losses_per_vis_interval], ["Generator"],
            "{0}generator_loss_history.png".format(MODEL_OUTPUT_DIR))

        # Save model at checkpoint
        torch.save(generator.state_dict(),
                   "{0}generator".format(MODEL_OUTPUT_DIR))
        torch.save(critic.state_dict(), "{0}critic".format(MODEL_OUTPUT_DIR))

        # Upsample and save samples
        sample_tags = brainpedia.preprocessor.decode_label(
            labels_batch.data[0])
        real_sample_data = real_brain_img_data_batch[0].cpu().data.numpy(
        ).squeeze()
            classifier_running_losses[2]))

        print("NN CLASSIFIER TEST ACCURACY:               {0:.2f}%".format(
            100.0 * accuracies[0]))
        print("NN SYNTHETIC CLASSIFIER TEST ACCURACY:     {0:.2f}%".format(
            100.0 * accuracies[1]))
        print("NN MIXED CLASSIFIER TEST ACCURACY:         {0:.2f}%\n\n".format(
            100.0 * accuracies[2]))

        # Loss histories
        for i in range(num_classifiers):
            classifier_losses[i].append(classifier_running_losses[i])
            classifier_running_losses[i] = 0.0

        Plot.plot_histories(
            classifier_losses,
            ['[REAL] Loss', '[SYNTHETIC] Loss', '[REAL + SYNTHETIC] Loss'],
            "{0}loss_histories".format(args.output_dir))
        Plot.plot_histories(classifier_accuracies, [
            '[REAL] Test Accuracy', '[SYNTHETIC] Test Accuracy',
            '[REAL + SYNTHETIC] Test Accuracy'
        ], "{0}accuracy_histories".format(args.output_dir))

        # Save model at checkpoint
        torch.save(classifiers[0].state_dict(),
                   "{0}nn_classifier".format(args.output_dir))
        torch.save(classifiers[1].state_dict(),
                   "{0}synthetic_nn_classifier".format(args.output_dir))
        torch.save(classifiers[2].state_dict(),
                   "{0}mixed_nn_classifier".format(args.output_dir))

# Save final NN classifier results to results_f: