Example #1
0
def train(data_iter, net, cross_entropy, trainer, num_epochs, batch_size):
    sw = mb.SummaryWriter(logdir='./logs', flush_secs=2)
    params = net.collect_params('.*W|.*dense')
    param_names = params.keys()
    ls = 0
    # train_x, train_y, test_x, test_y = allData['train_x'], allData['train_y'], allData['test_x'], allData['test_y']
    for epoch in range(num_epochs):
        train_loss_sum, train_acc_sum, n, start = 0., 0., 0., time.time()
        for X, Y in data_iter:
            # X.attach_grad()
            with autograd.record():
                pre = net(X.reshape(*X.shape, 1))
                loss = cross_entropy(pre, Y).sum()
            loss.backward()
            trainer.step(batch_size)

            # 记录
            train_loss_sum += loss.asscalar()
            train_acc_sum += (pre.argmax(axis=1) == Y).sum().asscalar()
            n += len(Y)
            sw.add_histogram(tag='cross_entropy', values=train_loss_sum / n, global_step=ls)

            for i, name in enumerate(param_names):
                sw.add_histogram(tag=name,
                                 values=net.collect_params()[name].grad(),
                                 global_step=ls, bins=1000)
            ls += 1
        # test_acc = evaluate_accuracy(test_x, test_y, net)
        print('epoch %d, loss %.4f, train acc %.3f,  time %.1f sec' %
              (epoch + 1, train_loss_sum / n, train_acc_sum / n, time.time() - start))
    sw.close()
    return net
Example #2
0
    def _init_train(self):
        self.exp = self.exp + datetime.datetime.now().strftime("%m-%dx%H-%M")

        self.batch_size *= len(self.ctx)

        print(global_variable.yellow)
        print('Batch Size = {}'.format(self.batch_size))
        print('Record Step = {}'.format(self.record_step))

        #self.L1_loss = mxnet.gluon.loss.L1Loss()
        #self.L2_loss = mxnet.gluon.loss.L2Loss()
        self.HB_loss = mxnet.gluon.loss.HuberLoss()
        self.LG_loss = mxnet.gluon.loss.LogisticLoss(label_format='binary')
        self.CE_loss = mxnet.gluon.loss.SoftmaxCrossEntropyLoss(
            from_logits=False, sparse_label=False)

        # -------------------- init trainer -------------------- #
        optimizer = mxnet.optimizer.create('adam',
                                           learning_rate=self.learning_rate,
                                           multi_precision=False)

        self.trainer = mxnet.gluon.Trainer(self.net.collect_params(),
                                           optimizer=optimizer)

        logdir = os.path.join(self.version, 'logs')
        self.sw = mxboard.SummaryWriter(logdir=logdir, verbose=True)

        if not os.path.exists(self.backup_dir):
            os.makedirs(self.backup_dir)
Example #3
0
 def __init__(self,
              logdir: str,
              source_vocab: Optional[vocab.Vocab] = None,
              target_vocab: Optional[vocab.Vocab] = None) -> None:
     self.logdir = logdir
     self.source_labels = vocab.get_ordered_tokens_from_vocab(
         source_vocab) if source_vocab is not None else None
     self.target_labels = vocab.get_ordered_tokens_from_vocab(
         target_vocab) if target_vocab is not None else None
     try:
         import mxboard
         logger.info("Logging training events for Tensorboard at '%s'",
                     self.logdir)
         self._writer = mxboard.SummaryWriter(logdir=self.logdir,
                                              flush_secs=60,
                                              verbose=False)
     except ImportError:
         logger.info(
             "mxboard not found. Consider 'pip install mxboard' to log events to Tensorboard."
         )
         self._writer = None
