Exemplo n.º 1
0
    def log_validation_results(trainer):
        # evaluate test(validation) set
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        logger.info(
            "Validation Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n"
            .format(trainer.state.epoch, metrics['loss'],
                    str(metrics['pix-acc']), str(metrics['iou']),
                    str(metrics['f1'])))

        # update scheduler
        lr_scheduler.step(metrics['loss'])

        # update and save state
        update_state(weight=model.state_dict(),
                     train_loss=state['train_loss'],
                     val_loss=metrics['loss'],
                     val_pix_acc=metrics['pix-acc'],
                     val_iou=metrics['iou'],
                     val_f1=metrics['f1'])

        path = ckpt_path.format(network=networks,
                                optimizer=optimizer,
                                lr=lr,
                                epoch=trainer.state.epoch)
        save_ckpt_file(path, state)
Exemplo n.º 2
0
    def log_training_results(trainer):
        # evaluate on training set
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        logger.info("Training Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n".format(
            trainer.state.epoch, metrics['loss'], str(metrics['pix-acc']), str(metrics['iou']), str(metrics['f1'])))

        # update state
        update_state(weight=model.state_dict(),
                     train_loss=metrics['loss'],
                     val_loss=state['val_loss'],
                     val_pix_acc=state['val_pix_acc'],
                     val_iou=state['val_iou'],
                     val_f1=state['val_f1'])
Exemplo n.º 3
0
def webhook_handler():
    signature = request.headers["X-Line-Signature"]
    # get request body as text
    body = request.get_data(as_text=True)
    # app.logger.info(f"Request body: {body}")
    # parse webhook body
    try:
        events = parser.parse(body, signature)
    except InvalidSignatureError:
        abort(400)

    # if event is MessageEvent and message is TextMessage, then echo text
    for event in events:
        if event.source.user_id not in machine:
            machine[event.source.user_id] = TocMachine(
                states=machine_diagram_states,
                transitions=machine_diagram_transitions,
                initial="user",
                auto_transitions=False,
                show_conditions=True,
            )
        if not isinstance(event, MessageEvent) and not isinstance(
                event, PostbackEvent):
            continue
        # if not isinstance(event.message, TextMessage):
        #    continue
        # if not isinstance(event.message.text, str):
        #    continue
        # print(f"\nFSM STATE: {machine.state}")
        machine[event.source.user_id].state = read_state(event.source.user_id)
        response = machine[event.source.user_id].advance(event)
        if not response:
            send_text_message(event.reply_token, "Not Entering any State")
        else:
            update_state(event.source.user_id,
                         machine[event.source.user_id].state)

    return "OK"
Exemplo n.º 4
0
    def run_episode(self):
        """ Game loop """
        # reset
        obs = self.env.reset()
        s = init_state(obs)
        R = 0  # total reward this episode
        self.debug_log = []

        while True:
            time.sleep(self.config['THREAD_DELAY'])  # yield

            if self.render:
                self.env.render()

            a = self.agent.act(s)
            obs, r, done, info = self.env.step(a)
            sp = update_state(s, obs)

            if done:  # terminal state
                sp = None

            self.agent.train(s, a, r, sp)
            if self.debug:
                self.debug_log.append([s, a, r, sp, done])

            s = sp
            R += r

            if done or self.stop_signal:
                Environment.scores.append(R)
                self.episode_number += 1
                # Static purge for now
                Environment.scores = Environment.scores[-100:]

                if self.debug:  # Save logs
                    # TODO: folder restructure
                    save_pickle(
                        self.debug_log,
                        os.path.join(
                            'debug_logs',
                            "ENV{}_EPISODE{}".format(self.id,
                                                     self.episode_number)))
                if self.render:  # Demo mode
                    print("ENV_{} INFO: total reward this episode: {}".format(
                        self.id, R))
                break
