def run_nn_dmi(args): set_global_seeds(args['seed']) dataset = DataLoader(args['dataset'], args) X_train, X_test, X_val, y_train, y_test, y_val = dataset.prepare_train_test_val( args) mlp = MLP( feature_dim=X_train.shape[-1], hidsizes=args['hidsize'], dropout=args['dropout'], outputs=2, ) classifier = DMIClassifier( model=mlp, learning_rate=args['lr'], ) results = classifier.fit( X_train, y_train, X_test, y_test, batchsize=args['batchsize'], episodes=args['episodes'], logger=logger if args['seeds'] == 1 else None, ) return results
def make_mujoco_env(env_id, seed): """ Create a wrapped, monitored gym.Env for MuJoCo. """ rank = MPI.COMM_WORLD.Get_rank() set_global_seeds(seed + 10000 * rank) env = gym.make(env_id) logger.configure() env = Monitor(env, os.path.join(logger.get_dir(), str(rank))) env.seed(seed) return env
def find_best_margin(args): """ return `best_margin / 0.1` """ set_global_seeds(args['seed']) dataset = DataLoader(args['dataset']) X_train, X_test, X_val, y_train, y_test, y_val = dataset.prepare_train_test_val( args) results = [] for margin in MARGINS: model = Perceptron(feature_dim=X_train.shape[-1], margin=margin) model.fit(X_train, y_train) results.append(model.score(X_val, y_val)) return results
def make_robotics_env(env_id, seed, rank=0): """ Create a wrapped, monitored gym.Env for MuJoCo. """ set_global_seeds(seed) env = gym.make(env_id) env = FlattenDictWrapper(env, ['observation', 'desired_goal']) env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), info_keywords=('is_success', )) env.seed(seed) return env
def __init__(self, expt_dir='experiment', load_dir=None, loss=NLLLoss(), batch_size=64, random_seed=None, checkpoint_every=100, print_every=100, use_gpu=False, learning_rate=0.001, max_grad_norm=1.0, eval_with_mask=True, scheduled_sampling=False, teacher_forcing_ratio=0.0, ddatt_loss_weight=0.0, ddattcls_loss_weight=0.0, att_scale_up=0.0): self.random_seed = random_seed if random_seed is not None: set_global_seeds(random_seed) self.loss = loss self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every self.use_gpu = use_gpu self.learning_rate = learning_rate self.max_grad_norm = max_grad_norm self.eval_with_mask = eval_with_mask self.scheduled_sampling = scheduled_sampling self.teacher_forcing_ratio = teacher_forcing_ratio self.ddatt_loss_weight = ddatt_loss_weight self.ddattcls_loss_weight = ddattcls_loss_weight self.att_scale_up = att_scale_up if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.load_dir = load_dir self.batch_size = batch_size self.logger = logging.getLogger(__name__) self.writer = torch.utils.tensorboard.writer.SummaryWriter( log_dir=self.expt_dir)
def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--prioritized', type=int, default=1) parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) parser.add_argument('--dueling', type=int, default=1) parser.add_argument('--num-timesteps', type=int, default=int(10e6)) parser.add_argument('--checkpoint-freq', type=int, default=10000) parser.add_argument('--checkpoint-path', type=str, default=None) args = parser.parse_args() logger.configure() set_global_seeds(args.seed) env = make_atari(args.env) env = Monitor(env, logger.get_dir()) env = wrap_deepmind(env) model = cnn_to_mlp( convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], hiddens=[256], dueling=bool(args.dueling), ) fit( env, q_func=model, lr=1e-4, max_timesteps=args.num_timesteps, buffer_size=10000, exploration_fraction=0.1, exploration_final_eps=0.01, train_freq=4, learning_starts=10000, target_network_update_freq=1000, gamma=0.99, prioritized_replay=bool(args.prioritized), prioritized_replay_alpha=args.prioritized_replay_alpha, checkpoint_freq=args.checkpoint_freq, checkpoint_path=args.checkpoint_path, ) env.close() sess = tf.get_default_session() del sess
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0): """ Create a wrapped, monitored SubprocVecEnv for Atari. """ if wrapper_kwargs is None: wrapper_kwargs = {} def make_env(rank): # pylint: disable=C0111 def _thunk(): env = make_atari(env_id) env.seed(seed + rank) env = Monitor( env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) return wrap_deepmind(env, **wrapper_kwargs) return _thunk set_global_seeds(seed) return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
def main(): # import pdb; pdb.set_trace() # load config parser = argparse.ArgumentParser(description='LAS Training') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # set random seed if config['random_seed'] is not None: set_global_seeds(config['random_seed']) # record config if not os.path.isabs(config['save']): config_save_dir = os.path.join(os.getcwd(), config['save']) if not os.path.exists(config['save']): os.makedirs(config['save']) # resume or not if type(config['load']) != type(None): config_save_dir = os.path.join(config['save'], 'model-cont.cfg') else: config_save_dir = os.path.join(config['save'], 'model.cfg') save_config(config, config_save_dir) # contruct trainer t = Trainer(expt_dir=config['save'], load_dir=config['load'], batch_size=config['batch_size'], minibatch_partition=config['minibatch_partition'], checkpoint_every=config['checkpoint_every'], print_every=config['print_every'], learning_rate=config['learning_rate'], eval_with_mask=config['eval_with_mask'], scheduled_sampling=config['scheduled_sampling'], teacher_forcing_ratio=config['teacher_forcing_ratio'], use_gpu=config['use_gpu'], max_grad_norm=config['max_grad_norm'], max_count_no_improve=config['max_count_no_improve'], max_count_num_rollback=config['max_count_num_rollback'], keep_num=config['keep_num'], normalise_loss=config['normalise_loss']) # vocab path_vocab_src = config['path_vocab_src'] # load train set train_path_src = config['train_path_src'] train_acous_path = config['train_acous_path'] train_set = Dataset(train_path_src, path_vocab_src=path_vocab_src, use_type=config['use_type'], acous_path=train_acous_path, seqrev=config['seqrev'], acous_norm=config['acous_norm'], acous_norm_path=config['acous_norm_path'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], acous_max_len=config['acous_max_len'], use_gpu=config['use_gpu'], logger=t.logger) vocab_size = len(train_set.vocab_src) # load dev set if config['dev_path_src']: dev_path_src = config['dev_path_src'] dev_acous_path = config['dev_acous_path'] dev_set = Dataset(dev_path_src, path_vocab_src=path_vocab_src, use_type=config['use_type'], acous_path=dev_acous_path, acous_norm_path=config['acous_norm_path'], seqrev=config['seqrev'], acous_norm=config['acous_norm'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], acous_max_len=config['acous_max_len'], use_gpu=config['use_gpu'], logger=t.logger) else: dev_set = None # construct model las_model = LAS(vocab_size, embedding_size=config['embedding_size'], acous_hidden_size=config['acous_hidden_size'], acous_att_mode=config['acous_att_mode'], hidden_size_dec=config['hidden_size_dec'], hidden_size_shared=config['hidden_size_shared'], num_unilstm_dec=config['num_unilstm_dec'], # acous_dim=config['acous_dim'], acous_norm=config['acous_norm'], spec_aug=config['spec_aug'], batch_norm=config['batch_norm'], enc_mode=config['enc_mode'], use_type=config['use_type'], # embedding_dropout=config['embedding_dropout'], dropout=config['dropout'], residual=config['residual'], batch_first=config['batch_first'], max_seq_len=config['max_seq_len'], load_embedding=config['load_embedding'], word2id=train_set.src_word2id, id2word=train_set.src_id2word, use_gpu=config['use_gpu']) device = check_device(config['use_gpu']) t.logger.info('device:{}'.format(device)) las_model = las_model.to(device=device) # run training las_model = t.train( train_set, las_model, num_epochs=config['num_epochs'], dev_set=dev_set)
def main(args): set_global_seeds(args.seed) device = args.device dtype = torch_dtypes.get(args.dtype) if 'cuda' in args.device: device_id = args.device_ids device = torch.device(device, device_id) torch.cuda.set_device(device_id) save_path = os.path.join(args.results_dir, args.model) if not os.path.exists(args.results_dir): os.mkdir(args.results_dir) log_file = open(os.path.join(args.results_dir, args.model + ".log"), "w") regime = literal_eval(args.optimization_config) model_config = literal_eval(args.model_config) vocab, rev_vocab = pickle.load(open(args.vocab, 'rb')) model_config.setdefault('encoder', {}) model_config.setdefault('decoder', {}) model_config['encoder']['vocab_size'] = len(vocab) model_config['decoder']['vocab_size'] = len(vocab) model_config['vocab_size'] = model_config['decoder']['vocab_size'] args.model_config = model_config model = transformer.Transformer(**model_config) model.to(device) criterion = nn.NLLLoss(ignore_index=PAD) params = model.parameters() optimizer = optim.Adam(params, lr=regime['lr']) # load data, word vocab, and parse vocab h5f_train = h5py.File(args.train_data, 'r') inp_train = h5f_train['inputs'] out_train = h5f_train['outputs'] input_lens_train = h5f_train['input_lens'] output_lens_train = h5f_train['output_lens'] inp_order_train = h5f_train['reordering_input'] out_order_train = h5f_train['reordering_output'] print("training samples: %d" % len(inp_train)) log_file.write("training samples: %d \n" % len(inp_train)) batch_size = args.batch_size h5f_dev = h5py.File(args.dev_data, 'r') inp_dev = h5f_dev['inputs'][0:500] out_dev = h5f_dev['outputs'][0:500] input_lens_dev = h5f_dev['input_lens'][0:500] output_lens_dev = h5f_dev['output_lens'][0:500] inp_order_dev = h5f_dev['reordering_input'][0:500] include_coverage_loss = False include_reorder_information = args.include_reorder_information train_minibatches = [(start, start + batch_size) for start in range(0, inp_train.shape[0], batch_size) ][:-1] dev_minibatches = [(start, start + batch_size) for start in range(0, inp_dev.shape[0], batch_size) ][:-1] random.shuffle(train_minibatches) log_file.write("num training batches: %d \n \n" % len(train_minibatches)) coverage_coef = 0.5 for ep in range(args.epochs): random.shuffle(train_minibatches) ep_loss = 0. start_time = time.time() num_batches = 0 cov_loss = 0. for b_idx, (start, end) in enumerate(train_minibatches): inp = inp_train[start:end] out = out_train[start:end] in_len = input_lens_train[start:end] out_len = output_lens_train[start:end] in_order = inp_order_train[start:end] out_order = out_order_train[start:end] # chop input based on length of last instance (for encoder efficiency) max_in_len = int(np.amax(in_len)) inp = inp[:, :max_in_len] in_order = in_order[:, :max_in_len] # compute max output length and chop output (for decoder efficiency) max_out_len = int(np.amax(out_len)) out = out[:, :max_out_len] out_order = out_order[:, :max_out_len] in_order = np.asarray(in_order) # sentences are too short if max_in_len < args.min_sent_length: continue swap = random.random() > 0.5 if swap: inp, out = out, inp in_order, out_order = out_order, in_order out_x = np.concatenate( [out[:, 1:], np.zeros((out.shape[0], 1))], axis=1) # torchify input curr_inp = Variable( torch.from_numpy(inp.astype('int32')).long().cuda()) curr_out = Variable( torch.from_numpy(out.astype('int32')).long().cuda()) curr_out_x = Variable( torch.from_numpy(out_x.astype('int32')).long().cuda()) curr_in_order = Variable( torch.from_numpy(in_order.astype('int32')).long().cuda()) # forward prop if include_reorder_information: preds, attention = model(curr_inp, curr_out, curr_in_order, get_attention=True) else: preds, attention = model(curr_inp, curr_out, None, None, get_attention=True) preds = preds.view(-1, len(vocab)) preds = nn.functional.log_softmax(preds, -1) num_batches += 1 # compute masked loss loss = criterion(preds, curr_out_x.view(-1)) if include_coverage_loss: coverage_loss = 0 attention = attention[ 1] ## Batch size * max out len * max in len coverage = torch.zeros( (attention.shape[0], attention.shape[2])).cuda() for att_idx in range(0, attention.shape[1]): if att_idx == 0: c_t = coverage else: c_t = coverage + attention[:, att_idx - 1, :].squeeze(1) x = torch.min(attention[:, att_idx, :].squeeze(1), c_t) coverage_loss += torch.mean(torch.sum(x, 1)) coverage_loss = coverage_loss / attention.shape[1] loss_total = loss + coverage_coef * coverage_loss cov_loss += coverage_loss.item() else: loss_total = loss optimizer.zero_grad() loss_total.backward(retain_graph=False) torch.nn.utils.clip_grad_norm_(params, args.grad_clip) optimizer.step() ep_loss += loss.data.item() if b_idx % (args.save_freq) == 0: to_print = random.randint(0, len(dev_minibatches) - 1) dev_nll = 0. for b_dev_idx, (start, end) in enumerate(dev_minibatches): inp = inp_dev[start:end] out = out_dev[start:end] in_len = input_lens_dev[start:end] out_len = output_lens_dev[start:end] in_order = inp_order_dev[start:end] curr_bsz = inp.shape[0] max_in_len = int(np.amax(in_len)) inp = inp[:, :max_in_len] in_order = in_order[:, :max_in_len] max_out_len = int(np.amax(out_len)) out = out[:, :max_out_len] out_x = np.concatenate( [out[:, 1:], np.zeros((out.shape[0], 1))], axis=1) curr_inp = Variable( torch.from_numpy(inp.astype('int32')).long().cuda()) curr_out = Variable( torch.from_numpy(out.astype('int32')).long().cuda()) curr_out_x = Variable( torch.from_numpy(out_x.astype('int32')).long().cuda()) curr_in_order = Variable( torch.from_numpy( in_order.astype('int32')).long().cuda()) if include_reorder_information: preds, _ = model(curr_inp, curr_out, curr_in_order) else: preds, _ = model(curr_inp, curr_out, None, None) preds = preds.view((-1, len(vocab))) preds = nn.functional.log_softmax(preds, -1) bos = Variable( torch.from_numpy( np.asarray([vocab["BOS"] ]).astype('int32')).long().cuda()) loss_dev = criterion(preds, curr_out_x.view(-1)) dev_nll += loss_dev.item() preds = preds.view(curr_bsz, max_out_len, -1).cpu().data.numpy() if b_dev_idx == to_print: for i in range(min(3, curr_bsz)): print('input: %s' % ' '.join([rev_vocab[w] for (j, w) in enumerate(inp[i]) \ if j < in_len[i]])) print('gt output: %s' % ' '.join([rev_vocab[w] for (j, w) in enumerate(out[i]) \ if j < out_len[i]])) if include_reorder_information: x = model.generate( curr_inp[i].unsqueeze(0), [list(bos)], curr_in_order[i].unsqueeze(0), beam_size=5, max_sequence_length=50)[0] else: x = model.generate(curr_inp[i].unsqueeze(0), [list(bos)], None, beam_size=5, max_sequence_length=50)[0] preds = [s.output for s in x] print([ ' '.join( [rev_vocab[int(w.data.cpu())] for w in p]) for p in preds ][0]) print("\n") print('dev nll per token: %f' % (dev_nll / float(len(dev_minibatches)))) print('done with batch %d / %d in epoch %d, loss: %f, cov loss: %f, time:%d' \ % (b_idx, len(train_minibatches), ep, ep_loss / num_batches, cov_loss / num_batches, time.time() - start_time)) print('train nll per token : %f \n' % (float(ep_loss) / float(num_batches))) torch.save( { 'state_dict': model.state_dict(), 'ep_loss': ep_loss / num_batches, 'train_minibatches': train_minibatches, 'config_args': args }, save_path) log_file.write("epoch : %d , batch : %d\n" % (ep, num_batches)) log_file.write("dev nll: %f \n" % (dev_nll / float(len(dev_minibatches)))) log_file.write("train nll: %f \n \n" % (float(ep_loss) / float(num_batches))) ep_loss = 0. num_batches = 0. start_time = time.time()
def fit( policy, env, seed, nsteps=20, nstack=4, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01, max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99, log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0, trust_region=True, alpha=0.99, delta=1 ): print("Running Acer Simple") print(locals()) tf.reset_default_graph() set_global_seeds(seed) # num_procs = len(env.remotes) # HACK model = Acer( policy=policy, observation_space=env.observation_space, action_space=env.action_space, nenvs=env.num_envs, nsteps=nsteps, nstack=nstack, ent_coef=ent_coef, q_coef=q_coef, gamma=gamma, max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule, c=c, trust_region=trust_region, alpha=alpha, delta=delta ) env_runner = Environment(env=env, model=model, nsteps=nsteps, nstack=nstack) if replay_ratio > 0: buffer = Buffer(env=env, nsteps=nsteps, nstack=nstack, size=buffer_size) else: buffer = None nbatch = env.num_envs * nsteps agent = AgentEnv(env_runner, model, buffer, log_interval) agent.tstart = time.time() # nbatch samples, 1 on_policy call and multiple off-policy calls for agent.steps in range(0, total_timesteps, nbatch): agent.call(on_policy=True) if replay_ratio > 0 and buffer.has_atleast(replay_start): n = np.random.poisson(replay_ratio) for _ in range(n): agent.call(on_policy=False) # no simulation steps in this env.close()
def main(): parser = arg_parser() parser.add_argument('--platform', help='environment choice', choices=['atari', 'mujoco'], default='atari') platform_args, environ_args = parser.parse_known_args() platform = platform_args.platform rank = MPI.COMM_WORLD.Get_rank() # atari if platform == 'atari': from bench import Monitor from utils.cmd import atari_arg_parser, make_atari, \ wrap_deepmind from policies.nohashingcnn import CnnPolicy args = atari_arg_parser().parse_known_args()[0] if rank == 0: logger.configure() else: logger.configure(format_strs=[]) workerseed = args.seed + 10000 * rank set_global_seeds(workerseed) env = make_atari(args.env) env = Monitor( env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) env.seed(workerseed) env = wrap_deepmind(env) env.seed(workerseed) model = TRPO(CnnPolicy, env.observation_space, env.action_space) sess = model.single_threaded_session().__enter__() # model.reset_graph_and_vars() model.init_vars() fit(model, env, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3, max_timesteps=int(args.num_timesteps * 1.1), gamma=0.98, lam=1.0, vf_iters=3, vf_stepsize=1e-4, entcoeff=0.00) sess.close() env.close() # mujoco if platform == 'mujoco': from policies.ppo1mlp import PPO1Mlp from utils.cmd import make_mujoco_env, mujoco_arg_parser args = mujoco_arg_parser().parse_known_args()[0] if rank == 0: logger.configure() else: logger.configure(format_strs=[]) logger.set_level(logger.DISABLED) workerseed = args.seed + 10000 * rank env = make_mujoco_env(args.env, workerseed) def policy(name, observation_space, action_space): return PPO1Mlp(name, env.observation_space, env.action_space, hid_size=32, num_hid_layers=2) model = TRPO(policy, env.observation_space, env.action_space) sess = model.single_threaded_session().__enter__() model.init_vars() fit(model, env, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1, max_timesteps=args.num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3) sess.close() env.close()
def main(): # load config parser = argparse.ArgumentParser(description='Seq2seq Training') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # set random seed if config['random_seed'] is not None: set_global_seeds(config['random_seed']) # record config if not os.path.isabs(config['save']): config_save_dir = os.path.join(os.getcwd(), config['save']) if not os.path.exists(config['save']): os.makedirs(config['save']) # resume or not if config['load']: resume = True print('resuming {} ...'.format(config['load'])) config_save_dir = os.path.join(config['save'], 'model-cont.cfg') else: resume = False config_save_dir = os.path.join(config['save'], 'model.cfg') save_config(config, config_save_dir) # contruct trainer t = Trainer(expt_dir=config['save'], load_dir=config['load'], batch_size=config['batch_size'], checkpoint_every=config['checkpoint_every'], print_every=config['print_every'], learning_rate=config['learning_rate'], eval_with_mask=config['eval_with_mask'], scheduled_sampling=config['scheduled_sampling'], teacher_forcing_ratio=config['teacher_forcing_ratio'], use_gpu=config['use_gpu'], max_grad_norm=config['max_grad_norm'], max_count_no_improve=config['max_count_no_improve'], max_count_num_rollback=config['max_count_num_rollback'], keep_num=config['keep_num'], normalise_loss=config['normalise_loss'], minibatch_split=config['minibatch_split']) # load train set train_path_src = config['train_path_src'] train_path_tgt = config['train_path_tgt'] path_vocab_src = config['path_vocab_src'] path_vocab_tgt = config['path_vocab_tgt'] train_set = Dataset(train_path_src, train_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, seqrev=config['seqrev'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], use_gpu=config['use_gpu'], logger=t.logger, use_type=config['use_type']) vocab_size_enc = len(train_set.vocab_src) vocab_size_dec = len(train_set.vocab_tgt) # load dev set if config['dev_path_src'] and config['dev_path_tgt']: dev_path_src = config['dev_path_src'] dev_path_tgt = config['dev_path_tgt'] dev_set = Dataset(dev_path_src, dev_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, seqrev=config['seqrev'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], use_gpu=config['use_gpu'], logger=t.logger, use_type=config['use_type']) else: dev_set = None # construct model seq2seq = Seq2seq(vocab_size_enc, vocab_size_dec, share_embedder=config['share_embedder'], embedding_size_enc=config['embedding_size_enc'], embedding_size_dec=config['embedding_size_dec'], embedding_dropout=config['embedding_dropout'], hidden_size_enc=config['hidden_size_enc'], num_bilstm_enc=config['num_bilstm_enc'], num_unilstm_enc=config['num_unilstm_enc'], hidden_size_dec=config['hidden_size_dec'], num_unilstm_dec=config['num_unilstm_dec'], hidden_size_att=config['hidden_size_att'], hidden_size_shared=config['hidden_size_shared'], dropout=config['dropout'], residual=config['residual'], batch_first=config['batch_first'], max_seq_len=config['max_seq_len'], load_embedding_src=config['load_embedding_src'], load_embedding_tgt=config['load_embedding_tgt'], src_word2id=train_set.src_word2id, tgt_word2id=train_set.tgt_word2id, src_id2word=train_set.src_id2word, tgt_id2word=train_set.tgt_id2word, att_mode=config['att_mode']) device = check_device(config['use_gpu']) t.logger.info('device:{}'.format(device)) seq2seq = seq2seq.to(device=device) # run training seq2seq = t.train(train_set, seq2seq, num_epochs=config['num_epochs'], resume=resume, dev_set=dev_set)
def fit(policy, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20, ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, kfac_clip=0.001, save_interval=None, lrschedule='linear'): tf.reset_default_graph() set_global_seeds(seed) nenvs = env.num_envs ob_space = env.observation_space ac_space = env.action_space model = AcktrDiscrete(policy, ob_space, ac_space, nenvs, total_timesteps, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_fisher_coef, lr=lr, max_grad_norm=max_grad_norm, kfac_clip=kfac_clip, lrschedule=lrschedule) # if save_interval and logger.get_dir(): # import cloudpickle # with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh: # fh.write(cloudpickle.dumps(make_model)) # model = make_model() runner = Environment(env, model, nsteps=nsteps, gamma=gamma) nbatch = nenvs * nsteps tstart = time.time() coord = tf.train.Coordinator() enqueue_threads = model.q_runner.create_threads(model.sess, coord=coord, start=True) for update in range(1, total_timesteps // nbatch + 1): obs, states, rewards, masks, actions, values = runner.run() policy_loss, value_loss, policy_entropy = model.train( obs, states, rewards, masks, actions, values) model.old_obs = obs nseconds = time.time() - tstart fps = int((update * nbatch) / nseconds) if update % log_interval == 0 or update == 1: ev = explained_variance(values, rewards) logger.record_tabular("nupdates", update) logger.record_tabular("total_timesteps", update * nbatch) logger.record_tabular("fps", fps) logger.record_tabular("policy_entropy", float(policy_entropy)) logger.record_tabular("policy_loss", float(policy_loss)) logger.record_tabular("value_loss", float(value_loss)) logger.record_tabular("explained_variance", float(ev)) logger.dump_tabular() if save_interval and (update % save_interval == 0 or update == 1) \ and logger.get_dir(): savepath = os.path.join(logger.get_dir(), 'checkpoint%.5i' % update) print('Saving to', savepath) model.save(savepath) coord.request_stop() coord.join(enqueue_threads) env.close()