Example #4
0
def train1():
    neth = NetH(json='insightface.json', nclass=nclass, ctx=ctx)
    neth.hybridize()
    trainer_neth = gluon.Trainer(neth.collect_params(), 'adam',
                                 {'learning_rate': 1e-4})

    circle_loss = CircleLoss(nclass, scale=64, margin=0.25)

    tick = time.time()
    ts = time.localtime(tick)
    stamp = time.strftime('%Y%m%d%H%M%S', ts)
    with mxboard.SummaryWriter(logdir='logs/' + stamp) as sw:
        iternum = 0
        for epoch in range(nepoch):
            for batch in train_loader:
                faces = batch.data[0].as_in_context(ctx)  # 0-1
                labels = batch.label[0].as_in_context(ctx)

                with autograd.record():
                    pred = neth(faces)
                    loss = circle_loss(pred, labels)
                    loss.backward()
                trainer_neth.step(batch_size)

                if iternum % 100 == 0:
                    step = iternum / 100
                    print("epoch: %d, iter: %d, loss: %f" %
                          (epoch, iternum, loss.mean().asscalar()))
                    sw.add_scalar("LOSS",
                                  value=('LOSS', loss.mean().asscalar()),
                                  global_step=step)
                    pass

                iternum = iternum + 1

                pass  # for in train_loader
            neth.export('circle', epoch)
Example #5
0
 def __init__(self, logdir, **kwargs):
     self.num_inst = 0
     self.global_num_inst = 0
     self.sw = mxboard.SummaryWriter(logdir=logdir  #, flush_secs=5
                                     )
     super(MxboardAccuracy, self).__init__(**kwargs)
Example #6
0
File: logging.py Project: olk/ki-go
 def __init__(self, logs_p, print_n=100):
     super().__init__()
     self._print_n = print_n
     self._sw = mxb.SummaryWriter(logdir=str(logs_p))
