Example #1
0
def run_nn_dmi(args):
    set_global_seeds(args['seed'])
    dataset = DataLoader(args['dataset'], args)
    X_train, X_test, X_val, y_train, y_test, y_val = dataset.prepare_train_test_val(
        args)
    mlp = MLP(
        feature_dim=X_train.shape[-1],
        hidsizes=args['hidsize'],
        dropout=args['dropout'],
        outputs=2,
    )
    classifier = DMIClassifier(
        model=mlp,
        learning_rate=args['lr'],
    )
    results = classifier.fit(
        X_train,
        y_train,
        X_test,
        y_test,
        batchsize=args['batchsize'],
        episodes=args['episodes'],
        logger=logger if args['seeds'] == 1 else None,
    )
    return results
Example #2
0
def make_mujoco_env(env_id, seed):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.
    """
    rank = MPI.COMM_WORLD.Get_rank()
    set_global_seeds(seed + 10000 * rank)
    env = gym.make(env_id)
    logger.configure()
    env = Monitor(env, os.path.join(logger.get_dir(), str(rank)))
    env.seed(seed)
    return env
Example #3
0
def find_best_margin(args):
    """ return `best_margin / 0.1` """
    set_global_seeds(args['seed'])
    dataset = DataLoader(args['dataset'])
    X_train, X_test, X_val, y_train, y_test, y_val = dataset.prepare_train_test_val(
        args)

    results = []
    for margin in MARGINS:
        model = Perceptron(feature_dim=X_train.shape[-1], margin=margin)
        model.fit(X_train, y_train)
        results.append(model.score(X_val, y_val))
    return results
Example #4
0
def make_robotics_env(env_id, seed, rank=0):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.
    """
    set_global_seeds(seed)
    env = gym.make(env_id)
    env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
    env = Monitor(env,
                  logger.get_dir()
                  and os.path.join(logger.get_dir(), str(rank)),
                  info_keywords=('is_success', ))
    env.seed(seed)
    return env
Example #5
0
    def __init__(self,
                 expt_dir='experiment',
                 load_dir=None,
                 loss=NLLLoss(),
                 batch_size=64,
                 random_seed=None,
                 checkpoint_every=100,
                 print_every=100,
                 use_gpu=False,
                 learning_rate=0.001,
                 max_grad_norm=1.0,
                 eval_with_mask=True,
                 scheduled_sampling=False,
                 teacher_forcing_ratio=0.0,
                 ddatt_loss_weight=0.0,
                 ddattcls_loss_weight=0.0,
                 att_scale_up=0.0):

        self.random_seed = random_seed
        if random_seed is not None:
            set_global_seeds(random_seed)

        self.loss = loss
        self.optimizer = None
        self.checkpoint_every = checkpoint_every
        self.print_every = print_every
        self.use_gpu = use_gpu
        self.learning_rate = learning_rate
        self.max_grad_norm = max_grad_norm
        self.eval_with_mask = eval_with_mask
        self.scheduled_sampling = scheduled_sampling
        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.ddatt_loss_weight = ddatt_loss_weight
        self.ddattcls_loss_weight = ddattcls_loss_weight
        self.att_scale_up = att_scale_up

        if not os.path.isabs(expt_dir):
            expt_dir = os.path.join(os.getcwd(), expt_dir)
        self.expt_dir = expt_dir
        if not os.path.exists(self.expt_dir):
            os.makedirs(self.expt_dir)
        self.load_dir = load_dir

        self.batch_size = batch_size
        self.logger = logging.getLogger(__name__)
        self.writer = torch.utils.tensorboard.writer.SummaryWriter(
            log_dir=self.expt_dir)
Example #6
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = Monitor(env, logger.get_dir())
    env = wrap_deepmind(env)
    model = cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    fit(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )

    env.close()
    sess = tf.get_default_session()
    del sess
