Esempio n. 1
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      Some value that we want to have printed later.
    """
        package_path = args.package_path or os.getcwd()
        # Mimic behavior of ml jobs submit training
        package_root = os.path.dirname(os.path.abspath(package_path))
        if args.distributed:
            local_train.RunDistributed(args.module_name,
                                       package_root,
                                       args.parameter_server_count or 2,
                                       args.worker_count or 2,
                                       args.start_port,
                                       user_args=args.user_args)
        else:
            if args.parameter_server_count:
                log.warn(
                    _BAD_FLAGS_WARNING_MESSAGE.format(
                        flag='--parameter-server-count'))
            if args.worker_count:
                log.warn(
                    _BAD_FLAGS_WARNING_MESSAGE.format(flag='--worker-count'))
            local_train.MakeProcess(args.module_name,
                                    package_root,
                                    args=args.user_args,
                                    task_type='master')
Esempio n. 2
0
  def Run(self, args):
    """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      Some value that we want to have printed later.
    """
    package_path = args.package_path or os.getcwd()
    # Mimic behavior of ml-engine jobs submit training
    package_root = os.path.dirname(os.path.abspath(package_path))
    if args.distributed:
      retval = local_train.RunDistributed(
          args.module_name,
          package_root,
          args.parameter_server_count or 2,
          args.worker_count or 2,
          args.start_port,
          user_args=args.user_args or [])
    else:
      if args.parameter_server_count:
        log.warn(_BAD_FLAGS_WARNING_MESSAGE.format(
            flag='--parameter-server-count'))
      if args.worker_count:
        log.warn(_BAD_FLAGS_WARNING_MESSAGE.format(flag='--worker-count'))
      retval = local_train.MakeProcess(args.module_name,
                                       package_root,
                                       args=args.user_args,
                                       task_type='main')
    # Don't raise an exception because the users will already see the message.
    # We want this to mimic calling the script directly as much as possible.
    self.exit_code = retval