Ejemplo n.º 1
0
                               env.action_space, args.stack_frames)
    if args.model == 'CONV':
        shared_model = A3C_CONV(args.stack_frames, env.action_space)
    if args.load:
        saved_state = torch.load('{0}{1}.dat'.format(args.load_model_dir,
                                                     args.env),
                                 map_location=lambda storage, loc: storage)
        shared_model.load_state_dict(saved_state)
    shared_model.share_memory()

    if args.shared_optimizer:
        if args.optimizer == 'RMSprop':
            optimizer = SharedRMSprop(shared_model.parameters(), lr=args.lr)
        if args.optimizer == 'Adam':
            optimizer = SharedAdam(shared_model.parameters(),
                                   lr=args.lr,
                                   amsgrad=args.amsgrad)
        optimizer.share_memory()
    else:
        optimizer = None

    processes = []

    p = mp.Process(target=test, args=(args, shared_model))
    p.start()
    processes.append(p)
    time.sleep(0.1)
    for rank in range(0, args.workers):
        p = mp.Process(target=train,
                       args=(rank, args, shared_model, optimizer))
        p.start()
Ejemplo n.º 2
0
    for i in setup_json.keys():
        if i in args.env:
            env_conf = setup_json[i]
    env = atari_env(args.env, env_conf)
    shared_model = A3Clstm(env.observation_space.shape[0], env.action_space)
    if args.load:
        saved_state = torch.load(
            '{0}{1}.dat'.format(args.load_model_dir, args.env))
        shared_model.load_state_dict(saved_state)
    shared_model.share_memory()

    if args.shared_optimizer:
        if args.optimizer == 'RMSprop':
            optimizer = SharedRMSprop(shared_model.parameters(), lr=args.lr)
        if args.optimizer == 'Adam':
            optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)
        optimizer.share_memory()
    else:
        optimizer = None

    processes = []

    p = Process(target=test, args=(args, shared_model, env_conf))
    p.start()
    processes.append(p)
    time.sleep(0.1)
    for rank in range(0, args.workers):
        p = Process(
            target=train, args=(rank, args, shared_model, optimizer, env_conf))
        p.start()
        processes.append(p)
Ejemplo n.º 3
0
Archivo: main.py Proyecto: hvcl/ColorRL
def main (scripts, args):
    scripts = " ".join (sys.argv[0:])
    args = parser.parse_args()
    args.scripts = scripts
    
    torch.manual_seed(args.seed)
    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method('spawn')

    if (args.deploy):
        raw, gt_lbl, raw_valid, gt_lbl_valid, raw_test, gt_lbl_test = setup_data(args)
    else:
        raw, gt_lbl, raw_valid, gt_lbl_valid, raw_test, gt_lbl_test = setup_data (args)

    env_conf = setup_env_conf (args)


    shared_model = get_model (args, args.model, env_conf ["observation_shape"], args.features, 
                        atrous_rates=args.atr_rate, num_actions=2, split=args.data_channel, 
                        multi=args.multi)

    manager = mp.Manager ()
    shared_dict = manager.dict ()
    if args.wctrl == "s2m":
        shared_dict ["spl_w"] = args.spl_w
        shared_dict ["mer_w"] = args.mer_w

    if args.load:
        saved_state = torch.load(
            args.load,
            map_location=lambda storage, loc: storage)
        shared_model.load_state_dict(saved_state)
    if not args.deploy:
        shared_model.share_memory()

    if args.deploy:
         deploy (shared_model, args, args.gpu_ids [0], (raw_test, gt_lbl_test))
         exit ()
    
    if args.shared_optimizer:
        if args.optimizer == 'RMSprop':
            optimizer = SharedRMSprop(shared_model.parameters(), lr=args.lr)
        if args.optimizer == 'Adam':
            optimizer = SharedAdam(
                shared_model.parameters(), lr=args.lr, amsgrad=args.amsgrad)
        optimizer.share_memory()
    else:
        optimizer = None


    processes = []
    if not args.no_test:
        if raw_test is not None:
            if (args.deploy):
                p = mp.Process(target=test_func, args=(args, shared_model, env_conf, [raw_valid, gt_lbl_valid], (raw_test, gt_lbl_test, raw_test_upsize, gt_lbl_test_upsize, shared_dict)))
            else:
                p = mp.Process(target=test_func, args=(args, shared_model, env_conf, [raw_valid, gt_lbl_valid], (raw_test, gt_lbl_test), shared_dict))
        else:
            p = mp.Process(target=test_func, args=(args, shared_model, env_conf, [raw_valid, gt_lbl_valid], None, shared_dict))
        p.start()
        processes.append(p)
    
    time.sleep(0.1)

    for rank in range(0, args.workers):
        p = mp.Process(
            target=train_func, args=(rank, args, shared_model, optimizer, env_conf, [raw, gt_lbl], shared_dict))

        p.start()
        processes.append(p)
        time.sleep(0.1)

    for p in processes:
        time.sleep(0.1)
        p.join()