示例#1
0
dali_iter = DALIClassificationIterator(train_pipes, size)

while it < iters + 1:

    for batches in tqdm(dali_iter):
        if it == lr_steps[lr_counter]:
            trainer.set_learning_rate(trainer.learning_rate * 0.1)
            lr_counter += 1
        datas, labels = split_and_load(batches, num_gpu)

        with ag.record():
            ots = [net(X) for X in datas]
            embedds = [ot[0] for ot in ots]
            outputs = [ot[1] for ot in ots]
            losses = [
                loss(yhat, y, emb)
                for yhat, y, emb in zip(outputs, labels, embedds)
            ]

        for l in losses:
            ag.backward(l)

        trainer.step(batch_size)
        acc_mtc.update(labels, outputs)
        loss_mtc.update(0, losses)

        if (it % save_period) == 0 and it != 0:
            _, train_loss = loss_mtc.get()
            _, train_acc = acc_mtc.get()
            toc = time.time()
            logger.info(
        it = epoch * batches_per_epoch + i
        lr_scheduler.update(i, epoch)

        datas = gluon.utils.split_and_load(batch[0],
                                           ctx_list=ctx,
                                           batch_axis=0,
                                           even_split=False)
        labels = gluon.utils.split_and_load(batch[1],
                                            ctx_list=ctx,
                                            batch_axis=0,
                                            even_split=False)

        with ag.record():
            ots = [net(X) for X in datas]
            outputs = [ot[1] for ot in ots]
            losses = [loss(yhat, y) for yhat, y in zip(outputs, labels)]

        for l in losses:
            ag.backward(l)

        trainer.step(batch_size)
        acc_mtc.update(labels, outputs)
        loss_mtc.update(0, losses)

        if (it % opt.save_frequency) == 0:
            _, train_loss = loss_mtc.get()
            train_metric_name, train_acc = acc_mtc.get()
            toc = time.time()

            logger.info(
                '\nEpoch[%d] Batch[%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f train loss: %.6f'
示例#3
0
# metrics
loss_mtc, acc_mtc = mx.metric.Loss(), mx.metric.Accuracy()
tic = time.time()
btic = time.time()

# train loop
for epoch in range(epochs):
    for i, batch in enumerate(train_iter):
        it = epoch * num_batches + i
        data = batch[0].data[0]
        label = batch[0].label[0]

        with ag.record():
            embedding, output = net(data)
            batch_loss = loss(output, label, embedding)

        ag.backward(batch_loss)
        trainer.step(batch_size)

        acc_mtc.update([label], [output])
        loss_mtc.update(0, [batch_loss])

        if (it % save_period) == 0 and it != 0:
            _, train_loss = loss_mtc.get()
            _, train_acc = acc_mtc.get()
            toc = time.time()
            logger.info(
                '\n[epoch % 2d] [it % 3d] train loss: %.6f, train_acc: %.6f | '
                'learning rate: %.8f speed: %.2f samples/s, time: %.6f' %
                (epoch, it, train_loss, train_acc, trainer.learning_rate,