コード例 #1
0
def main(args):
    # Set logging
    if not os.path.exists("./log"):
        os.makedirs("./log")

    log = set_log(args)
    tb_writer = SummaryWriter('./log/tb_{0}'.format(args.log_name))

    # Set seed
    set_seed(args.seed, cudnn=args.make_deterministic)

    # Set sampler
    sampler = BatchSampler(args, log)

    # Set policy
    policy = CaviaMLPPolicy(
        input_size=int(np.prod(sampler.observation_space.shape)),
        output_size=int(np.prod(sampler.action_space.shape)),
        hidden_sizes=(args.hidden_size, ) * args.num_layers,
        num_context_params=args.num_context_params,
        device=args.device)

    # Initialise baseline
    baseline = LinearFeatureBaseline(
        int(np.prod(sampler.observation_space.shape)))

    # Initialise meta-learner
    metalearner = MetaLearner(sampler, policy, baseline, args, tb_writer)

    # Begin train
    train(sampler, metalearner, args, log, tb_writer)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env-type', default='cheetah_vel')
    # parser.add_argument('--env-type', default='point_robot_sparse')
    # parser.add_argument('--env-type', default='gridworld')
    args, rest_args = parser.parse_known_args()
    env = args.env_type

    # --- GridWorld ---
    if env == 'gridworld':
        args = args_gridworld.get_args(rest_args)
    # --- PointRobot ---
    elif env == 'point_robot':
        args = args_point_robot.get_args(rest_args)
    elif env == 'point_robot_sparse':
        args = args_point_robot_sparse.get_args(rest_args)
    # --- Mujoco ---
    elif env == 'cheetah_vel':
        args = args_cheetah_vel.get_args(rest_args)
    elif env == 'ant_semicircle':
        args = args_ant_semicircle.get_args(rest_args)
    elif env == 'ant_semicircle_sparse':
        args = args_ant_semicircle_sparse.get_args(rest_args)

    # make sure we have log directories
    try:
        os.makedirs(args.agent_log_dir)
    except OSError:
        files = glob.glob(os.path.join(args.agent_log_dir, '*.monitor.csv'))
        for f in files:
            os.remove(f)
    eval_log_dir = args.agent_log_dir + "_eval"
    try:
        os.makedirs(eval_log_dir)
    except OSError:
        files = glob.glob(os.path.join(eval_log_dir, '*.monitor.csv'))
        for f in files:
            os.remove(f)

    # set gpu
    set_gpu_mode(torch.cuda.is_available() and args.use_gpu)

    # start training
    learner = MetaLearner(args)

    learner.train()
コード例 #3
0
def _train(dic_exp_conf, dic_agent_conf, dic_traffic_env_conf, dic_path):
    '''
        Perform meta-testing for MAML, Metalight, Random, and Pretrained 

        Arguments:
            dic_exp_conf:           dict,   configuration of this experiment
            dic_agent_conf:         dict,   configuration of agent
            dic_traffic_env_conf:   dict,   configuration of traffic environment
            dic_path:               dict,   path of source files and output files
    '''

    random.seed(dic_agent_conf['SEED'])
    np.random.seed(dic_agent_conf['SEED'])
    tf.set_random_seed(dic_agent_conf['SEED'])

    sampler = BatchSampler(dic_exp_conf=dic_exp_conf,
                           dic_agent_conf=dic_agent_conf,
                           dic_traffic_env_conf=dic_traffic_env_conf,
                           dic_path=dic_path,
                           batch_size=args.fast_batch_size,
                           num_workers=args.num_workers)

    policy = config.DIC_AGENTS[args.algorithm](
        dic_agent_conf=dic_agent_conf,
        dic_traffic_env_conf=dic_traffic_env_conf,
        dic_path=dic_path)

    metalearner = MetaLearner(sampler,
                              policy,
                              dic_agent_conf=dic_agent_conf,
                              dic_traffic_env_conf=dic_traffic_env_conf,
                              dic_path=dic_path)

    if dic_agent_conf['PRE_TRAIN']:
        if not dic_agent_conf['PRE_TRAIN_MODEL_NAME'] == 'random':
            params = pickle.load(
                open(
                    os.path.join(
                        'model', 'initial', "common",
                        dic_agent_conf['PRE_TRAIN_MODEL_NAME'] + '.pkl'),
                    'rb'))
            metalearner.meta_params = params
            metalearner.meta_target_params = params

    tasks = dic_exp_conf['TRAFFIC_IN_TASKS']

    episodes = None
    for batch_id in range(dic_exp_conf['NUM_ROUNDS']):
        tasks = [dic_exp_conf['TRAFFIC_FILE']]
        if dic_agent_conf['MULTI_EPISODES']:
            episodes = metalearner.sample_meta_test(tasks[0], batch_id,
                                                    episodes)
        else:
            episodes = metalearner.sample_meta_test(tasks[0], batch_id)