Example #7
0
def main():
    data_p = Path('/storage/data/').resolve()
    checkpoint_p = Path('./checkpoints/').resolve()
    checkpoint_p.mkdir(parents=True, exist_ok=True)
    logs_p = Path('./logs/').resolve()
    shutil.rmtree(logs_p, ignore_errors=True)
    encoder = SevenPlaneEncoder((19, 19))
    builder = SGFDatasetBuilder(data_p, encoder=encoder)
    builder.download_and_prepare()
    train_itr = builder.train_dataset(batch_size=BATCH_SIZE,
                                      max_worker=cpu_count(),
                                      factor=FACTOR)
    test_itr = builder.test_dataset(batch_size=BATCH_SIZE,
                                    max_worker=cpu_count(),
                                    factor=FACTOR)
    # build model
    betago = Model()
    # convert to half-presicion floating point FP16
    # NOTE: all NVIDIA GPUs with compute capability 6.1 have a low-rate FP16 performance == FFP16 is not the fast path on these GPUs
    #       data passed to split_and_load() must be float16 too
    #betago.cast('float16')
    # hybridize for speed
    betago.hybridize(static_alloc=True, static_shape=True)
    # print graph
    shape = (1, ) + encoder.shape()
    mx.viz.print_summary(betago(mx.sym.var('data')), shape={'data': shape})
    # pin GPUs
    ctx = [mx.gpu(i) for i in range(GPU_COUNT)]
    # optimizer
    opt_params = {
        'learning_rate': 0.001,
        'beta1': 0.9,
        'beta2': 0.999,
        'epsilon': 1e-08
    }
    opt = mx.optimizer.create('adam', **opt_params)
    # initialize parameters
    # MXNet initializes the weight matrices uniformly by drawing from [−0.07,0.07], bias parameters are all set to 0
    # 'Xavier': initializer is designed to keep the scale of gradients roughly the same in all layers
    betago.initialize(mx.init.Xavier(magnitude=2.3),
                      ctx=ctx,
                      force_reinit=True)
    # fetch and broadcast parameters
    params = betago.collect_params()
    # trainer
    trainer = Trainer(params=params, optimizer=opt, kvstore='device')
    # loss function
    loss_fn = SoftmaxCrossEntropyLoss()
    # use accuracy as the evaluation metric
    metric = Accuracy()
    with mxb.SummaryWriter(logdir='./logs') as sw:
        # add graph to MXBoard
        #betago.forward(mx.nd.ones(shape, ctx=ctx[0]))
        #betago.forward(mx.nd.ones(shape, ctx=ctx[1]))
        #sw.add_graph(betago)
        profiler.set_config(profile_all=True,
                            aggregate_stats=True,
                            continuous_dump=True,
                            filename='profile_output.json')
        start = time.perf_counter()
        # train
        for e in range(EPOCHS):
            if 0 == e:
                profiler.set_state('run')
            tick = time.time()
            # reset the train data iterator.
            train_itr.reset()
            # loop over the train data iterator
            for i, batch in enumerate(train_itr):
                if 0 == i:
                    tick_0 = time.time()
                # splits train data into multiple slices along batch_axis
                # copy each slice into a context
                data = split_and_load(batch.data[0],
                                      ctx_list=ctx,
                                      batch_axis=0,
                                      even_split=False)
                # splits train label into multiple slices along batch_axis
                # copy each slice into a context
                label = split_and_load(batch.label[0],
                                       ctx_list=ctx,
                                       batch_axis=0,
                                       even_split=False)
                outputs = []
                losses = []
                # inside training scope
                with ag.record():
                    for x, y in zip(data, label):
                        z = betago(x)
                        # computes softmax cross entropy loss
                        l = loss_fn(z, y)
                        outputs.append(z)
                        losses.append(l)
                # backpropagate the error for one iteration
                for l in losses:
                    l.backward()
                # make one step of parameter update.
                # trainer needs to know the batch size of data
                # to normalize the gradient by 1/batch_size
                trainer.step(BATCH_SIZE)
                # updates internal evaluation
                metric.update(label, outputs)
                # Print batch metrics
                if 0 == i % PRINT_N and 0 < i:
                    # checkpointing
                    betago.save_parameters(
                        str(checkpoint_p.joinpath(
                            'betago-{}.params'.format(e))))
                    sw.add_scalar(tag='Accuracy',
                                  value={'naive': metric.get()[1]},
                                  global_step=i - PRINT_N)
                    sw.add_scalar(tag='Speed',
                                  value={
                                      'naive':
                                      BATCH_SIZE * (PRINT_N) /
                                      (time.time() - tick)
                                  },
                                  global_step=i - PRINT_N)
                    print(
                        'epoch[{}] batch [{}], accuracy {:.4f}, samples/sec: {:.4f}'
                        .format(e, i,
                                metric.get()[1],
                                BATCH_SIZE * (PRINT_N) / (time.time() - tick)))
                    tick = time.time()
            if 0 == e:
                profiler.set_state('stop')
                profiler.dump()
            # gets the evaluation result
            print('epoch [{}], accuracy {:.4f}, samples/sec: {:.4f}'.format(
                e,
                metric.get()[1],
                BATCH_SIZE * (i + 1) / (time.time() - tick_0)))
            # reset evaluation result to initial state
            metric.reset()

    elapsed = time.perf_counter() - start
    print('elapsed: {:0.3f}'.format(elapsed))
    # use Accuracy as the evaluation metric
    metric = Accuracy()
    for batch in test_itr:
        data = split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
        label = split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
        outputs = []
        for x in data:
            outputs.append(betago(x))
        metric.update(label, outputs)
    print('validation %s=%f' % metric.get())
