Пример #1
0
 def get_area_spell_targets(self,
                            center: Cell = None,
                            row: int = None,
                            col: int = None,
                            spell: Spell = None,
                            type_id: int = None) -> List["Unit"]:
     if spell is None:
         if type_id is not None:
             spell = self.get_cast_spell_by_id(type_id)
         else:
             return []
     if type(spell) is not Spell:
         Logs.show_log("invalid spell chosen in get_area_spell_targets")
         return []
     if not spell.is_area_spell():
         return []
     if center is None:
         center = Cell(row, col)
     ls = []
     for i in range(max(0, center.row - spell.range),
                    min(center.row + spell.range + 1, self._map.row_num)):
         for j in range(
                 max(0, center.col - spell.range),
                 min(center.col + spell.range + 1, self._map.col_num)):
             cell = self._map.get_cell(i, j)
             for u in cell.units:
                 if self._is_unit_targeted(u, spell.target):
                     ls.append(u)
     return ls
Пример #2
0
 def _spells_init(self, msg):
     self.spells = [Spell(type=SpellType.get_value(spell["type"]),
                          type_id=spell["typeId"],
                          duration=spell["duration"],
                          priority=spell["priority"],
                          range=spell["range"],
                          power=spell["power"],
                          target=SpellTarget.get_value(spell["target"]),
                          is_damaging=False)
                    for spell in msg]
Пример #3
0
def trainIters(args):
    charSet = CharSet(args['LANGUAGE'])

    watch = Watch(args['LAYER_SIZE'], args['HIDDEN_SIZE'], args['HIDDEN_SIZE'])
    spell = Spell(args['LAYER_SIZE'], args['HIDDEN_SIZE'],
                  charSet.get_total_num())

    # watch = nn.DataParallel(watch)
    # spell = nn.DataParallel(spell)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    watch = watch.to(device)
    spell = spell.to(device)

    watch_optimizer = optim.Adam(watch.parameters(), lr=args['LEARNING_RATE'])
    spell_optimizer = optim.Adam(spell.parameters(), lr=args['LEARNING_RATE'])
    watch_scheduler = optim.lr_scheduler.StepLR(
        watch_optimizer,
        step_size=args['LEARNING_RATE_DECAY_EPOCH'],
        gamma=args['LEARNING_RATE_DECAY_RATIO'])
    spell_scheduler = optim.lr_scheduler.StepLR(
        spell_optimizer,
        step_size=args['LEARNING_RATE_DECAY_EPOCH'],
        gamma=args['LEARNING_RATE_DECAY_RATIO'])
    criterion = nn.CrossEntropyLoss(ignore_index=charSet.get_index_of('<pad>'))

    train_loader, eval_loader = get_dataloaders(args['PATH'], args['BS'],
                                                args['VMAX'], args['TMAX'],
                                                args['WORKER'], charSet,
                                                args['VALIDATION_RATIO'])
    # train_loader = DataLoader(dataset=dataset,
    #                     batch_size=batch_size,
    #                     shuffle=True)
    total_batch = len(train_loader)
    total_eval_batch = len(eval_loader)

    for epoch in range(args['ITER']):
        avg_loss = 0.0
        avg_eval_loss = 0.0
        avg_cer = 0.0
        avg_eval_cer = 0.0
        watch_scheduler.step()
        spell_scheduler.step()

        watch = watch.train()
        spell = spell.train()

        for i, (data, labels) in enumerate(train_loader):

            loss, cer = train(data, labels, watch, spell, watch_optimizer,
                              spell_optimizer, criterion, True, charSet)
            avg_loss += loss
            avg_cer += cer
            print('Batch : ', i + 1, '/', total_batch,
                  ', ERROR in this minibatch: ', loss)
            print('Character error rate : ', cer)

        watch = watch.eval()
        spell = spell.eval()

        for k, (data, labels) in enumerate(eval_loader):
            loss, cer = train(data, labels, watch, spell, watch_optimizer,
                              spell_optimizer, criterion, False, charSet)
            avg_eval_loss += loss
            avg_eval_cer += cer
        print('epoch:', epoch, ' train_loss:', float(avg_loss / total_batch))
        print('epoch:', epoch, ' Average CER:', float(avg_cer / total_batch))
        print('epoch:', epoch, ' Validation_loss:',
              float(avg_eval_loss / total_eval_batch))
        print('epoch:', epoch, ' Average CER:',
              float(avg_eval_cer / total_eval_batch))
        if epoch % args['SAVE_EVERY'] == 0 and epoch != 0:
            torch.save(watch, 'watch{}.pt'.format(epoch))
            torch.save(spell, 'spell{}.pt'.format(epoch))