コード例 #4
0
def main(args):

    dataset_name = args.dataset
    model_name = args.model
    n_inner_iter = args.adaptation_steps
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    is_test = args.is_test
    stopping_patience = args.stopping_patience
    epochs = args.epochs
    fast_lr = args.learning_rate
    slow_lr = args.meta_learning_rate
    noise_level = args.noise_level
    noise_type = args.noise_type
    resume = args.resume

    first_order = False
    inner_loop_grad_clip = 20
    task_size = 50
    output_dim = 1
    checkpoint_freq = 10
    horizon = 10
    ##test

    meta_info = {
        "POLLUTION": [5, 50, 14],
        "HR": [32, 50, 13],
        "BATTERY": [20, 50, 3]
    }

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, task_size, input_dim = meta_info[dataset_name]

    grid = [0., noise_level]
    output_directory = "output/"

    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str(
            trial) + "/"
        save_model_file_ = output_directory + save_model_file
        save_model_file_encoder = output_directory + "encoder_" + save_model_file
        load_model_file_ = output_directory + load_model_file
        checkpoint_file = output_directory + "checkpoint_" + save_model_file.split(
            ".")[0]

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Learning rate :%f \n" % fast_lr)
            f.write("Meta-learning rate: %f \n" % slow_lr)
            f.write("Adaptation steps: %f \n" % n_inner_iter)
            f.write("Noise level: %f \n" % noise_level)

        if model_name == "LSTM":
            model = LSTMModel(batch_size=batch_size,
                              seq_len=window_size,
                              input_dim=input_dim,
                              n_layers=2,
                              hidden_dim=120,
                              output_dim=output_dim)
            model2 = LinearModel(120, 1)
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(model2.parameters()),
                                     lr=slow_lr)
        loss_func = mae
        #loss_func = nn.SmoothL1Loss()
        #loss_func = nn.MSELoss()
        initial_epoch = 0

        #torch.backends.cudnn.enabled = False

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)
        model.to(device)

        early_stopping = EarlyStopping(patience=stopping_patience,
                                       model_file=save_model_file_encoder,
                                       verbose=True)
        early_stopping2 = EarlyStopping(patience=stopping_patience,
                                        model_file=save_model_file_,
                                        verbose=True)

        if resume:
            checkpoint = torch.load(checkpoint_file)
            model.load_state_dict(checkpoint["model"])
            meta_learner.load_state_dict(checkpoint["meta_learner"])
            initial_epoch = checkpoint["epoch"]
            best_score = checkpoint["best_score"]
            counter = checkpoint["counter_stopping"]

            early_stopping.best_score = best_score
            early_stopping2.best_score = best_score

            early_stopping.counter = counter
            early_stopping2.counter = counter

        total_tasks, task_size, window_size, input_dim = train_data_ML.x.shape
        accum_mean = 0.0

        for epoch in range(initial_epoch, epochs):

            model.zero_grad()
            meta_learner._model.zero_grad()

            #train
            batch_idx = np.random.randint(0, total_tasks - 1, batch_size)

            #for batch_idx in range(0, total_tasks-1, batch_size):

            x_spt, y_spt = train_data_ML[batch_idx]
            x_qry, y_qry = train_data_ML[batch_idx + 1]

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            # data augmentation
            epsilon = grid[np.random.randint(0, len(grid))]

            if noise_type == "additive":
                y_spt = y_spt + epsilon
                y_qry = y_qry + epsilon
            else:
                y_spt = y_spt * (1 + epsilon)
                y_qry = y_qry * (1 + epsilon)

            train_tasks = [
                Task(model.encoder(x_spt[i]), y_spt[i])
                for i in range(x_spt.shape[0])
            ]
            val_tasks = [
                Task(model.encoder(x_qry[i]), y_qry[i])
                for i in range(x_qry.shape[0])
            ]

            adapted_params = meta_learner.adapt(train_tasks)
            mean_loss = meta_learner.step(adapted_params,
                                          val_tasks,
                                          is_training=True)
            #accum_mean += mean_loss.cpu().detach().numpy()

            #progressBar(batch_idx, total_tasks, 100)

            #print(accum_mean/(batch_idx+1))

            #test

            val_error = test(validation_data_ML, meta_learner, model, device,
                             noise_level)
            test_error = test(test_data_ML, meta_learner, model, device, 0.0)
            print("Epoch:", epoch)
            print("Val error:", val_error)
            print("Test error:", test_error)

            early_stopping(val_error, model)
            early_stopping2(val_error, meta_learner)

            #checkpointing
            if epochs % checkpoint_freq == 0:
                torch.save(
                    {
                        "epoch": epoch,
                        "model": model.state_dict(),
                        "meta_learner": meta_learner.state_dict(),
                        "best_score": early_stopping2.best_score,
                        "counter_stopping": early_stopping2.counter
                    }, checkpoint_file)

            if early_stopping.early_stop:
                print("Early stopping")
                break

        print("hallo")
        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        validation_error = test(validation_data_ML,
                                meta_learner,
                                model,
                                device,
                                noise_level=0.0)
        test_error = test(test_data_ML,
                          meta_learner,
                          model,
                          device,
                          noise_level=0.0)

        validation_error_h1 = test(validation_data_ML,
                                   meta_learner,
                                   model,
                                   device,
                                   noise_level=0.0,
                                   horizon=1)
        test_error_h1 = test(test_data_ML,
                             meta_learner,
                             model,
                             device,
                             noise_level=0.0,
                             horizon=1)

        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                    first_order, 0, inner_loop_grad_clip,
                                    device)

        validation_error_h0 = test(validation_data_ML,
                                   meta_learner2,
                                   model,
                                   device,
                                   noise_level=0.0,
                                   horizon=1)
        test_error_h0 = test(test_data_ML,
                             meta_learner2,
                             model,
                             device,
                             noise_level=0.0,
                             horizon=1)

        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                    first_order, n_inner_iter,
                                    inner_loop_grad_clip, device)
        validation_error_mae = test(validation_data_ML, meta_learner2, model,
                                    device, 0.0)
        test_error_mae = test(test_data_ML, meta_learner2, model, device, 0.0)
        print("test_error_mae", test_error_mae)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Test error: %f \n" % test_error)
            f.write("Validation error: %f \n" % validation_error)
            f.write("Test error h1: %f \n" % test_error_h1)
            f.write("Validation error h1: %f \n" % validation_error_h1)
            f.write("Test error h0: %f \n" % test_error_h0)
            f.write("Validation error h0: %f \n" % validation_error_h0)
            f.write("Test error mae: %f \n" % test_error_mae)
            f.write("Validation error mae: %f \n" % validation_error_mae)

        print(test_error)
        print(validation_error)
コード例 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env-type', default='gridworld_varibad')
    args, rest_args = parser.parse_known_args()
    env = args.env_type

    # --- GridWorld ---

    # standard
    if env == 'gridworld_oracle':
        args = args_grid_oracle.get_args(rest_args)
    elif env == 'gridworld_belief_oracle':
        args = args_grid_belief_oracle.get_args(rest_args)
    elif env == 'gridworld_varibad':
        args = args_grid_varibad.get_args(rest_args)
    elif env == 'gridworld_rl2':
        args = args_grid_rl2.get_args(rest_args)

    # --- MUJOCO ---

    # - AntDir -
    elif env == 'mujoco_ant_dir_oracle':
        args = args_mujoco_ant_dir_oracle.get_args(rest_args)
    elif env == 'mujoco_ant_dir_rl2':
        args = args_mujoco_ant_dir_rl2.get_args(rest_args)
    elif env == 'mujoco_ant_dir_varibad':
        args = args_mujoco_ant_dir_varibad.get_args(rest_args)
    #
    # - CheetahDir -
    elif env == 'mujoco_cheetah_dir_oracle':
        args = args_mujoco_cheetah_dir_oracle.get_args(rest_args)
    elif env == 'mujoco_cheetah_dir_rl2':
        args = args_mujoco_cheetah_dir_rl2.get_args(rest_args)
    elif env == 'mujoco_cheetah_dir_varibad':
        args = args_mujoco_cheetah_dir_varibad.get_args(rest_args)
    #
    # - CheetahVel -
    elif env == 'mujoco_cheetah_vel_oracle':
        args = args_mujoco_cheetah_vel_oracle.get_args(rest_args)
    elif env == 'mujoco_cheetah_vel_rl2':
        args = args_mujoco_cheetah_vel_rl2.get_args(rest_args)
    elif env == 'mujoco_cheetah_vel_varibad':
        args = args_mujoco_cheetah_vel_varibad.get_args(rest_args)
    #
    # - Walker -
    elif env == 'mujoco_walker_oracle':
        args = args_mujoco_walker_oracle.get_args(rest_args)
    elif env == 'mujoco_walker_rl2':
        args = args_mujoco_walker_rl2.get_args(rest_args)
    elif env == 'mujoco_walker_varibad':
        args = args_mujoco_walker_varibad.get_args(rest_args)
    #
    # - CheetahHField
    elif env == 'mujoco_cheetah_hfield_varibad':
        args = args_mujoco_cheetah_hfield_varibad.get_args(rest_args)

    # - CheetahHill
    elif env == 'mujoco_cheetah_hill_varibad':
        args = args_mujoco_cheetah_hill_varibad.get_args(rest_args)

    # - CheetahBasin
    elif env == 'mujoco_cheetah_basin_varibad':
        args = args_mujoco_cheetah_basin_varibad.get_args(rest_args)

    # - CheetahGentle
    elif env == 'mujoco_cheetah_gentle_varibad':
        args = args_mujoco_cheetah_gentle_varibad.get_args(rest_args)

    # - CheetahSteep
    elif env == 'mujoco_cheetah_steep_varibad':
        args = args_mujoco_cheetah_steep_varibad.get_args(rest_args)

    # - CheetahJoint
    elif env == 'mujoco_cheetah_joint_varibad':
        args = args_mujoco_cheetah_joint_varibad.get_args(rest_args)

    # - CheetahBlocks
    elif env == 'mujoco_cheetah_blocks_varibad':
        args = args_mujoco_cheetah_blocks_varibad.get_args(rest_args)

    # make sure we have log directories for mujoco
    if 'mujoco' in env:
        try:
            os.makedirs(args.agent_log_dir)
        except OSError:
            files = glob.glob(os.path.join(args.agent_log_dir,
                                           '*.monitor.csv'))
            for f in files:
                os.remove(f)
        eval_log_dir = args.agent_log_dir + "_eval"
        try:
            os.makedirs(eval_log_dir)
        except OSError:
            files = glob.glob(os.path.join(eval_log_dir, '*.monitor.csv'))
            for f in files:
                os.remove(f)

    # warning
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    # start training
    if args.disable_varibad:
        # When the flag `disable_varibad` is activated, the file `learner.py` will be used instead of `metalearner.py`.
        # This is a stripped down version without encoder, decoder, stochastic latent variables, etc.
        learner = Learner(args)
    else:
        learner = MetaLearner(args)
    learner.train()
