def top_k_error(target_action,
                target_action_type,
                target_action_mask,
                rule_prob,
                terminal_gen_action_prob,
                token_prob,
                copy_prob,
                k=5):
    batch_size, max_action_length, _ = target_action.shape
    _, _, rule_num = rule_prob.shape
    _, _, token_num = token_prob.shape
    _, _, max_query_length = copy_prob.shape

    # (batch_size, max_action_length)
    rule_mask, token_mask, copy_mask = F.split(target_action_type, axis=2)

    # (batch_size, max_action_length)
    target_rule, target_token, target_copy = F.split(target_action, axis=2)
    target_rule = F.reshape(target_rule, (batch_size, max_action_length, 1))

    # (batch_size, max_action_length)
    gen_token_prob, copy_token_prob = F.split(terminal_gen_action_prob, axis=2)
    gen_token_prob = F.reshape(gen_token_prob,
                               (batch_size, max_action_length, 1))
    gen_token_prob = F.broadcast(gen_token_prob,
                                 (batch_size, max_action_length, token_num))
    copy_token_prob = F.reshape(copy_token_prob,
                                (batch_size, max_action_length, 1))
    copy_token_prob = F.broadcast(
        copy_token_prob, (batch_size, max_action_length, max_query_length))
    # (batch_size, max_action_length, token_num)
    token_prob = gen_token_prob * token_prob
    # (batch_size, max_action_length, max_query_length)
    copy_prob = copy_token_prob * copy_prob
    # (batch_size, max_action_length, token_num + max_query_length)
    gen_or_copy = F.concatenate(token_prob, copy_prob, axis=2)

    # (batch_size, max_action_length)
    token_label = token_mask * target_token + (copy_mask *
                                               (target_copy + token_num))
    token_label = F.reshape(token_label, (batch_size, max_action_length, 1))

    # (batch_size, max_action_length, 1)
    rule_err = F.top_n_error(rule_prob, target_rule, axis=2, n=k)
    rule_err = F.reshape(rule_err, (batch_size, max_action_length))
    # (batch_size, max_action_length, 1)
    token_err = F.top_n_error(gen_or_copy, token_label, axis=2, n=k)
    token_err = F.reshape(token_err, (batch_size, max_action_length))

    # (batch_size, max_action_length)
    err = rule_mask * rule_err + (token_mask + copy_mask) * token_err
    # (batch_size,)
    num = F.sum(rule_mask, axis=1) + F.sum(token_mask, axis=1) + F.sum(
        copy_mask, axis=1)
    # (batch_size,)
    err = F.sum(err, axis=1)
    # (batch_size,)
    err = err / (num + 1e-7)
    return F.mean(err)
def meta_test(args, shape_x, test_data):

    # Build episode generators
    test_episode_generator = EpisodeGenerator(
        test_data[0], test_data[1], args.n_class, args.n_shot, args.n_query)

    # Build prototypical network
    xs_v = nn.Variable((args.n_class * args.n_shot, ) + shape_x)
    xq_v = nn.Variable((args.n_class * args.n_query, ) + shape_x)
    hq_v = net(args.n_class, xs_v, xq_v, args.embedding,
               args.net_type, args.metric, True)
    yq_v = nn.Variable((args.n_class * args.n_query, 1))
    err_v = F.mean(F.top_n_error(hq_v, yq_v, n=1))

    # Load parameters
    nn.load_parameters(args.work_dir + "/params.h5")

    # Evaluate error rate
    v_errs = []
    for k in range(args.n_episode_for_test):
        xs_v.d, xq_v.d, yq_v.d = test_episode_generator.next()
        err_v.forward(clear_no_need_grad=True, clear_buffer=True)
        v_errs.append(np.float(err_v.d.copy()))
    v_err_mean = np.mean(v_errs)
    v_err_std = np.std(v_errs)
    v_err_conf = 1.96 * v_err_std / np.sqrt(args.n_episode_for_test)

    # Monitor error rate
    monitor = Monitor(args.work_dir)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)
    monitor_test_err.add(0, v_err_mean * 100)
    monitor_test_conf.add(0, v_err_conf * 100)

    return v_err_mean, v_err_conf
Beispiel #3
0
def net(input, label, bn_batch_stat, args, init_params=None):
    output = forward_conv(input, bn_batch_stat, args, init_params)
    loss = loss_func(output, label)
    output2 = output.get_unlinked_variable(need_grad=False)
    accuracy = 1.0 - F.mean(F.top_n_error(output2, label, n=1))

    return (loss, accuracy)
Beispiel #4
0
def get_model(args,
              num_classes,
              test=False,
              channel_last=False,
              with_error=True):
    """
    Create computation graph and variables.
    """
    nn_in_size = 224
    if channel_last:
        image = nn.Variable([args.batch_size, nn_in_size, nn_in_size, 4])
    else:
        image = nn.Variable([args.batch_size, 4, nn_in_size, nn_in_size])
    label = nn.Variable([args.batch_size, 1])
    pred, hidden = model_resnet_nhwc.resnet_imagenet(image,
                                                     num_classes,
                                                     args.num_layers,
                                                     args.shortcut_type,
                                                     test=test,
                                                     tiny=False,
                                                     channel_last=channel_last)
    pred.persistent = True
    loss = F.mean(loss_function(pred, label, args.label_smoothing))
    error = F.sum(F.top_n_error(pred, label, n=1))
    Model = namedtuple('Model',
                       ['image', 'label', 'pred', 'loss', 'error', 'hidden'])
    return Model(image, label, pred, loss, error, hidden)
Beispiel #5
0
def test_top_n_error_forward(seed, axis, n, ctx, func_name):
    ishape = [5, 6, 7]
    rng = np.random.RandomState(seed)

    l_shape = list(ishape)
    l_shape[axis] = 1
    n_class = ishape[axis]

    inputs = [
        rng.rand(5, 6, 7).astype(np.float32) * 0.9 + 0.05,
        rng.randint(0, n_class, size=l_shape).astype(np.int)
    ]

    ref = ref_top_n_error(inputs[0], inputs[1], axis, n)

    x = nn.Variable(ishape)
    l = nn.Variable(l_shape)
    y = F.top_n_error(x, l, axis, n)
    x.d = inputs[0]
    l.d = inputs[1]
    y.forward()
    res = y.d

    atol_f = 1e-6
    assert_allclose(ref, res, atol=atol_f)
