Example #1
0
def main():
    args = parse_args()

    logger = utils.default_logger_setup()

    utils.mkdir_p(args.base_dir)
    expt = json.load(args.expt_name)

    last_iter = None
    while True:
        current_iter = integrated_rl.divine_current_iteration(args.base_dir)
        logger.info('Current iteration: %d', current_iter)

        if last_iter is not None:
            if last_iter >= current_iter:
                raise ValueError(
                    'No progress made. last_iter = %s, current_iter = %s' %
                    (last_iter, current_iter))
        last_iter = current_iter

        if current_iter >= expt['num_iters']:
            break
        cmd = ['./rp', 'metagrok/integrated_rl.py']
        cmd.extend(sys.argv[1:])
        logger.info('Running command: %s', cmd)
        subprocess.check_call(cmd, stdout=sys.stdout, stderr=sys.stderr)

        logger.info('Evicting /tmp directory')
        shutil.rmtree('/tmp')
        utils.mkdir_p('/tmp')

    logger.info('Done!')
Example #2
0
def run_update(updater, iter_num, out_dir, gamma, lam, delete_logs):
    match_dir = os.path.join(out_dir, 'matches')
    model_dir = os.path.join(out_dir, 'models')

    mkdir_p(model_dir)
    mkdir_p(match_dir)

    iter_name = mk_iter_name(iter_num)
    iter_dir = os.path.join(out_dir, iter_name)

    fname = os.path.join(match_dir, '%s.npz' % iter_name)
    data = rollup(updater.policy, iter_dir, gamma, lam)
    np.savez_compressed(fname, **data)

    results = check_results(iter_dir)
    logger.info('Results: %s', results)

    if delete_logs:
        shutil.rmtree(iter_dir)

    extras = prepare([sorted(glob.glob('%s/*.npz' % match_dir))[-1]])
    post_prepare(extras)
    dataset = NDArrayDictDataset(extras)

    updater.update(dataset)
    mname = os.path.join(model_dir, '%s.pytorch' % iter_name)
    torch.save(updater.policy.state_dict(), mname)
Example #3
0
def main():
  args = parse_args()
  utils.mkdir_p(args.outdir)

  params = vars(args)
  params['_format_details'] = formats.get(args.format)

  with open(os.path.join(args.outdir, 'args.json'), 'w') as fd:
    json.dump(params, fd)

  p1wins, p2wins = gevent.spawn(start, parse_args()).get()

  subject = 'head2head evaluation finished: ' + args.outdir

  mean = float(p1wins) / args.num_matches
  var = mean * (1. - mean)
  sem = math.sqrt(var / args.num_matches)
  z = (mean - 0.5) / (1e-8 + sem)

  fmt_args = (args.p1, p1wins, args.p2, p2wins, mean, sem, z, json.dumps(params, indent = 2))

  text = textwrap.dedent('''\
  Results:
    p1 (%s) num wins: %s
    p2 (%s) num wins: %s

    mean: %s
    sem: %s
    z-score: %s

  Arguments:
  %s''') % fmt_args

  mail.send(subject, text)
Example #4
0
def run_iter(out_dir, iter_num, simulate_fn, clean=False):
    iter_name = mk_iter_name(iter_num)
    iter_dir = os.path.join(out_dir, iter_name)

    if clean and os.path.isdir(iter_dir):
        shutil.rmtree(iter_dir)

    mkdir_p(iter_dir)
    simulate_fn(iter_dir)