Example #7
0
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
    """
    Create a wrapped, monitored SubprocVecEnv for Atari.
    """
    if wrapper_kwargs is None:
        wrapper_kwargs = {}

    def make_env(rank):  # pylint: disable=C0111
        def _thunk():
            env = make_atari(env_id)
            env.seed(seed + rank)
            env = Monitor(
                env,
                logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
            return wrap_deepmind(env, **wrapper_kwargs)

        return _thunk

    set_global_seeds(seed)
    return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
Example #8
0
File: train.py Project: EdieLu/LAS
def main():

	# import pdb; pdb.set_trace()
	# load config
	parser = argparse.ArgumentParser(description='LAS Training')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# set random seed
	if config['random_seed'] is not None:
		set_global_seeds(config['random_seed'])

	# record config
	if not os.path.isabs(config['save']):
		config_save_dir = os.path.join(os.getcwd(), config['save'])
	if not os.path.exists(config['save']):
		os.makedirs(config['save'])

	# resume or not
	if type(config['load']) != type(None):
		config_save_dir = os.path.join(config['save'], 'model-cont.cfg')
	else:
		config_save_dir = os.path.join(config['save'], 'model.cfg')
	save_config(config, config_save_dir)

	# contruct trainer
	t = Trainer(expt_dir=config['save'],
					load_dir=config['load'],
					batch_size=config['batch_size'],
					minibatch_partition=config['minibatch_partition'],
					checkpoint_every=config['checkpoint_every'],
					print_every=config['print_every'],
					learning_rate=config['learning_rate'],
					eval_with_mask=config['eval_with_mask'],
					scheduled_sampling=config['scheduled_sampling'],
					teacher_forcing_ratio=config['teacher_forcing_ratio'],
					use_gpu=config['use_gpu'],
					max_grad_norm=config['max_grad_norm'],
					max_count_no_improve=config['max_count_no_improve'],
					max_count_num_rollback=config['max_count_num_rollback'],
					keep_num=config['keep_num'],
					normalise_loss=config['normalise_loss'])

	# vocab
	path_vocab_src = config['path_vocab_src']

	# load train set
	train_path_src = config['train_path_src']
	train_acous_path = config['train_acous_path']
	train_set = Dataset(train_path_src, path_vocab_src=path_vocab_src,
		use_type=config['use_type'],
		acous_path=train_acous_path,
		seqrev=config['seqrev'],
		acous_norm=config['acous_norm'],
		acous_norm_path=config['acous_norm_path'],
		max_seq_len=config['max_seq_len'],
		batch_size=config['batch_size'],
		acous_max_len=config['acous_max_len'],
		use_gpu=config['use_gpu'],
		logger=t.logger)

	vocab_size = len(train_set.vocab_src)

	# load dev set
	if config['dev_path_src']:
		dev_path_src = config['dev_path_src']
		dev_acous_path = config['dev_acous_path']
		dev_set = Dataset(dev_path_src, path_vocab_src=path_vocab_src,
			use_type=config['use_type'],
			acous_path=dev_acous_path,
			acous_norm_path=config['acous_norm_path'],
			seqrev=config['seqrev'],
			acous_norm=config['acous_norm'],
			max_seq_len=config['max_seq_len'],
			batch_size=config['batch_size'],
			acous_max_len=config['acous_max_len'],
			use_gpu=config['use_gpu'],
			logger=t.logger)
	else:
		dev_set = None

	# construct model
	las_model = LAS(vocab_size,
					embedding_size=config['embedding_size'],
					acous_hidden_size=config['acous_hidden_size'],
					acous_att_mode=config['acous_att_mode'],
					hidden_size_dec=config['hidden_size_dec'],
					hidden_size_shared=config['hidden_size_shared'],
					num_unilstm_dec=config['num_unilstm_dec'],
					#
					acous_dim=config['acous_dim'],
					acous_norm=config['acous_norm'],
					spec_aug=config['spec_aug'],
					batch_norm=config['batch_norm'],
					enc_mode=config['enc_mode'],
					use_type=config['use_type'],
					#
					embedding_dropout=config['embedding_dropout'],
					dropout=config['dropout'],
					residual=config['residual'],
					batch_first=config['batch_first'],
					max_seq_len=config['max_seq_len'],
					load_embedding=config['load_embedding'],
					word2id=train_set.src_word2id,
					id2word=train_set.src_id2word,
					use_gpu=config['use_gpu'])

	device = check_device(config['use_gpu'])
	t.logger.info('device:{}'.format(device))
	las_model = las_model.to(device=device)

	# run training
	las_model = t.train(
		train_set, las_model, num_epochs=config['num_epochs'], dev_set=dev_set)
Example #9
0
def main(args):
    set_global_seeds(args.seed)

    device = args.device
    dtype = torch_dtypes.get(args.dtype)

    if 'cuda' in args.device:
        device_id = args.device_ids
        device = torch.device(device, device_id)
        torch.cuda.set_device(device_id)

    save_path = os.path.join(args.results_dir, args.model)
    if not os.path.exists(args.results_dir):
        os.mkdir(args.results_dir)
    log_file = open(os.path.join(args.results_dir, args.model + ".log"), "w")

    regime = literal_eval(args.optimization_config)
    model_config = literal_eval(args.model_config)

    vocab, rev_vocab = pickle.load(open(args.vocab, 'rb'))

    model_config.setdefault('encoder', {})
    model_config.setdefault('decoder', {})
    model_config['encoder']['vocab_size'] = len(vocab)
    model_config['decoder']['vocab_size'] = len(vocab)
    model_config['vocab_size'] = model_config['decoder']['vocab_size']
    args.model_config = model_config
    model = transformer.Transformer(**model_config)
    model.to(device)

    criterion = nn.NLLLoss(ignore_index=PAD)
    params = model.parameters()

    optimizer = optim.Adam(params, lr=regime['lr'])

    # load data, word vocab, and parse vocab
    h5f_train = h5py.File(args.train_data, 'r')
    inp_train = h5f_train['inputs']
    out_train = h5f_train['outputs']
    input_lens_train = h5f_train['input_lens']
    output_lens_train = h5f_train['output_lens']
    inp_order_train = h5f_train['reordering_input']
    out_order_train = h5f_train['reordering_output']
    print("training samples: %d" % len(inp_train))
    log_file.write("training samples: %d \n" % len(inp_train))

    batch_size = args.batch_size
    h5f_dev = h5py.File(args.dev_data, 'r')
    inp_dev = h5f_dev['inputs'][0:500]
    out_dev = h5f_dev['outputs'][0:500]
    input_lens_dev = h5f_dev['input_lens'][0:500]
    output_lens_dev = h5f_dev['output_lens'][0:500]
    inp_order_dev = h5f_dev['reordering_input'][0:500]

    include_coverage_loss = False
    include_reorder_information = args.include_reorder_information

    train_minibatches = [(start, start + batch_size)
                         for start in range(0, inp_train.shape[0], batch_size)
                         ][:-1]
    dev_minibatches = [(start, start + batch_size)
                       for start in range(0, inp_dev.shape[0], batch_size)
                       ][:-1]
    random.shuffle(train_minibatches)

    log_file.write("num training batches: %d \n \n" % len(train_minibatches))

    coverage_coef = 0.5
    for ep in range(args.epochs):
        random.shuffle(train_minibatches)
        ep_loss = 0.
        start_time = time.time()
        num_batches = 0
        cov_loss = 0.

        for b_idx, (start, end) in enumerate(train_minibatches):
            inp = inp_train[start:end]
            out = out_train[start:end]
            in_len = input_lens_train[start:end]
            out_len = output_lens_train[start:end]
            in_order = inp_order_train[start:end]
            out_order = out_order_train[start:end]

            # chop input based on length of last instance (for encoder efficiency)
            max_in_len = int(np.amax(in_len))
            inp = inp[:, :max_in_len]
            in_order = in_order[:, :max_in_len]

            # compute max output length and chop output (for decoder efficiency)
            max_out_len = int(np.amax(out_len))
            out = out[:, :max_out_len]
            out_order = out_order[:, :max_out_len]

            in_order = np.asarray(in_order)

            # sentences are too short
            if max_in_len < args.min_sent_length:
                continue

            swap = random.random() > 0.5
            if swap:
                inp, out = out, inp
                in_order, out_order = out_order, in_order

            out_x = np.concatenate(
                [out[:, 1:], np.zeros((out.shape[0], 1))], axis=1)

            # torchify input
            curr_inp = Variable(
                torch.from_numpy(inp.astype('int32')).long().cuda())
            curr_out = Variable(
                torch.from_numpy(out.astype('int32')).long().cuda())
            curr_out_x = Variable(
                torch.from_numpy(out_x.astype('int32')).long().cuda())
            curr_in_order = Variable(
                torch.from_numpy(in_order.astype('int32')).long().cuda())

            # forward prop
            if include_reorder_information:
                preds, attention = model(curr_inp,
                                         curr_out,
                                         curr_in_order,
                                         get_attention=True)
            else:
                preds, attention = model(curr_inp,
                                         curr_out,
                                         None,
                                         None,
                                         get_attention=True)
            preds = preds.view(-1, len(vocab))
            preds = nn.functional.log_softmax(preds, -1)

            num_batches += 1
            # compute masked loss
            loss = criterion(preds, curr_out_x.view(-1))

            if include_coverage_loss:
                coverage_loss = 0
                attention = attention[
                    1]  ## Batch size * max out len * max in len
                coverage = torch.zeros(
                    (attention.shape[0], attention.shape[2])).cuda()
                for att_idx in range(0, attention.shape[1]):
                    if att_idx == 0:
                        c_t = coverage
                    else:
                        c_t = coverage + attention[:,
                                                   att_idx - 1, :].squeeze(1)

                    x = torch.min(attention[:, att_idx, :].squeeze(1), c_t)
                    coverage_loss += torch.mean(torch.sum(x, 1))

                coverage_loss = coverage_loss / attention.shape[1]
                loss_total = loss + coverage_coef * coverage_loss
                cov_loss += coverage_loss.item()

            else:
                loss_total = loss

            optimizer.zero_grad()
            loss_total.backward(retain_graph=False)
            torch.nn.utils.clip_grad_norm_(params, args.grad_clip)
            optimizer.step()
            ep_loss += loss.data.item()

            if b_idx % (args.save_freq) == 0:

                to_print = random.randint(0, len(dev_minibatches) - 1)
                dev_nll = 0.
                for b_dev_idx, (start, end) in enumerate(dev_minibatches):

                    inp = inp_dev[start:end]
                    out = out_dev[start:end]
                    in_len = input_lens_dev[start:end]
                    out_len = output_lens_dev[start:end]
                    in_order = inp_order_dev[start:end]
                    curr_bsz = inp.shape[0]

                    max_in_len = int(np.amax(in_len))
                    inp = inp[:, :max_in_len]
                    in_order = in_order[:, :max_in_len]

                    max_out_len = int(np.amax(out_len))
                    out = out[:, :max_out_len]
                    out_x = np.concatenate(
                        [out[:, 1:], np.zeros((out.shape[0], 1))], axis=1)

                    curr_inp = Variable(
                        torch.from_numpy(inp.astype('int32')).long().cuda())
                    curr_out = Variable(
                        torch.from_numpy(out.astype('int32')).long().cuda())
                    curr_out_x = Variable(
                        torch.from_numpy(out_x.astype('int32')).long().cuda())
                    curr_in_order = Variable(
                        torch.from_numpy(
                            in_order.astype('int32')).long().cuda())

                    if include_reorder_information:
                        preds, _ = model(curr_inp, curr_out, curr_in_order)
                    else:
                        preds, _ = model(curr_inp, curr_out, None, None)
                    preds = preds.view((-1, len(vocab)))
                    preds = nn.functional.log_softmax(preds, -1)

                    bos = Variable(
                        torch.from_numpy(
                            np.asarray([vocab["BOS"]
                                        ]).astype('int32')).long().cuda())
                    loss_dev = criterion(preds, curr_out_x.view(-1))
                    dev_nll += loss_dev.item()

                    preds = preds.view(curr_bsz, max_out_len,
                                       -1).cpu().data.numpy()

                    if b_dev_idx == to_print:
                        for i in range(min(3, curr_bsz)):

                            print('input: %s' % ' '.join([rev_vocab[w] for (j, w) in enumerate(inp[i]) \
                                                          if j < in_len[i]]))
                            print('gt output: %s' % ' '.join([rev_vocab[w] for (j, w) in enumerate(out[i]) \
                                                              if j < out_len[i]]))

                            if include_reorder_information:
                                x = model.generate(
                                    curr_inp[i].unsqueeze(0), [list(bos)],
                                    curr_in_order[i].unsqueeze(0),
                                    beam_size=5,
                                    max_sequence_length=50)[0]
                            else:
                                x = model.generate(curr_inp[i].unsqueeze(0),
                                                   [list(bos)],
                                                   None,
                                                   beam_size=5,
                                                   max_sequence_length=50)[0]
                            preds = [s.output for s in x]
                            print([
                                ' '.join(
                                    [rev_vocab[int(w.data.cpu())] for w in p])
                                for p in preds
                            ][0])
                            print("\n")

                print('dev nll per token: %f' %
                      (dev_nll / float(len(dev_minibatches))))

                print('done with batch %d / %d in epoch %d, loss: %f, cov loss: %f, time:%d' \
                      % (b_idx, len(train_minibatches), ep,
                         ep_loss / num_batches, cov_loss / num_batches, time.time() - start_time))
                print('train nll per token : %f \n' %
                      (float(ep_loss) / float(num_batches)))

                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'ep_loss': ep_loss / num_batches,
                        'train_minibatches': train_minibatches,
                        'config_args': args
                    }, save_path)

                log_file.write("epoch : %d , batch : %d\n" % (ep, num_batches))
                log_file.write("dev nll: %f \n" %
                               (dev_nll / float(len(dev_minibatches))))
                log_file.write("train nll: %f \n \n" %
                               (float(ep_loss) / float(num_batches)))

                ep_loss = 0.
                num_batches = 0.
                start_time = time.time()
