Example #1
0
def train(env,
          model,
          buffer,
          exploration,
          monitor,
          update_fn,
          eval_fn,
          final_step,
          update_start,
          update_interval,
          save_interval,
          evaluate_interval,
          loss_labels=[]):
    reward_monitor = MonitorSeries('reward', monitor, interval=1)
    eval_reward_monitor = MonitorSeries('eval_reward', monitor, interval=1)
    time_monitor = MonitorTimeElapsed('time', monitor, interval=10000)
    loss_monitors = []
    for label in loss_labels:
        loss_monitors.append(MonitorSeries(label, monitor, interval=10000))

    step = 0
    while step <= final_step:
        obs_t = env.reset()
        ter_tp1 = False
        cumulative_reward = 0.0
        model.reset(step)
        while not ter_tp1:
            # select best action
            act_t = model.infer(obs_t)
            # add exploration noise
            act_t = exploration.get(step, act_t)
            # iterate environment
            obs_tp1, rew_tp1, ter_tp1, _ = env.step(act_t)
            # store transition
            buffer.append(obs_t, [act_t], rew_tp1, obs_tp1, ter_tp1)

            # update parameters
            if step > update_start and step % update_interval == 0:
                for i, loss in enumerate(update_fn(step)):
                    if loss is not None:
                        loss_monitors[i].add(step, loss)

            # save parameters
            if step % save_interval == 0:
                path = os.path.join(monitor.save_path, 'model_%d.h5' % step)
                nn.save_parameters(path)

            if step % evaluate_interval == 0:
                eval_reward_monitor.add(step, np.mean(eval_fn()))

            step += 1
            cumulative_reward += rew_tp1
            obs_t = obs_tp1
            time_monitor.add(step)

        # record metrics
        reward_monitor.add(step, cumulative_reward)
Example #2
0
    def train(self):
        # variables for training
        tx_in = nn.Variable(
            [self._batch_size, self._x_input_length, self._cols_size])
        tx_out = nn.Variable(
            [self._batch_size, self._x_output_length, self._cols_size])
        tpred = self.network(tx_in, self._lstm_unit_name, self._lstm_units)
        tpred.persistent = True
        loss = F.mean(F.squared_error(tpred, tx_out))
        solver = S.Adam(self._learning_rate)
        solver.set_parameters(nn.get_parameters())

        # variables for validation
        vx_in = nn.Variable(
            [self._batch_size, self._x_input_length, self._cols_size])
        vx_out = nn.Variable(
            [self._batch_size, self._x_output_length, self._cols_size])
        vpred = self.network(vx_in, self._lstm_unit_name, self._lstm_units)

        # data iterators
        tdata = self._load_dataset(self._training_dataset_path,
                                   self._batch_size,
                                   shuffle=True)
        vdata = self._load_dataset(self._validation_dataset_path,
                                   self._batch_size,
                                   shuffle=True)

        # monitors
        from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
        monitor = Monitor(self._monitor_path)
        monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
        monitor_err = MonitorSeries("Training error", monitor, interval=10)
        monitor_time = MonitorTimeElapsed("Training time",
                                          monitor,
                                          interval=100)
        monitor_verr = MonitorSeries("Validation error", monitor, interval=10)

        # Training loop
        for i in range(self._max_iter):
            if i % self._val_interval == 0:
                ve = self._validate(vpred, vx_in, vx_out, vdata,
                                    self._val_iter)
                monitor_verr.add(i, ve / self._val_iter)
            te = self._train(tpred, solver, loss, tx_in, tx_out, tdata.next(),
                             self._weight_decay)
            monitor_loss.add(i, loss.d.copy())
            monitor_err.add(i, te)
            monitor_time.add(i)
        ve = self._validate(vpred, vx_in, vx_out, vdata, self._val_iter)
        monitor_verr.add(i, ve / self._val_iter)

        # Save a best model parameters
        nn.save_parameters(self._model_params_path)
Example #3
0
def main():
    # Training settings
    args = Yolov2OptionTraining().parse_args()

    nsamples = file_lines(args.train)

    set_default_context_by_args(args)

    # Training parameters
    max_epochs = args.max_batches * args.batch_size * args.accum_times / nsamples + 1

    if not os.path.exists(args.output):
        os.mkdir(args.output)

    ###############

    # Load parameters
    print("Load", args.weight, "...")
    if args.fine_tune:
        nn.load_parameters(args.weight)
        nn.parameter.pop_parameter("detection/conv/W")
        nn.parameter.pop_parameter("detection/conv/b")
    else:
        nn.load_parameters(args.weight)

    train_graph = TrainGraph(args, (args.size_aug[-1], args.size_aug[-1]))
    yolo_solver = YoloSolver(args)

    if args.on_memory_data:
        on_memory_data = dataset.load_on_memory_data(args.train)
    else:
        on_memory_data = None

    prefetch_iterator = PrefetchIterator()

    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.output)
    monitor_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_time = MonitorTimeElapsed('Time per epoch', monitor, interval=1)

    # Epoch loop
    for epoch in range(0, int(max_epochs)):
        loss = train(args, epoch, max_epochs, train_graph, yolo_solver,
                     prefetch_iterator, on_memory_data)
        monitor_loss.add(epoch, loss)
        monitor_time.add(epoch)

    # Save the final parameters
    logging('save weights to %s/%06d.h5' % (args.output, epoch + 1))
    nn.save_parameters('%s/%06d.h5' % (args.output, epoch + 1))
def get_common_monitors(monitor):
    """
    Create monitors for displaying and storing losses.
    """
    monitor_content_loss = MonitorSeries('content loss', monitor, interval=20)
    monitor_gen_loss = MonitorSeries('generator loss', monitor, interval=20)
    monitor_warp_loss = MonitorSeries('warp loss', monitor, interval=20)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=20)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=20)
    Monitor_common = collections.namedtuple('Monitor_common', [
        'monitor_content_loss', 'monitor_gen_loss', 'monitor_warp_loss',
        'monitor_lr', 'monitor_time'
    ])
    return Monitor_common(monitor_content_loss, monitor_gen_loss,
                          monitor_warp_loss, monitor_lr, monitor_time)
Example #5
0
def create_monitor(net_name: str, interval=100):
    value = namedtuple("value", ("loss", "error"))
    monitor = namedtuple("monitor", ("path", "time", "train", "val"))

    path = get_monitor_path(net_name)
    kwargs = dict(monitor=Monitor(path), interval=interval)
    time = MonitorTimeElapsed("time", **kwargs)

    kwargs["verbose"] = False
    loss = MonitorSeries("train_loss", **kwargs)
    error = MonitorSeries("train_error", **kwargs)
    value_train = value(loss, error)

    loss = MonitorSeries("val_loss", **kwargs)
    error = MonitorSeries("val_error", **kwargs)
    value_val = value(loss, error)

    return monitor(path, time, value_train, value_val)
Example #6
0
def setup_monitor(conf, monitor):
    """
    Setup monitor to keep track of losses and times to log them
    """
    jsi_monitor = {
        'rec_loss':
        MonitorSeries('rec_loss', monitor, interval=conf.monitor_interval),
        'psnr':
        MonitorSeries('psnr', monitor, interval=conf.monitor_interval),
        'lr':
        MonitorSeries('learning_rate', monitor,
                      interval=conf.monitor_interval),
        'g_final_loss':
        MonitorSeries('g_final_loss', monitor, interval=conf.monitor_interval),
        'd_final_fm_loss':
        MonitorSeries('d_final_fm_loss',
                      monitor,
                      interval=conf.monitor_interval),
        'd_final_detail_loss':
        MonitorSeries('d_final_detail_loss',
                      monitor,
                      interval=conf.monitor_interval),
        'g_adv_loss':
        MonitorSeries('g_adv_loss', monitor, interval=conf.monitor_interval),
        'g_detail_adv_loss':
        MonitorSeries('g_detail_adv_loss',
                      monitor,
                      interval=conf.monitor_interval),
        'fm_loss':
        MonitorSeries('fm_loss', monitor, interval=conf.monitor_interval),
        'fm_detail_loss':
        MonitorSeries('fm_detail_loss',
                      monitor,
                      interval=conf.monitor_interval),
        'time':
        MonitorTimeElapsed("Training time per epoch",
                           monitor,
                           interval=conf.monitor_interval)
    }
    return jsi_monitor