Beispiel #6
0
def main():
    # Context
    ctx = get_extension_context("cudnn", device_id="0")
    nn.set_default_context(ctx)
    nn.auto_forward(False)
    # Inputs
    b, c, h, w = 64, 1, 28, 28
    x = nn.Variable([b, c, h, w])
    t = nn.Variable([b, 1])
    vx = nn.Variable([b, c, h, w])
    vt = nn.Variable([b, 1])
    # Model
    model = Model()
    pred = model(x)
    loss = F.softmax_cross_entropy(pred, t)
    vpred = model(vx, test=True)
    verror = F.top_n_error(vpred, vt)
    # Solver
    solver = S.Adam()
    solver.set_parameters(model.get_parameters(grad_only=True))
    # Data Iterator
    tdi = data_iterator_mnist(b, train=True)
    vdi = data_iterator_mnist(b, train=False)
    # Monitor
    monitor = Monitor("tmp.monitor")
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Training loop
    for e in range(1):
        for j in range(tdi.size // b):
            i = e * tdi.size // b + j
            x.d, t.d = tdi.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()
            monitor_loss.add(i, loss.d)
        error = 0.0
        for _ in range(vdi.size // b):
            vx.d, vt.d = vdi.next()
            verror.forward(clear_buffer=True)
            error += verror.d
        error /= vdi.size // b
        monitor_verr.add(i, error)
Beispiel #7
0
    def metrics(self, outputs, targets):
        r"""Return a dictionary of metrics to monitor during training.

        It is expected to have a 1:1 mapping between the
        model outputs and targets variables.

        Args:
            outputs (list of nn.Variable):
                A list of output variables computed from the model.
            targets (list of nn.Variable):
                A list of target variables loaded from the data.

        Returns:
            dict: A dictionary containing all metrics (nn.Variable) to monitor
                E.g., {'error': nn.Variable((1,)), 'F1': nn.Variable((1,))}
        """
        assert len(targets) == 1

        return {"error": F.mean(F.top_n_error(outputs[0], targets[0]))}
Beispiel #8
0
    def __call__(self, args, input_ids, attention_mask=None, token_type_ids=None,
                 position_ids=None, head_mask=None, labels=None, num_labels=2,
                 vocab_size=30522, num_embed_dim=768, num_pos_ids=512,
                 num_attention_layers=12, num_attention_embed_dim=768,
                 num_attention_heads=12, num_attention_dim_feedforward=3072,
                 attention_activation=None, pool_outmap=768, embed_dropout_prob=0.1,
                 attention_dropout_prob=0.1, dropout_prob=0.1,
                 test=True):

        pooled_output = self.bert(args, input_ids,
                                  attention_mask=attention_mask,
                                  token_type_ids=token_type_ids,
                                  position_ids=position_ids,
                                  head_mask=head_mask,
                                  vocab_size=vocab_size,
                                  num_embed_dim=num_embed_dim,
                                  num_pos_ids=num_pos_ids,
                                  num_attention_layers=num_attention_layers,
                                  num_attention_embed_dim=num_attention_embed_dim,
                                  num_attention_heads=num_attention_heads,
                                  num_attention_dim_feedforward=num_attention_dim_feedforward,
                                  attention_activation=attention_activation,
                                  pool_outmap=pool_outmap,
                                  embed_dropout_prob=embed_dropout_prob,
                                  attention_dropout_prob=attention_dropout_prob,
                                  test=test)

        if not test:
            pooled_output = F.dropout(pooled_output, p=dropout_prob)
        logits = PF.affine(pooled_output, num_labels,
                           base_axis=1, name='affine_seq_class')

        label = F.reshape(labels, (-1, 1), inplace=False)
        if args.task_name == "sts-b":
            loss = F.mean((logits-label)**2)
        else:
            loss = F.mean(F.softmax_cross_entropy(logits, label))
        error = F.sum(F.top_n_error(logits, label))

        return loss, logits, error
Beispiel #9
0
    def __init__(self,
                 solver,
                 tinput=None,
                 tlabel=None,
                 tpred=None,
                 tdata=None,
                 vinput=None,
                 vlabel=None,
                 vpred=None,
                 vdata=None,
                 monitor_path=None,
                 model_save_path=None,
                 max_epoch=1,
                 iter_per_epoch=None,
                 val_iter=None):
        # Monitors
        monitor = Monitor(monitor_path)
        monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
        monitor_err = MonitorSeries("Training error", monitor, interval=10)
        monitor_vloss = MonitorSeries("Valid loss", monitor, interval=1)
        monitor_verr = MonitorSeries("Valid error", monitor, interval=1)
        monitor_time = MonitorTimeElapsed("Training time",
                                          monitor,
                                          interval=10)

        # Loss and error
        tpred = tpred.apply(persistent=True)
        tloss = F.mean(F.softmax_cross_entropy(tpred, tlabel))
        terror = F.mean(F.top_n_error(tpred.get_unlinked_variable(), tlabel))
        vpred = vpred.apply(persistent=True)
        vloss = F.mean(F.softmax_cross_entropy(vpred, vlabel))
        verror = F.mean(F.top_n_error(vpred.get_unlinked_variable(), vlabel))

        # Updater
        def tdata_feeder():
            tinput.d, tlabel.d = tdata.next()

        def forward_callback_on_finish(i):
            terror.forward()

        def update_callback_on_finish(i):
            monitor_loss.add(i, tloss.d)
            monitor_err.add(i, terror.d)
            monitor_time.add(i)

        updater = Updater(
            solver,
            tloss,
            data_feeder=tdata_feeder,
            forward_callback_on_finish=forward_callback_on_finish,
            update_callback_on_finish=update_callback_on_finish)

        # Evaluator
        def vdata_feeder():
            vinput.d, vlabel.d = vdata.next()

        def vloss_callback_on_finish(i, v):
            monitor_vloss.add(i, v)

        def verror_callback_on_finish(i, v):
            monitor_verr.add(i, v)

        val_iter = val_iter if val_iter is not None else vdata.size // vdata.batch_size
        evaluator = Evaluator([vloss, verror],
                              data_feeder=vdata_feeder,
                              val_iter=val_iter,
                              callback_on_finish=[
                                  vloss_callback_on_finish,
                                  verror_callback_on_finish
                              ])

        # Trainer
        iter_per_epoch = iter_per_epoch if iter_per_epoch is not None \
            else tdata.size // tdata.batch_size
        self.trainer = Trainer(updater,
                               evaluator,
                               model_save_path,
                               max_epoch=max_epoch,
                               iter_per_epoch=iter_per_epoch)
Beispiel #10
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train-file", type=str)
    parser.add_argument("--valid-file", type=str)
    parser.add_argument("--num-training-examples", type=int, default=50)
    parser.add_argument("--accum-grad", type=int, default=1)
    parser.add_argument("--valid-interval", type=int, default=200)
    parser.add_argument("--threshold", type=float, default=0.95)
    parser.add_argument("--context", type=str, default="cpu")
    parser.add_argument("--device-id", type=int, default=0)

    args = parser.parse_args()

    from nnabla.ext_utils import get_extension_context
    extension_module = args.context
    ctx = get_extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # prepare data iterators
    tdata = data_iterator(
        BAbI15DataSource(args.train_file,
                         args.num_training_examples,
                         shuffle=True), 1, False, False, False)
    vdata = data_iterator(
        BAbI15DataSource(args.valid_file, 1000, shuffle=True), 1, False, False,
        False)

    # prepare monitors
    monitor = M.Monitor("./bAbI15")
    tloss = M.MonitorSeries("Training Loss", monitor, interval=10)
    terror = M.MonitorSeries("Training Error", monitor, interval=10)
    verror = M.MonitorSeries("Validation Error", monitor, interval=1)

    # prepare solver
    solver = S.Adam()
    solver_initialized = False

    cnt = 0
    while True:
        l = 0.0
        e = 0.0

        solver.zero_grad()
        for _ in range(args.accum_grad):
            # read next data
            x = tdata.next()
            V = x[1][0][0]
            E = x[2][0][0]
            ans = x[3][0][0]

            # construct GGNN
            output = predict(V, E)
            output = F.reshape(output, (1, output.shape[0]))

            # initialize solver
            if not solver_initialized:
                solver.set_parameters(nn.get_parameters())
                solver_initialized = True
                solver.zero_grad()

            # calculate loss/error
            label = nn.Variable((1, 1))
            label.data.data[0, 0] = ans
            output2 = output.unlinked()
            loss = F.mean(F.softmax_cross_entropy(output, label))
            error = F.mean(F.top_n_error(output2, label))
            F.sink(loss, error).forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)

            l += loss.data.data
            e += error.data.data

        # dump log
        tloss.add(cnt, l / args.accum_grad)
        terror.add(cnt, e / args.accum_grad)
        l = 0.0
        e = 0.0

        solver.update()

        cnt += 1
        if cnt % args.valid_interval == 0:
            # validation
            validation_error = 0
            correct_example = None
            wrong_example = None
            for _ in range(vdata.size):
                x = vdata.next()
                id2str = x[0][0][0]
                V = x[1][0][0]
                E = x[2][0][0]
                ans = x[3][0][0]

                output = predict(V, E)
                output = F.reshape(output, (1, output.shape[0]))

                # calculate error
                label = nn.Variable((1, 1))
                label.data.data[0, 0] = ans
                error = F.top_n_error(output, label)
                error.forward(clear_no_need_grad=True)

                if error.data.data > 0.5:
                    if wrong_example is None:
                        wrong_example = (id2str, V, E, ans, output.data.data)
                else:
                    if correct_example is None:
                        correct_example = (id2str, V, E, ans, output.data.data)
                validation_error += error.data.data
            validation_error /= vdata.size
            verror.add(cnt, validation_error)
            accuracy = 1 - validation_error
            if accuracy >= args.threshold:

                def show(example):
                    for i, j in example[2]["is"]:
                        print("{} is {}.".format(example[0][i], example[0][j]))
                    for i, j in example[2]["has_fear"]:
                        print("{} are afraid of {}.".format(
                            example[0][i], example[0][j]))
                    i = np.argmax(example[1])
                    print("What is {} afraid of?".format(example[0][i]))
                    i = np.argmax(example[4])
                    print("Expected: {}, Actual: {}".format(
                        example[0][example[3]], example[0][i]))

                if correct_example is not None:
                    show(correct_example)
                if wrong_example is not None:
                    show(wrong_example)

                break
def meta_train(args, train_data, valid_data, test_data):

    # Build episode generators
    shape_x = (1, 28, 28)
    train_episode_generator = EpisodeGenerator(args.n_class_tr, args.n_shot_tr,
                                               args.n_query_tr, shape_x,
                                               train_data)
    valid_episode_generator = EpisodeGenerator(args.n_class, args.n_shot,
                                               args.n_query, shape_x,
                                               valid_data)
    test_episode_generator = EpisodeGenerator(args.n_class, args.n_shot,
                                              args.n_query, shape_x, test_data)

    # Build training model
    xs_t = nn.Variable((args.n_class_tr * args.n_shot_tr, ) + shape_x)
    xq_t = nn.Variable((args.n_class_tr * args.n_query_tr, ) + shape_x)
    hq_t = net(args.n_class_tr, xs_t, xq_t, args.embedding, args.net_type,
               args.metric, False)
    yq_t = nn.Variable((args.n_class_tr * args.n_query_tr, 1))
    loss_t = F.mean(F.softmax_cross_entropy(hq_t, yq_t))

    # Build evaluation model
    xs_v = nn.Variable((args.n_class * args.n_shot, ) + shape_x)
    xq_v = nn.Variable((args.n_class * args.n_query, ) + shape_x)
    hq_v = net(args.n_class, xs_v, xq_v, args.embedding, args.net_type,
               args.metric, True)
    yq_v = nn.Variable((args.n_class * args.n_query, 1))
    err_v = F.mean(F.top_n_error(hq_v, yq_v, n=1))

    # Setup solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor outputs
    monitor = Monitor(args.work_dir)
    monitor_loss = MonitorSeries("Training loss",
                                 monitor,
                                 interval=args.iter_per_epoch)
    monitor_valid_err = MonitorSeries("Validation error",
                                      monitor,
                                      interval=args.iter_per_valid)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)

    # Output files
    param_file = args.work_dir + "params.h5"
    tsne_file = args.work_dir + "tsne.png"

    # Training loop
    train_losses = []
    best_err = 1.0
    for i in range(args.max_iteration):

        # Decay learning rate
        if (i + 1) % args.lr_decay_interval == 0:
            solver.set_learning_rate(solver.learning_rate() * args.lr_decay)

        # Create an episode
        xs_t.d, xq_t.d, yq_t.d = train_episode_generator.next()

        # Training by the episode
        solver.zero_grad()
        loss_t.forward(clear_no_need_grad=True)
        loss_t.backward(clear_buffer=True)
        solver.update()
        train_losses.append(loss_t.d.copy())

        # Evaluation
        if (i + 1) % args.iter_per_valid == 0:
            train_loss = np.mean(train_losses)
            train_losses = []
            valid_errs = []
            for k in range(args.n_episode_for_valid):
                xs_v.d, xq_v.d, yq_v.d = valid_episode_generator.next()
                err_v.forward(clear_no_need_grad=True, clear_buffer=True)
                valid_errs.append(np.float(err_v.d.copy()))
            valid_err = np.mean(valid_errs)

            monitor_valid_err.add(i + 1, valid_err * 100)
            if valid_err < best_err:
                best_err = valid_err
                nn.save_parameters(param_file)

    # Final evaluation
    nn.load_parameters(param_file)
    v_errs = []
    for k in range(args.n_episode_for_test):
        xs_v.d, xq_v.d, yq_v.d = test_episode_generator.next()
        err_v.forward(clear_no_need_grad=True, clear_buffer=True)
        v_errs.append(np.float(err_v.d.copy()))
    v_err_mean = np.mean(v_errs)
    v_err_std = np.std(v_errs)
    v_err_conf = 1.96 * v_err_std / np.sqrt(args.n_episode_for_test)
    monitor_test_err.add(0, v_err_mean * 100)
    monitor_test_conf.add(0, v_err_conf * 100)

    # Visualization
    n_class = 50
    n_sample = 20
    batch = test_data[:n_class].reshape(n_class * n_sample, 1, 28, 28)
    label = []
    for i in range(n_class):
        label.extend(np.ones(n_sample) * (i % 50))
    u = get_embeddings(batch, conv4)
    v = get_tsne(u)
    plot_tsne(v[:, 0], v[:, 1], label, tsne_file)
