Пример #1
0
    def get_centers_from_xy(self, X, y, classes=None, model=None):
        if model is None:
            model = HIModel(self.config)
        model.train(False)
        cfg = self.config.copy()
        vX_tr, vy_tr = F.prepare_ds(X, y, cfg, False)

        if classes is None:
            vX_t, vy_t = vX_tr, vy_tr
        else:
            vX_t, vy_t = [], []
            for i, _ in enumerate(vX_tr):
                if vy_tr[i] in classes:
                    vX_t.append(vX_tr[i])
                    vy_t.append(vy_tr[i])
            vX_t = torch.stack(vX_t)
            vy_t = torch.stack(vy_t)

        fXl = []
        fyl = []
        for i, _ in tqdm(enumerate(vy_t)):
            rad = random.randint(0, len(vX_t) - self.config.BATCH_SIZE)
            if True:
                x = vX_t[rad:rad + self.config.BATCH_SIZE]
                p3 = F.runtime_preprocess(self.config, x)
                fXl.extend(model.extract(p3).detach())
                fyl.extend(vy_t[rad:rad + self.config.BATCH_SIZE])
        return fXl, fyl
Пример #2
0
 def get_center(self, image):
     normed = self.preprocessor.norm(image)
     words = self.preprocessor.segment_words(normed)
     if len(words) == 0:
         return None
     else:
         return self.model.get_center(F.words2word_block(words)), len(words)
Пример #3
0
 def get_center(self, word_block):
     self.featextractor.train(False)
     r = []
     for w in range(0, len(word_block), self.config.BATCH_SIZE):
         x = F.runtime_preprocess(self.config, word_block[w : w + self.config.BATCH_SIZE])
         r.extend(self.featextractor.extract(x).cpu().detach())
     return sum(r) / len(r)
Пример #4
0
    def __init__(self, config):
        super().__init__()
        self.config = config.copy()

        self.i = create_resnet(config.RESNET_LAYERS, num_classes=config.FEATURES_COUNT)
        if F.valid_path(config.MODEL_PATH):
            try:
                self.i.load_state_dict(torch.load(config.MODEL_PATH, map_location=config.DEVICE))
            except:
                print("An error occured while trying to load weights. A new model created.")

        self.t = nn.Sequential(
            nn.ReLU(),
            nn.Linear(config.FEATURES_COUNT, config.CLASS_COUNT),
            nn.Softmax(dim=1)
        )