Example #5
0
def start(args):
  logger = utils.default_logger_setup(logging.INFO)
  logger.info('Writing to ' + args.outdir)

  config.set_cuda(args.cuda)
  p1dir = os.path.join(args.outdir, 'p1')
  p2dir = os.path.join(args.outdir, 'p2')

  utils.mkdir_p(p1dir)
  utils.mkdir_p(p2dir)

  prog = os.path.join(
      config.get('showdown_root'),
      config.get('showdown_server_dir'),
      'pokemon-showdown')

  game = Game(options = formats.get(args.format), prog = prog)

  policy_1 = torch_policy.load(args.p1)
  policy_2 = torch_policy.load(args.p2)

  wins = [0, 0]

  logger.info('starting...')
  # TODO: make multithreaded. Maybe integrate this with RL experiment
  for i in range(args.num_matches):
    p1 = EnginePkmnPlayer(policy_1, '%s-p1' % i,
      play_best_move = args.play_best_move in ['p1', 'both'])
    p2 = EnginePkmnPlayer(policy_2, '%s-p2' % i,
      play_best_move = args.play_best_move in ['p2', 'both'])
    game.play(p1, p2)

    for j in [0, 1]:
      player = [p1, p2][j]
      dirname = [p1dir, p2dir][j]

      bogger = battlelogs.BattleLogger(player.gid, dirname)
      for block in player.blocks:
        bogger.log(block)
      bogger.close()

      if player.result == 'winner':
        wins[j] += 1
      else:
        assert player.result in ['loser', 'tie']

  return wins
Example #6
0
def main():
  import argparse
  parser = argparse.ArgumentParser()
  parser.add_argument('expt_name')
  parser.add_argument('base_dir')
  parser.add_argument('--cuda', action = 'store_true')
  parser.add_argument('--parallelism', type = int, default = mp.cpu_count())
  parser.add_argument('--prog', choices = list(_name_to_prog.keys()), default = 'run_one_iteration')

  args = parser.parse_args()

  from metagrok import remote_debug
  remote_debug.listen()

  config.set_cuda(args.cuda)

  utils.mkdir_p('/tmp')

  logger = utils.default_logger_setup()
  fhandler = logging.FileHandler('/tmp/iteration.log')
  fhandler.setFormatter(logging.Formatter(constants.LOG_FORMAT))
  fhandler.setLevel(logging.INFO)
  logger.addHandler(fhandler)

  prog = _name_to_prog[args.prog]

  time_begin = utils.iso_ts()
  result = prog(
    expt_name = args.expt_name,
    base_dir = args.base_dir,
    parallelism = args.parallelism,
    cuda = args.cuda,
  )
  time_end = utils.iso_ts()

  result['time_begin'] = time_begin
  result['time_end'] = time_end

  fhandler.close()

  if args.prog != 'simulate_and_rollup' and result['iter'] % 5 == 0:
    mail.send(
      result['subject'],
      json.dumps(result, indent = 2, sort_keys = True),
      attachments = ['/tmp/iteration.log'])
  os.remove('/tmp/iteration.log')
Example #7
0
def main():
    import os

    from metagrok import utils
    from metagrok.battlelogs import BattleLogger

    args = parse_args()
    in_dir = args.in_dir
    out_dir = args.out_dir

    for in_subdir in sorted(os.listdir(in_dir)):
        in_subdir_path = os.path.join(in_dir, in_subdir)
        out_subdir_path = os.path.join(out_dir, in_subdir)
        utils.mkdir_p(out_subdir_path)
        for fname in sorted(os.listdir(in_subdir_path)):
            basename, ext = os.path.splitext(fname)
            in_path = os.path.join(in_subdir_path, fname)

            if not fname.endswith('.json'):
                print('Skipping: ' + in_path)
                continue

            p1_base = basename + '.p1'
            p2_base = basename + '.p2'

            if (os.path.isfile(
                    os.path.join(out_subdir_path, p1_base + '.jsons.gz'))
                    and os.path.isfile(
                        os.path.join(out_subdir_path, p2_base + '.jsons.gz'))):
                print('Already processed: ' + in_path)
                continue
            print('Processing: ' + in_path)

            with fileio.open(in_path) as fd:
                replay = json.load(fd)
                p1s, p2s = convert(replay)

            blogger = BattleLogger(p1_base, log_dir=out_subdir_path)
            for block in p1s:
                blogger.log(block)
            blogger.close()

            blogger = BattleLogger(p2_base, log_dir=out_subdir_path)
            for block in p2s:
                blogger.log(block)
            blogger.close()
