def initialize(self): self.model = ModelBuilder(self.config).get_model() self.model.config_name = self.config_name try: self.model.set_training(True) except AttributeError: pass self.model.initialize() self.saver_sparse = tf.train.Saver( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.model.get_variable_scope()))
def initialize_profile(self): self.config['batch_size'] = str(bb) self.model = ModelBuilder(self.config).get_model() self.config['batch_size'] = str(self.num_batch) self.model.config_name = self.config_name try: self.model.set_training(False) except AttributeError: pass self.model.initialize() self.saver_sparse = tf.train.Saver( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.model.get_variable_scope()))
def train_chord_to_melody_model(self, tt_split=0.9, epochs=100, model_name='basic_rnn'): ''' Train model step - model takes in chord piano roll and outputs melody piano roll. :param tt_split: train test split :param epochs: number of epochs to train :param model_name: specify which model we are training :return: None. Model is assigned as self.model for this generator ''' # Train test split self.__prepare_data_tt_splited(tt_split=tt_split, model_name=model_name, src="nottingham-embed") # print('Chords shape: {} Melodies shape: {}'.format(chords.shape, melodies.shape)) # Load / train model if model_name == 'basic_rnn': if os.path.exists("basic_rnn.h5"): mb = ModelBuilder(self.X_train, self.Y_train, self.X_test, self.Y_test) model = mb.build_basic_rnn_model( input_dim=self.X_train.shape[1:]) model.load_weights("basic_rnn.h5") else: mb = ModelBuilder(self.X_train, self.Y_train, self.X_test, self.Y_test) model = mb.build_attention_bidirectional_rnn_model( input_dim=self.X_train.shape[1:]) model = mb.train_model(model, epochs, loss="categorical_crossentropy") model.save_weights("basic_rnn.h5") self.model = model
class SparseConvClusteringTrainer: def read_config(self, config_file_path, config_name): config_file = cp.ConfigParser() config_file.read(config_file_path) self.config = config_file[config_name] def __init__(self, config_file, config_name): self.read_config(config_file, config_name) self.config_name = config_name self.from_scratch = int(self.config['from_scratch']) == 1 self.model_path = self.config['model_path'] self.summary_path = self.config['summary_path'] self.test_out_path = self.config['test_out_path'] self.profile_out_path = self.config['profiler_out_path'] self.train_for_iterations = int(self.config['train_for_iterations']) self.save_after_iterations = int(self.config['save_after_iterations']) self.learning_rate = float(self.config['learning_rate']) self.training_files = self.config['training_files_list'] self.validation_files = self.config['validation_files_list'] self.test_files = self.config['test_files_list'] self.validate_after = int(self.config['validate_after']) self.num_testing_samples = int(self.config['num_testing_samples']) self.num_batch = int(self.config['batch_size']) self.num_max_entries = int(self.config['max_entries']) self.num_data_dims = int(self.config['num_data_dims']) try: self.output_seed_indices = int( self.config['output_seed_indices_in_inference']) == 1 except KeyError: self.output_seed_indices = False try: self.plotting_input_file_path = self.config[ 'plotting_input_file_path'] except KeyError: self.plotting_input_file_path = None if self.plotting_input_file_path is not None: try: self.plot_after = int(self.config['plot_after']) except KeyError: raise RuntimeError( "Setting plot after but haven't set the plotting input path" ) else: self.plot_after = -1 self.spatial_features_indices = tuple([ int(x) for x in (self.config['input_spatial_features_indices']).split(',') ]) self.spatial_features_local_indices = tuple([ int(x) for x in ( self.config['input_spatial_features_local_indices']).split(',') ]) self.other_features_indices = tuple([ int(x) for x in (self.config['input_other_features_indices']).split(',') ]) self.target_indices = tuple( [int(x) for x in (self.config['target_indices']).split(',')]) self.reader_type = self.config['reader_type'] if len(self.config['reader_type']) != 0 else \ "data_and_num_entries_reader" self.reader_factory = ReaderFactory() self.model = None def initialize(self): self.model = ModelBuilder(self.config).get_model() self.model.config_name = self.config_name try: self.model.set_training(True) except AttributeError: pass self.model.initialize() self.saver_sparse = tf.train.Saver( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.model.get_variable_scope())) def initialize_test(self): self.model = ModelBuilder(self.config).get_model() self.model.config_name = self.config_name try: self.model.set_training(False) except AttributeError: pass self.model.initialize() self.saver_sparse = tf.train.Saver( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.model.get_variable_scope())) def initialize_profile(self): self.config['batch_size'] = str(bb) self.model = ModelBuilder(self.config).get_model() self.config['batch_size'] = str(self.num_batch) self.model.config_name = self.config_name try: self.model.set_training(False) except AttributeError: pass self.model.initialize() self.saver_sparse = tf.train.Saver( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.model.get_variable_scope())) def clean_summary_dir(self): print("Cleaning summary dir") for the_file in os.listdir(self.summary_path): file_path = os.path.join(self.summary_path, the_file) try: if os.path.isfile(file_path): os.unlink(file_path) except Exception as e: print(e) def profile(self): global bb tf.reset_default_graph() self.initialize_profile() print("Beginning to profile network with parameters", get_num_parameters(self.model.get_variable_scope())) placeholders = self.model.get_placeholders() subprocess.call("mkdir -p %s" % (self.profile_out_path), shell=True) graph_output = self.model.get_compute_graphs() inputs_feed = self.reader_factory.get_class( self.reader_type)(self.training_files, self.num_max_entries, self.num_data_dims, bb).get_feeds() init = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) inference_time_values = [] with tf.Session(config=session_conf) as sess: # with tf.Session() as sess: sess.run(init) profiler = Profiler(sess.graph) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) iteration_number = 0 print("Starting iterations") while iteration_number < 20: inputs_train = sess.run(list(inputs_feed)) if len(placeholders) == 5: inputs_train_dict = { placeholders[0]: inputs_train[0][:, :, self.spatial_features_indices], placeholders[1]: inputs_train[0][:, :, self.spatial_features_local_indices], placeholders[2]: inputs_train[0][:, :, self.other_features_indices], placeholders[3]: inputs_train[0][:, :, self.target_indices], placeholders[4]: inputs_train[1], self.model.is_train: True, self.model.learning_rate: 1 } else: inputs_train_dict = { placeholders[0]: inputs_train[0][:, :, self.spatial_features_indices], placeholders[1]: inputs_train[0][:, :, self.spatial_features_local_indices], placeholders[2]: inputs_train[0][:, :, self.other_features_indices], placeholders[3]: inputs_train[0][:, :, self.target_indices], placeholders[4]: inputs_train[1], placeholders[5]: inputs_train[2], self.model.is_train: True, self.model.learning_rate: 1 } run_meta = tf.RunMetadata() start_time = time.time() eval_output = sess.run( graph_output, feed_dict=inputs_train_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_meta) print("XC Time: --- %s seconds --- Iteration %d" % (time.time() - start_time, iteration_number)) profiler.add_step(iteration_number, run_meta) # Or profile the timing of your model operations. opts = option_builder.ProfileOptionBuilder.time_and_memory() profiler.profile_operations(options=opts) # Or you can generate a timeline: opts = (option_builder.ProfileOptionBuilder( option_builder.ProfileOptionBuilder.time_and_memory() ).with_step(iteration_number).with_timeline_output( os.path.join(self.profile_out_path, 'profile')).build()) x = profiler.profile_graph(options=opts) inference_time_values.append(x.total_exec_micros) peak_bytes = x.total_peak_bytes iteration_number += 1 print(self.config_name, "Batch size: ", bb) print(repr(np.array(inference_time_values))) print( "Mean", np.mean(np.array(inference_time_values, dtype=np.float32)[1:])) print( "Variance", np.std(np.array(inference_time_values, dtype=np.float32)[1:])) print("Peak bytes", peak_bytes) # Stop the threads coord.request_stop() # Wait for threads to stop coord.join(threads) def train(self): self.initialize() print("Beginning to train network with parameters", get_num_parameters(self.model.get_variable_scope())) print("Variable scope:", self.model.get_variable_scope()) placeholders = self.model.get_placeholders() if self.from_scratch: subprocess.call("mkdir -p %s" % (self.summary_path), shell=True) subprocess.call("mkdir -p %s" % (self.test_out_path), shell=True) subprocess.call("mkdir -p %s" % (os.path.join(self.test_out_path, 'ops')), shell=True) with open(self.model_path + '_code.py', 'w') as f: f.write(self.model.get_code()) ops_parent = os.path.dirname(ops.__file__) for ops_file in os.listdir(ops_parent): if not ops_file.endswith('.py'): continue shutil.copy(os.path.join(ops_parent, ops_file), os.path.join(self.test_out_path, 'ops')) graph_loss = self.model.get_losses() graph_optmiser = self.model.get_optimizer() graph_summary = self.model.get_summary() graph_summary_validation = self.model.get_summary_validation() graph_output = self.model.get_compute_graphs() graph_temp = self.model.get_temp() if self.plot_after != -1: data_plotting = None # TODO: Load if self.from_scratch: self.clean_summary_dir() inputs_feed = self.reader_factory.get_class( self.reader_type)(self.training_files, self.num_max_entries, self.num_data_dims, self.num_batch).get_feeds() inputs_validation_feed = self.reader_factory.get_class( self.reader_type)(self.validation_files, self.num_max_entries, self.num_data_dims, self.num_batch).get_feeds(shuffle=False) print("\n****************************************") print("Feed Input type", type(inputs_feed[0])) print("Feed Input shape: ", inputs_feed[0].get_shape().as_list()) init = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) summary_writer = tf.summary.FileWriter(self.summary_path, sess.graph) if not self.from_scratch: self.saver_sparse.restore(sess, self.model_path) print("\n\nINFO: Loading model\n\n") with open(self.model_path + '.txt', 'r') as f: iteration_number = int(f.read()) else: iteration_number = 0 print("Starting iterations") while iteration_number < self.train_for_iterations: inputs_train = sess.run(list(inputs_feed)) print("\n****************************************") print("Input Train type", type(inputs_train[0])) print("Input Train shape: ", inputs_train[0].shape) learning_rate = 1 if hasattr(self.model, "learningrate_scheduler"): learning_rate = self.model.learningrate_scheduler.get_lr( iteration_number) else: learning_rate = self.model.learning_rate if iteration_number == 0: print('learning rate ', learning_rate) if len(placeholders) == 5: inputs_train_dict = { placeholders[0]: inputs_train[0][:, :, self.spatial_features_indices], placeholders[1]: inputs_train[0][:, :, self.spatial_features_local_indices], placeholders[2]: inputs_train[0][:, :, self.other_features_indices], placeholders[3]: inputs_train[0][:, :, self.target_indices], placeholders[4]: inputs_train[1], self.model.is_train: True, self.model.learning_rate: learning_rate } else: inputs_train_dict = { placeholders[0]: inputs_train[0][:, :, self.spatial_features_indices], placeholders[1]: inputs_train[0][:, :, self.spatial_features_local_indices], placeholders[2]: inputs_train[0][:, :, self.other_features_indices], placeholders[3]: inputs_train[0][:, :, self.target_indices], placeholders[4]: inputs_train[1], # placeholders[5]: inputs_train[2], self.model.is_train: True, self.model.learning_rate: learning_rate } t, eval_loss, _, eval_summary, eval_output = sess.run( [ graph_temp, graph_loss, graph_optmiser, graph_summary, graph_output ], feed_dict=inputs_train_dict) if self.plot_after != -1: if iteration_number % self.plot_after == 0: pass if iteration_number % self.validate_after == 0: inputs_validation = sess.run(list(inputs_validation_feed)) self.inputs_plot = inputs_validation if len(placeholders) == 5: inputs_validation_dict = { placeholders[0]: inputs_validation[0] [:, :, self.spatial_features_indices], placeholders[1]: inputs_validation[0] [:, :, self.spatial_features_local_indices], placeholders[2]: inputs_validation[0][:, :, self.other_features_indices], placeholders[3]: inputs_validation[0][:, :, self.target_indices], placeholders[4]: inputs_validation[1], self.model.is_train: False, self.model.learning_rate: learning_rate } else: inputs_validation_dict = { placeholders[0]: inputs_validation[0] [:, :, self.spatial_features_indices], placeholders[1]: inputs_validation[0] [:, :, self.spatial_features_local_indices], placeholders[2]: inputs_validation[0][:, :, self.other_features_indices], placeholders[3]: inputs_validation[0][:, :, self.target_indices], placeholders[4]: inputs_validation[1], # placeholders[5]: inputs_validation[2], self.model.is_train: False, self.model.learning_rate: learning_rate } eval_loss_validation, eval_summary_validation = sess.run( [graph_loss, graph_summary_validation], feed_dict=inputs_validation_dict) summary_writer.add_summary(eval_summary_validation, iteration_number) print("Validation - Iteration %4d: loss %.6E" % (iteration_number, eval_loss_validation)) print("Training - Iteration %4d: loss %0.6E" % (iteration_number, eval_loss)) print(t[0]) iteration_number += 1 summary_writer.add_summary(eval_summary, iteration_number) if iteration_number % self.save_after_iterations == 0: print("\n\nINFO: Saving model\n\n") self.saver_sparse.save(sess, self.model_path) with open(self.model_path + '.txt', 'w') as f: f.write(str(iteration_number)) # Stop the threads coord.request_stop() # Wait for threads to stop coord.join(threads) def visualize(self): self.initialize_test() print("Beginning to visualize network with parameters", get_num_parameters(self.model.get_variable_scope())) placeholders = self.model.get_placeholders() graph_loss = self.model.get_losses() graph_output = self.model.get_compute_graphs() graph_temp = self.model.get_temp() layer_feats = self.model.temp_feat_visualize inputs_feed = self.reader_factory.get_class(self.reader_type)( self.test_files, self.num_max_entries, self.num_data_dims, self.num_batch).get_feeds(shuffle=False) init = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) self.saver_sparse.restore(sess, self.model_path) print("\n\nINFO: Loading model", self.model_path, "\n\n") print("Starting visualizing") iteration_number = 0 while iteration_number < int( np.ceil(self.num_testing_samples / self.num_batch)): inputs_test = sess.run(list(inputs_feed)) print("Run") if len(placeholders) == 5: inputs_train_dict = { placeholders[0]: inputs_test[0][:, :, self.spatial_features_indices], placeholders[1]: inputs_test[0][:, :, self.spatial_features_local_indices], placeholders[2]: inputs_test[0][:, :, self.other_features_indices], placeholders[3]: inputs_test[0][:, :, self.target_indices], placeholders[4]: inputs_test[1], self.model.is_train: False, self.model.learning_rate: 0 } else: inputs_train_dict = { placeholders[0]: inputs_test[0][:, :, self.spatial_features_indices], placeholders[1]: inputs_test[0][:, :, self.spatial_features_local_indices], placeholders[2]: inputs_test[0][:, :, self.other_features_indices], placeholders[3]: inputs_test[0][:, :, self.target_indices], placeholders[4]: inputs_test[1], placeholders[5]: inputs_test[2], self.model.is_train: False, self.model.learning_rate: 0 } eval_out = sess.run([graph_temp, graph_loss, graph_output] + layer_feats, feed_dict=inputs_train_dict) layer_outs = eval_out[3:] prediction = eval_out[2] if iteration_number * self.num_batch + self.num_batch >= 32: for x in range(32): event_number = (32 + x) % self.num_batch print("Event number", event_number) seed_index = inputs_test[2][event_number, :] print(seed_index) spatial_features = inputs_test[0][ event_number, :, :][:, self.spatial_features_indices] energy = inputs_test[0][event_number, :, :][:, 0] gt = inputs_test[0][ event_number, :, :][:, self.target_indices] predictionx = prediction[event_number] layer_outsx = [x[event_number] for x in layer_outs] if 'aggregators' in self.config_name: plots.plot_clustering_layer_wise_visualize_agg( spatial_features, energy, predictionx, gt, layer_outsx, self.config_name) else: plots.plot_clustering_layer_wise_visualize( spatial_features, energy, predictionx, gt, layer_outsx, self.config_name) sys.exit(0) # Put the condition here! iteration_number += 1 # Stop the threads coord.request_stop() # Wait for threads to stop coord.join(threads) def test(self): self.initialize_test() print("Beginning to test network with parameters", get_num_parameters(self.model.get_variable_scope())) placeholders = self.model.get_placeholders() graph_loss = self.model.get_losses() graph_optmiser = self.model.get_optimizer() graph_summary = self.model.get_summary() graph_summary_validation = self.model.get_summary_validation() graph_output = self.model.get_compute_graphs() graph_temp = self.model.get_temp() inputs_feed = self.reader_factory.get_class(self.reader_type)( self.test_files, self.num_max_entries, self.num_data_dims, self.num_batch).get_feeds(shuffle=False) inference_streamer = InferenceOutputStreamer( output_path=self.test_out_path, cache_size=100) inference_streamer.start_thread() print(type(inputs_feed)) print("****************************************") print("Test Input shape: ", inputs_feed.get_shape().as_list()) init = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) self.saver_sparse.restore(sess, self.model_path) print("\n\nINFO: Loading model", self.model_path, "\n\n") print("Starting testing") iteration_number = 0 while iteration_number < int( np.ceil(self.num_testing_samples / self.num_batch)): inputs_test = sess.run(list(inputs_feed)) if len(placeholders) == 5: inputs_train_dict = { placeholders[0]: inputs_test[0][:, :, self.spatial_features_indices], placeholders[1]: inputs_test[0][:, :, self.spatial_features_local_indices], placeholders[2]: inputs_test[0][:, :, self.other_features_indices], placeholders[3]: inputs_test[0][:, :, self.target_indices], placeholders[4]: inputs_test[1], self.model.is_train: False, self.model.learning_rate: 0 } else: inputs_train_dict = { placeholders[0]: inputs_test[0][:, :, self.spatial_features_indices], placeholders[1]: inputs_test[0][:, :, self.spatial_features_local_indices], placeholders[2]: inputs_test[0][:, :, self.other_features_indices], placeholders[3]: inputs_test[0][:, :, self.target_indices], placeholders[4]: inputs_test[1], # placeholders[5]: inputs_test[2], self.model.is_train: False, self.model.learning_rate: 0 } t, eval_loss, eval_output = sess.run( [graph_temp, graph_loss, graph_output], feed_dict=inputs_train_dict) print("Adding", len(inputs_test[0]), "test results") for i in range(len(inputs_test[0])): if not self.output_seed_indices: inference_streamer.add( (inputs_test[0][i], (inputs_test[1])[i, 0], eval_output[i])) else: inference_streamer.add( (inputs_test[0][i], (inputs_test[1])[i, 0], eval_output[i], inputs_test[2][i])) print("Testing - Sample %4d: loss %0.5f" % (iteration_number * self.num_batch, eval_loss)) print(t[0]) iteration_number += 1 # Stop the threads coord.request_stop() # Wait for threads to stop coord.join(threads) inference_streamer.close()
def main(): print('Using GPU: ', torch.cuda.get_device_name(0)) start_time = time.time() # Prepare configuration config = Configuration() cf = config.load() # Enable log file logger_debug = Logger(cf.log_file_debug) logger_debug.write('\n ---------- Init experiment: ' + cf.exp_name + ' ---------- \n') # Model building logger_debug.write('- Building model: ' + cf.model_name + ' <--- ') model = ModelBuilder(cf) model.build() print(model.net) # Problem type if cf.problem_type == 'segmentation': problem_manager = SemanticSegmentationManager(cf, model) elif cf.problem_type == 'classification': problem_manager = ClassificationManager(cf, model) elif cf.problem_type == 'detection': problem_manager = DetectionManager(cf, model) else: raise ValueError('Unknown problem type') # Create dataloader builder dataloader = DataLoaderBuilder(cf, model) if cf.train: model.net.train() # enable dropout modules and others train_time = time.time() logger_debug.write('\n- Reading Train dataset: ') dataloader.build_train() if (cf.valid_dataset_path is not None or (cf.valid_images_txt is not None and cf.valid_gt_txt is not None)) and cf.valid_samples_epoch != 0: logger_debug.write('\n- Reading Validation dataset: ') dataloader.build_valid(cf.valid_samples_epoch, cf.valid_images_txt, cf.valid_gt_txt, cf.resize_image_valid, cf.valid_batch_size) problem_manager.trainer.start(dataloader.train_loader, dataloader.train_set, dataloader.valid_set, dataloader.valid_loader) else: # Train without validation inside epoch problem_manager.trainer.start(dataloader.train_loader, dataloader.train_set) train_time = time.time() - train_time logger_debug.write('\t Train step finished: %ds ' % train_time) if cf.validation: model.net.eval() valid_time = time.time() if not cf.train: logger_debug.write('- Reading Validation dataset: ') dataloader.build_valid(cf.valid_samples, cf.valid_images_txt, cf.valid_gt_txt, cf.resize_image_valid, cf.valid_batch_size) else: # If the Dataloader for validation was used on train, only update the total number of images to take dataloader.valid_set.update_indexes( cf.valid_samples, valid=True) # valid=True avoids shuffle for validation logger_debug.write('\n- Starting validation <---') problem_manager.validator.start(dataloader.valid_set, dataloader.valid_loader, 'Validation') valid_time = time.time() - valid_time logger_debug.write('\t Validation step finished: %ds ' % valid_time) if cf.test: model.net.eval() test_time = time.time() logger_debug.write('\n- Reading Test dataset: ') dataloader.build_valid(cf.test_samples, cf.test_images_txt, cf.test_gt_txt, cf.resize_image_test, cf.test_batch_size) logger_debug.write('\n - Starting test <---') problem_manager.validator.start(dataloader.valid_set, dataloader.valid_loader, 'Test') test_time = time.time() - test_time logger_debug.write('\t Test step finished: %ds ' % test_time) if cf.predict_test: model.net.eval() pred_time = time.time() logger_debug.write('\n- Reading Prediction dataset: ') dataloader.build_predict() logger_debug.write('\n - Generating predictions <---') problem_manager.predictor.start(dataloader.predict_loader) pred_time = time.time() - pred_time logger_debug.write('\t Prediction step finished: %ds ' % pred_time) total_time = time.time() - start_time logger_debug.write('\n- Experiment finished: %ds ' % total_time) logger_debug.write('\n')
def load_model(self, model_name, tt_split=0.9, is_fast_load=True): # clear session to avoid any errors K.clear_session() print("Chosen model: {}".format(model_name)) if not is_fast_load: # Train test split if model_name == 'bidem' or model_name == 'attention' or model_name == "bidem_preload": self.__prepare_data_tt_splited(tt_split=tt_split, model_name=model_name, src='nottingham-embed') print('Chords shape: {} Melodies shape: {}'.format( self.X_train.shape, self.Y_train.shape)) else: self.__prepare_data_tt_splited(tt_split=tt_split, model_name=model_name, src='nottingham') print('Chords shape: {} Melodies shape: {}'.format( self.X_train.shape, self.Y_train.shape)) if is_fast_load: mb = ModelBuilder(None, None, None, None) else: mb = ModelBuilder(self.X_train, self.Y_train, self.X_test, self.Y_test) if model_name == 'basic_rnn_normalized': self.model = mb.build_basic_rnn_model(input_dim=(1200, 128)) weights_path = '../note/active_models/basic_rnn_weights_500.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'basic_rnn_unnormalized': self.model = mb.build_basic_rnn_model(input_dim=(1200, 128)) weights_path = '../note/active_models/basic_rnn_weights_500_unnormalized.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'bidem': self.model = mb.build_bidirectional_rnn_model(input_dim=(1200, )) weights_path = '../note/active_models/bidem_weights_500.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'bidem_regularized': self.model = mb.build_bidirectional_rnn_model_no_embeddings( input_dim=(1200, 1)) weights_path = '../note/active_models/bidirectional_regularized_500.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'attention': self.model = mb.build_attention_bidirectional_rnn_model( input_dim=(1200, )) weights_path = '../note/active_models/attention_weights_1000.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'bidem_preload': self.model = mb.build_bidirectional_rnn_model_no_embeddings( input_dim=(None, 32)) weights_path = '../note/active_models/bidirectional_embedding_preload_100.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) else: print('No model name: {}'.format(model_name)) return self.model_name = model_name
currentdir = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe()))) parentdir = os.path.dirname(currentdir) sys.path.insert(0, parentdir) from agents.actions.sc2_wrapper import SC2Wrapper, TerranWrapper, ProtossWrapper from agents.states.sc2 import Simple64State from models.model_builder import ModelBuilder from utils.logger import Logger state_builder = Simple64State() action_wrapper = TerranWrapper() helper = ModelBuilder() helper.add_input_layer(int(state_builder.get_state_dim())) helper.add_fullyconn_layer(512) helper.add_fullyconn_layer(256) helper.add_output_layer(action_wrapper.get_action_space_dim()) b = Logger(10000) print("Saving b") b.save(".") print(b.ep_total) b.ep_total = 900 print(b.ep_total) print("Loading b") b.load(".")