Example #7
0
def distil():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction
        data_iterator = data_iterator_cifar10
        c = 3
        h = w = 32
        n_train = 50000
        n_valid = 10000

    # TRAIN
    teacher = "teacher"
    student = "student"
    maps = args.maps
    rrate = args.reduction_rate
    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    image.persistent = True  # not clear the intermediate buffer re-used
    label = nn.Variable([args.batch_size, 1])
    label.persistent = True  # not clear the intermediate buffer re-used
    # Create `teacher` and "student" prediction graph.
    model_load_path = args.model_load_path
    nn.load_parameters(model_load_path)
    pred_label = model_prediction(image,
                                  net=teacher,
                                  maps=maps,
                                  test=not args.use_batch)
    pred_label.need_grad = False  # no need backward through teacher graph
    pred = model_prediction(image,
                            net=student,
                            maps=int(maps * (1. - rrate)),
                            test=False)
    pred.persistent = True  # not clear the intermediate buffer used
    loss_ce = F.mean(F.softmax_cross_entropy(pred, label))
    loss_ce_soft = ce_soft(pred, pred_label)
    loss = args.weight_ce * loss_ce + args.weight_ce_soft * loss_ce_soft

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create teacher prediction graph.
    vpred = model_prediction(vimage,
                             net=student,
                             maps=int(maps * (1. - rrate)),
                             test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    with nn.parameter_scope(student):
        solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator for MNIST.
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata[1].next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data[1].next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata[1].next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct computation graphs for training and one for validation.
    * Initialize solvers and set parameter variables to those.
    * Instantiate a communicator and set parameter variables.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprops
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Create contexts
    extension_module = args.context
    if extension_module != "cuda" and \
            extension_module != "cuda.cudnn":
        raise Exception("Use `cuda` or `cuda.cudnn` extension_module.")
    n_devices = args.n_devices
    ctxs = []
    for i in range(n_devices):
        ctx = extension_context(extension_module, device_id=i)
        ctxs.append(ctx)
    ctx = ctxs[-1]

    # Create training graphs
    input_image_train = []
    preds_train = []
    losses_train = []
    test = False
    for i in range(n_devices):
        image = nn.Variable((args.batch_size, 3, 32, 32))
        label = nn.Variable((args.batch_size, 1))
        device_scope_name = "device{}".format(i)

        pred = cifar100_resnet23_prediction(
            image, ctxs[i], device_scope_name, test)
        loss = cifar100_resnet32_loss(pred, label)

        input_image_train.append({"image": image, "label": label})
        preds_train.append(pred)
        losses_train.append(loss)

    # Create validation graph
    test = True
    device_scope_name = "device{}".format(0)
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar100_resnet23_prediction(
        image_valid, ctxs[i], device_scope_name, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solvers = []
    for i in range(n_devices):
        with nn.context_scope(ctxs[i]):
            solver = S.Adam()
            device_scope_name = "device{}".format(i)
            with nn.parameter_scope(device_scope_name):
                params = nn.get_parameters()
                solver.set_parameters(params)
            solvers.append(solver)

    # Communicator
    comm = C.DataParalellCommunicator(ctx)
    for i in range(n_devices):
        device_scope_name = "device{}".format(i)
        with nn.parameter_scope(device_scope_name):
            ctx = ctxs[i]
            params = nn.get_parameters()
            comm.add_context_and_parameters((ctx, params))
    comm.init()

    # Create threadpools with one thread
    pools = []
    for _ in range(n_devices):
        pool = ThreadPool(processes=1)
        pools.append(pool)

    # Once forward/backward to safely secure memory
    for device_id in range(n_devices):
        data, label = \
            (np.random.randn(*input_image_train[device_id]["image"].shape),
             (np.random.rand(*input_image_train[device_id]["label"].shape) * 10).astype(np.int32))

        ret = pools[device_id].apply_async(forward_backward,
                                           (input_image_train[device_id]["image"], data,
                                            input_image_train[device_id]["label"], label,
                                               losses_train[device_id], solvers[device_id]))
        ret.get()
        losses_train[device_id].d  # sync to host

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)
    with data_iterator_cifar100(args.batch_size, True) as tdata, \
            data_iterator_cifar100(bs_valid, False) as vdata:
        # Training-loop
        for i in range(int(args.max_iter / n_devices)):
            # Validation
            if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                ve = 0.
                for j in range(args.val_iter):
                    image, label = vdata.next()
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= args.val_iter
                monitor_verr.add(i * n_devices, ve)
            if i % int(args.model_save_interval / n_devices) == 0:
                nn.save_parameters(os.path.join(
                    args.model_save_path, 'params_%06d.h5' % i))

            # Forwards/Zerograd/Backwards
            fb_results = []
            for device_id in range(n_devices):
                image, label = tdata.next()

                res = pools[device_id].apply_async(forward_backward,
                                                   (input_image_train[device_id]["image"], image,
                                                    input_image_train[device_id]["label"], label,
                                                    losses_train[device_id], solvers[device_id]))
                fb_results.append(res)
            for device_id in range(n_devices):
                fb_results[device_id].get()

            # In-place Allreduce
            comm.allreduce()

            # Solvers update
            for device_id in range(n_devices):
                solvers[device_id].update()

            e = categorical_error(
                preds_train[-1].d, input_image_train[-1]["label"].d)
            monitor_loss.add(i * n_devices, losses_train[-1].d.copy())
            monitor_err.add(i * n_devices, e)
            monitor_time.add(i * n_devices)

    nn.save_parameters(os.path.join(
        args.model_save_path,
        'params_%06d.h5' % (args.max_iter / n_devices)))
Example #9
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
Example #10
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path,
        '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)
Example #11
0
def train(args):
    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = args.batch_size, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])

    # Model
    # workaround for starting with the same model among devices.
    np.random.seed(412)
    maps = args.maps
    # within-domain reconstruction (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a")
    # within-domain reconstruction (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(shape=x_style_a.shape)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a")
    x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(shape=x_style_b.shape)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")
    x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b")
    x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b")
    # discriminate (domain A)
    p_x_fake_a_list = discriminators(x_fake_a)
    p_x_real_a_list = discriminators(x_real_a)
    p_x_fake_b_list = discriminators(x_fake_b)
    p_x_real_b_list = discriminators(x_real_b)

    # Loss
    # within-domain reconstruction
    loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True)
    loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True)
    # content and style reconstruction
    loss_recon_x_style_a = recon_loss(x_style_rec_a,
                                      z_style_a).apply(persistent=True)
    loss_recon_x_content_b = recon_loss(x_content_rec_b,
                                        x_content_b).apply(persistent=True)
    loss_recon_x_style_b = recon_loss(x_style_rec_b,
                                      z_style_b).apply(persistent=True)
    loss_recon_x_content_a = recon_loss(x_content_rec_a,
                                        x_content_a).apply(persistent=True)

    # adversarial

    def f(x, y):
        return x + y

    loss_gen_a = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_a_list]).apply(persistent=True)
    loss_dis_a = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list)
    ]).apply(persistent=True)
    loss_gen_b = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_b_list]).apply(persistent=True)
    loss_dis_b = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list)
    ]).apply(persistent=True)
    # loss for generator-related models
    loss_gen = loss_gen_a + loss_gen_b \
        + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \
        + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \
        + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b)
    # loss for discriminators
    loss_dis = loss_dis_a + loss_dis_b

    # Solver
    lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2
    # solver for generator-related models
    solver_gen = S.Adam(lr_g, beta1, beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
    solver_gen.set_parameters(params_gen)
    # solver for discriminators
    solver_dis = S.Adam(lr_d, beta1, beta2)
    with nn.parameter_scope("discriminators"):
        params_dis = nn.get_parameters()
    solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    # time
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    # reconstruction
    monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A",
                                                 monitor,
                                                 interval=10)
    monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B",
                                                 monitor,
                                                 interval=10)
    # adversarial
    monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10)
    monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10)
    monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10)
    monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10)
    monitor_losses = [
        # reconstruction
        (monitor_loss_recon_x_a, loss_recon_x_a),
        (monitor_loss_recon_x_content_b, loss_recon_x_content_b),
        (monitor_loss_recon_x_style_a, loss_recon_x_style_a),
        (monitor_loss_recon_x_b, loss_recon_x_b),
        (monitor_loss_recon_x_content_a, loss_recon_x_content_a),
        (monitor_loss_recon_x_style_b, loss_recon_x_style_b),
        # adaversarial
        (monitor_loss_gen_a, loss_gen_a),
        (monitor_loss_dis_a, loss_dis_a),
        (monitor_loss_gen_b, loss_gen_b),
        (monitor_loss_dis_b, loss_dis_b)
    ]
    # image
    monitor_image_a = MonitorImage("Fake Image B to A Train",
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B Train",
                                   monitor,
                                   interval=1)
    monitor_images = [
        (monitor_image_a, x_fake_a),
        (monitor_image_b, x_fake_b),
    ]

    # DataIterator
    rng_a = np.random.RandomState(device_id)
    rng_b = np.random.RandomState(device_id + n_devices)
    di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b)

    # Train
    for i in range(args.max_iter // n_devices):
        ii = i * n_devices
        # Train generator-related models
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_gen.values()])
        solver_gen.weight_decay(args.weight_decay_rate)
        solver_gen.update()

        # Train discriminators
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        x_fake_a.need_grad, x_fake_b.need_grad = False, False
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_dis.values()])
        solver_dis.weight_decay(args.weight_decay_rate)
        solver_dis.update()
        x_fake_a.need_grad, x_fake_b.need_grad = True, True

        # LR schedule
        if (i + 1) % (args.lr_decay_at_every // n_devices) == 0:
            lr_d = solver_dis.learning_rate() * args.lr_decay_rate
            lr_g = solver_gen.learning_rate() * args.lr_decay_rate
            solver_dis.set_learning_rate(lr_d)
            solver_gen.set_learning_rate(lr_g)

        if mpi_local_rank == 0:
            # Monitor
            monitor_time.add(ii)
            for mon, loss in monitor_losses:
                mon.add(ii, loss.d)
            # Save
            if (i + 1) % (args.model_save_interval // n_devices) == 0:
                for mon, x in monitor_images:
                    mon.add(ii, x.d)
                nn.save_parameters(
                    os.path.join(args.monitor_path,
                                 "param_{:05d}.h5".format(i)))

    if mpi_local_rank == 0:
        # Monitor
        for mon, loss in monitor_losses:
            mon.add(ii, loss.d)
        # Save
        for mon, x in monitor_images:
            mon.add(ii, x.d)
        nn.save_parameters(
            os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))
Example #12
0
def main():

    random.seed(args.seed)
    np.random.seed(args.seed)

    # Prepare for CUDA.
    ctx = get_extension_context('cudnn', device_id=args.gpus)
    nn.set_default_context(ctx)

    start_full_time = time.time()
    from iterator import data_iterator

    # Data list for sceneflow data set
    train_list = "./dataset/sceneflow_train.csv"
    test_list = "./dataset/sceneflow_test.csv"
    train = True
    validation = True

    # Set monitor path.
    monitor_path = './nnmonitor' + str(datetime.now().strftime("%Y%m%d%H%M%S"))

    img_left, img_right, disp_img = read_csv(train_list)
    img_left_test, img_right_test, disp_img_test = read_csv(test_list)
    train_samples = len(img_left)
    test_samples = len(img_left_test)
    train_size = int(len(img_left) / args.batchsize_train)
    test_size = int(len(img_left_test) / args.batchsize_test)

    # Create data iterator.
    data_iterator_train = data_iterator(
        train_samples, args.batchsize_train, img_left, img_right, disp_img, train=True, shuffle=True, dataset=args.dataset)
    data_iterator_test = data_iterator(
        test_samples, args.batchsize_test, img_left_test, img_right_test, disp_img_test, train=False, shuffle=False, dataset=args.dataset)

    # Set data size

    print(train_size, test_size)

    # Define data shape for training.
    var_left = nn.Variable(
        (args.batchsize_train, 3, args.crop_height, args.crop_width))
    var_right = nn.Variable(
        (args.batchsize_train, 3, args.crop_height, args.crop_width))
    var_disp = nn.Variable(
        (args.batchsize_train, 1, args.crop_height, args.crop_width))
    # Define data shape for testing.
    var_left_test = nn.Variable(
        (args.batchsize_test, 3, args.im_height, args.im_width))
    var_right_test = nn.Variable(
        (args.batchsize_test, 3, args.im_height, args.im_width))
    var_disp_test = nn.Variable(
        (args.batchsize_test, 1, args.im_height, args.im_width))
    mask_test = nn.Variable(
        (args.batchsize_test, 1, args.im_height, args.im_width))

    if args.loadmodel is not None:
        # Loading CNN pretrained parameters.
        nn.load_parameters(args.loadmodel)

    # === for Training ===
    # Definition of pred
    pred1, pred2, pred3 = psm_net(var_left, var_right, args.maxdisp, True)
    mask_train = F.less_scalar(var_disp, args.maxdisp)
    sum_mask = F.maximum_scalar(F.sum(mask_train), 1)
    # Definition of loss
    loss = 0.5 * (0.5 * F.sum(F.huber_loss(pred1, var_disp)*mask_train)/(sum_mask) + 0.7 * F.sum(F.huber_loss(
        pred2, var_disp)*mask_train)/(sum_mask) + F.sum(F.huber_loss(pred3, var_disp)*mask_train)/(sum_mask))

    # === for Testing ===
    # Definition of pred
    mask_test = F.less_scalar(var_disp_test, args.maxdisp)
    sum_mask_test = F.maximum_scalar(F.sum(mask_test), 1)
    pred_test = psm_net(var_left_test, var_right_test, args.maxdisp, False)
    test_loss = F.sum(F.abs(pred_test - var_disp_test)*mask_test)/sum_mask_test

    # Prepare monitors.
    monitor = Monitor(monitor_path)
    monitor_train = MonitorSeries('Training loss', monitor, interval=1)
    monitor_test = MonitorSeries('Validation loss', monitor, interval=1)
    monitor_time_train = MonitorTimeElapsed(
        "Training time/epoch", monitor, interval=1)

    # Create a solver (parameter updater)
    solver = S.Adam(alpha=0.001, beta1=0.9, beta2=0.999)

    # Set Parameters
    params = nn.get_parameters()
    solver.set_parameters(params)
    params2 = nn.get_parameters(grad_only=False)
    solver.set_parameters(params2)

    for epoch in range(1, args.epochs+1):
        print('This is %d-th epoch' % (epoch))

        if validation:
            ## teting ##
            total_test_loss = 0

            index_test = 0
            while index_test < test_size:
                var_left_test.d, var_right_test.d, var_disp_test.d = data_iterator_test.next()
                test_loss.forward(clear_no_need_grad=True)
                total_test_loss += test_loss

                print('Iter %d test loss = %.3f' % (index_test, test_loss.d))
                index_test += 1
            test_error = total_test_loss/test_size
            print('epoch %d total 3-px error in val = %.3f' %
                  (epoch, test_error.d))
            # Pass validation loss to a monitor.
            monitor_test.add(epoch, test_error)

        if train:
            ## training ##
            total_train_loss = 0
            index = 0

            while index < train_size:

                # Get mini batch
                # Preprocess
                var_left.d, var_right.d, var_disp.d = data_iterator_train.next()
                loss.forward(clear_no_need_grad=True)
                # Initialize gradients
                solver.zero_grad()
                # Backward execution
                loss.backward(clear_buffer=True)
                # Update parameters by computed gradients
                solver.update()
                print('Iter %d training loss = %.3f' %
                      (index, loss.d))
                total_train_loss += loss.d
                index += 1
            train_error = total_train_loss/train_size
            monitor_time_train.add(epoch)
            print('epoch %d total training loss = %.3f' % (epoch, train_error))

            # Pass training loss to a monitor.
            monitor_train.add(epoch, train_error)
            print('full training time = %.2f HR' %
                  ((time.time() - start_full_time)/3600))

            # Save Parameter
            out_param_file = os.path.join(
                args.savemodel, 'psmnet_trained_param_' + str(epoch) + '.h5')
            nn.save_parameters(out_param_file)
Example #13
0
def main():
    # Args
    args = get_args()
    save_args(args)

    # Context
    ctx = extension_context(args.context,
                            device_id=args.device_id,
                            type_config=args.type_config)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Data Itrator
    di = data_iterator(args.img_path,
                       args.batch_size,
                       imsize=(args.imsize, args.imsize),
                       num_samples=args.train_samples,
                       dataset_name=args.dataset_name)
    # Model
    generator = Generator(use_bn=args.use_bn,
                          last_act=args.last_act,
                          use_wscale=args.not_use_wscale,
                          use_he_backward=args.use_he_backward)
    discriminator = Discriminator(use_ln=args.use_ln,
                                  alpha=args.leaky_alpha,
                                  use_wscale=args.not_use_wscale,
                                  use_he_backward=args.use_he_backward)

    # Solver
    solver_gen = S.Adam(alpha=args.learning_rate,
                        beta1=args.beta1,
                        beta2=args.beta2)
    solver_dis = S.Adam(alpha=args.learning_rate,
                        beta1=args.beta1,
                        beta2=args.beta2)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_dis = MonitorSeries("Discriminator Loss",
                                     monitor,
                                     interval=10)
    monitor_p_fake = MonitorSeries("Fake Probability", monitor, interval=10)
    monitor_p_real = MonitorSeries("Real Probability", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training Time per Resolution",
                                      monitor,
                                      interval=1)
    monitor_image_tile = MonitorImageTileWithName("Image Tile",
                                                  monitor,
                                                  num_images=4,
                                                  normalize_method=lambda x:
                                                  (x + 1.) / 2.)

    # TODO: use argment
    resolution_list = [4, 8, 16, 32, 64, 128]
    channel_list = [512, 512, 256, 128, 64, 32]

    trainer = Trainer(di,
                      generator,
                      discriminator,
                      solver_gen,
                      solver_dis,
                      args.monitor_path,
                      monitor_loss_gen,
                      monitor_loss_dis,
                      monitor_p_fake,
                      monitor_p_real,
                      monitor_time,
                      monitor_image_tile,
                      resolution_list,
                      channel_list,
                      n_latent=args.latent,
                      n_critic=args.critic,
                      save_image_interval=args.save_image_interval,
                      hyper_sphere=args.hyper_sphere,
                      l2_fake_weight=args.l2_fake_weight)

    # TODO: use images per resolution?
    trainer.train(args.epoch_per_resolution)
Example #14
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)
Example #15
0
def train():
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Train
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    h_d, h_copy, pred, g_pred, g_label = cnn_dni(image, y=label)
    loss_ce = ce_loss(pred, label)  # loss of a problem at hand
    loss_se = se_loss(g_pred, g_label)  # gradient synthesizer loss

    # Test
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    vpred = cnn(vimage, test=True)

    # Solver
    solver = S.Adam(args.learning_rate)
    with nn.parameter_scope("ref"):
        solver.set_parameters(nn.get_parameters())
    solver_gs = S.Adam(args.learning_rate)
    with nn.parameter_scope("gs"):
        solver_gs.set_parameters(nn.get_parameters())

    # Monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # DataIterator
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)

    # Training loop
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=False)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Training
        image.d, label.d = data.next()
        solver.zero_grad()
        solver_gs.zero_grad()

        ## forward
        h_d.forward(clear_no_need_grad=False)
        loss_ce.forward(clear_no_need_grad=False)
        loss_se.forward(clear_no_need_grad=False)

        ## backward
        loss_ce.backward(clear_buffer=False)
        h_d.backward(clear_buffer=False)
        loss_se.backward(clear_buffer=False)

        ## update
        solver.weight_decay(args.weight_decay)
        solver.update()
        solver_gs.weight_decay(args.weight_decay)
        solver_gs.update()

        ## monitor
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss_ce.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % args.max_iter))
Example #16
0
def main():
    conf = get_config()
    train_gt_path = sorted(glob.glob(conf.DIV2K.gt_train + "/*.png"))
    train_lq_path = sorted(glob.glob(conf.DIV2K.lq_train + "/*.png"))
    val_gt_path = sorted(glob.glob(conf.SET14.gt_val + "/*.png"))
    val_lq_path = sorted(glob.glob(conf.SET14.lq_val + "/*.png"))
    train_samples = len(train_gt_path)
    val_samples = len(val_gt_path)
    lr_g = conf.hyperparameters.lr_g
    lr_d = conf.hyperparameters.lr_d
    lr_steps = conf.train.lr_steps

    random.seed(conf.train.seed)
    np.random.seed(conf.train.seed)

    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(
        extension_module, device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    # data iterators for train and val data
    from data_loader import data_iterator_sr
    data_iterator_train = data_iterator_sr(
        train_samples, conf.train.batch_size, train_gt_path, train_lq_path, train=True, shuffle=True)
    data_iterator_val = data_iterator_sr(
        val_samples, conf.val.batch_size, val_gt_path, val_lq_path, train=False, shuffle=False)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    train_gt = nn.Variable(
        (conf.train.batch_size, 3, conf.train.gt_size, conf.train.gt_size))
    train_lq = nn.Variable(
        (conf.train.batch_size, 3, conf.train.gt_size // conf.train.scale, conf.train.gt_size // conf.train.scale))

    # setting up monitors for logging
    monitor_path = './nnmonitor' + str(datetime.now().strftime("%Y%m%d%H%M%S"))
    monitor = Monitor(monitor_path)
    monitor_pixel_g = MonitorSeries(
        'l_g_pix per iteration', monitor, interval=100)
    monitor_val = MonitorSeries(
        'Validation loss per epoch', monitor, interval=1)
    monitor_time = MonitorTimeElapsed(
        "Training time per epoch", monitor, interval=1)

    with nn.parameter_scope("gen"):
        nn.load_parameters(conf.train.gen_pretrained)
        fake_h = rrdb_net(train_lq, 64, 23)
        fake_h.persistent = True
    pixel_loss = F.mean(F.absolute_error(fake_h, train_gt))
    pixel_loss.persistent = True
    gen_loss = pixel_loss

    if conf.model.esrgan:
        from esrgan_model import get_esrgan_gen, get_esrgan_dis, get_esrgan_monitors
        gen_model = get_esrgan_gen(conf, train_gt, train_lq, fake_h)
        gen_loss = conf.hyperparameters.eta_pixel_loss * pixel_loss + conf.hyperparameters.feature_loss_weight * gen_model.feature_loss + \
            conf.hyperparameters.lambda_gan_loss * gen_model.loss_gan_gen
        dis_model = get_esrgan_dis(fake_h, gen_model.pred_d_real)
        # Set Discriminator parameters
        solver_dis = S.Adam(lr_d, beta1=0.9, beta2=0.99)
        with nn.parameter_scope("dis"):
            solver_dis.set_parameters(nn.get_parameters())
        esr_mon = get_esrgan_monitors()

    # Set generator Parameters
    solver_gen = S.Adam(alpha=lr_g, beta1=0.9, beta2=0.99)
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    train_size = int(
        train_samples / conf.train.batch_size / comm.n_procs)
    total_epochs = conf.train.n_epochs
    start_epoch = 0
    current_iter = 0
    if comm.rank == 0:
        print("total_epochs", total_epochs)
        print("train_samples", train_samples)
        print("val_samples", val_samples)
        print("train_size", train_size)

    for epoch in range(start_epoch + 1, total_epochs + 1):
        index = 0
        # Training loop for psnr rrdb model
        while index < train_size:
            current_iter += comm.n_procs
            train_gt.d, train_lq.d = data_iterator_train.next()

            if not conf.model.esrgan:
                lr_g = get_repeated_cosine_annealing_learning_rate(
                    current_iter, conf.hyperparameters.eta_max, conf.hyperparameters.eta_min, conf.train.cosine_period,
                    conf.train.cosine_num_period)

            if conf.model.esrgan:
                lr_g = get_multistep_learning_rate(
                    current_iter, lr_steps, lr_g)
                gen_model.var_ref.d = train_gt.d
                gen_model.pred_d_real.grad.zero()
                gen_model.pred_d_real.forward(clear_no_need_grad=True)
                gen_model.pred_d_real.need_grad = False

            # Generator update
            gen_loss.forward(clear_no_need_grad=True)
            solver_gen.zero_grad()
            # All-reduce gradients every 2MiB parameters during backward computation
            if comm.n_procs > 1:
                with nn.parameter_scope('gen'):
                    all_reduce_callback = comm.get_all_reduce_callback()
                    gen_loss.backward(clear_buffer=True,
                                      communicator_callbacks=all_reduce_callback)
            else:
                gen_loss.backward(clear_buffer=True)
            solver_gen.set_learning_rate(lr_g)
            solver_gen.update()

            # Discriminator Upate
            if conf.model.esrgan:
                gen_model.pred_d_real.need_grad = True
                lr_d = get_multistep_learning_rate(
                    current_iter, lr_steps, lr_d)
                solver_dis.zero_grad()
                dis_model.l_d_total.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    with nn.parameter_scope('dis'):
                        all_reduce_callback = comm.get_all_reduce_callback()
                    dis_model.l_d_total.backward(
                        clear_buffer=True, communicator_callbacks=all_reduce_callback)
                else:
                    dis_model.l_d_total.backward(clear_buffer=True)
                solver_dis.set_learning_rate(lr_d)
                solver_dis.update()

            index += 1
            if comm.rank == 0:
                monitor_pixel_g.add(
                    current_iter, pixel_loss.d.copy())
                monitor_time.add(epoch * comm.n_procs)
            if comm.rank == 0 and conf.model.esrgan:
                esr_mon.monitor_feature_g.add(
                    current_iter, gen_model.feature_loss.d.copy())
                esr_mon.monitor_gan_g.add(
                    current_iter, gen_model.loss_gan_gen.d.copy())
                esr_mon.monitor_gan_d.add(
                    current_iter, dis_model.l_d_total.d.copy())
                esr_mon.monitor_d_real.add(current_iter, F.mean(
                    gen_model.pred_d_real.data).data)
                esr_mon.monitor_d_fake.add(current_iter, F.mean(
                    gen_model.pred_g_fake.data).data)

        # Validation Loop
        if comm.rank == 0:
            avg_psnr = 0.0
            for idx in range(val_samples):
                val_gt_im, val_lq_im = data_iterator_val.next()
                val_gt = nn.NdArray.from_numpy_array(val_gt_im)
                val_lq = nn.NdArray.from_numpy_array(val_lq_im)
                with nn.parameter_scope("gen"):
                    avg_psnr = val_save(
                        val_gt, val_lq, val_lq_path, idx, epoch, avg_psnr)
            avg_psnr = avg_psnr / val_samples
            monitor_val.add(epoch, avg_psnr)

        # Save generator weights
        if comm.rank == 0:
            if not os.path.exists(conf.train.savemodel):
                os.makedirs(conf.train.savemodel)
            with nn.parameter_scope("gen"):
                nn.save_parameters(os.path.join(
                    conf.train.savemodel, "generator_param_%06d.h5" % epoch))
       # Save discriminator weights
        if comm.rank == 0 and conf.model.esrgan:
            with nn.parameter_scope("dis"):
                nn.save_parameters(os.path.join(
                    conf.train.savemodel, "discriminator_param_%06d.h5" % epoch))
Example #17
0
def train(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = comm.local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn

    # Model
    # workaround to start with the same weights in the distributed system.
    np.random.seed(412)
    # generator loss
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes,
                       sn=not_sn).apply(persistent=True)
    p_fake = discriminator(x_fake, y_fake, maps=maps //
                           16, n_classes=n_classes, sn=not_sn)
    loss_gen = gan_loss(p_fake)
    # discriminator loss
    y_real = nn.Variable([batch_size])
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, y_real, maps=maps //
                           16, n_classes=n_classes, sn=not_sn)
    loss_dis = gan_loss(p_fake, p_real)
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    y_test = nn.Variable.from_numpy_array(
        generate_random_class(n_classes, batch_size))
    x_test = generator(z_test, y_test, maps=maps,
                       n_classes=n_classes, test=True, sn=not_sn)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    if comm.rank == 0:
        monitor = Monitor(args.monitor_path)
        monitor_loss_gen = MonitorSeries(
            "Generator Loss", monitor, interval=10)
        monitor_loss_dis = MonitorSeries(
            "Discriminator Loss", monitor, interval=10)
        monitor_time = MonitorTimeElapsed(
            "Training Time", monitor, interval=10)
        monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor,
                                                    num_images=args.batch_size,
                                                    interval=1,
                                                    normalize_method=normalize_method)
        monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=normalize_method)
    # DataIterator
    rng = np.random.RandomState(device_id)
    di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path,
                                args.batch_size, n_classes=args.n_classes,
                                rng=rng)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need for discriminator backward
        solver_dis.zero_grad()
        for _ in range(args.accum_grad):
            # feed x_real and y_real
            x_data, y_data = di.next()
            x_real.d, y_real.d = x_data, y_data.flatten()
            # feed z and y_fake
            z_data = np.random.randn(args.batch_size, args.latent)
            y_data = generate_random_class(args.n_classes, args.batch_size)
            z.d, y_fake.d = z_data, y_data
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(
                1.0 / (args.accum_grad * n_devices), clear_buffer=True)
        comm.all_reduce([v.grad for v in params_dis.values()])
        solver_dis.update()

        # Train genrator
        x_fake.need_grad = True  # need for generator backward
        solver_gen.zero_grad()
        for _ in range(args.accum_grad):
            z_data = np.random.randn(args.batch_size, args.latent)
            y_data = generate_random_class(args.n_classes, args.batch_size)
            z.d, y_fake.d = z_data, y_data
            loss_gen.forward(clear_no_need_grad=True)
            loss_gen.backward(
                1.0 / (args.accum_grad * n_devices), clear_buffer=True)
        comm.all_reduce([v.grad for v in params_gen.values()])
        solver_gen.update()

        # Synchronize by averaging the weights over devices using allreduce
        if i % args.sync_weight_every_itr == 0:
            weights = [v.data for v in nn.get_parameters().values()]
            comm.all_reduce(weights, division=True, inplace=True)

        # Save model and image
        if i % args.save_interval == 0 and comm.rank == 0:
            x_test.forward(clear_buffer=True)
            nn.save_parameters(os.path.join(
                args.monitor_path, "params_{}.h5".format(i)))
            monitor_image_tile_train.add(i, x_fake.d)
            monitor_image_tile_test.add(i, x_test.d)

        # Monitor
        if comm.rank == 0:
            monitor_loss_gen.add(i, loss_gen.d.copy())
            monitor_loss_dis.add(i, loss_dis.d.copy())
            monitor_time.add(i)

    if comm.rank == 0:
        x_test.forward(clear_buffer=True)
        nn.save_parameters(os.path.join(
            args.monitor_path, "params_{}.h5".format(i)))
        monitor_image_tile_train.add(i, x_fake.d)
        monitor_image_tile_test.add(i, x_test.d)