Example #10
0
def fit(
        policy,
        env,
        seed,
        nsteps=20,
        nstack=4,
        total_timesteps=int(80e6),
        q_coef=0.5,
        ent_coef=0.01,
        max_grad_norm=10,
        lr=7e-4,
        lrschedule='linear',
        rprop_epsilon=1e-5,
        rprop_alpha=0.99,
        gamma=0.99,
        log_interval=100,
        buffer_size=50000,
        replay_ratio=4,
        replay_start=10000,
        c=10.0,
        trust_region=True,
        alpha=0.99,
        delta=1
):
    print("Running Acer Simple")
    print(locals())
    tf.reset_default_graph()
    set_global_seeds(seed)

    # num_procs = len(env.remotes)  # HACK
    model = Acer(
        policy=policy,
        observation_space=env.observation_space,
        action_space=env.action_space,
        nenvs=env.num_envs,
        nsteps=nsteps,
        nstack=nstack,
        ent_coef=ent_coef,
        q_coef=q_coef,
        gamma=gamma,
        max_grad_norm=max_grad_norm,
        lr=lr,
        rprop_alpha=rprop_alpha,
        rprop_epsilon=rprop_epsilon,
        total_timesteps=total_timesteps,
        lrschedule=lrschedule,
        c=c,
        trust_region=trust_region,
        alpha=alpha,
        delta=delta
    )

    env_runner = Environment(env=env, model=model, nsteps=nsteps, nstack=nstack)
    if replay_ratio > 0:
        buffer = Buffer(env=env, nsteps=nsteps, nstack=nstack, size=buffer_size)
    else:
        buffer = None
    nbatch = env.num_envs * nsteps
    agent = AgentEnv(env_runner, model, buffer, log_interval)
    agent.tstart = time.time()
    # nbatch samples, 1 on_policy call and multiple off-policy calls
    for agent.steps in range(0, total_timesteps, nbatch):
        agent.call(on_policy=True)
        if replay_ratio > 0 and buffer.has_atleast(replay_start):
            n = np.random.poisson(replay_ratio)
            for _ in range(n):
                agent.call(on_policy=False)  # no simulation steps in this

    env.close()