Example #8
0
    def on_start(self):
        self._s.req = None

        self._s.battler = self.spawn(
            battler.APIPlayerBattler,
            roomid=self._roomid,
            player_ctor=self._player_ctor,
            conf=self._conf,
        )
        if self._pslog_dir:
            utils.mkdir_p(self._pslog_dir)
            self._s.log_file = open(
                os.path.join(self._pslog_dir, '%s.pslog' % self._bid), 'w')
        else:
            self._s.log_file = None

        if self._timer:
            self._parent.send('send-to-room', ('/timer on', self._roomid))
Example #9
0
def main():
  args = parse_args()
  config.set_cuda(False)

  from metagrok import remote_debug
  remote_debug.listen()

  p1_policy = torch_policy.load(args.policy_tag)
  p2_policy = p1_policy
  if args.p2_policy_tag:
    p2_policy = torch_policy.load(args.p2_policy_tag)

  fmt = formats.get(args.fmt)
  game = Game(fmt, '{}/{}/pokemon-showdown'.format(
      config.get('showdown_root'),
      config.get('showdown_server_dir')))
  count = 0
  while True:
    time.sleep(0.1)

    r = sys.stdin.readline().strip()

    if r == 'done':
      break

    battle_dir = os.path.join('/tmp', args.id, '%06d' % count)
    utils.mkdir_p(battle_dir)

    p1 = EnginePkmnPlayer(p1_policy, 'p1', epsilon = args.epsilon)
    p2 = EnginePkmnPlayer(p2_policy, 'p2', epsilon = args.epsilon)
    game.play(p1, p2)

    num_blocks = 0
    for i, player in enumerate([p1, p2]):
      blogger = battlelogs.BattleLogger('p%d' % (i + 1), battle_dir)
      for block in player.blocks:
        blogger.log(block)
        num_blocks += 1
      blogger.close()
    count += 1

    sys.stdout.write('%s\t%d\n' % (battle_dir, num_blocks))
    sys.stdout.flush()
Example #10
0
def load_latest_policy(out_dir, policy):
    model_dir = os.path.join(out_dir, 'models')
    mkdir_p(model_dir)

    models = sorted(glob.glob('%s/*.pytorch' % model_dir))

    if models:
        policy.load_state_dict(torch.load(sorted(models)[-1]))
    else:
        torch.save(policy.state_dict(),
                   '%s/%s.pytorch' % (model_dir, mk_iter_name(0)))
        models = sorted(glob.glob('%s/*.pytorch' % model_dir))

    policy.type(config.tt())

    if config.use_cuda():
        policy = policy.cuda()

    return models
Example #11
0
def run_one_iteration(expt_name, base_dir, parallelism = mp.cpu_count(), cuda = False):
  logger = logging.getLogger('run_one_iteration')

  expt = json.load(expt_name)

  # 1: Figure out which iteration we are running
  current_iter = divine_current_iteration(base_dir)
  iter_dir = os.path.join(base_dir, 'iter%06d' % current_iter)
  utils.mkdir_p(iter_dir)
  logger.info('Current iteration: %d', current_iter)

  # 2: If rollup file exists, we've finished simulating battles
  rollup_fname = os.path.join(iter_dir, 'rollup.npz')
  if not os.path.isfile(rollup_fname):
    # 3: If not, finish simulations and make the rollup file.
    simulate_and_rollup(expt_name, base_dir, parallelism, cuda)

  # 4: Do gradient update, write to end.pytorch
  result = perform_policy_update(expt_name, base_dir, parallelism, cuda)
  next_start_model_file = result['next_start_model_file']

  # 5: Email
  subject = 'Iteration [%d/%d] finished for [%s]' % (current_iter + 1, expt['num_iters'], base_dir)
  expt_shortname = os.path.splitext(os.path.basename(expt_name))[0]
  tag = 'eval-%s-%03d' % (expt_shortname, current_iter)
  message = '''To evaluate, run:
  
  scripts/deploy.sh %s smogeval ./rp metagrok/exe/smogon_eval.py %s:%s
  ''' % (tag, expt['policy_cls'], next_start_model_file)

  return dict(
    dir = iter_dir,
    iter = current_iter + 1,
    message = message,
    name = expt_shortname,
    next_start_model_file = next_start_model_file,
    params = expt,
    subject = subject,
  )