Example #18
0
def train(args):
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction
    elif args.net == 'cifar10_binary_connect_resnet23_prediction':
        model_prediction = cifar10_binary_connect_resnet23_prediction
    elif args.net == 'cifar10_binary_net_resnet23_prediction':
        model_prediction = cifar10_binary_net_resnet23_prediction
    elif args.net == 'cifar10_binary_weight_resnet23_prediction':
        model_prediction = cifar10_binary_weight_resnet23_prediction
    elif args.net == 'cifar10_fp_connect_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_fp_connect_resnet23_prediction,
            n=args.bit_width,
            delta=args.delta)
    elif args.net == 'cifar10_fp_net_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_fp_net_resnet23_prediction,
            n=args.bit_width,
            delta=args.delta)
    elif args.net == 'cifar10_pow2_connect_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_pow2_connect_resnet23_prediction,
            n=args.bit_width,
            m=args.upper_bound)
    elif args.net == 'cifar10_pow2_net_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_pow2_net_resnet23_prediction,
            n=args.bit_width,
            m=args.upper_bound)
    elif args.net == 'cifar10_inq_resnet23_prediction':
        model_prediction = functools.partial(cifar10_inq_resnet23_prediction,
                                             num_bits=args.bit_width)
    elif args.net == 'cifar10_min_max_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_min_max_resnet23_prediction,
            ql_min=args.ql_min,
            ql_max=args.ql_max,
            p_min_max=args.p_min_max,
            a_min_max=args.a_min_max,
            a_ema=args.a_ema,
            ste_fine_grained=args.ste_fine_grained)

    # TRAIN
    maps = 64
    data_iterator = data_iterator_cifar10
    c = 3
    h = w = 32
    n_train = 50000
    n_valid = 10000

    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    label = nn.Variable([args.batch_size, 1])
    # Create model_prediction graph.
    pred = model_prediction(image, maps=maps, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = model_prediction(vimage, maps=maps, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
            if ve < best_ve:
                nn.save_parameters(
                    os.path.join(args.model_save_path, 'params_%06d.h5' % i))
                best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Example #19
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size

    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=32,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((args.batch_size, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    _, tdata = data_iterator(args.batch_size, True, rng)
    vsource, vdata = data_iterator(args.batch_size, False)

    # loss_error_train.forward()

    # Training-loop
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                input_image_valid["image"].d = image
                input_image_valid["label"].d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            # Save model
            if device_id == 0:
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(
                        os.path.join(args.model_save_path,
                                     'params_%06d.h5' % i))

        # Forward/Zerograd
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:  # loss and error locally, and elapsed time
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

        # exit(0)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
Example #20
0
def main(args):
    # Setting
    device_id = args.device_id
    conf = args.conf
    path = args.conf.data_path
    B = conf.batch_size
    R = conf.n_rays
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    # Dataset
    ds = DTUMVSDataSource(path, R, shuffle=True)
    di = data_iterator_dtumvs(ds, B)

    camloc = nn.Variable([B, 3])
    raydir = nn.Variable([B, R, 3])
    alpha = nn.Variable.from_numpy_array(conf.alpha)
    color_gt = nn.Variable([B, R, 3])
    mask_obj = nn.Variable([B, R, 1])

    # Monitor
    interval = di.size
    monitor_path = create_monitor_path(conf.data_path, args.monitor_path)
    monitor = Monitor(monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=interval)
    monitor_mhit = MonitorSeries("Hit count", monitor, interval=1)
    monitor_color_loss = MonitorSeries(
        "Training color loss", monitor, interval=interval)
    monitor_mask_loss = MonitorSeries(
        "Training mask loss", monitor, interval=interval)
    monitor_eikonal_loss = MonitorSeries(
        "Training eikonal loss", monitor, interval=interval)
    monitor_time = MonitorTimeElapsed(
        "Training time", monitor, interval=interval)
    monitor_image = MonitorImage("Rendered image", monitor, interval=1)

    # Solver
    solver = S.Adam(conf.learning_rate)
    loss, color_loss, mask_loss, eikonal_loss, mask_hit = \
        idr_loss(camloc, raydir, alpha, color_gt, mask_obj, conf)
    solver.set_parameters(nn.get_parameters())

    # Training loop
    for i in range(conf.train_epoch):

        ds.change_sampling_idx()

        # Validate
        if i % conf.valid_epoch_interval == 0 and not args.skip_val:
            def validate(i):
                pose_ = ds.poses[conf.valid_index:conf.valid_index+1, ...]
                intrinsic_ = ds.intrinsics[conf.valid_index:conf.valid_index+1, ...]
                mask_obj_ = ds.masks[conf.valid_index:conf.valid_index+1, ...]
                image = render(pose_, intrinsic_, mask_obj_, conf)
                monitor_image.add(i, image)
                nn.save_parameters(f"{monitor_path}/model_{i:05d}.h5")
            validate(i)

        # Train
        for j in range(di.size):
            # Feed data
            color_, mask_, intrinsic_, pose_, xy_ = di.next()
            color_gt.d = color_
            mask_obj.d = mask_
            raydir_, camloc_ = generate_raydir_camloc(pose_, intrinsic_, xy_)
            raydir.d = raydir_
            camloc.d = camloc_

            # Network
            loss.forward()
            solver.zero_grad()
            loss.backward(clear_buffer=True)
            solver.update()

            # Monitor
            t = i * di.size + j
            monitor_mhit.add(t, np.sum(mask_hit.d))
            monitor_loss.add(t, loss.d)
            monitor_color_loss.add(t, color_loss.d)
            monitor_mask_loss.add(t, mask_loss.d)
            monitor_eikonal_loss.add(t, eikonal_loss.d)
            monitor_time.add(t)

        # Decay
        if i in conf.alpha_decay:
            alpha.d = alpha.d * 2.0
        if i in conf.lr_decay:
            solver.set_learning_rate(solver.learning_rate() * 0.5)

    validate(i)
Example #21
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
Example #22
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    rng = np.random.RandomState(313)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Communicator and Context
    extension_module = "cuda.cudnn"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx = extension_context(extension_module, device_id=device_id)
    nn.set_default_context(ctx)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # add parameters to communicator
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    tdata = data_iterator(args.batch_size, True, rng)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if device_id == 0:
            if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                ve = 0.
                for j in range(args.val_iter):
                    image, label = vdata.next()
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= args.val_iter
                monitor_verr.add(i * n_devices, ve)
            if i % int(args.model_save_interval / n_devices) == 0:
                nn.save_parameters(
                    os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Allreduce
        comm.allreduce(division=False, inplace=False)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:
            e = categorical_error(pred_train.d, input_image_train["label"].d)
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, e)
            monitor_time.add(i * n_devices)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Communicator and Context
    extension_module = "cuda.cudnn"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx = extension_context(extension_module, device_id=device_id)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = cifar100_resnet23_prediction(
        image_train, ctx, test)
    loss_train = cifar100_resnet32_loss(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # add parameters to communicator
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar100_resnet23_prediction(
        image_valid, ctx, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(1. * n_train_samples /
                      args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = 1. * n_devices / warmup_iter

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)
    with data_iterator_cifar100(args.batch_size, True) as tdata, \
            data_iterator_cifar100(bs_valid, False) as vdata:
        # Training-loop
        for i in range(int(args.max_iter / n_devices)):
            # Validation
            if mpi_rank == 0:
                if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                    ve = 0.
                    for j in range(args.val_iter):
                        image, label = vdata.next()
                        input_image_valid["image"].d = image
                        pred_valid.forward()
                        ve += categorical_error(pred_valid.d, label)
                    ve /= args.val_iter
                    monitor_verr.add(i * n_devices, ve)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(os.path.join(
                        args.model_save_path, 'params_%06d.h5' % i))

            # Forward/Zerograd/Backward
            image, label = tdata.next()
            input_image_train["image"].d = image
            input_image_train["label"].d = label
            loss_train.forward()
            solver.zero_grad()
            loss_train.backward()

            # In-place Allreduce
            comm.allreduce(division=True)

            # Solvers update
            solver.update()

            # Linear Warmup
            if i < warmup_iter:
                lr = base_lr * n_devices * warmup_slope * i
                solver.set_learning_rate(lr)
            else:
                lr = base_lr * n_devices
                solver.set_learning_rate(lr)

            if mpi_rank == 0:
                e = categorical_error(
                    pred_train.d, input_image_train["label"].d)
                monitor_loss.add(i * n_devices, loss_train.d.copy())
                monitor_err.add(i * n_devices, e)
                monitor_time.add(i * n_devices)
    if mpi_rank == 0:
        nn.save_parameters(os.path.join(
            args.model_save_path,
            'params_%06d.h5' % (args.max_iter / n_devices)))
Example #24
0
    def __init__(self,
                 solver,
                 tinput=None,
                 tlabel=None,
                 tpred=None,
                 tdata=None,
                 vinput=None,
                 vlabel=None,
                 vpred=None,
                 vdata=None,
                 monitor_path=None,
                 model_save_path=None,
                 max_epoch=1,
                 iter_per_epoch=None,
                 val_iter=None):
        # Monitors
        monitor = Monitor(monitor_path)
        monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
        monitor_vloss = MonitorSeries("Valid loss", monitor, interval=1)
        monitor_time = MonitorTimeElapsed("Training time",
                                          monitor,
                                          interval=10)

        # Loss
        tpred = tpred.apply(persistent=True)
        tloss = F.mean(F.squared_error(tpred, tlabel))
        vpred = vpred.apply(persistent=True)
        vloss = F.mean(F.squared_error(vpred, vlabel))

        # Updater
        def tdata_feeder():
            tinput.d, tlabel.d = tdata.next()

        def update_callback_on_finish(i):
            monitor_loss.add(i, tloss.d)
            monitor_time.add(i)

        updater = Updater(
            solver,
            tloss,
            data_feeder=tdata_feeder,
            forward_callback_on_finish=forward_callback_on_finish,
            update_callback_on_finish=update_callback_on_finish)

        # Evaluator
        def vdata_feeder():
            vinput.d, vlabel.d = vdata.next()

        def vloss_callback_on_finish(i, v):
            monitor_vloss.add(i, v)

        val_iter = val_iter if val_iter is not None else vdata.size // vdata.batch_size
        evaluator = Evaluator(vloss,
                              data_feeder=vdata_feeder,
                              val_iter=val_iter,
                              callback_on_finish=vloss_callback_on_finish)

        # Trainer
        iter_per_epoch = iter_per_epoch if iter_per_epoch is not None \
            else tdata.size // tdata.batch_size
        self.trainer = Trainer(updater,
                               evaluator,
                               model_save_path,
                               max_epoch=max_epoch,
                               iter_per_epoch=iter_per_epoch)
Example #25
0
    # [def solver]
    if g_optimizer == 'SGD':
        solver = S.Sgd(g_default_learning_rate)
    elif g_optimizer == 'Adam':
        solver = S.Adam(g_default_learning_rate)
    elif g_optimizer == 'AdaBound':
        solver = S.AdaBound(g_default_learning_rate)
    solver.set_parameters(coef_dict_source)

    # [def monitor]
    monitor = Monitor(g_save_log_dir)
    monitor_loss = MonitorSeries("Training loss",
                                 monitor,
                                 interval=g_monitor_interval)
    monitor_time = MonitorTimeElapsed("Training time",
                                      monitor,
                                      interval=g_monitor_interval)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)

    # [training]
    iter_num = 0
    start_index = 0
    newest_model = 'dummy'
    for epoch in range(g_max_epoch):
        for i in range(iter_num_max):
            data_batch, start_index = get_data_batch(my_tdata, start_index,
                                                     g_batch_size, tdata_num)
            image_batch = np.array(
                data_batch['image'])  # shape=(batch_size, 3, 300, 300)
            label_batch = np.array(data_batch['label'])
Example #26
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Data Iterator
    tdata = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(args.max_iter):
        # Validation
        if i % int(n_train_samples / args.batch_size) == 0:
            ve = 0.
            for j in range(args.val_iter):
                image, label = vdata.next()
                input_image_valid["image"].d = image
                pred_valid.forward()
                ve += categorical_error(pred_valid.d, label)
            ve /= args.val_iter
            monitor_verr.add(i, ve)
        if int(i % args.model_save_interval) == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Solvers update
        solver.update()

        e = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_loss.add(i, loss_train.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter)))
Example #27
0
def train():
    '''
    Main script.
    '''
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # TRAIN
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    x = image / 255.0
    t_onehot = F.one_hot(label, (10, ))
    with nn.parameter_scope("capsnet"):
        c1, pcaps, u_hat, caps, pred = model.capsule_net(
            x,
            test=False,
            aug=True,
            grad_dynamic_routing=args.grad_dynamic_routing)
    with nn.parameter_scope("capsnet_reconst"):
        recon = model.capsule_reconstruction(caps, t_onehot)
    loss_margin, loss_reconst, loss = model.capsule_loss(
        pred, t_onehot, recon, x)
    pred.persistent = True

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    vx = vimage / 255.0
    with nn.parameter_scope("capsnet"):
        _, _, _, _, vpred = model.capsule_net(vx, test=True, aug=False)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    train_iter = int(60000 / args.batch_size)
    val_iter = int(10000 / args.batch_size)
    logger.info("#Train: {} #Validation: {}".format(train_iter, val_iter))
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=1)
    monitor_mloss = MonitorSeries("Training margin loss", monitor, interval=1)
    monitor_rloss = MonitorSeries("Training reconstruction loss",
                                  monitor,
                                  interval=1)
    monitor_err = MonitorSeries("Training error", monitor, interval=1)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)
    monitor_lr = MonitorSeries("Learning rate", monitor, interval=1)

    # To_save_nnp
    m_image, m_label, m_noise, m_recon = model_tweak_digitscaps(
        args.batch_size)
    contents = save_nnp({
        'x1': m_image,
        'x2': m_label,
        'x3': m_noise
    }, {'y': m_recon}, args.batch_size)
    save.save(os.path.join(args.monitor_path, 'capsnet_epoch0_result.nnp'),
              contents)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    start_point = 0

    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)
    # Training loop.
    for e in range(start_point, args.max_epochs):

        # Learning rate decay
        learning_rate = solver.learning_rate()
        if e != 0:
            learning_rate *= 0.9
        solver.set_learning_rate(learning_rate)
        monitor_lr.add(e, learning_rate)

        # Training
        train_error = 0.0
        train_loss = 0.0
        train_mloss = 0.0
        train_rloss = 0.0
        for i in range(train_iter):
            image.d, label.d = data.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()
            train_error += categorical_error(pred.d, label.d)
            train_loss += loss.d
            train_mloss += loss_margin.d
            train_rloss += loss_reconst.d
        train_error /= train_iter
        train_loss /= train_iter
        train_mloss /= train_iter
        train_rloss /= train_iter

        # Validation
        val_error = 0.0
        for j in range(val_iter):
            vimage.d, vlabel.d = vdata.next()
            vpred.forward(clear_buffer=True)
            val_error += categorical_error(vpred.d, vlabel.d)
        val_error /= val_iter

        # Monitor
        monitor_time.add(e)
        monitor_loss.add(e, train_loss)
        monitor_mloss.add(e, train_mloss)
        monitor_rloss.add(e, train_rloss)
        monitor_err.add(e, train_error)
        monitor_verr.add(e, val_error)
        save_checkpoint(args.monitor_path, e, solver)

    # To_save_nnp
    contents = save_nnp({
        'x1': m_image,
        'x2': m_label,
        'x3': m_noise
    }, {'y': m_recon}, args.batch_size)
    save.save(os.path.join(args.monitor_path, 'capsnet_result.nnp'), contents)
