Beispiel #1
0
    def setup(self, args=None, d=None):
        """
        Initialize the checkpoint register.

        - args: Namespace or dictionary of params associated with the network
        - d: Dict (optional) to store together with the model
        """
        if self.is_setup:
            return self

        if not os.path.isdir(os.path.join(self.topdir, self.subdir)):
            os.makedirs(os.path.join(self.topdir, self.subdir))

        if args is not None:
            if isinstance(args, argparse.Namespace):
                args = vars(args)
            # add git info
            git_info = GitInfo(self.topdir)
            commit, branch = git_info.get_commit(), git_info.get_branch()
            args['git-commit'] = commit
            args['git-branch'] = branch
            from seqmod import __commit__
            args['seqmod-git-commit'] = __commit__
            # dump
            with open(self.checkpoint_path('params.yml'), 'w') as f:
                yaml.dump(args, f, default_flow_style=False)

        if d is not None:
            u.save_model(d, self.checkpoint_path('dict'), mode=self.ext)

        self.is_setup = True

        return self
Beispiel #2
0
    def save_nbest(self, model, loss):
        """
        Save model according to current state and some validation loss
        """
        if not self.is_setup:
            raise ValueError("Checkpoint not setup yet")

        def format_loss(loss):
            return '{:.4f}'.format(loss)

        if len(self.buf_best) == self.keep:
            losses = [format_loss(l) for _, l in self.buf_best]
            (worstm, worstl) = self.buf_best[-1]
            if loss < worstl and format_loss(
                    loss) not in losses:  # avoid duplicates
                try:
                    os.remove(worstm)
                except FileNotFoundError:
                    logging.warn("Couldn't find model [{}]".format(worstm))
                    print(self.buf_best, worstm, loss, worstl)
                self.buf_best.pop()
            else:
                return

        modelname = u.save_model(model,
                                 self.get_modelname(format_loss(loss)),
                                 mode=self.ext)
        self.buf_best.append((modelname, loss))
        self.buf_best.sort(key=itemgetter(1))

        return self
Beispiel #3
0
    def save_nlast(self, model):
        """
        Only keep track of n last models regardless loss
        """
        if not self.is_setup:
            raise ValueError("Checkpoint not setup yet")

        if len(self.buf_last) == self.keep:
            oldestm, _ = self.buf_last[-1]
            try:
                os.remove(oldestm)
            except FileNotFoundError:
                logging.warn("Couldn't find model [{}]".format(oldestm))
            self.buf_last.pop()

        timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
        modelname = u.save_model(model,
                                 self.get_modelname(timestamp),
                                 mode=self.ext)
        self.buf_last.append((modelname, timestamp))
        self.buf_last.sort(key=itemgetter(1), reverse=True)

        return self
    d = Dict(max_size=args.max_size, min_freq=args.min_freq,
             eos_token=u.EOS, force_unk=True)

    trainpath = os.path.join(args.path, 'train.txt')
    testpath = os.path.join(args.path, 'test.txt')
    outputformat = (args.output + ".{}.npz").format

    if os.path.isfile(outputformat("train")):
        raise ValueError("Output train file already exists")
    if os.path.isfile(outputformat("test")):
        raise ValueError("Output test file already exists")

    print("Fitting dictionary")
    d.fit(load_lines(trainpath, processor=processor),
          load_lines(testpath, processor=processor))
    u.save_model(d, args.output + '.dict')

    print("Transforming train data")
    with open(outputformat("train"), 'wb+') as f:
        vector = []
        for line in d.transform(load_lines(trainpath, processor=processor)):
            vector.extend(line)
        np.save(f, np.array(vector))

    if os.path.isfile(testpath):
        print("Transforming test data")
        with open(outputformat("test"), 'wb+') as f:
            vector = []
            for line in d.transform(load_lines(testpath, processor=processor)):
                vector.extend(line)
            np.save(f, np.array(vector))