コード例 #6
0
def main(args):
    dataset_name = args.dataset
    model_name = args.model
    n_inner_iter = args.adaptation_steps
    meta_learning_rate = args.meta_learning_rate
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    task_size = args.task_size
    noise_level = args.noise_level
    noise_type = args.noise_type
    epochs = args.epochs
    loss_fcn_str = args.loss
    modulate_task_net = args.modulate_task_net
    weight_vrae = args.weight_vrae
    stopping_patience = args.stopping_patience

    meta_info = {"POLLUTION": [5, 14], "HR": [32, 13], "BATTERY": [20, 3]}

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, input_dim = meta_info[dataset_name]

    grid = [0., noise_level]

    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    total_tasks = len(train_data_ML)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_fn = mae if loss_fcn_str == "MAE" else nn.SmoothL1Loss()

    ##multimodal learner parameters
    # paramters wto increase capactiy of the model
    n_layers_task_net = 2
    n_layers_task_encoder = 2
    n_layers_task_decoder = 2

    hidden_dim_task_net = 120
    hidden_dim_encoder = 120
    hidden_dim_decoder = 120

    # fixed values
    input_dim_task_net = input_dim
    input_dim_task_encoder = input_dim + 1
    output_dim_task_net = 1
    output_dim_task_decoder = input_dim + 1

    first_order = False
    inner_loop_grad_clip = 20

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MMAML/" + str(
            trial) + "/"
        save_model_file_ = output_directory + save_model_file
        save_model_file_encoder = output_directory + "encoder_" + save_model_file
        load_model_file_ = output_directory + load_model_file
        checkpoint_file = output_directory + "checkpoint_" + save_model_file.split(
            ".")[0]

        writer = SummaryWriter()

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        task_net = LSTMModel(batch_size=batch_size,
                             seq_len=window_size,
                             input_dim=input_dim_task_net,
                             n_layers=n_layers_task_net,
                             hidden_dim=hidden_dim_task_net,
                             output_dim=output_dim_task_net)

        task_encoder = LSTMModel(batch_size=batch_size,
                                 seq_len=task_size,
                                 input_dim=input_dim_task_encoder,
                                 n_layers=n_layers_task_encoder,
                                 hidden_dim=hidden_dim_encoder,
                                 output_dim=1)

        task_decoder = LSTMDecoder(batch_size=1,
                                   n_layers=n_layers_task_decoder,
                                   seq_len=task_size,
                                   output_dim=output_dim_task_decoder,
                                   hidden_dim=hidden_dim_encoder,
                                   latent_dim=hidden_dim_decoder,
                                   device=device)

        lmbd = Lambda(hidden_dim_encoder, hidden_dim_task_net)

        multimodal_learner = MultimodalLearner(task_net, task_encoder,
                                               task_decoder, lmbd,
                                               modulate_task_net)
        multimodal_learner.to(device)

        output_layer = LinearModel(120, 1)
        opt = torch.optim.Adam(list(multimodal_learner.parameters()) +
                               list(output_layer.parameters()),
                               lr=meta_learning_rate)

        meta_learner = MetaLearner(output_layer, opt, learning_rate, loss_fn,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        early_stopping = EarlyStopping(patience=stopping_patience,
                                       model_file=save_model_file_,
                                       verbose=True)
        early_stopping_encoder = EarlyStopping(
            patience=stopping_patience,
            model_file=save_model_file_encoder,
            verbose=True)

        task_data_train = torch.FloatTensor(
            get_task_encoder_input(train_data_ML))
        task_data_validation = torch.FloatTensor(
            get_task_encoder_input(validation_data_ML))
        task_data_test = torch.FloatTensor(
            get_task_encoder_input(test_data_ML))

        val_loss_hist = []
        test_loss_hist = []

        for epoch in range(epochs):

            multimodal_learner.train()

            batch_idx = np.random.randint(0, total_tasks - 1, batch_size)
            task = task_data_train[batch_idx].cuda()

            x_spt, y_spt = train_data_ML[batch_idx]
            x_qry, y_qry = train_data_ML[batch_idx + 1]

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            # data augmentation
            epsilon = grid[np.random.randint(0, len(grid))]

            if noise_type == "additive":
                y_spt = y_spt + epsilon
                y_qry = y_qry + epsilon
            else:
                y_spt = y_spt * (1 + epsilon)
                y_qry = y_qry * (1 + epsilon)

            x_spt_encodings = []
            x_qry_encodings = []
            vrae_loss_accum = 0.0
            for i in range(batch_size):
                x_spt_encoding, (vrae_loss, kl_loss,
                                 rec_loss) = multimodal_learner(
                                     x_spt[i],
                                     task[i:i + 1],
                                     output_encoding=True)
                x_spt_encodings.append(x_spt_encoding)
                vrae_loss_accum += vrae_loss

                x_qry_encoding, _ = multimodal_learner(x_qry[i],
                                                       task[i:i + 1],
                                                       output_encoding=True)
                x_qry_encodings.append(x_qry_encoding)

            train_tasks = [
                Task(x_spt_encodings[i], y_spt[i])
                for i in range(x_spt.shape[0])
            ]
            val_tasks = [
                Task(x_qry_encodings[i], y_qry[i])
                for i in range(x_qry.shape[0])
            ]

            # print(vrae_loss)

            adapted_params = meta_learner.adapt(train_tasks)
            mean_loss = meta_learner.step(adapted_params,
                                          val_tasks,
                                          is_training=True,
                                          additional_loss_term=weight_vrae *
                                          vrae_loss_accum / batch_size)

            ##plotting grad of output layer
            for tag, parm in output_layer.linear.named_parameters():
                writer.add_histogram("Grads_output_layer_" + tag,
                                     parm.grad.data.cpu().numpy(), epoch)

            multimodal_learner.eval()
            val_loss = test(validation_data_ML, multimodal_learner,
                            meta_learner, task_data_validation)
            test_loss = test(test_data_ML, multimodal_learner, meta_learner,
                             task_data_test)

            print("Epoch:", epoch)
            print("Train loss:", mean_loss)
            print("Val error:", val_loss)
            print("Test error:", test_loss)

            early_stopping(val_loss, meta_learner)
            early_stopping_encoder(val_loss, multimodal_learner)

            val_loss_hist.append(val_loss)
            test_loss_hist.append(test_loss)

            if early_stopping.early_stop:
                print("Early stopping")
                break

            writer.add_scalar("Loss/train",
                              mean_loss.cpu().detach().numpy(), epoch)
            writer.add_scalar("Loss/val", val_loss, epoch)
            writer.add_scalar("Loss/test", test_loss, epoch)

        multimodal_learner.load_state_dict(torch.load(save_model_file_encoder))
        output_layer.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner = MetaLearner(output_layer, opt, learning_rate, loss_fn,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        val_loss = test(validation_data_ML, multimodal_learner, meta_learner,
                        task_data_validation)
        test_loss = test(test_data_ML, multimodal_learner, meta_learner,
                         task_data_test)

        with open(output_directory + "/results3.txt", "a+") as f:
            f.write("Dataset :%s \n" % dataset_name)
            f.write("Test error: %f \n" % test_loss)
            f.write("Val error: %f \n" % val_loss)
            f.write("\n")

        writer.add_hparams(
            {
                "fast_lr": learning_rate,
                "slow_lr": meta_learning_rate,
                "adaption_steps": n_inner_iter,
                "patience": stopping_patience,
                "weight_vrae": weight_vrae,
                "noise_level": noise_level,
                "dataset": dataset_name,
                "trial": trial
            }, {
                "val_loss": val_loss,
                "test_loss": test_loss
            })
コード例 #7
0
def main():

    wandb.init(project="ofsl-implementation", entity="joeljosephjin")
    
    args, unparsed = FLAGS.parse_known_args()
    if len(unparsed) != 0:
        raise NameError("Argument {} not recognized".format(unparsed))

    if args.seed is None:
        args.seed = random.randint(0, 1e3)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    if args.cpu:
        args.dev = torch.device('cpu')
    else:
        if not torch.cuda.is_available():
            raise RuntimeError("GPU unavailable.")

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        args.dev = torch.device('cuda')

    logger = GOATLogger(args)

    # Load train/validation/test tasksets using the benchmark interface
    tasksets = l2l.vision.benchmarks.get_tasksets('mini-imagenet',
                                                  train_ways=args.n_class,
                                                  train_samples=2*args.n_shot,
                                                  test_ways=args.n_class,
                                                  test_samples=2*args.n_shot,
                                                  root='~/data',
    )
    
    # Set up learner, meta-learner
    learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev)
    learner_wo_grad = copy.deepcopy(learner_w_grad)
    metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev)
    # gets the model parameters in a concatenated torch list; then pushes it into cI.data
    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())

    # Set up loss, optimizer, learning rate scheduler
    optim = torch.optim.Adam(metalearner.parameters(), args.lr)

    if args.resume:
        logger.loginfo("Initialized from: {}".format(args.resume))
        last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev)

    if args.mode == 'test':
        _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
        return

    best_acc = 0.0
    logger.loginfo("Start training")
    
    # Meta-training
    for eps in range(50000):
        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED

        batch = tasksets.train.sample()
        adapt_x, adapt_y, eval_x, eval_y = process_batch(batch, args)
        # print('len(adapt_x)',len(adapt_x)) # 25

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.train()
        
        cI = metalearner.metalstm.cI.data
        h = None
        for _ in range(args.epoch):
            # get the loss/grad
            # copy from cell state to model.parameters
            learner_w_grad.copy_flat_params(cI)

            # do a forward pass and get the loss
            output = learner_w_grad(adapt_x)
            loss = learner_w_grad.criterion(output, adapt_y)
            acc = accuracy(output, adapt_y)

            # populate the gradients
            learner_w_grad.zero_grad()
            loss.backward()

            # get the grad from the lwg.parameters
            grad = torch.cat([p.grad.data.view(-1) / args.batch_size for p in learner_w_grad.parameters()], 0)

            # preprocess grad & loss and metalearner forward
            grad_prep = preprocess_grad_loss(grad)  # [n_learner_params, 2]
            loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2]

            # push the loss, grad thru the metalearner
            cI, h = metalearner([loss_prep, grad_prep, grad.unsqueeze(1)], h)


        # Train meta-learner with validation loss
        # same as copy_flat_params; only diff = parameters of the model are not nn.Params anymore, they're just plain tensors now.
        learner_wo_grad.transfer_params(learner_w_grad, cI)

        # do a forward pass and get the loss
        output = learner_wo_grad(eval_x)
        loss = learner_wo_grad.criterion(output, eval_y)
        acc = accuracy(output, eval_y)
        
        # update the metalearner aka the metalstm
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)
        optim.step()

        # loggers
        logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')
        wandb.log({'loss':loss.item(), 'accuracy':acc}, step=eps)

        # Meta-validation
        if eps % args.val_freq == 0 and eps != 0:
            save_ckpt(eps, metalearner, optim, args.save)
            val_batch = tasksets.train.sample()
            acc, test_loss = meta_test(eps, val_batch, learner_w_grad, learner_wo_grad, metalearner, args, logger)
            wandb.log({'test_loss':test_loss.item(), 'test_accuracy':acc}, step=eps)
            if acc > best_acc:
                best_acc = acc
                logger.loginfo("* Best accuracy so far *\n")

    logger.loginfo("Done")