Example #28
0
def train():
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction

    # TRAIN
    maps = 64
    data_iterator = data_iterator_cifar10
    c = 3
    h = w = 32
    n_train = 50000
    n_valid = 10000

    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    label = nn.Variable([args.batch_size, 1])
    # Create model_prediction graph.
    pred = model_prediction(image, maps=maps, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # SSL Regularization
    loss += ssl_regularization(nn.get_parameters(), args.filter_decay,
                               args.channel_decay)

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = model_prediction(vimage, maps=maps, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Example #29
0
def main():
    # Args
    args = get_args()

    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    logger.info(ctx)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Monitor
    monitor = Monitor(args.monitor_path)

    # Validation
    logger.info("Start validation")

    num_images = args.valid_samples
    num_batches = num_images // args.batch_size

    # DataIterator
    di = data_iterator(args.img_path,
                       args.batch_size,
                       imsize=(args.imsize, args.imsize),
                       num_samples=args.valid_samples,
                       dataset_name=args.dataset_name)
    # generator
    gen = load_gen(args.model_load_path,
                   use_bn=args.use_bn,
                   last_act=args.last_act,
                   use_wscale=args.not_use_wscale,
                   use_he_backward=args.use_he_backward)

    # compute metric
    if args.validation_metric == "ms-ssim":
        logger.info("Multi Scale SSIM")
        monitor_time = MonitorTimeElapsed("MS-SSIM-ValidationTime",
                                          monitor,
                                          interval=1)
        monitor_metric = MonitorSeries("MS-SSIM", monitor, interval=1)
        from ms_ssim import compute_metric
        score = compute_metric(gen, args.batch_size, num_images, args.latent,
                               args.hyper_sphere)
        monitor_time.add(0)
        monitor_metric.add(0, score)
    elif args.validation_metric == "swd":
        logger.info("Sliced Wasserstein Distance")
        monitor_time = MonitorTimeElapsed("SWD-ValidationTime",
                                          monitor,
                                          interval=1)
        monitor_metric = MonitorSeries("SWD", monitor, interval=1)
        nhoods_per_image = 128
        nhood_size = 7
        level_list = [128, 64, 32, 16]  # TODO: use argument
        dir_repeats = 4
        dirs_per_repeat = 128
        from sliced_wasserstein import compute_metric
        score = compute_metric(di, gen, args.latent, num_batches,
                               nhoods_per_image, nhood_size, level_list,
                               dir_repeats, dirs_per_repeat, args.hyper_sphere)
        monitor_time.add(0)
        monitor_metric.add(0, score)  # averaged in the log
    else:
        logger.info("Set `validation-metric` as either `ms-ssim` or `swd`.")
    logger.info(score)
    logger.info("End validation")
Example #30
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    lambda_ = args.lambda_

    # Model
    # generator loss
    z = nn.Variable([batch_size, latent])
    x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True)
    p_fake = discriminator(x_fake, maps=maps)
    loss_gen = gan_loss(p_fake).apply(persistent=True)
    # discriminator loss
    p_fake = discriminator(x_fake, maps=maps)
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, maps=maps)
    loss_dis = gan_loss(p_fake, p_real).apply(persistent=True)
    # gradient penalty
    eps = F.rand(shape=[batch_size, 1, 1, 1])
    x_rmix = eps * x_real + (1.0 - eps) * x_fake
    p_rmix = discriminator(x_rmix, maps=maps)
    x_rmix.need_grad = True  # Enabling gradient computation for double backward
    grads = nn.grad([p_rmix], [x_rmix])
    l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads]
    gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms])
    loss_dis += lambda_ * gp
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    x_test = generator(z_test, maps=maps, test=True,
                       up=args.up).apply(persistent=True)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)

    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_cri = MonitorSeries("Negative Critic Loss",
                                     monitor,
                                     interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=batch_size,
                                                interval=1,
                                                normalize_method=denormalize)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Data Iterator
    di = data_iterator_cifar10(batch_size, True)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need backward to generator
        for _ in range(args.n_critic):
            solver_dis.zero_grad()
            x_real.d = di.next()[0] / 127.5 - 1.0
            z.d = np.random.randn(batch_size, latent)
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(clear_buffer=True)
            solver_dis.update()

        # Train generator
        x_fake.need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        z.d = np.random.randn(batch_size, latent)
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()
        # Monitor
        monitor_loss_gen.add(i, loss_gen.d)
        monitor_loss_cri.add(i, -loss_dis.d)
        monitor_time.add(i)

        # Save
        if i % args.save_interval == 0:
            monitor_image_tile_train.add(i, x_fake)
            monitor_image_tile_test.add(i, x_test)
            nn.save_parameters(
                os.path.join(args.monitor_path, "params_{}.h5".format(i)))

    # Last
    x_test.forward(clear_buffer=True)
    nn.save_parameters(
        os.path.join(args.monitor_path, "params_{}.h5".format(i)))
    monitor_image_tile_train.add(i, x_fake)
    monitor_image_tile_test.add(i, x_test)
