Beispiel #1
0
 def test(val_data, ctx):
     metric = CtcMetrics(num_classes=config.num_classes)
     metric.reset()
     for datas, labels in val_data:
         data = gluon.utils.split_and_load(nd.array(datas),
                                           ctx_list=ctx,
                                           batch_axis=0,
                                           even_split=False)
         label = gluon.utils.split_and_load(nd.array(labels),
                                            ctx_list=ctx,
                                            batch_axis=0,
                                            even_split=False)
         output = [net(X) for X in data]
         metric.update(label, output)
     return metric.get()
Beispiel #2
0
    def train(ctx, batch_size):
        #net.initialize(mx.init.Xavier(), ctx=ctx)
        train_data = DataLoader(ImageDataset(root=default.dataset_path, train=True), \
                                batch_size=batch_size,shuffle=True,num_workers=num_workers)
        val_data = DataLoader(ImageDataset(root=default.dataset_path, train=False), \
                              batch_size=batch_size, shuffle=True,num_workers=num_workers)

        # lr_epoch = [int(epoch) for epoch in args.lr_step.split(',')]
        net.collect_params().reset_ctx(ctx)
        lr = args.lr
        end_lr = args.end_lr
        lr_decay = args.lr_decay
        lr_decay_step = args.lr_decay_step
        all_step = len(train_data)
        schedule = mx.lr_scheduler.FactorScheduler(step=lr_decay_step *
                                                   all_step,
                                                   factor=lr_decay,
                                                   stop_factor_lr=end_lr)
        adam_optimizer = mx.optimizer.Adam(learning_rate=lr,
                                           lr_scheduler=schedule)
        trainer = gluon.Trainer(net.collect_params(), optimizer=adam_optimizer)

        train_metric = CtcMetrics()
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        best_val_score = 0

        save_period = args.save_period
        save_dir = args.save_dir
        model_name = args.prefix
        plot_path = args.save_dir
        epochs = args.end_epoch
        frequent = args.frequent
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            train_loss = 0
            num_batch = 0
            tic_b = time.time()
            for datas, labels in train_data:
                data = gluon.utils.split_and_load(nd.array(datas),
                                                  ctx_list=ctx,
                                                  batch_axis=0,
                                                  even_split=False)
                label = gluon.utils.split_and_load(nd.array(labels),
                                                   ctx_list=ctx,
                                                   batch_axis=0,
                                                   even_split=False)
                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1
                num_batch += 1
                if num_batch % frequent == 0:
                    train_loss_b = train_loss / (batch_size * num_batch)
                    logging.info(
                        '[Epoch %d] [num_bath %d] tain_acc=%f loss=%f time/batch: %f'
                        % (epoch, num_batch, acc, train_loss_b,
                           (time.time() - tic_b) / num_batch))
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(val_data, ctx)
            train_history.update([1 - acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_path, model_name))
            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-crnn-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                symbol_file = os.path.join(save_dir, model_name)
                net.export(path=symbol_file, epoch=epoch)
                # net.save_parameters('%s/crnn-%s-%d.params' % (save_dir, model_name, epoch))

        if save_period and save_dir:
            symbol_file = os.path.join(save_dir, model_name)
            net.export(path=symbol_file, epoch=epoch - 1)