def __init__(self, args, data, model, device): """ Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system (model) and a device (e.g. gpu/cpu/n) :param args: A namedtuple containing all experiment hyperparameters :param data: A data provider of instance MetaLearningSystemDataLoader :param model: A meta learning system instance :param device: Device/s to use for the experiment """ self.args, self.device = args, device self.model = model self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder( experiment_name=self.args.experiment_name) self.total_losses = dict() self.state = dict() self.state['best_val_acc'] = 0. self.state['best_val_iter'] = 0 self.state['current_iter'] = 0 self.state['current_iter'] = 0 self.start_epoch = 0 self.max_models_to_save = self.args.max_models_to_save self.create_summary_csv = False if self.args.continue_from_epoch == 'from_scratch': self.create_summary_csv = True elif self.args.continue_from_epoch == 'latest': checkpoint = os.path.join(self.saved_models_filepath, "train_model_latest") print("attempting to find existing checkpoint", ) if os.path.exists(checkpoint): self.state = \ self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", model_idx='latest') self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) else: self.args.continue_from_epoch = 'from_scratch' self.create_summary_csv = True elif int(self.args.continue_from_epoch) >= 0: self.state = \ self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", model_idx=self.args.continue_from_epoch) self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) self.data = data(args=args, current_iter=self.state['current_iter']) print("train_seed {}, val_seed: {}, at start time".format(self.data.dataset.seed["train"], self.data.dataset.seed["val"])) self.total_epochs_before_pause = self.args.total_epochs_before_pause self.state['best_epoch'] = int(self.state['best_val_iter'] / self.args.total_iter_per_epoch) self.epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) self.augment_flag = True if 'omniglot' in self.args.dataset_name.lower() else False self.start_time = time.time() self.epochs_done_in_this_run = 0 print(self.state['current_iter'], int(self.args.total_iter_per_epoch * self.args.total_epochs))
def housekeeping(): argument_parser = get_base_argument_parser() args = process_args(argument_parser) if args.gpu_ids_to_use is None: select_devices( args.num_gpus_to_use, max_load=args.max_gpu_selection_load, max_memory=args.max_gpu_selection_memory, exclude_gpu_ids=args.excude_gpu_list, ) else: os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids_to_use.replace( " ", ",") saved_models_filepath, logs_filepath, images_filepath = build_experiment_folder( experiment_name=args.experiment_name, log_path=args.logs_path) args.saved_models_filepath = saved_models_filepath args.logs_filepath = logs_filepath args.images_filepath = images_filepath # Determinism Seeding can be annoying in pytorch at the moment. # Based on my experience, the below means of seeding allows for deterministic # experimentation. torch.manual_seed(args.seed) np.random.seed(args.seed) # set seed random.seed(args.seed) device = (torch.cuda.current_device() if torch.cuda.is_available() and args.num_gpus_to_use > 0 else "cpu") args.device = device if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True # Always save a snapshot of the current state of the code. I've found this helps # immensely if you find that one of your many experiments was actually quite good # but you forgot what you did snapshot_filename = f"{saved_models_filepath}/snapshot.tar.gz" filetypes_to_include = [".py"] all_files = [] for _ in filetypes_to_include: all_files += glob.glob("**/*.py", recursive=True) with tarfile.open(snapshot_filename, "w:gz") as tar: for file in all_files: tar.add(file) return args
train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng) val_data = CIFAR10DataProvider(which_set="valid", batch_size=batch_size, rng=rng) test_data = CIFAR10DataProvider(which_set="test", batch_size=batch_size, rng=rng) # setup our data providers print("Running {}".format(experiment_name)) print("Starting from epoch {}".format(continue_from_epoch)) saved_models_filepath, logs_filepath = build_experiment_folder( experiment_name, logs_path) # generate experiment dir # Placeholder setup data_inputs = tf.placeholder(tf.float32, [ batch_size, train_data.inputs.shape[1], train_data.inputs.shape[2], train_data.inputs.shape[3] ], 'data-inputs') data_targets = tf.placeholder(tf.int32, [batch_size], 'data-targets') training_phase = tf.placeholder(tf.bool, name='training-flag') rotate_data = tf.placeholder(tf.bool, name='rotate-flag') dropout_rate = tf.placeholder(tf.float32, name='dropout-prob') classifier_network = ClassifierNetworkGraph( input_x=data_inputs, target_placeholder=data_targets,
num_classes_per_set=args.classes_per_set, label_as_int=True) experiment = ExperimentBuilder(data) one_shot_omniglot, losses, c_error_opt_op, init = experiment.build_experiment( args.batch_size, args.classes_per_set, args.samples_per_class, args.use_full_context_embeddings, full_context_unroll_k=args.full_context_unroll_k, args=args) total_train_batches = args.total_iter_per_epoch total_val_batches = args.total_iter_per_epoch total_test_batches = args.total_iter_per_epoch saved_models_filepath, logs_filepath = build_experiment_folder( args.experiment_title) save_statistics(logs_filepath, [ "epoch", "total_train_c_loss_mean", "total_train_c_loss_std", "total_train_accuracy_mean", "total_train_accuracy_std", "total_val_c_loss_mean", "total_val_c_loss_std", "total_val_accuracy_mean", "total_val_accuracy_std", "total_test_c_loss_mean", "total_test_c_loss_std", "total_test_accuracy_mean", "total_test_accuracy_std" ], create=True) # Experiment initialization and running with tf.Session() as sess: sess.run(init) train_saver = tf.train.Saver()
def __init__(self, args, data, model, device): """ Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system (model) and a device (e.g. gpu/cpu/n) :param args: A namedtuple containing all experiment hyperparameters :param data: A data provider of instance MetaLearningSystemDataLoader :param model: A meta learning system instance :param device: Device/s to use for the experiment """ self.args, self.device = args, device num_thousand_iters = int(args.total_epochs * args.total_iter_per_epoch / 1000) if not args.TR_MAML: tmp_maml_str = '' else: tmp_maml_str = 'tr' if args.dataset_name == 'omniglot_dataset': tmp_ds = 'omni_' else: tmp_ds = 'mini_' experiment_name = tmp_ds + str(num_thousand_iters) + 'k' + str( args.batch_size) + 'bs_' + str( args.num_classes_per_set) + 'way' + str( args.num_samples_per_class) + 'shot_' + str( args.number_of_training_steps_per_iter ) + 'gs_' + tmp_maml_str + 'm_' + str( args.train_seed) + 'seed' self.model = model self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder( experiment_name) print(experiment_name) self.total_losses = dict() self.state = dict() self.state['best_val_acc'] = 0. self.state['best_val_iter'] = 0 self.state['current_iter'] = 0 self.state['current_iter'] = 0 self.start_epoch = 0 self.max_models_to_save = self.args.max_models_to_save self.create_summary_csv = False self.args.continue_from_epoch = 'from_scratch' if self.args.continue_from_epoch == 'from_scratch': self.create_summary_csv = True elif self.args.continue_from_epoch == 'latest': checkpoint = os.path.join(self.saved_models_filepath, "train_model_latest") print("attempting to find existing checkpoint", ) if os.path.exists(checkpoint): self.state = \ self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", model_idx='latest') self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) else: self.args.continue_from_epoch = 'from_scratch' self.create_summary_csv = True elif int(self.args.continue_from_epoch) >= 0: self.state = \ self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", model_idx=self.args.continue_from_epoch) self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) self.num_train_tasks = self.args.num_train_tasks self.num_test_tasks = self.args.num_test_tasks self.data = data(args=args, current_iter=self.state['current_iter']) print("train_seed {}, val_seed: {}, at start time".format( self.data.dataset.seed["train"], self.data.dataset.seed["val"])) self.total_epochs_before_pause = self.args.total_epochs_before_pause self.state['best_epoch'] = int(self.state['best_val_iter'] / self.args.total_iter_per_epoch) self.epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) self.augment_flag = True if 'omniglot' in self.args.dataset_name.lower( ) else False self.start_time = time.time() self.epochs_done_in_this_run = 0 print("CURRENT STATE") print(self.state['current_iter'], int(self.args.total_iter_per_epoch * self.args.total_epochs))
def __init__(self, parser, data): tf.reset_default_graph() args = parser.parse_args() self.continue_from_epoch = args.continue_from_epoch self.experiment_name = args.experiment_title self.saved_models_filepath, self.log_path, self.save_image_path = build_experiment_folder(self.experiment_name) self.num_gpus = args.num_of_gpus self.batch_size = args.batch_size gen_depth_per_layer = args.generator_inner_layers discr_depth_per_layer = args.discriminator_inner_layers self.z_dim = args.z_dim self.num_generations = args.num_generations self.dropout_rate_value = args.dropout_rate_value self.data = data self.reverse_channels = False generator_layers = [64, 64, 128, 128] discriminator_layers = [64, 64, 128, 128] gen_inner_layers = [gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer] discr_inner_layers = [discr_depth_per_layer, discr_depth_per_layer, discr_depth_per_layer, discr_depth_per_layer] generator_layer_padding = ["SAME", "SAME", "SAME", "SAME"] image_height = data.image_height image_width = data.image_width image_channel = data.image_channel self.input_x_i = tf.placeholder(tf.float32, [self.num_gpus, self.batch_size, image_height, image_width, image_channel], 'inputs-1') self.input_x_j = tf.placeholder(tf.float32, [self.num_gpus, self.batch_size, image_height, image_width, image_channel], 'inputs-2-same-class') self.z_input = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], 'z-input') self.training_phase = tf.placeholder(tf.bool, name='training-flag') self.random_rotate = tf.placeholder(tf.bool, name='rotation-flag') self.dropout_rate = tf.placeholder(tf.float32, name='dropout-prob') dagan = DAGAN(batch_size=self.batch_size, input_x_i=self.input_x_i, input_x_j=self.input_x_j, dropout_rate=self.dropout_rate, generator_layer_sizes=generator_layers, generator_layer_padding=generator_layer_padding, num_channels=data.image_channel, is_training=self.training_phase, augment=self.random_rotate, discriminator_layer_sizes=discriminator_layers, discr_inner_conv=discr_inner_layers, gen_inner_conv=gen_inner_layers, num_gpus=self.num_gpus, z_dim=self.z_dim, z_inputs=self.z_input) self.summary, self.losses, self.graph_ops = dagan.init_train() self.same_images = dagan.sample_same_images() self.total_train_batches = int(data.training_data_size / (self.batch_size * self.num_gpus)) self.total_gen_batches = int(data.generation_data_size / (self.batch_size * self.num_gpus)) self.init = tf.global_variables_initializer() self.spherical_interpolation = True self.tensorboard_update_interval = int(self.total_train_batches/100/self.num_gpus) self.total_epochs = 200 if self.continue_from_epoch == -1: save_statistics(self.log_path, ['epoch', 'total_d_train_loss_mean', 'total_d_val_loss_mean', 'total_d_train_loss_std', 'total_d_val_loss_std', 'total_g_train_loss_mean', 'total_g_val_loss_mean', 'total_g_train_loss_std', 'total_g_val_loss_std'], create=True)
trainloader, testloader, in_shape = load_dataset(args) n_train_batches = len(trainloader) n_train_images = len(trainloader.dataset) n_test_batches = len(testloader) n_test_images = len(testloader.dataset) print("Data loaded successfully ") print("Training --> {} images and {} batches".format(n_train_images, n_train_batches)) print("Testing --> {} images and {} batches".format(n_test_images, n_test_batches)) ######################################################################################################### Admin saved_models_filepath, logs_filepath, images_filepath = build_experiment_folder( args) start_epoch, latest_loadpath = get_start_epoch(args) args.latest_loadpath = latest_loadpath best_epoch, best_test_acc = get_best_epoch(args) if best_epoch >= 0: print('Best evaluation acc so far at {} epochs: {:0.2f}'.format( best_epoch, best_test_acc)) if not args.resume: save_statistics(logs_filepath, "result_summary_statistics", [ "epoch", "train_loss", "test_loss", "train_loss_c",
def __init__(self, args, data, model, device): """ Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system (model) and a device (e.g. gpu/cpu/n) :param args: A namedtuple containing all experiment hyperparameters :param data: A data provider of instance MetaLearningSystemDataLoader :param model: A meta learning system instance :param device: Device/s to use for the experiment """ self.args, self.device = args, device self.model = model ( self.saved_models_filepath, self.logs_filepath, self.samples_filepath, ) = build_experiment_folder(experiment_name=self.args.experiment_name) self.per_task_performance = defaultdict(lambda: 0) self.total_losses = dict() self.state = dict() self.state["best_val_loss"] = 10**6 self.state["best_val_accuracy"] = 0 self.state["best_val_iter"] = 0 self.state["current_iter"] = 0 self.start_epoch = 0 self.num_epoch_no_improvements = 0 self.patience = args.patience self.create_summary_csv = False self.writer = SummaryWriter("runs/{}".format( self.args.experiment_name)) if self.args.continue_from_epoch == "from_scratch": self.create_summary_csv = True elif self.args.continue_from_epoch == "latest": checkpoint = os.path.join(self.saved_models_filepath, "train_model_latest") print("attempting to find existing checkpoint", ) if os.path.exists(checkpoint): self.state = self.model.load_model( model_save_dir=self.saved_models_filepath, model_name="train_model", model_idx="latest", ) self.start_epoch = int(self.state["current_iter"] / self.args.total_iter_per_epoch) else: self.args.continue_from_epoch = "from_scratch" self.create_summary_csv = True elif int(self.args.continue_from_epoch) >= 0: self.state = self.model.load_model( model_save_dir=self.saved_models_filepath, model_name="train_model", model_idx=self.args.continue_from_epoch, ) self.start_epoch = int(self.state["current_iter"] / self.args.total_iter_per_epoch) self.data = data(args=args, current_iter=self.state["current_iter"]) self.idx_to_class_name = self.data.dataset.load_from_json( self.data.dataset.index_to_label_name_dict_file) print("train_seed {}, val_seed: {}, at start time".format( self.data.dataset.seed["train"], self.data.dataset.seed["val"])) self.total_epochs_before_pause = self.args.total_epochs_before_pause self.state["best_epoch"] = int(self.state["best_val_iter"] / self.args.total_iter_per_epoch) self.epoch = int(self.state["current_iter"] / self.args.total_iter_per_epoch) self.start_time = time.time() self.epochs_done_in_this_run = 0 print( self.state["current_iter"], int(self.args.total_iter_per_epoch * self.args.total_epochs), ) if self.epoch == 0: for param_name, param in self.model.named_parameters(): self.writer.add_histogram(param_name, param, 0) self.writer.flush()
def __init__(self, args, data): tf.reset_default_graph() self.continue_from_epoch = args.continue_from_epoch self.experiment_name = args.experiment_title self.saved_models_filepath, self.log_path, self.save_image_path = build_experiment_folder( self.experiment_name) self.num_gpus = args.num_of_gpus self.batch_size = args.batch_size # self.support_number = args.support_number self.selected_classes = args.selected_classes gen_depth_per_layer = args.generator_inner_layers discr_depth_per_layer = args.discriminator_inner_layers self.z_dim = args.z_dim self.num_generations = args.num_generations self.dropout_rate_value = args.dropout_rate_value self.data = data self.reverse_channels = False if args.generation_layers == 6: generator_layers = [64, 64, 128, 128, 256, 256] gen_inner_layers = [ gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer ] generator_layer_padding = [ "SAME", "SAME", "SAME", "SAME", "SAME", "SAME" ] else: generator_layers = [64, 64, 128, 128] gen_inner_layers = [ gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer ] generator_layer_padding = ["SAME", "SAME", "SAME", "SAME"] discriminator_layers = [64, 64, 128, 128] discr_inner_layers = [ discr_depth_per_layer, discr_depth_per_layer, discr_depth_per_layer, discr_depth_per_layer ] image_height = data.image_height image_width = data.image_width image_channel = data.image_channel self.classes = tf.placeholder(tf.int32) self.selected_classes = tf.placeholder(tf.int32) self.support_number = tf.placeholder(tf.int32) #### [self.input_x_i, self.input_y_i, self.input_global_y_i] --> [images, few shot label, global label] ## batch: [self.input_x_i, self.input_y_i, self.input_global_y_i] ## support: self.input_x_j, self.input_y_j, self.input_global_y_j] ## the input of discriminator: [self.input_x_j_selected, self.input_global_y_j_selected] self.input_x_i = tf.placeholder(tf.float32, [ self.num_gpus, self.batch_size, image_height, image_width, image_channel ], 'batch') self.input_y_i = tf.placeholder( tf.float32, [self.num_gpus, self.batch_size, self.data.selected_classes], 'y_inputs_bacth') self.input_global_y_i = tf.placeholder( tf.float32, [self.num_gpus, self.batch_size, self.data.training_classes], 'y_inputs_bacth_global') self.input_x_j = tf.placeholder(tf.float32, [ self.num_gpus, self.batch_size, self.data.selected_classes * self.data.support_number, image_height, image_width, image_channel ], 'support') self.input_y_j = tf.placeholder(tf.float32, [ self.num_gpus, self.batch_size, self.data.selected_classes * self.data.support_number, self.data.selected_classes ], 'y_inputs_support') self.input_global_y_j = tf.placeholder(tf.float32, [ self.num_gpus, self.batch_size, self.data.selected_classes * self.data.support_number, self.data.training_classes ], 'y_inputs_support_global') self.input_x_j_selected = tf.placeholder(tf.float32, [ self.num_gpus, self.batch_size, image_height, image_width, image_channel ], 'support_discriminator') self.input_global_y_j_selected = tf.placeholder( tf.float32, [self.num_gpus, self.batch_size, self.data.training_classes], 'y_inputs_support_discriminator') # self.z_input = tf.placeholder(tf.float32, [self.batch_size*self.data.selected_classes, self.z_dim], 'z-input') # self.z_input_2 = tf.placeholder(tf.float32, [self.batch_size*self.data.selected_classes, self.z_dim], 'z-input_2') self.z_input = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], 'z-input') self.z_input_2 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], 'z-input_2') self.training_phase = tf.placeholder(tf.bool, name='training-flag') self.z1z2_training = tf.placeholder(tf.bool, name='z1z2_training-flag') self.random_rotate = tf.placeholder(tf.bool, name='rotation-flag') self.dropout_rate = tf.placeholder(tf.float32, name='dropout-prob') self.is_z2 = args.is_z2 self.is_z2_vae = args.is_z2_vae self.matching = args.matching self.fce = args.fce self.full_context_unroll_k = args.full_context_unroll_k self.average_per_class_embeddings = args.average_per_class_embeddings self.restore_path = args.restore_path self.is_z2 = args.is_z2 self.is_z2_vae = args.is_z2_vae self.loss_G = args.loss_G self.loss_D = args.loss_D self.loss_CLA = args.loss_CLA self.loss_FSL = args.loss_FSL self.loss_KL = args.loss_KL self.loss_recons_B = args.loss_recons_B self.loss_matching_G = args.loss_matching_G self.loss_matching_D = args.loss_matching_D self.loss_sim = args.loss_sim self.strategy = args.strategy #### training/validation/testin time_1 = time.time() dagan = DAGAN(batch_size=self.batch_size, input_x_i=self.input_x_i, input_x_j=self.input_x_j, input_y_i=self.input_y_i, input_y_j=self.input_y_j, input_global_y_i=self.input_global_y_i, input_global_y_j=self.input_global_y_j, input_x_j_selected=self.input_x_j_selected, input_global_y_j_selected=self.input_global_y_j_selected, \ selected_classes=self.data.selected_classes, support_num=self.data.support_number, classes=self.data.training_classes, dropout_rate=self.dropout_rate, generator_layer_sizes=generator_layers, generator_layer_padding=generator_layer_padding, num_channels=data.image_channel, is_training=self.training_phase, augment=self.random_rotate, discriminator_layer_sizes=discriminator_layers, discr_inner_conv=discr_inner_layers, is_z2=self.is_z2, is_z2_vae=self.is_z2_vae, gen_inner_conv=gen_inner_layers, num_gpus=self.num_gpus, z_dim=self.z_dim, z_inputs=self.z_input, z_inputs_2=self.z_input_2, use_wide_connections=args.use_wide_connections, fce=self.fce, matching=self.matching, full_context_unroll_k=self.full_context_unroll_k, average_per_class_embeddings=self.average_per_class_embeddings, loss_G=self.loss_G, loss_D=self.loss_D, loss_KL=self.loss_KL, loss_recons_B=self.loss_recons_B, loss_matching_G=self.loss_matching_G, loss_matching_D=self.loss_matching_D, loss_CLA=self.loss_CLA, loss_FSL=self.loss_FSL, loss_sim=self.loss_sim, z1z2_training=self.z1z2_training) self.same_images = dagan.sample_same_images() # self.summary, self.losses, self.accuracy, self.graph_ops = classifier.init_train() self.total_train_batches = int(data.training_data_size / (self.batch_size * self.num_gpus)) self.total_val_batches = int(data.validation_data_size / (self.batch_size * self.num_gpus)) self.total_test_batches = int(data.testing_data_size / (self.batch_size * self.num_gpus)) self.total_gen_batches = int(data.testing_data_size / (self.batch_size * self.num_gpus)) self.init = tf.global_variables_initializer() time_2 = time.time() # print('time for constructing graph:',time_2 - time_1) self.tensorboard_update_interval = int(self.total_train_batches / 1 / self.num_gpus) self.total_epochs = 800 self.is_generation_for_classifier = args.is_generation_for_classifier self.is_all_test_categories = args.is_all_test_categories
def __init__(self, data_dict, model, experiment_name, continue_from_epoch, max_models_to_save, total_iter_per_epoch, total_epochs, num_evaluation_tasks, batch_size, evaluate_on_test_set_only, args): """ Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system (model) and a device (e.g. gpu/cpu/n) :param args: A namedtuple containing all experiment hyperparameters :param data: A data provider of instance MetaLearningSystemDataLoader :param model: A meta learning system instance :param device: Device/s to use for the experiment """ self.model = model self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder( experiment_name=experiment_name) self.total_losses = dict() self.state = dict() self.state['best_val_acc'] = 0. self.state['best_val_iter'] = 0 self.state['current_iter'] = 0 self.state['current_iter'] = 0 self.start_epoch = 0 self.max_models_to_save = max_models_to_save self.create_summary_csv = False self.evaluate_on_test_set_only = evaluate_on_test_set_only for key, value in args.__dict__.items(): setattr(self, key, value) if continue_from_epoch == 'from_scratch': self.create_summary_csv = True elif continue_from_epoch == 'latest': checkpoint = os.path.join(self.saved_models_filepath, "train_model_latest") print("attempting to find existing checkpoint", ) if os.path.exists(checkpoint): try: self.state = \ self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", model_idx='latest') self.start_epoch = int(self.state['current_iter'] / total_iter_per_epoch) except: self.continue_from_epoch = 'from_scratch' self.create_summary_csv = True else: self.continue_from_epoch = 'from_scratch' self.create_summary_csv = True elif int(continue_from_epoch) >= 0: checkpoint = os.path.join( self.saved_models_filepath, "train_model_{}".format(continue_from_epoch)) if os.path.exists(checkpoint): self.state = \ self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", model_idx=continue_from_epoch) self.start_epoch = int(self.state['current_iter'] / total_iter_per_epoch) else: self.continue_from_epoch = 'from_scratch' self.create_summary_csv = True self.data = data_dict self.total_iter_per_epoch = total_iter_per_epoch self.batch_size = batch_size self.total_epochs = total_epochs self.num_evaluation_tasks = num_evaluation_tasks print("train_seed {}, val_seed: {}, at start time".format( self.data["train"].dataset.seed, self.data["val"].dataset.seed)) self.state['best_epoch'] = int(self.state['best_val_iter'] / total_iter_per_epoch) self.epoch = int(self.state['current_iter'] / total_iter_per_epoch) self.start_time = time.time() self.epochs_done_in_this_run = 0 print(self.state['current_iter'], int(total_iter_per_epoch * total_epochs))