Beispiel #5
0
                                    table=table)
        del train_lines, train_conds
        print("Processing test")
        linesiter = readlines(os.path.join(args.path, 'test.csv'))
        test_labels, test_lines = zip(*linesiter)
        test = examples_from_lines(test_lines,
                                   test_labels,
                                   lang_d,
                                   conds_d,
                                   table=table)
        del test_lines, test_labels
        d = tuple([lang_d] + conds_d)

        if args.save_data:
            assert args.data_path, "save_data requires data_path"
            u.save_model((train, test, d, table), args.data_path)

    train, valid = BlockDataset.splits_from_data(tuple(train),
                                                 d,
                                                 args.batch_size,
                                                 args.bptt,
                                                 gpu=args.gpu,
                                                 table=table,
                                                 test=None,
                                                 dev=args.dev_split)

    test = BlockDataset(tuple(test),
                        d,
                        args.batch_size,
                        args.bptt,
                        fitted=True,
Beispiel #6
0
    if args.early_stopping > 0:
        early_stopping = EarlyStopping(args.early_stopping)
    model_check_hook = make_lm_check_hook(
        d, method=args.decoding_method, temperature=args.temperature,
        max_seq_len=args.max_seq_len, seed_text=args.seed, gpu=args.gpu,
        early_stopping=early_stopping)
    num_checkpoints = len(train) // (args.checkpoint * args.hooks_per_epoch)
    trainer.add_hook(model_check_hook, num_checkpoints=num_checkpoints)

    # loggers
    visdom_logger = VisdomLogger(
        log_checkpoints=args.log_checkpoints, title=args.prefix, env='lm',
        server='http://' + args.visdom_server)
    trainer.add_loggers(StdLogger(), visdom_logger)

    trainer.train(args.epochs, args.checkpoint, gpu=args.gpu)

    if args.save:
        test_ppl = trainer.validate_model(test=True)
        print("Test perplexity: %g" % test_ppl)
        if args.save:
            f = '{prefix}.{cell}.{layers}l.{hid_dim}h.{emb_dim}e.{bptt}b.{ppl}'
            fname = f.format(ppl="%.2f" % test_ppl, **vars(args))
            if os.path.isfile(fname):
                answer = input("File [%s] exists. Overwrite? (y/n): " % fname)
                if answer.lower() not in ("y", "yes"):
                    print("Goodbye!")
                    sys.exit(0)
            print("Saving model to [%s]..." % fname)
            u.save_model(model, fname, d=d)
Beispiel #7
0
    if args.early_stopping > 0:
        early_stopping = EarlyStopping(args.early_stopping)
    model_check_hook = make_lm_check_hook(
        d, method=args.decoding_method, temperature=args.temperature,
        max_seq_len=args.max_seq_len, seed_text=args.seed, gpu=args.gpu,
        early_stopping=early_stopping)
    num_checkpoints = len(train) // (args.checkpoint * args.hooks_per_epoch)
    trainer.add_hook(model_check_hook, num_checkpoints=num_checkpoints)

    # loggers
    visdom_logger = VisdomLogger(
        log_checkpoints=args.log_checkpoints, title=args.prefix, env='lm',
        server='http://' + args.visdom_server)
    trainer.add_loggers(StdLogger(), visdom_logger)

    trainer.train(args.epochs, args.checkpoint, gpu=args.gpu)

    if args.save:
        test_ppl = trainer.validate_model(test=True)
        print("Test perplexity: %g" % test_ppl)
        if args.save:
            f = '{prefix}.{cell}.{layers}l.{hid_dim}h.{emb_dim}e.{bptt}b.{ppl}'
            fname = f.format(ppl="%.2f" % test_ppl, **vars(args))
            if os.path.isfile(fname):
                answer = input("File [%s] exists. Overwrite? (y/n): " % fname)
                if answer.lower() not in ("y", "yes"):
                    print("Goodbye!")
                    sys.exit(0)
            print("Saving model to [%s]..." % fname)
            u.save_model(model, fname, d=d)