示例#1
0
    def __init__(self, args, data, model, device):
        """
        Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system
        (model) and a device (e.g. gpu/cpu/n)
        :param args: A namedtuple containing all experiment hyperparameters
        :param data: A data provider of instance MetaLearningSystemDataLoader
        :param model: A meta learning system instance
        :param device: Device/s to use for the experiment
        """
        self.args, self.device = args, device

        self.model = model
        self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder(
            experiment_name=self.args.experiment_name)

        self.total_losses = dict()
        self.state = dict()
        self.state['best_val_acc'] = 0.
        self.state['best_val_iter'] = 0
        self.state['current_iter'] = 0
        self.state['current_iter'] = 0
        self.start_epoch = 0
        self.max_models_to_save = self.args.max_models_to_save
        self.create_summary_csv = False

        if self.args.continue_from_epoch == 'from_scratch':
            self.create_summary_csv = True

        elif self.args.continue_from_epoch == 'latest':
            checkpoint = os.path.join(self.saved_models_filepath, "train_model_latest")
            print("attempting to find existing checkpoint", )
            if os.path.exists(checkpoint):
                self.state = \
                    self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                          model_idx='latest')
                self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)

            else:
                self.args.continue_from_epoch = 'from_scratch'
                self.create_summary_csv = True
        elif int(self.args.continue_from_epoch) >= 0:
            self.state = \
                self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                      model_idx=self.args.continue_from_epoch)
            self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)

        self.data = data(args=args, current_iter=self.state['current_iter'])

        print("train_seed {}, val_seed: {}, at start time".format(self.data.dataset.seed["train"],
                                                                  self.data.dataset.seed["val"]))
        self.total_epochs_before_pause = self.args.total_epochs_before_pause
        self.state['best_epoch'] = int(self.state['best_val_iter'] / self.args.total_iter_per_epoch)
        self.epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)
        self.augment_flag = True if 'omniglot' in self.args.dataset_name.lower() else False
        self.start_time = time.time()
        self.epochs_done_in_this_run = 0
        print(self.state['current_iter'], int(self.args.total_iter_per_epoch * self.args.total_epochs))
示例#2
0
def housekeeping():
    argument_parser = get_base_argument_parser()
    args = process_args(argument_parser)

    if args.gpu_ids_to_use is None:
        select_devices(
            args.num_gpus_to_use,
            max_load=args.max_gpu_selection_load,
            max_memory=args.max_gpu_selection_memory,
            exclude_gpu_ids=args.excude_gpu_list,
        )
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids_to_use.replace(
            " ", ",")

    saved_models_filepath, logs_filepath, images_filepath = build_experiment_folder(
        experiment_name=args.experiment_name, log_path=args.logs_path)

    args.saved_models_filepath = saved_models_filepath
    args.logs_filepath = logs_filepath
    args.images_filepath = images_filepath

    # Determinism Seeding can be annoying in pytorch at the moment.
    # Based on my experience, the below means of seeding allows for deterministic
    # experimentation.
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)  # set seed
    random.seed(args.seed)
    device = (torch.cuda.current_device() if torch.cuda.is_available()
              and args.num_gpus_to_use > 0 else "cpu")
    args.device = device
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True

    # Always save a snapshot of the current state of the code. I've found this helps
    # immensely if you find that one of your many experiments was actually quite good
    # but you forgot what you did

    snapshot_filename = f"{saved_models_filepath}/snapshot.tar.gz"
    filetypes_to_include = [".py"]
    all_files = []
    for _ in filetypes_to_include:
        all_files += glob.glob("**/*.py", recursive=True)
    with tarfile.open(snapshot_filename, "w:gz") as tar:
        for file in all_files:
            tar.add(file)

    return args
示例#3
0
train_data = CIFAR10DataProvider(which_set="train",
                                 batch_size=batch_size,
                                 rng=rng)
val_data = CIFAR10DataProvider(which_set="valid",
                               batch_size=batch_size,
                               rng=rng)