Exemplo n.º 5
0
def train_with_ignite(networks, dataset, data_dir, batch_size, img_size,
                      epochs, lr, momentum, num_workers, optimizer, logger):

    from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
    from ignite.metrics import Loss
    from utils.metrics import MultiThresholdMeasures, Accuracy, IoU, F1score

    # device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # build model
    model = get_network(networks)

    # log model summary
    input_size = (3, img_size, img_size)
    summarize_model(model.to(device), input_size, logger, batch_size, device)

    # build loss
    loss = torch.nn.BCEWithLogitsLoss()

    # build optimizer and scheduler
    model_optimizer = get_optimizer(optimizer, model, lr, momentum)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model_optimizer)

    # transforms on both image and mask
    train_joint_transforms = jnt_trnsf.Compose([
        jnt_trnsf.RandomCrop(img_size),
        jnt_trnsf.RandomRotate(5),
        jnt_trnsf.RandomHorizontallyFlip()
    ])

    # transforms only on images
    train_image_transforms = std_trnsf.Compose([
        std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05),
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_joint_transforms = jnt_trnsf.Compose([jnt_trnsf.Safe32Padding()])

    test_image_transforms = std_trnsf.Compose([
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # transforms only on mask
    mask_transforms = std_trnsf.Compose([std_trnsf.ToTensor()])

    # build train / test loader
    train_loader = get_loader(dataset=dataset,
                              data_dir=data_dir,
                              train=True,
                              joint_transforms=train_joint_transforms,
                              image_transforms=train_image_transforms,
                              mask_transforms=mask_transforms,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    test_loader = get_loader(dataset=dataset,
                             data_dir=data_dir,
                             train=False,
                             joint_transforms=test_joint_transforms,
                             image_transforms=test_image_transforms,
                             mask_transforms=mask_transforms,
                             batch_size=1,
                             shuffle=False,
                             num_workers=num_workers)

    # build trainer / evaluator with ignite
    trainer = create_supervised_trainer(model,
                                        model_optimizer,
                                        loss,
                                        device=device)
    measure = MultiThresholdMeasures()
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                '': measure,
                                                'pix-acc': Accuracy(measure),
                                                'iou': IoU(measure),
                                                'loss': Loss(loss),
                                                'f1': F1score(measure),
                                            },
                                            device=device)

    # initialize state variable for checkpoint
    state = update_state(model.state_dict(), 0, 0, 0, 0, 0)

    # make ckpt path
    ckpt_root = './ckpt/'
    filename = '{network}_{optimizer}_lr_{lr}_epoch_{epoch}.pth'
    ckpt_path = os.path.join(ckpt_root, filename)

    # execution after every training iteration
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(trainer):
        num_iter = (trainer.state.iteration - 1) % len(train_loader) + 1
        if num_iter % 20 == 0:
            logger.info("Epoch[{}] Iter[{:03d}] Loss: {:.2f}".format(
                trainer.state.epoch, num_iter, trainer.state.output))

    # execution after every training epoch
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        # evaluate on training set
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        logger.info(
            "Training Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n"
            .format(trainer.state.epoch, metrics['loss'],
                    str(metrics['pix-acc']), str(metrics['iou']),
                    str(metrics['f1'])))

        # update state
        update_state(weight=model.state_dict(),
                     train_loss=metrics['loss'],
                     val_loss=state['val_loss'],
                     val_pix_acc=state['val_pix_acc'],
                     val_iou=state['val_iou'],
                     val_f1=state['val_f1'])

    # execution after every epoch
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        # evaluate test(validation) set
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        logger.info(
            "Validation Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n"
            .format(trainer.state.epoch, metrics['loss'],
                    str(metrics['pix-acc']), str(metrics['iou']),
                    str(metrics['f1'])))

        # update scheduler
        lr_scheduler.step(metrics['loss'])

        # update and save state
        update_state(weight=model.state_dict(),
                     train_loss=state['train_loss'],
                     val_loss=metrics['loss'],
                     val_pix_acc=metrics['pix-acc'],
                     val_iou=metrics['iou'],
                     val_f1=metrics['f1'])

        path = ckpt_path.format(network=networks,
                                optimizer=optimizer,
                                lr=lr,
                                epoch=trainer.state.epoch)
        save_ckpt_file(path, state)

    trainer.run(train_loader, max_epochs=epochs)
Exemplo n.º 6
0
Arquivo: run.py Projeto: xuihau/G2GTr
def validate(model, parser, dev_batched, dataset, device, batch_size_eval,
             pad_action, opt):

    batch_del_label = None
    graph_emb = None

    n_batchs = len(dev_batched)
    dependencies_total = []
    pbar = tqdm(total=n_batchs)
    for batch in range(n_batchs):

        ################################## preparing the data ##################################
        batch_buffer_ind, batch_buffer, batch_buffer_pos, mask_buffer = dev_batched[
            batch]

        batch_buffer = batch_buffer.to(device)
        batch_buffer_ind = batch_buffer_ind.to(device)
        batch_buffer_pos = batch_buffer_pos.to(device)
        mask_buffer = mask_buffer.to(device)
        batch_size = batch_buffer.size()[0]

        if opt.graphinput:
            batch_delete = batch_buffer.clone()
            batch_delete_pos = batch_buffer_pos.clone()

            graph_emb = torch.zeros((batch_size, batch_buffer.size()[1] + 2,
                                     batch_buffer.size()[1])).long().to(device)
            graph_input = torch.zeros(
                (batch_size, 3 * batch_buffer.size()[1] + 4,
                 3 * batch_buffer.size()[1] + 4)).long().to(device)

        #### build attention mask
        mask_stack = torch.zeros(
            (batch_size, batch_buffer.size()[1] + 1)).byte().to(device)
        mask_stack[:, -1] = 1

        #### build stack word
        batch_stack = torch.zeros(
            (batch_size, batch_buffer.size()[1] + 1)).long().to(device)
        batch_stack[:, -1] = parser.ROOT

        batch_stack_ind = (torch.ones(
            (batch_size, batch_buffer.size()[1] + 1)) *
                           (batch_buffer.size()[1] + 1)).long().to(device)
        batch_stack_ind[:, -1] = 0

        #### build stack POS
        batch_stack_pos = torch.zeros(
            (batch_size, batch_buffer.size()[1] + 1)).long().to(device)
        batch_stack_pos[:, -1] = parser.P_ROOT

        #### build dep
        batch_dep = torch.ones_like(batch_stack) * parser.NULL
        batch_dep_pos = torch.ones_like(batch_stack_pos) * parser.P_NULL
        batch_dep_label = torch.ones_like(batch_stack_pos) * parser.L_NULL

        batch_dep_buffer = torch.ones_like(batch_stack[:, :-1]) * parser.NULL
        batch_dep_pos_buffer = torch.ones_like(
            batch_stack_pos[:, :-1]) * parser.P_NULL
        batch_dep_label_buffer = torch.ones_like(
            batch_stack_pos[:, :-1]) * parser.L_NULL

        ### token_type_ids
        if opt.graphinput:
            token_type_ids = torch.zeros(
                (batch_size, 3 * batch_buffer.size()[1] + 4)).long().to(device)
            token_type_ids[:, :batch_buffer.size()[1] + 2] = 2
            token_type_ids[:,
                           batch_buffer.size()[1] +
                           2:2 * batch_buffer.size()[1] + 3] = 1
        else:
            token_type_ids = torch.zeros(
                (batch_size, 2 * batch_buffer.size()[1] + 4)).long().to(device)
            token_type_ids[:, :batch_buffer.size()[1] + 3] = 1

        if opt.seppoint:
            sep_point = batch_buffer.size()[1] + 1
        else:
            sep_point = 0

        ### build the dependency set
        dependencies = []
        for _ in range(batch_size):
            dependencies.append([])

        ###### do the initialization of model
        batch_size = len(batch_buffer)

        mask_cls = torch.ones((batch_size, 1)).byte().to(device)
        mask_sep = torch.ones((batch_size, 1)).byte().to(device)

        step_i = 0

        ## clip cls and sep
        ones = (torch.ones((batch_size, 1))).long().to(device)
        batch_temp = batch_dep[:, :-1].clone()
        batch_pos_temp = batch_dep_pos[:, :-1].clone()
        batch_label_temp = batch_dep_label[:, :-1].clone()
        if opt.graphinput:
            batch_del_label = batch_label_temp.clone()

        update = torch.zeros(batch_size).long().to(device)
        action_state = None
        action_cell = None
        transitions = None
        labels = None

        while True:
            if len(torch.nonzero(update)) == batch_size:
                break

            if opt.graphinput:
                mask_delete, graph_in = prepare_graph(graph_emb,
                                                      batch_stack_ind,
                                                      batch_buffer_ind,
                                                      batch_size, graph_input,
                                                      device, parser.NULL)
                input_ids_x, pos_ids_x, batch_dep_input, batch_dep_pos_input,batch_dep_label_input,\
                batch_graph_label_input, attention_mask = prepare_data(ones,opt.graphinput,parser,
                        batch_stack,batch_buffer,batch_stack_pos,batch_buffer_pos,
                        batch_dep,batch_temp,batch_dep_pos,batch_pos_temp,batch_dep_label,batch_label_temp, mask_cls,
                        mask_stack,mask_sep,mask_buffer,batch_dep_buffer,batch_dep_pos_buffer,batch_dep_label_buffer,
                        mask_delete,batch_del_label,batch_delete,batch_delete_pos)
                transitions, labels, action_state, action_cell = model(
                    1, sep_point, input_ids_x, pos_ids_x, batch_dep_input,
                    batch_dep_pos_input, batch_dep_label_input,
                    batch_graph_label_input, attention_mask, update,
                    token_type_ids, mask_stack, mask_buffer, batch_stack_ind,
                    transitions, labels, action_state, action_cell, graph_in)

            else:
                input_ids_x, pos_ids_x, batch_dep_input, batch_dep_pos_input,batch_dep_label_input,attention_mask =\
                prepare_data(ones,opt.graphinput,parser,batch_stack,batch_buffer,batch_stack_pos,batch_buffer_pos,
                            batch_dep,batch_temp,batch_dep_pos,batch_pos_temp,batch_dep_label,batch_label_temp,
                            mask_cls,mask_stack,mask_sep,mask_buffer,batch_dep_buffer,batch_dep_pos_buffer,
                            batch_dep_label_buffer)

                transitions, labels, action_state, action_cell = model(
                    1, sep_point, input_ids_x, pos_ids_x, batch_dep_input,
                    batch_dep_pos_input, batch_dep_label_input, None,
                    attention_mask, update, token_type_ids, mask_stack,
                    mask_buffer, batch_stack_ind, transitions, labels,
                    action_state, action_cell)

            ####### update stack and buffer ##############################
            batch_stack_ind,batch_stack,batch_buffer_ind,batch_buffer,batch_stack_pos,batch_buffer_pos,\
            batch_dep,batch_dep_pos,mask_buffer,mask_stack, batch_dep_label, graph_emb,batch_del_label,\
            batch_dep_buffer,batch_dep_pos_buffer,batch_dep_label_buffer,\
            dependencies = update_state(1,opt,batch_stack_ind, batch_stack, batch_buffer_ind,
                        batch_buffer, batch_stack_pos,batch_buffer_pos,batch_dep,batch_dep_pos,mask_buffer,
                        mask_stack, transitions, batch_dep_label,labels,parser.NULL,parser.P_NULL,parser.L_NULL,
                        batch_dep_buffer,batch_dep_pos_buffer,batch_dep_label_buffer,
                        graph_emb,batch_del_label,dependencies=dependencies)

            update = ((mask_buffer.sum(dim=1) == 0) *
                      (mask_stack.sum(dim=1) == 1) * pad_action * 1.0).long()

            del input_ids_x, pos_ids_x, batch_dep_input,\
                batch_dep_pos_input,batch_dep_label_input,attention_mask
            if opt.graphinput:
                del mask_delete, graph_in
            ######################################################################################

            step_i += 1

        pbar.update(1)
        del batch_buffer_ind, batch_buffer, batch_buffer_pos,mask_buffer, \
                batch_stack_ind, batch_stack, batch_stack_pos, token_type_ids, batch_dep, batch_dep_pos,ones,\
                batch_label_temp,batch_dep_label,batch_dep_buffer,batch_dep_pos_buffer,batch_dep_label_buffer
        if opt.graphinput:
            del batch_delete, graph_emb, graph_input, batch_del_label, batch_delete_pos

        dependencies_total.extend(dependencies)

    with open(opt.mainpath + '/dependency/' + str(opt.outputname) + '.pkl',
              'wb') as f:
        pickle.dump(dependencies_total, f, pickle.HIGHEST_PROTOCOL)

    UAS = all_tokens = 0.0
    LAS = 0.0
    with tqdm(total=len(dataset)) as prog:
        for i, ex in enumerate(dataset):
            head = [-1] * len(ex['word'])
            label = [-1] * len(ex['word'])
            for h, t, l in dependencies_total[i]:
                head[t] = h
                label[t] = l
            for pred_h, pred_l, gold_h, gold_l, pos in zip(
                    head[1:], label[1:], ex['head'][1:], ex['label'][1:],
                    ex['pos'][1:]):

                assert parser.id2tok[pos].startswith(P_PREFIX)
                pos_str = parser.id2tok[pos][len(P_PREFIX):]
                UAS += 1.0 if pred_h == gold_h else 0.0
                if pred_h == gold_h and pred_l == gold_l:
                    LAS += 1.0
                all_tokens += 1

            prog.update(i + 1)
    UAS /= all_tokens
    LAS /= all_tokens
    del dependencies_total

    return UAS, LAS
Exemplo n.º 7
0
Arquivo: run.py Projeto: xuihau/G2GTr
def train(model, parser, train_batched, train_data, dev_batched,
          dev_data_total, output_path, device, max_seq_length, mean_seq_length,
          emb_size, opt, pad_action):

    pad_action = pad_action['P']

    best_dev_LAS = 0
    n_batchs = len(train_batched)

    if opt.mean_seq:
        num_train_optimization_steps = opt.nepochs * mean_seq_length * n_batchs
    else:
        num_train_optimization_steps = opt.nepochs * max_seq_length * n_batchs

    print('number of steps')
    print(num_train_optimization_steps)
    ## define the optimizer
    if opt.Bertoptim:
        if opt.use_two_opts:
            print("use two optimizers")
            model_nonbert = []
            model_bert = []
            layernorm_params = [
                'LayerNormKeys', 'dp_relation_k', 'dp_relation_v', 'compose',
                'label_emb', 'pooler.dense', 'pooler.dense_label'
            ]
            for name, param in model.named_parameters():
                if 'bertmodel' in name and not any(nd in name
                                                   for nd in layernorm_params):
                    model_bert.append(param)
                else:
                    model_nonbert.append(param)
            optimizer = BertAdam(model_bert,
                                 lr=opt.lr,
                                 warmup=opt.warmupproportion,
                                 t_total=num_train_optimization_steps)
            optimizer_nonbert = BertAdam(model_nonbert,
                                         lr=opt.lr_nonbert,
                                         warmup=opt.warmupproportion,
                                         t_total=num_train_optimization_steps)
        else:
            optimizer = BertAdam(model.parameters(),
                                 lr=opt.lr,
                                 warmup=opt.warmupproportion,
                                 t_total=num_train_optimization_steps)
    else:
        optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    ## load optimizer state from a checkpoint
    if opt.pretrained:
        state_dict = torch.load(opt.mainpath + '/output/' +
                                str(opt.modelpath) + "model.weights" +
                                "pretrained")
        optimizer.load_state_dict(state_dict['optimizer'])
        del state_dict

    loss_func = nn.CrossEntropyLoss()
    loss_func_label = nn.CrossEntropyLoss()

    batch_del_label = None
    graph_emb = None

    n_epochs = opt.nepochs - opt.real_epoch

    for n_epoch in range(n_epochs):

        print("Epoch {:} out of {:}".format(n_epoch + 1, opt.nepochs))
        model.train()
        if not opt.Bertoptim:
            adjust_learning_rate(opt.lr, opt.updatelr, optimizer, n_epoch)
        loss_meter = AverageMeter()
        iters = np.arange(n_batchs)
        if opt.shuffle:
            print('Do shuffling')
            random.shuffle(iters)
        loss = 0.
        pbar = tqdm(total=n_batchs)
        for it in iters:

            ################################## preparing the data #################################
            batch_buffer_ind,batch_buffer, batch_buffer_pos, mask_buffer ,\
                actions_batch, actions_mask_batch, batch_labels, mask_labels = train_batched[it]

            actions_batch = actions_batch.to(device)
            actions_mask_batch = actions_mask_batch.to(device)
            batch_labels = batch_labels.to(device)
            mask_labels = mask_labels.to(device)

            batch_buffer = batch_buffer.to(device)
            batch_buffer_ind = batch_buffer_ind.to(device)

            batch_buffer_pos = batch_buffer_pos.to(device)
            mask_buffer = mask_buffer.to(device)

            batch_size = len(batch_buffer)

            if opt.graphinput:
                batch_delete = batch_buffer.clone()
                batch_delete_pos = batch_buffer_pos.clone()

                graph_emb = torch.zeros(
                    (batch_size, batch_buffer.size()[1] + 2,
                     batch_buffer.size()[1])).long().to(device)
                graph_input = torch.zeros(
                    (batch_size, 3 * batch_buffer.size()[1] + 4,
                     3 * batch_buffer.size()[1] + 4)).long().to(device)

            #### build attention mask
            mask_stack = torch.zeros(
                (batch_size, batch_buffer.size()[1] + 1)).byte().to(device)
            mask_stack[:, -1] = 1

            #### build stack word
            batch_stack = torch.zeros(
                (batch_size, batch_buffer.size()[1] + 1)).long().to(device)
            batch_stack[:, -1] = parser.ROOT
            batch_stack_ind = (torch.ones(
                (batch_size, batch_buffer.size()[1] + 1)) *
                               (batch_buffer.size()[1] + 1)).long().to(device)
            batch_stack_ind[:, -1] = 0
            #### build stack POS
            batch_stack_pos = torch.zeros(
                (batch_size, batch_buffer.size()[1] + 1)).long().to(device)
            batch_stack_pos[:, -1] = parser.P_ROOT

            #### build dep
            batch_dep = torch.ones_like(batch_stack) * parser.NULL
            batch_dep_pos = torch.ones_like(batch_stack_pos) * parser.P_NULL
            batch_dep_label = torch.ones_like(batch_stack_pos) * parser.L_NULL

            batch_dep_buffer = torch.ones_like(
                batch_stack[:, :-1]) * parser.NULL
            batch_dep_pos_buffer = torch.ones_like(
                batch_stack_pos[:, :-1]) * parser.P_NULL
            batch_dep_label_buffer = torch.ones_like(
                batch_stack_pos[:, :-1]) * parser.L_NULL

            ### token_type_ids
            if opt.graphinput:
                token_type_ids = torch.zeros(
                    (batch_size,
                     3 * batch_buffer.size()[1] + 4)).long().to(device)
                token_type_ids[:, :batch_buffer.size()[1] + 2] = 2
                token_type_ids[:,
                               batch_buffer.size()[1] +
                               2:2 * batch_buffer.size()[1] + 3] = 1
            else:
                token_type_ids = torch.zeros(
                    (batch_size,
                     2 * batch_buffer.size()[1] + 4)).long().to(device)
                token_type_ids[:, :batch_buffer.size()[1] + 3] = 1

            if opt.seppoint:
                sep_point = batch_buffer.size()[1] + 1
            else:
                sep_point = 0
            mask_cls = torch.ones((batch_size, 1)).byte().to(device)
            mask_sep = torch.ones((batch_size, 1)).byte().to(device)

            # main loop
            if actions_batch is None:
                step_length = opt.maxsteplength
            else:
                step_length = actions_batch.size()[1]

            step_i = 0
            ## clip cls and sep
            ones = (torch.ones((batch_size, 1))).long().to(device)
            batch_temp = batch_dep[:, :-1].clone()
            batch_pos_temp = batch_dep_pos[:, :-1].clone()
            batch_label_temp = batch_dep_label[:, :-1].clone()

            if opt.graphinput:
                batch_del_label = batch_label_temp.clone()

            update = torch.zeros(batch_size).long().to(device)

            action_state = None
            action_cell = None

            mode = 0
            while True:
                if step_i == step_length - 1:
                    mode = 1
                if len(torch.nonzero(update)) == batch_size:
                    break

                if step_i == 0:
                    prev_transitions = None
                else:
                    prev_transitions = actions_batch[:, step_i - 1]

                if step_i == 0:
                    prev_labels = None
                else:
                    prev_labels = batch_labels[:, step_i - 1]

                transitions = actions_batch[:, step_i]
                steps = batch_labels[:, step_i]

                if opt.graphinput:
                    ## build graoh input matrix
                    mask_delete, graph_in = prepare_graph(
                        graph_emb, batch_stack_ind, batch_buffer_ind,
                        batch_size, graph_input, device, parser.NULL)
                    ## prepare data for transformer
                    input_ids_x, pos_ids_x, batch_dep_input, batch_dep_pos_input,batch_dep_label_input,\
                    batch_graph_label_input,attention_mask =prepare_data(ones,opt.graphinput,parser,batch_stack,
                            batch_buffer,batch_stack_pos,batch_buffer_pos,batch_dep,batch_temp,batch_dep_pos,
                            batch_pos_temp,batch_dep_label,batch_label_temp, mask_cls,mask_stack,mask_sep,mask_buffer,
                            batch_dep_buffer,batch_dep_pos_buffer,batch_dep_label_buffer,mask_delete,batch_del_label
                                                                         ,batch_delete,batch_delete_pos)

                    output_batch, output_label, action_state, action_cell = model(
                        0, sep_point, input_ids_x, pos_ids_x, batch_dep_input,
                        batch_dep_pos_input, batch_dep_label_input,
                        batch_graph_label_input, attention_mask, update,
                        token_type_ids, mask_stack, mask_buffer,
                        batch_stack_ind, prev_transitions, prev_labels,
                        action_state, action_cell, graph_in)
                else:
                    input_ids_x, pos_ids_x, batch_dep_input, batch_dep_pos_input,batch_dep_label_input,attention_mask =\
                    prepare_data(ones,opt.graphinput,parser,batch_stack,batch_buffer,batch_stack_pos,batch_buffer_pos,
                            batch_dep,batch_temp,batch_dep_pos,batch_pos_temp,batch_dep_label,batch_label_temp,
                            mask_cls,mask_stack,mask_sep,mask_buffer,batch_dep_buffer,batch_dep_pos_buffer,
                            batch_dep_label_buffer)

                    output_batch, output_label, action_state, action_cell = model(
                        0, sep_point, input_ids_x, pos_ids_x, batch_dep_input,
                        batch_dep_pos_input, batch_dep_label_input, None,
                        attention_mask, update, token_type_ids, mask_stack,
                        mask_buffer, batch_stack_ind, prev_transitions,
                        prev_labels, action_state, action_cell)
                ##################### mask the output batch #######################################
                action_batch = actions_batch[:, step_i]
                action_mask_batch = actions_mask_batch[:, step_i]

                action_batch = action_batch.masked_select(action_mask_batch)
                output_batch = output_batch.masked_select(
                    action_mask_batch.unsqueeze(0).t()).view(-1, opt.nclass)

                ############################ mask the output label batch ###########################
                batch_label = batch_labels[:, step_i]
                mask_label = mask_labels[:, step_i]

                batch_label = batch_label.masked_select(mask_label)
                output_label = output_label.masked_select(
                    mask_label.unsqueeze(0).t()).view(-1, parser.n_transit - 1)

                if output_label.nelement():
                    loss = loss_func(output_batch,
                                     action_batch) + loss_func_label(
                                         output_label, batch_label)
                else:
                    loss = loss_func(output_batch, action_batch)

                ## back propagation
                if opt.use_two_opts:
                    optimizer_nonbert.zero_grad()
                    optimizer.zero_grad()
                else:
                    optimizer.zero_grad()
                if mode:
                    loss.backward()
                else:
                    loss.backward(retain_graph=True)

                if opt.use_two_opts:
                    optimizer.step()
                    optimizer_nonbert.step()
                else:
                    optimizer.step()

                loss_meter.update(loss.item())


                del output_batch,output_label,input_ids_x, pos_ids_x, batch_dep_input,\
                    batch_dep_pos_input,batch_dep_label_input,attention_mask
                if opt.graphinput:
                    del mask_delete, graph_in
                ############################## update stack and buffer ##############################
                batch_stack_ind,batch_stack,batch_buffer_ind,batch_buffer,batch_stack_pos,batch_buffer_pos,\
                batch_dep,batch_dep_pos,mask_buffer,mask_stack, batch_dep_label, graph_emb,batch_del_label,\
                batch_dep_buffer,batch_dep_pos_buffer,batch_dep_label_buffer = update_state(0,opt,batch_stack_ind,
                            batch_stack, batch_buffer_ind, batch_buffer,batch_stack_pos, batch_buffer_pos,batch_dep,
                            batch_dep_pos,mask_buffer,mask_stack,transitions, batch_dep_label,steps,parser.NULL,
                            parser.P_NULL,parser.L_NULL,batch_dep_buffer,batch_dep_pos_buffer,batch_dep_label_buffer,
                            graph_emb,batch_del_label)
                update = ((mask_buffer.sum(dim=1) == 0) *
                          (mask_stack.sum(dim=1) == 1) * pad_action).long()
                step_i += 1

            pbar.update(1)
            del actions_batch,batch_labels,mask_labels,batch_buffer, batch_buffer_pos,mask_buffer, actions_mask_batch, \
                    batch_stack, batch_stack_pos, token_type_ids, batch_dep, batch_dep_pos,batch_dep_label,\
                    ones, batch_dep_buffer,batch_dep_pos_buffer,batch_dep_label_buffer
            if opt.graphinput:
                del batch_delete, batch_delete_pos, graph_emb, batch_del_label, graph_input

        print("Average Train Loss: {}".format(loss_meter.avg))
        print("")

        print("Evaluating on dev set", )

        model.eval()

        dev_UAS, dev_LAS = validate(model, parser, dev_batched, dev_data_total,
                                    device, opt.batchsize, pad_action, opt)
        print("- dev UAS: {:.2f}".format(dev_UAS * 100.0))
        print("- dev LAS: {:.2f}".format(dev_LAS * 100.0))
        if dev_LAS > best_dev_LAS:
            best_dev_LAS = dev_LAS
            print("New best dev UAS! Saving model.")
            torch.save(
                {
                    'model': model.state_dict(),
                    'opt': opt,
                    'optimizer': optimizer.state_dict()
                }, output_path)

        torch.save(
            {
                'model': model.state_dict(),
                'opt': opt,
                'optimizer': optimizer.state_dict()
            }, output_path + "pretrained")
Exemplo n.º 8
0
def test_normal_mu_shift():
    batches = get_batches_normal()

    # drift_detector_type = "MOA_KL"
    drift_detector_type = "MOA_SUBSAMPLE_SIMPLE_MAJORITY"
    univariate_detector_types = \
        ["CUSUM", "ADWIN", "GEOMETRIC_MOVING_AVERAGE",
         "DDM", "EDDM", "EWMA", "PAGE_HINKLEY", "HDDM_A",
         "HDDM_W", "SEED", "SEQ1"]

    # univariate_detector_types = ["HDDM_A"]

    # TODO EnsembleDriftDetectionMethods
    univariate_detector_options = {
        "CUSUM": {
            "minNumInstancesOption": {"type": "IntOption", "value": 300, "min": 0, "max": sys.maxsize},
            "deltaOption": {"type": "FloatOption", "value": 0.005, "min": 0.0, "max": 1.0},
            "lambdaOption": {"type": "FloatOption", "value": 50.0, "min": 0.0, "max": sys.float_info.max},
        },
        "ADWIN": {
            "deltaAdwinOption": {"type": "FloatOption", "value": 0.002, "min": 0.0, "max": 1.0},
        },
        "DDM": {
            "minNumInstancesOption": {"type": "IntOption", "value": 30, "min": 0, "max": sys.maxsize},
            "warningLevelOption": {"type": "FloatOption", "value": 2.0, "min": 1.0, "max": 4.0},
            "outcontrolLevelOption": {"type": "FloatOption", "value": 3.0, "min": 1.0, "max": 5.0},
        },
        "EWMA": {
            "minNumInstancesOption": {"type": "IntOption", "value": 30, "min": 0, "max": sys.maxsize},
            "lambdaOption": {"type": "FloatOption", "value": 0.2, "min": 0.0, "max": sys.float_info.max},
        },
        "GEOMETRIC_MOVING_AVERAGE": {
            "minNumInstancesOption": {"type": "IntOption", "value": 30, "min": 0, "max": sys.maxsize},
            "lambdaOption": {"type": "FloatOption", "value": 1.0, "min": 0.0, "max": sys.float_info.max},
            "alphaOption": {"type": "FloatOption", "value": 0.99, "min": 0.0, "max": 1.0},
        },
        "EDDM": {},
        "PAGE_HINKLEY": {
            "minNumInstancesOption": {"type": "IntOption", "value": 30, "min": 0, "max": sys.maxsize},
            "deltaOption": {"type": "FloatOption", "value": 0.005, "min": 0.0, "max": 1.0},
            "lambdaOption": {"type": "FloatOption", "value": 50.0, "min": 0.0, "max": sys.float_info.max},
            "alphaOption": {"type": "FloatOption", "value": 1 - 0.0001, "min": 0.0, "max": 1.0},
        },
        "HDDM_A": {
            "driftConfidenceOption": {"type": "FloatOption", "value": 0.001, "min": 0.0, "max": 1.0},
            "warningConfidenceOption": {"type": "FloatOption", "value": 0.005, "min": 0.0, "max": 1.0},
            "oneSidedTestOption": {"type": "MultiChoiceOption", "value": 1, "options": [0, 1]},
        },
        "HDDM_W": {
            "driftConfidenceOption": {"type": "FloatOption", "value": 0.001, "min": 0.0, "max": 1.0},
            "warningConfidenceOption": {"type": "FloatOption", "value": 0.005, "min": 0.0, "max": 1.0},
            "lambdaOption": {"type": "FloatOption", "value": 0.05, "min": 0.0, "max": 1.0},
            "oneSidedTestOption": {"type": "MultiChoiceOption", "value": 0, "options": [0, 1]},
        },
        "SEED": {
            "deltaSEEDOption": {"type": "FloatOption", "value": 0.05, "min": 0.0, "max": 1.0},
            "blockSizeSEEDOption": {"type": "IntOption", "value": 32, "min": 32, "max": 256},
            "epsilonPrimeSEEDOption": {"type": "FloatOption", "value": 0.01, "min": 0.0025, "max": 0.01},
            "alphaSEEDOption": {"type": "FloatOption", "value": 0.8, "min": 0.2, "max": 0.8},
            "compressTermSEEDOption": {"type": "IntOption", "value": 75, "min": 50, "max": 100},
        },
        "SEQ1": {
            "deltaOption": {"type": "FloatOption", "value": 0.01, "min": 0.0, "max": 1.0},
            "deltaWarningOption": {"type": "FloatOption", "value": 0.1, "min": 0.0, "max": 1.0},
            "blockSeqDriftOption": {"type": "IntOption", "value": 200, "min": 100, "max": 10000},
        },

    }
    """
    IntOption 
    FloatOption
    ListOption
    MultiChoiceOption
    
    CusumDM: minNumInstancesOption = new IntOption("minNumInstances", 'n',  "The minimum number of instances before permitting detecting change.", 30, 0, Integer.MAX_VALUE);
             deltaOption = new FloatOption("delta", 'd', "Delta parameter of the Cusum Test", 0.005, 0.0, 1.0);
             lambdaOption = new FloatOption("lambda", 'l', "Threshold parameter of the Cusum Test", 50, 0.0, Float.MAX_VALUE);
             
    ADWINChangeDetector: deltaAdwinOption = new FloatOption("deltaAdwin", 'a',  "Delta of Adwin change detection", 0.002, 0.0, 1.0);
    
    DDM: minNumInstancesOption = new IntOption("minNumInstances", 'n', "The minimum number of instances before permitting detecting change.", 30, 0, Integer.MAX_VALUE);
         warningLevelOption = new FloatOption( "warningLevel", 'w', "Warning Level.", 2.0, 1.0, 4.0);
         outcontrolLevelOption = new FloatOption("outcontrolLevel", 'o', "Outcontrol Level.", 3.0, 1.0, 5.0);
         
    EnsembleDriftDetectionMethods: changeDetectorsOption = new ListOption("changeDetectors", 'c', "Change Detectors to use.", new ClassOption("driftDetectionMethod", 'd',  "Drift detection method to use.", ChangeDetector.class, "DDM"),new Option[0], ',');
                                   predictionOption = new MultiChoiceOption("prediction", 'l', "Prediction to use.", new String[]{"max", "min", "majority"}, new String[]{"Maximum", "Minimum", "Majority"}, 0);

    EWMAChartDM: minNumInstancesOption = new IntOption("minNumInstances", 'n',  "The minimum number of instances before permitting detecting change.", 30, 0, Integer.MAX_VALUE);
                 lambdaOption = new FloatOption("lambda", 'l', "Lambda parameter of the EWMA Chart Method", 0.2, 0.0, Float.MAX_VALUE);


    GeometricMovingAverageDM: minNumInstancesOption = new IntOption("minNumInstances", 'n',  "The minimum number of instances before permitting detecting change.", 30, 0, Integer.MAX_VALUE);
                              lambdaOption = new FloatOption("lambda", 'l', "Threshold parameter of the Geometric Moving Average Test", 1, 0.0, Float.MAX_VALUE);
                              alphaOption = new FloatOption("alpha", 'a', "Alpha parameter of the Geometric Moving Average Test", .99, 0.0, 1.0);
                              
    HDDM_A_Test: driftConfidenceOption = new FloatOption("driftConfidence", 'd', "Confidence to the drift", 0.001, 0, 1);
                 warningConfidenceOption = new FloatOption("warningConfidence", 'w', "Confidence to the warning", 0.005, 0, 1);
                 oneSidedTestOption = new MultiChoiceOption("typeOfTest", 't',  "Monitors error increments and decrements (two-sided) or only increments (one-sided)", new String[]{ "One-sided", "Two-sided"}, new String[]{"One-sided", "Two-sided"}, 1);

    HDDM_W_Test: driftConfidenceOption = new FloatOption("driftConfidence", 'd', "Confidence to the drift", 0.001, 0, 1);
                 warningConfidenceOption = new FloatOption("warningConfidence", 'w', "Confidence to the warning", 0.005, 0, 1);
                 lambdaOption = new FloatOption("lambda", 'm', "Controls how much weight is given to more recent data compared to older data. Smaller values mean less weight given to recent data.", 0.050, 0, 1);
                 oneSidedTestOption = new MultiChoiceOption("typeOfTest", 't', "Monitors error increments and decrements (two-sided) or only increments (one-sided)", new String[]{"One-sided", "Two-sided"}, new String[]{"One-sided", "Two-sided"},0);

    PageHinkleyDM: minNumInstancesOption = new IntOption("minNumInstances", 'n', "The minimum number of instances before permitting detecting change.", 30, 0, Integer.MAX_VALUE);
                   deltaOption = new FloatOption("delta", 'd', "Delta parameter of the Page Hinkley Test", 0.005, 0.0, 1.0);
                   lambdaOption = new FloatOption("lambda", 'l', "Lambda parameter of the Page Hinkley Test", 50, 0.0, Float.MAX_VALUE);
                   alphaOption = new FloatOption("alpha", 'a', "Alpha parameter of the Page Hinkley Test", 1 - 0.0001, 0.0, 1.0);
                   
    RDDM: minNumInstancesOption = new IntOption("minNumInstances", 'n', "Minimum number of instances before monitoring changes.", 129, 0, Integer.MAX_VALUE);
          warningLevelOption = new FloatOption("warningLevel", 'w', "Warning Level.", 1.773, 1.0, 4.0);
          driftLevelOption = new FloatOption("driftLevel", 'o', "Drift Level.", 2.258, 1.0, 5.0);
          maxSizeConceptOption = new IntOption("maxSizeConcept", 'x', "Maximum Size of Concept.", 40000, 1, Integer.MAX_VALUE);
          minSizeStableConceptOption = new IntOption("minSizeStableConcept", 'y', "Minimum Size of Stable Concept.", 7000, 1, 20000);
          warnLimitOption = new IntOption("warnLimit", 'z', "Warning Limit of instances", 1400, 1, 20000);
          
    SEEDChangeDetector: deltaSEEDOption = new FloatOption("deltaSEED", 'd', "Delta value of SEED Detector", 0.05, 0.0, 1.0);
                        blockSizeSEEDOption = new IntOption("blockSizeSEED", 'b', "BlockSize value of SEED Detector", 32, 32, 256);
                        epsilonPrimeSEEDOption = new FloatOption("epsilonPrimeSEED", 'e', "EpsilonPrime value of SEED Detector", 0.01, 0.0025, 0.01);
                        alphaSEEDOption = new FloatOption("alphaSEED", 'a', "Alpha value of SEED Detector", 0.8, 0.2, 0.8);
                        compressTermSEEDOption = new IntOption("compressTermSEED", 'c', "CompressTerm value of SEED Detector", 75, 50, 100);
          
    SeqDrift1ChangeDetector: deltaOption = new FloatOption("deltaSeqDrift1", 'd', "Delta of SeqDrift1 change detection",0.01, 0.0, 1.0);
                             deltaWarningOption = new FloatOption("deltaWarningOption", 'w', "Delta of SeqDrift1 change detector to declare warning state",0.1, 0.0, 1.0);
                             blockSeqDriftOption = new IntOption("blockSeqDrift1Option",'b',"Block size of SeqDrift1 change detector", 200, 100, 10000);

    
    SeqDrift2ChangeDetector: deltaSeqDrift2Option = new FloatOption("deltaSeq2Drift", 'd', "Delta of SeqDrift2 change detection",0.01, 0.0, 1.0);
                             blockSeqDrift2Option = new IntOption("blockSeqDrift2Option",'b',"Block size of SeqDrift2 change detector", 200, 100, 10000);

 
    EDDM: N/A
    """

    # drift_detector_type = "SKMULTIFLOW_ADWIN"
    job_conf = {
        'moa_jar_path': moa_jar_path,
        'drift_detector_type': drift_detector_type,
        'univariate_detector_type': drift_detector_type,
        'threshold': 0.01,
        'py4j_jar_path': py4j_jar_path,
    }
    logger.warning('Starting test')
    dump = False
    for univariate_cd in univariate_detector_types:
        job_conf['univariate_detector_type'] = univariate_cd
        job_conf['univariate_detector_options'] = univariate_detector_options[univariate_cd]
        print(">>> univariate_cd: %s" % univariate_cd)
        try:
            detection_steps = []
            warning_steps = []
            state = None
            for i, b in enumerate(batches):
                new_values = [value(x) for x in b]
                state = update_state(new_values, state, job_conf, dump)

                if state[1]["warning_detected"]:
                    # logger.warning('Warning zone has been detected at index: %s, data: %s' % (i, str(b)))
                    warning_steps.append(i)
                if state[1]["change_detected"]:
                    # logger.warning('Change has been detected at index: %s, data: %s' % (i, str(b)))
                    detection_steps.append(i)
            print('warning_steps (%s): %s' % (len(warning_steps), warning_steps))
            print('detection_steps (%s): %s' % (len(detection_steps), detection_steps))
            if not dump:
                state[0].gg.shutdown()
        except Exception as e:
            traceback.print_exc()
            print(">>> univariate_cd fails: %s, error: %s" % (univariate_cd, e))
Exemplo n.º 9
0
    }
    # In Memory
    # context = StreamingContext(spark.sparkContext, batch_duration)
    # lines = [[1, 2, 3], [4, 5, 6]]
    # stream = context.queueStream(lines)

    # Google storage
    # context = StreamingContext.getOrCreate(checkpointDirectory, create_context)
    # stream = context.textFileStream("gs://test1-sha456-logs-sink/data/")
    # spark._jsc.hadoopConfiguration().set('fs.gs.impl', 'com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem')
    # spark._jsc.hadoopConfiguration().set('fs.AbstractFileSystem.gs.impl',
    #                                      'com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS')
    # spark._jsc.hadoopConfiguration().set('fs.gs.project.id', project_id)
    # spark._jsc.hadoopConfiguration().set('google.cloud.auth.service.account.enable', 'true')
    # spark._jsc.hadoopConfiguration().set('google.cloud.auth.service.account.json.keyfile', cred_location)

    # Pubsub
    context = StreamingContext.getOrCreate(
        checkpoint_directory,
        lambda: create_context(spark, checkpoint_directory, batch_duration))
    stream = pubsub.PubsubUtils.createStream(context, subscription_name,
                                             batch_size, True)

    stream.flatMap(parse_request) \
        .updateStateByKey(lambda new_values, state: update_state(new_values, state, job_conf)) \
        .map(lambda state: publish_state_metric(state, pushgateway_url, myelin_ns, port, input_drift_probability_metric)) \
        .foreachRDD(lambda rdd: write_state_to_bq(rdd, state_table))

    context.start()
    context.awaitTermination()
Exemplo n.º 10
0
def solve_lvl(environment, state, goals, current_plan, depth, max_depth,
              analysis):
    padding = "".join(["**" for x in range(0, len(current_plan))]) + " "
    plan = []

    if analysis:
        print("Current Plan: {0}".format("\n".join(
            [x.simple_str() for x in current_plan])))
        print("\n###############################\n")

    if len(goals) == 0:
        return plan

    if depth > max_depth:
        return None

    i = 0
    while i < len(goals):
        goal = goals[i]

        if analysis:
            print padding + "Current Plan: {0}".format(" -> ".join(
                [x.simple_str() for x in current_plan]))
            print padding + "Subgoal: {0}".format(goal)
            print padding + "Other Goals: {0}".format(", ".join(
                [str(x) for x in goals[i + 1:]]))
            print padding + "State: {0}".format(", ".join(
                [str(s) for s in state]))
            raw_input("")

        if action_satisfied(state, goal):
            if analysis:
                raw_input(padding + "Goal is already reached!")
                print ""
            i += 1
            continue

        possible_actions = sorted(
            get_possible_actions(environment, goal),
            key=lambda c: initial_state_count(state, c.preconditions))

        # musimy znalesc podcel ktory pomoze nam osiagnac cel

        # znajdz wszystkie akcje ktore moga osiagnac cel
        if analysis:
            print padding + "set of possible actions which reach goal: {0}:".format(
                goal)
            print "\n".join(
                [padding + x.simple_str() for x in possible_actions])
            raw_input("")

        action_found = False

        for action in possible_actions:

            if analysis:
                print padding + "Trying next action to reach goal:  {0}:".format(
                    goal)
                print padding + str(action).replace("\n", "\n" + padding)
                raw_input("")

            # sprawdza czy jest przynajmniej jedna akcja dla kazdego warunku ktora osiaga go
            if not preconditions_reachable(environment, action):
                if analysis:
                    print padding + "Some preconditions not reachable by any possible action. Have to skip"
                    raw_input("")
                continue

            # sprawdza czy jakas akcja zagraza innemu celowi
            if is_dangerous(goals, action):
                if analysis:
                    print padding + "Action threaten another goal state. Have to skip"
                    raw_input("")
                continue

            # jezeli akcja nie moze byc od razu odrzucona jako nieosiagalna, obnizaj
            if analysis:
                print padding + "Action cannot be rejected as unreachable. Have to descend..."
                raw_input("")

            temporary_state = list(state)

            subgoals = list(action.preconditions)

            current_plan.append(action)

            solution = solve_lvl(environment, temporary_state, subgoals,
                                 current_plan, depth + 1, max_depth, analysis)

            # jezeli nie znaleziono rozwiazania
            if solution is None:
                if analysis:
                    print padding + "No solution found with this action. Have to skip..."
                current_plan.pop()
                continue

            if analysis:
                print padding + "Possible solution found!"
                raw_input("")

            # aktualizacja stanu o efekt ktory wprowadza nowa akcja
            for effect in action.effects:
                update_state(temporary_state, effect)

            # sprawdza czy stan nie usunal jakiegos z poprzednich podceli
            deleted = [
                x for x in goals[0:i]
                if x != goal and not action_satisfied(temporary_state, x)
            ]
            deleted_length = len(deleted)
            if len(deleted) > 0:

                if analysis:
                    print padding + "reach {0} but delete other goals: {1}".format(
                        goal, ", ".join([str(x) for x in deleted]))
                    print padding + "Re-adding the deleted goals to the end of the list"
                    raw_input("")
                [goals.remove(x) for x in deleted]
                [goals.append(x) for x in deleted]
                i -= deleted_length

                if analysis:
                    print padding + "New goals: {0}".format(", ".join(
                        [str(x) for x in goals]))
                    raw_input("")

            # dodaj podcel do celi
            plan.extend(solution)

            # akceptuj tymczasowy stan
            del state[:]
            state.extend(temporary_state)

            # dodaj akcje do planu
            plan.append(action)

            if analysis:
                print padding + "New State: " + ", ".join(
                    [str(x) for x in state])
                raw_input("")

            i += 1
            action_found = True
            break

        if not action_found:
            if analysis:
                print ""
                raw_input("**" + padding +
                          "No actions found to reach this subgoal. Go back...")
                print ""

            return None

    return plan
Exemplo n.º 11
0
def respond_message(message):

    if message.text.startswith('/start'):
        register_user(message)

    message.text = emoji.demojize(message.text)
    t = Tracker.get_or_none(Tracker.id == message.chat.id)
    user = User.get_or_none(id=message.chat.id)

    if not t or not user:
        send_message(
            message.chat.id,
            ":cross_mark: Not a registered user. Please click on /start.")
        return

    # update the username with every message
    # this is important as it is the only way to find out the user identity
    User.update(username=message.chat.username, ).where(
        User.id == message.chat.id).execute()

    # ------------------------------------
    # HOST a game
    # ------------------------------------
    if t.state == 'start' and message.text == ":desktop_computer: Host a Game":
        host_start(message, user)

    # end game
    elif t.state == 'host_game' and message.text == ":cross_mark: Leave":
        host_leave(message, user)

    # select from a list of roles
    elif t.state == 'host_game' and message.text == ":right_arrow: Next":
        # FIXME: poll is deactivated. inline keyboard is used now.
        # host_select_roles_with_poll(message, user)
        host_select_roles(message, user)

    # select from a list of roles
    elif t.state == 'host_game' and message.text == ":envelope: Send Roles":
        host_send_roles(message, user)

    # ------------------------------------
    # New Game
    # ------------------------------------
    join_code_pattern = r"^/start (?P<code>\w{4})$"
    match = re.match(join_code_pattern, message.text)
    if match:
        code = match.group("code")
        message.text = code

    if t.state == 'start' and message.text == ":game_die: Join a Game":
        player_leave(message, user)

    elif t.state == 'join_game':
        player_start(message, user)

    elif t.state == 'start' and match:
        player_start(message, user)

    # ------------------------------------
    # Change Name
    # ------------------------------------
    if t.state == 'start' and message.text == ":bust_in_silhouette: Change Name":
        update_state(user, 'change_name')
        text = f":bust_in_silhouette: Current name: <b>{f2p(user.name)}</b>\n\n"
        text += ":input_latin_letters: Enter your new name:"
        send_message(user.id,
                     text,
                     reply_markup=create_keyboard([":cross_mark: Discard"]))

    elif t.state == 'change_name' and message.text == ":cross_mark: Discard":
        update_state(user, 'start')
        send_message(user.id,
                     ":cross_mark: Discard",
                     reply_markup=keyboards.main)

    elif t.state == 'change_name':
        if len(message.text) > 100:
            send_message(user.id,
                         "Name length must be less than 100 characters.")
            return

        User.update(name=message.text).where(User.id == user.id).execute()
        update_state(user, 'start')
        send_message(
            user.id,
            f":white_heavy_check_mark: Your Name is updated now to: <b>{message.text}</b>",
            reply_markup=keyboards.main)

    # ------------------------------------
    # Settings
    # ------------------------------------
    if t.state == 'start' and message.text == ":gear_selector: Settings":
        edit_game_settings(message, user)