Exemplo n.º 1
0
    def fit(self, X, Y, constraints=None):
        print("Training dual structural SVM")
        # we initialize with a small value so that loss-augmented inference
        # can give us something meaningful in the first iteration
        w = np.ones(self.problem.size_psi) * 1e-5
        n_samples = len(X)
        if constraints is None:
            constraints = [[] for i in xrange(n_samples)]
        loss_curve = []
        objective_curve = []
        primal_objective_curve = []
        self.ws = []
        self.alphas = []  # dual solutions
        for iteration in xrange(self.max_iter):
            if self.verbose > 0:
                print("iteration %d" % iteration)
            new_constraints = 0
            current_loss = 0.
            #for i, x, y in zip(np.arange(len(X)), X, Y):
                #y_hat, delta_psi, slack, loss = self._find_constraint(x, y, w)
            candidate_constraints = (Parallel(n_jobs=self.n_jobs)
                                     (delayed(find_constraint)
                                      (self.problem, x, y, w)
                                      for x, y in zip(X, Y)))
            for i, x, y, constraint in zip(np.arange(len(X)), X, Y,
                                           candidate_constraints):
                y_hat, delta_psi, slack, loss = constraint

                # check that the slack fits the loss-augmented inference
                x_loss_augmented = self.problem.loss_augment(x, y, w)
                dpsi_ = (GridCRF.psi(self.problem, x_loss_augmented, y)
                         - GridCRF.psi(self.problem, x_loss_augmented, y_hat))
                if np.abs(slack + min(0, np.dot(w, dpsi_))) > 0.01:
                    tracer()

                current_loss += loss

                if self.verbose > 1:
                    print("current slack: %f" % slack)
                y_hat_plain = unwrap_pairwise(y_hat)
                already_active = np.any([True for y__, _, _ in constraints[i]
                                         if (y_hat_plain ==
                                             unwrap_pairwise(y__)).all()])
                if already_active:
                    continue

                if self.check_constraints:
                    # "smart" stopping criterion
                    # check if most violated constraint is more violated
                    # than previous ones by more then eps.
                    # If it is less violated, inference was wrong/approximate
                    for con in constraints[i]:
                        # compute slack for old constraint
                        slack_tmp = max(con[2] - np.dot(w, con[1]), 0)
                        if self.verbose > 1:
                            print("slack old constraint: %f" % slack_tmp)
                        # if slack of new constraint is smaller or not
                        # significantly larger, don't add constraint.
                        # if smaller, complain about approximate inference.
                        if slack - slack_tmp < -1e-5:
                            print("bad inference: %f" % (slack_tmp - slack))
                            if self.break_on_bad:
                                tracer()
                            already_active = True
                            break

                # if significant slack and constraint not active
                # this is a weaker check than the "check_constraints" one.
                if not already_active and slack > 1e-5:
                    constraints[i].append([y_hat, delta_psi, loss])
                    new_constraints += 1
            current_loss /= len(X)
            loss_curve.append(current_loss)

            if new_constraints == 0:
                print("no additional constraints")
                #tracer()
                if iteration > 0:
                    break
            w, objective = self._solve_n_slack_qp(constraints, n_samples)

            # hack to make loss-augmented prediction working:
            w[:self.problem.n_states][w[:self.problem.n_states] == 0] = 1e-10
            slacks = [max(np.max([-np.dot(w, psi_) + loss_
                                  for _, psi_, loss_ in sample]), 0)
                      for sample in constraints]
            sum_of_slacks = np.sum(slacks)
            objective_p = self.C * sum_of_slacks / len(X) + np.sum(w ** 2) / 2.
            primal_objective_curve.append(objective_p)
            if (len(primal_objective_curve) > 2
                    and objective_p > primal_objective_curve[-2] + 1e8):
                print("primal loss became smaller. that shouldn't happen.")
                tracer()
            objective_curve.append(objective)
            if self.verbose > 0:
                print("current loss: %f  new constraints: %d, "
                      "primal objective: %f dual objective: %f" %
                      (current_loss, new_constraints,
                       primal_objective_curve[-1], objective))
            if (iteration > 1 and primal_objective_curve[-1] -
                    primal_objective_curve[-2] < 0.0001):
                print("objective converged.")
                break
            self.ws.append(w)
            if self.verbose > 1:
                print(w)
        self.w = w
        self.constraints_ = constraints
        print("calls to inference: %d" % self.problem.inference_calls)
        if self.plot:
            plt.figure()
            plt.subplot(131, title="loss")
            plt.plot(loss_curve)
            plt.subplot(132, title="objective")
            plt.plot(objective_curve)
            plt.subplot(133, title="primal objective")
            plt.plot(primal_objective_curve)
            plt.show()
            plt.close()