Exemplo n.º 1
0
def model_validate(val_data_normalized, train_window, train_history,
                   train_forward, model, cuda, feature, label):
    # model Validation
    model.eval()
    val_feature, val_label, val_exog, val_len = nn_create_val_inout_sequences(
        val_data_normalized, train_window, train_history, train_forward,
        feature, label, cuda)
    val_set = DataHolder(val_feature, val_label, val_exog, 1, val_len, cuda)
    val_batch = DataLoader(val_set,
                           batch_size=1,
                           shuffle=False,
                           drop_last=True)
    max_tensor_len = train_window - 1
    val_single_loss = 0

    val_pred_all = []
    current_batch = tqdm(val_batch)
    loss_function = nn.MSELoss()

    for idx, batch in enumerate(current_batch):
        feature_tensor, label_tensor, exog_tensor, len_tensor = batch
        model.init_hidden(1, cuda)

        if idx % 12 == 0:
            if idx != 0:
                val_pred_all.append(val_pred_batch)
            val_pred_batch = []
            temp_tensor = feature_tensor
            temp_padded_tensor = pad_tensor(temp_tensor,
                                            len_tensor.squeeze().item(),
                                            max_tensor_len, cuda)
        else:
            temp_padded_tensor = pad_tensor(temp_tensor,
                                            len_tensor.squeeze().item(),
                                            max_tensor_len, cuda)

        with torch.no_grad():

            val_mean_pred, val_std_pred = model(temp_padded_tensor,
                                                exog_tensor, len_tensor, 1)
            val_pred_batch.append(
                (val_mean_pred.clone().detach().cpu().numpy(),
                 val_std_pred.clone().detach().cpu().numpy(),
                 label_tensor.clone().detach().cpu().numpy()))

            single_loss = loss_function(val_mean_pred.squeeze(),
                                        label_tensor.squeeze())

            val_single_loss += single_loss.item()
            #val_pred_all.append((val_mean_pred.squeeze().item(), val_std_pred.squeeze().item()))

            #print(val_pred.unsqueeze(0).size(), exog_tensor.size())
            temp_feature = torch.cat((val_mean_pred.unsqueeze(0), exog_tensor),
                                     dim=2)

            #print(temp_feature.size())
            #print(temp_tensor.size())
            temp_tensor = torch.cat((temp_tensor, temp_feature), dim=1)

    return val_single_loss, val_pred_all
Exemplo n.º 2
0
 def user_input(self, tag, attachment=None, timeout=15, g=None, st=None):
     # Peer player, get his input from server
     g = g if g else Game.getgame()
     st = st if st else g.get_synctag()
     input = DataHolder()
     input.timeout = timeout
     input.player = self
     input.input = None
     input.tag = tag
     g.emit_event('user_input_start', input)
     input.input = Executive.server.gexpect('input_%s_%d' % (tag, st))
     g.emit_event('user_input_finish', input)
     return input.input
Exemplo n.º 3
0
    def user_input(self, tag, attachment=None, timeout=15, g=None, st=None):
        g = g if g else Game.getgame()
        st = st if st else g.get_synctag()
        input = DataHolder()
        input.tag = tag
        input.input = None
        input.attachment = attachment
        input.timeout = timeout
        input.player = self

        class Break(Exception):
            pass

        cur_greenlet = gevent.getcurrent()

        def waiter_func():
            rst = Executive.server.gexpect('input_%s_%d' % (tag, st))
            cur_greenlet.kill(Break(), block=False)
            return rst

        try:
            tle = TimeLimitExceeded(timeout + 1)
            tle.start()
            waiter = gevent.spawn(waiter_func)
            g.emit_event('user_input_start', input)
            gevent.sleep(0)
            rst = g.emit_event('user_input', input)
            Executive.server.gwrite('input_%s_%d' % (tag, st), rst.input)
        except (Break, TimeLimitExceeded) as e:
            if isinstance(e, TimeLimitExceeded) and e is not tle:
                raise
            g.emit_event('user_input_timeout', input)
            rst = input
        except:
            waiter.kill()
            raise
        finally:
            tle.cancel()
            g.emit_event('user_input_finish', input)
            try:
                waiter.join()
                gevent.sleep(0)
            except Break:
                pass

        rst.input = waiter.get()
        g.emit_event('user_input_finish', input)
        return rst.input