Example #31
0
def main(opt):
    '''
    NNabla configuration
    '''
    # os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    type_config = 'half' if opt.mixed_precision else 'float'
    comm = init_nnabla(ext_name=opt.extension_module,
                       device_id='0', type_config=type_config)
    nn.set_auto_forward(True)
    output_folder = os.path.join(
        opt.save_dir, "tmp.monitor.{}_{}".format(opt.arch, opt.num_layers))
    monitor = Monitor(output_folder)
    monitor_loss = None
    monitor_hm_loss = None
    monitor_wh_loss = None
    monitor_off_loss = None
    monitor_acc = None
    monitor_val_loss = None
    monitor_val_hm_loss = None
    monitor_val_wh_loss = None
    monitor_val_off_loss = None
    monitor_map = None
    monitor_time = None
    Detector = detector_factory[opt.task]
    detector = Detector(opt)

    interval = 1
    if comm.rank == 0:
        monitor_loss = MonitorSeries(
            "Training Loss", monitor, interval=interval, verbose=False)
        monitor_hm_loss = MonitorSeries(
            "hm_loss", monitor, interval=interval, verbose=False)
        monitor_wh_loss = MonitorSeries(
            "wh_loss", monitor, interval=interval, verbose=False)
        monitor_off_loss = MonitorSeries(
            "off_loss", monitor, interval=interval, verbose=False)
        monitor_val_loss = MonitorSeries(
            "Validation Loss", monitor, interval=interval, verbose=False)
        monitor_val_hm_loss = MonitorSeries(
            "val_hm_loss", monitor, interval=interval, verbose=False)
        monitor_val_wh_loss = MonitorSeries(
            "val_wh_loss", monitor, interval=interval, verbose=False)
        monitor_val_off_loss = MonitorSeries(
            "val_off_loss", monitor, interval=interval, verbose=False)
        monitor_map = MonitorSeries(
            "Val mAP", monitor, interval=interval, verbose=False)
        monitor_time = MonitorTimeElapsed(
            "time", monitor, interval=1, verbose=False)
    '''
    Data Iterators
    '''
    seed = opt.seed
    rng = np.random.RandomState(seed)
    source_factory = get_data_source(opt.dataset)
    train_source = source_factory(opt, 'train', shuffle=True, rng=rng,
                                  mixed_precision=opt.mixed_precision, channel_last=opt.channel_last)
    train_loader = data_iterator(train_source,
                                 opt.batch_size,
                                 with_memory_cache=False,
                                 with_file_cache=False
                                 )
    train_loader = train_loader.slice(rng, comm.n_procs, slice_pos=comm.rank)
    val_source = source_factory(opt, 'val', shuffle=False, rng=rng,
                                mixed_precision=opt.mixed_precision, channel_last=opt.channel_last)
    val_loader = data_iterator(val_source,
                               opt.batch_size,
                               with_memory_cache=False,
                               with_file_cache=False
                               )
    logger.info('Creating model...')
    logger.info(opt.heads)
    logger.info(f"batch size per gpu: {opt.batch_size}")
    model = create_model(opt.arch, opt.heads, opt.head_conv, opt.num_layers, training=True,
                         channel_last=opt.channel_last, pretrained_model_dir=opt.pretrained_model_dir)
    if opt.checkpoint != '':
        load_model(model, opt.checkpoint, clear=True)

    start_epoch = 0
    loss_func = CtdetLoss(opt)
    lr_sched = create_learning_rate_scheduler(
        opt.train_config.learning_rate_config)
    solver = S.Adam(alpha=lr_sched.get_lr())
    trainer = Trainer(
                model, loss_func, solver, train_loader, train_source, [
                    monitor_loss, monitor_hm_loss, monitor_wh_loss, monitor_off_loss, monitor_val_loss, monitor_val_hm_loss, monitor_val_wh_loss, monitor_val_off_loss], opt, comm)

    root_dir = opt.save_dir

    checkpoint_dir = os.path.join(root_dir, output_folder, 'checkpoints')
    start_epoch = 0
    if opt.resume_from is not None:
        start_epoch = trainer.load_checkpoint(checkpoint_dir, opt.resume_from)
        logger.info('resuming from the epoch {}'.format(start_epoch))

    for epoch in range(start_epoch, opt.num_epochs):
        lr_sched.set_epoch(epoch)
        trainer.solver.set_learning_rate(lr_sched.get_lr())
        iteration = trainer.update(epoch)
        if comm.rank == 0:
            if epoch % opt.save_intervals == 0 or epoch == (opt.num_epochs-1):
                monitor_time.add(epoch)
                trainer.save_checkpoint(checkpoint_dir, epoch)

        if epoch % opt.val_intervals == 0 or epoch == (opt.num_epochs-1):
            model.training = False
            trainer.evaluate(val_loader, epoch)
            if not opt.val_calc_map:
                num_iters = val_loader.size
                pbar = trange(num_iters, desc="[Test][exp_id:{} epoch:{}/{}]".format(
                    opt.exp_id, epoch, opt.num_epochs), disable=comm.rank > 0)
                if comm.rank == 0:
                    results = {}
                    for ind in pbar:
                        img_id = val_source.images[ind]
                        img_info = val_source.coco.loadImgs(ids=[img_id])[0]
                        img_path = os.path.join(
                            val_source.img_dir, img_info['file_name'])
                        with nn.context_scope(comm.ctx_float):
                            ret = detector.run(img_path)
                        results[img_id] = ret['results']
                    val_map = val_source.run_eval(
                        results, opt.save_dir, opt.data_dir)
                    monitor_map.add(epoch, val_map)
            model.training = True
            if comm.n_procs > 1:
                # Required to prevent timeout error of allreduce
                # at the first iteration of the next epoch.
                comm.comm.barrier()
