def run(self, model_path, model_indice, log_file, log_file_link): """There are four evaluation modes: 1.only eval a .pth model: -e *.pth 2.only eval a certain epoch: -e epoch 3.eval all epochs in a given section: -e start_epoch-end_epoch 4.eval all epochs from a certain started epoch: -e start_epoch- """ if '.pth' in model_indice: models = [ model_indice, ] elif "-" in model_indice: start_epoch = int(model_indice.split("-")[0]) end_epoch = model_indice.split("-")[1] models = os.listdir(model_path) models.remove("epoch-last.pth") sorted_models = [None] * len(models) model_idx = [0] * len(models) for idx, m in enumerate(models): num = m.split(".")[0].split("-")[1] model_idx[idx] = num sorted_models[idx] = m model_idx = np.array([int(i) for i in model_idx]) down_bound = model_idx >= start_epoch up_bound = [True] * len(sorted_models) if end_epoch: end_epoch = int(end_epoch) assert start_epoch < end_epoch up_bound = model_idx <= end_epoch bound = up_bound * down_bound model_slice = np.array(sorted_models)[bound] models = [os.path.join(model_path, model) for model in model_slice] else: models = [ os.path.join(model_path, 'epoch-%s.pth' % model_indice), ] results = open(log_file, 'a') link_file(log_file, log_file_link) for model in models: logger.info("Load Model: %s" % model) self.val_func = load_model(self.network, model) # for name, parameters in self.val_func.named_parameters(): # print(name, ':', parameters.size()) result_line = self.multi_process_evaluation() results.write('Model: ' + model + '\n') results.write(result_line) results.write('\n') results.flush() results.close()
def save_and_link_checkpoint(self, snapshot_dir, log_dir, log_dir_link): ensure_dir(snapshot_dir) if not osp.exists(log_dir_link): link_file(log_dir, log_dir_link) current_epoch_checkpoint = osp.join( snapshot_dir, 'epoch-{}.pth'.format(self.state.epoch)) self.save_checkpoint(current_epoch_checkpoint) last_epoch_checkpoint = osp.join(snapshot_dir, 'epoch-last.pth') link_file(current_epoch_checkpoint, last_epoch_checkpoint)
def run(self, model_path, model_indice, log_file, log_file_link): """Evaluate models.""" if '.pth' in model_indice: models = [model_indice, ] else: models = [os.path.join(model_path, 'epoch-%s.pth' % model_indice), ] results = open(log_file, 'a') link_file(log_file, log_file_link) for model in models: logger.info("Load Model: %s" % model) self.val_func = load_model(self.network, model) result_line = self.multi_process_evaluation() results.write('Model: ' + model + '\n') results.write(result_line) results.write('\n') results.flush() results.close()
def link_tb(self, source, target): ensure_dir(source) ensure_dir(target) link_file(source, target)