Exemplo n.º 4
0
 def user_input(self, tag, attachment=None, timeout=15, g=None, st=None):
     # Peer player, get his input from server
     g = g if g else Game.getgame()
     st = st if st else g.get_synctag()
     input = DataHolder()
     input.timeout = timeout
     input.player = self
     input.input = None
     input.tag = tag
     g.emit_event('user_input_start', input)
     input.input = Executive.server.gexpect('input_%s_%d' % (tag, st))
     g.emit_event('user_input_finish', input)
     return input.input
Exemplo n.º 5
0
def train_model(train_data_normalized, batch_size, train_window, train_history,
                train_forward, model, optimizer, epochs, cuda, feature, label):
    model.train()
    #print(model.nn_hidden.weight)
    train_feature, train_label, train_exog, train_len = nn_create_inout_sequences(
        train_data_normalized, train_window, train_history, train_forward,
        feature, label, cuda)
    batch_train_data = DataHolder(train_feature, train_label, train_exog, 1,
                                  train_len, cuda)
    train_batches = DataLoader(batch_train_data,
                               batch_size=batch_size,
                               shuffle=True,
                               drop_last=True)

    # loss_function = model.loss_function
    loss_function = nn.MSELoss()
    model_optimizer = optimizer

    for i in range(epochs):
        epoch_loss = 0
        current_batch = tqdm(train_batches)
        for idx, batch in enumerate(current_batch):
            feature_tensor, label_tensor, exog_tensor, len_tensor = batch
            model.init_hidden(batch_size, cuda)

            model_optimizer.zero_grad()

            mean_pred, std_pred = model(feature_tensor, exog_tensor,
                                        len_tensor, batch_size)
            #single_loss = loss_function(mean_pred.squeeze(), std_pred.squeeze(), label_tensor, cuda)
            single_loss = loss_function(mean_pred.squeeze(), label_tensor)
            single_loss.backward()
            model_optimizer.step()
            epoch_loss += single_loss.item()
            current_batch.set_description(
                f'epoch: {i:3} loss: {epoch_loss / (idx + 1):10.8f} ')
    #print(model.nn_hidden.weight)

    return epoch_loss / (idx + 1)
Exemplo n.º 6
0
    def user_input(self, tag, attachment=None, timeout=15, g=None, st=None):
        g = g if g else Game.getgame()
        st = st if st else g.get_synctag()
        input = DataHolder()
        input.tag = tag
        input.input = None
        input.attachment = attachment
        input.timeout = timeout
        input.player = self

        class Break(Exception): pass

        cur_greenlet = gevent.getcurrent()
        def waiter_func():
            rst = Executive.server.gexpect('input_%s_%d' % (tag, st))
            cur_greenlet.kill(Break(), block=False)
            return rst

        try:
            tle = TimeLimitExceeded(timeout+1)
            tle.start()
            waiter = gevent.spawn(waiter_func)
            g.emit_event('user_input_start', input)
            gevent.sleep(0)
            rst = g.emit_event('user_input', input)
            Executive.server.gwrite('input_%s_%d' % (tag, st), rst.input)
        except (Break, TimeLimitExceeded) as e:
            if isinstance(e, TimeLimitExceeded) and e is not tle:
                raise
            g.emit_event('user_input_timeout', input)
            rst = input
        except:
            waiter.kill()
            raise
        finally:
            tle.cancel()
            g.emit_event('user_input_finish', input)
            try:
                waiter.join()
                gevent.sleep(0)
            except Break:
                pass

        rst.input = waiter.get()
        g.emit_event('user_input_finish', input)
        return rst.input