Example #12
0
  def _setup(self):
    # set lr to some dummy value
    self._optimizer = optim.Adam(
      self.policy.parameters(),
      lr = 1e-4,
      weight_decay = self._weight_decay,
    )

    if self._out_dir:
      utils.mkdir_p(self._out_dir)
      result = _load_latest(self._out_dir)

      if result:
        model_fname = result['model']
        optimizer_fname = result['optimizer']
        random_fname = result['random']
        self._current_epoch = result['epoch']

        self.logger.info('Loading random state: %s', random_fname)
        rand = torch.load(random_fname)
        torch.cuda.random.set_rng_state_all(rand['pytorch_cuda'])
        torch.random.set_rng_state(rand['pytorch_cpu'])
        random.setstate(rand['python'])

        self.logger.info('Loading optimizer: %s', optimizer_fname)
        self.optimizer.load_state_dict(torch.load(optimizer_fname))

        self.logger.info('Loading model: %s', model_fname)
        self.policy.load_state_dict(torch.load(model_fname))
      else:
        self._current_epoch = 0
        self._save_checkpoint(override = 0)
    else:
      self._current_epoch = 0

    self._prepare_learning_rate()
Example #13
0
def start(args):
  logger = utils.default_logger_setup(logging.INFO)
  logger.info('Writing to ' + args.outdir)

  config.set_cuda(args.cuda)
  p1dir = os.path.join(args.outdir, 'p1')
  p2dir = os.path.join(args.outdir, 'p2')

  utils.mkdir_p(p1dir)
  utils.mkdir_p(p2dir)

  prog = os.path.join(
      config.get('showdown_root'),
      config.get('showdown_server_dir'),
      'pokemon-showdown')

  #game = Game(options = formats.get(args.format), prog = prog)
  policy_1 = torch_policy.load(args.p1)
  policy_2 = torch_policy.load(args.p2)

  wins = [0, 0]

  logger.info('starting...')
  # TODO: make multithreaded. Maybe integrate this with RL experiment
  p1_teams, p2_teams = tg.init_lc_thunderdome()
  strategy_agent = team_choice.AgentStrategyProfile(p1_teams, p2_teams)
  for i in range(args.num_matches):
    #team1_ind = team_choice.teamchoice_random(formats.ou_teams)
    team1_ind = strategy_agent.select_action()
    team2_ind = strategy_agent.select_action_p2()
    game = Game(options = strategy_agent.get_teams(team1_ind, team2_ind), prog = prog)
    p1 = EnginePkmnPlayer(policy_1, '%s-p1' % i,
      play_best_move = args.play_best_move in ['p1', 'both'])
    p2 = EnginePkmnPlayer(policy_2, '%s-p2' % i,
      play_best_move = args.play_best_move in ['p2', 'both'])
    game.play(p1, p2)

    for j in [0, 1]:
      player = [p1, p2][j]
      dirname = [p1dir, p2dir][j]

      bogger = battlelogs.BattleLogger(player.gid, dirname)
      for block in player.blocks:
        bogger.log(block)
      bogger.close()
      if j == 0:
         if player.result == 'winner':
           strategy_agent.update(team1_ind, team2_ind, p1_win=True)
         else:
           strategy_agent.update(team1_ind, team2_ind, p1_win=False)
      if player.result == 'winner':
        wins[j] += 1
      else:
        assert player.result in ['loser', 'tie']
  print(strategy_agent.get_utility_matrix())
  with open("lc_thunderdome_results.txt", "w+") as wf:
	  for ct, team in enumerate(strategy_agent.p1_teams):
		  wf.write("{}\t{}\n".format(ct, team))
	  for ct, team in enumerate(strategy_agent.p2_teams):
		  wf.write("{}\t{}\n".format(ct, team))
	 # wf.write(strategy_agent.get_utility_matrix())
	  wf.write("\n")
	  wf.flush()
	  wf.close()
  return wins
