def main(): """ Test code """ # Train dataset from datasets.MinatarDataset.MinatarDataset import MinatarDataset dataset = MinatarDataset(name="dataset_random_3000_bullet_matched.json") dim_dict = dataset.get_dims() env_len = dim_dict["action_len"] obj_in_len = dim_dict["obj_len"] type_len = dim_dict["type_len"] # Constrcut the model model = PredictionModel(obj_in_len=obj_in_len, env_len=env_len, obj_type_len=type_len, accumulate_batches=4, exist_type_separate=False, appear_type_separate=False) # model = load("") # Train for _ in range(5): idx = random.randint(0, len(dataset)) batch = dataset[idx] # s, a, sprime, sappear, r batch_ = [] for item in batch: batch_.append(item.numpy().tolist()) batch_[-1] = batch_[-1][0] model.updateModel(*batch_) model.predict(batch_[0], batch_[1]) model.save() return 0
def train_pl(): # Square linear dataset = MinatarDataset() dim_dict = dataset.get_dims() env_len = dim_dict["action_len"] obj_in_len = dim_dict["obj_len"] obj_reg_len = 2 obj_attri_len = 2 out_set_size = 10 hidden_dim = 512 # Prepare the dataloader dataset_size = len(dataset) train_size = int(dataset_size * 0.8) train_set, val_set = torch.utils.data.random_split( dataset, [train_size, dataset_size - train_size]) train_data_loader = DataLoader( train_set, batch_size=1, shuffle=True) # num_workers=8, pin_memory=True, val_data_loader = DataLoader(val_set, batch_size=1, pin_memory=True) # Initialize the model model = VariancePointNet(env_len=env_len, obj_in_len=obj_in_len, obj_reg_len=obj_reg_len, obj_attri_len=obj_attri_len, out_set_size=out_set_size, hidden_dim=hidden_dim) # Early stop callback # early_stop_callback = EarlyStopping( # monitor='val_loss', # min_delta=0.00, # patience=3, # verbose=False, # mode='min' # ) # Train trainer = pl.Trainer( gpus=1, precision=16, max_epochs=12, # check_val_every_n_epoch=4, accumulate_grad_batches=64, profiler="simple" # callbacks=[early_stop_callback] ) trainer.fit(model, train_data_loader, val_data_loader) # Evaluate # trainer.test(model, test_dataloaders = val_data_loader) evaluate(model=model)
def evaluate(model=None, path=None): # load model if model is None: if path is None: list_ckpts = glob.glob( os.path.join("lightning_logs", "*", "checkpoints", "*.ckpt")) latest_ckpt = max(list_ckpts, key=os.path.getctime) print("Using checkpoint ", latest_ckpt) path = latest_ckpt model = SetDSPN.load_from_checkpoint(path) # model.freeze() # Evaluate # dataset = MinatarDataset(name="dataset_random_3000_bullet_matched.json") dataset = MinatarDataset(name="dataset_random_3000_new_matched.json") eval_data_loader = DataLoader(dataset, batch_size=1) counter = 0 while counter < 20: batch_idx = random.randint(0, len(dataset)) batch = dataset[batch_idx] s, a, sprime, sappear, r = batch if len(sappear) == 0: continue pred = model(s.unsqueeze(0), a.unsqueeze(0)) visualize(pred, s, sprime, sappear) counter += 1
def evaluate(model=None, path=None): # load model if model is None: if path is None: list_ckpts = glob.glob( os.path.join("lightning_logs", "*", "checkpoints", "*.ckpt")) latest_ckpt = max(list_ckpts, key=os.path.getctime) print("Using checkpoint ", latest_ckpt) path = latest_ckpt model = VariancePointNet.load_from_checkpoint(path) model.freeze() # Evaluate dataset = MinatarDataset() eval_data_loader = DataLoader(dataset, batch_size=1) for i in range(5): batch_idx = random.randint(0, len(dataset)) batch = dataset[batch_idx] s, a, sprime, sappear, r = batch pred = model(s.unsqueeze(0), a.unsqueeze(0)) visualize(pred, sprime, s)
def train_pl(): # Square linear dataset = MinatarDataset(name="dataset_random_3000_bullet_matched.json", dataset_size=100) # dataset = MinatarDataset(name="dataset_random_3000_new_matched.json") # dataset = MinatarDataset(name="dataset_random_3000_full_matched.json") dim_dict = dataset.get_dims() env_len = dim_dict["action_len"] obj_in_len = dim_dict["obj_len"] obj_reg_len = 2 obj_attri_len = 2 out_set_size = 10 hidden_dim = 512 # Prepare the dataloader dataset_size = len(dataset) train_size = int(dataset_size * 0.8) train_set, val_set = torch.utils.data.random_split( dataset, [train_size, dataset_size - train_size]) train_data_loader = DataLoader( train_set, batch_size=1, shuffle=True) # num_workers=8, pin_memory=True, val_data_loader = DataLoader(val_set, batch_size=1, pin_memory=True) # Initialize the model model = SetDSPN(obj_in_len=obj_in_len, obj_reg_len=2, obj_attri_len=2, env_len=env_len, latent_dim=64, out_set_size=3, n_iters=10, internal_lr=50, overall_lr=1e-3, loss_encoder_weight=1) # Early stop callback # early_stop_callback = EarlyStopping( # monitor='val_loss', # min_delta=0.00, # patience=3, # verbose=False, # mode='min' # ) # Native train # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # for i, batch in enumerate(train_data_loader): # print(i) # s, a, sprime, sappear, r = batch # s, a, sappear = s.to(model.device), a.to(model.device), sappear.to(model.device) # pred = model(s, a) # losses = model.loss_fn(pred, sappear) # # optimizer.zero_grad() # losses['loss_encoder'].backward() # optimizer.step() # pass # Train trainer = pl.Trainer( gpus=None, # precision=16, max_epochs=1, # check_val_every_n_epoch=4, # accumulate_grad_batches=64, # profiler="simple", # auto_lr_find=True, # callbacks=[early_stop_callback] ) trainer.fit(model, train_data_loader, val_data_loader)
def train_pl(): # Square linear dataset = MinatarDataset(name="dataset_random_3000_bullet_matched.json") # dataset = MinatarDataset(name="dataset_random_3000_new_matched.json") # dataset = MinatarDataset(name="dataset_random_3000_full_matched.json") # dataset = MinatarDataset(name="asterix_dataset_random_3000.json") dim_dict = dataset.get_dims() env_len = dim_dict["action_len"] obj_in_len = dim_dict["obj_len"] type_len = dim_dict["type_len"] # Prepare the dataloader dataset_size = len(dataset) train_size = int(dataset_size * 0.8) train_set, val_set = torch.utils.data.random_split( dataset, [train_size, dataset_size - train_size]) train_data_loader = DataLoader( train_set, batch_size=1, num_workers=8, shuffle=True) # num_workers=8, pin_memory=True, val_data_loader = DataLoader(val_set, batch_size=1, num_workers=8, pin_memory=True) # Initialize the model # model = SetDSPN( # obj_in_len=obj_in_len, # obj_reg_len=2, # obj_type_len=type_len, # env_len=env_len, # latent_dim=64, # out_set_size=3, # n_iters=10, # internal_lr=50, # overall_lr=1e-3, # loss_encoder_weight=1 # ) model = SetTransformer(obj_in_len=obj_in_len, obj_reg_len=2, obj_type_len=type_len, env_len=env_len, out_set_size=3, learning_rate=1e-4) # Early stop callback # early_stop_callback = EarlyStopping( # monitor='val_loss', # min_delta=0.00, # patience=3, # verbose=False, # mode='min' # ) # Native train # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # for i, batch in enumerate(train_data_loader): # print(i) # s, a, sprime, sappear, r = batch # s, a, sappear = s.to(model.device), a.to(model.device), sappear.to(model.device) # pred = model(s, a) # losses = model.loss_fn(pred, sappear) # # optimizer.zero_grad() # losses['loss_encoder'].backward() # optimizer.step() # pass # Train gpus = torch.cuda.device_count() trainer = pl.Trainer( gpus=1, precision=16, max_epochs=16, # check_val_every_n_epoch=4, accumulate_grad_batches=64, profiler="simple", auto_lr_find=True, log_every_n_steps=5, # callbacks=[early_stop_callback] ) lr_finder = False if lr_finder: # Find the ideal lr lr_finder = trainer.tuner.lr_find(model, train_dataloader=train_data_loader, val_dataloaders=val_data_loader, max_lr=0.1, min_lr=1e-5) # Results can be found in lr_finder.results # Plot with fig = lr_finder.plot(suggest=True) fig.show() # Pick point based on plot, or get suggestion new_lr = lr_finder.suggestion() else: trainer.fit(model, train_data_loader, val_data_loader) # Evaluate # trainer.test(model, test_dataloaders = val_data_loader) evaluate(model=model)