Пример #1
0
def cli_main():
    task_parser = argparse.ArgumentParser(allow_abbrev=False)
    task_parser.add_argument('--task',
                             type=str,
                             default='bert',
                             choices=['bert', 'mnist'])
    task_parser.add_argument('--optimizer',
                             type=str,
                             default='adam',
                             choices=['adam', 'adadelta'])
    task_parser.add_argument('--lr-scheduler',
                             type=str,
                             default='PolynomialDecayScheduler',
                             choices=['PolynomialDecayScheduler'])

    pre_args, s = task_parser.parse_known_args()

    parser = options.get_training_parser(
        task=pre_args.task,
        optimizer=pre_args.optimizer,
        lr_scheduler=pre_args.lr_scheduler,
    )
    args = options.parse_args_and_arch(parser, s)

    if args.distributed_init_method is not None:
        assert args.distributed_gpus <= torch.cuda.device_count()

        if args.distributed_gpus > 1 and not args.distributed_no_spawn:  # #by default run this logic
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=args.distributed_gpus,
            )
        else:
            distributed_main(args.device_id, args)
    elif args.distributed_world_size > 1:
        # fallback for single node with multiple GPUs
        assert args.distributed_world_size <= torch.cuda.device_count()
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(
            port=port)
        args.distributed_rank = None  # set based on device id
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
    else:
        # single GPU training
        main(args)
def cli_train():
    # Create default parser from CLI
    parser = options.get_training_parser(default_env='gym_env',
                                         default_model='ppo',
                                         default_agent='default_agent',
                                         )
    # Modify this function when customizing new arguments
    args = options.parse_custom_args(parser)
    print(args)

    ############# test #############
    # Comment out this block when running.
    # Make sure that the assigned GPU numbers is no less than the worker numbers
    args.gpu = "0,2,3"
    args.num_workers = 3
    ################################

    if args.gpu is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not args.num_gpus:
        args.num_gpus = len(args.gpu.split(','))
    assert args.num_gpus
    assert args.num_workers <= args.num_gpus

    # Init ray
    ray.init(address=args.address, redis_address=args.redis_address, num_cpus=args.num_cpus, num_gpus=args.num_gpus,
             include_webui=False, ignore_reinit_error=True)
    assert ray.is_initialized()

    # Define agents / models
    agent_cls = ray.remote(setup_agent(args))
    model_cls = setup_model(args)
    local_model = model_cls(args, )
    remote_model_cls = ray.remote(num_gpus=1)(model_cls)

    # Create models / agents / workers
    all_models = [remote_model_cls.remote(args, ) for _ in range(args.num_workers)]
    all_agents = [agent_cls.remote(args, model) for model in all_models]
    all_workers = [DefaultWorker.remote(args, all_agents[i], all_models[i]) for i in range(args.num_workers)]

    # Get initial weights of local networks
    weights = local_model.get_weights()
    # put the weights in the object store
    weights = [ray.put(w) for w in weights]

    print(f'Start running [workers] ...')
    # Get initial observations
    all_states = [worker.start.remote() for worker in all_workers]
    while True:
        # Run agent.step
        all_actions = [agent.step.remote(states) for agent, states in zip(all_agents, all_states)]
        # Run worker.step
        all_states = [worker.step.remote(actions) for worker, actions in zip(all_workers, all_actions)]
        # Compute gradients from agents given unified weights
        rets = [agent.fetch_grads.remote(weights) for agent in all_agents]
        rets = [ret for ret in ray.get(rets) if ret is not None]
        # Check if it is trained (trajectories satisfy training batch size)
        # If so, grads of actors and critics will be returned
        # otherwise return None.
        if len(rets) > 0:
            # Collect actor / critic gradients
            a_grads, c_grads = list(zip(*rets))
            a_grads = ray.get(list(a_grads))
            c_grads = ray.get(list(c_grads))
            # Take the mean of all gradients
            avg_a_grads = [sum(g[i] for g in a_grads) / len(a_grads) for i in range(len(a_grads[0]))]
            avg_c_grads = [sum(g[i] for g in a_grads) / len(c_grads) for i in range(len(c_grads[0]))]
            # Update local networks
            a_feed_dict = {grad[0]: m_grad for (grad, m_grad) in zip(local_model.a_grads, avg_a_grads)}
            c_feed_dict = {grad[0]: m_grad for (grad, m_grad) in zip(local_model.c_grads, avg_c_grads)}
            local_model.sess.run(local_model.atrain_op, a_feed_dict)
            local_model.sess.run(local_model.ctrain_op, c_feed_dict)
            # take the updated weights
            weights = local_model.get_weights()
            weights = [ray.put(w) for w in weights]