Example #8
0
def train_network(net, lr, input_shape, batch_size, train_path, test_path,
                  epoch, ctx):
    train_data, val_data = prepare_data(train_path, test_path, input_shape,
                                        batch_size)

    for X, y in train_data:
        print("X shape {}, y shape", X.shape, y.shape)
        break

    net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)

    net.summary(nd.zeros(shape=(1, 3) + input_shape, ctx=ctx))

    net.hybridize()

    lr_sched = mx.lr_scheduler.FactorScheduler(2000, factor=0.6, base_lr=1.0)
    optim = mx.optimizer.SGD(learning_rate=lr,
                             momentum=0.9,
                             wd=0.0001,
                             lr_scheduler=lr_sched)
    trainer = gluon.Trainer(net.collect_params(), optim)

    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

    train_acc_meter = mx.metric.Accuracy()
    train_loss_meter = mx.metric.CrossEntropy()

    hybridized = False

    with mxboard.SummaryWriter(logdir="./vgg_logs", flush_secs=60) as sw:
        for ep in range(1, epoch + 1):
            epoch_start = timeit.default_timer()

            train_acc_meter.reset()
            train_loss_meter.reset()

            print("Current Learning Rate: {}".format(trainer.learning_rate))
            for it, (data, label) in enumerate(train_data):
                data = data.as_in_context(ctx)
                label = label.as_in_context(ctx)

                with autograd.record():
                    output = net(data)
                    loss_val = loss_fn(output, label)
                loss_val.backward()
                trainer.step(data.shape[0])

                train_acc_meter.update(preds=[output], labels=[label])
                train_loss_meter.update(labels=[label],
                                        preds=[nd.softmax(output, axis=1)])

                if it % 10 == 0:
                    print(
                        "Epoch {}, batch {}, train loss {:.4f}, train acc {:.4f}"
                        .format(ep, it,
                                train_loss_meter.get()[1],
                                train_acc_meter.get()[1]))

            nd.waitall()
            epoch_stop = timeit.default_timer()

            val_loss, val_acc = evaluate(val_data, net, ctx)
            nd.waitall()
            print(
                "Epoch {}, Training time {}, learning rate {}, validation loss {:.5f}, validatoin acc {:.5f}"
                .format(ep, epoch_stop - epoch_start, trainer.learning_rate,
                        val_loss, val_acc))
            sw.add_scalar(tag="train_loss",
                          value=train_loss_meter.get()[1],
                          global_step=ep)
            sw.add_scalar(tag="train_acc",
                          value=train_acc_meter.get()[1],
                          global_step=ep)
            sw.add_scalar(tag="val_acc", value=val_acc, global_step=ep)
            sw.add_scalar(tag="val_loss", value=val_loss, global_step=ep)
            sw.add_scalar(tag="learning_rate",
                          value=trainer.learning_rate,
                          global_step=ep)
            if not hybridized:
                sw.add_graph(net)
                hybridized = True

            if ep % 2 == 0:
                net.export("vgg_models/vgg", ep)

    return net
Example #9
0
        SigmoidBinaryCrossEntropyLoss(from_sigmoid=False, batch_axis=0),
        input_dim, args.latent_dim)

    # optimizer
    trainer = Trainer(params=model_params,
                      optimizer='adam',
                      optimizer_params={'learning_rate': args.learning_rate})

    # forward function for training
    def forward_fn(batch):
        x = batch.data[0].as_in_context(ctx)
        y, q = vae_nn(x)
        loss = loss_fn(x, q, y)
        return loss

    # train
    run_id = train(forward_fn, train_iter, val_iter, trainer,
                   args.num_train_samples, args.num_val_samples, args.val_freq,
                   args.logdir)

    # generate latent space figure if latent dim = 2
    sw = mxboard.SummaryWriter(logdir=os.path.join(args.logdir, run_id))
    if args.latent_dim == 2:
        img = generate_2d_latent_space_image(vae_nn,
                                             val_iter,
                                             input_shape,
                                             n=20,
                                             ctx=ctx)
        sw.add_image('2D_Latent_space', img)
    sw.close()
Example #10
0
send_every_n = 50


# In[36]:


best_test_loss = 10e20


# In[37]:


log_dir = './logs/text_denoising'
checkpoint_dir = "model_checkpoint"
checkpoint_name = key+".params"
sw = mxboard.SummaryWriter(logdir=log_dir, flush_secs=1)


