def __init__(self, parser, data):
        tf.reset_default_graph()
        args = parser.parse_args()
        self.continue_from_epoch = args.continue_from_epoch
        self.experiment_name = args.experiment_title
        self.saved_models_filepath, self.log_path, self.save_image_path = build_experiment_folder(
            self.experiment_name)
        self.num_gpus = args.num_of_gpus
        self.batch_size = args.batch_size
        gen_depth_per_layer = args.generator_inner_layers
        discr_depth_per_layer = args.discriminator_inner_layers
        self.z_dim = args.z_dim
        self.num_generations = args.num_generations
        self.dropout_rate_value = args.dropout_rate_value
        self.data = data
        self.reverse_channels = False
        # self.support_number = args.support_number
        self.classification_total_epoch = args.classification_total_epoch
        image_channel = data.image_channel
        self.use_wide_connections = args.use_wide_connections

        generator_layers = [64, 64, 128, 128]
        self.discriminator_layers = [64, 64, 128, 128]

        gen_inner_layers = [
            gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer,
            gen_depth_per_layer
        ]
        self.discr_inner_layers = [
            discr_depth_per_layer, discr_depth_per_layer,
            discr_depth_per_layer, discr_depth_per_layer
        ]
        generator_layer_padding = ["SAME", "SAME", "SAME", "SAME"]

        image_height = data.image_height
        image_width = data.image_width
        image_channel = data.image_channel

        self.support_number = args.support_number
        self.selected_classes = args.selected_classes
        self.general_classification_samples = args.general_classification_samples

        self.classes = tf.placeholder(tf.int32)
        self.selected_class = tf.placeholder(tf.int32)
        self.number_support = tf.placeholder(tf.int32)

        self.input_x_i = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, image_height, image_width,
            image_channel
        ], 'batch')
        self.input_y_i = tf.placeholder(
            tf.float32,
            [self.num_gpus, self.batch_size, self.data.selected_classes],
            'y_inputs_bacth')
        self.input_global_y_i = tf.placeholder(
            tf.float32,
            [self.num_gpus, self.batch_size, self.data.testing_classes],
            'y_inputs_bacth_global')

        self.input_x_j = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.data.selected_classes *
            self.data.support_number, image_height, image_width, image_channel
        ], 'support')
        self.input_y_j = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.data.selected_classes *
            self.data.support_number, self.data.selected_classes
        ], 'y_inputs_support')
        self.input_global_y_j = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.data.selected_classes *
            self.data.support_number, self.data.testing_classes
        ], 'y_inputs_support_global')

        self.input_x_j_selected = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, image_height, image_width,
            image_channel
        ], 'support_discriminator')
        self.input_global_y_j_selected = tf.placeholder(
            tf.float32,
            [self.num_gpus, self.batch_size, self.data.testing_classes],
            'y_inputs_support_discriminator')

        #### setting placehoder for the matchingGAN, mainly for the support images
        self.input_y_i_dagan = tf.placeholder(
            tf.float32,
            [self.num_gpus, self.batch_size, self.selected_classes],
            'y_inputs_bacth_dagan')
        self.input_x_j_dagan = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.support_number, image_height,
            image_width, image_channel
        ], 'support_dagan')
        self.input_y_j_dagan = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.support_number,
            self.data.selected_classes
        ], 'y_inputs_support_dagan')
        self.input_global_y_j_dagan = tf.placeholder(
            tf.float32, [
                self.num_gpus, self.batch_size, self.support_number,
                self.data.testing_classes
            ], 'y_inputs_support_global_dagan')

        self.z_input = tf.placeholder(tf.float32,
                                      [self.batch_size, self.z_dim], 'z-input')
        self.z_input_2 = tf.placeholder(tf.float32,
                                        [self.batch_size, self.z_dim],
                                        'z-input_2')

        self.feed_augmented = tf.placeholder(tf.int32)
        self.feed_confidence = tf.placeholder(tf.int32)
        self.feed_loss_d = tf.placeholder(tf.int32)

        # self.selected_loss_d = tf.placeholder(tf.int32)
        # self.selected_confidence = tf.placeholder(tf.int32)
        # self.number_augmented = tf.placeholder(tf.int32)
        self.training_phase = tf.placeholder(tf.bool, name='training-flag')
        self.random_rotate = tf.placeholder(tf.bool, name='rotation-flag')
        self.dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')
        self.z1z2_training = tf.placeholder(tf.bool, name='z1z2_training-flag')
        self.is_z2 = args.is_z2
        self.is_z2_vae = args.is_z2_vae

        self.is_z2 = args.is_z2
        self.is_z2_vae = args.is_z2_vae
        self.loss_G = args.loss_G
        self.loss_D = args.loss_D
        self.loss_CLA = args.loss_CLA
        self.loss_FSL = args.loss_FSL
        self.loss_KL = args.loss_KL
        self.loss_recons_B = args.loss_recons_B
        self.loss_matching_G = args.loss_matching_G
        self.loss_matching_D = args.loss_matching_D
        self.loss_sim = args.loss_sim
        self.strategy = args.strategy

        self.is_fewshot_setting = args.is_fewshot_setting
        # self.few_shot_episode_classes = args.few_shot_episode_classes
        self.few_shot_episode_classes = args.selected_classes
        self.confidence = args.confidence
        self.loss_d = args.loss_d
        self.augmented_number = args.augmented_number
        self.matching = args.matching
        self.fce = args.fce
        self.full_context_unroll_k = args.full_context_unroll_k
        self.average_per_class_embeddings = args.average_per_class_embeddings
        self.restore_path = args.restore_path
        self.restore_classifier_path = args.restore_classifier_path
        self.episodes = args.episodes_number

        if self.augmented_number > 0:
            dagan = DAGAN(batch_size=self.batch_size, input_x_i=self.input_x_i, input_x_j=self.input_x_j_dagan,
                      input_y_i=self.input_y_i_dagan, input_y_j=self.input_y_j_dagan, input_global_y_i=self.input_global_y_i,
                      input_global_y_j=self.input_global_y_j_dagan,
                      input_x_j_selected=self.input_x_j_selected,
                      input_global_y_j_selected=self.input_global_y_j_selected, \
                      selected_classes=self.selected_classes, support_num=self.support_number,
                      classes=self.data.training_classes,
                      dropout_rate=self.dropout_rate, generator_layer_sizes=generator_layers,
                      generator_layer_padding=generator_layer_padding, num_channels=data.image_channel,
                      is_training=self.training_phase, augment=self.random_rotate,
                      discriminator_layer_sizes=self.discriminator_layers,
                      discr_inner_conv=self.discr_inner_layers, is_z2=self.is_z2, is_z2_vae=self.is_z2_vae,
                      gen_inner_conv=gen_inner_layers, num_gpus=self.num_gpus, z_dim=self.z_dim, z_inputs=self.z_input,
                      z_inputs_2=self.z_input_2,
                      use_wide_connections=args.use_wide_connections, fce=self.fce, matching=self.matching,
                      full_context_unroll_k=self.full_context_unroll_k,
                      average_per_class_embeddings=self.average_per_class_embeddings,
                      loss_G=self.loss_G, loss_D=self.loss_D, loss_KL=self.loss_KL, loss_recons_B=self.loss_recons_B,
                      loss_matching_G=self.loss_matching_G, loss_matching_D=self.loss_matching_D,
                      loss_CLA=self.loss_CLA, loss_FSL=self.loss_FSL, loss_sim=self.loss_sim,
                      z1z2_training=self.z1z2_training)
            self.same_images = dagan.sample_same_images()

        if self.is_fewshot_setting:
            print('fewshot classifier categories:',
                  self.few_shot_episode_classes)
            classifier = densenet_classifier(
                input_x_i=self.input_x_i,
                input_y=self.input_y_i,
                classes=self.few_shot_episode_classes,
                batch_size=self.batch_size,
                layer_sizes=self.discriminator_layers,
                inner_layers=self.discr_inner_layers,
                num_gpus=self.num_gpus,
                use_wide_connections=args.use_wide_connections,
                is_training=self.training_phase,
                augment=self.random_rotate,
                dropout_rate=self.dropout_rate)
        else:
            print('general classifier categories:', self.data.testing_classes)
            classifier = densenet_classifier(
                input_x_i=self.input_x_i,
                input_y=self.input_global_y_i,
                classes=self.data.testing_classes,
                batch_size=self.batch_size,
                layer_sizes=self.discriminator_layers,
                inner_layers=self.discr_inner_layers,
                num_gpus=self.num_gpus,
                use_wide_connections=args.use_wide_connections,
                is_training=self.training_phase,
                augment=self.random_rotate,
                dropout_rate=self.dropout_rate)

        self.summary, self.losses, self.accuracy, self.graph_ops = classifier.init_train(
        )

        self.total_train_batches = int(data.training_data_size /
                                       (self.batch_size * self.num_gpus))
        self.total_val_batches = int(data.validation_data_size /
                                     (self.batch_size * self.num_gpus))
        self.total_test_batches = int(data.testing_data_size /
                                      (self.batch_size * self.num_gpus))
        self.total_gen_batches = int(data.testing_data_size /
                                     (self.batch_size * self.num_gpus))
        self.init = tf.global_variables_initializer()
        self.spherical_interpolation = False

        self.tensorboard_update_interval = int(self.total_test_batches / 10 /
                                               self.num_gpus)

        self.z_vectors = np.random.normal(size=(10, self.z_dim))
        self.z_vectors_2 = np.random.normal(size=(10, self.z_dim))

        self.z_inputs = np.random.randn(self.batch_size, self.z_dim)
        self.z_inputs_2 = np.random.randn(self.batch_size, self.z_dim)

        self.total_test_items = int(
            self.general_classification_samples / self.data.support_number) + 1
    def __init__(self, parser, data):
        tf.reset_default_graph()
        args = parser.parse_args()
        self.continue_from_epoch = args.continue_from_epoch
        self.experiment_name = args.experiment_title
        self.saved_models_filepath, self.log_path, self.save_image_path = build_experiment_folder(
            self.experiment_name)
        self.num_gpus = args.num_of_gpus
        self.batch_size = args.batch_size
        gen_depth_per_layer = args.generator_inner_layers
        discr_depth_per_layer = args.discriminator_inner_layers
        self.z_dim = args.z_dim
        self.num_generations = args.num_generations
        self.dropout_rate_value = args.dropout_rate_value
        self.data = data
        self.reverse_channels = False
        self.support_number = args.support_number
        self.classification_total_epoch = args.classification_total_epoch
        image_channel = data.image_channel
        self.use_wide_connections = args.use_wide_connections
        self.pretrain = args.pretrain

        generator_layers = [64, 64, 128, 128]
        self.discriminator_layers = [64, 64, 128, 128]

        gen_inner_layers = [
            gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer,
            gen_depth_per_layer
        ]
        self.discr_inner_layers = [
            discr_depth_per_layer, discr_depth_per_layer,
            discr_depth_per_layer, discr_depth_per_layer
        ]
        generator_layer_padding = ["SAME", "SAME", "SAME", "SAME"]

        image_height = self.data.image_width
        image_width = self.data.image_width
        image_channels = self.data.image_channel

        self.classes = tf.placeholder(tf.int32)
        self.input_x_i = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size * self.data.selected_classes,
            image_height, image_width, image_channels
        ], 'inputs-1')

        self.input_y = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size * self.data.selected_classes,
            self.data.selected_classes
        ], 'y_inputs-1')

        self.input_x_j = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.support_number, image_height,
            image_width, image_channels
        ], 'inputs-2-same-class')

        self.z_input = tf.placeholder(tf.float32,
                                      [self.batch_size, self.z_dim], 'z-input')
        self.z_input_2 = tf.placeholder(tf.float32,
                                        [self.batch_size, self.z_dim],
                                        'z-input_2')

        self.training_phase = tf.placeholder(tf.bool, name='training-flag')
        self.random_rotate = tf.placeholder(tf.bool, name='rotation-flag')
        self.dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')

        self.matching = args.matching
        self.fce = args.fce
        self.full_context_unroll_k = args.full_context_unroll_k
        self.average_per_class_embeddings = args.average_per_class_embeddings

        self.total_train_batches = data.training_data_size / (self.batch_size *
                                                              self.num_gpus)
        self.total_val_batches = data.validation_data_size / (self.batch_size *
                                                              self.num_gpus)
        self.total_test_batches = 5 * 545 / (self.batch_size * self.num_gpus)
        self.total_gen_batches = data.generation_data_size / (self.batch_size *
                                                              self.num_gpus)
        self.spherical_interpolation = True

        self.tensorboard_update_interval = int(self.total_test_batches / 10 /
                                               self.num_gpus)

        classifier = densenet_classifier(
            input_x_i=self.input_x_i,
            input_y=self.input_y,
            classes=self.data.selected_classes,
            batch_size=self.batch_size,
            layer_sizes=self.discriminator_layers,
            inner_layers=self.discr_inner_layers,
            num_gpus=self.num_gpus,
            use_wide_connections=self.use_wide_connections,
            is_training=self.training_phase,
            augment=self.random_rotate,
            dropout_rate=self.dropout_rate)
        print('classes', self.data.selected_classes)

        self.summary, self.losses, self.accuracy, self.graph_ops = classifier.init_train(
        )
        self.init = tf.global_variables_initializer()