コード例 #8
0
ファイル: render.py プロジェクト: jeonggwanlee/varibad
from metalearner import MetaLearner
import argparse
from config.mujoco import args_mujoco_cheetah_joint_varibad, args_mujoco_cheetah_hfield_varibad, \
    args_mujoco_cheetah_blocks_varibad
import matplotlib

matplotlib.use('Agg')

parser = argparse.ArgumentParser()
parser.add_argument('--env-type', default='gridworld_varibad')
args, rest_args = parser.parse_known_args()
env = args.env_type

#args = args_mujoco_cheetah_joint_varibad.get_args(rest_args)
#args = args_mujoco_cheetah_hfield_varibad.get_args(rest_args)
args = args_mujoco_cheetah_blocks_varibad.get_args(rest_args)

metalearner = MetaLearner(args)

metalearner.load_and_render(load_iter=4000)
#metalearner.load(load_iter=3500)
コード例 #9
0
def main():

    args, unparsed = FLAGS.parse_known_args()
    if len(unparsed) != 0:
        raise NameError("Argument {} not recognized".format(unparsed))

    if args.seed is None:
        args.seed = random.randint(0, 1e3)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cpu:
        args.dev = torch.device('cpu')
    else:
        if not torch.cuda.is_available():
            raise RuntimeError("GPU unavailable.")

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    args.dev = torch.device('cuda')

    logger = GOATLogger(args)

    # Get data
    #train_loader, val_loader, test_loader = prepare_data(args)
    dataset = StateDataset(epoch_len=10240, batch_size=1)
    train_loader = torch.utils.data.DataLoader(dataset)
    # Set up learner, meta-learner
    learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev)
    learner_wo_grad = copy.deepcopy(learner_w_grad)
    metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev)
    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())
    A = torch.tensor(np.array([[1.0,1.0],[0.0,1.0]])).float().to(args.dev)
    B = torch.tensor(np.array([[0.0],[1.0]])).float().to(args.dev)
    Q = torch.tensor(0.01*np.diag([1.0, 1.0])).float().to(args.dev)
    R = torch.tensor([[0.1]]).float().to(args.dev)
    # Set up loss, optimizer, learning rate scheduler
    optim = torch.optim.Adam(metalearner.parameters(), args.lr)

    if args.resume:
        logger.loginfo("Initialized from: {}".format(args.resume))
        last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev)

    if args.mode == 'test':
        _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
        return

    best_acc = 0.0
    logger.loginfo("Start training")
    # Meta-training
    for eps, X in enumerate(train_loader):
        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
        #train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]
        #train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]
        #test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]
        #test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]
        X = X.float().to(args.dev)
        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.train()
        cI = train_learner(learner_w_grad, metalearner, X, args)
        x = X[0,0:1]
        # Train meta-learner with validation loss
        learner_wo_grad.transfer_params(learner_w_grad, cI)
        T = 15
        x_list = []
        x_list.append(x)
        u_list = []
        for i in range(T):
            u = learner_wo_grad(x_list[i])
            x_next = A@x_list[i].T + [email protected]
            x_list.append(x_next.T)
            u_list.append(u)
        #output = learner_wo_grad(X)
        loss = learner_wo_grad.criterion(x_list, u_list, A, B, Q, R)
        #acc = accuracy(output, test_target)
        if(eps%1000==0):
            print('loss: ', loss.item())
            print(x_list)
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)
        optim.step()
        if(eps%100==0):
            print(eps)

        #logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), phase='train')

        # Meta-validation
        #if eps % args.val_freq == 0 and eps != 0:
        #    save_ckpt(eps, metalearner, optim, args.save)
            #acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
            #if acc > best_acc:
            #    best_acc = acc
            #    logger.loginfo("* Best accuracy so far *\n")

    logger.loginfo("Done")