# Creating network

# In[38]:


net = Denoiser(alphabet_size=len(ALPHABET), max_src_length=FEATURE_LEN, max_tgt_length=FEATURE_LEN, num_heads=num_heads, embed_size=embed_size, num_layers=num_layers)
net.initialize(mx.init.Xavier(), ctx)


# Preparing the loss

# In[39]:
Example #11
0
def Train():
    N = k_at_hop[0] * (k_at_hop[1] + 1) + 1
    gcn = model.GCN()
    #mlp = model.MLP()

    gcn.collect_params().initialize(mx.init.Normal(0.01), ctx=ctx)
    #mlp.collect_params().initialize(mx.init.Normal(0.01), ctx=ctx)

    # hybridize
    gcn.hybridize()
    #mlp.hybridize()

    trainer_gcn = gluon.Trainer(gcn.collect_params(), 'adam',
                                {'learning_rate': 0.001})  #, 'beta1': 0.5
    #trainer_mlp = gluon.Trainer(mlp.collect_params(), 'adam', {'learning_rate': 0.001}) #, 'beta1': 0.5

    crit = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)

    tick = time.time()
    ts = time.localtime(tick)
    stamp = time.strftime('%Y%m%d%H%M%S', ts)
    with mxboard.SummaryWriter(logdir='logs/' + stamp) as sw:
        iternum = 0
        for epoch in range(nepoch):
            for feat, A, center_idx, one_hop_idcs, edge_labels in data_loader:
                feat = feat.as_in_context(ctx)
                A = A.as_in_context(ctx)
                center_idx = center_idx.as_in_context(ctx)
                one_hop_idcs = one_hop_idcs.as_in_context(ctx)
                edge_labels = edge_labels.as_in_context(ctx)

                batch_size, _, _ = feat.shape

                w = mx.nd.zeros((batch_size, N), ctx=ctx)
                for b in range(batch_size):
                    w[b][one_hop_idcs[b]] = 1

                with autograd.record():
                    #x = gcn(feat, A)
                    #_,_,dout = x.shape
                    #x = x.reshape(-1, dout)
                    pred = gcn(feat, A)
                    labels = edge_labels.reshape(-1, 1)
                    loss = crit(pred, labels).reshape((batch_size, -1)) * w
                loss.backward()

                trainer_gcn.step(batch_size)
                pred = pred.reshape(batch_size, -1)
                labels = labels.reshape(batch_size, -1)

                pred_ = mx.nd.zeros((batch_size, 200))
                labels_ = mx.nd.zeros((batch_size, 200))
                for b in range(batch_size):
                    pred_[b] = pred[b][one_hop_idcs[b]]
                    labels_[b] = labels[b][one_hop_idcs[b]]
                    pass

                lr = trainer_gcn.learning_rate
                p, r, acc = accuracy(pred_.reshape(-1), labels_.reshape(-1))

                sw.add_scalar(tag='Eva',
                              value=('Acc', acc.mean().asscalar()),
                              global_step=iternum)
                sw.add_scalar(tag='Eva', value=('P', p), global_step=iternum)
                sw.add_scalar(tag='Eva', value=('R', r), global_step=iternum)
                sw.add_scalar(tag='Loss',
                              value=('loss', loss.mean().asscalar()),
                              global_step=iternum)
                print("Loss:",
                      loss.mean().asscalar(), "Acc:",
                      acc.mean().asscalar(), "P:", p, "R:", r)

                if iternum % 2000 == 0:
                    #trainer_gcn.set_learning_rate(lr * 0.1)
                    gcn.export('./models/gcn-' + stamp, iternum / 2000)
                    #mlp.export('./models/mlp-'+stamp, iternum / 2000)
                    pass

                iternum = iternum + 1
                pass
            pass
        pass
    pass
