def setup_model(self, load_checkpoint=None, print_model_summary=False): self.checkpoints_path = os.path.join(self.config['training']['path'], 'checkpoints') self.samples_path = os.path.join(self.config['training']['path'], 'samples') self.history_filename = 'history_' + self.config['training']['path'][ self.config['training']['path'].rindex('/') + 1:] + '.csv' model = self.build_model() if os.path.exists(self.checkpoints_path) and util.dir_contains_files( self.checkpoints_path): if load_checkpoint is not None: last_checkpoint_path = load_checkpoint self.epoch_num = 0 else: checkpoints = os.listdir(self.checkpoints_path) checkpoints.sort(key=lambda x: os.stat( os.path.join(self.checkpoints_path, x)).st_mtime) last_checkpoint = checkpoints[-1] last_checkpoint_path = os.path.join(self.checkpoints_path, last_checkpoint) self.epoch_num = int(last_checkpoint[11:16]) print('Loading model from epoch: %d' % self.epoch_num) model.load_weights(last_checkpoint_path) else: print('Building new model...') if not os.path.exists(self.config['training']['path']): os.makedirs(self.config['training']['path']) if not os.path.exists(self.checkpoints_path): os.makedirs(self.checkpoints_path) self.epoch_num = 0 if not os.path.exists(self.samples_path): os.makedirs(self.samples_path) if print_model_summary: model.summary() self.compile_model(model) # model.compile(optimizer=self.optimizer, # loss={'data_output_1': self.out_1_loss, 'data_output_2': self.out_2_loss}, metrics=self.metrics) self.config['model']['num_params'] = model.count_params() config_path = os.path.join(self.config['training']['path'], 'config.json') if not os.path.exists(config_path): util.pretty_json_dump(self.config, config_path) if print_model_summary: util.pretty_json_dump(self.config) return model
def setup_model(self): self.checkpoints_path = os.path.join(self.config['training']['path'], 'checkpoints') model = self.build_model() if not self.useTPU: self.tb_path = os.path.join(self.config['tensorboard']['path'], "%d" % self.epoch_num) os.makedirs(self.tb_path, exist_ok=True) #model.summary() losses = {} loss_weights = [] for i in range(self.num_sources): losses['data_output%d' % (i + 1)] = self.get_out_loss() loss_weights.append(self.config['learning']['loss_weights'][i]) if self.useTPU: # TPU tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"] tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( tpu_grpc_url) strategy = tf.contrib.tpu.TPUDistributionStrategy( tpu_cluster_resolver) model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) if os.path.exists(self.checkpoints_path) and util.dir_contains_files( self.checkpoints_path): checkpoints = os.listdir(self.checkpoints_path) checkpoints.sort(key=lambda x: os.stat( os.path.join(self.checkpoints_path, x)).st_mtime) last_checkpoint = checkpoints[-1] last_checkpoint_path = os.path.join(self.checkpoints_path, last_checkpoint) m = re.match(r'checkpoint\.([0-9]+)-', last_checkpoint) if m: self.epoch_num = int(m.group(1)) print('Loading model from epoch: %d' % self.epoch_num) model.load_weights(last_checkpoint_path) else: self.epoch_num = 0 print('Building new model...') else: print('Building new model...') os.makedirs(self.checkpoints_path, exist_ok=True) self.epoch_num = 0 model.compile(optimizer=self.get_optimizer(), loss=losses, loss_weights=loss_weights, metrics=self.get_metrics()) return model
def get_latest_checkpoint_path(session_dir): """ Returns the path of the most recent checkpoint in session_dir. Args: session_dir: string Returns: string """ checkpoints_path = os.path.join(session_dir, 'checkpoints') if os.path.exists(checkpoints_path) and util.dir_contains_files( checkpoints_path): checkpoints = os.listdir(checkpoints_path) checkpoints.sort( key=lambda x: os.stat(os.path.join(checkpoints_path, x)).st_mtime) last_checkpoint = checkpoints[-1] return os.path.join(checkpoints_path, last_checkpoint) else: return ''
def setup_model(self): """Creates a SFUN object. Returns: keras model """ self.checkpoints_path = os.path.join(self._config['training']['session_dir'], 'checkpoints') if not os.path.exists(self.checkpoints_path): os.mkdir(self.checkpoints_path) self.history_filename = 'history_' + self._config['training']['session_dir'][self._config['training']['session_dir'].rindex('/') + 1:] + '.csv' self.model, inputs_mask = self.build_model(train_bn=self.train_bn) self.compile_sfun(self.model, inputs_mask, self._config['training']['lr']) self._config['dataset']['num_freq'] = self.num_freq config_path = os.path.join(self._config['training']['session_dir'], 'config.json') if os.path.exists(self.checkpoints_path) and util.dir_contains_files(self.checkpoints_path): checkpoints = os.listdir(self.checkpoints_path) checkpoints.sort(key=lambda x: os.stat(os.path.join(self.checkpoints_path, x)).st_mtime) last_checkpoint = checkpoints[-1] last_checkpoint_path = os.path.join(self.checkpoints_path, last_checkpoint) self.epoch_num = int(last_checkpoint[11:16]) print('Loading Sound Field Network model from epoch: %d' % self.epoch_num) self.model.load_weights(last_checkpoint_path) else: print('Building new Sound Field Network model...') self.epoch_num = 0 self.model.summary() if not os.path.exists(config_path): util.save_config(config_path, self._config) return self.model