Пример #5
0
def validate(config, model, paths, PAX, verbose=True):
    valpr = Preprocessor(config)
    pointss = []
    for cpaths in tqdm(paths):
        w = []
        for p in cpaths:
            try:
                w.extend(valpr.open_norm_segm(p))
            except:
                pass
        if len(w) < 2 * PAX:
            if verbose:
                print(os.path.dirname(cpaths[0]), "missed, not enough words:",
                      len(w))
            continue
        pointss.append([])
        for i in range(len(w) // PAX):
            pointss[-1].append(
                model.get_center(F.words2word_block(w[i * PAX:(i + 1) * PAX])))

    X_dists = []
    y_dists = []
    for pss in tqdm(pointss):
        for i, _ in enumerate(pss):
            for j in range(i + 1, len(pss)):
                X_dists.append(F.dist(pss[i], pss[j]).item())
                y_dists.append(0)

    zero_count = len(y_dists)
    for i in range(zero_count):
        random.shuffle(pointss)
        s1 = pointss[0]
        s2 = pointss[1]
        ps1 = random.choice(s1)
        ps2 = random.choice(s2)
        X_dists.append(F.dist(ps1, ps2).item())
        y_dists.append(1)

    allnums = sorted(X_dists)
    ns = []
    bacc, bh = 0, 0
    rctr = range(0, len(allnums) - 1, max(300 // (PAX**2), 1))
    if verbose:
        rctr = tqdm(rctr)
    XXX = []
    for i in rctr:
        thr = (allnums[i] + allnums[i + 1]) / 2
        XXX.append(thr)
        acc = countacc(X_dists, y_dists, thr)
        ns.append(acc)
        if acc > bacc:
            bacc = acc
            bh = thr
    if verbose:
        plt.plot(XXX, ns)

    if verbose:
        print("Best accuracy:", bacc)
        print("Best threshold:", bh)

    return bh
Пример #6
0
 def gen_paths(self, path):
     return F.gen_paths(path)
Пример #7
0
 def dist(self, c1, c2):
     return F.dist(c1, c2).item()
Пример #8
0
    def fit(self, X, y, verbose=True, plot=False):
        assert len(X) == len(y), "X and y must have the same size"
        assert max(y) + 1 <= self.config.CLASS_COUNT, "y classes mismatch config.CLASS_COUNT"
        self.__v = verbose
        self.__p = plot
        pr = self.__pr
        cfg = self.config
        X_train, y_train, X_test, y_test = F.prepare_ds(X, y, cfg, True)
        assert len(X_train) > 0, "An error occurred while taking X_train"
        assert len(X_test) > 0, "An error occurred while taking X_test"
        pr("First stage: fitting nn")
        REDRAW_SIZE = 20
        lasttime = time.time()
        ydistr = F.count_distr(y_train, cfg.CLASS_COUNT)
        crit = nn.CrossEntropyLoss(weight=torch.tensor(1 / np.array(ydistr)).to(self.device).type(torch.float))
        opt = torch.optim.Adam(self.featextractor.parameters(), lr=cfg.LEARNING_RATE)
        losses = []
        accs = []
        g_losses = []
        g_accs = []

        valaccs = []
        self.featextractor.train(False)
        valacc = F.validate_model(cfg, self.featextractor.classify_proba, X_test, y_test)
        self.featextractor.train(True)
        valaccs.append(valacc)

        g_valaccs = valaccs[:]

        self.featextractor.train(True)
        uniqpath = "HI" + str(int(time.time()))

        lastsave = ""
        for i in range(cfg.N_EPOCHS):
            batch_id = random.randint(0, len(X_train) - cfg.BATCH_SIZE)
            X_b = F.runtime_preprocess(self.config, X_train[batch_id: batch_id + cfg.BATCH_SIZE])
            ytrue = y_train[batch_id: batch_id + cfg.BATCH_SIZE]
            ypred = self.featextractor.classify_proba(X_b)
            loss = crit(ypred, ytrue.to(self.device).type(torch.long))
            loss.backward()
            opt.step()
            opt.zero_grad()

            allb = cfg.BATCH_SIZE
            s = torch.argmax(ypred.cpu(), dim=1) == ytrue.type(torch.long)
            accs.append(s.sum().item() / allb)
            losses.append(loss.item())
            g_accs.append(sum(accs[-cfg.SMOOTH_POWER:]) / len(accs[-cfg.SMOOTH_POWER:]))
            g_losses.append(sum(losses[-cfg.SMOOTH_POWER:]) / len(losses[-cfg.SMOOTH_POWER:]))

            if i % REDRAW_SIZE == REDRAW_SIZE - 1:
                clear_output(True)
                plt.figure(figsize=[24, 6.7])
                plt.subplot(1, 2, 1)
                plt.plot(g_losses[::cfg.PLOT_REDRAW_DENSE], label="loss")
                plt.legend()
                plt.subplot(1, 2, 2)
                plt.plot(g_accs[::cfg.PLOT_REDRAW_DENSE], label="train acc")
                plt.plot(g_valaccs[::cfg.PLOT_REDRAW_DENSE], label="val acc")
                plt.legend()
                plt.show()
                pr("acc:", round(sum(g_accs[-1 - REDRAW_SIZE: -1]) / REDRAW_SIZE, 3))
                pr("loss:", round(sum(g_losses[-1 - REDRAW_SIZE: -1]) / REDRAW_SIZE, 3))
                tm = time.time()
                tmgone = tm - lasttime
                lasttime = tm
                pr("last val acc:", round(g_valaccs[-1], 3))
                pr(round(REDRAW_SIZE / tmgone, 2), "epochs per second")
                pr(round(1000 * tmgone / REDRAW_SIZE, 2), "seconds for 1000 epochs")
                if lastsave != "":
                    pr("Last backup is saved to", lastsave)

            if cfg.BACKUP_DIRECTORY is not None and i % cfg.BACKUP_PERIOD == cfg.BACKUP_PERIOD - 1:
                p = cfg.BACKUP_DIRECTORY + uniqpath + "/model_" + str(i)
                self.saveto(p)
                lastsave = p

            if i % cfg.VAL_PERIOD == cfg.VAL_PERIOD - 1:
                self.featextractor.train(False)
                valacc = F.validate_model(cfg, self.featextractor.classify_proba, X_test, y_test)
                self.featextractor.train(True)
                valaccs.append(valacc)
            else:
                valaccs.append(valaccs[-1])
            g_valaccs.append(sum(valaccs[-cfg.SMOOTH_POWER:]) / len(valaccs[-cfg.SMOOTH_POWER:]))
        return True
Пример #9
0
 def saveto(self, path=None):
     if path is None:
         path = self.config.MODEL_PATH
     F.safe_path(path)
     torch.save(self.featextractor.i.state_dict(), path)