Example #12
0
import mxnet as mx, cv2 as cv
from mxnet.gluon import nn, loss as gloss, data as gdata
from mxnet import nd, image, autograd
from utils import utils, predata
from utils.utils import calc_loss, cls_eval, bbox_eval
from utils.train import validate
import matplotlib.pyplot as plt
import fpn
import time, argparse
import mxboard as mxb

sw = mxb.SummaryWriter(logdir='./logs', flush_secs=5)

# parsing cli arguments
parser = argparse.ArgumentParser()
parser.add_argument("-l", "--load", dest="load",
                    help="bool: load model to directly infer rather than training",
                    type=int, default=1)
parser.add_argument("-b", "--base", dest="base",
                    help="bool: using additional base network",
                    type=int, default=0)
parser.add_argument("-e", "--epoches", dest="num_epoches",
                    help="int: trainig epoches",
                    type=int, default=20)
parser.add_argument("-bs", "--batch_size", dest="batch_size",
                    help="int: batch size for training",
                    type=int, default=4)
parser.add_argument("-is", "--imsize", dest="input_size",
                    help="int: input size",
                    type=int, default=256)
Example #13
0
def train_ResNeXt(net, lr, input_shape, batch_size, train_path, test_path,
                  epoch, ctx):
    train_data, val_data = prepare_data(train_path, test_path, input_shape,
                                        batch_size)

    lr_sched = mx.lr_scheduler.FactorScheduler(step=1000,
                                               factor=0.94,
                                               base_lr=1)
    optim = mx.optimizer.SGD(learning_rate=lr,
                             momentum=0.9,
                             wd=1e-3,
                             lr_scheduler=lr_sched)
    trainer = gluon.Trainer(net.collect_params(), optim)

    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

    train_acc_meter = mx.metric.Accuracy()
    train_loss_meter = mx.metric.CrossEntropy()

    hybridized = False

    with mxboard.SummaryWriter(logdir="./resnext_logs", flush_secs=30) as sw:
        for ep in range(1, epoch + 1):
            #train_data.reset()
            #val_data.reset()
            print("Current Learning Rate {}".format(trainer.learning_rate))
            epoch_start = timeit.default_timer()

            train_acc_meter.reset()
            train_loss_meter.reset()

            for it, (data, label) in enumerate(train_data):
                data = data.as_in_context(ctx)
                label = label.as_in_context(ctx)

                with autograd.record():
                    output = net(data)
                    loss_val = loss_fn(output, label)
                    loss_val.backward()
                trainer.step(data.shape[0])

                train_acc_meter.update(preds=[output], labels=[label])
                train_loss_meter.update(labels=[label],
                                        preds=[nd.softmax(output, axis=1)])

                if it % 10 == 0:
                    print(
                        "Epoch {}, batch {}, train loss {:.4f}, train acc {:.4f}"
                        .format(ep, it,
                                train_loss_meter.get()[1],
                                train_acc_meter.get()[1]))

            epoch_stop = timeit.default_timer()

            val_loss, val_acc = evaluate(val_data, net, ctx)
            print(
                "Epoch {}, Training time {}, validation loss {:.5f}, validation acc {:.5f}"
                .format(ep, epoch_stop - epoch_start, val_loss, val_acc))
            sw.add_scalar(tag="train_loss",
                          value=train_loss_meter.get()[1],
                          global_step=ep)
            sw.add_scalar(tag="train_acc",
                          value=train_acc_meter.get()[1],
                          global_step=ep)
            sw.add_scalar(tag="val_acc", value=val_acc, global_step=ep)
            sw.add_scalar(tag="val_loss", value=val_loss, global_step=ep)
            sw.add_scalar(tag="learning_rate",
                          value=trainer.learning_rate,
                          global_step=ep)
            if not hybridized:
                sw.add_graph(net)
                hybridized = True

            if ep % 1 == 0:
                net.export("resnext_models/resnext", ep)

    return net
