def extract_features(exp_dir, trial_id, test_name, save_name): '''Saves features after training. ''' print('#################### Feature Extraction {} ####################'.format(trial_id)) # Get trial folder trial_dir = os.path.join(exp_dir, trial_id) assert os.path.isfile(os.path.join(trial_dir, 'summary.json')) # Load config with open(os.path.join(exp_dir, 'configs', '{}.json'.format(trial_id)), 'r') as f: config = json.load(f) data_config = config['data_config'] model_config = config['model_config'] # No need to load training data for feature extraction. data_config["label_train_set"] = False # Load dataset if test_name is not None: data_config["test_name"] = test_name dataset = load_dataset(data_config) dataset.eval() # Load best model state_dict = torch.load(os.path.join( trial_dir, 'best.pth'), map_location=lambda storage, loc: storage) model_class = get_model_class(model_config['name'].lower()) model_config['label_functions'] = dataset.active_label_functions model_config['augmentations'] = dataset.active_augmentations model = model_class(model_config) model.load_state_dict(state_dict) num_samples = 128 loader = DataLoader(dataset, batch_size=num_samples, shuffle=False) (states, actions, labels_dict) = next(iter(loader)) states = states.transpose(0, 1) actions = actions.transpose(0, 1) save_array = np.array([]) for batch_idx, (states, actions, labels_dict) in enumerate(loader): states = states actions = actions labels_dict = {key: value for key, value in labels_dict.items()} states = states.transpose(0, 1) actions = actions.transpose(0, 1) with torch.no_grad(): if len(dataset.active_label_functions) > 0: label_list = [] for lf_idx, lf_name in enumerate(labels_dict): label_list.append(labels_dict[lf_name]) label_input = torch.cat(label_list, -1) encodings_mean, _ = model.encode_mean(states[:-1], actions, labels=label_input) else: encodings_mean, _ = model.encode_mean(states[:-1], actions) if save_array.shape[0] == 0: save_array = encodings_mean else: save_array = np.concatenate([save_array, encodings_mean], axis=0) np.savez(os.path.join(trial_dir, save_name), save_array) print("Saved Features: " + os.path.join(trial_dir, save_name))
def visualize_samples_ctvae(exp_dir, trial_id, num_samples, num_values, repeat_index, burn_in, temperature): print('#################### Trial {} ####################'.format(trial_id)) # Get trial folder trial_dir = os.path.join(exp_dir, trial_id) assert os.path.isfile(os.path.join(trial_dir, 'summary.json')) # Load config with open(os.path.join(exp_dir, 'configs', '{}.json'.format(trial_id)), 'r') as f: config = json.load(f) data_config = config['data_config'] model_config = config['model_config'] # Load dataset dataset = load_dataset(data_config) dataset.eval() print(type(dataset)) # Load best model state_dict = torch.load(os.path.join(trial_dir, 'best.pth'), map_location=lambda storage, loc: storage) model_class = get_model_class(model_config['name'].lower()) assert model_class.requires_labels model_config['label_functions'] = dataset.active_label_functions model = model_class(model_config) model.filter_and_load_state_dict(state_dict) # Load environment env = load_environment(data_config['name']) # TODO make env_config? # TODO for now, assume just one active label function # assert len(dataset.active_label_functions) == 1 # for lf in dataset.active_label_functions: loader = DataLoader(dataset, batch_size=num_samples, shuffle=False) (states, actions, labels_dict) = next(iter(loader)) if repeat_index >= 0: states_single = states[repeat_index].unsqueeze(0) states = states_single.repeat(num_samples, 1, 1) actions_single = actions[repeat_index].unsqueeze(0) actions = actions_single.repeat(num_samples, 1, 1) states = states.transpose(0, 1) actions = actions.transpose(0, 1) y = labels_dict["copulation"] with torch.no_grad(): env.reset(init_state=states[0].clone()) model.reset_policy(labels=y, temperature=args.temperature) rollout_states, rollout_actions = generate_rollout(env, model, burn_in=args.burn_in, burn_in_actions=actions, horizon=actions.size(0)) rollout_states = rollout_states.transpose(0, 1) rollout_actions = rollout_actions.transpose(0, 1) dataset.save( rollout_states, rollout_actions, labels=y, lf_list=dataset.active_label_functions, burn_in=burn_in, save_path=os.path.join(trial_dir, 'results', "copulating"), save_name='repeat_{:03d}_{}'.format(repeat_index, "copulating") if repeat_index >= 0 else '', single_plot=(repeat_index >= 0))
def start_training(save_path, data_config, model_config, train_config, device, test_code=False): summary = { 'training' : [] } logger = [] # Sample and fix a random seed if not set in train_config if 'seed' not in train_config: train_config['seed'] = random.randint(0, 9999) seed = train_config['seed'] torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True # Initialize dataset dataset = load_dataset(data_config) summary['dataset'] = dataset.summary # Add state and action dims to model config model_config['state_dim'] = dataset.state_dim model_config['action_dim'] = dataset.action_dim # Get model class model_class = get_model_class(model_config['name'].lower()) # Check if model needs labels as input #if model_class.requires_labels: model_config['label_dim'] = dataset.label_dim model_config['label_functions'] = dataset.active_label_functions #if model_class.requires_augmentations: model_config['augmentations'] = dataset.active_augmentations # Initialize model model = model_class(model_config).to(device) summary['model'] = model_config summary['model']['num_parameters'] = model.num_parameters # Initialize dataloaders kwargs = {'num_workers': 8, 'pin_memory': False, 'worker_init_fn': np.random.seed(seed)} if device is not 'cpu' else {} data_loader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True, **kwargs) # Initialize with pretrained model (if specified) if 'pretrained_model' in train_config: print('LOADING pretrained model: {}'.format(train_config['pretrained_model'])) # model_path = os.path.join(os.path.dirname(save_path), train_config['pretrained_model']) model_path = os.path.join(os.path.dirname(os.path.dirname(save_path)), train_config['pretrained_model']) state_dict = torch.load(model_path) model.load_state_dict(state_dict) torch.save(model.state_dict(), os.path.join(save_path, 'best.pth')) # copy over best model # Start training if isinstance(train_config['num_epochs'], int): train_config['num_epochs'] = [train_config['num_epochs']] start_time = time.time() epochs_done = 0 for num_epochs in train_config['num_epochs']: model.prepare_stage(train_config) stage_start_time = time.time() print('##### STAGE {} #####'.format(model.stage)) best_test_log = {} best_test_log_times = [] for epoch in range(num_epochs): epochs_done += 1 print('--- EPOCH [{}/{}] ---'.format(epochs_done, sum(train_config['num_epochs']))) epoch_start_time = time.time() train_log = run_epoch(data_loader, model, device, train=True, early_break=test_code) test_log = run_epoch(data_loader, model, device, train=False, early_break=test_code) epoch_time = time.time() - epoch_start_time print('{:.3f} seconds'.format(epoch_time)) logger.append({ 'epoch' : epochs_done, 'stage' : model.stage, 'train' : train_log, 'test' : test_log, 'time' : epoch_time }) # Save model checkpoints if epochs_done % train_config['checkpoint_freq'] == 0: torch.save(model.state_dict(), os.path.join(save_path, 'checkpoints', 'checkpoint_{}.pth'.format(epochs_done))) print('Checkpoint saved') # Save model with best test loss during stage if epoch == 0 or sum(test_log['losses'].values()) < sum(best_test_log['losses'].values()): best_test_log = test_log best_test_log_times.append(epochs_done) torch.save(model.state_dict(), os.path.join(save_path, 'best.pth')) print('Best model saved') # Save training statistics by stage summary['training'].append({ 'stage' : model.stage, 'num_epochs' : num_epochs, 'stage_time' : round(time.time()-stage_start_time, 3), 'best_test_log_times' : best_test_log_times, 'best_test_log' : best_test_log }) # Load best model for next stage if model.stage < len(train_config['num_epochs']): best_state = torch.load(os.path.join(save_path, 'best.pth')) model.load_state_dict(best_state) torch.save(model.state_dict(), os.path.join(save_path, 'best_stage_{}.pth'.format(model.stage))) # Save final model torch.save(model.state_dict(), os.path.join(save_path,'final.pth')) print('Final model saved') # Save total time summary['total_time'] = round(time.time()-start_time, 3) model_config.pop('label_functions') model_config.pop('augmentations') return summary, logger, data_config, model_config, train_config
def visualize_samples_ctvae(exp_dir, trial_id, num_samples, num_values, repeat_index, burn_in, temperature, bad_experiment=True): print( '#################### Trial {} ####################'.format(trial_id)) # Get trial folder trial_dir = os.path.join(exp_dir, trial_id) assert os.path.isfile(os.path.join(trial_dir, 'summary.json')) # Load config with open(os.path.join(exp_dir, 'configs', '{}.json'.format(trial_id)), 'r') as f: config = json.load(f) data_config = config['data_config'] model_config = config['model_config'] # Load dataset dataset = load_dataset(data_config) dataset.eval() # Load best model state_dict = torch.load(os.path.join(trial_dir, 'best.pth'), map_location=lambda storage, loc: storage) model_class = get_model_class(model_config['name'].lower()) assert model_class.requires_labels model_config['label_functions'] = dataset.active_label_functions model = model_class(model_config) model.filter_and_load_state_dict(state_dict) # Load environment env = load_environment(data_config['name']) # TODO make env_config? loader = DataLoader(dataset, batch_size=num_samples, shuffle=False) (states, actions, labels_dict) = next(iter(loader)) if repeat_index >= 0: states_single = states[repeat_index].unsqueeze(0) states = states_single.repeat(num_samples, 1, 1) actions_single = actions[repeat_index].unsqueeze(0) actions = actions_single.repeat(num_samples, 1, 1) states = states.transpose(0, 1) actions = actions.transpose(0, 1) losses = [] y = labels_dict["copulation"] with torch.no_grad(): for k in range(3): env.reset(init_state=states[0].clone()) model.reset_policy(labels=y, temperature=args.temperature) rollout_states, rollout_actions = generate_rollout( env, model, model2=model, burn_in=args.burn_in, burn_in_actions=actions, horizon=actions.size(0)) rollout_states = rollout_states.transpose(0, 1) rollout_actions = rollout_actions.transpose(0, 1) # if we have a single agent setting, we generate two rollouts and vert stack them if bad_experiment: rollout_states_2, rollout_actions_2 = generate_rollout( env, model, burn_in=args.burn_in, burn_in_actions=actions, horizon=actions.size(0)) rollout_states_2 = rollout_states_2.transpose(0, 1) rollout_actions_2 = rollout_actions_2.transpose(0, 1) stack_tensor_states = torch.cat( (rollout_states, rollout_states_2), dim=2) stack_tensor_action = torch.cat( (rollout_actions, rollout_actions_2), dim=2) rollout_states_3, rollout_actions_3 = generate_rollout( env, model, burn_in=args.burn_in, burn_in_actions=actions, horizon=actions.size(0)) rollout_states_3 = rollout_states_3.transpose(0, 1) rollout_actions_3 = rollout_actions_3.transpose(0, 1) rollout_states_4, rollout_actions_4 = generate_rollout( env, model, burn_in=args.burn_in, burn_in_actions=actions, horizon=actions.size(0)) rollout_states_4 = rollout_states_4.transpose(0, 1) rollout_actions_4 = rollout_actions_4.transpose(0, 1) stack_tensor_states_2 = torch.cat( (rollout_states_3, rollout_states_4), dim=2) stack_tensor_action_2 = torch.cat( (rollout_actions_3, rollout_actions_4), dim=2) final_states_tensor = torch.cat( (stack_tensor_states, stack_tensor_states_2), dim=1) final_actions_tensor = torch.cat( (stack_tensor_action, stack_tensor_action_2), dim=1) losses.append( get_classification_loss(final_states_tensor, final_actions_tensor)) else: losses.append( get_classification_loss(rollout_states, rollout_actions)) print(np.mean(losses))
def compute_stylecon_ctvae(exp_dir, trial_id, args): print('#################### Trial {} ####################'.format(trial_id)) # Get trial folder trial_dir = os.path.join(exp_dir, trial_id) assert os.path.isfile(os.path.join(trial_dir, 'summary.json')) # Load config with open(os.path.join(exp_dir, 'configs', '{}.json'.format(trial_id)), 'r') as f: config = json.load(f) data_config = config['data_config'] model_config = config['model_config'] # Load dataset dataset = load_dataset(data_config) dataset.eval() # Load best model state_dict = torch.load(os.path.join(trial_dir, 'best.pth'), map_location=lambda storage, loc: storage) model_class = get_model_class(model_config['name'].lower()) assert model_class.requires_labels model_config['label_functions'] = dataset.active_label_functions model = model_class(model_config) model.filter_and_load_state_dict(state_dict) # Load environment env = load_environment(data_config['name']) # TODO make env_config? # Load batch loader = DataLoader(dataset, batch_size=args.num_samples, shuffle=True) (states, actions, labels_dict) = next(iter(loader)) states = states.transpose(0,1) actions = actions.transpose(0,1) # Randomly permute labels for independent sampling if args.sampling_mode == 'indep': for lf_name, labels in labels_dict.items(): random_idx = torch.randperm(labels.size(0)) labels_dict[lf_name] = labels[random_idx] labels_concat = torch.cat(list(labels_dict.values()), dim=-1) # MC sample of labels # Generate rollouts with labels with torch.no_grad(): env.reset(init_state=states[0].clone()) model.reset_policy(labels=labels_concat, temperature=args.temperature) rollout_states, rollout_actions = generate_rollout(env, model, burn_in=args.burn_in, burn_in_actions=actions, horizon=actions.size(0)) rollout_states = rollout_states.transpose(0,1) rollout_actions = rollout_actions.transpose(0,1) stylecon_by_sample = torch.ones(args.num_samples) # used to track if ALL categorical labels are self-consistent categorical_lf_count = 0 for lf in dataset.active_label_functions: print('--- {} ---'.format(lf.name)) y = labels_dict[lf.name] # Apply labeling functions on rollouts rollouts_y = lf.label(rollout_states, rollout_actions, batch=True) if lf.categorical: # Compute stylecon for each label class matching_y = y*rollouts_y class_count = torch.sum(y, dim=0) stylecon_class_count = torch.sum(matching_y, dim=0) stylecon_by_class = stylecon_class_count/class_count stylecon_by_class = [round(i,4) for i in stylecon_by_class.tolist()] # Compute stylecon for each sample stylecon_by_sample *= torch.sum(matching_y, dim=1) categorical_lf_count += 1 print('class_sc_cnt:\t {}'.format(stylecon_class_count.int().tolist())) print('class_cnt:\t {}'.format(class_count.int().tolist())) print('class_sc:\t {}'.format(stylecon_by_class)) print('average: {}'.format(torch.sum(stylecon_class_count)/torch.sum(class_count))) else: # Compute stylecon diff = rollouts_y-y print('L1 stylecon {}'.format(torch.mean(torch.abs(diff)).item())) print('L2 stylecon {}'.format(torch.mean(diff**2).item())) # Visualizing stylecon range_lower = dataset.summary['label_functions'][lf.name]['train_dist']['min'] range_upper = dataset.summary['label_functions'][lf.name]['train_dist']['max'] label_values = np.linspace(range_lower, range_upper, args.num_values) rollouts_y_mean = np.zeros(args.num_values) rollouts_y_std = np.zeros(args.num_values) for i, val in enumerate(label_values): # Set labels # TODO this is not MC-sampling, need to do rejection sampling for true computation I think labels_dict_copy = { key: value for key, value in labels_dict.items() } labels_dict_copy[lf.name] = val*torch.ones(args.num_samples, 1) labels_concat = torch.cat(list(labels_dict_copy.values()), dim=-1) # Generate samples with labels with torch.no_grad(): samples = model.generate(x, labels_concat, burn_in=args.burn_in, temperature=args.temperature) samples = samples.transpose(0,1) # Apply labeling functions on samples rollouts_y = lf.label(samples, batch=True) # Compute statistics of labels rollouts_y_mean[i] = torch.mean(rollouts_y).item() rollouts_y_std[i] = torch.std(rollouts_y).item() plt.plot(label_values, label_values, color='b', marker='o') plt.plot(label_values, rollouts_y_mean, color='r', marker='o') plt.fill_between(label_values, rollouts_y_mean-2*rollouts_y_std, rollouts_y_mean+2*rollouts_y_std, color='red', alpha=0.3) plt.xlabel('Input Label') plt.ylabel('Output Label') plt.title('LF_{}, {} samples, 2 stds'.format(lf.name, args.num_samples)) plt.savefig(os.path.join(trial_dir, 'results', '{}.png'.format(lf.name))) plt.close() stylecon_all_count = int(torch.sum(stylecon_by_sample)) print('--- stylecon for {} categorical LFs: {} [{}/{}] ---'.format( categorical_lf_count, stylecon_all_count/args.num_samples, stylecon_all_count, args.num_samples))
def visualize_samples_ctvae(exp_dir, trial_id, num_samples, num_values, repeat_index, burn_in, temperature): print( '#################### Trial {} ####################'.format(trial_id)) # Get trial folder trial_dir = os.path.join(exp_dir, trial_id) assert os.path.isfile(os.path.join(trial_dir, 'summary.json')) # Load config with open(os.path.join(exp_dir, 'configs', '{}.json'.format(trial_id)), 'r') as f: config = json.load(f) data_config = config['data_config'] model_config = config['model_config'] # Load dataset dataset = load_dataset(data_config) dataset.eval() # Load best model state_dict = torch.load(os.path.join(trial_dir, 'best.pth'), map_location=lambda storage, loc: storage) model_class = get_model_class(model_config['name'].lower()) assert model_class.requires_labels model_config['label_functions'] = dataset.active_label_functions model = model_class(model_config) model.filter_and_load_state_dict(state_dict) # Load environment env = load_environment(data_config['name']) # TODO make env_config? # TODO for now, assume just one active label function assert len(dataset.active_label_functions) == 1 for lf in dataset.active_label_functions: loader = DataLoader(dataset, batch_size=num_samples, shuffle=False) (states, actions, labels_dict) = next(iter(loader)) if repeat_index >= 0: states_single = states[repeat_index].unsqueeze(0) states = states_single.repeat(num_samples, 1, 1) actions_single = actions[repeat_index].unsqueeze(0) actions = actions_single.repeat(num_samples, 1, 1) states = states.transpose(0, 1) actions = actions.transpose(0, 1) if lf.categorical: label_values = np.arange(0, lf.output_dim) else: range_lower = torch.min(dataset.lf_labels[lf.name]) range_upper = torch.max(dataset.lf_labels[lf.name]) label_values = np.linspace(range_lower, range_upper, num_values + 2) label_values = np.around(label_values, decimals=1) label_values = label_values[1:-1] for c in label_values: if lf.categorical: y = torch.zeros(num_samples, lf.output_dim) y[:, c] = 1 else: y = c * torch.ones(num_samples, 1) # Generate rollouts with labels with torch.no_grad(): env.reset(init_state=states[0].clone()) model.reset_policy(labels=y, temperature=args.temperature) rollout_states, rollout_actions = generate_rollout( env, model, burn_in=args.burn_in, burn_in_actions=actions, horizon=actions.size(0)) rollout_states = rollout_states.transpose(0, 1) rollout_actions = rollout_actions.transpose(0, 1) dataset.save( rollout_states, rollout_actions, labels=y, lf_list=dataset.active_label_functions, burn_in=burn_in, # save_path=os.path.join(trial_dir, 'results', '{}_label_{}'.format(lf.name, c)), # save_name='repeat_{:03d}'.format(repeat_index) if repeat_index >= 0 else '', save_path=os.path.join(trial_dir, 'results', lf.name), save_name='repeat_{:03d}_{}'.format(repeat_index, c) if repeat_index >= 0 else '', single_plot=(repeat_index >= 0))