def setup(self, args=None, d=None): """ Initialize the checkpoint register. - args: Namespace or dictionary of params associated with the network - d: Dict (optional) to store together with the model """ if self.is_setup: return self if not os.path.isdir(os.path.join(self.topdir, self.subdir)): os.makedirs(os.path.join(self.topdir, self.subdir)) if args is not None: if isinstance(args, argparse.Namespace): args = vars(args) # add git info git_info = GitInfo(self.topdir) commit, branch = git_info.get_commit(), git_info.get_branch() args['git-commit'] = commit args['git-branch'] = branch from seqmod import __commit__ args['seqmod-git-commit'] = __commit__ # dump with open(self.checkpoint_path('params.yml'), 'w') as f: yaml.dump(args, f, default_flow_style=False) if d is not None: u.save_model(d, self.checkpoint_path('dict'), mode=self.ext) self.is_setup = True return self
def save_nbest(self, model, loss): """ Save model according to current state and some validation loss """ if not self.is_setup: raise ValueError("Checkpoint not setup yet") def format_loss(loss): return '{:.4f}'.format(loss) if len(self.buf_best) == self.keep: losses = [format_loss(l) for _, l in self.buf_best] (worstm, worstl) = self.buf_best[-1] if loss < worstl and format_loss( loss) not in losses: # avoid duplicates try: os.remove(worstm) except FileNotFoundError: logging.warn("Couldn't find model [{}]".format(worstm)) print(self.buf_best, worstm, loss, worstl) self.buf_best.pop() else: return modelname = u.save_model(model, self.get_modelname(format_loss(loss)), mode=self.ext) self.buf_best.append((modelname, loss)) self.buf_best.sort(key=itemgetter(1)) return self
def save_nlast(self, model): """ Only keep track of n last models regardless loss """ if not self.is_setup: raise ValueError("Checkpoint not setup yet") if len(self.buf_last) == self.keep: oldestm, _ = self.buf_last[-1] try: os.remove(oldestm) except FileNotFoundError: logging.warn("Couldn't find model [{}]".format(oldestm)) self.buf_last.pop() timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") modelname = u.save_model(model, self.get_modelname(timestamp), mode=self.ext) self.buf_last.append((modelname, timestamp)) self.buf_last.sort(key=itemgetter(1), reverse=True) return self
d = Dict(max_size=args.max_size, min_freq=args.min_freq, eos_token=u.EOS, force_unk=True) trainpath = os.path.join(args.path, 'train.txt') testpath = os.path.join(args.path, 'test.txt') outputformat = (args.output + ".{}.npz").format if os.path.isfile(outputformat("train")): raise ValueError("Output train file already exists") if os.path.isfile(outputformat("test")): raise ValueError("Output test file already exists") print("Fitting dictionary") d.fit(load_lines(trainpath, processor=processor), load_lines(testpath, processor=processor)) u.save_model(d, args.output + '.dict') print("Transforming train data") with open(outputformat("train"), 'wb+') as f: vector = [] for line in d.transform(load_lines(trainpath, processor=processor)): vector.extend(line) np.save(f, np.array(vector)) if os.path.isfile(testpath): print("Transforming test data") with open(outputformat("test"), 'wb+') as f: vector = [] for line in d.transform(load_lines(testpath, processor=processor)): vector.extend(line) np.save(f, np.array(vector))
table=table) del train_lines, train_conds print("Processing test") linesiter = readlines(os.path.join(args.path, 'test.csv')) test_labels, test_lines = zip(*linesiter) test = examples_from_lines(test_lines, test_labels, lang_d, conds_d, table=table) del test_lines, test_labels d = tuple([lang_d] + conds_d) if args.save_data: assert args.data_path, "save_data requires data_path" u.save_model((train, test, d, table), args.data_path) train, valid = BlockDataset.splits_from_data(tuple(train), d, args.batch_size, args.bptt, gpu=args.gpu, table=table, test=None, dev=args.dev_split) test = BlockDataset(tuple(test), d, args.batch_size, args.bptt, fitted=True,
if args.early_stopping > 0: early_stopping = EarlyStopping(args.early_stopping) model_check_hook = make_lm_check_hook( d, method=args.decoding_method, temperature=args.temperature, max_seq_len=args.max_seq_len, seed_text=args.seed, gpu=args.gpu, early_stopping=early_stopping) num_checkpoints = len(train) // (args.checkpoint * args.hooks_per_epoch) trainer.add_hook(model_check_hook, num_checkpoints=num_checkpoints) # loggers visdom_logger = VisdomLogger( log_checkpoints=args.log_checkpoints, title=args.prefix, env='lm', server='http://' + args.visdom_server) trainer.add_loggers(StdLogger(), visdom_logger) trainer.train(args.epochs, args.checkpoint, gpu=args.gpu) if args.save: test_ppl = trainer.validate_model(test=True) print("Test perplexity: %g" % test_ppl) if args.save: f = '{prefix}.{cell}.{layers}l.{hid_dim}h.{emb_dim}e.{bptt}b.{ppl}' fname = f.format(ppl="%.2f" % test_ppl, **vars(args)) if os.path.isfile(fname): answer = input("File [%s] exists. Overwrite? (y/n): " % fname) if answer.lower() not in ("y", "yes"): print("Goodbye!") sys.exit(0) print("Saving model to [%s]..." % fname) u.save_model(model, fname, d=d)