Example #1
0
class MCSFinder(ControlFlow):
  def __init__(self, simulation_cfg, superlog_path_or_dag,
               invariant_check=InvariantChecker.check_correspondence,
               transform_dag=None, end_wait_seconds=0.5,
               mcs_trace_path=None, extra_log=None, dump_runtime_stats=True,
               **kwargs):
    super(MCSFinder, self).__init__(simulation_cfg)
    self.sync_callback = None
    self._log = logging.getLogger("mcs_finder")

    if type(superlog_path_or_dag) == str:
      superlog_path = superlog_path_or_dag
      # The dag is codefied as a list, where each element has
      # a list of its dependents
      self.dag = EventDag(superlog_parser.parse_path(superlog_path))
    else:
      self.dag = superlog_path_or_dag

    self.invariant_check = invariant_check
    self.transform_dag = transform_dag
    self.mcs_trace_path = mcs_trace_path
    self._extra_log = extra_log
    self._runtime_stats = None
    self.kwargs = kwargs
    self.end_wait_seconds = end_wait_seconds
    if dump_runtime_stats:
      self._runtime_stats = {}

  def log(self, msg):
    ''' Output a message to both self._log and self._extra_log '''
    self._log.info(msg)
    if self._extra_log is not None:
      self._extra_log.write(msg + '\n')

  def simulate(self, check_reproducability=True):
    if self._runtime_stats is not None:
      self._runtime_stats["total_inputs"] = len(self.dag.input_events)
      self._runtime_stats["total_events"] = len(self.dag)

    # inject domain knowledge into the dag
    self.dag.mark_invalid_input_sequences()
    self.dag = self.dag.filter_unsupported_input_types()

    if len(self.dag) == 0:
      raise RuntimeError("No supported input types?")

    if check_reproducability:
      # First, run through without pruning to verify that the violation exists
      if self._runtime_stats is not None:
        self._runtime_stats["replay_start_epoch"] = time.time()
      violations = self.replay(self.dag)
      if self._runtime_stats is not None:
        self._runtime_stats["replay_end_epoch"] = time.time()
      if violations == []:
        self.log("Unable to reproduce correctness violation!")
        sys.exit(5)

      self.log("Violation reproduced successfully! Proceeding with pruning")

    if self._runtime_stats is not None:
      self._runtime_stats["prune_start_epoch"] = time.time()
    self._ddmin(2)
    if self._runtime_stats is not None:
      self._runtime_stats["prune_end_epoch"] = time.time()
      self._dump_runtime_stats()
    msg.interactive("Final MCS (%d elements): %s" %
                    (len(self.dag.input_events),str(self.dag.input_events)))
    if self.mcs_trace_path is not None:
      self._dump_mcs_trace()
    return self.dag.events

  def _ddmin(self, split_ways, precomputed_subsets=None, iteration=0):
    ''' - iteration is the # of times we've replayed (not the number of times
    we've invoked _ddmin)'''
    # This is the delta-debugging algorithm from:
    #   http://www.st.cs.uni-saarland.de/papers/tse2002/tse2002.pdf,
    # Section 3.2
    # TODO(cs): we could do much better if we leverage domain knowledge (e.g.,
    # start by pruning all LinkFailures)
    if split_ways > len(self.dag.input_events):
      self._track_iteration_size(iteration + 1, split_ways)
      self.log("Done")
      return

    if precomputed_subsets is None:
      precomputed_subsets = set()

    self.log("Checking %d subsets" % split_ways)
    subsets = split_list(self.dag.input_events, split_ways)
    self.log("Subsets: %s" % str(subsets))
    for i, subset in enumerate(subsets):
      new_dag = self.dag.input_subset(subset)
      input_sequence = tuple(new_dag.input_events)
      self.log("Current subset: %s" % str(input_sequence))
      if input_sequence in precomputed_subsets:
        self.log("Already computed. Skipping")
        continue
      precomputed_subsets.add(input_sequence)
      if input_sequence == ():
        self.log("Subset after pruning dependencies was empty. Skipping")
        continue

      iteration += 1
      violation = self._check_violation(new_dag, i, iteration, split_ways)
      if violation:
        self.dag = new_dag
        return self._ddmin(2, precomputed_subsets=precomputed_subsets,
                           iteration=iteration)

    self.log("No subsets with violations. Checking complements")
    for i, subset in enumerate(subsets):
      new_dag = self.dag.input_complement(subset)
      input_sequence = tuple(new_dag.input_events)
      self.log("Current complement: %s" % str(input_sequence))
      if input_sequence in precomputed_subsets:
        self.log("Already computed. Skipping")
        continue
      precomputed_subsets.add(input_sequence)
      if input_sequence == ():
        self.log("Subset after pruning dependencies was empty. Skipping")
        continue

      iteration += 1
      violation = self._check_violation(new_dag, i, iteration, split_ways)
      if violation:
        self.dag = new_dag
        return self._ddmin(max(split_ways - 1, 2),
                           precomputed_subsets=precomputed_subsets,
                           iteration=iteration)

    self.log("No complements with violations.")
    if split_ways < len(self.dag.input_events):
      self.log("Increasing granularity.")
      return self._ddmin(min(len(self.dag.input_events), split_ways*2),
                         precomputed_subsets=precomputed_subsets,
                         iteration=iteration)
    self._track_iteration_size(iteration + 1, split_ways)

  def _track_iteration_size(self, iteration, split_ways):
    if self._runtime_stats is not None:
      if "iteration_size" not in self._runtime_stats:
        self._runtime_stats["iteration_size"] = {}
      self._runtime_stats["iteration_size"][iteration] = len(self.dag.input_events)
      if "split_ways" not in self._runtime_stats:
        self._runtime_stats["split_ways"] = set()
        self._runtime_stats["split_ways"].add(split_ways)

  def _check_violation(self, new_dag, subset_index, iteration, split_ways):
    ''' Check if there were violations '''
    self._track_iteration_size(iteration, split_ways)
    violations = self.replay(new_dag)
    if violations == []:
      # No violation!
      # If singleton, this must be part of the MCS
      self.log("No violation in %d'th..." % subset_index)
      return False
    else:
      # Violation in the subset
      self.log("Violation! Considering %d'th" % subset_index)
      return True

  def replay(self, new_dag):
    # Run the simulation forward
    if self.transform_dag:
      new_dag = self.transform_dag(new_dag)

    # TODO(aw): MCSFinder needs to configure Simulation to always let DataplaneEvents pass through
    replayer = Replayer(self.simulation_cfg, new_dag, **self.kwargs)
    simulation = replayer.simulate()
    # Wait a bit in case the bug takes awhile to happen
    time.sleep(self.end_wait_seconds)
    violations = self.invariant_check(simulation)
    simulation.clean_up()
    return violations

  def _dump_mcs_trace(self):
    # Dump the mcs trace
    input_logger = InputLogger(output_path=self.mcs_trace_path)
    for e in self.dag.events:
      input_logger.log_input_event(e)
    input_logger.close(self.simulation_cfg, skip_mcs_cfg=True)

  def _dump_runtime_stats(self):
    if self._runtime_stats is not None:
      # First compute durations
      if "replay_end_epoch" in self._runtime_stats:
        self._runtime_stats["replay_duration_seconds"] =\
          (self._runtime_stats["replay_end_epoch"] -
           self._runtime_stats["replay_start_epoch"])
      if "prune_end_epoch" in self._runtime_stats:
        self._runtime_stats["prune_duration_seconds"] =\
          (self._runtime_stats["prune_end_epoch"] -
           self._runtime_stats["prune_start_epoch"])
      self._runtime_stats["total_replays"] = Replayer.total_replays
      self._runtime_stats["total_inputs_replayed"] =\
          Replayer.total_inputs_replayed
      self._runtime_stats["config"] = str(self.simulation_cfg)
      self._runtime_stats["peeker"] = self.transform_dag is not None
      self._runtime_stats["split_ways"] = list(self._runtime_stats["split_ways"])
      # Now write contents to a file
      now = timestamp_string()
      with file("runtime_stats/" + now + ".json", "w") as output:
        json_string = json.dumps(self._runtime_stats)
        output.write(json_string)
