コード例 #1
0
def main():
    """
    Test code
    """
    # Train dataset
    from datasets.MinatarDataset.MinatarDataset import MinatarDataset
    dataset = MinatarDataset(name="dataset_random_3000_bullet_matched.json")
    dim_dict = dataset.get_dims()
    env_len = dim_dict["action_len"]
    obj_in_len = dim_dict["obj_len"]
    type_len = dim_dict["type_len"]

    # Constrcut the model
    model = PredictionModel(obj_in_len=obj_in_len,
                            env_len=env_len,
                            obj_type_len=type_len,
                            accumulate_batches=4,
                            exist_type_separate=False,
                            appear_type_separate=False)
    # model = load("")

    # Train
    for _ in range(5):
        idx = random.randint(0, len(dataset))
        batch = dataset[idx]  # s, a, sprime, sappear, r
        batch_ = []
        for item in batch:
            batch_.append(item.numpy().tolist())
        batch_[-1] = batch_[-1][0]
        model.updateModel(*batch_)
        model.predict(batch_[0], batch_[1])

    model.save()

    return 0
コード例 #2
0
def train_pl():
    # Square linear
    dataset = MinatarDataset()
    dim_dict = dataset.get_dims()
    env_len = dim_dict["action_len"]
    obj_in_len = dim_dict["obj_len"]
    obj_reg_len = 2
    obj_attri_len = 2
    out_set_size = 10
    hidden_dim = 512

    # Prepare the dataloader
    dataset_size = len(dataset)
    train_size = int(dataset_size * 0.8)
    train_set, val_set = torch.utils.data.random_split(
        dataset, [train_size, dataset_size - train_size])
    train_data_loader = DataLoader(
        train_set, batch_size=1,
        shuffle=True)  # num_workers=8, pin_memory=True,
    val_data_loader = DataLoader(val_set, batch_size=1, pin_memory=True)

    # Initialize the model
    model = VariancePointNet(env_len=env_len,
                             obj_in_len=obj_in_len,
                             obj_reg_len=obj_reg_len,
                             obj_attri_len=obj_attri_len,
                             out_set_size=out_set_size,
                             hidden_dim=hidden_dim)

    # Early stop callback
    # early_stop_callback = EarlyStopping(
    #     monitor='val_loss',
    #     min_delta=0.00,
    #     patience=3,
    #     verbose=False,
    #     mode='min'
    # )

    # Train
    trainer = pl.Trainer(
        gpus=1,
        precision=16,
        max_epochs=12,
        # check_val_every_n_epoch=4,
        accumulate_grad_batches=64,
        profiler="simple"
        # callbacks=[early_stop_callback]
    )
    trainer.fit(model, train_data_loader, val_data_loader)

    # Evaluate
    # trainer.test(model, test_dataloaders = val_data_loader)
    evaluate(model=model)
コード例 #3
0
def evaluate(model=None, path=None):
    # load model
    if model is None:
        if path is None:
            list_ckpts = glob.glob(
                os.path.join("lightning_logs", "*", "checkpoints", "*.ckpt"))
            latest_ckpt = max(list_ckpts, key=os.path.getctime)
            print("Using checkpoint ", latest_ckpt)
            path = latest_ckpt

        model = SetDSPN.load_from_checkpoint(path)
        # model.freeze()

    # Evaluate
    # dataset = MinatarDataset(name="dataset_random_3000_bullet_matched.json")
    dataset = MinatarDataset(name="dataset_random_3000_new_matched.json")
    eval_data_loader = DataLoader(dataset, batch_size=1)

    counter = 0
    while counter < 20:
        batch_idx = random.randint(0, len(dataset))
        batch = dataset[batch_idx]
        s, a, sprime, sappear, r = batch
        if len(sappear) == 0:
            continue

        pred = model(s.unsqueeze(0), a.unsqueeze(0))
        visualize(pred, s, sprime, sappear)
        counter += 1
コード例 #4
0
def evaluate(model=None, path=None):
    # load model
    if model is None:
        if path is None:
            list_ckpts = glob.glob(
                os.path.join("lightning_logs", "*", "checkpoints", "*.ckpt"))
            latest_ckpt = max(list_ckpts, key=os.path.getctime)
            print("Using checkpoint ", latest_ckpt)
            path = latest_ckpt

        model = VariancePointNet.load_from_checkpoint(path)
        model.freeze()

    # Evaluate
    dataset = MinatarDataset()
    eval_data_loader = DataLoader(dataset, batch_size=1)
    for i in range(5):
        batch_idx = random.randint(0, len(dataset))
        batch = dataset[batch_idx]
        s, a, sprime, sappear, r = batch
        pred = model(s.unsqueeze(0), a.unsqueeze(0))
        visualize(pred, sprime, s)