Пример #4
0
def run(dataset_path):
    os.makedirs('checkpoints', exist_ok=True)
    writer = SummaryWriter()

    watcher = Watch().to('cuda' if torch.cuda.is_available() else 'cpu')
    speller = Spell().to('cuda' if torch.cuda.is_available() else 'cpu')


    # Applying learning rate decay as we observed loss diverge.
    # https://discuss.pytorch.org/t/how-to-use-torch-optim-lr-scheduler-exponentiallr/12444/6
    watch_optimizer = optim.Adam(watcher.parameters(), lr=LEARNING_RATE)
    spell_optimizer = optim.Adam(speller.parameters(), lr=LEARNING_RATE)
    watch_scheduler = optim.lr_scheduler.StepLR(watch_optimizer, step_size=LEARNING_RATE_STEP, gamma=LEARNING_RATE_GAMMA)
    spell_scheduler = optim.lr_scheduler.StepLR(spell_optimizer, step_size=LEARNING_RATE_STEP, gamma=LEARNING_RATE_GAMMA)

    dataset = VideoDataset(dataset_path)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    criterion = nn.CrossEntropyLoss()

    losses = []
    for epoch in range(1, 100):
        watch_scheduler.step()
        spell_scheduler.step()

        watcher = watcher.train()
        speller = speller.train()

        for i, (x, chars) in tqdm(enumerate(dataloader, 1), total=len(dataloader), desc=f'Epoch: {epoch}'):
            chars = chars[0]
            loss = 0
            watch_optimizer.zero_grad()
            spell_optimizer.zero_grad()

            x = x.to(device)
            chars = chars.to(device)
            chars_len = chars.size(0)

            output_from_vgg_lstm, states_from_vgg_lstm = watcher(x)

            spell_input = torch.tensor([[CHAR_SET.index('<sos>')]]).repeat(output_from_vgg_lstm.size(0), 1).to(device)
            spell_hidden = states_from_vgg_lstm
            spell_state = torch.zeros_like(spell_hidden).to(device)
            context = torch.zeros(output_from_vgg_lstm.size(0), 1, spell_hidden.size(2)).to(device)

            for idx in range(chars_len):
                spell_output, spell_hidden, spell_state, context = speller(spell_input, spell_hidden, spell_state, output_from_vgg_lstm, context)
                _, topi = spell_output.topk(1, dim=2)
                spell_input = chars[idx].long().view(1, 1)
                loss += criterion(spell_output.squeeze(1), chars[idx].long().view(1))

            loss = loss.to('cuda' if torch.cuda.is_available() else 'cpu')
            loss.backward()
            watch_optimizer.step()
            spell_optimizer.step()

            norm_loss = float(loss / chars.size(0))
            losses.append(norm_loss)

            writer.add_scalar('train/loss', norm_loss, global_step=epoch * len(dataloader) + i)
            writer.add_scalar('train/lr-watcher', watch_scheduler.get_lr()[0], global_step=epoch * len(dataloader) + i)
            writer.add_scalar('train/lr-speller', spell_scheduler.get_lr()[0], global_step=epoch * len(dataloader) + i)

        watcher = watcher.eval()
        speller = speller.eval()

        torch.save({
            'watcher': watcher.state_dict(),
            'speller': speller.state_dict(),
            }, f'checkpoints/{epoch:03d}_{norm_loss}.pth')


    print(f'{losses}')
Пример #5
0
def trainIters(n_iters,
               videomax,
               txtmax,
               data_path,
               batch_size,
               worker,
               ratio_of_validation=0.0001,
               learning_rate_decay=2000,
               save_every=30,
               learning_rate=0.01):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    watch = Watch(3, 512, 512)
    spell = Spell(num_layers=3, output_size=len(int_list), hidden_size=512)

    watch = watch.to(device)
    spell = spell.to(device)

    watch_optimizer = optim.Adam(watch.parameters(), lr=learning_rate)
    spell_optimizer = optim.Adam(spell.parameters(), lr=learning_rate)
    watch_scheduler = optim.lr_scheduler.StepLR(watch_optimizer,
                                                step_size=learning_rate_decay,
                                                gamma=0.1)
    spell_scheduler = optim.lr_scheduler.StepLR(spell_optimizer,
                                                step_size=learning_rate_decay,
                                                gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=38)

    train_loader, eval_loader = get_dataloaders(
        data_path,
        batch_size,
        videomax,
        txtmax,
        worker,
        ratio_of_validation=ratio_of_validation)
    # train_loader = DataLoader(dataset=dataset,
    #                     batch_size=batch_size,
    #                     shuffle=True)
    total_batch = len(train_loader)
    total_eval_batch = len(eval_loader)

    for epoch in range(n_iters):
        avg_loss = 0.0
        avg_eval_loss = 0.0
        watch_scheduler.step()
        spell_scheduler.step()

        watch = watch.train()
        spell = spell.train()

        for i, (data, labels) in enumerate(train_loader):

            loss = train(data.to(device), labels.to(device), watch, spell,
                         watch_optimizer, spell_optimizer, criterion, True)
            avg_loss += loss
            print('Batch : ', i + 1, '/', total_batch,
                  ', ERROR in this minibatch: ', loss)
            del data, labels, loss

        watch = watch.eval()
        spell = spell.eval()

        for k, (data, labels) in enumerate(eval_loader):
            loss = train(data.to(device), labels.to(device), watch, spell,
                         watch_optimizer, spell_optimizer, criterion, False)
            avg_eval_loss += loss
            print('Batch : ', i + 1, '/', total_batch,
                  ', Validation ERROR in this minibatch: ', loss)
            del data, labels, loss

        print('epoch:', epoch, ' train_loss:', float(avg_loss / total_batch))
        print('epoch:', epoch, ' eval_loss:',
              float(avg_eval_loss / total_eval_batch))
        if epoch % save_every == 0 and epoch != 0:
            torch.save(watch, 'watch{}.pt'.format(epoch))
            torch.save(spell, 'spell{}.pt'.format(epoch))