Exemplo n.º 1
0
    def __init__(self, config):
        self.goal_finish = RNN(FinishConfig())
        self.goal_type = AStarType(TypeConfig())
        self.goal_entity = AStarEntity(EntityConfig())

        self.goal_finish.load_state_dict(torch.load(config.goal_finish_path, map_location=torch.device('cpu')))
        self.goal_type.load_state_dict(torch.load(config.goal_type_path, map_location=torch.device('cpu')))
        self.goal_entity.load_state_dict(torch.load(config.goal_entity_path, map_location=torch.device('cpu')))
Exemplo n.º 2
0
def main(config):
    train_loader = DataLoader(GoalTypeDataset("train"),
                              batch_size=config.batch_size,
                              collate_fn=collate,
                              shuffle=True)
    valid_loader = DataLoader(GoalTypeDataset("val"),
                              batch_size=config.batch_size,
                              collate_fn=collate,
                              shuffle=True)
    model = AStarType(config).to(config.device)
    criterion = nn.BCELoss()
    embedding_params_ids = list(map(id, model.goal_embedding.parameters()))
    rest_params = filter(lambda x: id(x) not in embedding_params_ids,
                         model.parameters())
    optimizer = optim.Adam(
        [{
            'params': filter(lambda p: p.requires_grad, rest_params)
        }, {
            'params': model.goal_embedding.parameters(),
            'lr': 0.5
        }],
        lr=config.lr,
    )
    scheduler = CosineAnnealingLR(optimizer, 32)

    train_losses, valid_losses, min_valid_loss = [], [], float('inf')
    for epoch in range(config.num_epoch):
        train_loss, train_acc = train_epoch(model,
                                            criterion,
                                            optimizer,
                                            train_loader,
                                            config.device,
                                            config.max_norm,
                                            scheduler=scheduler)
        valid_loss, valid_acc = validate_epoch(model, valid_loader, criterion,
                                               config.device)
        # break
        tqdm.write(f'epoch #{epoch + 1:3d}\ttrain_loss: {train_loss:.3e}'
                   f' train_acc: {train_acc:.4f}'
                   f' valid_loss: {valid_loss:.3e}'
                   f' valid_acc: {valid_acc:.4f}\n')

        # Early stopping if the current valid_loss is greater than the last three valid losses
        if len(valid_losses) > 2 and all(valid_loss >= loss
                                         for loss in valid_losses[-3:]):
            print('Stop early')
            break

        if valid_loss < min_valid_loss:
            min_valid_loss = valid_loss
            torch.save(model.state_dict(), config.save_path)

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
Exemplo n.º 3
0
class GoalPlanning:
    def __init__(self, config):
        self.goal_finish = RNN(FinishConfig())
        self.goal_type = AStarType(TypeConfig())
        self.goal_entity = AStarEntity(EntityConfig())

        self.goal_finish.load_state_dict(torch.load(config.goal_finish_path, map_location=torch.device('cpu')))
        self.goal_type.load_state_dict(torch.load(config.goal_type_path, map_location=torch.device('cpu')))
        self.goal_entity.load_state_dict(torch.load(config.goal_entity_path, map_location=torch.device('cpu')))

    def goal_finish_infer(self, text, first_goal, final_goal):
        self.goal_finish.eval()
        with torch.no_grad():
            if len(text) > 0:
                text = torch.tensor(text, dtype=torch.long).unsqueeze(1) # (seq_len, batch)
                first_goal = torch.tensor(first_goal, dtype=torch.long).unsqueeze(0) # (batch, seq_len)
                final_goal = torch.tensor(final_goal, dtype=torch.long).unsqueeze(0)
                output = self.goal_finish(text, first_goal, final_goal, "test")
                pred = output.argmax(dim=-1)
                return pred.item()
            else:
                return 0

    def goal_type_infer(self, past_goal_type_seq, cur_goal_type, final_goal_type):
        self.goal_type.eval()
        with torch.no_grad():
            past_goal_type_seq = torch.tensor(past_goal_type_seq, dtype=torch.long).unsqueeze(1)
            cur_goal_type = torch.tensor(cur_goal_type, dtype=torch.long).unsqueeze(0)
            final_goal_type = torch.tensor(final_goal_type, dtype=torch.long).unsqueeze(0)
            # print(final_goal_type)
            # print(past_goal_type_seq.shape, cur_goal_type.shape, final_goal_type.shape)
            output = self.goal_type(past_goal_type_seq, cur_goal_type, final_goal_type, "test")
            return output.item()

    def goal_entity_infer(self, past_entity_seq, cur_entity, final_entity):
        self.goal_entity.eval()
        with torch.no_grad():
            past_entity_seq = torch.tensor(past_entity_seq, dtype=torch.long).unsqueeze(1)
            cur_entity = torch.tensor(cur_entity, dtype=torch.long).unsqueeze(0)
            final_entity = torch.tensor(final_entity, dtype=torch.long).unsqueeze(0)
            output = self.goal_entity(past_entity_seq, cur_entity, final_entity, "test")
            return output.item()