def main(args): # Set logging if not os.path.exists("./log"): os.makedirs("./log") log = set_log(args) tb_writer = SummaryWriter('./log/tb_{0}'.format(args.log_name)) # Set seed set_seed(args.seed, cudnn=args.make_deterministic) # Set sampler sampler = BatchSampler(args, log) # Set policy policy = CaviaMLPPolicy( input_size=int(np.prod(sampler.observation_space.shape)), output_size=int(np.prod(sampler.action_space.shape)), hidden_sizes=(args.hidden_size, ) * args.num_layers, num_context_params=args.num_context_params, device=args.device) # Initialise baseline baseline = LinearFeatureBaseline( int(np.prod(sampler.observation_space.shape))) # Initialise meta-learner metalearner = MetaLearner(sampler, policy, baseline, args, tb_writer) # Begin train train(sampler, metalearner, args, log, tb_writer)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env-type', default='cheetah_vel') # parser.add_argument('--env-type', default='point_robot_sparse') # parser.add_argument('--env-type', default='gridworld') args, rest_args = parser.parse_known_args() env = args.env_type # --- GridWorld --- if env == 'gridworld': args = args_gridworld.get_args(rest_args) # --- PointRobot --- elif env == 'point_robot': args = args_point_robot.get_args(rest_args) elif env == 'point_robot_sparse': args = args_point_robot_sparse.get_args(rest_args) # --- Mujoco --- elif env == 'cheetah_vel': args = args_cheetah_vel.get_args(rest_args) elif env == 'ant_semicircle': args = args_ant_semicircle.get_args(rest_args) elif env == 'ant_semicircle_sparse': args = args_ant_semicircle_sparse.get_args(rest_args) # make sure we have log directories try: os.makedirs(args.agent_log_dir) except OSError: files = glob.glob(os.path.join(args.agent_log_dir, '*.monitor.csv')) for f in files: os.remove(f) eval_log_dir = args.agent_log_dir + "_eval" try: os.makedirs(eval_log_dir) except OSError: files = glob.glob(os.path.join(eval_log_dir, '*.monitor.csv')) for f in files: os.remove(f) # set gpu set_gpu_mode(torch.cuda.is_available() and args.use_gpu) # start training learner = MetaLearner(args) learner.train()
def _train(dic_exp_conf, dic_agent_conf, dic_traffic_env_conf, dic_path): ''' Perform meta-testing for MAML, Metalight, Random, and Pretrained Arguments: dic_exp_conf: dict, configuration of this experiment dic_agent_conf: dict, configuration of agent dic_traffic_env_conf: dict, configuration of traffic environment dic_path: dict, path of source files and output files ''' random.seed(dic_agent_conf['SEED']) np.random.seed(dic_agent_conf['SEED']) tf.set_random_seed(dic_agent_conf['SEED']) sampler = BatchSampler(dic_exp_conf=dic_exp_conf, dic_agent_conf=dic_agent_conf, dic_traffic_env_conf=dic_traffic_env_conf, dic_path=dic_path, batch_size=args.fast_batch_size, num_workers=args.num_workers) policy = config.DIC_AGENTS[args.algorithm]( dic_agent_conf=dic_agent_conf, dic_traffic_env_conf=dic_traffic_env_conf, dic_path=dic_path) metalearner = MetaLearner(sampler, policy, dic_agent_conf=dic_agent_conf, dic_traffic_env_conf=dic_traffic_env_conf, dic_path=dic_path) if dic_agent_conf['PRE_TRAIN']: if not dic_agent_conf['PRE_TRAIN_MODEL_NAME'] == 'random': params = pickle.load( open( os.path.join( 'model', 'initial', "common", dic_agent_conf['PRE_TRAIN_MODEL_NAME'] + '.pkl'), 'rb')) metalearner.meta_params = params metalearner.meta_target_params = params tasks = dic_exp_conf['TRAFFIC_IN_TASKS'] episodes = None for batch_id in range(dic_exp_conf['NUM_ROUNDS']): tasks = [dic_exp_conf['TRAFFIC_FILE']] if dic_agent_conf['MULTI_EPISODES']: episodes = metalearner.sample_meta_test(tasks[0], batch_id, episodes) else: episodes = metalearner.sample_meta_test(tasks[0], batch_id)
def main(args): dataset_name = args.dataset model_name = args.model n_inner_iter = args.adaptation_steps batch_size = args.batch_size save_model_file = args.save_model_file load_model_file = args.load_model_file lower_trial = args.lower_trial upper_trial = args.upper_trial is_test = args.is_test stopping_patience = args.stopping_patience epochs = args.epochs fast_lr = args.learning_rate slow_lr = args.meta_learning_rate noise_level = args.noise_level noise_type = args.noise_type resume = args.resume first_order = False inner_loop_grad_clip = 20 task_size = 50 output_dim = 1 checkpoint_freq = 10 horizon = 10 ##test meta_info = { "POLLUTION": [5, 50, 14], "HR": [32, 50, 13], "BATTERY": [20, 50, 3] } assert model_name in ("FCN", "LSTM"), "Model was not correctly specified" assert dataset_name in ("POLLUTION", "HR", "BATTERY") window_size, task_size, input_dim = meta_info[dataset_name] grid = [0., noise_level] output_directory = "output/" train_data_ML = pickle.load( open( "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) validation_data_ML = pickle.load( open( "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) test_data_ML = pickle.load( open( "../../Data/TEST-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) for trial in range(lower_trial, upper_trial): output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str( trial) + "/" save_model_file_ = output_directory + save_model_file save_model_file_encoder = output_directory + "encoder_" + save_model_file load_model_file_ = output_directory + load_model_file checkpoint_file = output_directory + "checkpoint_" + save_model_file.split( ".")[0] try: os.mkdir(output_directory) except OSError as error: print(error) with open(output_directory + "/results2.txt", "a+") as f: f.write("Learning rate :%f \n" % fast_lr) f.write("Meta-learning rate: %f \n" % slow_lr) f.write("Adaptation steps: %f \n" % n_inner_iter) f.write("Noise level: %f \n" % noise_level) if model_name == "LSTM": model = LSTMModel(batch_size=batch_size, seq_len=window_size, input_dim=input_dim, n_layers=2, hidden_dim=120, output_dim=output_dim) model2 = LinearModel(120, 1) optimizer = torch.optim.Adam(list(model.parameters()) + list(model2.parameters()), lr=slow_lr) loss_func = mae #loss_func = nn.SmoothL1Loss() #loss_func = nn.MSELoss() initial_epoch = 0 #torch.backends.cudnn.enabled = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func, first_order, n_inner_iter, inner_loop_grad_clip, device) model.to(device) early_stopping = EarlyStopping(patience=stopping_patience, model_file=save_model_file_encoder, verbose=True) early_stopping2 = EarlyStopping(patience=stopping_patience, model_file=save_model_file_, verbose=True) if resume: checkpoint = torch.load(checkpoint_file) model.load_state_dict(checkpoint["model"]) meta_learner.load_state_dict(checkpoint["meta_learner"]) initial_epoch = checkpoint["epoch"] best_score = checkpoint["best_score"] counter = checkpoint["counter_stopping"] early_stopping.best_score = best_score early_stopping2.best_score = best_score early_stopping.counter = counter early_stopping2.counter = counter total_tasks, task_size, window_size, input_dim = train_data_ML.x.shape accum_mean = 0.0 for epoch in range(initial_epoch, epochs): model.zero_grad() meta_learner._model.zero_grad() #train batch_idx = np.random.randint(0, total_tasks - 1, batch_size) #for batch_idx in range(0, total_tasks-1, batch_size): x_spt, y_spt = train_data_ML[batch_idx] x_qry, y_qry = train_data_ML[batch_idx + 1] x_spt, y_spt = to_torch(x_spt), to_torch(y_spt) x_qry = to_torch(x_qry) y_qry = to_torch(y_qry) # data augmentation epsilon = grid[np.random.randint(0, len(grid))] if noise_type == "additive": y_spt = y_spt + epsilon y_qry = y_qry + epsilon else: y_spt = y_spt * (1 + epsilon) y_qry = y_qry * (1 + epsilon) train_tasks = [ Task(model.encoder(x_spt[i]), y_spt[i]) for i in range(x_spt.shape[0]) ] val_tasks = [ Task(model.encoder(x_qry[i]), y_qry[i]) for i in range(x_qry.shape[0]) ] adapted_params = meta_learner.adapt(train_tasks) mean_loss = meta_learner.step(adapted_params, val_tasks, is_training=True) #accum_mean += mean_loss.cpu().detach().numpy() #progressBar(batch_idx, total_tasks, 100) #print(accum_mean/(batch_idx+1)) #test val_error = test(validation_data_ML, meta_learner, model, device, noise_level) test_error = test(test_data_ML, meta_learner, model, device, 0.0) print("Epoch:", epoch) print("Val error:", val_error) print("Test error:", test_error) early_stopping(val_error, model) early_stopping2(val_error, meta_learner) #checkpointing if epochs % checkpoint_freq == 0: torch.save( { "epoch": epoch, "model": model.state_dict(), "meta_learner": meta_learner.state_dict(), "best_score": early_stopping2.best_score, "counter_stopping": early_stopping2.counter }, checkpoint_file) if early_stopping.early_stop: print("Early stopping") break print("hallo") model.load_state_dict(torch.load(save_model_file_encoder)) model2.load_state_dict( torch.load(save_model_file_)["model_state_dict"]) meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func, first_order, n_inner_iter, inner_loop_grad_clip, device) validation_error = test(validation_data_ML, meta_learner, model, device, noise_level=0.0) test_error = test(test_data_ML, meta_learner, model, device, noise_level=0.0) validation_error_h1 = test(validation_data_ML, meta_learner, model, device, noise_level=0.0, horizon=1) test_error_h1 = test(test_data_ML, meta_learner, model, device, noise_level=0.0, horizon=1) model.load_state_dict(torch.load(save_model_file_encoder)) model2.load_state_dict( torch.load(save_model_file_)["model_state_dict"]) meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func, first_order, 0, inner_loop_grad_clip, device) validation_error_h0 = test(validation_data_ML, meta_learner2, model, device, noise_level=0.0, horizon=1) test_error_h0 = test(test_data_ML, meta_learner2, model, device, noise_level=0.0, horizon=1) model.load_state_dict(torch.load(save_model_file_encoder)) model2.load_state_dict( torch.load(save_model_file_)["model_state_dict"]) meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func, first_order, n_inner_iter, inner_loop_grad_clip, device) validation_error_mae = test(validation_data_ML, meta_learner2, model, device, 0.0) test_error_mae = test(test_data_ML, meta_learner2, model, device, 0.0) print("test_error_mae", test_error_mae) with open(output_directory + "/results2.txt", "a+") as f: f.write("Test error: %f \n" % test_error) f.write("Validation error: %f \n" % validation_error) f.write("Test error h1: %f \n" % test_error_h1) f.write("Validation error h1: %f \n" % validation_error_h1) f.write("Test error h0: %f \n" % test_error_h0) f.write("Validation error h0: %f \n" % validation_error_h0) f.write("Test error mae: %f \n" % test_error_mae) f.write("Validation error mae: %f \n" % validation_error_mae) print(test_error) print(validation_error)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env-type', default='gridworld_varibad') args, rest_args = parser.parse_known_args() env = args.env_type # --- GridWorld --- # standard if env == 'gridworld_oracle': args = args_grid_oracle.get_args(rest_args) elif env == 'gridworld_belief_oracle': args = args_grid_belief_oracle.get_args(rest_args) elif env == 'gridworld_varibad': args = args_grid_varibad.get_args(rest_args) elif env == 'gridworld_rl2': args = args_grid_rl2.get_args(rest_args) # --- MUJOCO --- # - AntDir - elif env == 'mujoco_ant_dir_oracle': args = args_mujoco_ant_dir_oracle.get_args(rest_args) elif env == 'mujoco_ant_dir_rl2': args = args_mujoco_ant_dir_rl2.get_args(rest_args) elif env == 'mujoco_ant_dir_varibad': args = args_mujoco_ant_dir_varibad.get_args(rest_args) # # - CheetahDir - elif env == 'mujoco_cheetah_dir_oracle': args = args_mujoco_cheetah_dir_oracle.get_args(rest_args) elif env == 'mujoco_cheetah_dir_rl2': args = args_mujoco_cheetah_dir_rl2.get_args(rest_args) elif env == 'mujoco_cheetah_dir_varibad': args = args_mujoco_cheetah_dir_varibad.get_args(rest_args) # # - CheetahVel - elif env == 'mujoco_cheetah_vel_oracle': args = args_mujoco_cheetah_vel_oracle.get_args(rest_args) elif env == 'mujoco_cheetah_vel_rl2': args = args_mujoco_cheetah_vel_rl2.get_args(rest_args) elif env == 'mujoco_cheetah_vel_varibad': args = args_mujoco_cheetah_vel_varibad.get_args(rest_args) # # - Walker - elif env == 'mujoco_walker_oracle': args = args_mujoco_walker_oracle.get_args(rest_args) elif env == 'mujoco_walker_rl2': args = args_mujoco_walker_rl2.get_args(rest_args) elif env == 'mujoco_walker_varibad': args = args_mujoco_walker_varibad.get_args(rest_args) # # - CheetahHField elif env == 'mujoco_cheetah_hfield_varibad': args = args_mujoco_cheetah_hfield_varibad.get_args(rest_args) # - CheetahHill elif env == 'mujoco_cheetah_hill_varibad': args = args_mujoco_cheetah_hill_varibad.get_args(rest_args) # - CheetahBasin elif env == 'mujoco_cheetah_basin_varibad': args = args_mujoco_cheetah_basin_varibad.get_args(rest_args) # - CheetahGentle elif env == 'mujoco_cheetah_gentle_varibad': args = args_mujoco_cheetah_gentle_varibad.get_args(rest_args) # - CheetahSteep elif env == 'mujoco_cheetah_steep_varibad': args = args_mujoco_cheetah_steep_varibad.get_args(rest_args) # - CheetahJoint elif env == 'mujoco_cheetah_joint_varibad': args = args_mujoco_cheetah_joint_varibad.get_args(rest_args) # - CheetahBlocks elif env == 'mujoco_cheetah_blocks_varibad': args = args_mujoco_cheetah_blocks_varibad.get_args(rest_args) # make sure we have log directories for mujoco if 'mujoco' in env: try: os.makedirs(args.agent_log_dir) except OSError: files = glob.glob(os.path.join(args.agent_log_dir, '*.monitor.csv')) for f in files: os.remove(f) eval_log_dir = args.agent_log_dir + "_eval" try: os.makedirs(eval_log_dir) except OSError: files = glob.glob(os.path.join(eval_log_dir, '*.monitor.csv')) for f in files: os.remove(f) # warning if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') # start training if args.disable_varibad: # When the flag `disable_varibad` is activated, the file `learner.py` will be used instead of `metalearner.py`. # This is a stripped down version without encoder, decoder, stochastic latent variables, etc. learner = Learner(args) else: learner = MetaLearner(args) learner.train()
def main(args): dataset_name = args.dataset model_name = args.model n_inner_iter = args.adaptation_steps meta_learning_rate = args.meta_learning_rate learning_rate = args.learning_rate batch_size = args.batch_size save_model_file = args.save_model_file load_model_file = args.load_model_file lower_trial = args.lower_trial upper_trial = args.upper_trial task_size = args.task_size noise_level = args.noise_level noise_type = args.noise_type epochs = args.epochs loss_fcn_str = args.loss modulate_task_net = args.modulate_task_net weight_vrae = args.weight_vrae stopping_patience = args.stopping_patience meta_info = {"POLLUTION": [5, 14], "HR": [32, 13], "BATTERY": [20, 3]} assert model_name in ("FCN", "LSTM"), "Model was not correctly specified" assert dataset_name in ("POLLUTION", "HR", "BATTERY") window_size, input_dim = meta_info[dataset_name] grid = [0., noise_level] train_data_ML = pickle.load( open( "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) validation_data_ML = pickle.load( open( "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) test_data_ML = pickle.load( open( "../../Data/TEST-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) total_tasks = len(train_data_ML) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") loss_fn = mae if loss_fcn_str == "MAE" else nn.SmoothL1Loss() ##multimodal learner parameters # paramters wto increase capactiy of the model n_layers_task_net = 2 n_layers_task_encoder = 2 n_layers_task_decoder = 2 hidden_dim_task_net = 120 hidden_dim_encoder = 120 hidden_dim_decoder = 120 # fixed values input_dim_task_net = input_dim input_dim_task_encoder = input_dim + 1 output_dim_task_net = 1 output_dim_task_decoder = input_dim + 1 first_order = False inner_loop_grad_clip = 20 for trial in range(lower_trial, upper_trial): output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MMAML/" + str( trial) + "/" save_model_file_ = output_directory + save_model_file save_model_file_encoder = output_directory + "encoder_" + save_model_file load_model_file_ = output_directory + load_model_file checkpoint_file = output_directory + "checkpoint_" + save_model_file.split( ".")[0] writer = SummaryWriter() try: os.mkdir(output_directory) except OSError as error: print(error) task_net = LSTMModel(batch_size=batch_size, seq_len=window_size, input_dim=input_dim_task_net, n_layers=n_layers_task_net, hidden_dim=hidden_dim_task_net, output_dim=output_dim_task_net) task_encoder = LSTMModel(batch_size=batch_size, seq_len=task_size, input_dim=input_dim_task_encoder, n_layers=n_layers_task_encoder, hidden_dim=hidden_dim_encoder, output_dim=1) task_decoder = LSTMDecoder(batch_size=1, n_layers=n_layers_task_decoder, seq_len=task_size, output_dim=output_dim_task_decoder, hidden_dim=hidden_dim_encoder, latent_dim=hidden_dim_decoder, device=device) lmbd = Lambda(hidden_dim_encoder, hidden_dim_task_net) multimodal_learner = MultimodalLearner(task_net, task_encoder, task_decoder, lmbd, modulate_task_net) multimodal_learner.to(device) output_layer = LinearModel(120, 1) opt = torch.optim.Adam(list(multimodal_learner.parameters()) + list(output_layer.parameters()), lr=meta_learning_rate) meta_learner = MetaLearner(output_layer, opt, learning_rate, loss_fn, first_order, n_inner_iter, inner_loop_grad_clip, device) early_stopping = EarlyStopping(patience=stopping_patience, model_file=save_model_file_, verbose=True) early_stopping_encoder = EarlyStopping( patience=stopping_patience, model_file=save_model_file_encoder, verbose=True) task_data_train = torch.FloatTensor( get_task_encoder_input(train_data_ML)) task_data_validation = torch.FloatTensor( get_task_encoder_input(validation_data_ML)) task_data_test = torch.FloatTensor( get_task_encoder_input(test_data_ML)) val_loss_hist = [] test_loss_hist = [] for epoch in range(epochs): multimodal_learner.train() batch_idx = np.random.randint(0, total_tasks - 1, batch_size) task = task_data_train[batch_idx].cuda() x_spt, y_spt = train_data_ML[batch_idx] x_qry, y_qry = train_data_ML[batch_idx + 1] x_spt, y_spt = to_torch(x_spt), to_torch(y_spt) x_qry = to_torch(x_qry) y_qry = to_torch(y_qry) # data augmentation epsilon = grid[np.random.randint(0, len(grid))] if noise_type == "additive": y_spt = y_spt + epsilon y_qry = y_qry + epsilon else: y_spt = y_spt * (1 + epsilon) y_qry = y_qry * (1 + epsilon) x_spt_encodings = [] x_qry_encodings = [] vrae_loss_accum = 0.0 for i in range(batch_size): x_spt_encoding, (vrae_loss, kl_loss, rec_loss) = multimodal_learner( x_spt[i], task[i:i + 1], output_encoding=True) x_spt_encodings.append(x_spt_encoding) vrae_loss_accum += vrae_loss x_qry_encoding, _ = multimodal_learner(x_qry[i], task[i:i + 1], output_encoding=True) x_qry_encodings.append(x_qry_encoding) train_tasks = [ Task(x_spt_encodings[i], y_spt[i]) for i in range(x_spt.shape[0]) ] val_tasks = [ Task(x_qry_encodings[i], y_qry[i]) for i in range(x_qry.shape[0]) ] # print(vrae_loss) adapted_params = meta_learner.adapt(train_tasks) mean_loss = meta_learner.step(adapted_params, val_tasks, is_training=True, additional_loss_term=weight_vrae * vrae_loss_accum / batch_size) ##plotting grad of output layer for tag, parm in output_layer.linear.named_parameters(): writer.add_histogram("Grads_output_layer_" + tag, parm.grad.data.cpu().numpy(), epoch) multimodal_learner.eval() val_loss = test(validation_data_ML, multimodal_learner, meta_learner, task_data_validation) test_loss = test(test_data_ML, multimodal_learner, meta_learner, task_data_test) print("Epoch:", epoch) print("Train loss:", mean_loss) print("Val error:", val_loss) print("Test error:", test_loss) early_stopping(val_loss, meta_learner) early_stopping_encoder(val_loss, multimodal_learner) val_loss_hist.append(val_loss) test_loss_hist.append(test_loss) if early_stopping.early_stop: print("Early stopping") break writer.add_scalar("Loss/train", mean_loss.cpu().detach().numpy(), epoch) writer.add_scalar("Loss/val", val_loss, epoch) writer.add_scalar("Loss/test", test_loss, epoch) multimodal_learner.load_state_dict(torch.load(save_model_file_encoder)) output_layer.load_state_dict( torch.load(save_model_file_)["model_state_dict"]) meta_learner = MetaLearner(output_layer, opt, learning_rate, loss_fn, first_order, n_inner_iter, inner_loop_grad_clip, device) val_loss = test(validation_data_ML, multimodal_learner, meta_learner, task_data_validation) test_loss = test(test_data_ML, multimodal_learner, meta_learner, task_data_test) with open(output_directory + "/results3.txt", "a+") as f: f.write("Dataset :%s \n" % dataset_name) f.write("Test error: %f \n" % test_loss) f.write("Val error: %f \n" % val_loss) f.write("\n") writer.add_hparams( { "fast_lr": learning_rate, "slow_lr": meta_learning_rate, "adaption_steps": n_inner_iter, "patience": stopping_patience, "weight_vrae": weight_vrae, "noise_level": noise_level, "dataset": dataset_name, "trial": trial }, { "val_loss": val_loss, "test_loss": test_loss })
def main(): wandb.init(project="ofsl-implementation", entity="joeljosephjin") args, unparsed = FLAGS.parse_known_args() if len(unparsed) != 0: raise NameError("Argument {} not recognized".format(unparsed)) if args.seed is None: args.seed = random.randint(0, 1e3) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if args.cpu: args.dev = torch.device('cpu') else: if not torch.cuda.is_available(): raise RuntimeError("GPU unavailable.") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False args.dev = torch.device('cuda') logger = GOATLogger(args) # Load train/validation/test tasksets using the benchmark interface tasksets = l2l.vision.benchmarks.get_tasksets('mini-imagenet', train_ways=args.n_class, train_samples=2*args.n_shot, test_ways=args.n_class, test_samples=2*args.n_shot, root='~/data', ) # Set up learner, meta-learner learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev) learner_wo_grad = copy.deepcopy(learner_w_grad) metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev) # gets the model parameters in a concatenated torch list; then pushes it into cI.data metalearner.metalstm.init_cI(learner_w_grad.get_flat_params()) # Set up loss, optimizer, learning rate scheduler optim = torch.optim.Adam(metalearner.parameters(), args.lr) if args.resume: logger.loginfo("Initialized from: {}".format(args.resume)) last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev) if args.mode == 'test': _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) return best_acc = 0.0 logger.loginfo("Start training") # Meta-training for eps in range(50000): # episode_x.shape = [n_class, n_shot + n_eval, c, h, w] # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED batch = tasksets.train.sample() adapt_x, adapt_y, eval_x, eval_y = process_batch(batch, args) # print('len(adapt_x)',len(adapt_x)) # 25 # Train learner with metalearner learner_w_grad.reset_batch_stats() learner_wo_grad.reset_batch_stats() learner_w_grad.train() learner_wo_grad.train() cI = metalearner.metalstm.cI.data h = None for _ in range(args.epoch): # get the loss/grad # copy from cell state to model.parameters learner_w_grad.copy_flat_params(cI) # do a forward pass and get the loss output = learner_w_grad(adapt_x) loss = learner_w_grad.criterion(output, adapt_y) acc = accuracy(output, adapt_y) # populate the gradients learner_w_grad.zero_grad() loss.backward() # get the grad from the lwg.parameters grad = torch.cat([p.grad.data.view(-1) / args.batch_size for p in learner_w_grad.parameters()], 0) # preprocess grad & loss and metalearner forward grad_prep = preprocess_grad_loss(grad) # [n_learner_params, 2] loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2] # push the loss, grad thru the metalearner cI, h = metalearner([loss_prep, grad_prep, grad.unsqueeze(1)], h) # Train meta-learner with validation loss # same as copy_flat_params; only diff = parameters of the model are not nn.Params anymore, they're just plain tensors now. learner_wo_grad.transfer_params(learner_w_grad, cI) # do a forward pass and get the loss output = learner_wo_grad(eval_x) loss = learner_wo_grad.criterion(output, eval_y) acc = accuracy(output, eval_y) # update the metalearner aka the metalstm optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip) optim.step() # loggers logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train') wandb.log({'loss':loss.item(), 'accuracy':acc}, step=eps) # Meta-validation if eps % args.val_freq == 0 and eps != 0: save_ckpt(eps, metalearner, optim, args.save) val_batch = tasksets.train.sample() acc, test_loss = meta_test(eps, val_batch, learner_w_grad, learner_wo_grad, metalearner, args, logger) wandb.log({'test_loss':test_loss.item(), 'test_accuracy':acc}, step=eps) if acc > best_acc: best_acc = acc logger.loginfo("* Best accuracy so far *\n") logger.loginfo("Done")
from metalearner import MetaLearner import argparse from config.mujoco import args_mujoco_cheetah_joint_varibad, args_mujoco_cheetah_hfield_varibad, \ args_mujoco_cheetah_blocks_varibad import matplotlib matplotlib.use('Agg') parser = argparse.ArgumentParser() parser.add_argument('--env-type', default='gridworld_varibad') args, rest_args = parser.parse_known_args() env = args.env_type #args = args_mujoco_cheetah_joint_varibad.get_args(rest_args) #args = args_mujoco_cheetah_hfield_varibad.get_args(rest_args) args = args_mujoco_cheetah_blocks_varibad.get_args(rest_args) metalearner = MetaLearner(args) metalearner.load_and_render(load_iter=4000) #metalearner.load(load_iter=3500)
def main(): args, unparsed = FLAGS.parse_known_args() if len(unparsed) != 0: raise NameError("Argument {} not recognized".format(unparsed)) if args.seed is None: args.seed = random.randint(0, 1e3) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cpu: args.dev = torch.device('cpu') else: if not torch.cuda.is_available(): raise RuntimeError("GPU unavailable.") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False args.dev = torch.device('cuda') logger = GOATLogger(args) # Get data #train_loader, val_loader, test_loader = prepare_data(args) dataset = StateDataset(epoch_len=10240, batch_size=1) train_loader = torch.utils.data.DataLoader(dataset) # Set up learner, meta-learner learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev) learner_wo_grad = copy.deepcopy(learner_w_grad) metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev) metalearner.metalstm.init_cI(learner_w_grad.get_flat_params()) A = torch.tensor(np.array([[1.0,1.0],[0.0,1.0]])).float().to(args.dev) B = torch.tensor(np.array([[0.0],[1.0]])).float().to(args.dev) Q = torch.tensor(0.01*np.diag([1.0, 1.0])).float().to(args.dev) R = torch.tensor([[0.1]]).float().to(args.dev) # Set up loss, optimizer, learning rate scheduler optim = torch.optim.Adam(metalearner.parameters(), args.lr) if args.resume: logger.loginfo("Initialized from: {}".format(args.resume)) last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev) if args.mode == 'test': _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) return best_acc = 0.0 logger.loginfo("Start training") # Meta-training for eps, X in enumerate(train_loader): # episode_x.shape = [n_class, n_shot + n_eval, c, h, w] # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED #train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :] #train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot] #test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :] #test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval] X = X.float().to(args.dev) # Train learner with metalearner learner_w_grad.reset_batch_stats() learner_wo_grad.reset_batch_stats() learner_w_grad.train() learner_wo_grad.train() cI = train_learner(learner_w_grad, metalearner, X, args) x = X[0,0:1] # Train meta-learner with validation loss learner_wo_grad.transfer_params(learner_w_grad, cI) T = 15 x_list = [] x_list.append(x) u_list = [] for i in range(T): u = learner_wo_grad(x_list[i]) x_next = A@x_list[i].T + [email protected] x_list.append(x_next.T) u_list.append(u) #output = learner_wo_grad(X) loss = learner_wo_grad.criterion(x_list, u_list, A, B, Q, R) #acc = accuracy(output, test_target) if(eps%1000==0): print('loss: ', loss.item()) print(x_list) optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip) optim.step() if(eps%100==0): print(eps) #logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), phase='train') # Meta-validation #if eps % args.val_freq == 0 and eps != 0: # save_ckpt(eps, metalearner, optim, args.save) #acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) #if acc > best_acc: # best_acc = acc # logger.loginfo("* Best accuracy so far *\n") logger.loginfo("Done")
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env-type', default='gridworld_varibad') args, rest_args = parser.parse_known_args() env = args.env_type # --- GridWorld --- if env == 'gridworld_belief_oracle': args = args_grid_belief_oracle.get_args(rest_args) elif env == 'gridworld_varibad': args = args_grid_varibad.get_args(rest_args) elif env == 'gridworld_rl2': args = args_grid_rl2.get_args(rest_args) # --- PointRobot 2D Navigation --- elif env == 'pointrobot_multitask': args = args_pointrobot_multitask.get_args(rest_args) elif env == 'pointrobot_varibad': args = args_pointrobot_varibad.get_args(rest_args) elif env == 'pointrobot_rl2': args = args_pointrobot_rl2.get_args(rest_args) elif env == 'pointrobot_humplik': args = args_pointrobot_humplik.get_args(rest_args) # --- MUJOCO --- # - CheetahDir - elif env == 'cheetah_dir_multitask': args = args_cheetah_dir_multitask.get_args(rest_args) elif env == 'cheetah_dir_expert': args = args_cheetah_dir_expert.get_args(rest_args) elif env == 'cheetah_dir_varibad': args = args_cheetah_dir_varibad.get_args(rest_args) elif env == 'cheetah_dir_rl2': args = args_cheetah_dir_rl2.get_args(rest_args) # # - CheetahVel - elif env == 'cheetah_vel_multitask': args = args_cheetah_vel_multitask.get_args(rest_args) elif env == 'cheetah_vel_expert': args = args_cheetah_vel_expert.get_args(rest_args) elif env == 'cheetah_vel_avg': args = args_cheetah_vel_avg.get_args(rest_args) elif env == 'cheetah_vel_varibad': args = args_cheetah_vel_varibad.get_args(rest_args) elif env == 'cheetah_vel_rl2': args = args_cheetah_vel_rl2.get_args(rest_args) # # - AntDir - elif env == 'ant_dir_multitask': args = args_ant_dir_multitask.get_args(rest_args) elif env == 'ant_dir_expert': args = args_ant_dir_expert.get_args(rest_args) elif env == 'ant_dir_varibad': args = args_ant_dir_varibad.get_args(rest_args) elif env == 'ant_dir_rl2': args = args_ant_dir_rl2.get_args(rest_args) # # - AntGoal - elif env == 'ant_goal_multitask': args = args_ant_goal_multitask.get_args(rest_args) elif env == 'ant_goal_expert': args = args_ant_goal_expert.get_args(rest_args) elif env == 'ant_goal_varibad': args = args_ant_goal_varibad.get_args(rest_args) elif env == 'ant_goal_humplik': args = args_ant_goal_humplik.get_args(rest_args) elif env == 'ant_goal_rl2': args = args_ant_goal_rl2.get_args(rest_args) # # - Walker - elif env == 'walker_multitask': args = args_walker_multitask.get_args(rest_args) elif env == 'walker_expert': args = args_walker_expert.get_args(rest_args) elif env == 'walker_avg': args = args_walker_avg.get_args(rest_args) elif env == 'walker_varibad': args = args_walker_varibad.get_args(rest_args) elif env == 'walker_rl2': args = args_walker_rl2.get_args(rest_args) # # - HumanoidDir - elif env == 'humanoid_dir_multitask': args = args_humanoid_dir_multitask.get_args(rest_args) elif env == 'humanoid_dir_expert': args = args_humanoid_dir_expert.get_args(rest_args) elif env == 'humanoid_dir_varibad': args = args_humanoid_dir_varibad.get_args(rest_args) elif env == 'humanoid_dir_rl2': args = args_humanoid_dir_rl2.get_args(rest_args) else: raise Exception("Invalid Environment") # warning for deterministic execution if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') # if we're normalising the actions, we have to make sure that the env expects actions within [-1, 1] if args.norm_actions_pre_sampling or args.norm_actions_post_sampling: envs = make_vec_envs( env_name=args.env_name, seed=0, num_processes=args.num_processes, gamma=args.policy_gamma, device='cpu', episodes_per_task=args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=None, ) assert np.unique(envs.action_space.low) == [-1] assert np.unique(envs.action_space.high) == [1] # clean up arguments if args.disable_metalearner or args.disable_decoder: args.decode_reward = False args.decode_state = False args.decode_task = False if hasattr(args, 'decode_only_past') and args.decode_only_past: args.split_batches_by_elbo = True # if hasattr(args, 'vae_subsample_decodes') and args.vae_subsample_decodes: # args.split_batches_by_elbo = True # begin training (loop through all passed seeds) seed_list = [args.seed] if isinstance(args.seed, int) else args.seed for seed in seed_list: print('training', seed) args.seed = seed args.action_space = None if args.disable_metalearner: # If `disable_metalearner` is true, the file `learner.py` will be used instead of `metalearner.py`. # This is a stripped down version without encoder, decoder, stochastic latent variables, etc. learner = Learner(args) else: learner = MetaLearner(args) learner.train()
def main(): args, unparsed = FLAGS.parse_known_args() if len(unparsed) != 0: raise NameError("Argument {} not recognized".format(unparsed)) if args.seed is None: args.seed = random.randint(0, 1e3) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cpu: args.dev = torch.device('cpu') else: if not torch.cuda.is_available(): raise RuntimeError("GPU unavailable.") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False args.dev = torch.device('cuda') #logger = GOATLogger(args) use_qrnn = True # Get data train_loader, val_loader, test_loader = prepare_data(args) # Set up learner, meta-learner learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev) learner_wo_grad = copy.deepcopy(learner_w_grad) metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0), use_qrnn).to(args.dev) metalearner.metalstm.init_cI(learner_w_grad.get_flat_params()) # Set up loss, optimizer, learning rate scheduler optim = torch.optim.Adam(metalearner.parameters(), args.lr) if args.resume: #logger.loginfo("Initialized from: {}".format(args.resume)) last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev) if args.mode == 'test': #_ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) return best_acc = 0.0 print("Starting training...") print("Shots: ", args.n_shot) print("Classes: ", args.n_class) start_time = datetime.now() # Meta-training for eps, (episode_x, episode_y) in enumerate(train_loader): # episode_x.shape = [n_class, n_shot + n_eval, c, h, w] # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED train_input = episode_x[:, :args.n_shot].reshape( -1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :] train_target = torch.LongTensor( np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot] test_input = episode_x[:, args.n_shot:].reshape( -1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :] test_target = torch.LongTensor( np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval] # Train learner with metalearner learner_w_grad.reset_batch_stats() learner_wo_grad.reset_batch_stats() learner_w_grad.train() learner_wo_grad.train() cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args) # Train meta-learner with validation loss learner_wo_grad.transfer_params(learner_w_grad, cI) output = learner_wo_grad(test_input) loss = learner_wo_grad.criterion(output, test_target) acc = accuracy(output, test_target) optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip) optim.step() if ((eps + 1) % 250 == 0 or eps == 0): print(eps + 1, "/", args.episode, " Loss: ", loss.item(), " Acc:", acc) #logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train') # Meta-validation if ((eps + 1) % args.val_freq == 0 and eps != 0) or eps + 1 == args.episode: #save_ckpt(eps, metalearner, optim, args.save) acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args) print("Meta validation: ", eps + 1, " Acc: ", acc) if acc > best_acc: best_acc = acc print(" New best: ", acc) # logger.loginfo("* Best accuracy so far *\n") end_time = datetime.now() print("Time to execute: ", end_time - start_time) print("Average per iteration", (end_time - start_time) / args.episode) torch.cuda.empty_cache() #acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args) print("Training complete, best acc: ", best_acc)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env-type', default='ant_dir_rl2') args, rest_args = parser.parse_known_args() env = args.env_type # --- GridWorld --- if env == 'gridworld_oracle': args = args_grid_oracle.get_args(rest_args) elif env == 'gridworld_belief_oracle': args = args_grid_belief_oracle.get_args(rest_args) elif env == 'gridworld_varibad': args = args_grid_varibad.get_args(rest_args) elif env == 'gridworld_rl2': args = args_grid_rl2.get_args(rest_args) # --- MUJOCO --- # - AntDir - elif env == 'ant_dir_oracle': args = args_ant_dir_oracle.get_args(rest_args) elif env == 'ant_dir_rl2': args = args_ant_dir_rl2.get_args(rest_args) elif env == 'ant_dir_varibad': args = args_ant_dir_varibad.get_args(rest_args) # # - AntGoal - elif env == 'ant_goal_oracle': args = args_ant_goal_oracle.get_args(rest_args) elif env == 'ant_goal_varibad': args = args_ant_goal_varibad.get_args(rest_args) elif env == 'ant_goal_rl2': args = args_ant_goal_rl2.get_args(rest_args) # # - CheetahDir - elif env == 'cheetah_dir_oracle': args = args_cheetah_dir_oracle.get_args(rest_args) elif env == 'cheetah_dir_rl2': args = args_cheetah_dir_rl2.get_args(rest_args) elif env == 'cheetah_dir_varibad': args = args_cheetah_dir_varibad.get_args(rest_args) # # - CheetahVel - elif env == 'cheetah_vel_oracle': args = args_cheetah_vel_oracle.get_args(rest_args) elif env == 'cheetah_vel_rl2': args = args_cheetah_vel_rl2.get_args(rest_args) elif env == 'cheetah_vel_varibad': args = args_cheetah_vel_varibad.get_args(rest_args) elif env == 'cheetah_vel_avg': args = args_cheetah_vel_avg.get_args(rest_args) # # - Walker - elif env == 'walker_oracle': args = args_walker_oracle.get_args(rest_args) elif env == 'walker_avg': args = args_walker_avg.get_args(rest_args) elif env == 'walker_rl2': args = args_walker_rl2.get_args(rest_args) elif env == 'walker_varibad': args = args_walker_varibad.get_args(rest_args) # warning for deterministic execution if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError('If you want fully deterministic code, use num_processes 1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') # clean up arguments if hasattr(args, 'disable_decoder') and args.disable_decoder: args.decode_reward = False args.decode_state = False args.decode_task = False if hasattr(args, 'decode_only_past') and args.decode_only_past: args.split_batches_by_elbo = True # if hasattr(args, 'vae_subsample_decodes') and args.vae_subsample_decodes: # args.split_batches_by_elbo = True # begin training (loop through all passed seeds) seed_list = [args.seed] if isinstance(args.seed, int) else args.seed for seed in seed_list: print('training', seed) args.seed = seed if args.disable_metalearner: # If `disable_metalearner` is true, the file `learner.py` will be used instead of `metalearner.py`. # This is a stripped down version without encoder, decoder, stochastic latent variables, etc. learner = Learner(args) else: learner = MetaLearner(args) learner.train()
def main(): args, unparsed = FLAGS.parse_known_args() args = brandos_load(args) if len(unparsed) != 0: raise NameError("Argument {} not recognized".format(unparsed)) if args.seed is None: args.seed = random.randint(0, 1e3) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) #args.dev = torch.device('cpu') if args.cpu: args.dev = torch.device('cpu') args.gpu_name = args.dev else: if not torch.cuda.is_available(): raise RuntimeError("GPU unavailable.") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False args.dev = torch.device('cuda') try: args.gpu_name = torch.cuda.get_device_name(0) except: args.gpu_name = args.dev print(f'device {args.dev}') logger = GOATLogger(args) # Get data train_loader, val_loader, test_loader = prepare_data(args) # Set up learner, meta-learner learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev) learner_wo_grad = copy.deepcopy(learner_w_grad) metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to( args.dev) metalearner.metalstm.init_cI(learner_w_grad.get_flat_params()) # Set up loss, optimizer, learning rate scheduler optim = torch.optim.Adam(metalearner.parameters(), args.lr) if args.resume: logger.loginfo("Initialized from: {}".format(args.resume)) last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev) if args.mode == 'test': _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) return best_acc = 0.0 logger.loginfo("---> Start training") # Meta-training for eps, (episode_x, episode_y) in enumerate( train_loader ): # sample data set split episode_x = D = (D^{train},D^{test}) print(f'episode = {eps}') #print(f'episode_y = {episode_y}') # print(f'episide_x.size() = {episode_x.size()}') # episide_x.size() = torch.Size([5, 20, 3, 84, 84]) i.e. N classes for K shot task with K_eval query examples # print(f'episode_x.mean() = {episode_x.mean()}') # episode_x.shape = [n_class, n_shot + n_eval, c, h, w] # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED train_input = episode_x[:, :args.n_shot].reshape( -1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :] train_target = torch.LongTensor( np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot] test_input = episode_x[:, args.n_shot:].reshape( -1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :] test_target = torch.LongTensor( np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval] # Train learner with metalearner learner_w_grad.reset_batch_stats() learner_wo_grad.reset_batch_stats() learner_w_grad.train() learner_wo_grad.train() cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args) # Train meta-learner with validation loss learner_wo_grad.transfer_params(learner_w_grad, cI) output = learner_wo_grad(test_input) loss = learner_wo_grad.criterion(output, test_target) acc = accuracy(output, test_target) optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip) optim.step() logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train') # Meta-validation if eps % args.val_freq == 0 and eps != 0: save_ckpt(eps, metalearner, optim, args.save) acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) if acc > best_acc: best_acc = acc logger.loginfo(f"* Best accuracy so far {acc}*\n") logger.loginfo(f'acc: {acc}') logger.loginfo(f"* Best accuracy so far {acc}*\n") logger.loginfo("Done")
def main(args): print('starting....') utils.set_seed(args.seed, cudnn=args.make_deterministic) continuous_actions = (args.env_name in ['AntVel-v1', 'AntDir-v1', 'AntPos-v0', 'HalfCheetahVel-v1', 'HalfCheetahDir-v1', '2DNavigation-v0']) # subfolders for logging method_used = 'maml' if args.maml else 'cavia' num_context_params = str(args.num_context_params) + '_' if not args.maml else '' output_name = num_context_params + 'lr=' + str(args.fast_lr) + 'tau=' + str(args.tau) output_name += '_' + datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S') dir_path = os.path.dirname(os.path.realpath(__file__)) log_folder = os.path.join(os.path.join(dir_path, 'logs'), args.env_name, method_used, output_name) save_folder = os.path.join(os.path.join(dir_path, 'saves'), output_name) if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(log_folder): os.makedirs(log_folder) # initialise tensorboard writer writer = SummaryWriter(log_folder) # save config file with open(os.path.join(save_folder, 'config.json'), 'w') as f: config = {k: v for (k, v) in vars(args).items() if k != 'device'} config.update(device=args.device.type) json.dump(config, f, indent=2) with open(os.path.join(log_folder, 'config.json'), 'w') as f: config = {k: v for (k, v) in vars(args).items() if k != 'device'} config.update(device=args.device.type) json.dump(config, f, indent=2) sampler = BatchSampler(args.env_name, batch_size=args.fast_batch_size, num_workers=args.num_workers, device=args.device, seed=args.seed) if continuous_actions: if not args.maml: policy = CaviaMLPPolicy( int(np.prod(sampler.envs.observation_space.shape)), int(np.prod(sampler.envs.action_space.shape)), hidden_sizes=(args.hidden_size,) * args.num_layers, num_context_params=args.num_context_params, device=args.device ) else: policy = NormalMLPPolicy( int(np.prod(sampler.envs.observation_space.shape)), int(np.prod(sampler.envs.action_space.shape)), hidden_sizes=(args.hidden_size,) * args.num_layers ) else: if not args.maml: raise NotImplementedError else: policy = CategoricalMLPPolicy( int(np.prod(sampler.envs.observation_space.shape)), sampler.envs.action_space.n, hidden_sizes=(args.hidden_size,) * args.num_layers) # initialise baseline baseline = LinearFeatureBaseline(int(np.prod(sampler.envs.observation_space.shape))) # initialise meta-learner metalearner = MetaLearner(sampler, policy, baseline, gamma=args.gamma, fast_lr=args.fast_lr, tau=args.tau, device=args.device) for batch in range(args.num_batches): # get a batch of tasks tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size) # do the inner-loop update for each task # this returns training (before update) and validation (after update) episodes episodes, inner_losses = metalearner.sample(tasks, first_order=args.first_order) # take the meta-gradient step outer_loss = metalearner.step(episodes, max_kl=args.max_kl, cg_iters=args.cg_iters, cg_damping=args.cg_damping, ls_max_steps=args.ls_max_steps, ls_backtrack_ratio=args.ls_backtrack_ratio) # -- logging curr_returns = total_rewards(episodes, interval=True) print(' return after update: ', curr_returns[0][1]) # Tensorboard writer.add_scalar('policy/actions_train', episodes[0][0].actions.mean(), batch) writer.add_scalar('policy/actions_test', episodes[0][1].actions.mean(), batch) writer.add_scalar('running_returns/before_update', curr_returns[0][0], batch) writer.add_scalar('running_returns/after_update', curr_returns[0][1], batch) writer.add_scalar('running_cfis/before_update', curr_returns[1][0], batch) writer.add_scalar('running_cfis/after_update', curr_returns[1][1], batch) writer.add_scalar('loss/inner_rl', np.mean(inner_losses), batch) writer.add_scalar('loss/outer_rl', outer_loss.item(), batch) # -- evaluation # evaluate for multiple update steps if batch % args.test_freq == 0: test_tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size) test_episodes = metalearner.test(test_tasks, num_steps=args.num_test_steps, batch_size=args.test_batch_size, halve_lr=args.halve_test_lr) all_returns = total_rewards(test_episodes, interval=True) for num in range(args.num_test_steps + 1): writer.add_scalar('evaluation_rew/avg_rew ' + str(num), all_returns[0][num], batch) writer.add_scalar('evaluation_cfi/avg_rew ' + str(num), all_returns[1][num], batch) print(' inner RL loss:', np.mean(inner_losses)) print(' outer RL loss:', outer_loss.item()) # -- save policy network with open(os.path.join(save_folder, 'policy-{0}.pt'.format(batch)), 'wb') as f: torch.save(policy.state_dict(), f)
def metalight_train(dic_exp_conf, dic_agent_conf, _dic_traffic_env_conf, _dic_path, tasks, batch_id): ''' metalight meta-train function Arguments: dic_exp_conf: dict, configuration of this experiment dic_agent_conf: dict, configuration of agent _dic_traffic_env_conf: dict, configuration of traffic environment _dic_path: dict, path of source files and output files tasks: list, traffic files name in this round batch_id: int, round number ''' tot_path = [] tot_traffic_env = [] for task in tasks: dic_traffic_env_conf = copy.deepcopy(_dic_traffic_env_conf) dic_path = copy.deepcopy(_dic_path) dic_path.update({ "PATH_TO_DATA": os.path.join(dic_path['PATH_TO_DATA'], task.split(".")[0]) }) # parse roadnet dic_traffic_env_conf["ROADNET_FILE"] = dic_traffic_env_conf[ "traffic_category"]["traffic_info"][task][2] dic_traffic_env_conf["FLOW_FILE"] = dic_traffic_env_conf[ "traffic_category"]["traffic_info"][task][3] roadnet_path = os.path.join( dic_path['PATH_TO_DATA'], dic_traffic_env_conf["traffic_category"] ["traffic_info"][task][2]) # dic_traffic_env_conf['ROADNET_FILE']) lane_phase_info = parse_roadnet(roadnet_path) dic_traffic_env_conf["LANE_PHASE_INFO"] = lane_phase_info[ "intersection_1_1"] dic_traffic_env_conf["num_lanes"] = int( len(lane_phase_info["intersection_1_1"]["start_lane"]) / 4) # num_lanes per direction dic_traffic_env_conf["num_phases"] = len( lane_phase_info["intersection_1_1"]["phase"]) dic_traffic_env_conf["TRAFFIC_FILE"] = task tot_path.append(dic_path) tot_traffic_env.append(dic_traffic_env_conf) sampler = BatchSampler(dic_exp_conf=dic_exp_conf, dic_agent_conf=dic_agent_conf, dic_traffic_env_conf=tot_traffic_env, dic_path=tot_path, batch_size=args.fast_batch_size, num_workers=args.num_workers) policy = config.DIC_AGENTS[args.algorithm]( dic_agent_conf=dic_agent_conf, dic_traffic_env_conf=tot_traffic_env, dic_path=tot_path) metalearner = MetaLearner(sampler, policy, dic_agent_conf=dic_agent_conf, dic_traffic_env_conf=tot_traffic_env, dic_path=tot_path) if batch_id == 0: params = pickle.load( open(os.path.join(dic_path['PATH_TO_MODEL'], 'params_init.pkl'), 'rb')) params = [params] * len(policy.policy_inter) metalearner.meta_params = params metalearner.meta_target_params = params else: params = pickle.load( open( os.path.join(dic_path['PATH_TO_MODEL'], 'params_%d.pkl' % (batch_id - 1)), 'rb')) params = [params] * len(policy.policy_inter) metalearner.meta_params = params period = dic_agent_conf['PERIOD'] target_id = int((batch_id - 1) / period) meta_params = pickle.load( open( os.path.join(dic_path['PATH_TO_MODEL'], 'params_%d.pkl' % (target_id * period)), 'rb')) meta_params = [meta_params] * len(policy.policy_inter) metalearner.meta_target_params = meta_params metalearner.sample_metalight(tasks, batch_id)
def main(args): dataset_name = args.dataset model_name = args.model n_inner_iter = args.adaptation_steps batch_size = args.batch_size save_model_file = args.save_model_file load_model_file = args.load_model_file lower_trial = args.lower_trial upper_trial = args.upper_trial is_test = args.is_test stopping_patience = args.stopping_patience epochs = args.epochs fast_lr = args.learning_rate slow_lr = args.meta_learning_rate first_order = False inner_loop_grad_clip = 20 task_size = 50 output_dim = 1 horizon = 10 ##test meta_info = { "POLLUTION": [5, 50, 14], "HR": [32, 50, 13], "BATTERY": [20, 50, 3] } assert model_name in ("FCN", "LSTM"), "Model was not correctly specified" assert dataset_name in ("POLLUTION", "HR", "BATTERY") window_size, task_size, input_dim = meta_info[dataset_name] output_directory = "output/" train_data_ML = pickle.load( open( "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) validation_data_ML = pickle.load( open( "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) test_data_ML = pickle.load( open( "../../Data/TEST-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) for trial in range(lower_trial, upper_trial): output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str( trial) + "/" save_model_file_ = output_directory + save_model_file load_model_file_ = output_directory + load_model_file try: os.mkdir(output_directory) except OSError as error: print(error) with open(output_directory + "/results2.txt", "a+") as f: f.write("Learning rate :%f \n" % fast_lr) f.write("Meta-learning rate: %f \n" % slow_lr) f.write("Adaptation steps: %f \n" % n_inner_iter) f.write("\n") if model_name == "LSTM": model = LSTMModel(batch_size=batch_size, seq_len=window_size, input_dim=input_dim, n_layers=2, hidden_dim=120, output_dim=output_dim) optimizer = torch.optim.Adam(model.parameters(), lr=slow_lr) loss_func = mae torch.backends.cudnn.enabled = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") meta_learner = MetaLearner(model, optimizer, fast_lr, loss_func, first_order, n_inner_iter, inner_loop_grad_clip, device) total_tasks, task_size, window_size, input_dim = train_data_ML.x.shape early_stopping = EarlyStopping(patience=stopping_patience, model_file=save_model_file_, verbose=True) for _ in range(epochs): #train batch_idx = np.random.randint(0, total_tasks - 1, batch_size) x_spt, y_spt = train_data_ML[batch_idx] x_qry, y_qry = train_data_ML[batch_idx + 1] x_spt, y_spt = to_torch(x_spt), to_torch(y_spt) x_qry = to_torch(x_qry) y_qry = to_torch(y_qry) train_tasks = [ Task(x_spt[i], y_spt[i]) for i in range(x_spt.shape[0]) ] val_tasks = [ Task(x_qry[i], y_qry[i]) for i in range(x_qry.shape[0]) ] adapted_params = meta_learner.adapt(train_tasks) mean_loss = meta_learner.step(adapted_params, val_tasks, is_training=True) print(mean_loss) #test val_error = test(validation_data_ML, meta_learner, device) print(val_error) early_stopping(val_error, meta_learner) if early_stopping.early_stop: print("Early stopping") break model.load_state_dict(torch.load(save_model_file_)["model_state_dict"]) meta_learner = MetaLearner(model, optimizer, fast_lr, loss_func, first_order, n_inner_iter, inner_loop_grad_clip, device) validation_error = test(validation_data_ML, meta_learner, device) test_error = test(test_data_ML, meta_learner, device) with open(output_directory + "/results2.txt", "a+") as f: f.write("Test error: %f \n" % test_error) f.write("Validation error: %f \n" % validation_error) print(test_error) print(validation_error)