Beispiel #12
0
def train():
    """
    Main script.
    """

    args = get_args()

    _ = nn.load_parameters(args.pretrained_model_path)
    if args.fine_tune:
        nnabla.parameter.pop_parameter('decoder/logits/affine/conv/W')
        nnabla.parameter.pop_parameter('decoder/logits/affine/conv/b')

    n_train_samples = args.train_samples
    n_val_samples = args.val_samples
    distributed = args.distributed
    compute_acc = args.compute_acc

    if distributed:
        # Communicator and Context
        from nnabla.ext_utils import get_extension_context
        extension_module = "cudnn"
        ctx = get_extension_context(
            extension_module, type_config=args.type_config)
        comm = C.MultiProcessDataParalellCommunicator(ctx)
        comm.init()
        n_devices = comm.size
        mpi_rank = comm.rank
        device_id = mpi_rank
        ctx.device_id = str(device_id)
        nn.set_default_context(ctx)
    else:
        # Get context.
        from nnabla.ext_utils import get_extension_context
        extension_module = args.context
        if args.context is None:
            extension_module = 'cpu'
        logger.info("Running in %s" % extension_module)
        ctx = get_extension_context(
            extension_module, device_id=args.device_id, type_config=args.type_config)
        nn.set_default_context(ctx)
        n_devices = 1
        device_id = 0

    # training data
    data = data_iterator_segmentation(
            args.train_samples, args.batch_size, args.train_dir, args.train_label_dir, target_width=args.image_width, target_height=args.image_height)
    # validation data
    vdata = data_iterator_segmentation(args.val_samples, args.batch_size, args.val_dir,
                                       args.val_label_dir, target_width=args.image_width, target_height=args.image_height)

    if distributed:
        data = data.slice(
            rng=None, num_of_slices=n_devices, slice_pos=device_id)
        vdata = vdata.slice(
            rng=None, num_of_slices=n_devices, slice_pos=device_id)
    num_classes = args.num_class

    # Workaround to start with the same initialized weights for all workers.
    np.random.seed(313)
    t_model = get_model(
        args, test=False)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.unlinked()
    t_e = F.sum(F.top_n_error(t_pred2, t_model.label, axis=1)
                * t_model.mask) / F.sum(t_model.mask)

    v_model = get_model(
        args, test=True)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()
    v_e = F.sum(F.top_n_error(v_pred2, v_model.label, axis=1)
                * v_model.mask) / F.sum(t_model.mask)

    # Create Solver
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Load checkpoint
    start_point = 0
    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)

    # Setting warmup.
    base_lr = args.learning_rate / n_devices
    warmup_iter = int(1. * n_train_samples /
                      args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed(
        "Validation time", monitor, interval=1)

    # save_nnp
    contents = save_nnp({'x': v_model.image}, {
                        'y': v_model.pred}, args.batch_size)
    save.save(os.path.join(args.model_save_path,
                           'Deeplabv3plus_result_epoch0.nnp'), contents, variable_batch_size=False)

    # Training loop
    for i in range(start_point, int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            save_checkpoint(args.model_save_path, i, solver)
        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            vmiou_local = 0.
            val_iter_local = n_val_samples // args.batch_size
            vl_local = nn.NdArray()
            vl_local.zero()
            ve_local = nn.NdArray()
            ve_local.zero()
            for j in range(val_iter_local):
                images, labels, masks = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.mask.d = masks
                v_model.image.data.cast(np.float32, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.data
                ve_local += v_e.data
                # Mean IOU computation
                if compute_acc:
                    vmiou_local += compute_miou(num_classes, labels,
                                                np.argmax(v_model.pred.d, axis=1), masks)

            vl_local /= val_iter_local
            ve_local /= val_iter_local
            if compute_acc:
                vmiou_local /= val_iter_local
                vmiou_ndarray = nn.NdArray.from_numpy_array(
                    np.array(vmiou_local))
            if distributed:
                comm.all_reduce(vl_local, division=True, inplace=True)
                comm.all_reduce(ve_local, division=True, inplace=True)
                if compute_acc:
                    comm.all_reduce(vmiou_ndarray, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl_local.data.copy())
                monitor_verr.add(i * n_devices, ve_local.data.copy())
                if compute_acc:
                    monitor_miou.add(i * n_devices, vmiou_local)
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        e_acc = nn.NdArray(t_e.shape)
        e_acc.zero()
        l_acc = nn.NdArray(t_model.loss.shape)
        l_acc.zero()
        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels, masks = data.next()
            t_model.image.d = images
            t_model.label.d = labels
            t_model.mask.d = masks
            t_model.image.data.cast(np.float32, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            e_acc += t_e.data
            l_acc += t_model.loss.data

        # AllReduce
        if distributed:
            params = [x.grad for x in nn.get_parameters().values()]
            comm.all_reduce(params, division=False, inplace=False)
            comm.all_reduce(l_acc, division=True, inplace=True)
            comm.all_reduce(e_acc, division=True, inplace=True)
        solver.scale_grad(1./args.accum_grad)
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if distributed:
            # Synchronize by averaging the weights over devices using allreduce
            if (i+1) % args.sync_weight_every_itr == 0:
                weights = [x.data for x in nn.get_parameters().values()]
                comm.all_reduce(weights, division=True, inplace=True)

        if device_id == 0:
            monitor_loss.add(
                i * n_devices, (l_acc / args.accum_grad).data.copy())
            monitor_err.add(
                i * n_devices, (e_acc / args.accum_grad).data.copy())
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter --> changed to poly learning rate decay policy
        # if i in args.learning_rate_decay_at:
        solver.set_learning_rate(base_lr * ((1 - i / args.max_iter)**0.1))

    if device_id == 0:
        nn.save_parameters(os.path.join(args.model_save_path,
                                        'param_%06d.h5' % args.max_iter))

    contents = save_nnp({'x': v_model.image}, {
                        'y': v_model.pred}, args.batch_size)
    save.save(os.path.join(args.model_save_path,
                           'Deeplabv3plus_result.nnp'), contents, variable_batch_size=False)
def train():
    """
    Main script.

    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error

    """

    args = get_args()
    n_train_samples = 1281167
    num_classes = 1000

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Pipelines and Iterators for training
    train_pipes = [
        TrainPipeline(args.batch_size,
                      args.num_threads,
                      device_id,
                      args.train_cachefile_dir,
                      args.train_list,
                      seed=device_id + 1,
                      num_gpu=n_devices,
                      random_area=args.random_area)
    ]
    train_pipes[0].build()
    data = DALIClassificationIterator(train_pipes,
                                      train_pipes[0].epoch_size("Reader") //
                                      n_devices,
                                      auto_reset=True,
                                      stop_at_epoch=False)
    # Pipelines and Iterators for validation
    val_pipes = [
        ValPipeline(args.batch_size,
                    args.num_threads,
                    device_id,
                    args.val_cachefile_dir,
                    args.val_list,
                    seed=device_id + 1,
                    num_gpu=n_devices)
    ]
    val_pipes[0].build()
    vdata = DALIClassificationIterator(val_pipes,
                                       val_pipes[0].epoch_size("Reader") //
                                       n_devices,
                                       auto_reset=True,
                                       stop_at_epoch=False)
    # Network for training
    t_model = get_model(args,
                        num_classes,
                        n_devices,
                        args.accum_grad,
                        test=False)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.get_unlinked_variable(need_grad=False)
    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        n_devices,
                        args.accum_grad,
                        test=True)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.get_unlinked_variable(need_grad=False)
    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Solver
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_learning_rate(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitors
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=1)

    # Training loop
    vl = nn.Variable()
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            ve_local = 0.
            vl_local = 0.
            val_iter_local = args.val_iter // n_devices
            for j in range(val_iter_local):
                nextImage, nextLabel = vdata.next()
                v_model.image.data = nextImage
                v_model.label.data = nextLabel
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.d.copy()
                ve_local += v_e.d.copy()
            vl_local /= val_iter_local
            vl.d = vl_local
            comm.all_reduce(vl.data, division=True, inplace=True)
            ve_local /= val_iter_local
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl.d.copy())
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            nextImage, nextLabel = data.next()
            t_model.image.data = nextImage
            t_model.label.data = nextLabel
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            l, e = accumulate_error(l, e, t_model, t_e)

        # AllReduce
        params = [x.grad for x in nn.get_parameters().values()]
        comm.all_reduce(params, division=False, inplace=False)

        # Update
        solver.weight_decay(args.weight_decay)
        solver.update()

        if device_id == 0:
            monitor_loss.add(i * n_devices, l / args.accum_grad)
            monitor_err.add(i * n_devices, e / args.accum_grad)
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter
        if i * n_devices in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'param_%06d.h5' % (args.max_iter / n_devices)))