コード例 #5
0
def train_pl():
    # Square linear
    dataset = MinatarDataset(name="dataset_random_3000_bullet_matched.json",
                             dataset_size=100)
    # dataset = MinatarDataset(name="dataset_random_3000_new_matched.json")
    # dataset = MinatarDataset(name="dataset_random_3000_full_matched.json")
    dim_dict = dataset.get_dims()
    env_len = dim_dict["action_len"]
    obj_in_len = dim_dict["obj_len"]
    obj_reg_len = 2
    obj_attri_len = 2
    out_set_size = 10
    hidden_dim = 512

    # Prepare the dataloader
    dataset_size = len(dataset)
    train_size = int(dataset_size * 0.8)
    train_set, val_set = torch.utils.data.random_split(
        dataset, [train_size, dataset_size - train_size])
    train_data_loader = DataLoader(
        train_set, batch_size=1,
        shuffle=True)  # num_workers=8, pin_memory=True,
    val_data_loader = DataLoader(val_set, batch_size=1, pin_memory=True)

    # Initialize the model
    model = SetDSPN(obj_in_len=obj_in_len,
                    obj_reg_len=2,
                    obj_attri_len=2,
                    env_len=env_len,
                    latent_dim=64,
                    out_set_size=3,
                    n_iters=10,
                    internal_lr=50,
                    overall_lr=1e-3,
                    loss_encoder_weight=1)

    # Early stop callback
    # early_stop_callback = EarlyStopping(
    #     monitor='val_loss',
    #     min_delta=0.00,
    #     patience=3,
    #     verbose=False,
    #     mode='min'
    # )

    # Native train
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # for i, batch in enumerate(train_data_loader):
    #     print(i)
    #     s, a, sprime, sappear, r = batch
    #     s, a, sappear = s.to(model.device), a.to(model.device), sappear.to(model.device)
    #     pred = model(s, a)
    #     losses = model.loss_fn(pred, sappear)
    #
    #     optimizer.zero_grad()
    #     losses['loss_encoder'].backward()
    #     optimizer.step()
    #     pass

    # Train
    trainer = pl.Trainer(
        gpus=None,
        # precision=16,
        max_epochs=1,
        # check_val_every_n_epoch=4,
        # accumulate_grad_batches=64,
        # profiler="simple",
        # auto_lr_find=True,
        # callbacks=[early_stop_callback]
    )
    trainer.fit(model, train_data_loader, val_data_loader)
コード例 #6
0
def train_pl():
    # Square linear
    dataset = MinatarDataset(name="dataset_random_3000_bullet_matched.json")
    # dataset = MinatarDataset(name="dataset_random_3000_new_matched.json")
    # dataset = MinatarDataset(name="dataset_random_3000_full_matched.json")
    # dataset = MinatarDataset(name="asterix_dataset_random_3000.json")
    dim_dict = dataset.get_dims()
    env_len = dim_dict["action_len"]
    obj_in_len = dim_dict["obj_len"]
    type_len = dim_dict["type_len"]

    # Prepare the dataloader
    dataset_size = len(dataset)
    train_size = int(dataset_size * 0.8)
    train_set, val_set = torch.utils.data.random_split(
        dataset, [train_size, dataset_size - train_size])
    train_data_loader = DataLoader(
        train_set, batch_size=1, num_workers=8,
        shuffle=True)  # num_workers=8, pin_memory=True,
    val_data_loader = DataLoader(val_set,
                                 batch_size=1,
                                 num_workers=8,
                                 pin_memory=True)

    # Initialize the model
    # model = SetDSPN(
    #     obj_in_len=obj_in_len,
    #     obj_reg_len=2,
    #     obj_type_len=type_len,
    #     env_len=env_len,
    #     latent_dim=64,
    #     out_set_size=3,
    #     n_iters=10,
    #     internal_lr=50,
    #     overall_lr=1e-3,
    #     loss_encoder_weight=1
    # )

    model = SetTransformer(obj_in_len=obj_in_len,
                           obj_reg_len=2,
                           obj_type_len=type_len,
                           env_len=env_len,
                           out_set_size=3,
                           learning_rate=1e-4)

    # Early stop callback
    # early_stop_callback = EarlyStopping(
    #     monitor='val_loss',
    #     min_delta=0.00,
    #     patience=3,
    #     verbose=False,
    #     mode='min'
    # )

    # Native train
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # for i, batch in enumerate(train_data_loader):
    #     print(i)
    #     s, a, sprime, sappear, r = batch
    #     s, a, sappear = s.to(model.device), a.to(model.device), sappear.to(model.device)
    #     pred = model(s, a)
    #     losses = model.loss_fn(pred, sappear)
    #
    #     optimizer.zero_grad()
    #     losses['loss_encoder'].backward()
    #     optimizer.step()
    #     pass

    # Train
    gpus = torch.cuda.device_count()
    trainer = pl.Trainer(
        gpus=1,
        precision=16,
        max_epochs=16,
        # check_val_every_n_epoch=4,
        accumulate_grad_batches=64,
        profiler="simple",
        auto_lr_find=True,
        log_every_n_steps=5,
        # callbacks=[early_stop_callback]
    )

    lr_finder = False
    if lr_finder:
        # Find the ideal lr
        lr_finder = trainer.tuner.lr_find(model,
                                          train_dataloader=train_data_loader,
                                          val_dataloaders=val_data_loader,
                                          max_lr=0.1,
                                          min_lr=1e-5)
        # Results can be found in
        lr_finder.results

        # Plot with
        fig = lr_finder.plot(suggest=True)
        fig.show()

        # Pick point based on plot, or get suggestion
        new_lr = lr_finder.suggestion()
    else:
        trainer.fit(model, train_data_loader, val_data_loader)

        # Evaluate
        # trainer.test(model, test_dataloaders = val_data_loader)
        evaluate(model=model)