Example #2
0
class MCSFinder(ControlFlow):
  def __init__(self, simulation_cfg, superlog_path_or_dag,
               invariant_check_name=None,
               transform_dag=None, end_wait_seconds=0.5,
               mcs_trace_path=None, extra_log=None, runtime_stats_path=None,
               wait_on_deterministic_values=False,
               no_violation_verification_runs=1,
               optimized_filtering=False, forker=LocalForker(),
               replay_final_trace=True, strict_assertion_checking=False,
               **kwargs):
    super(MCSFinder, self).__init__(simulation_cfg)
    # number of subsequences delta debugging has examined so far, for
    # distingushing runtime stats from different intermediate runs.
    self.subsequence_id = 0
    self.mcs_log_tracker = None
    self.replay_log_tracker = None
    self.mcs_trace_path = mcs_trace_path
    self.sync_callback = None
    self._log = logging.getLogger("mcs_finder")

    if invariant_check_name is None:
      raise ValueError("Must specify invariant check")
    if invariant_check_name not in name_to_invariant_check:
      raise ValueError('''Unknown invariant check %s.\n'''
                       '''Invariant check name must be defined in config.invariant_checks''',
                       invariant_check_name)
    self.invariant_check = name_to_invariant_check[invariant_check_name]

    if type(superlog_path_or_dag) == str:
      self.superlog_path = superlog_path_or_dag
      # The dag is codefied as a list, where each element has
      # a list of its dependents
      self.dag = EventDag(log_parser.parse_path(self.superlog_path))
    else:
      self.dag = superlog_path_or_dag

    last_invariant_violation = self.dag.get_last_invariant_violation()
    if last_invariant_violation is None:
      raise ValueError("No invariant violation found in dag...")
    violations = last_invariant_violation.violations
    if len(violations) > 1:
      self.bug_signature = None
      while self.bug_signature is None:
        msg.interactive("\n------------------------------------------\n")
        msg.interactive("Multiple violations detected! Choose one for MCS Finding:")
        for i, violation in enumerate(violations):
          msg.interactive("  [%d] %s" % (i+1, violation))
        violation_index = msg.raw_input("> ")
        if re.match("^\d+$", violation_index) is None or\
           int(violation_index) < 1 or\
           int(violation_index) > len(violations):
          msg.fail("Must provide an integer between 1 and %d!" % len(violations))
          continue
        self.bug_signature = violations[int(violation_index)-1]
    else:
      self.bug_signature = violations[0]
    msg.success("\nBug signature to match is %s. Proceeding with MCS finding!\n" % self.bug_signature)

    self.transform_dag = transform_dag
    # A second log with just our MCS progress log messages
    self._extra_log = extra_log
    self.kwargs = kwargs
    self.end_wait_seconds = end_wait_seconds
    self.wait_on_deterministic_values = wait_on_deterministic_values
    # `no' means "number"
    self.no_violation_verification_runs = no_violation_verification_runs
    self._runtime_stats = RuntimeStats(self.subsequence_id, runtime_stats_path=runtime_stats_path)
    # Whether to try alternate trace splitting techiques besides splitting by time.
    self.optimized_filtering = optimized_filtering
    self.forker = forker
    self.replay_final_trace = replay_final_trace
    self.strict_assertion_checking = strict_assertion_checking

  def log(self, s):
    ''' Output a message to both self._log and self._extra_log '''
    msg.mcs_event(s)
    if self._extra_log is not None:
      self._extra_log.write(s + '\n')
      self._extra_log.flush()

  def log_violation(self, s):
    ''' Output a message to both self._log and self._extra_log '''
    msg.mcs_event(color.RED + s)
    if self._extra_log is not None:
      self._extra_log.write(s + '\n')
      self._extra_log.flush()

  def log_no_violation(self, s):
    ''' Output a message to both self._log and self._extra_log '''
    msg.mcs_event(color.GREEN + s)
    if self._extra_log is not None:
      self._extra_log.write(s + '\n')
      self._extra_log.flush()

  def init_results(self, results_dir):
    ''' Precondition: results_dir exists, and is clean (preferably
    initialized by experiments/setup.py).'''
    if self._extra_log is None:
      self._extra_log = open("%s/mcs_finder.log" % results_dir, "w")
    if self._runtime_stats.get_runtime_stats_path() is None:
      runtime_stats_path = "%s/runtime_stats.json" % results_dir
      self._runtime_stats.set_runtime_stats_path(runtime_stats_path)
    if self.mcs_trace_path is None:
      self.mcs_trace_path = "%s/mcs.trace" % results_dir
    # TODO(cs): assumes that transform dag is a peeker, not some other
    # transformer
    peeker_exists = self.transform_dag is not None
    self.mcs_log_tracker = MCSLogTracker(results_dir, self.mcs_trace_path,
                                         self._runtime_stats,
                                         self.simulation_cfg, peeker_exists)
    self.replay_log_tracker = ReplayLogTracker(results_dir)

  # N.B. only called in the parent process.
  def simulate(self, check_reproducability=True):
    self._runtime_stats.set_dag_stats(self.dag)

    # apply domain knowledge: treat failure/recovery pairs atomically, and
    # filter event types we don't want to include in the MCS
    # (e.g. CheckInvariants)
    self.dag.mark_invalid_input_sequences()
    self.dag = self.dag.filter_unsupported_input_types()

    if len(self.dag) == 0:
      raise RuntimeError("No supported input types?")

    if check_reproducability:
      # First, run through without pruning to verify that the violation exists
      self._runtime_stats.record_replay_start()

      for i in range(0, self.no_violation_verification_runs):
        bug_found = self.replay(self.dag, "reproducibility",
                                ignore_runtime_stats=True)
        if bug_found:
          break
      self._runtime_stats.set_initial_verification_runs_needed(i)
      self._runtime_stats.record_replay_end()
      if not bug_found:
        msg.fail("Unable to reproduce correctness violation!")
        sys.exit(5)
      self.log("Violation reproduced successfully! Proceeding with pruning")

    self._runtime_stats.record_prune_start()

    # Run optimizations.
    # TODO(cs): Better than a boolean flag: check if
    # log(len(self.dag)) > number of input types to try
    if self.optimized_filtering:
      self._optimize_event_dag()
    precompute_cache = PrecomputeCache()

    # Invoke delta debugging
    (dag, total_inputs_pruned) = self._ddmin(self.dag, 2, precompute_cache=precompute_cache)
    # Make sure to track the final iteration size
    self._track_iteration_size(total_inputs_pruned)
    self.dag = dag

    self.log("=== Total replays: %d ===" % self._runtime_stats.total_replays)
    self._runtime_stats.record_prune_end()
    self.mcs_log_tracker.dump_runtime_stats()
    self.log("Final MCS (%d elements):" % len(self.dag.input_events))
    for i in self.dag.input_events:
      self.log(" - %s" % str(i))

    if self.replay_final_trace:
      #  Replaying the final trace achieves two goals:
      #  - verifies that the MCS indeed ends in the violation
      #  - allows us to prune internal events that time out
      bug_found = self.replay(self.dag, "final_mcs_trace",
                              ignore_runtime_stats=True)
      if not bug_found:
        self.log('''Warning: MCS did not result in violation. Trying replay '''
                 '''again without timed out events.''')
        # TODO(cs): always replay the MCS without timeouts, since the console
        # output will be significantly cleaner?
        no_timeouts = self.dag.filter_timeouts()
        bug_found = self.replay(no_timeouts, "final_mcs_no_timed_out_events",
                                ignore_runtime_stats=True)
        if not bug_found:
          self.log('''Warning! Final MCS did not result in violation, even '''
                   ''' after ignoring timed out internal events. '''
                   ''' See tools/visualization/visualize1D.html for debugging''')

    # N.B. dumping the MCS trace must occur after the final replay trace,
    # since we need to infer which events will time out for events.trace.notimeouts
    if self.mcs_trace_path is not None:
      self.mcs_log_tracker.dump_mcs_trace(self.dag, self)
    return ExitCode(0)

  # N.B. always called within a child process.
  def _ddmin(self, dag, split_ways, precompute_cache=None, label_prefix=(),
             total_inputs_pruned=0):
    # This is the delta-debugging algorithm from:
    #   http://www.st.cs.uni-saarland.de/papers/tse2002/tse2002.pdf,
    # Section 3.2
    # TODO(cs): we could do much better if we leverage domain knowledge (e.g.,
    # start by pruning all LinkFailures, or splitting by nodes rather than
    # time)
    if split_ways > len(dag.input_events):
      self.log("Done")
      return (dag, total_inputs_pruned)

    local_label = lambda i, inv=False: "%s%d/%d" % ("~" if inv else "", i, split_ways)
    subset_label = lambda label: ".".join(map(str, label_prefix + ( label, )))
    print_subset = lambda label, s: subset_label(label) + ": "+" ".join(map(lambda e: e.label, s))

    subsets = split_list(dag.input_events, split_ways)
    self.log("Subsets:\n"+"\n".join(print_subset(local_label(i), s) for i, s in enumerate(subsets)))
    for i, subset in enumerate(subsets):
      label = local_label(i)
      new_dag = dag.input_subset(subset)
      input_sequence = tuple(new_dag.input_events)
      self.log("Current subset: %s" % print_subset(label, input_sequence))
      if precompute_cache.already_done(input_sequence):
        self.log("Already computed. Skipping")
        continue
      precompute_cache.update(input_sequence)
      if input_sequence == ():
        self.log("Subset after pruning dependencies was empty. Skipping")
        continue

      self._track_iteration_size(total_inputs_pruned)
      violation = self._check_violation(new_dag, i, label)
      if violation:
        self.log_violation("Subset %s reproduced violation. Subselecting." % subset_label(label))
        self.mcs_log_tracker.maybe_dump_intermediate_mcs(new_dag,
                                                         subset_label(label), self)

        total_inputs_pruned += len(dag.input_events) - len(new_dag.input_events)
        return self._ddmin(new_dag, 2, precompute_cache=precompute_cache,
                           label_prefix = label_prefix + (label, ),
                           total_inputs_pruned=total_inputs_pruned)

    self.log_no_violation("No subsets with violations. Checking complements")
    for i, subset in enumerate(subsets):
      label = local_label(i, True)
      prefix = label_prefix + (label, )
      new_dag = dag.input_complement(subset)
      input_sequence = tuple(new_dag.input_events)
      self.log("Current complement: %s" % print_subset(label, input_sequence))
      if precompute_cache.already_done(input_sequence):
        self.log("Already computed. Skipping")
        continue
      precompute_cache.update(input_sequence)

      if input_sequence == ():
        self.log("Subset %s after pruning dependencies was empty. Skipping", subset_label(label))
        continue

      self._track_iteration_size(total_inputs_pruned)
      violation = self._check_violation(new_dag, i, label)
      if violation:
        self.log_violation("Subset %s reproduced violation. Subselecting." % subset_label(label))
        self.mcs_log_tracker.maybe_dump_intermediate_mcs(new_dag,
                                                         subset_label(label), self)
        total_inputs_pruned += len(dag.input_events) - len(new_dag.input_events)
        return self._ddmin(new_dag, max(split_ways - 1, 2),
                           precompute_cache=precompute_cache,
                           label_prefix=prefix,
                           total_inputs_pruned=total_inputs_pruned)

    self.log_no_violation("No complements with violations.")
    if split_ways < len(dag.input_events):
      self.log("Increasing granularity.")
      return self._ddmin(dag, min(len(dag.input_events), split_ways*2),
                         precompute_cache=precompute_cache,
                         label_prefix=label_prefix,
                         total_inputs_pruned=total_inputs_pruned)
    return (dag, total_inputs_pruned)

  # N.B. always called within a child process.
  def _track_iteration_size(self, total_inputs_pruned):
    self._runtime_stats.record_iteration_size(len(self.dag.input_events) - total_inputs_pruned)

  # N.B. always called within a child process.
  def _check_violation(self, new_dag, subset_index, label):
    ''' Check if there were violations '''
    # Try no_violation_verification_runs times to see if the bug shows up
    for i in range(0, self.no_violation_verification_runs):
      bug_found = self.replay(new_dag, label)

      if bug_found:
        # Violation in the subset
        self.log_violation("Violation! Considering %d'th" % subset_index)
        self._runtime_stats.record_violation_found(i)
        return True

    # No violation!
    self.log_no_violation("No violation in %d'th..." % subset_index)
    return False

  def replay(self, new_dag, label, ignore_runtime_stats=False):
    # Run the simulation forward
    if self.transform_dag:
      new_dag = self.transform_dag(new_dag)

    self._runtime_stats.record_replay_stats(len(new_dag.input_events))

    # N.B. this function is run as a child process.
    def play_forward(results_dir, subsequence_id):
      # TODO(cs): need to serialize the parameters to Replayer rather than
      # wrapping them in a closure... otherwise, can't use RemoteForker
      # TODO(aw): MCSFinder needs to configure Simulation to always let DataplaneEvents pass through
      create_clean_python_dir(results_dir)

      # Copy stdout and stderr to a file "replay.out"
      tee = Tee(open(os.path.join(results_dir, "replay.out"), "w"))
      tee.tee_stdout()
      tee.tee_stderr()

      # Set up replayer.
      input_logger = InputLogger()
      replayer = Replayer(self.simulation_cfg, new_dag,
                          wait_on_deterministic_values=self.wait_on_deterministic_values,
                          input_logger=input_logger,
                          allow_unexpected_messages=False,
                          pass_through_whitelisted_messages=True,
                          **self.kwargs)
      replayer.init_results(results_dir)
      self._runtime_stats = RuntimeStats(subsequence_id)
      violations = []
      simulation = None
      try:
        simulation = replayer.simulate()
        self._track_new_internal_events(simulation, replayer)
        # Wait a bit in case the bug takes awhile to happen
        self.log("Sleeping %d seconds after run"  % self.end_wait_seconds)
        time.sleep(self.end_wait_seconds)
        violations = self.invariant_check(simulation)
        if violations != []:
          input_logger.log_input_event(InvariantViolation(violations))
      except SystemExit:
        # One of the invariant checks bailed early. Oddly, this is not an
        # error for us, it just means that there were no violations...
        # [this logic is arguably broken]
        # Return no violations, and let Forker handle system exit for us.
        violations = []
      finally:
        input_logger.close(replayer, self.simulation_cfg, skip_mcs_cfg=True)
        if simulation is not None:
          simulation.clean_up()
        tee.close()
      if self.strict_assertion_checking:
        test_serialize_response(violations, self._runtime_stats.client_dict())
      timed_out_internal = [ e.label for e in new_dag.events if e.timed_out ]
      return (violations, self._runtime_stats.client_dict(), timed_out_internal)

    # TODO(cs): once play_forward() is no longer a closure, register it only once
    self.forker.register_task("play_forward", play_forward)
    results_dir = self.replay_log_tracker.get_replay_logger_dir(label)
    self.subsequence_id += 1
    (violations, client_runtime_stats,
                 timed_out_internal) = self.forker.fork("play_forward",
                                                        results_dir,
                                                        self.subsequence_id)
    new_dag.set_events_as_timed_out(timed_out_internal)

    bug_found = False
    if violations != []:
      msg.fail("Violations: %s" % str(violations))
      if self.bug_signature in violations:
        bug_found = True
      else:
        msg.fail("Bug does not match initial violation fingerprint!")
    else:
      msg.interactive("No correctness violations!")

    if not ignore_runtime_stats:
      self._runtime_stats.merge_client_dict(client_runtime_stats)

    return bug_found

  def _optimize_event_dag(self):
    ''' Employs domain knowledge of event classes to reduce the size of event
    dag. Currently prunes event types.'''
    # TODO(cs): Another approach for later: split by nodes
    event_types = [TrafficInjection, DataplaneDrop, SwitchFailure,
                   SwitchRecovery, LinkFailure, LinkRecovery, HostMigration,
                   ControllerFailure, ControllerRecovery, PolicyChange, ControlChannelBlock,
                   ControlChannelUnblock]
    for event_type in event_types:
      pruned = [e for e in self.dag.input_events if not isinstance(e, event_type)]
      if len(pruned)==len(self.dag.input_events):
        self.log("\t** No events pruned for event type %s. Next!" % event_type)
        continue
      pruned_dag = self.dag.input_complement(pruned)
      bug_found = self.replay(pruned_dag, "opt_%s" % event_type.__name__)
      if bug_found:
        self.log("\t** VIOLATION for pruning event type %s! Resizing original dag" % event_type)
        self.dag = pruned_dag

  # N.B. always called within a child process.
  def _track_new_internal_events(self, simulation, replayer):
    ''' Pre: simulation must have been run through a replay'''
    # We always check against internal events that were buffered at the end of
    # the original run (don't want to overcount)
    prev_buffered_receives = []
    try:
      path = self.superlog_path + ".unacked"
      if not os.path.exists(path):
        log.warn("unacked internal events file from original run does not exist")
        return
      prev_buffered_receives = set([ e.pending_receive for e in
                                     [ f for f in EventDag(log_parser.parse_path(path)).events
                                       if type(f) == ControlMessageReceive ] ])
    except ValueError as e:
      log.warn("unacked internal events is corrupt? %r" % e)
      return
    buffered_message_receipts = []
    for p in simulation.openflow_buffer.pending_receives():
      if p not in prev_buffered_receives:
        buffered_message_receipts.append(repr(p))
      else:
        prev_buffered_receives.remove(p)

    self._runtime_stats.record_buffered_message_receipts(buffered_message_receipts)
    new_internal_events = replayer.unexpected_state_changes + replayer.passed_unexpected_messages
    self._runtime_stats.record_new_internal_events(new_internal_events)
    self._runtime_stats.record_early_internal_events(replayer.early_state_changes)
    self._runtime_stats.record_timed_out_events(dict(replayer.event_scheduler_stats.event2timeouts))
    self._runtime_stats.record_matched_events(dict(replayer.event_scheduler_stats.event2matched))