def meta_train(args, shape_x, train_data, valid_data, test_data):

    # Build episode generators
    train_episode_generator = EpisodeGenerator(
        train_data[0], train_data[1], args.n_class_tr, args.n_shot_tr, args.n_query_tr)
    valid_episode_generator = EpisodeGenerator(
        valid_data[0], valid_data[1], args.n_class, args.n_shot, args.n_query)
    test_episode_generator = EpisodeGenerator(
        test_data[0], test_data[1], args.n_class, args.n_shot, args.n_query)

    # Build training model
    xs_t = nn.Variable((args.n_class_tr * args.n_shot_tr, ) + shape_x)
    xq_t = nn.Variable((args.n_class_tr * args.n_query_tr, ) + shape_x)
    hq_t = net(args.n_class_tr, xs_t, xq_t, args.embedding,
               args.net_type, args.metric, False)
    yq_t = nn.Variable((args.n_class_tr * args.n_query_tr, 1))
    loss_t = F.mean(F.softmax_cross_entropy(hq_t, yq_t))

    # Build evaluation model
    xs_v = nn.Variable((args.n_class * args.n_shot, ) + shape_x)
    xq_v = nn.Variable((args.n_class * args.n_query, ) + shape_x)
    hq_v = net(args.n_class, xs_v, xq_v, args.embedding,
               args.net_type, args.metric, True)
    yq_v = nn.Variable((args.n_class * args.n_query, 1))
    err_v = F.mean(F.top_n_error(hq_v, yq_v, n=1))

    # Setup solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor outputs
    monitor = Monitor(args.work_dir)
    monitor_loss = MonitorSeries(
        "Training loss", monitor, interval=args.iter_per_epoch)
    monitor_valid_err = MonitorSeries(
        "Validation error", monitor, interval=args.iter_per_valid)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)

    # Output files
    param_file = args.work_dir + "/params.h5"
    tsne_file = args.work_dir + "/tsne.png"

    # Save NNP
    batch_size = 1
    contents = save_nnp({'x0': xs_v, 'x1': xq_v}, {
                          'y': hq_v}, batch_size)
    save.save(os.path.join(args.work_dir,
                           'MetricMetaLearning_epoch0.nnp'), contents, variable_batch_size=False)

    # Training loop
    train_losses = []
    best_err = 1.0
    for i in range(args.max_iteration):

        # Decay learning rate
        if (i + 1) % args.lr_decay_interval == 0:
            solver.set_learning_rate(solver.learning_rate() * args.lr_decay)

        # Create an episode
        xs_t.d, xq_t.d, yq_t.d = train_episode_generator.next()

        # Training by the episode
        solver.zero_grad()
        loss_t.forward(clear_no_need_grad=True)
        loss_t.backward(clear_buffer=True)
        solver.update()
        train_losses.append(loss_t.d.copy())

        # Evaluation
        if (i + 1) % args.iter_per_valid == 0:
            train_loss = np.mean(train_losses)
            train_losses = []
            valid_errs = []
            for k in range(args.n_episode_for_valid):
                xs_v.d, xq_v.d, yq_v.d = valid_episode_generator.next()
                err_v.forward(clear_no_need_grad=True, clear_buffer=True)
                valid_errs.append(np.float(err_v.d.copy()))
            valid_err = np.mean(valid_errs)

            monitor_loss.add(i + 1, loss_t.d.copy())
            monitor_valid_err.add(i + 1, valid_err * 100)
            if valid_err < best_err:
                best_err = valid_err
                nn.save_parameters(param_file)

    # Final evaluation
    nn.load_parameters(param_file)
    v_errs = []
    for k in range(args.n_episode_for_test):
        xs_v.d, xq_v.d, yq_v.d = test_episode_generator.next()
        err_v.forward(clear_no_need_grad=True, clear_buffer=True)
        v_errs.append(np.float(err_v.d.copy()))
    v_err_mean = np.mean(v_errs)
    v_err_std = np.std(v_errs)
    v_err_conf = 1.96 * v_err_std / np.sqrt(args.n_episode_for_test)
    monitor_test_err.add(0, v_err_mean * 100)
    monitor_test_conf.add(0, v_err_conf * 100)

    # Visualization
    n_class = 50
    n_sample = 20
    visualize_episode_generator = EpisodeGenerator(
        train_data[0], train_data[1], n_class, 0, n_sample)
    _, samples, labels = visualize_episode_generator.next()
    u = get_embeddings(samples, conv4)
    v = get_tsne(u)
    plot_tsne(v[:, 0], v[:, 1], labels[:, 0], tsne_file)

    # Save NNP
    contents = save_nnp({'x0': xs_v, 'x1': xq_v}, {
                          'y': hq_v}, batch_size)
    save.save(os.path.join(args.work_dir,
                           'MetricMetaLearning.nnp'), contents, variable_batch_size=False)