コード例 #10
0
ファイル: main.py プロジェクト: lmzintgraf/varibad
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env-type', default='gridworld_varibad')
    args, rest_args = parser.parse_known_args()
    env = args.env_type

    # --- GridWorld ---

    if env == 'gridworld_belief_oracle':
        args = args_grid_belief_oracle.get_args(rest_args)
    elif env == 'gridworld_varibad':
        args = args_grid_varibad.get_args(rest_args)
    elif env == 'gridworld_rl2':
        args = args_grid_rl2.get_args(rest_args)

    # --- PointRobot 2D Navigation ---

    elif env == 'pointrobot_multitask':
        args = args_pointrobot_multitask.get_args(rest_args)
    elif env == 'pointrobot_varibad':
        args = args_pointrobot_varibad.get_args(rest_args)
    elif env == 'pointrobot_rl2':
        args = args_pointrobot_rl2.get_args(rest_args)
    elif env == 'pointrobot_humplik':
        args = args_pointrobot_humplik.get_args(rest_args)

    # --- MUJOCO ---

    # - CheetahDir -
    elif env == 'cheetah_dir_multitask':
        args = args_cheetah_dir_multitask.get_args(rest_args)
    elif env == 'cheetah_dir_expert':
        args = args_cheetah_dir_expert.get_args(rest_args)
    elif env == 'cheetah_dir_varibad':
        args = args_cheetah_dir_varibad.get_args(rest_args)
    elif env == 'cheetah_dir_rl2':
        args = args_cheetah_dir_rl2.get_args(rest_args)
    #
    # - CheetahVel -
    elif env == 'cheetah_vel_multitask':
        args = args_cheetah_vel_multitask.get_args(rest_args)
    elif env == 'cheetah_vel_expert':
        args = args_cheetah_vel_expert.get_args(rest_args)
    elif env == 'cheetah_vel_avg':
        args = args_cheetah_vel_avg.get_args(rest_args)
    elif env == 'cheetah_vel_varibad':
        args = args_cheetah_vel_varibad.get_args(rest_args)
    elif env == 'cheetah_vel_rl2':
        args = args_cheetah_vel_rl2.get_args(rest_args)
    #
    # - AntDir -
    elif env == 'ant_dir_multitask':
        args = args_ant_dir_multitask.get_args(rest_args)
    elif env == 'ant_dir_expert':
        args = args_ant_dir_expert.get_args(rest_args)
    elif env == 'ant_dir_varibad':
        args = args_ant_dir_varibad.get_args(rest_args)
    elif env == 'ant_dir_rl2':
        args = args_ant_dir_rl2.get_args(rest_args)
    #
    # - AntGoal -
    elif env == 'ant_goal_multitask':
        args = args_ant_goal_multitask.get_args(rest_args)
    elif env == 'ant_goal_expert':
        args = args_ant_goal_expert.get_args(rest_args)
    elif env == 'ant_goal_varibad':
        args = args_ant_goal_varibad.get_args(rest_args)
    elif env == 'ant_goal_humplik':
        args = args_ant_goal_humplik.get_args(rest_args)
    elif env == 'ant_goal_rl2':
        args = args_ant_goal_rl2.get_args(rest_args)
    #
    # - Walker -
    elif env == 'walker_multitask':
        args = args_walker_multitask.get_args(rest_args)
    elif env == 'walker_expert':
        args = args_walker_expert.get_args(rest_args)
    elif env == 'walker_avg':
        args = args_walker_avg.get_args(rest_args)
    elif env == 'walker_varibad':
        args = args_walker_varibad.get_args(rest_args)
    elif env == 'walker_rl2':
        args = args_walker_rl2.get_args(rest_args)
    #
    # - HumanoidDir -
    elif env == 'humanoid_dir_multitask':
        args = args_humanoid_dir_multitask.get_args(rest_args)
    elif env == 'humanoid_dir_expert':
        args = args_humanoid_dir_expert.get_args(rest_args)
    elif env == 'humanoid_dir_varibad':
        args = args_humanoid_dir_varibad.get_args(rest_args)
    elif env == 'humanoid_dir_rl2':
        args = args_humanoid_dir_rl2.get_args(rest_args)
    else:
        raise Exception("Invalid Environment")

    # warning for deterministic execution
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    # if we're normalising the actions, we have to make sure that the env expects actions within [-1, 1]
    if args.norm_actions_pre_sampling or args.norm_actions_post_sampling:
        envs = make_vec_envs(
            env_name=args.env_name,
            seed=0,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            device='cpu',
            episodes_per_task=args.max_rollouts_per_task,
            normalise_rew=args.norm_rew_for_policy,
            ret_rms=None,
            tasks=None,
        )
        assert np.unique(envs.action_space.low) == [-1]
        assert np.unique(envs.action_space.high) == [1]

    # clean up arguments
    if args.disable_metalearner or args.disable_decoder:
        args.decode_reward = False
        args.decode_state = False
        args.decode_task = False

    if hasattr(args, 'decode_only_past') and args.decode_only_past:
        args.split_batches_by_elbo = True
    # if hasattr(args, 'vae_subsample_decodes') and args.vae_subsample_decodes:
    #     args.split_batches_by_elbo = True

    # begin training (loop through all passed seeds)
    seed_list = [args.seed] if isinstance(args.seed, int) else args.seed
    for seed in seed_list:
        print('training', seed)
        args.seed = seed
        args.action_space = None

        if args.disable_metalearner:
            # If `disable_metalearner` is true, the file `learner.py` will be used instead of `metalearner.py`.
            # This is a stripped down version without encoder, decoder, stochastic latent variables, etc.
            learner = Learner(args)
        else:
            learner = MetaLearner(args)
        learner.train()