def classification_svd():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction_slim

    # TRAIN
    reference = "reference"
    slim = "slim"
    rrate = 0.5  # reduction rate
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create `reference` and "slim" prediction graph.
    model_load_path = args.model_load_path
    pred = mnist_cnn_prediction(image, scope=slim, rrate=rrate, test=False)
    pred.persistent = True

    # Decompose and set parameters
    decompose_network_and_set_params(model_load_path, reference, slim, rrate)
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create reference prediction graph.
    vpred = mnist_cnn_prediction(vimage, scope=slim, rrate=rrate, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    with nn.parameter_scope(slim):
        solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    best_ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct computation graphs for training and one for validation.
    * Initialize solvers and set parameter variables to those.
    * Instantiate a communicator and set parameter variables.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprops
      * Set parameter gradients zero
      * Execute backprop.
      * In-place allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Create contexts
    extension_module = args.context
    if extension_module != "cuda" and \
            extension_module != "cuda.cudnn":
        raise Exception("Use `cuda` or `cuda.cudnn` extension_module.")
    n_devices = args.n_devices
    ctxs = []
    for i in range(n_devices):
        ctx = extension_context(extension_module, device_id=i)
        ctxs.append(ctx)
    ctx = ctxs[-1]

    # Create training graphs
    input_image_train = []
    preds_train = []
    losses_train = []
    test = False
    for i in range(n_devices):
        image = nn.Variable((args.batch_size, 3, 32, 32))
        label = nn.Variable((args.batch_size, 1))
        device_scope_name = "device{}".format(i)

        pred = cifar10_resnet23_prediction(image, ctxs[i], device_scope_name,
                                           test)
        loss = cifar10_resnet32_loss(pred, label)

        input_image_train.append({"image": image, "label": label})
        preds_train.append(pred)
        losses_train.append(loss)

    # Create validation graph
    test = True
    device_scope_name = "device{}".format(0)
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar10_resnet23_prediction(image_valid, ctxs[i],
                                             device_scope_name, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solvers = []
    for i in range(n_devices):
        with nn.context_scope(ctxs[i]):
            solver = S.Adam()
            device_scope_name = "device{}".format(i)
            with nn.parameter_scope(device_scope_name):
                params = nn.get_parameters()
                solver.set_parameters(params)
            solvers.append(solver)

    # Communicator
    comm = C.DataParalellCommunicator(ctx)
    for i in range(n_devices):
        device_scope_name = "device{}".format(i)
        with nn.parameter_scope(device_scope_name):
            ctx = ctxs[i]
            params = nn.get_parameters()
            comm.add_context_and_parameters((ctx, params))
    comm.init()

    # Create threadpools with one thread
    pools = []
    for _ in range(n_devices):
        pool = ThreadPool(processes=1)
        pools.append(pool)

    # Once forward/backward to safely secure memory
    for device_id in range(n_devices):
        data, label = \
            (np.random.randn(*input_image_train[device_id]["image"].shape),
             (np.random.rand(*input_image_train[device_id]["label"].shape) * 10).astype(np.int32))

        ret = pools[device_id].apply_async(
            forward_backward, (input_image_train[device_id]["image"], data,
                               input_image_train[device_id]["label"], label,
                               losses_train[device_id], solvers[device_id]))
        ret.get()
        losses_train[device_id].d  # sync to host

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    tdata = data_iterator_cifar10(args.batch_size, True, rng)
    vdata = data_iterator_cifar10(args.batch_size, False)

    # Training-loop
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve = 0.
            for j in range(args.val_iter):
                image, label = vdata.next()
                input_image_valid["image"].d = image
                pred_valid.forward()
                ve += categorical_error(pred_valid.d, label)
            ve /= args.val_iter
            monitor_verr.add(i * n_devices, ve)
        if i % int(args.model_save_interval / n_devices) == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forwards/Zerograd/Backwards
        fb_results = []
        for device_id in range(n_devices):
            image, label = tdata.next()

            res = pools[device_id].apply_async(
                forward_backward,
                (input_image_train[device_id]["image"], image,
                 input_image_train[device_id]["label"], label,
                 losses_train[device_id], solvers[device_id]))
            fb_results.append(res)
        for device_id in range(n_devices):
            fb_results[device_id].get()

        # In-place allreduce
        comm.allreduce(division=True, inplace=False)

        # Solvers update
        for device_id in range(n_devices):
            solvers[device_id].update()

        e = categorical_error(preds_train[-1].d,
                              input_image_train[-1]["label"].d)
        monitor_loss.add(i * n_devices, losses_train[-1].d.copy())
        monitor_err.add(i * n_devices, e)
        monitor_time.add(i * n_devices)

    nn.save_parameters(
        os.path.join(args.model_save_path,
                     'params_%06d.h5' % (args.max_iter / n_devices)))