def fit(self, X, Y, constraints=None): print("Training dual structural SVM") # we initialize with a small value so that loss-augmented inference # can give us something meaningful in the first iteration w = np.ones(self.problem.size_psi) * 1e-5 n_samples = len(X) if constraints is None: constraints = [[] for i in xrange(n_samples)] loss_curve = [] objective_curve = [] primal_objective_curve = [] self.ws = [] self.alphas = [] # dual solutions for iteration in xrange(self.max_iter): if self.verbose > 0: print("iteration %d" % iteration) new_constraints = 0 current_loss = 0. #for i, x, y in zip(np.arange(len(X)), X, Y): #y_hat, delta_psi, slack, loss = self._find_constraint(x, y, w) candidate_constraints = (Parallel(n_jobs=self.n_jobs) (delayed(find_constraint) (self.problem, x, y, w) for x, y in zip(X, Y))) for i, x, y, constraint in zip(np.arange(len(X)), X, Y, candidate_constraints): y_hat, delta_psi, slack, loss = constraint # check that the slack fits the loss-augmented inference x_loss_augmented = self.problem.loss_augment(x, y, w) dpsi_ = (GridCRF.psi(self.problem, x_loss_augmented, y) - GridCRF.psi(self.problem, x_loss_augmented, y_hat)) if np.abs(slack + min(0, np.dot(w, dpsi_))) > 0.01: tracer() current_loss += loss if self.verbose > 1: print("current slack: %f" % slack) y_hat_plain = unwrap_pairwise(y_hat) already_active = np.any([True for y__, _, _ in constraints[i] if (y_hat_plain == unwrap_pairwise(y__)).all()]) if already_active: continue if self.check_constraints: # "smart" stopping criterion # check if most violated constraint is more violated # than previous ones by more then eps. # If it is less violated, inference was wrong/approximate for con in constraints[i]: # compute slack for old constraint slack_tmp = max(con[2] - np.dot(w, con[1]), 0) if self.verbose > 1: print("slack old constraint: %f" % slack_tmp) # if slack of new constraint is smaller or not # significantly larger, don't add constraint. # if smaller, complain about approximate inference. if slack - slack_tmp < -1e-5: print("bad inference: %f" % (slack_tmp - slack)) if self.break_on_bad: tracer() already_active = True break # if significant slack and constraint not active # this is a weaker check than the "check_constraints" one. if not already_active and slack > 1e-5: constraints[i].append([y_hat, delta_psi, loss]) new_constraints += 1 current_loss /= len(X) loss_curve.append(current_loss) if new_constraints == 0: print("no additional constraints") #tracer() if iteration > 0: break w, objective = self._solve_n_slack_qp(constraints, n_samples) # hack to make loss-augmented prediction working: w[:self.problem.n_states][w[:self.problem.n_states] == 0] = 1e-10 slacks = [max(np.max([-np.dot(w, psi_) + loss_ for _, psi_, loss_ in sample]), 0) for sample in constraints] sum_of_slacks = np.sum(slacks) objective_p = self.C * sum_of_slacks / len(X) + np.sum(w ** 2) / 2. primal_objective_curve.append(objective_p) if (len(primal_objective_curve) > 2 and objective_p > primal_objective_curve[-2] + 1e8): print("primal loss became smaller. that shouldn't happen.") tracer() objective_curve.append(objective) if self.verbose > 0: print("current loss: %f new constraints: %d, " "primal objective: %f dual objective: %f" % (current_loss, new_constraints, primal_objective_curve[-1], objective)) if (iteration > 1 and primal_objective_curve[-1] - primal_objective_curve[-2] < 0.0001): print("objective converged.") break self.ws.append(w) if self.verbose > 1: print(w) self.w = w self.constraints_ = constraints print("calls to inference: %d" % self.problem.inference_calls) if self.plot: plt.figure() plt.subplot(131, title="loss") plt.plot(loss_curve) plt.subplot(132, title="objective") plt.plot(objective_curve) plt.subplot(133, title="primal objective") plt.plot(primal_objective_curve) plt.show() plt.close()