Пример #3
0
def prepare():
    input_args = '--train toys/reverse/train.src toys/reverse/train.trg ' \
                 '--dev toys/reverse/dev.src toys/reverse/dev.trg ' \
                 '--vocab toys/reverse/vocab.src toys/reverse/vocab.trg ' \
                 '--model runs/test ' \
                 '--hidden-size 32 ' \
                 '--max-epoch 2 ' \
                 '--eval-steps 1 ' \
                 '--shuffle -1 ' \
                 '--eval-batch-size 1 ' \
                 '--save-checkpoint-steps 1 ' \
                 '--arch multi-head-rnn ' \
                 '--lr 0.001 '.split()
    input_args = None
    parser = options.get_training_parser()
    # 1. Parse static command-line args and get default args
    cli_args, default_args, unknown_args = options.parse_static_args(
        parser, input_args=input_args)

    # 2. Load config args
    config_args = None
    if cli_args.config:
        args_list = []
        for config_file in cli_args.config:
            with open(config_file) as r:
                args_list.append(json.loads(r.read()))
        args_list = [argparse.Namespace(**item) for item in args_list]
        config_args = args_list[0]

        for args_ in args_list[1:]:
            config_args = override(config_args, args_, False)
    # 3. Load model args
    if cli_args.scratch and os.path.exists(cli_args.model):
        shutil.rmtree(cli_args.model)

    try:
        ckp_path = cli_args.model
        if os.path.isdir(cli_args.model):
            ckp_path = Loader.get_latest(cli_args.model)[1]
        state_dict = Loader.load_state(ckp_path)

    except FileNotFoundError:
        state_dict = {}
    resume = len(state_dict) > 0

    model_args = state_dict.get('args')
    # 4. Override by priorities.
    # cli_args > config_args > model_args > default_args
    args = override(config_args, cli_args, False)
    args = override(model_args, args, False)
    args = override(default_args, args, False)
    # 5. Parse a second time to get complete cli args
    cli_args, default_args = options.parse_dynamic_args(parser,
                                                        input_args=input_args,
                                                        parsed_args=args)
    # 6. Retain valid keys of args
    valid_keys = set(default_args.__dict__.keys())
    # 7. Override again
    args = override(args, cli_args, False)
    args = override(default_args, args, False)
    # 8. Remove invalid keys
    stripped_args = argparse.Namespace()
    for k in valid_keys:
        setattr(stripped_args, k, getattr(args, k))

    config_name = os.path.join(args.model, 'config.json')

    if not os.path.exists(args.model):
        os.makedirs(args.model)
    args.arch = 'mulsrc'
    if model_args != args or not os.path.exists(config_name):
        with open(config_name, 'w') as w:
            w.write(json.dumps(args.__dict__, indent=4, sort_keys=True))

    if len(args.train) == 1:
        assert args.langs and len(args.langs) == 2, args.langs

        prefix = args.train[0]
        args.train = [f'{prefix}.{args.langs[0]}', f'{prefix}.{args.langs[1]}']

    if len(args.dev) == 1:
        assert args.langs and len(args.langs) == 2, args.langs

        prefix = args.dev[0]
        args.dev = [f'{prefix}.{args.langs[0]}', f'{prefix}.{args.langs[1]}']

    if len(args.vocab) == 1:
        assert args.langs and len(args.langs) == 2, args.langs

        prefix = args.vocab[0]
        args.vocab = [f'{prefix}.{args.langs[0]}', f'{prefix}.{args.langs[1]}']

    return args, state_dict, resume
Пример #4
0
        start = time.time()
        all_coherence = get_nsp(all_sent_pairs, model.bert_tokenizer, model.bert_nsp)
        score['coherence'] = np.mean(all_coherence)
        print('Coherence', np.mean(all_coherence), ' using time ', time.time() - start)
   # if 'expressiveness' in rl_reward:
   #     start = time.time()
   #     all_dBLEU = [0.0]
   #     #all_dBLEU = model.compute_expressiveness_reward(all_db_hyps[0:1000])
   #     score['expressiveness'] = np.mean(all_dBLEU)
   #     print('Expressiveness', score['dBLEU'], ' using time ', time.time() - start)

    return score


if __name__ == '__main__':
    parser = get_training_parser()
    parser = VistDataLoader.add_args(parser)
    parser = VistModel.add_args(parser)
    args = parser.parse_args()
    print('args', args)
    
    # seed the RNG
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    np.random.seed(int(args.seed * 13 / 7))

    if args.save_dir is not None:
        args.save_model_to = args.save_dir + '/model'
        args.save_decode_file = args.save_dir + '/decode-len100'