示例#1
0
    def train(self, schedule):
        train = [schedule["train"][0], schedule["train"][1]]
        valid = None if not schedule.get("valid") else [schedule["valid"][0], schedule["valid"][1]]

        assert (valid is not None) == ("valid" in schedule["eval"]), "Confusion about validation set!"

        opt_schedule = schedule["opt"]
        
        pp = {"type" : str(self)}
        munk.taggify(self.logging, "pretty").send(pp)
        log = munk.add_keyvalue(self.logging, "layer", "Stack")
       
        epochs = opt_schedule["epochs"]
        if epochs > 0:
            opt_schedule["f"] = self.score
            opt_schedule["fprime"] = self.grad

            if "eval_score" in opt_schedule:
                self._eval_score = opt_schedule["eval_score"]
                opt_schedule["eval_score"] = self.evaluate_score

            opt, evals, peeks = prepare_opt(opt_schedule, self.params, schedule, train, valid)

            stop = opt_schedule["stop"]
            if "peeks" in opt_schedule:
                peek_iv = opt_schedule["peek_intervall"]
                peek_files = {}
                for p in opt_schedule["peeks"]:
                    peek_files[p] = p + ".peek"
            else:
                peek_iv = epochs + 1

            for i, info in enumerate(opt):
                if (i+1) % stop == 0:
                    for e in evals:
                        info[e] = evals[e](self.params)
                    info = replace_gnumpy_data(info)
                    log.send(info)

                if i+1 == epochs:
                    break
                
                if (i+1) % peek_iv == 0:
                    for p in peeks:
                        prediction, inputs = peeks[p](self.params)
                        np.savez(peek_files[p], prediction, inputs)
                        pp = {"msg": "Writing peek file %s"%peek_files[p]}
                        munk.taggify(self.logging, "pretty").send(pp)

        else:
            pp = {"msg": "NO FINETUNING of stack"}
            munk.taggify(self.logging, "pretty").send(pp)

        _params = self.params.as_numpy_array().tolist()
        info = dict(params=_params, shape=self.__repr__())
        log.send(info)
示例#2
0
    def pretrain(self, schedule):
        train = [schedule["train"][0], schedule["train"][1]]
        valid = None if not schedule.get("valid") else [schedule["valid"][0], schedule["valid"][1]]

        assert (valid is not None) == ("valid" in schedule["eval"]), "Confusion about validation set!"

        for i, (layer, sched) in enumerate(izip(self, self.stack)):
            pt_params = layer.pt_init(**sched)
            
            opt_schedule = sched["opt"]
            
            pp = {"layer":i, "type":str(layer)}
            munk.taggify(self.logging, "pretty").send(pp)
            log = munk.add_keyvalue(self.logging, "layer", i)
            
            epochs = opt_schedule["epochs"]
            if epochs > 0:
                opt_schedule["f"] = layer.pt_score
                opt_schedule["fprime"] = layer.pt_grad

                opt, evals, peeks = prepare_opt(opt_schedule, pt_params, schedule, train, valid)

                stop = opt_schedule["stop"]
                for j, info in enumerate(opt):
                    if (j+1) % stop == 0:
                        for e in evals:
                            info[e] = evals[e](pt_params)
                        info = replace_gnumpy_data(info)
                        log.send(info)
                        
                    if (j+1) == epochs:
                        break
            else:
                pp = {"msg": "NO PRETRAINING of layer %i"%i}
                munk.taggify(self.logging, "pretty").send(pp)

            info = layer.pt_done(pt_params, **sched)
            pt_params = None
            log.send(info)

            # move data forward, save in temporary hdf5
            if i < (len(self) - 1):
                nxt_name = strftime("%Y-%m-%d-%H:%M:%S") + "_L" + str(i+1) + "_TMP.h5"
                nxt = h5py.File(nxt_name)
                pp = {"msg": "Take care of temporary " + nxt_name}
                munk.taggify(self.logging, "pretty").send(pp)
                # if a validation set is available, move it forward, too.
                if valid:
                    valid[0] = self.next_hdf5(layer, valid[0], "validation", nxt, chunk=512)
                train[0] = self.next_hdf5(layer, train[0], "train", nxt, chunk=512)