test_data = CIFAR10DataProvider(which_set="test",
                                batch_size=batch_size,
                                rng=rng)
#  setup our data providers

print("Running {}".format(experiment_name))
print("Starting from epoch {}".format(continue_from_epoch))

saved_models_filepath, logs_filepath = build_experiment_folder(
    experiment_name, logs_path)  # generate experiment dir

# Placeholder setup
data_inputs = tf.placeholder(tf.float32, [
    batch_size, train_data.inputs.shape[1], train_data.inputs.shape[2],
    train_data.inputs.shape[3]
], 'data-inputs')
data_targets = tf.placeholder(tf.int32, [batch_size], 'data-targets')

training_phase = tf.placeholder(tf.bool, name='training-flag')
rotate_data = tf.placeholder(tf.bool, name='rotate-flag')
dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')

classifier_network = ClassifierNetworkGraph(
    input_x=data_inputs,
    target_placeholder=data_targets,
    num_classes_per_set=args.classes_per_set,
    label_as_int=True)

experiment = ExperimentBuilder(data)
one_shot_omniglot, losses, c_error_opt_op, init = experiment.build_experiment(
    args.batch_size,
    args.classes_per_set,
    args.samples_per_class,
    args.use_full_context_embeddings,
    full_context_unroll_k=args.full_context_unroll_k,
    args=args)
total_train_batches = args.total_iter_per_epoch
total_val_batches = args.total_iter_per_epoch
total_test_batches = args.total_iter_per_epoch

saved_models_filepath, logs_filepath = build_experiment_folder(
    args.experiment_title)

save_statistics(logs_filepath, [
    "epoch", "total_train_c_loss_mean", "total_train_c_loss_std",
    "total_train_accuracy_mean", "total_train_accuracy_std",
    "total_val_c_loss_mean", "total_val_c_loss_std", "total_val_accuracy_mean",
    "total_val_accuracy_std", "total_test_c_loss_mean",
    "total_test_c_loss_std", "total_test_accuracy_mean",
    "total_test_accuracy_std"
],
                create=True)

# Experiment initialization and running
with tf.Session() as sess:
    sess.run(init)
    train_saver = tf.train.Saver()
示例#5
0
    def __init__(self, args, data, model, device):
        """
        Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system
        (model) and a device (e.g. gpu/cpu/n)
        :param args: A namedtuple containing all experiment hyperparameters
        :param data: A data provider of instance MetaLearningSystemDataLoader
        :param model: A meta learning system instance
        :param device: Device/s to use for the experiment
        """
        self.args, self.device = args, device
        num_thousand_iters = int(args.total_epochs *
                                 args.total_iter_per_epoch / 1000)
        if not args.TR_MAML:
            tmp_maml_str = ''
        else:
            tmp_maml_str = 'tr'

        if args.dataset_name == 'omniglot_dataset':
            tmp_ds = 'omni_'
        else:
            tmp_ds = 'mini_'

        experiment_name = tmp_ds + str(num_thousand_iters) + 'k' + str(
            args.batch_size) + 'bs_' + str(
                args.num_classes_per_set) + 'way' + str(
                    args.num_samples_per_class) + 'shot_' + str(
                        args.number_of_training_steps_per_iter
                    ) + 'gs_' + tmp_maml_str + 'm_' + str(
                        args.train_seed) + 'seed'

        self.model = model
        self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder(
            experiment_name)
        print(experiment_name)
        self.total_losses = dict()
        self.state = dict()
        self.state['best_val_acc'] = 0.
        self.state['best_val_iter'] = 0
        self.state['current_iter'] = 0
        self.state['current_iter'] = 0
        self.start_epoch = 0
        self.max_models_to_save = self.args.max_models_to_save
        self.create_summary_csv = False

        self.args.continue_from_epoch = 'from_scratch'

        if self.args.continue_from_epoch == 'from_scratch':
            self.create_summary_csv = True

        elif self.args.continue_from_epoch == 'latest':
            checkpoint = os.path.join(self.saved_models_filepath,
                                      "train_model_latest")
            print("attempting to find existing checkpoint", )
            if os.path.exists(checkpoint):
                self.state = \
                    self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                          model_idx='latest')
                self.start_epoch = int(self.state['current_iter'] /
                                       self.args.total_iter_per_epoch)

            else:
                self.args.continue_from_epoch = 'from_scratch'
                self.create_summary_csv = True
        elif int(self.args.continue_from_epoch) >= 0:
            self.state = \
                self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                      model_idx=self.args.continue_from_epoch)
            self.start_epoch = int(self.state['current_iter'] /
                                   self.args.total_iter_per_epoch)

        self.num_train_tasks = self.args.num_train_tasks
        self.num_test_tasks = self.args.num_test_tasks
        self.data = data(args=args, current_iter=self.state['current_iter'])

        print("train_seed {}, val_seed: {}, at start time".format(
            self.data.dataset.seed["train"], self.data.dataset.seed["val"]))
        self.total_epochs_before_pause = self.args.total_epochs_before_pause
        self.state['best_epoch'] = int(self.state['best_val_iter'] /
                                       self.args.total_iter_per_epoch)
        self.epoch = int(self.state['current_iter'] /
                         self.args.total_iter_per_epoch)
        self.augment_flag = True if 'omniglot' in self.args.dataset_name.lower(
        ) else False
        self.start_time = time.time()
        self.epochs_done_in_this_run = 0
        print("CURRENT STATE")
        print(self.state['current_iter'],
              int(self.args.total_iter_per_epoch * self.args.total_epochs))