Example #14
0
def start(args):
    config.set_cuda(False)
    num_matches = args.num_matches
    username = args.username or utils.random_name()

    policy = torch_policy.load(args.spec)

    root_dir = args.root_dir or ('data/evals/%s' % utils.ts())
    proc_log_fname = os.path.join(root_dir, 'debug.log')
    player_log_dir = None
    if args.debug_mode:
        player_log_dir = os.path.join(root_dir, 'player-logs')

    utils.mkdir_p(root_dir)
    if player_log_dir:
        utils.mkdir_p(player_log_dir)

    params = vars(args)
    params['username'] = username
    params['root_dir'] = root_dir
    with open(os.path.join(root_dir, 'config.json'), 'w') as fd:
        json.dump(params, fd)

    logger = utils.default_logger_setup(logging.DEBUG)
    fhandler = logging.handlers.RotatingFileHandler(proc_log_fname,
                                                    maxBytes=16 * 1024 * 1024,
                                                    backupCount=5)
    fhandler.setFormatter(logging.Formatter(constants.LOG_FORMAT))
    if args.debug_mode:
        fhandler.setLevel(logging.DEBUG)
    else:
        fhandler.setLevel(logging.INFO)
    logger.addHandler(fhandler)

    conf = utils.Ns()
    conf.accept_challenges = False
    conf.formats = [args.format]
    conf.timer = True
    conf.username = username
    conf.host = args.host
    conf.port = args.port
    conf.max_concurrent = args.max_concurrent
    conf.pslog_dir = None
    conf.log_dir = player_log_dir
    conf.wait_time_before_requesting_move_seconds = args.wait_time_before_move

    logger.info('Setting up %s on %s:%s', conf.username, conf.host, conf.port)
    logger.info('Outputting logs to %s', root_dir)

    player_ctor = lambda gid: EnginePkmnPlayer(
        policy, gid, play_best_move=args.play_best_move)

    if args.team:
        with open(args.team) as fd:
            team = fd.read().strip()
    else:
        team = ''

    game = showdown.MatchmakingGame(conf, fmt=args.format, team=team)
    game.main()

    matches = dict((i, game([player_ctor])) for i in range(num_matches))

    count = 0
    record = {'winner': 0, 'loser': 0, 'tie': 0}
    while matches:
        found = False
        for i, msg in matches.items():
            if msg.ready():
                result = msg.get()
                logger.info('Finished %d/%d matches: %s', count + 1,
                            num_matches, result)
                record[result['result']] += 1
                count += 1
                found = True
                break

        if found:
            del matches[i]

        gevent.sleep(1.)

    logger.info('Battles completed! Quitting...')
    params['record'] = record
    logger.info(params['record'])

    game.stop()
    game.join()
    return params
