Ejemplo n.º 1
0
    if args.resume:
        print("=> loading checkpoint ")
        checkpoint = torch.load('../../7.87.t7')
        #checkpoint = torch.load('../../best.t7')
        args.start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        ac_net.load_state_dict(checkpoint['state_dict'])
        opt_ac.load_state_dict(checkpoint['optimizer'])
        opt_ac.state = defaultdict(dict, opt_ac.state)
        #print(opt_ac)
        shared_obs_stats = checkpoint['obs']

        print(ac_net)
        print("=> loaded checkpoint  (epoch {})".format(checkpoint['epoch']))

    ac_net.share_memory()

    #opt_ac.share_memory()
    #running_state = ZFilter((num_inputs,), clip=5)

    processes = []

    if args.test:
        p = mp.Process(target=test,
                       args=(args.num_processes, args, ac_net,
                             shared_obs_stats, opt_ac))
        p.start()
        processes.append(p)

    p = mp.Process(target=chief,
                   args=(args, args.num_processes + 1, traffic_light, counter,
Ejemplo n.º 2
0
def main(args):
    print(f" Session ID: {args.uuid}")

    # logging
    log_dir = f'logs/{args.env_name}/{args.model_id}/{args.uuid}/'
    args_logger = setup_logger('args', log_dir, f'args.log')
    env_logger = setup_logger('env', log_dir, f'env.log')

    if args.debug:
        debug.packages()
    os.environ['OMP_NUM_THREADS'] = "1"
    if torch.cuda.is_available():
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        devices = ",".join([str(i) for i in range(torch.cuda.device_count())])
        os.environ["CUDA_VISIBLE_DEVICES"] = devices

    args_logger.info(vars(args))
    env_logger.info(vars(os.environ))

    env = create_atari_environment(args.env_name)

    shared_model = ActorCritic(env.observation_space.shape[0],
                               env.action_space.n)

    if torch.cuda.is_available():
        shared_model = shared_model.cuda()

    shared_model.share_memory()

    optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)
    optimizer.share_memory()

    if args.load_model:  # TODO Load model before initializing optimizer
        checkpoint_file = f"{args.env_name}/{args.model_id}_{args.algorithm}_params.tar"
        checkpoint = restore_checkpoint(checkpoint_file)
        assert args.env_name == checkpoint['env'], \
            "Checkpoint is for different environment"
        args.model_id = checkpoint['id']
        args.start_step = checkpoint['step']
        print("Loading model from checkpoint...")
        print(f"Environment: {args.env_name}")
        print(f"      Agent: {args.model_id}")
        print(f"      Start: Step {args.start_step}")
        shared_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    else:
        print(f"Environment: {args.env_name}")
        print(f"      Agent: {args.model_id}")

    torch.manual_seed(args.seed)

    print(
        FontColor.BLUE + \
        f"CPUs:    {mp.cpu_count(): 3d} | " + \
        f"GPUs: {None if not torch.cuda.is_available() else torch.cuda.device_count()}" + \
        FontColor.END
    )

    processes = []

    counter = mp.Value('i', 0)
    lock = mp.Lock()

    # Queue training processes
    num_processes = args.num_processes
    no_sample = args.non_sample  # count of non-sampling processes

    if args.num_processes > 1:
        num_processes = args.num_processes - 1

    samplers = num_processes - no_sample

    for rank in range(0, num_processes):
        device = 'cpu'
        if torch.cuda.is_available():
            device = 0  # TODO: Need to move to distributed to handle multigpu
        if rank < samplers:  # random action
            p = mp.Process(
                target=train,
                args=(rank, args, shared_model, counter, lock, optimizer,
                      device),
            )
        else:  # best action
            p = mp.Process(
                target=train,
                args=(rank, args, shared_model, counter, lock, optimizer,
                      device, False),
            )
        p.start()
        time.sleep(1.)
        processes.append(p)

    # Queue test process
    p = mp.Process(target=test,
                   args=(args.num_processes, args, shared_model, counter, 0))

    p.start()
    processes.append(p)

    for p in processes:
        p.join()