示例#6
0
    def __init__(self, parser, data):
        tf.reset_default_graph()

        args = parser.parse_args()
        self.continue_from_epoch = args.continue_from_epoch
        self.experiment_name = args.experiment_title
        self.saved_models_filepath, self.log_path, self.save_image_path = build_experiment_folder(self.experiment_name)
        self.num_gpus = args.num_of_gpus
        self.batch_size = args.batch_size
        gen_depth_per_layer = args.generator_inner_layers
        discr_depth_per_layer = args.discriminator_inner_layers
        self.z_dim = args.z_dim
        self.num_generations = args.num_generations
        self.dropout_rate_value = args.dropout_rate_value
        self.data = data
        self.reverse_channels = False

        generator_layers = [64, 64, 128, 128]
        discriminator_layers = [64, 64, 128, 128]

        gen_inner_layers = [gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer]
        discr_inner_layers = [discr_depth_per_layer, discr_depth_per_layer, discr_depth_per_layer,
                              discr_depth_per_layer]
        generator_layer_padding = ["SAME", "SAME", "SAME", "SAME"]

        image_height = data.image_height
        image_width = data.image_width
        image_channel = data.image_channel

        self.input_x_i = tf.placeholder(tf.float32, [self.num_gpus, self.batch_size, image_height, image_width,
                                                     image_channel], 'inputs-1')
        self.input_x_j = tf.placeholder(tf.float32, [self.num_gpus, self.batch_size, image_height, image_width,
                                                     image_channel], 'inputs-2-same-class')

        self.z_input = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], 'z-input')
        self.training_phase = tf.placeholder(tf.bool, name='training-flag')
        self.random_rotate = tf.placeholder(tf.bool, name='rotation-flag')
        self.dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')

        dagan = DAGAN(batch_size=self.batch_size, input_x_i=self.input_x_i, input_x_j=self.input_x_j,
                      dropout_rate=self.dropout_rate, generator_layer_sizes=generator_layers,
                      generator_layer_padding=generator_layer_padding, num_channels=data.image_channel,
                      is_training=self.training_phase, augment=self.random_rotate,
                      discriminator_layer_sizes=discriminator_layers,
                      discr_inner_conv=discr_inner_layers,
                      gen_inner_conv=gen_inner_layers, num_gpus=self.num_gpus, z_dim=self.z_dim, z_inputs=self.z_input)

        self.summary, self.losses, self.graph_ops = dagan.init_train()
        self.same_images = dagan.sample_same_images()

        self.total_train_batches = int(data.training_data_size / (self.batch_size * self.num_gpus))

        self.total_gen_batches = int(data.generation_data_size / (self.batch_size * self.num_gpus))

        self.init = tf.global_variables_initializer()
        self.spherical_interpolation = True
        self.tensorboard_update_interval = int(self.total_train_batches/100/self.num_gpus)
        self.total_epochs = 200

        if self.continue_from_epoch == -1:
            save_statistics(self.log_path, ['epoch', 'total_d_train_loss_mean', 'total_d_val_loss_mean',
                                            'total_d_train_loss_std', 'total_d_val_loss_std',
                                            'total_g_train_loss_mean', 'total_g_val_loss_mean',
                                            'total_g_train_loss_std', 'total_g_val_loss_std'], create=True)