Beispiel #15
0
def main():
    """
    Main script.

    Steps:
    * Get and set context.
    * Load Dataset
    * Initialize DataIterator.
    * Create Networks
    *   Net for Labeled Data
    *   Net for Unlabeled Data
    *   Net for Test Data
    * Create Solver.
    * Training Loop.
    *   Test
    *   Training
    *     by Labeled Data
    *       Calculate Supervised Loss
    *     by Unlabeled Data
    *       Calculate Virtual Adversarial Noise
    *       Calculate Unsupervised Loss
    """

    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    shape_x = (1, 28, 28)
    n_h = args.n_units
    n_y = args.n_class

    # Load MNIST Dataset
    from mnist_data import load_mnist, data_iterator_mnist
    images, labels = load_mnist(train=True)
    rng = np.random.RandomState(706)
    inds = rng.permutation(len(images))

    def feed_labeled(i):
        j = inds[i]
        return images[j], labels[j]

    def feed_unlabeled(i):
        j = inds[i]
        return images[j], labels[j]

    di_l = data_iterator_simple(feed_labeled,
                                args.n_labeled,
                                args.batchsize_l,
                                shuffle=True,
                                rng=rng,
                                with_file_cache=False)
    di_u = data_iterator_simple(feed_unlabeled,
                                args.n_train,
                                args.batchsize_u,
                                shuffle=True,
                                rng=rng,
                                with_file_cache=False)
    di_v = data_iterator_mnist(args.batchsize_v, train=False)

    # Create networks
    # feed-forward-net building function
    def forward(x, test=False):
        return mlp_net(x, n_h, n_y, test)

    # Net for learning labeled data
    xl = nn.Variable((args.batchsize_l, ) + shape_x, need_grad=False)
    yl = forward(xl, test=False)
    tl = nn.Variable((args.batchsize_l, 1), need_grad=False)
    loss_l = F.mean(F.softmax_cross_entropy(yl, tl))

    # Net for learning unlabeled data
    xu = nn.Variable((args.batchsize_u, ) + shape_x, need_grad=False)
    yu = forward(xu, test=False)
    y1 = yu.get_unlinked_variable()
    y1.need_grad = False

    noise = nn.Variable((args.batchsize_u, ) + shape_x, need_grad=True)
    r = noise / (F.sum(noise**2, [1, 2, 3], keepdims=True))**0.5
    r.persistent = True
    y2 = forward(xu + args.xi_for_vat * r, test=False)
    y3 = forward(xu + args.eps_for_vat * r, test=False)
    loss_k = F.mean(distance(y1, y2))
    loss_u = F.mean(distance(y1, y3))

    # Net for evaluating validation data
    xv = nn.Variable((args.batchsize_v, ) + shape_x, need_grad=False)
    hv = forward(xv, test=True)
    tv = nn.Variable((args.batchsize_v, 1), need_grad=False)
    err = F.mean(F.top_n_error(hv, tv, n=1))

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor training and validation stats.
    import nnabla.monitor as M
    monitor = M.Monitor(args.model_save_path)
    monitor_verr = M.MonitorSeries("Test error", monitor, interval=240)
    monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=240)

    # Training Loop.
    t0 = time.time()

    for i in range(args.max_iter):

        # Validation Test
        if i % args.val_interval == 0:
            valid_error = calc_validation_error(di_v, xv, tv, err,
                                                args.val_iter)
            monitor_verr.add(i, valid_error)

        #################################
        ## Training by Labeled Data #####
        #################################

        # forward, backward and update
        xl.d, tl.d = di_l.next()
        xl.d = xl.d / 255
        solver.zero_grad()
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        #################################
        ## Training by Unlabeled Data ###
        #################################

        # Calculate y without noise, only once.
        xu.d, _ = di_u.next()
        xu.d = xu.d / 255
        yu.forward(clear_buffer=True)

        ##### Calculate Adversarial Noise #####
        # Do power method iteration
        noise.d = np.random.normal(size=xu.shape).astype(np.float32)
        for k in range(args.n_iter_for_power_method):
            r.grad.zero()
            loss_k.forward(clear_no_need_grad=True)
            loss_k.backward(clear_buffer=True)
            noise.data.copy_from(r.grad)

        ##### Calculate loss for unlabeled data #####
        # forward, backward and update
        solver.zero_grad()
        loss_u.forward(clear_no_need_grad=True)
        loss_u.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        ##### Learning rate update #####
        if i % args.iter_per_epoch == 0:
            solver.set_learning_rate(solver.learning_rate() *
                                     args.learning_rate_decay)
        monitor_time.add(i)

    # Evaluate the final model by the error rate with validation dataset
    valid_error = calc_validation_error(di_v, xv, tv, err, args.val_iter)
    monitor_verr.add(i, valid_error)
    monitor_time.add(i)

    # Save the model.
    parameter_file = os.path.join(args.model_save_path,
                                  'params_%06d.h5' % args.max_iter)
    nn.save_parameters(parameter_file)
def train_and_eval():

    # Settings
    args = get_args()
    n_class = args.n_class
    n_shot = args.n_shot
    n_query = args.n_query
    n_class_tr = args.n_class_tr
    n_shot_tr = args.n_shot_tr
    if n_shot_tr == 0:
        n_shot_tr = n_shot
    n_query_tr = args.n_query_tr
    if n_query_tr == 0:
        n_query_tr = n_query

    dataset = args.dataset
    dataset_root = args.dataset_root

    init_type = args.init_type
    embedding = args.embedding
    net_type = args.net_type
    metric = args.metric

    max_iteration = args.max_iteration
    lr_decay_interval = args.lr_decay_interval
    lr_decay = args.lr_decay
    iter_per_epoch = args.iter_per_epoch
    iter_per_valid = args.iter_per_valid
    n_episode_for_valid = args.n_episode_for_valid
    n_episode_for_test = args.n_episode_for_test
    work_dir = args.work_dir

    # Set context
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Monitor outputs
    from nnabla.monitor import Monitor, MonitorSeries
    monitor = Monitor(args.work_dir)
    monitor_loss = MonitorSeries("Training loss",
                                 monitor,
                                 interval=iter_per_epoch)
    monitor_valid_err = MonitorSeries("Validation error",
                                      monitor,
                                      interval=iter_per_valid)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)

    # Output files
    param_file = work_dir + "params.h5"
    tsne_file = work_dir + "tsne.png"

    # Load data
    shape_x = (1, 28, 28)
    train_data, valid_data, test_data = load_omniglot(dataset_root +
                                                      "/omniglot/data/")
    train_episode_generator = EpisodeGenerator(n_class_tr, n_shot_tr,
                                               n_query_tr, shape_x, train_data)
    valid_episode_generator = EpisodeGenerator(n_class, n_shot, n_query,
                                               shape_x, valid_data)
    test_episode_generator = EpisodeGenerator(n_class, n_shot, n_query,
                                              shape_x, test_data)

    # Build training model
    xs_t = nn.Variable((n_class_tr * n_shot_tr, ) + shape_x)
    xq_t = nn.Variable((n_class_tr * n_query_tr, ) + shape_x)
    hq_t = net(n_class_tr, xs_t, xq_t, init_type, embedding, net_type, metric,
               False)
    yq_t = nn.Variable((n_class_tr * n_query_tr, 1))
    loss_t = F.mean(F.softmax_cross_entropy(hq_t, yq_t))

    # Build evaluation model
    xs_v = nn.Variable((n_class * n_shot, ) + shape_x)
    xq_v = nn.Variable((n_class * n_query, ) + shape_x)
    hq_v = net(n_class, xs_v, xq_v, init_type, embedding, net_type, metric,
               True)
    yq_v = nn.Variable((n_class * n_query, 1))
    err_v = F.mean(F.top_n_error(hq_v, yq_v, n=1))

    # Setup solver
    solver = S.Adam(1.0e-3)
    solver.set_parameters(nn.get_parameters())
    learning_rate_decay_activate = True

    # Training loop
    train_losses = []
    best_err = 1.0
    for i in range(max_iteration):

        # Decay learning rate
        if learning_rate_decay_activate and ((i + 1) % lr_decay_interval == 0):
            solver.set_learning_rate(solver.learning_rate() * lr_decay)

        # Create an episode
        xs_t.d, xq_t.d, yq_t.d = train_episode_generator.next()

        # Training by the episode
        solver.zero_grad()
        loss_t.forward(clear_no_need_grad=True)
        loss_t.backward(clear_buffer=True)
        solver.update()
        train_losses.append(loss_t.d.copy())

        # Evaluation
        if (i + 1) % iter_per_valid == 0:
            train_loss = np.mean(train_losses)
            train_losses = []
            valid_errs = []
            for k in range(n_episode_for_valid):
                xs_v.d, xq_v.d, yq_v.d = valid_episode_generator.next()
                err_v.forward(clear_no_need_grad=True, clear_buffer=True)
                valid_errs.append(np.float(err_v.d.copy()))
            valid_err = np.mean(valid_errs)

            #monitor_loss.add(i + 1, train_loss)
            monitor_valid_err.add(i + 1, valid_err * 100)
            if valid_err < best_err:
                best_err = valid_err
                nn.save_parameters(param_file)

    # Final evaluation
    nn.load_parameters(param_file)
    v_errs = []
    for k in range(n_episode_for_test):
        xs_v.d, xq_v.d, yq_v.d = test_episode_generator.next()
        err_v.forward(clear_no_need_grad=True, clear_buffer=True)
        v_errs.append(np.float(err_v.d.copy()))
    v_err = np.mean(v_errs)
    v_err_conf = 1.96 * np.std(v_errs) / np.sqrt(n_episode_for_test)
    monitor_test_err.add(0, v_err * 100)
    monitor_test_conf.add(0, v_err_conf)

    # Visualization
    n_class = 50
    n_sample = 20
    batch = test_data[:n_class].reshape(n_class * n_sample, 1, 28, 28)
    label = []
    for i in range(n_class):
        label.extend(np.ones(n_sample) * (i % 50))
    u = get_embeddings(batch, conv4)
    v = get_tsne(u)
    plot_tsne(v[:, 0], v[:, 1], label, tsne_file)