コード例 #11
0
def main():

    args, unparsed = FLAGS.parse_known_args()
    if len(unparsed) != 0:
        raise NameError("Argument {} not recognized".format(unparsed))

    if args.seed is None:
        args.seed = random.randint(0, 1e3)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cpu:
        args.dev = torch.device('cpu')
    else:
        if not torch.cuda.is_available():
            raise RuntimeError("GPU unavailable.")

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        args.dev = torch.device('cuda')

    #logger = GOATLogger(args)
    use_qrnn = True

    # Get data
    train_loader, val_loader, test_loader = prepare_data(args)

    # Set up learner, meta-learner
    learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum,
                             args.n_class).to(args.dev)
    learner_wo_grad = copy.deepcopy(learner_w_grad)
    metalearner = MetaLearner(args.input_size, args.hidden_size,
                              learner_w_grad.get_flat_params().size(0),
                              use_qrnn).to(args.dev)
    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())

    # Set up loss, optimizer, learning rate scheduler
    optim = torch.optim.Adam(metalearner.parameters(), args.lr)

    if args.resume:
        #logger.loginfo("Initialized from: {}".format(args.resume))
        last_eps, metalearner, optim = resume_ckpt(metalearner, optim,
                                                   args.resume, args.dev)

    if args.mode == 'test':
        #_ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
        return

    best_acc = 0.0
    print("Starting training...")
    print("Shots: ", args.n_shot)
    print("Classes: ", args.n_class)

    start_time = datetime.now()

    # Meta-training
    for eps, (episode_x, episode_y) in enumerate(train_loader):
        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
        train_input = episode_x[:, :args.n_shot].reshape(
            -1, *episode_x.shape[-3:]).to(args.dev)  # [n_class * n_shot, :]
        train_target = torch.LongTensor(
            np.repeat(range(args.n_class),
                      args.n_shot)).to(args.dev)  # [n_class * n_shot]
        test_input = episode_x[:, args.n_shot:].reshape(
            -1, *episode_x.shape[-3:]).to(args.dev)  # [n_class * n_eval, :]
        test_target = torch.LongTensor(
            np.repeat(range(args.n_class),
                      args.n_eval)).to(args.dev)  # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.train()
        cI = train_learner(learner_w_grad, metalearner, train_input,
                           train_target, args)

        # Train meta-learner with validation loss
        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)

        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)
        optim.step()

        if ((eps + 1) % 250 == 0 or eps == 0):
            print(eps + 1, "/", args.episode, " Loss: ", loss.item(), " Acc:",
                  acc)
        #logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')

        # Meta-validation
        if ((eps + 1) % args.val_freq == 0
                and eps != 0) or eps + 1 == args.episode:
            #save_ckpt(eps, metalearner, optim, args.save)
            acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad,
                            metalearner, args)
            print("Meta validation: ", eps + 1, " Acc: ", acc)
            if acc > best_acc:
                best_acc = acc
                print("    New best: ", acc)
            # logger.loginfo("* Best accuracy so far *\n")

    end_time = datetime.now()
    print("Time to execute: ", end_time - start_time)
    print("Average per iteration", (end_time - start_time) / args.episode)
    torch.cuda.empty_cache()
    #acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args)
    print("Training complete, best acc: ", best_acc)
コード例 #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env-type', default='ant_dir_rl2')
    args, rest_args = parser.parse_known_args()
    env = args.env_type

    # --- GridWorld ---

    if env == 'gridworld_oracle':
        args = args_grid_oracle.get_args(rest_args)
    elif env == 'gridworld_belief_oracle':
        args = args_grid_belief_oracle.get_args(rest_args)
    elif env == 'gridworld_varibad':
        args = args_grid_varibad.get_args(rest_args)
    elif env == 'gridworld_rl2':
        args = args_grid_rl2.get_args(rest_args)

    # --- MUJOCO ---

    # - AntDir -
    elif env == 'ant_dir_oracle':
        args = args_ant_dir_oracle.get_args(rest_args)
    elif env == 'ant_dir_rl2':
        args = args_ant_dir_rl2.get_args(rest_args)
    elif env == 'ant_dir_varibad':
        args = args_ant_dir_varibad.get_args(rest_args)
    #
    # - AntGoal -
    elif env == 'ant_goal_oracle':
        args = args_ant_goal_oracle.get_args(rest_args)
    elif env == 'ant_goal_varibad':
        args = args_ant_goal_varibad.get_args(rest_args)
    elif env == 'ant_goal_rl2':
        args = args_ant_goal_rl2.get_args(rest_args)
    #
    # - CheetahDir -
    elif env == 'cheetah_dir_oracle':
        args = args_cheetah_dir_oracle.get_args(rest_args)
    elif env == 'cheetah_dir_rl2':
        args = args_cheetah_dir_rl2.get_args(rest_args)
    elif env == 'cheetah_dir_varibad':
        args = args_cheetah_dir_varibad.get_args(rest_args)
    #
    # - CheetahVel -
    elif env == 'cheetah_vel_oracle':
        args = args_cheetah_vel_oracle.get_args(rest_args)
    elif env == 'cheetah_vel_rl2':
        args = args_cheetah_vel_rl2.get_args(rest_args)
    elif env == 'cheetah_vel_varibad':
        args = args_cheetah_vel_varibad.get_args(rest_args)
    elif env == 'cheetah_vel_avg':
        args = args_cheetah_vel_avg.get_args(rest_args)
    #
    # - Walker -
    elif env == 'walker_oracle':
        args = args_walker_oracle.get_args(rest_args)
    elif env == 'walker_avg':
        args = args_walker_avg.get_args(rest_args)
    elif env == 'walker_rl2':
        args = args_walker_rl2.get_args(rest_args)
    elif env == 'walker_varibad':
        args = args_walker_varibad.get_args(rest_args)

    # warning for deterministic execution
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError('If you want fully deterministic code, use num_processes 1.'
                               'Warning: This will slow things down and might break A2C if '
                               'policy_num_steps < env._max_episode_steps.')

    # clean up arguments
    if hasattr(args, 'disable_decoder') and args.disable_decoder:
        args.decode_reward = False
        args.decode_state = False
        args.decode_task = False

    if hasattr(args, 'decode_only_past') and args.decode_only_past:
        args.split_batches_by_elbo = True
    # if hasattr(args, 'vae_subsample_decodes') and args.vae_subsample_decodes:
    #     args.split_batches_by_elbo = True

    # begin training (loop through all passed seeds)
    seed_list = [args.seed] if isinstance(args.seed, int) else args.seed
    for seed in seed_list:
        print('training', seed)
        args.seed = seed

        if args.disable_metalearner:
            # If `disable_metalearner` is true, the file `learner.py` will be used instead of `metalearner.py`.
            # This is a stripped down version without encoder, decoder, stochastic latent variables, etc.
            learner = Learner(args)
        else:
            learner = MetaLearner(args)
        learner.train()