Example #15
0
def simulate_and_rollup(expt_name, base_dir, parallelism, cuda):
  logger = logging.getLogger('simulate_and_rollup')

  expt = json.load(expt_name)

  # 1: Figure out which iteration we are running
  current_iter = divine_current_iteration(base_dir)
  iter_dir = os.path.join(base_dir, 'iter%06d' % current_iter)
  utils.mkdir_p(iter_dir)
  logger.info('Current iteration: %d', current_iter)

  # 2: Load the current policy file
  policy_tag = divine_current_policy_tag(expt, iter_dir, current_iter)
  logger.info('Using policy: %s', policy_tag)
  policy = torch_policy.load(policy_tag)

  # do a NaN check here
  for name, param in policy.named_parameters():
    if torch.isnan(param).any().item():
      raise ValueError('Encountered nan in latest model in parameter ' + name)

  rollup_fname = os.path.join(iter_dir, 'rollup.npz')
  assert not os.path.isfile(rollup_fname), 'rollup detected means matches already simulated'

  battles_dir = os.path.join(iter_dir, 'battles')
  utils.mkdir_p(battles_dir)
  num_battles = len([d
    for d in glob.glob(os.path.join(battles_dir, '*'))
    if len(os.listdir(d)) == 2])

  total_matches = expt['simulate_args']['num_matches']
  num_battles_remaining = total_matches - num_battles
  logger.info('%d battles left to simulate for this iteration', num_battles_remaining)
  if num_battles_remaining:
    start_time = time.time()

    def spawn_battler(bid):
      tag = str(bid)
      logger.info('Spawn battler with ID %s', bid)
      env = os.environ.copy()
      env['OMP_NUM_THREADS'] = '1'
      env['MKL_NUM_THREADS'] = '1'
      err_fd = open('/tmp/%03d.err.log' % bid, 'w')
      args = ['./rp', 'metagrok/exe/simulate_worker.py',
        policy_tag,
        expt.get('format', 'gen7randombattle'),
        str(bid),
      ]
      if 'epsilon' in expt['simulate_args']:
        args.append('--epsilon')
        args.append(str(expt['simulate_args']['epsilon']))
      if 'p2' in expt['simulate_args']:
        args.append('--p2-policy-tag')
        args.append(str(expt['simulate_args']['p2']))
      rv = subprocess.Popen(
        args,
        stdout = subprocess.PIPE,
        stdin = subprocess.PIPE,
        stderr = err_fd,
        env = env,
        encoding = 'utf-8',
        bufsize = 0,
      )
      os.system('taskset -p -c %d %d' % (bid % mp.cpu_count(), rv.pid))
      return rv, err_fd

    num_blocks = 0
    battle_number = num_battles

    workers, fds = list(zip(*[spawn_battler(i) for i in range(parallelism)]))
    for i in range(num_battles_remaining):
      workers[i % len(workers)].stdin.write('battle\n')
    
    while battle_number < total_matches:
      time.sleep(0.1)
      for w in workers:
        line = w.stdout.readline().strip()
        if line:
          proc_battle_dir, num_blocks_in_battle = line.split()
          num_blocks_in_battle = int(num_blocks_in_battle)
          num_blocks += num_blocks_in_battle
          battle_dir = os.path.join(battles_dir, '%06d' % battle_number)
          shutil.rmtree(battle_dir, ignore_errors = True)
          shutil.move(proc_battle_dir, battle_dir)
          battle_number += 1
          current_pct = int(100 * battle_number / total_matches)
          prev_pct = int(100 * (battle_number - 1) / total_matches)

          if current_pct > prev_pct:
            logger.info('Battle %s (%s%%) completed. Num blocks: %s',
              battle_number, current_pct, num_blocks_in_battle)

          if battle_number >= total_matches:
            break
      for fd in fds:
        fd.flush()

    for i, w in enumerate(workers):
      logger.info('Shutting down worker %s', i)
      w.stdin.write('done\n')
      w.communicate()

    for fd in fds:
      fd.close()

    for fname in glob.glob('/tmp/*.err.log'):
      os.remove(fname)

    total_time = time.time() - start_time
    logger.info('Ran %d blocks in %ss, rate = %s block/worker/s',
      num_blocks, total_time, float(num_blocks) / len(workers) / total_time)

  logger.info('Rolling up files...')
  rollup_fname = os.path.join(iter_dir, 'rollup.npz')
  num_records = perform_rollup(expt, iter_dir, policy_tag, parallelism, rollup_fname)

  expt_shortname = os.path.splitext(os.path.basename(expt_name))[0]

  return dict(
    a__status = 'Simulations complete',
    dir = iter_dir,
    iter = current_iter + 1,
    name = expt_shortname,
    num_matches = num_battles_remaining,
    num_total_records = num_records,
    subject = 'Experiment log: ' + base_dir,
    z__params = expt,
  )