Exemplo n.º 7
0
    def user_input_any(self, tag, expects, attachment=None, timeout=15):
        g = Game.getgame()
        st = g.get_synctag()

        tagstr = 'inputany_%s_%d' % (tag, st)

        g.emit_event('user_input_any_begin', (self, tag, attachment))

        input = DataHolder()
        input.tag = tag
        input.input = None
        input.attachment = attachment
        input.timeout = timeout
        input.player = g.me

        class Break(Exception): pass # ('Input: you are too late!')

        def waiter_func():
            pid, data = Executive.server.gexpect(tagstr + '_resp')
            g.kill(Break(), block=False)
            return pid, data

        if isinstance(g.me, TheChosenOne) and g.me in self:
            try:
                waiter = gevent.spawn(waiter_func)
                tle = TimeLimitExceeded(timeout)
                tle.start()
                g.emit_event('user_input_start', input)
                rst = g.emit_event('user_input', input)
                Executive.server.gwrite(tagstr, rst.input)
            except (Break, TimeLimitExceeded) as e:
                if isinstance(e, TimeLimitExceeded) and e is not tle:
                    raise
                g.emit_event('user_input_timeout', input)
                rst = input
                rst.input = None
                Executive.server.gwrite(tagstr, rst.input)
            except:
                waiter.kill()
                raise
            finally:
                tle.cancel()
                g.emit_event('user_input_finish', input)
                try:
                    waiter.join()
                    gevent.sleep(0)
                except Break:
                    pass

        else:
            # none of my business, just wait for the result
            try:
                waiter = gevent.spawn(waiter_func)
                waiter.join()
                gevent.sleep(0)
            except Break:
                pass

        pid, data = waiter.get()

        g.emit_event('user_input_any_end', tag)

        if pid is None:
            return None, None

        p = g.player_fromid(pid)

        if not expects(p, data):
            raise GameError('WTF?! Server cheats!')

        return p, data
Exemplo n.º 8
0
 def __init__(self, *a):
     from utils import DataHolder
     self.uniform = DataHolder()
     self.attrib = DataHolder()
Exemplo n.º 9
0
    def user_input_any(self, tag, expects, attachment=None, timeout=15):
        g = Game.getgame()
        st = g.get_synctag()

        tagstr = 'inputany_%s_%d' % (tag, st)

        g.emit_event('user_input_any_begin', (self, tag, attachment))

        input = DataHolder()
        input.tag = tag
        input.input = None
        input.attachment = attachment
        input.timeout = timeout
        input.player = g.me

        class Break(Exception):
            pass  # ('Input: you are too late!')

        def waiter_func():
            pid, data = Executive.server.gexpect(tagstr + '_resp')
            g.kill(Break(), block=False)
            return pid, data

        if isinstance(g.me, TheChosenOne) and g.me in self:
            try:
                waiter = gevent.spawn(waiter_func)
                tle = TimeLimitExceeded(timeout)
                tle.start()
                g.emit_event('user_input_start', input)
                rst = g.emit_event('user_input', input)
                Executive.server.gwrite(tagstr, rst.input)
            except (Break, TimeLimitExceeded) as e:
                if isinstance(e, TimeLimitExceeded) and e is not tle:
                    raise
                g.emit_event('user_input_timeout', input)
                rst = input
                rst.input = None
                Executive.server.gwrite(tagstr, rst.input)
            except:
                waiter.kill()
                raise
            finally:
                tle.cancel()
                g.emit_event('user_input_finish', input)
                try:
                    waiter.join()
                    gevent.sleep(0)
                except Break:
                    pass

        else:
            # none of my business, just wait for the result
            try:
                waiter = gevent.spawn(waiter_func)
                waiter.join()
                gevent.sleep(0)
            except Break:
                pass

        pid, data = waiter.get()

        g.emit_event('user_input_any_end', tag)

        if pid is None:
            return None, None

        p = g.player_fromid(pid)

        if not expects(p, data):
            raise GameError('WTF?! Server cheats!')

        return p, data
Exemplo n.º 10
0
def tag_metafunc(clsname, bases, _dict):
    data = DataHolder.parse(_dict)
    tags[clsname] = data