def __init__(self, dataset, weights_manager, objective, rollout_type="compare", guide_evaluator_type="mepa", guide_evaluator_cfg=None, arch_network_type="pointwise_comparator", arch_network_cfg=None, guide_batch_size=16, schedule_cfg=None): super(ArchNetworkEvaluator, self).__init__(dataset, weights_manager, objective, rollout_type, schedule_cfg) # construct the evaluator that will be used to guide the learning of the predictor ge_cls = BaseEvaluator.get_class_(guide_evaluator_type) self.guide_evaluator = ge_cls(dataset, weights_manager, objective, rollout_type=rollout_type, **(guide_evaluator_cfg or {})) # construct the architecture network an_cls = ArchNetwork.get_class_(arch_network_type) self.arch_network = an_cls( search_space=self.weights_manager.search_space, **(arch_network_cfg or {})) # configurations self.guide_batch_size = guide_batch_size
def __init__(self, search_space, rollout_type="compare", arch_network_type="pointwise_comparator", arch_network_cfg=None, schedule_cfg=None): super(BatchUpdateArchNetworkEvaluator, self).__init__(dataset=None, weights_manager=None, objective=None, rollout_type=rollout_type, schedule_cfg=schedule_cfg) # construct the architecture network an_cls = ArchNetwork.get_class_(arch_network_type) self.arch_network = an_cls( search_space=self.weights_manager.search_space, **(arch_network_cfg or {}))
def main(argv): parser = argparse.ArgumentParser(prog="train_cellss_pkl.py") parser.add_argument("cfg_file") parser.add_argument("--gpu", type=int, default=0, help="gpu device id") parser.add_argument("--num-workers", default=4, type=int) parser.add_argument("--report-freq", default=200, type=int) parser.add_argument("--seed", default=None, type=int) parser.add_argument("--train-dir", default=None, help="Save train log/results into TRAIN_DIR") parser.add_argument("--save-every", default=None, type=int) parser.add_argument("--test-only", default=False, action="store_true") parser.add_argument("--test-funcs", default=None, help="comma-separated list of test funcs") parser.add_argument("--load", default=None, help="Load comparator from disk.") parser.add_argument("--sample", default=None, type=int) parser.add_argument("--sample-batchify-inner-sample-n", default=None, type=int) parser.add_argument("--sample-to-file", default=None, type=str) parser.add_argument("--sample-from-file", default=None, type=str) parser.add_argument("--sample-conflict-file", default=None, type=str, action="append") parser.add_argument("--sample-ratio", default=10, type=float) parser.add_argument("--sample-output-dir", default="./sample_output/") # parser.add_argument("--data-fname", default="cellss_data.pkl") # parser.add_argument("--data-fname", default="cellss_data_round1_999.pkl") parser.add_argument("--data-fname", default="enas_data_round1_980.pkl") parser.add_argument("--addi-train", default=[], action="append", help="additional train data") parser.add_argument("--addi-train-only", action="store_true", default=False) parser.add_argument("--addi-valid", default=[], action="append", help="additional valid data") parser.add_argument("--addi-valid-only", action="store_true", default=False) parser.add_argument("--valid-true-split", default=None) parser.add_argument("--valid-score-split", default=None) parser.add_argument("--enas-ss", default=True, action="store_true") args = parser.parse_args(argv) setproctitle.setproctitle("python train_cellss_pkl.py config: {}; train_dir: {}; cwd: {}"\ .format(args.cfg_file, args.train_dir, os.getcwd())) # log log_format = "%(asctime)s %(message)s" logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt="%m/%d %I:%M:%S %p") if not args.test_only: assert args.train_dir is not None, "Must specificy `--train-dir` when training" # if training, setting up log file, backup config file if not os.path.exists(args.train_dir): os.makedirs(args.train_dir) log_file = os.path.join(args.train_dir, "train.log") logging.getLogger().addFile(log_file) # copy config file backup_cfg_file = os.path.join(args.train_dir, "config.yaml") shutil.copyfile(args.cfg_file, backup_cfg_file) else: backup_cfg_file = args.cfg_file # cuda if torch.cuda.is_available(): torch.cuda.set_device(args.gpu) cudnn.benchmark = True cudnn.enabled = True logging.info("GPU device = %d" % args.gpu) else: logging.info("no GPU available, use CPU!!") if args.seed is not None: if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info("Load pkl cache from cellss_data.pkl") data_fname = args.data_fname with open(data_fname, "rb") as rf: data = pickle.load(rf) with open(backup_cfg_file, "r") as cfg_f: cfg = yaml.load(cfg_f) logging.info("Config: %s", cfg) arch_network_type = cfg.get("arch_network_type", "pointwise_comparator") model_cls = ArchNetwork.get_class_(arch_network_type) # search space if args.enas_ss: ss_cfg_str = """ search_space_type: cnn search_space_cfg: cell_layout: null num_cell_groups: 2 num_init_nodes: 2 num_layers: 8 num_node_inputs: 2 num_steps: 4 reduce_cell_groups: - 1 shared_primitives: - skip_connect - sep_conv_3x3 - sep_conv_5x5 - avg_pool_3x3 - max_pool_3x3 """ else: ss_cfg_str = """ search_space_cfg: cell_layout: null num_cell_groups: 2 num_init_nodes: 2 num_layers: 8 num_node_inputs: 2 num_steps: 4 reduce_cell_groups: - 1 shared_primitives: - none - max_pool_3x3 - avg_pool_3x3 - skip_connect - sep_conv_3x3 - sep_conv_5x5 - dil_conv_3x3 - dil_conv_5x5 search_space_type: cnn """ ss_cfg = yaml.load(StringIO(ss_cfg_str)) search_space = get_search_space(ss_cfg["search_space_type"], **ss_cfg["search_space_cfg"]) model = model_cls(search_space, **cfg.pop("arch_network_cfg")) if args.load is not None: logging.info("Load %s from %s", arch_network_type, args.load) model.load(args.load) model.to(device) args.__dict__.update(cfg) logging.info("Combined args: %s", args) # init data loaders if hasattr(args, "train_size"): train_valid_split = args.train_size else: train_valid_split = int( getattr(args, "train_valid_split", 0.6) * len(data)) train_data = data[:train_valid_split] valid_data = data[train_valid_split:] if hasattr(args, "train_ratio") and args.train_ratio is not None: _num = len(train_data) train_data = train_data[:int(_num * args.train_ratio)] logging.info("Train dataset ratio: %.3f", args.train_ratio) if args.addi_train: if args.addi_train_only: train_data = [] for addi_train_fname in args.addi_train: with open(addi_train_fname, "rb") as rf: addi_train_data = pickle.load(rf) train_data += addi_train_data if args.addi_valid: if args.addi_valid_only: valid_data = [] for addi_fname in args.addi_valid: with open(addi_fname, "rb") as rf: addi_valid_data = pickle.load(rf) valid_data += addi_valid_data num_train_archs = len(train_data) logging.info("Number of architectures: train: %d; valid: %d", num_train_archs, len(valid_data)) train_data = CellSSDataset(train_data, minus=cfg.get("dataset_minus", None), div=cfg.get("dataset_div", None)) valid_data = CellSSDataset(valid_data, minus=cfg.get("dataset_minus", None), div=cfg.get("dataset_div", None)) train_loader = DataLoader( train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers, collate_fn=lambda items: list([np.array(x) for x in zip(*items)])) val_loader = DataLoader( valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers, collate_fn=lambda items: list([np.array(x) for x in zip(*items)])) if args.test_funcs is not None: test_func_names = args.test_funcs.split(",") test_funcs = [globals()[func_name] for func_name in test_func_names] else: test_funcs = [] # init test if not arch_network_type == "pairwise_comparator" or args.test_only: corr, func_res = valid(val_loader, model, args, funcs=test_funcs) if args.sample is not None: if args.sample_from_file: logging.info("Read genotypes from: {}".format( args.sample_from_file)) with open(args.sample_from_file, "r") as rf: from_genotypes = yaml.load(rf) assert len(from_genotypes) == args.sample * int( args.sample_ratio) else: from_genotypes = None if args.sample_conflict_file: conflict_archs = [] for scf in args.sample_conflict_file: conflict_archs += pickle.load(open(scf, "rb")) else: conflict_archs = None if args.sample_batchify_inner_sample_n is not None: # do not support multi-stage now genotypes = sample_batchify(search_space, model, args.sample_ratio, args.sample, args, conflict_archs=conflict_archs) else: genotypes = sample(search_space, model, args.sample * int(args.sample_ratio), args.sample, args, from_genotypes=from_genotypes, conflict_archs=conflict_archs) if args.sample_to_file: with open(args.sample_to_file, "w") as wf: yaml.dump([str(geno) for geno in genotypes], wf) else: with open("./final_template.yaml", "r") as rf: logging.info("Load final template config file!") template_cfg = yaml.load(rf) for i, genotype in enumerate(genotypes): sample_cfg = copy.deepcopy(template_cfg) sample_cfg["final_model_cfg"]["genotypes"] = str(genotype) if not os.path.exists(args.sample_output_dir): os.makedirs(args.sample_output_dir) with open( os.path.join(args.sample_output_dir, "{}.yaml".format(i)), "w") as wf: yaml.dump(sample_cfg, wf) if args.test_funcs is None: logging.info("INIT: kendall tau {:.4f}".format(corr)) else: logging.info("INIT: kendall tau {:.4f};\n\t{}".format( corr, "\n\t".join([ "{}: {}".format(name, _get_float_format(res, "{:.4f}")) for name, res in zip(test_func_names, func_res) ]))) if args.test_only: return _multi_stage = getattr(args, "multi_stage", False) _multi_stage_pair_pool = getattr(args, "multi_stage_pair_pool", False) if _multi_stage: all_perfs = np.array([item[1] for item in train_data.data]) all_inds = np.arange(all_perfs.shape[0]) stage_epochs = getattr(args, "stage_epochs", [0, 1, 2, 3]) num_stages = len(stage_epochs) default_stage_nums = [all_perfs.shape[0] // num_stages] * (num_stages - 1) + \ [all_perfs.shape[0] - all_perfs.shape[0] // num_stages * (num_stages - 1)] stage_nums = getattr(args, "stage_nums", default_stage_nums) assert np.sum(stage_nums) == all_perfs.shape[0] logging.info("Stage nums: {}".format(stage_nums)) stage_inds_lst = [] for i_stage in range(num_stages): max_stage_ = np.max(all_perfs[all_inds, stage_epochs[i_stage]]) min_stage_ = np.min(all_perfs[all_inds, stage_epochs[i_stage]]) logging.info( "Stage {}, epoch {}: min {:.2f} %; max {:.2f}% (range {:.2f} %)" .format(i_stage, stage_epochs[i_stage], min_stage_ * 100, max_stage_ * 100, (max_stage_ - min_stage_) * 100)) sorted_inds = np.argsort(all_perfs[all_inds, stage_epochs[i_stage]]) stage_inds, all_inds = all_inds[sorted_inds[:stage_nums[i_stage]]],\ all_inds[sorted_inds[stage_nums[i_stage]:]] stage_inds_lst.append(stage_inds) train_stages = [[train_data.data[ind] for ind in _stage_inds] for _stage_inds in stage_inds_lst] avg_score_stages = [] for i_stage in range(num_stages - 1): avg_score_stages.append( (all_perfs[stage_inds_lst[i_stage], stage_epochs[i_stage]].sum(), np.sum([ all_perfs[stage_inds_lst[j_stage], stage_epochs[i_stage]].sum() for j_stage in range(i_stage + 1, num_stages) ]))) if _multi_stage_pair_pool: all_stages, pairs_list = make_pair_pool(train_stages, args, stage_epochs) total_eval_time = all_perfs.shape[0] * all_perfs.shape[1] multi_stage_eval_time = sum([ (stage_epochs[i_stage] + 1) * len(_stage_inds) for i_stage, _stage_inds in enumerate(stage_inds_lst) ]) logging.info("Percentage of evaluation time: {:.2f} %".format( float(multi_stage_eval_time) / total_eval_time * 100)) for i_epoch in range(1, args.epochs + 1): model.on_epoch_start(i_epoch) if _multi_stage: if _multi_stage_pair_pool: avg_loss = train_multi_stage_pair_pool(all_stages, pairs_list, model, i_epoch, args) else: if getattr(args, "use_listwise", False): avg_loss = train_multi_stage_listwise( train_stages, model, i_epoch, args, avg_score_stages, stage_epochs) else: avg_loss = train_multi_stage(train_stages, model, i_epoch, args, avg_score_stages, stage_epochs) else: avg_loss = train(train_loader, model, i_epoch, args) logging.info("Train: Epoch {:3d}: train loss {:.4f}".format( i_epoch, avg_loss)) train_corr, train_func_res = valid(train_loader, model, args, funcs=test_funcs) if args.test_funcs is not None: for name, res in zip(test_func_names, train_func_res): logging.info("Train: Epoch {:3d}: {}: {}".format( i_epoch, name, _get_float_format(res, "{:.4f}"))) logging.info("Train: Epoch {:3d}: train kd {:.4f}".format( i_epoch, train_corr)) corr, func_res = valid(val_loader, model, args, funcs=test_funcs) if args.test_funcs is not None: for name, res in zip(test_func_names, func_res): logging.info("Valid: Epoch {:3d}: {}: {}".format( i_epoch, name, _get_float_format(res, "{:.4f}"))) logging.info("Valid: Epoch {:3d}: kendall tau {:.4f}".format( i_epoch, corr)) if args.save_every is not None and i_epoch % args.save_every == 0: save_path = os.path.join(args.train_dir, "{}.ckpt".format(i_epoch)) model.save(save_path) logging.info("Epoch {:3d}: Save checkpoint to {}".format( i_epoch, save_path))
def __init__( self, search_space, device, rollout_type, mode="eval", inner_controller_type=None, inner_controller_cfg=None, arch_network_type="pointwise_comparator", arch_network_cfg=None, # how to use the inner controller and arch network to sample new archs inner_sample_n=1, inner_samples=1, inner_steps=200, inner_report_freq=50, predict_batch_size=512, inner_random_init=True, inner_iter_random_init=True, inner_enumerate_search_space=False, # DEPRECATED inner_enumerate_sample_ratio=None, # DEPRECATED min_inner_sample_ratio=10, # how to train the arch network begin_train_num=0, predictor_train_cfg={ "epochs": 200, "num_workers": 2, "batch_size": 128, "compare": True, "max_compare_ratio": 4, "compare_threshold": 0., "report_freq": 50, "train_valid_split": None, "n_cross_valid": None, }, training_on_load=False, # force retraining on load schedule_cfg=None): super(PredictorBasedController, self).__init__(search_space, rollout_type, mode, schedule_cfg) expect(inner_controller_type is not None, "Must specificy inner controller type", ConfigException) self.device = device self.predictor_train_cfg = predictor_train_cfg self.inner_controller_reinit = True self.inner_sample_n = inner_sample_n self.inner_samples = inner_samples self.inner_steps = inner_steps self.inner_report_freq = inner_report_freq self.inner_random_init = inner_random_init self.inner_iter_random_init = inner_iter_random_init self.inner_enumerate_search_space = inner_enumerate_search_space if inner_enumerate_search_space: warnings.warn( "The `inner_enumerate_search_space` option is DEPRECATED. " "Use inner_controller, and set `inner_samples`, `inner_steps` " "accordingly", warnings.DeprecationWarning) self.inner_enumerate_sample_ratio = inner_enumerate_sample_ratio self.min_inner_sample_ratio = min_inner_sample_ratio self.predict_batch_size = predict_batch_size self.begin_train_num = begin_train_num self.training_on_load = training_on_load # initialize the inner controller inner_controller_cfg = inner_controller_cfg or {} tmp_r_type = inner_controller_cfg.get("rollout_type", None) if tmp_r_type is not None: expect( tmp_r_type == rollout_type, "If specified, inner_controller's `rollout_type` must match " "the outer `rollout_type`", ConfigException) inner_controller_cfg.pop("rollout_type", None) inner_controller_cfg.pop("mode", None) self.inner_controller_type = inner_controller_type self.inner_controller_cfg = inner_controller_cfg # if not self.inner_controller_reinit: self.inner_controller = BaseController.get_class_( self.inner_controller_type)(self.search_space, self.device, rollout_type=self.rollout_type, **self.inner_controller_cfg) # else: # self.inner_controller = None # Currently, we do not use controller with parameters to be optimized (e.g. RL-learned RNN) self.inner_cont_optimizer = None # initialize the predictor arch_network_cfg = arch_network_cfg or {} expect(arch_network_type == "pointwise_comparator", "only support pointwise_comparator arch network for now", ConfigException) model_cls = ArchNetwork.get_class_(arch_network_type) self.model = model_cls(self.search_space, **arch_network_cfg) self.model.to(self.device) self.gt_rollouts = [] self.gt_arch_scores = [] self.num_gt_rollouts = 0 # self.train_loader = None # self.val_loader = None self.is_predictor_trained = False
if __name__ == "__main__": import sys base_dir = sys.argv[1] def print_with_fmt(container, fmt): if isinstance(container, (list, tuple)): return ", ".join([print_with_fmt(i, fmt) for i in container]) return fmt.format(container) cfg_file = os.path.join(base_dir, "config.yaml") with open(cfg_file, "r") as rf: cfg = yaml.load(rf) ss = get_search_space(cfg["search_space_type"], **cfg["search_space_cfg"]) pred = ArchNetwork.get_class_(cfg["controller_cfg"]["arch_network_type"])( ss, **cfg["controller_cfg"]["arch_network_cfg"]) def calc_correlation(name_list, rollouts, corr_func=kendalltau): perfs = [[r.perf[name] for r in rollouts] for name in name_list] num_perf = len(perfs) corr_mat = np.zeros((num_perf, num_perf)) for i in range(num_perf): for j in range(num_perf): corr_mat[i][j] = corr_func(perfs[i], perfs[j]) return corr_mat def print_corr_mat(name_list, corr_mat, print_name): print(print_name) sz = len(name_list) name_list = [name[:12] for name in name_list] fmt_str = " ".join(["{:12}"] * (sz + 1))
def main(argv): parser = argparse.ArgumentParser(prog="train_nasbench201_pkl.py") parser.add_argument("cfg_file") parser.add_argument("--gpu", type=int, default=0, help="gpu device id") parser.add_argument("--num-workers", default=4, type=int) parser.add_argument("--report_freq", default=200, type=int) parser.add_argument("--seed", default=None, type=int) parser.add_argument("--train-dir", default=None, help="Save train log/results into TRAIN_DIR") parser.add_argument("--save-every", default=10, type=int) parser.add_argument("--test-only", default=False, action="store_true") parser.add_argument("--test-funcs", default=None, help="comma-separated list of test funcs") parser.add_argument("--load", default=None, help="Load comparator from disk.") parser.add_argument("--eval-only-last", default=None, type=int, help=("for pairwise compartor, the evaluation is slow," " only evaluate in the final epochs")) parser.add_argument("--save-predict", default=None, help="Save the predict scores") parser.add_argument("--train-pkl", default="nasbench201_05.pkl", help="Training Datasets pickle") parser.add_argument("--valid-pkl", default="nasbench201_05_valid.pkl", help="Evaluate Datasets pickle") args = parser.parse_args(argv) setproctitle.setproctitle("python train_nasbench201_pkl.py config: {}; train_dir: {}; cwd: {}"\ .format(args.cfg_file, args.train_dir, os.getcwd())) # log log_format = "%(asctime)s %(message)s" logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt="%m/%d %I:%M:%S %p") if not args.test_only: assert args.train_dir is not None, "Must specificy `--train-dir` when training" # if training, setting up log file, backup config file if not os.path.exists(args.train_dir): os.makedirs(args.train_dir) log_file = os.path.join(args.train_dir, "train.log") logging.getLogger().addFile(log_file) # copy config file backup_cfg_file = os.path.join(args.train_dir, "config.yaml") shutil.copyfile(args.cfg_file, backup_cfg_file) else: backup_cfg_file = args.cfg_file # cuda if torch.cuda.is_available(): torch.cuda.set_device(args.gpu) cudnn.benchmark = True cudnn.enabled = True logging.info("GPU device = %d" % args.gpu) else: logging.info("no GPU available, use CPU!!") if args.seed is not None: if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") search_space = get_search_space("nasbench-201", load_nasbench=False) logging.info( "Load pkl cache from nasbench201.pkl and nasbench201_valid.pkl") with open(args.train_pkl, "rb") as rf: train_data = pickle.load(rf) with open(args.valid_pkl, "rb") as rf: valid_data = pickle.load(rf) with open(backup_cfg_file, "r") as cfg_f: cfg = yaml.load(cfg_f) logging.info("Config: %s", cfg) arch_network_type = cfg.get("arch_network_type", "pointwise_comparator") model_cls = ArchNetwork.get_class_(arch_network_type) model = model_cls(search_space, **cfg.pop("arch_network_cfg")) if args.load is not None: logging.info("Load %s from %s", arch_network_type, args.load) model.load(args.load) model.to(device) args.__dict__.update(cfg) logging.info("Combined args: %s", args) # init nasbench data loaders if hasattr(args, "train_ratio") and args.train_ratio is not None: _num = len(train_data) train_data = train_data[:int(_num * args.train_ratio)] logging.info("Train dataset ratio: %.3f", args.train_ratio) num_train_archs = len(train_data) logging.info("Number of architectures: train: %d; valid: %d", num_train_archs, len(valid_data)) # decide how many archs would only train to halftime if hasattr(args, "ignore_quantile") and args.ignore_quantile is not None: half_accs = [item[2] for item in train_data] if not args.compare or getattr(args, "ignore_halftime", False): # just ignore halftime archs full_inds = np.argsort(half_accs)[int(num_train_archs * args.ignore_quantile):] train_data = [train_data[ind] for ind in full_inds] logging.info( "#Train architectures after ignore half-time %.2f bad archs: %d", args.ignore_quantile, len(train_data)) else: half_inds = np.argsort(half_accs)[:int(num_train_archs * args.ignore_quantile)] logging.info( "#Architectures do not need to be trained to final: %.2f (%.2f %%)", len(half_inds), args.ignore_quantile * 100) for ind in half_inds: train_data[ind] = (train_data[ind][0], None, train_data[ind][2]) train_data = NasBench201Dataset(train_data, minus=cfg.get("dataset_minus", None), div=cfg.get("dataset_div", None)) valid_data = NasBench201Dataset(valid_data, minus=cfg.get("dataset_minus", None), div=cfg.get("dataset_div", None)) train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers, collate_fn=lambda items: list(zip(*items))) val_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers, collate_fn=lambda items: list(zip(*items))) # init test if not arch_network_type in {"pairwise_comparator", "random_forest" } or args.test_only: if args.test_funcs is not None: test_func_names = args.test_funcs.split(",") corr, func_res = valid( val_loader, model, args, funcs=[globals()[func_name] for func_name in test_func_names] if args.test_funcs is not None else []) if args.test_funcs is None: logging.info("INIT: kendall tau {:.4f}".format(corr)) else: logging.info("INIT: kendall tau {:.4f};\n\t{}".format( corr, "\n\t".join([ "{}: {}".format(name, res) for name, res in zip(test_func_names, func_res) ]))) if args.test_only: return for i_epoch in range(1, args.epochs + 1): model.on_epoch_start(i_epoch) if getattr(args, "use_listwise", False): avg_loss = train_listwise(train_data, model, i_epoch, args, arch_network_type) else: avg_loss = train(train_loader, model, i_epoch, args, arch_network_type) logging.info("Train: Epoch {:3d}: train loss {:.4f}".format( i_epoch, avg_loss)) if args.eval_only_last is None or (args.epochs - i_epoch < args.eval_only_last): corr, _ = valid(val_loader, model, args) logging.info("Valid: Epoch {:3d}: kendall tau {:.4f}".format( i_epoch, corr)) if i_epoch % args.save_every == 0: save_path = os.path.join(args.train_dir, "{}.ckpt".format(i_epoch)) model.save(save_path) logging.info("Epoch {:3d}: Save checkpoint to {}".format( i_epoch, save_path))