コード例 #13
0
def main():

    args, unparsed = FLAGS.parse_known_args()
    args = brandos_load(args)
    if len(unparsed) != 0:
        raise NameError("Argument {} not recognized".format(unparsed))

    if args.seed is None:
        args.seed = random.randint(0, 1e3)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    #args.dev = torch.device('cpu')
    if args.cpu:
        args.dev = torch.device('cpu')
        args.gpu_name = args.dev
    else:
        if not torch.cuda.is_available():
            raise RuntimeError("GPU unavailable.")

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        args.dev = torch.device('cuda')
        try:
            args.gpu_name = torch.cuda.get_device_name(0)
        except:
            args.gpu_name = args.dev

    print(f'device {args.dev}')
    logger = GOATLogger(args)

    # Get data
    train_loader, val_loader, test_loader = prepare_data(args)

    # Set up learner, meta-learner
    learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum,
                             args.n_class).to(args.dev)
    learner_wo_grad = copy.deepcopy(learner_w_grad)
    metalearner = MetaLearner(args.input_size, args.hidden_size,
                              learner_w_grad.get_flat_params().size(0)).to(
                                  args.dev)
    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())

    # Set up loss, optimizer, learning rate scheduler
    optim = torch.optim.Adam(metalearner.parameters(), args.lr)

    if args.resume:
        logger.loginfo("Initialized from: {}".format(args.resume))
        last_eps, metalearner, optim = resume_ckpt(metalearner, optim,
                                                   args.resume, args.dev)

    if args.mode == 'test':
        _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad,
                      metalearner, args, logger)
        return

    best_acc = 0.0
    logger.loginfo("---> Start training")
    # Meta-training
    for eps, (episode_x, episode_y) in enumerate(
            train_loader
    ):  # sample data set split episode_x = D = (D^{train},D^{test})
        print(f'episode = {eps}')
        #print(f'episode_y = {episode_y}')
        # print(f'episide_x.size() = {episode_x.size()}')  # episide_x.size() = torch.Size([5, 20, 3, 84, 84]) i.e. N classes for K shot task with K_eval query examples
        # print(f'episode_x.mean() = {episode_x.mean()}')
        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
        train_input = episode_x[:, :args.n_shot].reshape(
            -1, *episode_x.shape[-3:]).to(args.dev)  # [n_class * n_shot, :]
        train_target = torch.LongTensor(
            np.repeat(range(args.n_class),
                      args.n_shot)).to(args.dev)  # [n_class * n_shot]
        test_input = episode_x[:, args.n_shot:].reshape(
            -1, *episode_x.shape[-3:]).to(args.dev)  # [n_class * n_eval, :]
        test_target = torch.LongTensor(
            np.repeat(range(args.n_class),
                      args.n_eval)).to(args.dev)  # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.train()
        cI = train_learner(learner_w_grad, metalearner, train_input,
                           train_target, args)

        # Train meta-learner with validation loss
        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)

        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)
        optim.step()

        logger.batch_info(eps=eps,
                          totaleps=args.episode,
                          loss=loss.item(),
                          acc=acc,
                          phase='train')

        # Meta-validation
        if eps % args.val_freq == 0 and eps != 0:
            save_ckpt(eps, metalearner, optim, args.save)
            acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad,
                            metalearner, args, logger)
            if acc > best_acc:
                best_acc = acc
                logger.loginfo(f"* Best accuracy so far {acc}*\n")

    logger.loginfo(f'acc: {acc}')
    logger.loginfo(f"* Best accuracy so far {acc}*\n")
    logger.loginfo("Done")
コード例 #14
0
ファイル: main.py プロジェクト: zhanghaoyue/cavia
def main(args):
    print('starting....')

    utils.set_seed(args.seed, cudnn=args.make_deterministic)

    continuous_actions = (args.env_name in ['AntVel-v1', 'AntDir-v1',
                                            'AntPos-v0', 'HalfCheetahVel-v1', 'HalfCheetahDir-v1',
                                            '2DNavigation-v0'])

    # subfolders for logging
    method_used = 'maml' if args.maml else 'cavia'
    num_context_params = str(args.num_context_params) + '_' if not args.maml else ''
    output_name = num_context_params + 'lr=' + str(args.fast_lr) + 'tau=' + str(args.tau)
    output_name += '_' + datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S')
    dir_path = os.path.dirname(os.path.realpath(__file__))
    log_folder = os.path.join(os.path.join(dir_path, 'logs'), args.env_name, method_used, output_name)
    save_folder = os.path.join(os.path.join(dir_path, 'saves'), output_name)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)

    # initialise tensorboard writer
    writer = SummaryWriter(log_folder)

    # save config file
    with open(os.path.join(save_folder, 'config.json'), 'w') as f:
        config = {k: v for (k, v) in vars(args).items() if k != 'device'}
        config.update(device=args.device.type)
        json.dump(config, f, indent=2)
    with open(os.path.join(log_folder, 'config.json'), 'w') as f:
        config = {k: v for (k, v) in vars(args).items() if k != 'device'}
        config.update(device=args.device.type)
        json.dump(config, f, indent=2)

    sampler = BatchSampler(args.env_name, batch_size=args.fast_batch_size, num_workers=args.num_workers,
                           device=args.device, seed=args.seed)

    if continuous_actions:
        if not args.maml:
            policy = CaviaMLPPolicy(
                int(np.prod(sampler.envs.observation_space.shape)),
                int(np.prod(sampler.envs.action_space.shape)),
                hidden_sizes=(args.hidden_size,) * args.num_layers,
                num_context_params=args.num_context_params,
                device=args.device
            )
        else:
            policy = NormalMLPPolicy(
                int(np.prod(sampler.envs.observation_space.shape)),
                int(np.prod(sampler.envs.action_space.shape)),
                hidden_sizes=(args.hidden_size,) * args.num_layers
            )
    else:
        if not args.maml:
            raise NotImplementedError
        else:
            policy = CategoricalMLPPolicy(
                int(np.prod(sampler.envs.observation_space.shape)),
                sampler.envs.action_space.n,
                hidden_sizes=(args.hidden_size,) * args.num_layers)

    # initialise baseline
    baseline = LinearFeatureBaseline(int(np.prod(sampler.envs.observation_space.shape)))

    # initialise meta-learner
    metalearner = MetaLearner(sampler, policy, baseline, gamma=args.gamma, fast_lr=args.fast_lr, tau=args.tau,
                              device=args.device)

    for batch in range(args.num_batches):

        # get a batch of tasks
        tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size)

        # do the inner-loop update for each task
        # this returns training (before update) and validation (after update) episodes
        episodes, inner_losses = metalearner.sample(tasks, first_order=args.first_order)

        # take the meta-gradient step
        outer_loss = metalearner.step(episodes, max_kl=args.max_kl, cg_iters=args.cg_iters,
                                      cg_damping=args.cg_damping, ls_max_steps=args.ls_max_steps,
                                      ls_backtrack_ratio=args.ls_backtrack_ratio)

        # -- logging

        curr_returns = total_rewards(episodes, interval=True)
        print('   return after update: ', curr_returns[0][1])

        # Tensorboard
        writer.add_scalar('policy/actions_train', episodes[0][0].actions.mean(), batch)
        writer.add_scalar('policy/actions_test', episodes[0][1].actions.mean(), batch)

        writer.add_scalar('running_returns/before_update', curr_returns[0][0], batch)
        writer.add_scalar('running_returns/after_update', curr_returns[0][1], batch)

        writer.add_scalar('running_cfis/before_update', curr_returns[1][0], batch)
        writer.add_scalar('running_cfis/after_update', curr_returns[1][1], batch)

        writer.add_scalar('loss/inner_rl', np.mean(inner_losses), batch)
        writer.add_scalar('loss/outer_rl', outer_loss.item(), batch)

        # -- evaluation

        # evaluate for multiple update steps
        if batch % args.test_freq == 0:
            test_tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size)
            test_episodes = metalearner.test(test_tasks, num_steps=args.num_test_steps,
                                             batch_size=args.test_batch_size, halve_lr=args.halve_test_lr)
            all_returns = total_rewards(test_episodes, interval=True)
            for num in range(args.num_test_steps + 1):
                writer.add_scalar('evaluation_rew/avg_rew ' + str(num), all_returns[0][num], batch)
                writer.add_scalar('evaluation_cfi/avg_rew ' + str(num), all_returns[1][num], batch)

            print('   inner RL loss:', np.mean(inner_losses))
            print('   outer RL loss:', outer_loss.item())

        # -- save policy network
        with open(os.path.join(save_folder, 'policy-{0}.pt'.format(batch)), 'wb') as f:
            torch.save(policy.state_dict(), f)