Example #11
0
def main():
    parser = arg_parser()
    parser.add_argument('--platform',
                        help='environment choice',
                        choices=['atari', 'mujoco'],
                        default='atari')

    platform_args, environ_args = parser.parse_known_args()
    platform = platform_args.platform

    rank = MPI.COMM_WORLD.Get_rank()

    # atari
    if platform == 'atari':
        from bench import Monitor
        from utils.cmd import atari_arg_parser, make_atari, \
            wrap_deepmind
        from policies.nohashingcnn import CnnPolicy

        args = atari_arg_parser().parse_known_args()[0]
        if rank == 0:
            logger.configure()
        else:
            logger.configure(format_strs=[])

        workerseed = args.seed + 10000 * rank
        set_global_seeds(workerseed)
        env = make_atari(args.env)

        env = Monitor(
            env,
            logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
        env.seed(workerseed)

        env = wrap_deepmind(env)
        env.seed(workerseed)

        model = TRPO(CnnPolicy, env.observation_space, env.action_space)
        sess = model.single_threaded_session().__enter__()
        # model.reset_graph_and_vars()
        model.init_vars()

        fit(model,
            env,
            timesteps_per_batch=512,
            max_kl=0.001,
            cg_iters=10,
            cg_damping=1e-3,
            max_timesteps=int(args.num_timesteps * 1.1),
            gamma=0.98,
            lam=1.0,
            vf_iters=3,
            vf_stepsize=1e-4,
            entcoeff=0.00)
        sess.close()
        env.close()

    # mujoco
    if platform == 'mujoco':
        from policies.ppo1mlp import PPO1Mlp
        from utils.cmd import make_mujoco_env, mujoco_arg_parser
        args = mujoco_arg_parser().parse_known_args()[0]

        if rank == 0:
            logger.configure()
        else:
            logger.configure(format_strs=[])
            logger.set_level(logger.DISABLED)

        workerseed = args.seed + 10000 * rank

        env = make_mujoco_env(args.env, workerseed)

        def policy(name, observation_space, action_space):
            return PPO1Mlp(name,
                           env.observation_space,
                           env.action_space,
                           hid_size=32,
                           num_hid_layers=2)

        model = TRPO(policy, env.observation_space, env.action_space)
        sess = model.single_threaded_session().__enter__()
        model.init_vars()

        fit(model,
            env,
            timesteps_per_batch=1024,
            max_kl=0.01,
            cg_iters=10,
            cg_damping=0.1,
            max_timesteps=args.num_timesteps,
            gamma=0.99,
            lam=0.98,
            vf_iters=5,
            vf_stepsize=1e-3)
        sess.close()
        env.close()
Example #12
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Training')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# set random seed
	if config['random_seed'] is not None:
		set_global_seeds(config['random_seed'])

	# record config
	if not os.path.isabs(config['save']):
		config_save_dir = os.path.join(os.getcwd(), config['save'])
	if not os.path.exists(config['save']):
		os.makedirs(config['save'])

	# resume or not
	if config['load']:
		resume = True
		print('resuming {} ...'.format(config['load']))
		config_save_dir = os.path.join(config['save'], 'model-cont.cfg')
	else:
		resume = False
		config_save_dir = os.path.join(config['save'], 'model.cfg')
	save_config(config, config_save_dir)

	# contruct trainer
	t = Trainer(expt_dir=config['save'],
					load_dir=config['load'],
					batch_size=config['batch_size'],
					checkpoint_every=config['checkpoint_every'],
					print_every=config['print_every'],
					learning_rate=config['learning_rate'],
					eval_with_mask=config['eval_with_mask'],
					scheduled_sampling=config['scheduled_sampling'],
					teacher_forcing_ratio=config['teacher_forcing_ratio'],
					use_gpu=config['use_gpu'],
					max_grad_norm=config['max_grad_norm'],
					max_count_no_improve=config['max_count_no_improve'],
					max_count_num_rollback=config['max_count_num_rollback'],
					keep_num=config['keep_num'],
					normalise_loss=config['normalise_loss'],
					minibatch_split=config['minibatch_split'])

	# load train set
	train_path_src = config['train_path_src']
	train_path_tgt = config['train_path_tgt']
	path_vocab_src = config['path_vocab_src']
	path_vocab_tgt = config['path_vocab_tgt']
	train_set = Dataset(train_path_src, train_path_tgt,
		path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt,
		seqrev=config['seqrev'],
		max_seq_len=config['max_seq_len'],
		batch_size=config['batch_size'],
		use_gpu=config['use_gpu'],
		logger=t.logger,
		use_type=config['use_type'])

	vocab_size_enc = len(train_set.vocab_src)
	vocab_size_dec = len(train_set.vocab_tgt)

	# load dev set
	if config['dev_path_src'] and config['dev_path_tgt']:
		dev_path_src = config['dev_path_src']
		dev_path_tgt = config['dev_path_tgt']
		dev_set = Dataset(dev_path_src, dev_path_tgt,
			path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt,
			seqrev=config['seqrev'],
			max_seq_len=config['max_seq_len'],
			batch_size=config['batch_size'],
			use_gpu=config['use_gpu'],
			logger=t.logger,
			use_type=config['use_type'])
	else:
		dev_set = None

	# construct model
	seq2seq = Seq2seq(vocab_size_enc, vocab_size_dec,
					share_embedder=config['share_embedder'],
					embedding_size_enc=config['embedding_size_enc'],
					embedding_size_dec=config['embedding_size_dec'],
					embedding_dropout=config['embedding_dropout'],
					hidden_size_enc=config['hidden_size_enc'],
					num_bilstm_enc=config['num_bilstm_enc'],
					num_unilstm_enc=config['num_unilstm_enc'],
					hidden_size_dec=config['hidden_size_dec'],
					num_unilstm_dec=config['num_unilstm_dec'],
					hidden_size_att=config['hidden_size_att'],
					hidden_size_shared=config['hidden_size_shared'],
					dropout=config['dropout'],
					residual=config['residual'],
					batch_first=config['batch_first'],
					max_seq_len=config['max_seq_len'],
					load_embedding_src=config['load_embedding_src'],
					load_embedding_tgt=config['load_embedding_tgt'],
					src_word2id=train_set.src_word2id,
					tgt_word2id=train_set.tgt_word2id,
					src_id2word=train_set.src_id2word,
					tgt_id2word=train_set.tgt_id2word,
					att_mode=config['att_mode'])


	device = check_device(config['use_gpu'])
	t.logger.info('device:{}'.format(device))
	seq2seq = seq2seq.to(device=device)

	# run training
	seq2seq = t.train(train_set, seq2seq,
		num_epochs=config['num_epochs'], resume=resume, dev_set=dev_set)
Example #13
0
def fit(policy,
        env,
        seed,
        total_timesteps=int(40e6),
        gamma=0.99,
        log_interval=1,
        nprocs=32,
        nsteps=20,
        ent_coef=0.01,
        vf_coef=0.5,
        vf_fisher_coef=1.0,
        lr=0.25,
        max_grad_norm=0.5,
        kfac_clip=0.001,
        save_interval=None,
        lrschedule='linear'):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    model = AcktrDiscrete(policy,
                          ob_space,
                          ac_space,
                          nenvs,
                          total_timesteps,
                          nsteps=nsteps,
                          ent_coef=ent_coef,
                          vf_coef=vf_fisher_coef,
                          lr=lr,
                          max_grad_norm=max_grad_norm,
                          kfac_clip=kfac_clip,
                          lrschedule=lrschedule)
    # if save_interval and logger.get_dir():
    #     import cloudpickle
    #     with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
    #         fh.write(cloudpickle.dumps(make_model))
    # model = make_model()

    runner = Environment(env, model, nsteps=nsteps, gamma=gamma)
    nbatch = nenvs * nsteps
    tstart = time.time()
    coord = tf.train.Coordinator()
    enqueue_threads = model.q_runner.create_threads(model.sess,
                                                    coord=coord,
                                                    start=True)
    for update in range(1, total_timesteps // nbatch + 1):
        obs, states, rewards, masks, actions, values = runner.run()
        policy_loss, value_loss, policy_entropy = model.train(
            obs, states, rewards, masks, actions, values)
        model.old_obs = obs
        nseconds = time.time() - tstart
        fps = int((update * nbatch) / nseconds)
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, rewards)
            logger.record_tabular("nupdates", update)
            logger.record_tabular("total_timesteps", update * nbatch)
            logger.record_tabular("fps", fps)
            logger.record_tabular("policy_entropy", float(policy_entropy))
            logger.record_tabular("policy_loss", float(policy_loss))
            logger.record_tabular("value_loss", float(value_loss))
            logger.record_tabular("explained_variance", float(ev))
            logger.dump_tabular()

        if save_interval and (update % save_interval == 0 or update == 1) \
           and logger.get_dir():
            savepath = os.path.join(logger.get_dir(),
                                    'checkpoint%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
    coord.request_stop()
    coord.join(enqueue_threads)
    env.close()