def evaluate( self, eps_generator: typing.Union[OmniglotLoader, ImageFolderGenerator] ) -> typing.List[float]: """Evaluate the performance """ print('Evaluation is started.\n') # load model model = self.load_model(resume_epoch=self.config['resume_epoch'], hyper_net_class=self.hyper_net_class, eps_generator=eps_generator) # get list of episode names, each episode name consists of classes eps = get_episodes(episode_file_path=self.config['episode_file']) accuracies = [None] * len(eps) for i, eps_name in enumerate(eps): eps_data = eps_generator.generate_episode(episode_name=eps_name) # split data into train and validation xt, yt, xv, yv = train_val_split(X=eps_data, k_shot=self.config['k_shot'], shuffle=True) # move data to GPU (if there is a GPU) x_t = torch.from_numpy(xt).float().to(self.config['device']) y_t = torch.tensor(yt, dtype=torch.long, device=self.config['device']) x_v = torch.from_numpy(xv).float().to(self.config['device']) y_v = torch.tensor(yv, dtype=torch.long, device=self.config['device']) _, logits = self.adapt_and_predict(model=model, x_t=x_t, y_t=y_t, x_v=x_v, y_v=None) # initialize y_prediction y_pred = torch.zeros(size=(y_v.shape[0], len(eps_data)), dtype=torch.float, device=self.config['device']) for logits_ in logits: y_pred += torch.softmax(input=logits_, dim=1) y_pred /= len(logits) accuracies[i] = (y_pred.argmax(dim=1) == y_v).float().mean().item() sys.stdout.write('\033[F') print(i + 1) acc_mean = np.mean(a=accuracies) acc_std = np.std(a=accuracies) print('\nAccuracy = {0:.2f} +/- {1:.2f}\n'.format( acc_mean * 100, 1.96 * acc_std / np.sqrt(len(accuracies)) * 100)) return accuracies
def evaluate() -> None: assert resume_epoch > 0 acc = [] if (num_episodes is None) and (episode_file is None): raise ValueError('Expect exactly one of num_episodes and episode_file to be not None, receive both are None.') # load model net, _, _ = load_model(epoch_id=resume_epoch) net.eval() episodes = get_episodes(episode_file_path=episode_file) if None in episodes: episodes = [None] * num_episodes file_acc = os.path.join(logdir, 'accuracy.txt') f_acc = open(file=file_acc, mode='w') try: for i, episode_name in enumerate(episodes): X = eps_generator.generate_episode(episode_name=episode_name) # split into train and validation xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True) # move data to gpu x_t = torch.from_numpy(xt).float().to(device) y_t = torch.tensor(yt, dtype=torch.long, device=device) x_v = torch.from_numpy(xv).float().to(device) y_v = torch.tensor(yv, dtype=torch.long, device=device) # adapt on the support data fnet = adapt_to_episode(x=x_t, y=y_t, net=net) # evaluate on the query data logits_v = fnet(x_v) episode_acc = (logits_v.argmax(dim=1) == y_v).sum().item() / (len(X) * v_shot) acc.append(episode_acc) f_acc.write('{}\n'.format(episode_acc)) sys.stdout.write('\033[F') print(i) except: pass else: pass finally: f_acc.close() mean = np.mean(a=acc) std = np.std(a=acc) n = len(acc) print('Accuracy = {0:.4f} +/- {1:.4f}'.format(mean, 1.96 * std / np.sqrt(n)))
def evaluate() -> None: assert resume_epoch > 0 acc = [] # load model net, _, _ = load_model(epoch_id=resume_epoch) net.eval() episodes = get_episodes(episode_file_path=episode_file) if None in episodes: episodes = [None] * num_episodes file_acc = os.path.join(logdir, 'accuracy.txt') f_acc = open(file=file_acc, mode='w') try: with torch.no_grad(): for i, episode_ in enumerate(episodes): X = eps_generator.generate_episode(episode_name=episode_) # split into train and validation xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True) # move data to gpu x_t = torch.from_numpy(xt).float().to(device) y_t = torch.tensor(yt, dtype=torch.long, device=device) x_v = torch.from_numpy(xv).float().to(device) y_v = torch.tensor(yv, dtype=torch.long, device=device) # adapt on the support data z_prototypes = adapt_to_episode(x=x_t, y=y_t, net=net) # evaluate on the query data z_v = net.forward(x_v) distance_matrix = euclidean_distance(matrixN=z_v, matrixM=z_prototypes) logits_v = -distance_matrix episode_acc = (logits_v.argmax(dim=1) == y_v).sum().item() / (len(X) * v_shot) acc.append(episode_acc) f_acc.write('{0}\n'.format(episode_acc)) sys.stdout.write('\033[F') print(i) except: pass else: pass finally: f_acc.close() mean = np.mean(a=acc) std = np.std(a=acc) n = len(acc) print('Accuracy = {0:.4f} +/- {1:.4f}'.format(mean, 1.96 * std / np.sqrt(n)))
def train() -> None: """Train Args: Returns: """ try: # parse training parameters meta_lr = args.meta_lr minibatch = args.minibatch minibatch_print = np.lcm(minibatch, 100) decay_lr = args.decay_lr num_episodes_per_epoch = args.num_episodes_per_epoch num_epochs = args.num_epochs # initialize/load model net, meta_optimizer, schdlr = load_model(epoch_id=resume_epoch, meta_lr=meta_lr, decay_lr=decay_lr) # zero grad meta_optimizer.zero_grad() # get episode list if not None -> generator of episode names, each episode name consists of classes episodes = get_episodes(episode_file_path=episode_file) # initialize a tensorboard summary writer for logging tb_writer = SummaryWriter( log_dir=logdir, purge_step=resume_epoch * num_episodes_per_epoch // minibatch_print if resume_epoch > 0 else None ) for epoch_id in range(resume_epoch, resume_epoch + num_epochs, 1): episode_count = 0 loss_monitor = 0 while (episode_count < num_episodes_per_epoch): # get episode from the given csv file, or just return None episode_name = random.sample(population=episodes, k=1)[0] X = eps_generator.generate_episode(episode_name=episode_name) # split into train and validation xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True) # move data to gpu x_t = torch.from_numpy(xt).float().to(device) y_t = torch.tensor(yt, dtype=torch.long, device=device) x_v = torch.from_numpy(xv).float().to(device) y_v = torch.tensor(yv, dtype=torch.long, device=device) # adapt on the support data fnet = adapt_to_episode(x=x_t, y=y_t, net=net) # evaluate on the query data logits_v = fnet.forward(x_v) cls_loss = torch.nn.functional.cross_entropy(input=logits_v, target=y_v) loss_monitor += cls_loss.item() cls_loss = cls_loss / minibatch cls_loss.backward() episode_count += 1 # update the meta-model if (episode_count % minibatch == 0): meta_optimizer.step() meta_optimizer.zero_grad() # monitor losses if (episode_count % minibatch_print == 0): loss_monitor /= minibatch_print global_step = (epoch_id * num_episodes_per_epoch + episode_count) // minibatch_print tb_writer.add_scalar( tag='Loss', scalar_value=loss_monitor, global_step=global_step ) loss_monitor = 0 # decay learning rate schdlr.step() # save model checkpoint = { 'net_state_dict': net.state_dict(), 'op_state_dict': meta_optimizer.state_dict(), 'lr_schdlr_state_dict': schdlr.state_dict() } checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch_id + 1) torch.save(checkpoint, os.path.join(logdir, checkpoint_filename)) checkpoint = 0 print('SAVING parameters into {0:s}\n'.format(checkpoint_filename)) except KeyboardInterrupt: pass else: pass finally: print('\nClose tensorboard summary writer') tb_writer.close()
def evaluate(hyper_net_cls, get_f_base_net_fn: typing.Callable, adapt_to_episode: typing.Callable, get_accuracy_fn: typing.Callable) -> None: """Evaluation """ acc = [] # initialize/load model hyper_net, base_net, _, _ = load_model(hyper_net_cls=hyper_net_cls, epoch_id=config['resume_epoch'], meta_lr=config['meta_lr'], decay_lr=config['decay_lr']) hyper_net.eval() base_net.eval() # get list of episode names, each episode name consists of classes episodes = get_episodes(episode_file_path=config['episode_file'], num_episodes=config['num_episodes']) try: acc_file = open(file=os.path.join(logdir, 'accuracy.txt'), mode='w') for i, episode_name in enumerate(episodes): X = eps_generator.generate_episode(episode_name=episode_name) # split into train and validation xt, yt, xv, yv = train_val_split(X=X, k_shot=config['k_shot'], shuffle=True) # move data to gpu x_t = torch.from_numpy(xt).float().to(device) y_t = torch.tensor(yt, dtype=torch.long, device=device) x_v = torch.from_numpy(xv).float().to(device) y_v = torch.tensor(yv, dtype=torch.long, device=device) # ------------------------- # functional base network # ------------------------- f_base_net = get_f_base_net_fn(base_net=base_net) # ------------------------- # adapt on the support data # ------------------------- f_hyper_net = adapt_to_episode(x=x_t, y=y_t, hyper_net=hyper_net, f_base_net=f_base_net) # ------------------------- # accuracy # ------------------------- acc_temp = get_accuracy_fn(x=x_v, y=y_v, f_hyper_net=f_hyper_net, f_base_net=f_base_net) acc.append(acc_temp) acc_file.write('{}\n'.format(acc_temp)) sys.stdout.write('\033[F') print(i) finally: acc_file.close() acc_mean = np.mean(acc) acc_std = np.std(acc) print('Accuracy = {} +/- {}'.format( acc_mean, 1.96 * acc_std / np.sqrt(len(episodes)))) return None
def train(hyper_net_cls, get_f_base_net_fn: typing.Callable, adapt_to_episode: typing.Callable, loss_on_query_fn: typing.Callable) -> None: """Base method used for training Args: """ # initialize/load model hyper_net, base_net, meta_opt, schdlr = load_model( hyper_net_cls=hyper_net_cls, epoch_id=config['resume_epoch'], meta_lr=config['meta_lr'], decay_lr=config['decay_lr']) # zero grad meta_opt.zero_grad() # get list of episode names, each episode name consists of classes episodes = get_episodes(episode_file_path=config['episode_file']) # initialize a tensorboard summary writer for logging tb_writer = SummaryWriter( log_dir=logdir, purge_step=config['resume_epoch'] * config['num_episodes_per_epoch'] // minibatch_print if config['resume_epoch'] > 0 else None) try: for epoch_id in range(config['resume_epoch'], config['resume_epoch'] + config['num_epochs'], 1): episode_count = 0 loss_monitor = 0 # kl_div_monitor = 0 while (episode_count < config['num_episodes_per_epoch']): # get episode from the given csv file, or just return None episode_name = random.sample(population=episodes, k=1)[0] X = eps_generator.generate_episode(episode_name=episode_name) # split into train and validation xt, yt, xv, yv = train_val_split(X=X, k_shot=config['k_shot'], shuffle=True) # move data to gpu x_t = torch.from_numpy(xt).float().to(device) y_t = torch.tensor(yt, dtype=torch.long, device=device) x_v = torch.from_numpy(xv).float().to(device) y_v = torch.tensor(yv, dtype=torch.long, device=device) # ------------------------- # functional base network # ------------------------- f_base_net = get_f_base_net_fn(base_net=base_net) # ------------------------- # adapt on the support data # ------------------------- f_hyper_net = adapt_to_episode(x=x_t, y=y_t, hyper_net=hyper_net, f_base_net=f_base_net) # ------------------------- # loss on query data # ------------------------- loss_meta = loss_on_query_fn(x=x_v, y=y_v, f_hyper_net=f_hyper_net, f_base_net=f_base_net, hyper_net=hyper_net) if torch.isnan(loss_meta): raise ValueError('Validation loss is NaN.') loss_meta = loss_meta / config['minibatch'] loss_meta.backward() # monitoring validation loss loss_monitor += loss_meta.item() # kl_div_monitor += kl_loss.item() episode_count += 1 # update the meta-model if (episode_count % config['minibatch'] == 0): # torch.nn.utils.clip_grad_norm_(parameters=hyper_net.parameters(), max_norm=10) meta_opt.step() meta_opt.zero_grad() # monitor losses if (episode_count % minibatch_print == 0): loss_monitor /= minibatch_print # kl_div_monitor /= minibatch_print # print('{}, {}'.format(loss_monitor, kl_div_monitor)) # print(loss_monitor) global_step = (epoch_id * config['num_episodes_per_epoch'] + episode_count) // minibatch_print tb_writer.add_scalar(tag='Loss', scalar_value=loss_monitor, global_step=global_step) loss_monitor = 0 # kl_div_monitor = 0 # decay learning rate schdlr.step() # save model checkpoint = { 'hyper_net_state_dict': hyper_net.state_dict(), 'op_state_dict': meta_opt.state_dict(), 'lr_schdlr_state_dict': schdlr.state_dict() } checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch_id + 1) torch.save(checkpoint, os.path.join(logdir, checkpoint_filename)) checkpoint = 0 print('SAVING parameters into {0:s}\n'.format(checkpoint_filename)) finally: print('\nClose tensorboard summary writer') tb_writer.close() return None
def train( self, eps_generator: typing.Union[OmniglotLoader, ImageFolderGenerator]) -> None: """Train meta-learning model Args: eps_generator: the generator that generate episodes/tasks """ print('Training is started.\nLog is stored at {0:s}.\n'.format( self.config['logdir'])) # initialize/load model. Please see the load_model method implemented in each specific class for further information about the model model = self.load_model(resume_epoch=self.config['resume_epoch'], hyper_net_class=self.hyper_net_class, eps_generator=eps_generator) model[-1].zero_grad() # get list of episode names, each episode name consists of classes eps = get_episodes(episode_file_path=self.config['episode_file']) # initialize a tensorboard summary writer for logging tb_writer = SummaryWriter(log_dir=self.config['logdir'], purge_step=self.config['resume_epoch'] * self.config['num_episodes_per_epoch'] // self.config['minibatch_print'] if self.config['resume_epoch'] > 0 else None) try: for epoch_id in range( self.config['resume_epoch'], self.config['resume_epoch'] + self.config['num_epochs'], 1): loss_monitor = 0. KL_monitor = 0. for eps_count in range(self.config['num_episodes_per_epoch']): # ------------------------- # get eps from the given csv file or just random (None) # ------------------------- eps_name = random.sample(population=eps, k=1)[0] # ------------------------- # episode data # ------------------------- eps_data = eps_generator.generate_episode( episode_name=eps_name) # split data into train and validation xt, yt, xv, yv = train_val_split( X=eps_data, k_shot=self.config['k_shot'], shuffle=True) # move data to GPU (if there is a GPU) x_t = torch.from_numpy(xt).float().to( self.config['device']) y_t = torch.tensor(yt, dtype=torch.long, device=self.config['device']) x_v = torch.from_numpy(xv).float().to( self.config['device']) y_v = torch.tensor(yv, dtype=torch.long, device=self.config['device']) # ------------------------- # adapt and predict the support data # ------------------------- f_hyper_net, logits = self.adapt_and_predict(model=model, x_t=x_t, y_t=y_t, x_v=x_v, y_v=y_v) loss_v = 0. for logits_ in logits: loss_v_temp = torch.nn.functional.cross_entropy( input=logits_, target=y_v) loss_v = loss_v + loss_v_temp loss_v = loss_v / len(logits) loss_monitor += loss_v.item() # monitor validation loss # calculate KL divergence KL_div = self.KL_divergence(model=model, f_hyper_net=f_hyper_net) KL_monitor += KL_div.item() if isinstance( KL_div, torch.Tensor) else KL_div # monitor KL divergence # extra loss applicable for ABML only loss_extra = self.loss_extra(model=model, f_hyper_net=f_hyper_net, x_t=x_t, y_t=y_t) # accumulate KL divergence to loss loss_v = loss_v + loss_extra + self.config[ 'KL_weight'] * KL_div loss_v = loss_v / self.config['minibatch'] # calculate gradients w.r.t. hyper_net's parameters loss_v.backward() # update meta-parameters if ((eps_count + 1) % self.config['minibatch'] == 0): loss_prior = self.loss_prior(model=model) if hasattr(loss_prior, 'requires_grad'): loss_prior.backward() model[-1].step() model[-1].zero_grad() # monitoring if (eps_count + 1) % self.config['minibatch_print'] == 0: loss_monitor /= self.config['minibatch_print'] KL_monitor = KL_monitor * self.config[ 'minibatch'] / self.config['minibatch_print'] # calculate step for Tensorboard Summary Writer global_step = ( epoch_id * self.config['num_episodes_per_epoch'] + eps_count + 1) // self.config['minibatch_print'] tb_writer.add_scalar(tag='Cls loss', scalar_value=loss_monitor, global_step=global_step) tb_writer.add_scalar(tag='KL divergence', scalar_value=KL_monitor, global_step=global_step) # reset monitoring variables loss_monitor = 0. KL_monitor = 0. # save model checkpoint = { 'hyper_net_state_dict': model[0].state_dict(), 'opt_state_dict': model[-1].state_dict() } checkpoint_path = os.path.join( self.config['logdir'], 'Epoch_{0:d}.pt'.format(epoch_id + 1)) torch.save(obj=checkpoint, f=checkpoint_path) print('State dictionaries are saved into {0:s}\n'.format( checkpoint_path)) print('Training is completed.') finally: print('\nClose tensorboard summary writer') tb_writer.close() return None
def evaluate( self, eps_generator: typing.Union[OmniglotLoader, ImageFolderGenerator]) -> None: """Evaluate the performance """ print('Evaluation is started.\n') # load model model = self.load_model(resume_epoch=self.config['resume_epoch'], hyper_net_class=self.hyper_net_class, eps_generator=eps_generator) # get list of episode names, each episode name consists of classes eps = get_episodes(episode_file_path=self.config['episode_file'], num_episodes=1000) ece, roc_auc, ll = 0., 0., 0. correct, total = 0., 0. for i, eps_name in enumerate(eps): x_t, y_t, x_v, y_v = eps_generator.generate_episode( episode_name=eps_name) x_t, y_t, x_v, y_v = x_t.cuda(), y_t.cuda(), x_v.cuda(), y_v.cuda() # split data into train and validation # grid = torchvision.utils.make_grid(x_t, nrow=5) # torchvision.utils.save_image(grid, f"train-{i}.png") # grid = torchvision.utils.make_grid(x_v, nrow=90) # torchvision.utils.save_image(grid, f"test-{i}.png") # if i == 5: # exit() # print(x_v.size()) _, logits = self.adapt_and_predict(model=model, x_t=x_t, y_t=y_t, x_v=x_v, y_v=None) # initialize y_prediction y_pred = torch.zeros((y_v.shape[0], self.config["n_way"]), device=x_v.device) for logits_ in logits: y_pred += torch.softmax(input=logits_, dim=1) y_pred /= len(logits) # exit() correct += (y_pred.argmax(dim=1) == y_v).sum() total += y_v.numel() e, _, _ = ece_yhat_only(100, y_v, y_pred, device=y_v.device) ece += e yhot = OneHotEncoder(categories='auto').fit_transform( y_v.cpu().numpy().reshape(-1, 1)).toarray() roc_auc += roc_auc_score(yhot, y_pred.detach().cpu().numpy(), multi_class="ovr") ll += y_pred[torch.arange(y_v.size(0)), y_v].detach().cpu().sum() sys.stdout.write('\033[F') print(i + 1) # , " ll: ", -np.log(ll / (i + 1))) acc = correct / total ece = ece / (i + 1) roc_auc = roc_auc / (i + 1) nll = -np.log(ll) + np.log(total) files = [ "{}_acc_{}.txt", "{}_ece_{}.txt", "{}_auroc_{}.txt", "{}_nll_{}.txt" ] metrics = [acc, ece, roc_auc, nll] for fl, mtrc in zip(files, metrics): if self.config["corrupt"]: extra = "corrupt" elif self.config["ood_test"]: extra = "oodtest" else: extra = "plain" path = os.path.join(self.config['logdir'], fl.format(extra, self.config['run'])) with open(path, "a+") as f: f.write("{}\n".format(mtrc)) print( '\nAccuracy: {0:.2f}\nECE: {1:.2f}\nAUROC: {2:.2f}\nNLL: {3:.2f}'. format(acc * 100, ece, roc_auc, nll))
def train( self, eps_generator: typing.Union[OmniglotLoader, ImageFolderGenerator]) -> None: """Train meta-learning model Args: eps_generator: the generator that generate episodes/tasks """ print('Training is started.\nLog is stored at {0:s}.\n'.format( self.config['logdir'])) # initialize/load model. Please see the load_model method implemented in each specific class for further information about the model model = self.load_model(resume_epoch=self.config['resume_epoch'], hyper_net_class=self.hyper_net_class, eps_generator=eps_generator) model[-1].zero_grad() # get list of episode names, each episode name consists of classes eps = get_episodes(episode_file_path=self.config['episode_file']) # initialize a tensorboard summary writer for logging tb_writer = SummaryWriter(log_dir=self.config['logdir'], purge_step=self.config['resume_epoch'] * self.config['num_episodes_per_epoch'] // self.config['minibatch_print'] if self.config['resume_epoch'] > 0 else None) try: for epoch_id in range( self.config['resume_epoch'], self.config['resume_epoch'] + self.config['num_epochs'], 1): print(f"Starting epoch: {epoch_id}") loss_monitor = 0. KL_monitor = 0. correct, total = 0., 0. for eps_count in range(self.config['num_episodes_per_epoch']): # ------------------------- # get eps from the given csv file or just random (None) # ------------------------- eps_name = random.sample(population=eps, k=1)[0] # ------------------------- # episode data # ------------------------- x_t, y_t, x_v, y_v = eps_generator.generate_episode( episode_name=eps_name) x_t, y_t, x_v, y_v = x_t.cuda(), y_t.cuda(), x_v.cuda( ), y_v.cuda() # ------------------------- # adapt and predict the support data # ------------------------- f_hyper_net, logits = self.adapt_and_predict(model=model, x_t=x_t, y_t=y_t, x_v=x_v, y_v=y_v) loss_v = 0. for logits_ in logits: loss_v_temp = torch.nn.functional.cross_entropy( input=logits_, target=y_v) loss_v = loss_v + loss_v_temp loss_v = loss_v / len(logits) loss_monitor += loss_v.item() # monitor validation loss correct += (torch.stack(logits).mean(dim=0).argmax( dim=-1) == y_v).sum().item() total += y_v.numel() # calculate KL divergence KL_div = self.KL_divergence(model=model, f_hyper_net=f_hyper_net) KL_monitor += KL_div.item() if isinstance( KL_div, torch.Tensor) else KL_div # monitor KL divergence # extra loss applicable for ABML only loss_extra = self.loss_extra(model=model, f_hyper_net=f_hyper_net, x_t=x_t, y_t=y_t) # accumulate KL divergence to loss loss_v = loss_v + loss_extra + self.config[ 'KL_weight'] * KL_div loss_v = loss_v / self.config['minibatch'] # calculate gradients w.r.t. hyper_net's parameters loss_v.backward() # sys.stdout.write('\033[F') # print(f"correct: {(correct / total):.4f}") # , " ll: ", -np.log(ll / (i + 1))) self.config['iters'] += 1 # update meta-parameters if ((eps_count + 1) % self.config['minibatch'] == 0): loss_prior = self.loss_prior(model=model) if hasattr(loss_prior, 'requires_grad'): loss_prior.backward() model[-1].step() model[-1].zero_grad() # monitoring if (eps_count + 1) % self.config['minibatch_print'] == 0: loss_monitor /= self.config['minibatch_print'] KL_monitor = KL_monitor * self.config[ 'minibatch'] / self.config['minibatch_print'] # calculate step for Tensorboard Summary Writer global_step = ( epoch_id * self.config['num_episodes_per_epoch'] + eps_count + 1) // self.config['minibatch_print'] tb_writer.add_scalar(tag='Cls loss', scalar_value=loss_monitor, global_step=global_step) tb_writer.add_scalar(tag='KL divergence', scalar_value=KL_monitor, global_step=global_step) # reset monitoring variables loss_monitor = 0. KL_monitor = 0. print(f"epoch {epoch_id} acc: {correct / total}") # save model checkpoint = { 'hyper_net_state_dict': model[0].state_dict(), 'opt_state_dict': model[-1].state_dict(), 'epoch': epoch_id, 'iters': self.config['iters'] } checkpoint_path = os.path.join( self.config['logdir'], 'run_{}.pt'.format(self.config['run'])) torch.save(obj=checkpoint, f=checkpoint_path) print('State dictionaries are saved into {0:s}\n'.format( checkpoint_path)) print('Training is completed.') finally: print('\nClose tensorboard summary writer') tb_writer.close() return None