Exemple #1
0
def run(args):
    device = torch.device("cpu")
    env = gym.make('SpaceInvaders-v0')
    state_size = env.observation_space.shape
    action_size = env.action_space.n

    model = ActorCritic([1, 4, 84, 84], action_size).to(device)
    opt = SharedRMSprop(model.parameters(),
                        lr=args.lr,
                        alpha=args.alpha,
                        eps=1e-8,
                        weight_decay=args.weight_decay,
                        momentum=args.momentum,
                        centered=False)
    opt_lock = mp.Lock()
    scheduler = LRScheduler(args)

    if args.load_fp:
        checkpoint = torch.load(args.load_fp)
        model.load_state_dict(checkpoint['model_state_dict'])
        opt.load_state_dict(checkpoint['optimizer_state_dict'])

    if args.train:
        start = time.time()

        model.share_memory()
        model.train()

        step_counter, max_reward, ma_reward, ma_loss = [
            mp.Value('d', 0.0) for _ in range(4)
        ]

        processes = []
        if args.num_procs == -1:
            args.num_procs = mp.cpu_count()
        for rank in range(args.num_procs):
            p = mp.Process(target=train,
                           args=(rank, args, device, model, opt, opt_lock,
                                 scheduler, step_counter, max_reward,
                                 ma_reward, ma_loss))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()

        if args.verbose > 0:
            print(f"Seconds taken: {time.time() - start}")
        if args.save_fp:
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    # 'optimizer_state_dict': opt.state_dict(),
                },
                args.save_fp)

    if args.test:
        model.eval()
        test(args, device, model)
Exemple #2
0
        shared_model.load_state_dict(saved_state)
    shared_model.share_memory()

    if args.shared_optimizer:
        if args.optimizer == 'RMSprop':
            optimizer = SharedRMSprop(shared_model.parameters(), lr=args.lr)
        if args.optimizer == 'Adam':
            optimizer = SharedAdam(shared_model.parameters(),
                                   lr=args.lr,
                                   amsgrad=args.amsgrad)
        if args.load:
            saved_state = torch.load('{0}{1}1.torch'.format(
                args.load_model_dir, args.env),
                                     map_location=lambda storage, loc: storage)
            print("load state dict")
            optimizer.load_state_dict(saved_state)
            print("loaded optimizer")
        optimizer.share_memory()
    else:
        optimizer = None

    processes = []

    p = mp.Process(target=testocpg, args=(args, shared_model, env_conf))
    p.start()
    processes.append(p)
    time.sleep(0.1)
    for rank in range(0, args.workers):
        p = mp.Process(target=trainocpg,
                       args=(rank, args, shared_model, optimizer, env_conf))
        p.start()