trainloader, testloader, in_shape = load_dataset(args)

n_train_batches = len(trainloader)
n_train_images = len(trainloader.dataset)
n_test_batches = len(testloader)
n_test_images = len(testloader.dataset)

print("Data loaded successfully ")
print("Training --> {} images and {} batches".format(n_train_images,
                                                     n_train_batches))
print("Testing --> {} images and {} batches".format(n_test_images,
                                                    n_test_batches))

######################################################################################################### Admin

saved_models_filepath, logs_filepath, images_filepath = build_experiment_folder(
    args)

start_epoch, latest_loadpath = get_start_epoch(args)
args.latest_loadpath = latest_loadpath
best_epoch, best_test_acc = get_best_epoch(args)
if best_epoch >= 0:
    print('Best evaluation acc so far at {} epochs: {:0.2f}'.format(
        best_epoch, best_test_acc))

if not args.resume:
    save_statistics(logs_filepath,
                    "result_summary_statistics", [
                        "epoch",
                        "train_loss",
                        "test_loss",
                        "train_loss_c",
示例#8
0
    def __init__(self, args, data, model, device):
        """
        Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system
        (model) and a device (e.g. gpu/cpu/n)
        :param args: A namedtuple containing all experiment hyperparameters
        :param data: A data provider of instance MetaLearningSystemDataLoader
        :param model: A meta learning system instance
        :param device: Device/s to use for the experiment
        """
        self.args, self.device = args, device

        self.model = model
        (
            self.saved_models_filepath,
            self.logs_filepath,
            self.samples_filepath,
        ) = build_experiment_folder(experiment_name=self.args.experiment_name)

        self.per_task_performance = defaultdict(lambda: 0)
        self.total_losses = dict()
        self.state = dict()
        self.state["best_val_loss"] = 10**6
        self.state["best_val_accuracy"] = 0
        self.state["best_val_iter"] = 0
        self.state["current_iter"] = 0
        self.start_epoch = 0
        self.num_epoch_no_improvements = 0
        self.patience = args.patience
        self.create_summary_csv = False

        self.writer = SummaryWriter("runs/{}".format(
            self.args.experiment_name))

        if self.args.continue_from_epoch == "from_scratch":
            self.create_summary_csv = True

        elif self.args.continue_from_epoch == "latest":
            checkpoint = os.path.join(self.saved_models_filepath,
                                      "train_model_latest")
            print("attempting to find existing checkpoint", )
            if os.path.exists(checkpoint):
                self.state = self.model.load_model(
                    model_save_dir=self.saved_models_filepath,
                    model_name="train_model",
                    model_idx="latest",
                )
                self.start_epoch = int(self.state["current_iter"] /
                                       self.args.total_iter_per_epoch)

            else:
                self.args.continue_from_epoch = "from_scratch"
                self.create_summary_csv = True
        elif int(self.args.continue_from_epoch) >= 0:
            self.state = self.model.load_model(
                model_save_dir=self.saved_models_filepath,
                model_name="train_model",
                model_idx=self.args.continue_from_epoch,
            )
            self.start_epoch = int(self.state["current_iter"] /
                                   self.args.total_iter_per_epoch)

        self.data = data(args=args, current_iter=self.state["current_iter"])

        self.idx_to_class_name = self.data.dataset.load_from_json(
            self.data.dataset.index_to_label_name_dict_file)

        print("train_seed {}, val_seed: {}, at start time".format(
            self.data.dataset.seed["train"], self.data.dataset.seed["val"]))
        self.total_epochs_before_pause = self.args.total_epochs_before_pause
        self.state["best_epoch"] = int(self.state["best_val_iter"] /
                                       self.args.total_iter_per_epoch)
        self.epoch = int(self.state["current_iter"] /
                         self.args.total_iter_per_epoch)

        self.start_time = time.time()
        self.epochs_done_in_this_run = 0
        print(
            self.state["current_iter"],
            int(self.args.total_iter_per_epoch * self.args.total_epochs),
        )

        if self.epoch == 0:
            for param_name, param in self.model.named_parameters():
                self.writer.add_histogram(param_name, param, 0)

            self.writer.flush()
    def __init__(self, args, data):
        tf.reset_default_graph()
        self.continue_from_epoch = args.continue_from_epoch
        self.experiment_name = args.experiment_title
        self.saved_models_filepath, self.log_path, self.save_image_path = build_experiment_folder(
            self.experiment_name)
        self.num_gpus = args.num_of_gpus
        self.batch_size = args.batch_size
        # self.support_number = args.support_number
        self.selected_classes = args.selected_classes
        gen_depth_per_layer = args.generator_inner_layers
        discr_depth_per_layer = args.discriminator_inner_layers
        self.z_dim = args.z_dim
        self.num_generations = args.num_generations
        self.dropout_rate_value = args.dropout_rate_value
        self.data = data
        self.reverse_channels = False

        if args.generation_layers == 6:
            generator_layers = [64, 64, 128, 128, 256, 256]
            gen_inner_layers = [
                gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer,
                gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer
            ]
            generator_layer_padding = [
                "SAME", "SAME", "SAME", "SAME", "SAME", "SAME"
            ]
        else:
            generator_layers = [64, 64, 128, 128]
            gen_inner_layers = [
                gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer,
                gen_depth_per_layer
            ]
            generator_layer_padding = ["SAME", "SAME", "SAME", "SAME"]

        discriminator_layers = [64, 64, 128, 128]
        discr_inner_layers = [
            discr_depth_per_layer, discr_depth_per_layer,
            discr_depth_per_layer, discr_depth_per_layer
        ]

        image_height = data.image_height
        image_width = data.image_width
        image_channel = data.image_channel

        self.classes = tf.placeholder(tf.int32)
        self.selected_classes = tf.placeholder(tf.int32)
        self.support_number = tf.placeholder(tf.int32)

        #### [self.input_x_i, self.input_y_i, self.input_global_y_i] --> [images, few shot label, global label]
        ## batch: [self.input_x_i, self.input_y_i, self.input_global_y_i]
        ## support: self.input_x_j, self.input_y_j, self.input_global_y_j]
        ## the input of discriminator: [self.input_x_j_selected, self.input_global_y_j_selected]
        self.input_x_i = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, image_height, image_width,
            image_channel
        ], 'batch')
        self.input_y_i = tf.placeholder(
            tf.float32,
            [self.num_gpus, self.batch_size, self.data.selected_classes],
            'y_inputs_bacth')
        self.input_global_y_i = tf.placeholder(
            tf.float32,
            [self.num_gpus, self.batch_size, self.data.training_classes],
            'y_inputs_bacth_global')

        self.input_x_j = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.data.selected_classes *
            self.data.support_number, image_height, image_width, image_channel
        ], 'support')
        self.input_y_j = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.data.selected_classes *
            self.data.support_number, self.data.selected_classes
        ], 'y_inputs_support')
        self.input_global_y_j = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, self.data.selected_classes *
            self.data.support_number, self.data.training_classes
        ], 'y_inputs_support_global')

        self.input_x_j_selected = tf.placeholder(tf.float32, [
            self.num_gpus, self.batch_size, image_height, image_width,
            image_channel
        ], 'support_discriminator')
        self.input_global_y_j_selected = tf.placeholder(
            tf.float32,
            [self.num_gpus, self.batch_size, self.data.training_classes],
            'y_inputs_support_discriminator')

        # self.z_input = tf.placeholder(tf.float32, [self.batch_size*self.data.selected_classes, self.z_dim], 'z-input')
        # self.z_input_2 = tf.placeholder(tf.float32, [self.batch_size*self.data.selected_classes, self.z_dim], 'z-input_2')

        self.z_input = tf.placeholder(tf.float32,
                                      [self.batch_size, self.z_dim], 'z-input')
        self.z_input_2 = tf.placeholder(tf.float32,
                                        [self.batch_size, self.z_dim],
                                        'z-input_2')

        self.training_phase = tf.placeholder(tf.bool, name='training-flag')
        self.z1z2_training = tf.placeholder(tf.bool, name='z1z2_training-flag')
        self.random_rotate = tf.placeholder(tf.bool, name='rotation-flag')
        self.dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')
        self.is_z2 = args.is_z2
        self.is_z2_vae = args.is_z2_vae

        self.matching = args.matching
        self.fce = args.fce
        self.full_context_unroll_k = args.full_context_unroll_k
        self.average_per_class_embeddings = args.average_per_class_embeddings
        self.restore_path = args.restore_path

        self.is_z2 = args.is_z2
        self.is_z2_vae = args.is_z2_vae
        self.loss_G = args.loss_G
        self.loss_D = args.loss_D
        self.loss_CLA = args.loss_CLA
        self.loss_FSL = args.loss_FSL
        self.loss_KL = args.loss_KL
        self.loss_recons_B = args.loss_recons_B
        self.loss_matching_G = args.loss_matching_G
        self.loss_matching_D = args.loss_matching_D
        self.loss_sim = args.loss_sim
        self.strategy = args.strategy

        #### training/validation/testin

        time_1 = time.time()
        dagan = DAGAN(batch_size=self.batch_size, input_x_i=self.input_x_i, input_x_j=self.input_x_j,
                      input_y_i=self.input_y_i, input_y_j=self.input_y_j, input_global_y_i=self.input_global_y_i,
                      input_global_y_j=self.input_global_y_j,
                      input_x_j_selected=self.input_x_j_selected,
                      input_global_y_j_selected=self.input_global_y_j_selected, \
                      selected_classes=self.data.selected_classes, support_num=self.data.support_number,
                      classes=self.data.training_classes,
                      dropout_rate=self.dropout_rate, generator_layer_sizes=generator_layers,
                      generator_layer_padding=generator_layer_padding, num_channels=data.image_channel,
                      is_training=self.training_phase, augment=self.random_rotate,
                      discriminator_layer_sizes=discriminator_layers,
                      discr_inner_conv=discr_inner_layers, is_z2=self.is_z2, is_z2_vae=self.is_z2_vae,
                      gen_inner_conv=gen_inner_layers, num_gpus=self.num_gpus, z_dim=self.z_dim, z_inputs=self.z_input,
                      z_inputs_2=self.z_input_2,
                      use_wide_connections=args.use_wide_connections, fce=self.fce, matching=self.matching,
                      full_context_unroll_k=self.full_context_unroll_k,
                      average_per_class_embeddings=self.average_per_class_embeddings,
                      loss_G=self.loss_G, loss_D=self.loss_D, loss_KL=self.loss_KL, loss_recons_B=self.loss_recons_B,
                      loss_matching_G=self.loss_matching_G, loss_matching_D=self.loss_matching_D,
                      loss_CLA=self.loss_CLA, loss_FSL=self.loss_FSL, loss_sim=self.loss_sim,
                      z1z2_training=self.z1z2_training)

        self.same_images = dagan.sample_same_images()

        # self.summary, self.losses, self.accuracy, self.graph_ops = classifier.init_train()

        self.total_train_batches = int(data.training_data_size /
                                       (self.batch_size * self.num_gpus))
        self.total_val_batches = int(data.validation_data_size /
                                     (self.batch_size * self.num_gpus))
        self.total_test_batches = int(data.testing_data_size /
                                      (self.batch_size * self.num_gpus))
        self.total_gen_batches = int(data.testing_data_size /
                                     (self.batch_size * self.num_gpus))
        self.init = tf.global_variables_initializer()

        time_2 = time.time()
        # print('time for constructing graph:',time_2 - time_1)

        self.tensorboard_update_interval = int(self.total_train_batches / 1 /
                                               self.num_gpus)
        self.total_epochs = 800
        self.is_generation_for_classifier = args.is_generation_for_classifier
        self.is_all_test_categories = args.is_all_test_categories
    def __init__(self, data_dict, model, experiment_name, continue_from_epoch,
                 max_models_to_save, total_iter_per_epoch, total_epochs,
                 num_evaluation_tasks, batch_size, evaluate_on_test_set_only,
                 args):
        """
        Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system
        (model) and a device (e.g. gpu/cpu/n)
        :param args: A namedtuple containing all experiment hyperparameters
        :param data: A data provider of instance MetaLearningSystemDataLoader
        :param model: A meta learning system instance
        :param device: Device/s to use for the experiment
        """

        self.model = model
        self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder(
            experiment_name=experiment_name)

        self.total_losses = dict()
        self.state = dict()
        self.state['best_val_acc'] = 0.
        self.state['best_val_iter'] = 0
        self.state['current_iter'] = 0
        self.state['current_iter'] = 0
        self.start_epoch = 0
        self.max_models_to_save = max_models_to_save
        self.create_summary_csv = False
        self.evaluate_on_test_set_only = evaluate_on_test_set_only

        for key, value in args.__dict__.items():
            setattr(self, key, value)

        if continue_from_epoch == 'from_scratch':
            self.create_summary_csv = True

        elif continue_from_epoch == 'latest':
            checkpoint = os.path.join(self.saved_models_filepath,
                                      "train_model_latest")
            print("attempting to find existing checkpoint", )
            if os.path.exists(checkpoint):
                try:
                    self.state = \
                        self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                              model_idx='latest')
                    self.start_epoch = int(self.state['current_iter'] /
                                           total_iter_per_epoch)
                except:
                    self.continue_from_epoch = 'from_scratch'
                    self.create_summary_csv = True

            else:
                self.continue_from_epoch = 'from_scratch'
                self.create_summary_csv = True
        elif int(continue_from_epoch) >= 0:
            checkpoint = os.path.join(
                self.saved_models_filepath,
                "train_model_{}".format(continue_from_epoch))
            if os.path.exists(checkpoint):
                self.state = \
                    self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                          model_idx=continue_from_epoch)
                self.start_epoch = int(self.state['current_iter'] /
                                       total_iter_per_epoch)
            else:
                self.continue_from_epoch = 'from_scratch'
                self.create_summary_csv = True

        self.data = data_dict
        self.total_iter_per_epoch = total_iter_per_epoch
        self.batch_size = batch_size
        self.total_epochs = total_epochs
        self.num_evaluation_tasks = num_evaluation_tasks

        print("train_seed {}, val_seed: {}, at start time".format(
            self.data["train"].dataset.seed, self.data["val"].dataset.seed))

        self.state['best_epoch'] = int(self.state['best_val_iter'] /
                                       total_iter_per_epoch)
        self.epoch = int(self.state['current_iter'] / total_iter_per_epoch)
        self.start_time = time.time()
        self.epochs_done_in_this_run = 0
        print(self.state['current_iter'],
              int(total_iter_per_epoch * total_epochs))