コード例 #15
0
ファイル: meta_train.py プロジェクト: zyr17/metalight
def metalight_train(dic_exp_conf, dic_agent_conf, _dic_traffic_env_conf,
                    _dic_path, tasks, batch_id):
    '''
        metalight meta-train function 

        Arguments:
            dic_exp_conf:           dict,   configuration of this experiment
            dic_agent_conf:         dict,   configuration of agent
            _dic_traffic_env_conf:  dict,   configuration of traffic environment
            _dic_path:              dict,   path of source files and output files
            tasks:                  list,   traffic files name in this round 
            batch_id:               int,    round number
    '''
    tot_path = []
    tot_traffic_env = []
    for task in tasks:
        dic_traffic_env_conf = copy.deepcopy(_dic_traffic_env_conf)
        dic_path = copy.deepcopy(_dic_path)
        dic_path.update({
            "PATH_TO_DATA":
            os.path.join(dic_path['PATH_TO_DATA'],
                         task.split(".")[0])
        })
        # parse roadnet
        dic_traffic_env_conf["ROADNET_FILE"] = dic_traffic_env_conf[
            "traffic_category"]["traffic_info"][task][2]
        dic_traffic_env_conf["FLOW_FILE"] = dic_traffic_env_conf[
            "traffic_category"]["traffic_info"][task][3]
        roadnet_path = os.path.join(
            dic_path['PATH_TO_DATA'], dic_traffic_env_conf["traffic_category"]
            ["traffic_info"][task][2])  # dic_traffic_env_conf['ROADNET_FILE'])
        lane_phase_info = parse_roadnet(roadnet_path)
        dic_traffic_env_conf["LANE_PHASE_INFO"] = lane_phase_info[
            "intersection_1_1"]
        dic_traffic_env_conf["num_lanes"] = int(
            len(lane_phase_info["intersection_1_1"]["start_lane"]) /
            4)  # num_lanes per direction
        dic_traffic_env_conf["num_phases"] = len(
            lane_phase_info["intersection_1_1"]["phase"])

        dic_traffic_env_conf["TRAFFIC_FILE"] = task

        tot_path.append(dic_path)
        tot_traffic_env.append(dic_traffic_env_conf)

    sampler = BatchSampler(dic_exp_conf=dic_exp_conf,
                           dic_agent_conf=dic_agent_conf,
                           dic_traffic_env_conf=tot_traffic_env,
                           dic_path=tot_path,
                           batch_size=args.fast_batch_size,
                           num_workers=args.num_workers)

    policy = config.DIC_AGENTS[args.algorithm](
        dic_agent_conf=dic_agent_conf,
        dic_traffic_env_conf=tot_traffic_env,
        dic_path=tot_path)

    metalearner = MetaLearner(sampler,
                              policy,
                              dic_agent_conf=dic_agent_conf,
                              dic_traffic_env_conf=tot_traffic_env,
                              dic_path=tot_path)

    if batch_id == 0:
        params = pickle.load(
            open(os.path.join(dic_path['PATH_TO_MODEL'], 'params_init.pkl'),
                 'rb'))
        params = [params] * len(policy.policy_inter)
        metalearner.meta_params = params
        metalearner.meta_target_params = params

    else:
        params = pickle.load(
            open(
                os.path.join(dic_path['PATH_TO_MODEL'],
                             'params_%d.pkl' % (batch_id - 1)), 'rb'))
        params = [params] * len(policy.policy_inter)
        metalearner.meta_params = params
        period = dic_agent_conf['PERIOD']
        target_id = int((batch_id - 1) / period)
        meta_params = pickle.load(
            open(
                os.path.join(dic_path['PATH_TO_MODEL'],
                             'params_%d.pkl' % (target_id * period)), 'rb'))
        meta_params = [meta_params] * len(policy.policy_inter)
        metalearner.meta_target_params = meta_params

    metalearner.sample_metalight(tasks, batch_id)
コード例 #16
0
def main(args):

    dataset_name = args.dataset
    model_name = args.model
    n_inner_iter = args.adaptation_steps
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    is_test = args.is_test
    stopping_patience = args.stopping_patience
    epochs = args.epochs
    fast_lr = args.learning_rate
    slow_lr = args.meta_learning_rate

    first_order = False
    inner_loop_grad_clip = 20
    task_size = 50
    output_dim = 1

    horizon = 10
    ##test

    meta_info = {
        "POLLUTION": [5, 50, 14],
        "HR": [32, 50, 13],
        "BATTERY": [20, 50, 3]
    }

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, task_size, input_dim = meta_info[dataset_name]

    output_directory = "output/"

    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str(
            trial) + "/"
        save_model_file_ = output_directory + save_model_file
        load_model_file_ = output_directory + load_model_file

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Learning rate :%f \n" % fast_lr)
            f.write("Meta-learning rate: %f \n" % slow_lr)
            f.write("Adaptation steps: %f \n" % n_inner_iter)
            f.write("\n")

        if model_name == "LSTM":
            model = LSTMModel(batch_size=batch_size,
                              seq_len=window_size,
                              input_dim=input_dim,
                              n_layers=2,
                              hidden_dim=120,
                              output_dim=output_dim)

        optimizer = torch.optim.Adam(model.parameters(), lr=slow_lr)
        loss_func = mae

        torch.backends.cudnn.enabled = False

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        meta_learner = MetaLearner(model, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        total_tasks, task_size, window_size, input_dim = train_data_ML.x.shape

        early_stopping = EarlyStopping(patience=stopping_patience,
                                       model_file=save_model_file_,
                                       verbose=True)

        for _ in range(epochs):

            #train
            batch_idx = np.random.randint(0, total_tasks - 1, batch_size)
            x_spt, y_spt = train_data_ML[batch_idx]
            x_qry, y_qry = train_data_ML[batch_idx + 1]

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            train_tasks = [
                Task(x_spt[i], y_spt[i]) for i in range(x_spt.shape[0])
            ]
            val_tasks = [
                Task(x_qry[i], y_qry[i]) for i in range(x_qry.shape[0])
            ]

            adapted_params = meta_learner.adapt(train_tasks)
            mean_loss = meta_learner.step(adapted_params,
                                          val_tasks,
                                          is_training=True)
            print(mean_loss)

            #test
            val_error = test(validation_data_ML, meta_learner, device)
            print(val_error)

            early_stopping(val_error, meta_learner)

            if early_stopping.early_stop:
                print("Early stopping")
                break

        model.load_state_dict(torch.load(save_model_file_)["model_state_dict"])
        meta_learner = MetaLearner(model, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        validation_error = test(validation_data_ML, meta_learner, device)
        test_error = test(test_data_ML, meta_learner, device)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Test error: %f \n" % test_error)
            f.write("Validation error: %f \n" % validation_error)

        print(test_error)
        print(validation_error)