Example #14
0
def train(forward_fn: Callable[[mx.io.DataBatch, Dict[str, float]],
                               nd.NDArray],
          train_iter: mx.io.DataIter,
          val_iter: mx.io.DataIter,
          trainer: Trainer,
          num_train_samples: int,
          num_val_samples: int,
          val_freq: int,
          logdir: str,
          run_suffix: str = '',
          validate_at_end: bool = True,
          hyperparam_scheduler: HyperparamScheduler = None,
          plot_callbacks: Tuple[Callable[[mxboard.SummaryWriter, int],
                                         None]] = tuple()):
    """
    Train the model given by its forward function.

    :param forward_fn: The forward function of the model that takes in a batch and returns loss (over samples in a
        batch)
    :param train_iter: Training data iterator.
    :param val_iter: Validation data iterator.
    :param trainer: Trainer.
    :param num_train_samples: Number of training samples.
    :param num_val_samples: Number of validation samples (per validation).
    :param val_freq: Validation frequency (in number of samples).
    :param logdir: Log directory for mxboard.
    :param run_suffix: Suffix for run id.
    :param validate_at_end: If True, run validation over the full validation set at the end.
    :param hyperparam_scheduler: Hyperparam scheduler. train calls get_hyperparams method of scheduler and passes the
        returned dictionary to forward function.
    :param plot_callbacks: A list of additional plotting callbacks. These are called after each update. A plotting
      callback should expect a mxboard.SummaryWriter and the iteration number. See DRAW/train_mnist.py for an example.
    """
    run_id = '{}_{}'.format(strftime('%Y%m%d%H%M%S'), run_suffix)
    sw = mxboard.SummaryWriter(logdir=os.path.join(logdir, run_id))
    pm = tqdm.tqdm(total=num_train_samples)

    last_val_loss = np.inf
    last_val_time = 0
    samples_processed = 0

    train_params = {}
    if hyperparam_scheduler:
        train_params = hyperparam_scheduler.get_hyperparams(
            samples_processed=0)

    while samples_processed < num_train_samples:
        batch = _get_batch(train_iter)

        # train step
        with autograd.record():
            loss = forward_fn(batch, **train_params)
        autograd.backward(loss)

        batch_size = loss.shape[0]
        trainer.step(batch_size=batch_size)

        samples_processed += batch_size
        last_train_loss = nd.mean(loss).asscalar()  # loss per sample
        # plot loss
        sw.add_scalar('Loss', {'Training': last_train_loss}, samples_processed)
        pm.update(n=batch_size)

        # call plot callbacks
        for callback in plot_callbacks:
            callback(sw, samples_processed)

        # validation step
        if samples_processed - last_val_time >= val_freq:
            last_val_time = samples_processed
            tot_val_loss = 0.0
            j = 0
            while j < num_val_samples:
                batch = _get_batch(val_iter)
                loss = forward_fn(batch, **train_params)
                tot_val_loss += nd.sum(loss).asscalar()
                j += loss.shape[0]

            last_val_loss = tot_val_loss / j  # loss per sample
            sw.add_scalar('Loss', {'Validation': last_val_loss},
                          samples_processed)
            sw.flush()

        # call hyperparam scheduler
        if hyperparam_scheduler:
            train_params = hyperparam_scheduler.get_hyperparams(
                samples_processed=samples_processed)

        pm.set_postfix({
            'Train loss': last_train_loss,
            'Val loss': last_val_loss
        })

    if validate_at_end:
        # calculate loss on the whole validation set
        tot_val_loss = 0.0
        j = 0
        for batch in val_iter:
            loss = forward_fn(batch, **train_params)
            tot_val_loss += nd.sum(loss).asscalar()
            j += loss.shape[0]

        last_val_loss = tot_val_loss / j  # loss per sample
        sw.add_scalar('Loss', {'Validation_final': last_val_loss},
                      samples_processed)
        pm.set_postfix({
            'Train loss': last_train_loss,
            'Val loss': last_val_loss
        })

    pm.close()
    sw.flush()
    sw.close()

    return run_id