Beispiel #17
0
def valid():
    """
    Main script for validation.

    """

    args = get_args()
    n_valid_samples = 50000
    num_classes = 1000
    assert n_valid_samples % args.batch_size == 0, \
        "Set batch_size such that n_valid_samples (50000) can be devided by batch_size. \Batch size is now set as {}".format(
            args.batch_size)

    # Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Pipelines and Iterators for validation
    device_id = int(args.device_id)
    val_pipes = [
        ValPipeline(args.batch_size,
                    args.num_threads,
                    device_id,
                    args.val_cachefile_dir,
                    args.val_list,
                    seed=device_id,
                    num_gpu=1)
    ]
    val_pipes[0].build()
    vdata = DALIClassificationIterator(val_pipes,
                                       val_pipes[0].epoch_size("Reader"),
                                       auto_reset=True,
                                       stop_at_epoch=False)

    # Network for validation
    nn.load_parameters(args.model_load_path)
    v_model = get_model(args, num_classes, 1, args.accum_grad, test=True)
    v_e = F.mean(F.top_n_error(v_model.pred, v_model.label, n=args.top_n))

    # Monitors
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=1)

    # Validation
    ve_local = 0.
    val_iter_local = n_valid_samples // args.batch_size
    for i in range(val_iter_local):
        nextImage, nextLabel = vdata.next()
        v_model.image.data.copy_from(nextImage)
        v_model.label.data.copy_from(nextLabel)
        v_model.image.data.cast(np.float, ctx)
        v_model.label.data.cast(np.int32, ctx)
        v_e.forward(clear_buffer=True)
        nn.logger.info("validation error is {} at {}-th batch".format(
            v_e.d, i))
        ve_local += v_e.d.copy()
    ve_local /= val_iter_local

    monitor_verr.add(0, ve_local)
    monitor_vtime.add(0)