Example #16
0
def perform_policy_update(expt_name, base_dir, parallelism, cuda):
  logger = logging.getLogger('perform_policy_update')

  expt = json.load(expt_name)

  # 1: Figure out which iteration we are running
  current_iter = divine_current_iteration(base_dir)
  iter_dir = os.path.join(base_dir, 'iter%06d' % current_iter)
  utils.mkdir_p(iter_dir)
  logger.info('Current iteration: %d', current_iter)

  # 2: Load the current policy file
  policy_tag = divine_current_policy_tag(expt, iter_dir, current_iter)
  logger.info('Using policy: %s', policy_tag)

  rollup_fname = os.path.join(iter_dir, 'rollup.npz')
  assert os.path.isfile(rollup_fname), 'cannot do policy update without rollup file'

  start_time = time.time()
  npz = np.load(rollup_fname)
  all_extras = collections.defaultdict(list)
  for iter_offset in range(expt.get('updater_buffer_length_iters', 1)):
    iter_num = current_iter - iter_offset
    if iter_num >= 0:
      r_fname = os.path.join(base_dir, 'iter%06d' % iter_num, 'rollup.npz')
      logger.info('Loading: %s', r_fname)
      npz = np.load(r_fname)
      for k in npz.files:
        all_extras[k].append(npz[k])
  extras = {}
  for k, vs in list(all_extras.items()):
    extras[k] = np.concatenate(vs)
    # This is a hack to save on memory.
    # The optimal solution is to
    #   1) read each rollup to determine array size,
    #   2) pre-allocate a big array and
    #   3) fill.
    # (metagrok/methods/learner.py does a similar thing but operates on jsons files.)
    del all_extras[k]
    del vs
  del all_extras
  learner.post_prepare(extras)

  total_time = time.time() - start_time
  logger.info('Loaded rollups in %ss', total_time)

  start_time = time.time()
  logger.info('Starting policy update...')
  extras = {k: torch.from_numpy(v) for k, v in extras.items()}
  extras = TTensorDictDataset(extras, in_place_shuffle = True)

  policy = torch_policy.load(policy_tag)
  updater_cls = utils.hydrate(expt['updater'])
  updater_args = dict(expt['updater_args'])
  for k, v in expt.get('updater_args_schedules', {}).items():
    updater_args[k] = Scheduler(v).select(current_iter)
  updater = updater_cls(policy = policy, **updater_args)
  if config.use_cuda():
    policy.cuda()

  updater.update(extras)
  if config.use_cuda():
    policy.cpu()

  total_time = time.time() - start_time
  logger.info('Ran policy update in %ss', total_time)

  with open('/tmp/end_model_file.pytorch', 'wb') as fd:
    torch.save(policy.state_dict(), fd)
  end_model_file = os.path.join(iter_dir, 'end.pytorch')
  shutil.move('/tmp/end_model_file.pytorch', end_model_file)

  next_iter_dir = os.path.join(base_dir, 'iter%06d' % (current_iter + 1))
  next_start_model_file = os.path.join(next_iter_dir, 'start.pytorch')
  utils.mkdir_p(next_iter_dir)
  shutil.copy(end_model_file, next_start_model_file)
  logger.info('Wrote to %s', next_start_model_file)

  expt_shortname = os.path.splitext(os.path.basename(expt_name))[0]

  return dict(
    a__status = 'Policy update complete',
    dir = iter_dir,
    iter = current_iter + 1,
    name = expt_shortname,
    next_start_model_file = next_start_model_file,
    subject = 'Experiment log: ' + base_dir,
    z__params = expt,
  )