def load_model_from_file(save_folder, input_shape, device): meta_file = os.path.join(save_folder, 'metadata.pkl') model_file = os.path.join(save_folder, 'model.pt') train_args = pickle.load(open(meta_file, 'rb'))['args'] model = modules.ContrastiveSWM(embedding_dim=train_args.embedding_dim, hidden_dim=train_args.hidden_dim, action_dim=train_args.action_dim, input_dims=input_shape, num_objects=train_args.num_objects, sigma=train_args.sigma, hinge=train_args.hinge, ignore_action=train_args.ignore_action, copy_action=train_args.copy_action, encoder=train_args.encoder).to(device) model.load_state_dict(torch.load(model_file)) return model, train_args
def evaluate(args, args_eval, model_file): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) ex = None if args_eval.sacred: from sacred import Experiment from sacred.observers import MongoObserver ex = Experiment(args_eval.sacred_name) ex.observers.append( MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME)) ex.add_config({ "batch_size": args.batch_size, "epochs": args.epochs, "learning_rate": args.learning_rate, "encoder": args.encoder, "num_objects": args.num_objects, "custom_neg": args.custom_neg, "in_ep_prob": args.in_ep_prob, "seed": args.seed, "dataset": args.dataset, "save_folder": args.save_folder, "eval_dataset": args_eval.dataset, "num_steps": args_eval.num_steps, "use_action_attention": args.use_action_attention }) device = torch.device('cuda' if args.cuda else 'cpu') dataset = utils.PathDatasetStateIds(hdf5_file=args.dataset, path_length=args_eval.num_steps) eval_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) # Get data sample obs = eval_loader.__iter__().next()[0] input_shape = obs[0][0].size() model = modules.ContrastiveSWM( embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, split_mlp=args.split_mlp, same_ep_neg=args.same_ep_neg, only_same_ep_neg=args.only_same_ep_neg, immovable_bit=args.immovable_bit, split_gnn=args.split_gnn, no_loss_first_two=args.no_loss_first_two, bisim_model=make_pairwise_encoder() if args.bisim_model_path else None, encoder=args.encoder, use_action_attention=args.use_action_attention).to(device) model.load_state_dict(torch.load(model_file)) model.eval() # topk = [1, 5, 10] topk = [1] hits_at = defaultdict(int) num_samples = 0 rr_sum = 0 pred_states = [] next_states = [] next_ids = [] with torch.no_grad(): for batch_idx, data_batch in enumerate(eval_loader): data_batch = [[t.to(device) for t in tensor] for tensor in data_batch] observations, actions, state_ids = data_batch if observations[0].size(0) != args.batch_size: continue obs = observations[0] next_obs = observations[-1] next_id = state_ids[-1] state = model.obj_encoder(model.obj_extractor(obs)) next_state = model.obj_encoder(model.obj_extractor(next_obs)) pred_state = state for i in range(args_eval.num_steps): pred_state = model.forward_transition(pred_state, actions[i]) pred_states.append(pred_state.cpu()) next_states.append(next_state.cpu()) next_ids.append(next_id.cpu().numpy()) pred_state_cat = torch.cat(pred_states, dim=0) next_state_cat = torch.cat(next_states, dim=0) next_ids_cat = np.concatenate(next_ids, axis=0) full_size = pred_state_cat.size(0) # Flatten object/feature dimensions next_state_flat = next_state_cat.view(full_size, -1) pred_state_flat = pred_state_cat.view(full_size, -1) dist_matrix = utils.pairwise_distance_matrix(next_state_flat, pred_state_flat) #num_digits = 1 #dist_matrix = (dist_matrix * 10 ** num_digits).round() / (10 ** num_digits) #dist_matrix = dist_matrix.float() dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1) dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix], dim=1) # Workaround to get a stable sort in numpy. dist_np = dist_matrix_augmented.numpy() indices = [] for row in dist_np: keys = (np.arange(len(row)), row) indices.append(np.lexsort(keys)) indices = np.stack(indices, axis=0) if args_eval.dedup: mask_mistakes = indices[:, 0] != 0 closest_next_ids = next_ids_cat[indices[:, 0] - 1] if len(next_ids_cat.shape) == 2: equal_mask = np.all(closest_next_ids == next_ids_cat, axis=1) else: equal_mask = closest_next_ids == next_ids_cat indices[:, 0][np.logical_and(equal_mask, mask_mistakes)] = 0 indices = torch.from_numpy(indices).long() #print('Processed {} batches of size {}'.format( # batch_idx + 1, args.batch_size)) labels = torch.zeros(indices.size(0), device=indices.device, dtype=torch.int64).unsqueeze(-1) num_samples += full_size #print('Size of current topk evaluation batch: {}'.format( # full_size)) for k in topk: match = indices[:, :k] == labels num_matches = match.sum() hits_at[k] += num_matches.item() match = indices == labels _, ranks = match.max(1) reciprocal_ranks = torch.reciprocal(ranks.double() + 1) rr_sum += reciprocal_ranks.sum().item() pred_states = [] next_states = [] next_ids = [] hits = hits_at[topk[0]] / float(num_samples) mrr = rr_sum / float(num_samples) if ex is not None: # ugly hack @ex.main def sacred_main(): ex.log_scalar("hits", hits) ex.log_scalar("mrr", mrr) ex.run() print('Hits @ {}: {}'.format(topk[0], hits)) print('MRR: {}'.format(mrr)) return hits, mrr
dataset = utils.StateTransitionsDataset(hdf5_file=args.dataset) train_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) # Get data sample obs = train_loader.__iter__().next()[0] input_shape = obs[0].size() model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, encoder=args.encoder).to(device) model.apply(utils.weights_init) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.decoder: if args.encoder == 'large': decoder = modules.DecoderCNNLarge(input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device)
def evaluate(args, args_eval, model_file): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) ex = None if args_eval.sacred: from sacred import Experiment from sacred.observers import MongoObserver ex = Experiment(args_eval.sacred_name) ex.observers.append(MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME)) ex.add_config({ "batch_size": args.batch_size, "epochs": args.epochs, "learning_rate": args.learning_rate, "encoder": args.encoder, "num_objects": args.num_objects, "custom_neg": args.custom_neg, "in_ep_prob": args.in_ep_prob, "seed": args.seed, "dataset": args.dataset, "save_folder": args.save_folder, "eval_dataset": args_eval.dataset, "num_steps": args_eval.num_steps, "use_action_attention": args.use_action_attention }) device = torch.device('cuda' if args.cuda else 'cpu') dataset = utils.PathDatasetStateIds( hdf5_file=args.dataset, path_length=10) eval_loader = data.DataLoader( dataset, batch_size=100, shuffle=False, num_workers=4) # Get data sample obs = eval_loader.__iter__().next()[0] input_shape = obs[0][0].size() model = modules.ContrastiveSWM( embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, split_mlp=args.split_mlp, same_ep_neg=args.same_ep_neg, only_same_ep_neg=args.only_same_ep_neg, immovable_bit=args.immovable_bit, split_gnn=args.split_gnn, no_loss_first_two=args.no_loss_first_two, bisim_model=make_pairwise_encoder() if args.bisim_model_path else None, encoder=args.encoder, use_action_attention=args.use_action_attention ).to(device) model.load_state_dict(torch.load(model_file)) model.eval() hits_list = [] with torch.no_grad(): for batch_idx, data_batch in enumerate(eval_loader): data_batch = [[t.to( device) for t in tensor] for tensor in data_batch] observations, actions, state_ids = data_batch if observations[0].size(0) != args.batch_size: continue states = [] for obs in observations: states.append(model.obj_encoder(model.obj_extractor(obs))) states = torch.stack(states, dim=0) state_ids = torch.stack(state_ids, dim=0) pred_state = states[0] if not args_eval.no_transition: for i in range(args_eval.num_steps): pred_state = model.forward_transition(pred_state, actions[i]) # pred_state: [100, |O|, D] # states: [10, 100, |O|, D] # pred_state_flat: [100, X] # states_flat: [10, 100, X] pred_state_flat = pred_state.reshape((pred_state.size(0), pred_state.size(1) * pred_state.size(2))) states_flat = states.reshape((states.size(0), states.size(1), states.size(2) * states.size(3))) # dist_matrix: [10, 100] dist_matrix = (states_flat - pred_state_flat[None]).pow(2).sum(2) indices = torch.argmin(dist_matrix, dim=0) correct = indices == args_eval.num_steps # print(indices[0], args_eval.num_steps) # observations = torch.stack(observations, dim=0) # correct_obs = observations[args_eval.num_steps, 0] # pred_obs = observations[indices[0], 0] # import matplotlib # matplotlib.use("TkAgg") # import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(correct_obs.cpu().numpy()[3:].transpose((1, 2, 0))) # plt.subplot(1, 2, 2) # plt.imshow(pred_obs.cpu().numpy()[3:].transpose((1, 2, 0))) # plt.show() # check for duplicates if args_eval.dedup: equal_mask = torch.all(state_ids[indices, list(range(100))] == state_ids[args_eval.num_steps], dim=1) correct = correct + equal_mask # hits hits_list.append(correct.float().mean().item()) hits = np.mean(hits_list) if ex is not None: # ugly hack @ex.main def sacred_main(): ex.log_scalar("hits", hits) ex.log_scalar("mrr", 0.) ex.run() print('Hits @ 1: {}'.format(hits)) return hits, 0.
def train_and_eval(args, eval_every, use_trans_model, ft_data_loader, ft_eval_data_loader): args.cuda = not args.no_cuda and torch.cuda.is_available() now = datetime.datetime.now() timestamp = now.isoformat() if args.name == 'none': exp_name = timestamp else: exp_name = args.name np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) exp_counter = 0 save_folder = '{}/{}/'.format(args.save_folder, exp_name) if not os.path.exists(save_folder): os.makedirs(save_folder) meta_file = os.path.join(save_folder, 'metadata.pkl') model_file = os.path.join(save_folder, 'model.pt') log_file = os.path.join(save_folder, 'log.txt') logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger() logger.addHandler(logging.FileHandler(log_file, 'a')) print = logger.info pickle.dump({'args': args}, open(meta_file, "wb")) device = torch.device('cuda' if args.cuda else 'cpu') print("About to get dataset") transform = None if args.data_aug: transform = utils.get_data_augmentation() dataset = utils.StateTransitionsDataAugDataset(hdf5_file=args.dataset, transforms=transform) else: dataset = utils.StateTransitionsDataset(hdf5_file=args.dataset) train_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) print("Dataset loaded") obs = train_loader.__iter__().next()[0] input_shape = obs[0].size() model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, encoder=args.encoder, use_nt_xent_loss=args.use_nt_xent_loss, temperature=args.temperature, use_slot_attn=args.slot_attn).to(device) model.apply(utils.weights_init) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) # Train model. print('Starting model training...') step = 0 best_loss = 1e9 epoch_acc_list = [] for epoch in range(1, args.epochs + 1): model.train() train_loss = 0 for batch_idx, data_batch in enumerate(train_loader): data_batch = [tensor.to(device) for tensor in data_batch] optimizer.zero_grad() if model.use_nt_xent_loss: loss = model.nt_xent_loss(*data_batch) else: loss = model.contrastive_loss(*data_batch) loss.backward() train_loss += loss.item() optimizer.step() if args.decoder: optimizer_dec.step() if batch_idx % args.log_interval == 0: print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data_batch[0]), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(data_batch[0]))) step += 1 avg_loss = train_loss / len(train_loader.dataset) print('====> Epoch: {} Average loss: {:.6f}'.format(epoch, avg_loss)) if avg_loss < best_loss: best_loss = avg_loss torch.save(model.state_dict(), model_file) if epoch % eval_every == 0 or epoch == args.epochs: #Copy model for fine tuning model_clone = modules.ContrastiveSWM( embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, encoder=args.encoder, use_nt_xent_loss=args.use_nt_xent_loss, temperature=args.temperature, use_slot_attn=args.slot_attn).to('cpu') model_clone.load_state_dict(copy.deepcopy( model.state_dict())) #Deepcopy does not work on model model_clone.to(device) ft_acc_list = fine_tune_and_eval_downstream( model_clone, device, ft_data_loader, ft_eval_data_loader, 10, acc_every=6, use_trans_model=use_trans_model, epochs=60, learning_rate=5e-4) model_clone.to('cpu') #Get best accuracy from list and use as the evaluation result for this training epoch best_ft_acc = max(ft_acc_list) epoch_acc_list.append((epoch, best_ft_acc)) print('[Epoch %d] test acc: %.3f' % (epoch, best_ft_acc)) return epoch_acc_list
bisim_model.load_state_dict(torch.load(args.bisim_model_path)) model = modules.ContrastiveSWM( embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, split_mlp=args.split_mlp, same_ep_neg=args.same_ep_neg, only_same_ep_neg=args.only_same_ep_neg, immovable_bit=args.immovable_bit, split_gnn=args.split_gnn, no_loss_first_two=args.no_loss_first_two, gamma=args.gamma, bisim_metric=bisim_metric, bisim_eps=args.bisim_eps, next_state_neg=args.next_state_neg, nl_type=args.nl_type, encoder=args.encoder, use_coord_grid=args.coord_grid, many_negs=args.num_negs != 1 or args.mix_negs, detach_negs=args.detach_negs, mix_negs=args.mix_negs, use_action_attention=args.use_action_attention).to(device) model.apply(utils.weights_init)
shuffle=False, num_workers=4) # Get data sample obs = eval_loader.__iter__().next()[0] input_shape = obs[0][0].size() model = modules.ContrastiveSWM( embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, split_mlp=args.split_mlp, same_ep_neg=args.same_ep_neg, only_same_ep_neg=args.only_same_ep_neg, immovable_bit=args.immovable_bit, split_gnn=args.split_gnn, no_loss_first_two=args.no_loss_first_two, encoder=args.encoder, use_action_attention=args.use_action_attention).to(device) model.load_state_dict(torch.load(model_file)) model.eval() all_states = [] with torch.no_grad():
# action_dim=args.action_dim, # input_dims=input_shape, # num_objects=args.num_objects, # sigma=args.sigma, # hinge=args.hinge, # ignore_action=args.ignore_action, # copy_action=args.copy_action, # encoder=args.encoder).to(device) model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, encoder=args.encoder, use_nt_xent_loss=args.use_nt_xent_loss, temperature=args.temperature, use_slot_attn=args.slot_attn).to(device) model.load_state_dict(torch.load(model_file)) model.eval() # topk = [1, 5, 10] topk = [1] hits_at = defaultdict(int) num_samples = 0 rr_sum = 0
def train_c_swm(args): args.cuda = not args.no_cuda and torch.cuda.is_available() print("Inside train_c_swm") now = datetime.datetime.now() timestamp = now.isoformat() if args.name == 'none': exp_name = timestamp else: exp_name = args.name np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) exp_counter = 0 save_folder = '{}/{}/'.format(args.save_folder, exp_name) if not os.path.exists(save_folder): os.makedirs(save_folder) meta_file = os.path.join(save_folder, 'metadata.pkl') model_file = os.path.join(save_folder, 'model.pt') log_file = os.path.join(save_folder, 'log.txt') logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger() logger.addHandler(logging.FileHandler(log_file, 'a')) #print = logger.info pickle.dump({'args': args}, open(meta_file, "wb")) device = torch.device('cuda' if args.cuda else 'cpu') print("About to get dataset") dataset = utils.StateTransitionsDataset(hdf5_file=args.dataset) train_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) print("Dataset loaded") # Get data sample obs = train_loader.__iter__().next()[0] input_shape = obs[0].size() model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, encoder=args.encoder).to(device) model.apply(utils.weights_init) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.decoder: if args.encoder == 'large': decoder = modules.DecoderCNNLarge( input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) elif args.encoder == 'medium': decoder = modules.DecoderCNNMedium( input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) elif args.encoder == 'small': decoder = modules.DecoderCNNSmall( input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) decoder.apply(utils.weights_init) optimizer_dec = torch.optim.Adam(decoder.parameters(), lr=args.learning_rate) # Train model. print('Starting model training...') step = 0 best_loss = 1e9 for epoch in range(1, args.epochs + 1): model.train() train_loss = 0 for batch_idx, data_batch in enumerate(train_loader): data_batch = [tensor.to(device) for tensor in data_batch] optimizer.zero_grad() if args.decoder: optimizer_dec.zero_grad() obs, action, next_obs = data_batch objs = model.obj_extractor(obs) state = model.obj_encoder(objs) rec = torch.sigmoid(decoder(state)) loss = F.binary_cross_entropy(rec, obs, reduction='sum') / obs.size(0) next_state_pred = state + model.transition_model(state, action) next_rec = torch.sigmoid(decoder(next_state_pred)) next_loss = F.binary_cross_entropy( next_rec, next_obs, reduction='sum') / obs.size(0) loss += next_loss else: loss = model.contrastive_loss(*data_batch) loss.backward() train_loss += loss.item() optimizer.step() if args.decoder: optimizer_dec.step() if batch_idx % args.log_interval == 0: print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data_batch[0]), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(data_batch[0]))) step += 1 avg_loss = train_loss / len(train_loader.dataset) # print('====> Epoch: {} Average loss: {:.6f}'.format( # epoch, avg_loss)) if avg_loss < best_loss: best_loss = avg_loss torch.save(model.state_dict(), model_file) return model