Beispiel #18
0
def CNN_run(args, model):

    data_iterator_train, data_iterator_valid, num_class = \
                get_data_iterator_and_num_class(args)

    channels, image_height, image_width = 3, args.height, args.width
    batch_size = args.batch_size
    initial_model_lr = args.model_lr

    one_epoch = data_iterator_train.size // batch_size
    max_iter = args.epoch * one_epoch
    val_iter = data_iterator_valid.size // batch_size

    # Create monitor.
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=100)
    monitor_err = MonitorSeries("Training error", monitor, interval=100)
    monitor_vloss = MonitorSeries("Test loss", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=100)

    # prepare variables and graph used for test
    image_valid = nn.Variable(
        (batch_size, channels, image_height, image_width))
    label_valid = nn.Variable((batch_size, 1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    pred_valid = construct_networks(args,
                                    image_valid,
                                    model,
                                    num_class,
                                    test=True)
    pred_valid.persistent = True
    loss_valid = loss_function(pred_valid, label_valid)
    top_1e_valid = F.mean(F.top_n_error(pred_valid, label_valid))

    # prepare variables and graph used for training
    image_train = nn.Variable(
        (batch_size, channels, image_height, image_width))
    label_train = nn.Variable((batch_size, 1))
    input_image_train = {"image": image_train, "label": label_train}

    pred_train = construct_networks(args,
                                    image_train,
                                    model,
                                    num_class,
                                    test=False)
    loss_train = loss_function(pred_train, label_train)
    top_1e_train = F.mean(F.top_n_error(pred_train, label_train))

    # prepare solvers
    solver = S.Momentum(initial_model_lr)
    solver.set_parameters(nn.get_parameters())

    # Training-loop
    for i in range(max_iter):
        image, label = data_iterator_train.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        nn.forward_all([loss_train, top_1e_train], clear_no_need_grad=True)

        monitor_loss.add(i, loss_train.d.copy())
        monitor_err.add(i, top_1e_train.d.copy())

        if args.lr_control_model:
            new_lr = learning_rate_scheduler(i, max_iter, initial_model_lr, 0)
            solver.set_learning_rate(new_lr)

        solver.zero_grad()
        loss_train.backward(clear_buffer=True)

        if args.with_grad_clip_model:
            for k, v in nn.get_parameters().items():
                v.grad.copy_from(
                    F.clip_by_norm(v.grad, args.grad_clip_value_model))

        # update parameters
        solver.weight_decay(args.weight_decay_model)
        solver.update()

        if i % args.model_save_interval == 0:
            # Validation during training.
            ve = 0.
            vloss = 0.
            for j in range(val_iter):
                v_image, v_label = data_iterator_valid.next()
                input_image_valid["image"].d = v_image
                input_image_valid["label"].d = v_label
                nn.forward_all([loss_valid, top_1e_valid], clear_buffer=True)
                vloss += loss_valid.d.copy()
                ve += top_1e_valid.d.copy()

            ve /= val_iter
            vloss /= val_iter
            monitor_vloss.add(i, vloss)
            monitor_verr.add(i, ve)

            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_{}.h5'.format(i)))

    ve = 0.
    vloss = 0.
    for j in range(val_iter):
        v_image, v_label = data_iterator_valid.next()
        input_image_valid["image"].d = v_image
        input_image_valid["label"].d = v_label
        nn.forward_all([loss_valid, top_1e_valid], clear_buffer=True)
        vloss += loss_valid.d.copy()
        ve += top_1e_valid.d.copy()

    ve /= val_iter
    vloss /= val_iter
    monitor_vloss.add(i, vloss)
    monitor_verr.add(i, ve)

    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_{}.h5'.format(i)))

    return
Beispiel #19
0
def train(args):
    """
    Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    Steps:
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Load checkpoint to resume previous training.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Create Communicator and Context
    comm = create_communicator(ignore_error=True)
    if comm:
        n_devices = comm.size
        mpi_rank = comm.rank
        device_id = comm.local_rank
    else:
        n_devices = 1
        mpi_rank = 0
        device_id = args.device_id

    if args.context == 'cpu':
        import nnabla_ext.cpu
        context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cudnn
        context = nnabla_ext.cudnn.context(device_id=device_id)
    nn.set_default_context(context)

    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size
    iter_per_epoch = int(n_train_samples / args.batch_size / n_devices)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)

    # Create validation graphs
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((bs_valid, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = iter_per_epoch * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # load checkpoint if file exist.
    start_point = 0
    if args.use_latest_checkpoint:
        files = glob.glob(f'{args.model_save_path}/checkpoint_*.json')
        if len(files) != 0:
            index = max([
                int(n) for n in
                [re.sub(r'.*checkpoint_(\d+).json', '\\1', f) for f in files]
            ])
            # load weights and solver state info from specified checkpoint file.
            start_point = load_checkpoint(
                f'{args.model_save_path}/checkpoint_{index}.json', solver)
        print(f'checkpoint is loaded. start iteration from {start_point}')

    # Create monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator

    # If the data does not exist, it will try to download it from the server
    # and prepare it. When executing multiple processes on the same host, it is
    # necessary to execute initial data preparation by the representative
    # process (rank is 0) on the host.

    # Download dataset by rank-0 process
    if single_or_rankzero():
        rng = np.random.RandomState(mpi_rank)
        _, tdata = data_iterator(args.batch_size, True, rng)
        vsource, vdata = data_iterator(bs_valid, False)

    # Wait for data to be prepared without watchdog
    if comm:
        comm.barrier()

    # Prepare dataset for remaining process
    if not single_or_rankzero():
        rng = np.random.RandomState(mpi_rank)
        _, tdata = data_iterator(args.batch_size, True, rng)
        vsource, vdata = data_iterator(bs_valid, False)

    # Training-loop
    ve = nn.Variable()
    for i in range(start_point // n_devices, args.epochs * iter_per_epoch):
        # Validation
        if i % iter_per_epoch == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                image_valid.d = image
                label_valid.d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            if comm:
                comm.all_reduce(ve.data, division=True, inplace=True)

            # Monitoring error and elapsed time
            if single_or_rankzero():
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Save model
        if single_or_rankzero():
            if i % (args.model_save_interval // n_devices) == 0:
                iter = i * n_devices
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 'params_%06d.h5' % iter))
                if args.use_latest_checkpoint:
                    save_checkpoint(args.model_save_path, iter, solver)

        # Forward/Zerograd
        image, label = tdata.next()
        image_train.d = image
        label_train.d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        # Monitoring loss, error and elapsed time
        if single_or_rankzero():
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

    # Save nnp last epoch
    if single_or_rankzero():
        runtime_contents = {
            'networks': [{
                'name': 'Validation',
                'batch_size': args.batch_size,
                'outputs': {
                    'y': pred_valid
                },
                'names': {
                    'x': image_valid
                }
            }],
            'executors': [{
                'name': 'Runtime',
                'network': 'Validation',
                'data': ['x'],
                'output': ['y']
            }]
        }
        iter = args.epochs * iter_per_epoch
        nn.save_parameters(
            os.path.join(args.model_save_path, 'params_%06d.h5' % iter))
        nnabla.utils.save.save(
            os.path.join(args.model_save_path, f'{args.net}_result.nnp'),
            runtime_contents)
    if comm:
        comm.barrier()
Beispiel #20
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train-file", type=str)
    parser.add_argument("--valid-file", type=str)
    parser.add_argument("--num-training-examples", type=int, default=250)
    parser.add_argument("--accum-grad", type=int, default=1)
    parser.add_argument("--valid-interval", type=int, default=200)
    parser.add_argument("--threshold", type=float, default=0.95)
    parser.add_argument("--context", type=str, default="cpu")
    parser.add_argument("--device-id", type=int, default=0)

    args = parser.parse_args()

    from nnabla.ext_utils import get_extension_context
    extension_module = args.context
    ctx = get_extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # prepare data iterators
    tdata = data_iterator(
        BAbI19DataSource(args.train_file,
                         args.num_training_examples,
                         shuffle=True), 1, False, False, False)
    vdata = data_iterator(
        BAbI19DataSource(args.valid_file, 1000, shuffle=True), 1, False, False,
        False)

    # prepare monitors
    monitor = M.Monitor("./bAbI19")
    tloss = M.MonitorSeries("Training Loss", monitor, interval=10)
    terror = M.MonitorSeries("Training Error", monitor, interval=10)
    verror = M.MonitorSeries("Validation Error", monitor, interval=1)

    # prepare solver
    solver = S.Adam()
    solver_initialized = False

    cnt = 0
    while True:
        l = 0.0
        e = 0.0

        solver.zero_grad()
        for _ in range(args.accum_grad):
            # read next data
            x = tdata.next()
            V = x[1][0][0]
            E = x[2][0][0]
            ans = x[3][0][0]

            # construct GGNN
            ## convert to nn.Variable
            x = nn.Variable(V.shape)
            x.data.data = V
            h = nn.Variable((len(V), 6))
            h.data.data = utils.h_0(V, 6)

            outputs = predict(V, E, len(ans))
            losses = []
            errors = []
            for a, output in zip(ans, outputs):
                label = nn.Variable((1, 1))
                label.data.data[0, 0] = a

                losses.append(F.mean(F.softmax_cross_entropy(output, label)))
                output2 = output.unlinked()
                errors.append(F.mean(F.top_n_error(output2, label)))

            # initialize solver
            if not solver_initialized:
                solver.set_parameters(nn.get_parameters())
                solver_initialized = True
                solver.zero_grad()

            # calculate loss/error
            loss = F.mean(F.stack(*losses))
            error = F.mean(F.stack(*errors))
            F.sink(loss, error).forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)

            l += loss.data.data
            e += error.data.data

        # dump log
        tloss.add(cnt, l / args.accum_grad)
        terror.add(cnt, e / args.accum_grad)
        l = 0.0
        e = 0.0

        solver.update()

        cnt += 1
        if cnt % args.valid_interval == 0:
            # validation
            validation_error = 0
            correct_example = None
            wrong_example = None
            for _ in range(vdata.size):
                x = vdata.next()
                id2str = x[0][0][0]
                V = x[1][0][0]
                E = x[2][0][0]
                ans = x[3][0][0]

                # construct GGNN
                ## convert to nn.Variable
                x = nn.Variable(V.shape)
                x.data.data = V
                h = nn.Variable((len(V), 6))
                h.data.data = utils.h_0(V, 6)

                outputs = predict(V, E, len(ans))
                errors = []
                actual = []
                for a, output in zip(ans, outputs):
                    label = nn.Variable((1, 1))
                    label.data.data[0, 0] = a

                    errors.append(F.mean(F.top_n_error(output, label)))
                    actual.append(output.data.data)

                error = F.mean(F.stack(*errors))
                error.forward(clear_no_need_grad=True)

                x = 0.0
                if error.data.data == 0:
                    x = 0
                else:
                    x = 1

                if x > 0.5:
                    if wrong_example is None:
                        wrong_example = (id2str, V, E, ans, actual)
                else:
                    if correct_example is None:
                        correct_example = (id2str, V, E, ans, actual)
                validation_error += x
            validation_error /= vdata.size
            verror.add(cnt, validation_error)
            accuracy = 1 - validation_error
            if accuracy >= args.threshold:

                def show(example):
                    if "s" in example[2]:
                        for i, j in example[2]["s"]:
                            print("The {} is south the {}.".format(
                                example[0][i], example[0][j]))
                    if "n" in example[2]:
                        for i, j in example[2]["n"]:
                            print("The {} is north the {}.".format(
                                example[0][i], example[0][j]))
                    if "w" in example[2]:
                        for i, j in example[2]["w"]:
                            print("The {} is west the {}.".format(
                                example[0][i], example[0][j]))
                    if "e" in example[2]:
                        for i, j in example[2]["e"]:
                            print("The {} is east the {}.".format(
                                example[0][i], example[0][j]))
                    i = np.argmax(example[1][:, 0])
                    j = np.argmax(example[1][:, 1])
                    print("What is the path from {} to {}?".format(
                        example[0][i], example[0][j]))

                    for (expected, actual) in zip(example[3], example[4]):
                        i = np.argmax(actual[0])
                        print("Expected: {}, Actual: {}".format(
                            id2classes[expected], id2classes[i]))

                if correct_example is not None:
                    show(correct_example)
                if wrong_example is not None:
                    show(wrong_example)

                break
Beispiel #21
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size

    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=32,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((args.batch_size, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    _, tdata = data_iterator(args.batch_size, True, rng)
    vsource, vdata = data_iterator(args.batch_size, False)

    # loss_error_train.forward()

    # Training-loop
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                input_image_valid["image"].d = image
                input_image_valid["label"].d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            # Save model
            if device_id == 0:
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(
                        os.path.join(args.model_save_path,
                                     'params_%06d.h5' % i))

        # Forward/Zerograd
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:  # loss and error locally, and elapsed time
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

        # exit(0)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
Beispiel #22
0
def train(max_iter=24000):
    shape_x = (1, 28, 28)
    n_h = args.n_units
    n_y = args.n_class

    # Load MNIST Dataset
    from mnist_data import load_mnist, data_iterator_mnist

    images, labels = load_mnist(train=True)
    rng = np.random.RandomState(706)
    inds = rng.permutation(len(images))

    def feed_labeled(i):
        j = inds[i]
        return images[j], labels[j]

    def feed_unlabeled(i):
        j = inds[i]
        return images[j], labels[j]

    di_l = I.data_iterator_simple(
        feed_labeled,
        args.n_labeled,
        args.batchsize_l,
        shuffle=True,
        rng=rng,
        with_file_cache=False,
    )
    di_u = I.data_iterator_simple(
        feed_unlabeled,
        args.n_train,
        args.batchsize_u,
        shuffle=True,
        rng=rng,
        with_file_cache=False,
    )
    di_v = data_iterator_mnist(args.batchsize_v, train=False)

    # Create networks
    # feed-forward-net building function
    def forward(x, test=False):
        return I.mlp_net(x, n_h, n_y, test)

    # Net for learning labeled data
    xl = nn.Variable((args.batchsize_l,) + shape_x, need_grad=False)
    yl = forward(xl, test=False)
    tl = nn.Variable((args.batchsize_l, 1), need_grad=False)
    loss_l = F.mean(F.softmax_cross_entropy(yl, tl))

    # Net for learning unlabeled data
    xu = nn.Variable((args.batchsize_u,) + shape_x, need_grad=False)
    yu = forward(xu, test=False)
    y1 = yu.get_unlinked_variable()
    y1.need_grad = False

    noise = nn.Variable((args.batchsize_u,) + shape_x, need_grad=True)
    r = noise / (F.sum(noise ** 2, [1, 2, 3], keepdims=True)) ** 0.5
    r.persistent = True
    y2 = forward(xu + args.xi_for_vat * r, test=False)
    y3 = forward(xu + args.eps_for_vat * r, test=False)
    loss_k = F.mean(I.distance(y1, y2))
    loss_u = F.mean(I.distance(y1, y3))

    # Net for evaluating validation data
    xv = nn.Variable((args.batchsize_v,) + shape_x, need_grad=False)
    hv = forward(xv, test=True)
    tv = nn.Variable((args.batchsize_v, 1), need_grad=False)
    err = F.mean(F.top_n_error(hv, tv, n=1))

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor training and validation stats.
    path = cache_dir(os.path.join(I.name, "monitor"))
    monitor = M.Monitor(path)
    monitor_verr = M.MonitorSeries("val_error", monitor, interval=240)
    monitor_time = M.MonitorTimeElapsed("time", monitor, interval=240)

    # Training Loop.
    for i in range(max_iter):

        # Validation Test
        if i % args.val_interval == 0:
            valid_error = I.calc_validation_error(di_v, xv, tv, err, args.val_iter)
            monitor_verr.add(i, valid_error)

        # forward, backward and update
        xl.d, tl.d = di_l.next()
        xl.d = xl.d / 255
        solver.zero_grad()
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Calculate y without noise, only once.
        xu.d, _ = di_u.next()
        xu.d = xu.d / 255
        yu.forward(clear_buffer=True)

        # Do power method iteration
        noise.d = np.random.normal(size=xu.shape).astype(np.float32)
        for k in range(args.n_iter_for_power_method):
            r.grad.zero()
            loss_k.forward(clear_no_need_grad=True)
            loss_k.backward(clear_buffer=True)
            noise.data.copy_from(r.grad)

        # forward, backward and update
        solver.zero_grad()
        loss_u.forward(clear_no_need_grad=True)
        loss_u.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        if i % args.iter_per_epoch == 0:
            solver.set_learning_rate(solver.learning_rate() * args.learning_rate_decay)
        monitor_time.add(i)

    # Evaluate the final model by the error rate with validation dataset
    valid_error = I.calc_validation_error(di_v, xv, tv, err, args.val_iter)
    monitor_verr.add(i, valid_error)
    monitor_time.add(i)

    return path
 def define_loss(pred, in_label, label, label_smoothing):
     loss = F.mean(
         softmax_cross_entropy_with_label_smoothing(pred, label,
                                                    label_smoothing))
     error = F.sum(F.top_n_error(pred, in_label, n=1))
     return loss, error
def train():
    """
    Main script.

    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error

    """

    args = get_args()
    if args.tiny_mode:
        n_train_samples = 100000
    else:
        n_train_samples = 1282167

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # workarond to start with the same parameters.
    rng = np.random.RandomState(device_id)
    if args.tiny_mode:
        # We use Tiny ImageNet from Stanford CS231N class.
        # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/)
        # Tiny ImageNet consists of 200 categories, each category has 500 images
        # in training set. The image size is 64x64. To adapt ResNet into 64x64
        # image inputs, the input image size of ResNet is set as 56x56, and
        # the stride in the first conv and the first max pooling are removed.
        # Please check README.
        data = data_iterator_tiny_imagenet(args.batch_size, 'train')
        vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')
        num_classes = 200
    else:
        # We use ImageNet.
        # (ImageNet, https://imagenet.herokuapp.com/)
        # ImageNet consists of 1000 categories, each category has 1280 images
        # in training set. The image size is various. To adapt ResNet into
        # 320x320 image inputs, the input image size of ResNet is set as
        # 224x224. We need to get tar file and create cache file(320x320 images).
        # Please check README.
        data = data_iterator_imagenet(args.batch_size,
                                      args.train_cachefile_dir,
                                      rng=rng)
        vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir)
        vdata = vdata.slice(rng=None,
                            num_of_slices=n_devices,
                            slice_pos=device_id)
        num_classes = 1000
    # Workaround to start with the same initialized weights for all workers.
    np.random.seed(313)
    t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.unlinked()
    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()
    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Add parameters to communicator.
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Setting warmup.
    base_lr = args.learning_rate / n_devices
    warmup_iter = int(1. * n_train_samples / args.batch_size /
                      args.accum_grad / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=1)

    # Training loop.
    vl = nn.Variable()
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            ve_local = 0.
            vl_local = 0.
            val_iter_local = args.val_iter // n_devices
            for j in range(val_iter_local):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.d.copy()
                ve_local += v_e.d.copy()
            vl_local /= val_iter_local
            vl.d = vl_local
            comm.all_reduce(vl.data, division=True, inplace=True)
            ve_local /= val_iter_local
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl.d.copy())
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            if j != 0:
                # Update e and l according to previous results of forward
                # propagation.
                # The update of last iteration is performed
                # after solver update to avoid unnecessary CUDA synchronization.
                # This is performed after data.next() in order to overlap
                # the data loading and graph execution.
                # TODO: Move this to the bottom of the loop when prefetch
                # data loader is available.
                l, e = accumulate_error(l, e, t_model, t_e)
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)

        # AllReduce
        params = [x.grad for x in nn.get_parameters().values()]
        comm.all_reduce(params, division=False, inplace=False)

        # Update
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Accumulate errors after solver update
        l, e = accumulate_error(l, e, t_model, t_e)

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        # Synchronize by averaging the weights over devices using allreduce
        if (i + 1) % args.sync_weight_every_itr == 0:
            weights = [x.data for x in nn.get_parameters().values()]
            comm.all_reduce(weights, division=True, inplace=True)

        if device_id == 0:
            monitor_loss.add(i * n_devices, l / args.accum_grad)
            monitor_err.add(i * n_devices, e / args.accum_grad)
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter
        if i * n_devices in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'param_%06d.h5' % (args.max_iter / n_devices)))
Beispiel #25
0
def train():
    """
    Main script.
    """

    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    if args.tiny_mode:
        # We use Tiny ImageNet from Stanford CS231N class.
        # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/)
        # Tiny ImageNet consists of 200 categories, each category has 500 images
        # in training set. The image size is 64x64. To adapt ResNet into 64x64
        # image inputs, the input image size of ResNet is set as 56x56, and
        # the stride in the first conv and the first max pooling are removed.
        # Please check README.
        data = data_iterator_tiny_imagenet(args.batch_size, 'train')
        vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')
        num_classes = 200
    else:
        # We use ImageNet.
        # (ImageNet, https://imagenet.herokuapp.com/)
        # ImageNet consists of 1000 categories, each category has 1280 images
        # in training set. The image size is various. To adapt ResNet into
        # 320x320 image inputs, the input image size of ResNet is set as
        # 224x224. We need to get tar file and create cache file(320x320 images).
        # Please check README.
        data = data_iterator_imagenet(args.batch_size,
                                      args.train_cachefile_dir)
        vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir)
        num_classes = 1000
    t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward

    # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix.
    t_pred2 = t_model.pred.get_unlinked_variable()
    t_pred2.need_grad = False

    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward

    # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix.
    v_pred2 = v_model.pred.get_unlinked_variable()
    v_pred2.need_grad = False

    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Save_nnp_Epoch0
    contents = save_nnp({'x': v_model.image}, {'y': v_model.pred},
                        args.batch_size)
    save.save(os.path.join(args.model_save_path, 'Imagenet_result_epoch0.nnp'),
              contents)

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    start_point = 0
    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=10)

    # Training loop.
    for i in range(start_point, args.max_iter):
        # Save parameters
        if i % args.model_save_interval == 0:
            # save checkpoint file
            save_checkpoint(args.model_save_path, i, solver)

        # Validation
        if i % args.val_interval == 0 and i != 0:

            # Clear all intermediate memory to save memory.
            # t_model.loss.clear_recursive()

            l = 0.0
            e = 0.0
            for j in range(args.val_iter):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                l += v_model.loss.d
                e += v_e.d
            monitor_vloss.add(i, l / args.val_iter)
            monitor_verr.add(i, e / args.val_iter)
            monitor_vtime.add(i)

            # Clear all intermediate memory to save memory.
            # v_model.loss.clear_recursive()

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            l, e = accumulate_error(l, e, t_model, t_e)

        solver.weight_decay(args.weight_decay)
        solver.update()

        monitor_loss.add(i, l / args.accum_grad)
        monitor_err.add(i, e / args.accum_grad)
        monitor_time.add(i)

        # Learning rate decay at scheduled iter
        if i in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)
    nn.save_parameters(
        os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter))

    # Save_nnp
    contents = save_nnp({'x': v_model.image}, {'y': v_model.pred},
                        args.batch_size)
    save.save(os.path.join(args.model_save_path, 'Imagenet_result.nnp'),
              contents)
def train():
    """
    Main script.
    """

    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    if args.tiny_mode:
        # We use Tiny ImageNet from Stanford CS231N class.
        # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/)
        # Tiny ImageNet consists of 200 categories, each category has 500 images
        # in training set. The image size is 64x64. To adapt ResNet into 64x64
        # image inputs, the input image size of ResNet is set as 56x56, and
        # the stride in the first conv and the first max pooling are removed.
        # Please check README.
        data = data_iterator_tiny_imagenet(args.batch_size, 'train')
        vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')
        num_classes = 200
    else:
        # We use ImageNet.
        # (ImageNet, https://imagenet.herokuapp.com/)
        # ImageNet consists of 1000 categories, each category has 1280 images
        # in training set. The image size is various. To adapt ResNet into
        # 320x320 image inputs, the input image size of ResNet is set as
        # 224x224. We need to get tar file and create cache file(320x320 images).
        # Please check README.
        data = data_iterator_imagenet(args.batch_size,
                                      args.train_cachefile_dir)
        vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir)
        num_classes = 1000
    t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.unlinked()
    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()
    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=10)

    # Training loop.
    for i in range(args.max_iter):
        # Save parameters
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % args.val_interval == 0 and i != 0:

            # Clear all intermediate memory to save memory.
            # t_model.loss.clear_recursive()

            l = 0.0
            e = 0.0
            for j in range(args.val_iter):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                l += v_model.loss.d
                e += v_e.d
            monitor_vloss.add(i, l / args.val_iter)
            monitor_verr.add(i, e / args.val_iter)
            monitor_vtime.add(i)

            # Clear all intermediate memory to save memory.
            # v_model.loss.clear_recursive()

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            if j != 0:
                # Update e and l according to previous results of forward
                # propagation.
                # The update of last iteration is performed
                # after solver update to avoid unnecessary CUDA synchronization.
                # This is performed after data.next() in order to overlap
                # the data loading and graph execution.
                # TODO: Move this to the bottom of the loop when prefetch
                # data loader is available.
                l, e = accumulate_error(l, e, t_model, t_e)
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)

        solver.weight_decay(args.weight_decay)
        solver.update()

        # Accumulate errors after solver update
        l, e = accumulate_error(l, e, t_model, t_e)

        monitor_loss.add(i, l / args.accum_grad)
        monitor_err.add(i, e / args.accum_grad)
        monitor_time.add(i)

        # Learning rate decay at scheduled iter
        if i in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)
    nn.save_parameters(
        os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter))