示例#1
0
def trainIters(args):
    charSet = CharSet(args['LANGUAGE'])

    watch = Watch(args['LAYER_SIZE'], args['HIDDEN_SIZE'], args['HIDDEN_SIZE'])
    spell = Spell(args['LAYER_SIZE'], args['HIDDEN_SIZE'],
                  charSet.get_total_num())

    # watch = nn.DataParallel(watch)
    # spell = nn.DataParallel(spell)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    watch = watch.to(device)
    spell = spell.to(device)

    watch_optimizer = optim.Adam(watch.parameters(), lr=args['LEARNING_RATE'])
    spell_optimizer = optim.Adam(spell.parameters(), lr=args['LEARNING_RATE'])
    watch_scheduler = optim.lr_scheduler.StepLR(
        watch_optimizer,
        step_size=args['LEARNING_RATE_DECAY_EPOCH'],
        gamma=args['LEARNING_RATE_DECAY_RATIO'])
    spell_scheduler = optim.lr_scheduler.StepLR(
        spell_optimizer,
        step_size=args['LEARNING_RATE_DECAY_EPOCH'],
        gamma=args['LEARNING_RATE_DECAY_RATIO'])
    criterion = nn.CrossEntropyLoss(ignore_index=charSet.get_index_of('<pad>'))

    train_loader, eval_loader = get_dataloaders(args['PATH'], args['BS'],
                                                args['VMAX'], args['TMAX'],
                                                args['WORKER'], charSet,
                                                args['VALIDATION_RATIO'])
    # train_loader = DataLoader(dataset=dataset,
    #                     batch_size=batch_size,
    #                     shuffle=True)
    total_batch = len(train_loader)
    total_eval_batch = len(eval_loader)

    for epoch in range(args['ITER']):
        avg_loss = 0.0
        avg_eval_loss = 0.0
        avg_cer = 0.0
        avg_eval_cer = 0.0
        watch_scheduler.step()
        spell_scheduler.step()

        watch = watch.train()
        spell = spell.train()

        for i, (data, labels) in enumerate(train_loader):

            loss, cer = train(data, labels, watch, spell, watch_optimizer,
                              spell_optimizer, criterion, True, charSet)
            avg_loss += loss
            avg_cer += cer
            print('Batch : ', i + 1, '/', total_batch,
                  ', ERROR in this minibatch: ', loss)
            print('Character error rate : ', cer)

        watch = watch.eval()
        spell = spell.eval()

        for k, (data, labels) in enumerate(eval_loader):
            loss, cer = train(data, labels, watch, spell, watch_optimizer,
                              spell_optimizer, criterion, False, charSet)
            avg_eval_loss += loss
            avg_eval_cer += cer
        print('epoch:', epoch, ' train_loss:', float(avg_loss / total_batch))
        print('epoch:', epoch, ' Average CER:', float(avg_cer / total_batch))
        print('epoch:', epoch, ' Validation_loss:',
              float(avg_eval_loss / total_eval_batch))
        print('epoch:', epoch, ' Average CER:',
              float(avg_eval_cer / total_eval_batch))
        if epoch % args['SAVE_EVERY'] == 0 and epoch != 0:
            torch.save(watch, 'watch{}.pt'.format(epoch))
            torch.save(spell, 'spell{}.pt'.format(epoch))
示例#2
0
def trainIters(n_iters,
               videomax,
               txtmax,
               data_path,
               batch_size,
               worker,
               ratio_of_validation=0.0001,
               learning_rate_decay=2000,
               save_every=30,
               learning_rate=0.01):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    watch = Watch(3, 512, 512)
    spell = Spell(num_layers=3, output_size=len(int_list), hidden_size=512)

    watch = watch.to(device)
    spell = spell.to(device)

    watch_optimizer = optim.Adam(watch.parameters(), lr=learning_rate)
    spell_optimizer = optim.Adam(spell.parameters(), lr=learning_rate)
    watch_scheduler = optim.lr_scheduler.StepLR(watch_optimizer,
                                                step_size=learning_rate_decay,
                                                gamma=0.1)
    spell_scheduler = optim.lr_scheduler.StepLR(spell_optimizer,
                                                step_size=learning_rate_decay,
                                                gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=38)

    train_loader, eval_loader = get_dataloaders(
        data_path,
        batch_size,
        videomax,
        txtmax,
        worker,
        ratio_of_validation=ratio_of_validation)
    # train_loader = DataLoader(dataset=dataset,
    #                     batch_size=batch_size,
    #                     shuffle=True)
    total_batch = len(train_loader)
    total_eval_batch = len(eval_loader)

    for epoch in range(n_iters):
        avg_loss = 0.0
        avg_eval_loss = 0.0
        watch_scheduler.step()
        spell_scheduler.step()

        watch = watch.train()
        spell = spell.train()

        for i, (data, labels) in enumerate(train_loader):

            loss = train(data.to(device), labels.to(device), watch, spell,
                         watch_optimizer, spell_optimizer, criterion, True)
            avg_loss += loss
            print('Batch : ', i + 1, '/', total_batch,
                  ', ERROR in this minibatch: ', loss)
            del data, labels, loss

        watch = watch.eval()
        spell = spell.eval()

        for k, (data, labels) in enumerate(eval_loader):
            loss = train(data.to(device), labels.to(device), watch, spell,
                         watch_optimizer, spell_optimizer, criterion, False)
            avg_eval_loss += loss
            print('Batch : ', i + 1, '/', total_batch,
                  ', Validation ERROR in this minibatch: ', loss)
            del data, labels, loss

        print('epoch:', epoch, ' train_loss:', float(avg_loss / total_batch))
        print('epoch:', epoch, ' eval_loss:',
              float(avg_eval_loss / total_eval_batch))
        if epoch % save_every == 0 and epoch != 0:
            torch.save(watch, 'watch{}.pt'.format(epoch))
            torch.save(spell, 'spell{}.pt'.format(epoch))