Example #3
0
class MCSFinder(ControlFlow):
  def __init__(self, simulation_cfg, superlog_path_or_dag,
               invariant_check_name=None,
               transform_dag=None, end_wait_seconds=0.5,
               mcs_trace_path=None, extra_log=None, runtime_stats_file=None,
               wait_on_deterministic_values=False,
               no_violation_verification_runs=1,
               optimized_filtering=False,
               **kwargs):
    super(MCSFinder, self).__init__(simulation_cfg)
    self.sync_callback = None
    self._log = logging.getLogger("mcs_finder")

    if invariant_check_name is None:
      raise ValueError("Must specify invariant check")
    if invariant_check_name not in name_to_invariant_check:
      raise ValueError('''Unknown invariant check %s.\n'''
                       '''Invariant check name must be defined in config.invariant_checks''',
                       invariant_check_name)
    self.invariant_check = name_to_invariant_check[invariant_check_name]

    if type(superlog_path_or_dag) == str:
      self.superlog_path = superlog_path_or_dag
      # The dag is codefied as a list, where each element has
      # a list of its dependents
      self.dag = EventDag(superlog_parser.parse_path(self.superlog_path))
    else:
      self.dag = superlog_path_or_dag

    self.transform_dag = transform_dag
    self.mcs_trace_path = mcs_trace_path
    self._extra_log = extra_log
    self.kwargs = kwargs
    self.end_wait_seconds = end_wait_seconds
    self.wait_on_deterministic_values = wait_on_deterministic_values
    # `no' means "number"
    self.no_violation_verification_runs = no_violation_verification_runs
    self._runtime_stats = RuntimeStats(runtime_stats_file)
    self.optimized_filtering = optimized_filtering

  def log(self, s):
    ''' Output a message to both self._log and self._extra_log '''
    msg.mcs_event(s)
    if self._extra_log is not None:
      self._extra_log.write(s + '\n')
      self._extra_log.flush()

  def log_violation(self, s):
    ''' Output a message to both self._log and self._extra_log '''
    msg.mcs_event(color.RED + s)
    if self._extra_log is not None:
      self._extra_log.write(s + '\n')
      self._extra_log.flush()

  def log_no_violation(self, s):
    ''' Output a message to both self._log and self._extra_log '''
    msg.mcs_event(color.GREEN + s)
    if self._extra_log is not None:
      self._extra_log.write(s + '\n')
      self._extra_log.flush()

  def init_results(self, results_dir):
    self.results_dir = results_dir
    if self._extra_log is None:
      self._extra_log = open("%s/mcs_finder.log" % results_dir, "w")
    if self._runtime_stats.runtime_stats_file is None:
      self._runtime_stats.runtime_stats_file = "%s/runtime_stats.json" % results_dir
    if self.mcs_trace_path is None:
      self.mcs_trace_path = "%s/mcs.trace" % results_dir

  def simulate(self, check_reproducability=True):
    self._runtime_stats.set_dag_stats(self.dag)

    # inject domain knowledge into the dag
    self.dag.mark_invalid_input_sequences()
    self.dag = self.dag.filter_unsupported_input_types()

    if len(self.dag) == 0:
      raise RuntimeError("No supported input types?")

    if check_reproducability:
      # First, run through without pruning to verify that the violation exists
      self._runtime_stats.record_replay_start()

      for i in range(0, self.no_violation_verification_runs):
        violations = self.replay(self.dag)
        if violations != []:
          break
      self._runtime_stats.set_initial_verification_runs_needed(i)
      self._runtime_stats.record_replay_end()
      if violations == []:
        msg.fail("Unable to reproduce correctness violation!")
        sys.exit(5)
      self.log("Violation reproduced successfully! Proceeding with pruning")
      Replayer.total_replays = 0
      Replayer.total_inputs_replayed = 0

    self._runtime_stats.record_prune_start()

    # TODO(cs): Better than a boolean flag: check if
    # log(len(self.dag)) > number of input types to try
    if self.optimized_filtering:
      self._optimize_event_dag()
    precompute_cache = PrecomputeCache()
    (dag, total_inputs_pruned) = self._ddmin(self.dag, 2, precompute_cache=precompute_cache)
    # Make sure to track the final iteration size
    self._track_iteration_size(total_inputs_pruned)
    self.dag = dag

    self._runtime_stats.record_prune_end()
    self._dump_runtime_stats()
    self.log("Final MCS (%d elements):" % len(self.dag.input_events))
    for i in self.dag.input_events:
      self.log(" - %s" % str(i))
    if self.mcs_trace_path is not None:
      self._dump_mcs_trace()
    self.log("=== Total replays: %d ===" % Replayer.total_replays)

    return self.simulation

  def _ddmin(self, dag, split_ways, precompute_cache=None, label_prefix=(),
             total_inputs_pruned=0):
    # This is the delta-debugging algorithm from:
    #   http://www.st.cs.uni-saarland.de/papers/tse2002/tse2002.pdf,
    # Section 3.2
    # TODO(cs): we could do much better if we leverage domain knowledge (e.g.,
    # start by pruning all LinkFailures)
    if split_ways > len(dag.input_events):
      self.log("Done")
      return (dag, total_inputs_pruned)

    local_label = lambda i, inv=False: "%s%d/%d" % ("~" if inv else "", i, split_ways)
    subset_label = lambda label: ".".join(map(str, label_prefix + ( label, )))
    print_subset = lambda label, s: subset_label(label) + ": "+" ".join(map(lambda e: e.label, s))

    subsets = split_list(dag.input_events, split_ways)
    self.log("Subsets:\n"+"\n".join(print_subset(local_label(i), s) for i, s in enumerate(subsets)))
    for i, subset in enumerate(subsets):
      label = local_label(i)
      new_dag = dag.input_subset(subset)
      input_sequence = tuple(new_dag.input_events)
      self.log("Current subset: %s" % print_subset(label, input_sequence))
      if precompute_cache.already_done(input_sequence):
        self.log("Already computed. Skipping")
        continue
      precompute_cache.update(input_sequence)
      if input_sequence == ():
        self.log("Subset after pruning dependencies was empty. Skipping")
        continue

      self._track_iteration_size(total_inputs_pruned)
      violation = self._check_violation(new_dag, i)
      if violation:
        self.log_violation("Subset %s reproduced violation. Subselecting." % subset_label(label))
        self._maybe_dump_intermediate_mcs(new_dag, subset_label(label))

        total_inputs_pruned += len(dag.input_events) - len(new_dag.input_events)
        return self._ddmin(new_dag, 2, precompute_cache=precompute_cache,
                           label_prefix = label_prefix + (label, ),
                           total_inputs_pruned=total_inputs_pruned)

    self.log_no_violation("No subsets with violations. Checking complements")
    for i, subset in enumerate(subsets):
      label = local_label(i, True)
      prefix = label_prefix + (label, )
      new_dag = dag.input_complement(subset)
      input_sequence = tuple(new_dag.input_events)
      self.log("Current complement: %s" % print_subset(label, input_sequence))
      if precompute_cache.already_done(input_sequence):
        self.log("Already computed. Skipping")
        continue
      precompute_cache.update(input_sequence)

      if input_sequence == ():
        self.log("Subset %s after pruning dependencies was empty. Skipping", subset_label(label))
        continue

      self._track_iteration_size(total_inputs_pruned)
      violation = self._check_violation(new_dag, i)
      if violation:
        self.log_violation("Subset %s reproduced violation. Subselecting." % subset_label(label))
        self._maybe_dump_intermediate_mcs(new_dag, subset_label(label))
        total_inputs_pruned += len(dag.input_events) - len(new_dag.input_events)
        return self._ddmin(new_dag, max(split_ways - 1, 2),
                           precompute_cache=precompute_cache,
                           label_prefix=prefix,
                           total_inputs_pruned=total_inputs_pruned)

    self.log_no_violation("No complements with violations.")
    if split_ways < len(dag.input_events):
      self.log("Increasing granularity.")
      return self._ddmin(dag, min(len(dag.input_events), split_ways*2),
                         precompute_cache=precompute_cache,
                         label_prefix=label_prefix,
                         total_inputs_pruned=total_inputs_pruned)
    return (dag, total_inputs_pruned)

  def _track_iteration_size(self, total_inputs_pruned):
    self._runtime_stats.record_iteration_size(len(self.dag.input_events) - total_inputs_pruned)

  def _check_violation(self, new_dag, subset_index):
    ''' Check if there were violations '''
    # Try no_violation_verification_runs times to see if the bug shows up
    for i in range(0, self.no_violation_verification_runs):
      violations = self.replay(new_dag)

      if violations != []:
        # Violation in the subset
        self.log_violation("Violation! Considering %d'th" % subset_index)
        self._runtime_stats.record_violation_found(i)
        return True

    # No violation!
    self.log_no_violation("No violation in %d'th..." % subset_index)
    return False

  def replay(self, new_dag):
    # Run the simulation forward
    if self.transform_dag:
      new_dag = self.transform_dag(new_dag)

    # TODO(aw): MCSFinder needs to configure Simulation to always let DataplaneEvents pass through
    replayer = Replayer(self.simulation_cfg, new_dag,
                        wait_on_deterministic_values=self.wait_on_deterministic_values,
                        **self.kwargs)
    simulation = replayer.simulate()
    self._track_new_internal_events(simulation, replayer)
    # Wait a bit in case the bug takes awhile to happen
    self.log("Sleeping %d seconds after run"  % self.end_wait_seconds)
    time.sleep(self.end_wait_seconds)
    violations = self.invariant_check(simulation)
    simulation.clean_up()
    return violations

  def _optimize_event_dag(self):
    ''' Employs domain knowledge of event classes to reduce the size of event dag '''
    event_types = [TrafficInjection, DataplaneDrop, SwitchFailure,
                   SwitchRecovery, LinkFailure, LinkRecovery, HostMigration,
                   ControllerFailure, ControllerRecovery, PolicyChange, ControlChannelBlock,
                   ControlChannelUnblock]
    for event_type in event_types:
      pruned = [e for e in self.dag.input_events if not isinstance(e, event_type)]
      if len(pruned)==len(self.dag.input_events):
        self.log("\t** No events pruned for event type %s. Next!" % event_type)
        continue
      pruned_dag = self.dag.input_complement(pruned)
      violations = self.replay(pruned_dag)
      if violations != []:
        self.log("\t** VIOLATION for pruning event type %s! Resizing original dag" % event_type)
        self.dag = pruned_dag

  def _track_new_internal_events(self, simulation, replayer):
    ''' Pre: simulation must have been run through a replay'''
    # We always check against internal events that were buffered at the end of
    # the original run (don't want to overcount)
    path = self.superlog_path + ".unacked"
    if not os.path.exists(path):
      log.warn("unacked internal events file from original run does not exists")
      return
    prev_buffered_receives = [ e.pending_receive for e in
                               EventDag(superlog_parser.parse_path(path)).events ]
    new_message_receipts = []
    for p in simulation.god_scheduler.pending_receives():
      if p not in prev_buffered_receives:
        new_message_receipts.append(repr(p))
      else:
        prev_buffered_receives.remove(p)
    new_state_changes = replayer.unexpected_state_changes
    new_internal_events = new_state_changes + new_message_receipts
    self._runtime_stats.record_new_internal_events(new_internal_events)
    self._runtime_stats.record_early_internal_events(replayer.early_state_changes)
    self._runtime_stats.record_timed_out_events(dict(replayer.event_scheduler_stats.event2timeouts))
    self._runtime_stats.record_matched_events(dict(replayer.event_scheduler_stats.event2matched))

  def _maybe_dump_intermediate_mcs(self, dag, label):
    class InterMCS(object):
      def __init__(self):
        self.min_size = sys.maxint
        self.count = 0
    if not hasattr(self, "_intermcs"):
      self._intermcs = InterMCS()
    if len(dag.events) < self._intermcs.min_size:
      self._intermcs.min_size = len(dag.events)
      self._intermcs.count += 1
      dst = os.path.join(self.results_dir, "intermcs_%d_%s" % (self._intermcs.count, label.replace("/", ".")))
      os.makedirs(dst)
      self._dump_mcs_trace(dag, os.path.join(dst, os.path.basename(self.mcs_trace_path)))
      self._dump_runtime_stats(os.path.join(dst,
          os.path.basename(self._runtime_stats.runtime_stats_file)))

  def _dump_mcs_trace(self, dag=None, mcs_trace_path=None):
    if dag is None:
      dag = self.dag
    if mcs_trace_path is None:
      mcs_trace_path = self.mcs_trace_path
    # Dump the mcs trace
    input_logger = InputLogger(output_path=mcs_trace_path)
    input_logger.open(os.path.dirname(mcs_trace_path))
    for e in dag.events:
      input_logger.log_input_event(e)
    input_logger.close(self, self.simulation_cfg, skip_mcs_cfg=True)

  def _dump_runtime_stats(self, runtime_stats_file=None):
    runtime_stats = self._runtime_stats.clone()
    if runtime_stats_file is not None:
      runtime_stats.runtime_stats_file = runtime_stats_file
    runtime_stats.set_peeker(self.transform_dag is not None)
    runtime_stats.set_config(str(self.simulation_cfg))
    runtime_stats.write_runtime_stats()