Esempio n. 1
0
def main(args):
    # Setting
    device_id = args.device_id
    conf = args.conf
    path = conf.data_path
    B = conf.batch_size
    R = conf.n_rays
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    ctx = get_extension_context('cudnn', device_id=device_id)
    nn.set_default_context(ctx)

    # Dataset
    ds = DTUMVSDataSource(path, R, shuffle=True)

    # Monitor
    monitor_path = "/".join(args.model_load_path.split("/")[0:-1])
    monitor = Monitor(monitor_path)
    monitor_psnrs = MonitorSeries(f"PSNRs", monitor, interval=1)
    monitor_psnr = MonitorSeries(f"PSNR", monitor, interval=1)

    # Load model
    nn.load_parameters(args.model_load_path)

    # Evaluate
    image_list = []
    for pose, intrinsic, mask_obj in zip(ds.poses, ds.intrinsics, ds.masks):
        image = render(pose[np.newaxis, ...], intrinsic[np.newaxis, ...],
                       mask_obj[np.newaxis, ...], conf)
        image_list.append(image)

    metric = psnr(image_list, ds.images, ds.masks, monitor_psnrs)
    monitor_psnr.add(0, metric)
def meta_test(args, shape_x, test_data):

    # Build episode generators
    test_episode_generator = EpisodeGenerator(
        test_data[0], test_data[1], args.n_class, args.n_shot, args.n_query)

    # Build prototypical network
    xs_v = nn.Variable((args.n_class * args.n_shot, ) + shape_x)
    xq_v = nn.Variable((args.n_class * args.n_query, ) + shape_x)
    hq_v = net(args.n_class, xs_v, xq_v, args.embedding,
               args.net_type, args.metric, True)
    yq_v = nn.Variable((args.n_class * args.n_query, 1))
    err_v = F.mean(F.top_n_error(hq_v, yq_v, n=1))

    # Load parameters
    nn.load_parameters(args.work_dir + "/params.h5")

    # Evaluate error rate
    v_errs = []
    for k in range(args.n_episode_for_test):
        xs_v.d, xq_v.d, yq_v.d = test_episode_generator.next()
        err_v.forward(clear_no_need_grad=True, clear_buffer=True)
        v_errs.append(np.float(err_v.d.copy()))
    v_err_mean = np.mean(v_errs)
    v_err_std = np.std(v_errs)
    v_err_conf = 1.96 * v_err_std / np.sqrt(args.n_episode_for_test)

    # Monitor error rate
    monitor = Monitor(args.work_dir)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)
    monitor_test_err.add(0, v_err_mean * 100)
    monitor_test_conf.add(0, v_err_conf * 100)

    return v_err_mean, v_err_conf
Esempio n. 3
0
def evaluate(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn

    # Model (Inception model) from nnp file
    nnp = NnpLoader(args.nnp_inception_model_load_path)
    x, y = get_input_and_output(nnp, args.batch_size, name=args.variable_name)

    if args.evaluation_metric == "IS":
        is_model = None
        compute_metric = compute_inception_score
    if args.evaluation_metric == "FID":
        di = data_iterator_imagenet(args.valid_dir,
                                    args.dirname_to_label_path,
                                    batch_size=args.batch_size,
                                    ih=args.image_size,
                                    iw=args.image_size,
                                    shuffle=True,
                                    train=False,
                                    noise=False)
        compute_metric = functools.partial(compute_frechet_inception_distance,
                                           di=di)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_metric = MonitorSeries("{}".format(args.evaluation_metric),
                                   monitor,
                                   interval=1)

    # Compute the evaluation metric for all models
    def cmp_func(path):
        return int(path.split("/")[-1].strip("params_").rstrip(".h5"))
    model_load_path = sorted(glob.glob("{}/*.h5".format(args.model_load_path)), key=cmp_func) \
        if os.path.isdir(args.model_load_path) else \
        [args.model_load_path]

    for path in model_load_path:
        # Model (SAGAN)
        nn.load_parameters(path)
        z = nn.Variable([batch_size, latent])
        y_fake = nn.Variable([batch_size])
        x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
            .apply(persistent=True)
        # Compute the evaluation metric
        score = compute_metric(z, y_fake, x_fake, x, y, args)
        itr = cmp_func(path)
        monitor_metric.add(itr, score)
Esempio n. 4
0
def train(env,
          model,
          buffer,
          exploration,
          monitor,
          update_fn,
          eval_fn,
          final_step,
          update_start,
          update_interval,
          save_interval,
          evaluate_interval,
          loss_labels=[]):
    reward_monitor = MonitorSeries('reward', monitor, interval=1)
    eval_reward_monitor = MonitorSeries('eval_reward', monitor, interval=1)
    time_monitor = MonitorTimeElapsed('time', monitor, interval=10000)
    loss_monitors = []
    for label in loss_labels:
        loss_monitors.append(MonitorSeries(label, monitor, interval=10000))

    step = 0
    while step <= final_step:
        obs_t = env.reset()
        ter_tp1 = False
        cumulative_reward = 0.0
        model.reset(step)
        while not ter_tp1:
            # select best action
            act_t = model.infer(obs_t)
            # add exploration noise
            act_t = exploration.get(step, act_t)
            # iterate environment
            obs_tp1, rew_tp1, ter_tp1, _ = env.step(act_t)
            # store transition
            buffer.append(obs_t, [act_t], rew_tp1, obs_tp1, ter_tp1)

            # update parameters
            if step > update_start and step % update_interval == 0:
                for i, loss in enumerate(update_fn(step)):
                    if loss is not None:
                        loss_monitors[i].add(step, loss)

            # save parameters
            if step % save_interval == 0:
                path = os.path.join(monitor.save_path, 'model_%d.h5' % step)
                nn.save_parameters(path)

            if step % evaluate_interval == 0:
                eval_reward_monitor.add(step, np.mean(eval_fn()))

            step += 1
            cumulative_reward += rew_tp1
            obs_t = obs_tp1
            time_monitor.add(step)

        # record metrics
        reward_monitor.add(step, cumulative_reward)
Esempio n. 5
0
def main():
    # Training settings
    args = Yolov2OptionTraining().parse_args()

    nsamples = file_lines(args.train)

    set_default_context_by_args(args)

    # Training parameters
    max_epochs = args.max_batches * args.batch_size * args.accum_times / nsamples + 1

    if not os.path.exists(args.output):
        os.mkdir(args.output)

    ###############

    # Load parameters
    print("Load", args.weight, "...")
    if args.fine_tune:
        nn.load_parameters(args.weight)
        nn.parameter.pop_parameter("detection/conv/W")
        nn.parameter.pop_parameter("detection/conv/b")
    else:
        nn.load_parameters(args.weight)

    train_graph = TrainGraph(args, (args.size_aug[-1], args.size_aug[-1]))
    yolo_solver = YoloSolver(args)

    if args.on_memory_data:
        on_memory_data = dataset.load_on_memory_data(args.train)
    else:
        on_memory_data = None

    prefetch_iterator = PrefetchIterator()

    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.output)
    monitor_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_time = MonitorTimeElapsed('Time per epoch', monitor, interval=1)

    # Epoch loop
    for epoch in range(0, int(max_epochs)):
        loss = train(args, epoch, max_epochs, train_graph, yolo_solver,
                     prefetch_iterator, on_memory_data)
        monitor_loss.add(epoch, loss)
        monitor_time.add(epoch)

    # Save the final parameters
    logging('save weights to %s/%06d.h5' % (args.output, epoch + 1))
    nn.save_parameters('%s/%06d.h5' % (args.output, epoch + 1))
Esempio n. 6
0
def main():
    # Context
    ctx = get_extension_context("cudnn", device_id="0")
    nn.set_default_context(ctx)
    nn.auto_forward(False)
    # Inputs
    b, c, h, w = 64, 1, 28, 28
    x = nn.Variable([b, c, h, w])
    t = nn.Variable([b, 1])
    vx = nn.Variable([b, c, h, w])
    vt = nn.Variable([b, 1])
    # Model
    model = Model()
    pred = model(x)
    loss = F.softmax_cross_entropy(pred, t)
    vpred = model(vx, test=True)
    verror = F.top_n_error(vpred, vt)
    # Solver
    solver = S.Adam()
    solver.set_parameters(model.get_parameters(grad_only=True))
    # Data Iterator
    tdi = data_iterator_mnist(b, train=True)
    vdi = data_iterator_mnist(b, train=False)
    # Monitor
    monitor = Monitor("tmp.monitor")
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Training loop
    for e in range(1):
        for j in range(tdi.size // b):
            i = e * tdi.size // b + j
            x.d, t.d = tdi.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()
            monitor_loss.add(i, loss.d)
        error = 0.0
        for _ in range(vdi.size // b):
            vx.d, vt.d = vdi.next()
            verror.forward(clear_buffer=True)
            error += verror.d
        error /= vdi.size // b
        monitor_verr.add(i, error)
Esempio n. 7
0
    def train(self):
        # variables for training
        tx_in = nn.Variable(
            [self._batch_size, self._x_input_length, self._cols_size])
        tx_out = nn.Variable(
            [self._batch_size, self._x_output_length, self._cols_size])
        tpred = self.network(tx_in, self._lstm_unit_name, self._lstm_units)
        tpred.persistent = True
        loss = F.mean(F.squared_error(tpred, tx_out))
        solver = S.Adam(self._learning_rate)
        solver.set_parameters(nn.get_parameters())

        # variables for validation
        vx_in = nn.Variable(
            [self._batch_size, self._x_input_length, self._cols_size])
        vx_out = nn.Variable(
            [self._batch_size, self._x_output_length, self._cols_size])
        vpred = self.network(vx_in, self._lstm_unit_name, self._lstm_units)

        # data iterators
        tdata = self._load_dataset(self._training_dataset_path,
                                   self._batch_size,
                                   shuffle=True)
        vdata = self._load_dataset(self._validation_dataset_path,
                                   self._batch_size,
                                   shuffle=True)

        # monitors
        from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
        monitor = Monitor(self._monitor_path)
        monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
        monitor_err = MonitorSeries("Training error", monitor, interval=10)
        monitor_time = MonitorTimeElapsed("Training time",
                                          monitor,
                                          interval=100)
        monitor_verr = MonitorSeries("Validation error", monitor, interval=10)

        # Training loop
        for i in range(self._max_iter):
            if i % self._val_interval == 0:
                ve = self._validate(vpred, vx_in, vx_out, vdata,
                                    self._val_iter)
                monitor_verr.add(i, ve / self._val_iter)
            te = self._train(tpred, solver, loss, tx_in, tx_out, tdata.next(),
                             self._weight_decay)
            monitor_loss.add(i, loss.d.copy())
            monitor_err.add(i, te)
            monitor_time.add(i)
        ve = self._validate(vpred, vx_in, vx_out, vdata, self._val_iter)
        monitor_verr.add(i, ve / self._val_iter)

        # Save a best model parameters
        nn.save_parameters(self._model_params_path)
Esempio n. 8
0
def train_loop(env, model, num_actions, return_fn, logdir, eval_fn, args):
    monitor = Monitor(logdir)
    reward_monitor = MonitorSeries('reward', monitor, interval=1)
    eval_reward_monitor = MonitorSeries('eval_reward', monitor, interval=1)
    policy_loss_monitor = MonitorSeries('policy_loss', monitor, interval=10000)
    value_loss_monitor = MonitorSeries('value_loss', monitor, interval=10000)
    sample_action = lambda x: np.random.choice(num_actions, p=x)

    step = 0
    obs_t = env.reset()
    cumulative_reward = np.zeros(len(env.envs), dtype=np.float32)
    obss_t, acts_t, vals_t, rews_tp1, ters_tp1 = [], [], [], [], []
    while step <= args.final_step:
        # inference
        probs_t, val_t = model.infer(pixel_to_float(obs_t))
        # sample actions
        act_t = list(map(sample_action, probs_t))
        # move environment
        obs_tp1, rew_tp1, ter_tp1, _ = env.step(act_t)
        # clip reward between [-1.0, 1.0]
        clipped_rew_tp1 = np.clip(rew_tp1, -1.0, 1.0)

        obss_t.append(obs_t)
        acts_t.append(act_t)
        vals_t.append(val_t)
        rews_tp1.append(clipped_rew_tp1)
        ters_tp1.append(ter_tp1)

        # update parameters
        if len(obss_t) == args.time_horizon:
            vals_t.append(val_t)
            rets_t = return_fn(vals_t, rews_tp1, ters_tp1)
            advs_t = rets_t - vals_t[:-1]
            policy_loss, value_loss = model.train(pixel_to_float(obss_t),
                                                  acts_t, rets_t, advs_t, step)
            policy_loss_monitor.add(step, policy_loss)
            value_loss_monitor.add(step, value_loss)
            obss_t, acts_t, vals_t, rews_tp1, ters_tp1 = [], [], [], [], []

        # save parameters
        cumulative_reward += rew_tp1
        obs_t = obs_tp1

        for i, ter in enumerate(ter_tp1):
            step += 1
            if ter:
                reward_monitor.add(step, cumulative_reward[i])
                cumulative_reward[i] = 0.0
            if step % 10**6 == 0:
                path = os.path.join(logdir, 'model_{}.h5'.format(step))
                nn.save_parameters(path)
                eval_reward_monitor.add(step, eval_fn())
Esempio n. 9
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    rng = np.random.RandomState(313)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Communicator and Context
    extension_module = "cuda.cudnn"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx = extension_context(extension_module, device_id=device_id)
    nn.set_default_context(ctx)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # add parameters to communicator
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    tdata = data_iterator(args.batch_size, True, rng)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if device_id == 0:
            if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                ve = 0.
                for j in range(args.val_iter):
                    image, label = vdata.next()
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= args.val_iter
                monitor_verr.add(i * n_devices, ve)
            if i % int(args.model_save_interval / n_devices) == 0:
                nn.save_parameters(
                    os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Allreduce
        comm.allreduce(division=False, inplace=False)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:
            e = categorical_error(pred_train.d, input_image_train["label"].d)
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, e)
            monitor_time.add(i * n_devices)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
Esempio n. 10
0
def CNN_run(args, ops, alphas_dict):
    """
        Based on the given model architecture,
        construct CNN and execute training.
        input:
            args: arguments set by user.
            ops: operations used in the network.
            arch_dict: a dictionary containing architecture information.
    """

    data_iterator = data_iterator_cifar10
    all_data = data_iterator(args.batch_size, True)
    tdata = all_data.slice(rng=None, slice_start=0, slice_end=25000)
    vdata = all_data.slice(rng=None, slice_start=25000, slice_end=50000)

    # CIFAR10 statistics, mean and variance
    CIFAR_MEAN = np.reshape([0.49139968, 0.48215827, 0.44653124], (1, 3, 1, 1))
    CIFAR_STD = np.reshape([0.24703233, 0.24348505, 0.26158768], (1, 3, 1, 1))

    channels, image_height, image_width = 3, 32, 32
    batch_size = args.batch_size
    initial_model_lr = args.model_lr

    one_epoch = tdata.size // batch_size
    max_iter = args.epoch * one_epoch

    # Create monitor.
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=100)
    monitor_err = MonitorSeries("Training error", monitor, interval=100)
    monitor_vloss = MonitorSeries("Validation loss", monitor, interval=100)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=100)

    # prepare variables and graph used for training
    image_train = nn.Variable(
        (batch_size, channels, image_height, image_width))
    label_train = nn.Variable((batch_size, 1))
    input_image_train = {"image": image_train, "label": label_train}
    pred_train = construct_networks(args, ops, image_train, test=False)
    loss_train = loss_function(pred_train, label_train)

    # prepare solvers for model parameters
    model_params_dict = \
        {k: v for k, v in nn.get_parameters().items() if "alpha_" not in k}
    solver_model = S.Momentum(initial_model_lr)
    solver_model.set_parameters(
        {
            k: v
            for k, v in nn.get_parameters().items()
            if k in model_params_dict.keys()
        },
        reset=False,
        retain_state=True)

    # prepare solvers for architecture parameters
    solver_archs = S.Adam(alpha=args.arch_lr, beta1=0.5, beta2=0.999)
    solver_archs.set_parameters(
        {
            k: v
            for k, v in nn.get_parameters().items() if k in alphas_dict.keys()
        },
        reset=False,
        retain_state=True)

    # Training-loop
    for i in range(max_iter):

        # Update Model Parameters.

        if args.second_order:
            # store the weights before update.
            original_weights = {
                k: nn.Variable(v.shape, need_grad=True).apply(
                    data=nn.NdArray(v.shape).copy_from(v.data))
                for k, v in nn.get_parameters().items() if "alpha_" not in k
            }

            # gradients refuge
            accumulated_gradient = \
                {k: nn.Variable(v.shape).apply(d=0)
                 for k, v in alphas_dict.items()}

        image, label = tdata.next()
        image = image / 255.0
        image = (image - CIFAR_MEAN) / CIFAR_STD
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()

        e = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_loss.add(i, loss_train.d.copy())
        monitor_err.add(i, e)

        if args.lr_control_model:
            new_lr = learning_rate_scheduler(i, max_iter, initial_model_lr, 0)
            solver_model.set_learning_rate(new_lr)

        solver_model.zero_grad()
        loss_train.backward(clear_buffer=True)

        if args.with_grad_clip_model:
            for k, v in model_params_dict.items():
                v.grad.copy_from(
                    F.clip_by_norm(v.grad, args.grad_clip_value_model))

        solver_model.weight_decay(args.weight_decay_model)
        solver_model.update()  # weights update ( w -> w')

        if args.second_order:
            updated_weights = {
                k: nn.Variable(v.shape, need_grad=True).apply(
                    data=nn.NdArray(v.shape).copy_from(v.data))
                for k, v in nn.get_parameters().items() if "alpha_" not in k
            }

        # Update Architecture Parameters.

        ve, vloss = 0., 0.
        v_image, v_label = vdata.next()
        v_image = v_image / 255.0
        v_image = (v_image - CIFAR_MEAN) / CIFAR_STD
        input_image_train["image"].d = v_image
        input_image_train["label"].d = v_label
        # compute Loss_on_valid(w', alpha)
        loss_train.forward(clear_no_need_grad=True)

        ve = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_vloss.add(i, loss_train.d.copy())
        monitor_verr.add(i, ve)

        solver_archs.zero_grad()
        solver_model.zero_grad()
        loss_train.backward(clear_buffer=True)  # its gradient is stored

        if args.second_order:
            accumulated_gradient = store_gradient(accumulated_gradient,
                                                  alphas_dict,
                                                  coeff=1.)

            # grad_alpha_L_val(w', alpha).  Note that gradient stored into .data
            delta_gradient_w = {
                k: nn.Variable(v.shape).apply(data=nn.NdArray(
                    v.shape).copy_from(v.grad),
                                              need_grad=True)
                for k, v in nn.get_parameters().items() if "alpha_" not in k
            }

            epsilon = 0.01 / np.sum(
                [np.linalg.norm(v.d) for v in delta_gradient_w.values()])

            coeff = 1.0 * epsilon
            # w -> w+ (= w + epsilon*grad_Loss_on_val(w', alpha))
            weight_modify(original_weights, delta_gradient_w,
                          model_params_dict, coeff)

            input_image_train["image"].d = image  # reuse the same data
            input_image_train["label"].d = label

            # compute Loss_on_train(w+, alpha)
            loss_train.forward()
            solver_archs.zero_grad()
            solver_model.zero_grad()
            loss_train.backward(clear_buffer=True)  # its gradient is stored

            # accumulate currently registered gradient
            coeff = (-1.) * args.eta / 2. * epsilon
            accumulated_gradient = store_gradient(accumulated_gradient,
                                                  alphas_dict, coeff)

            coeff = -1.0 * epsilon
            # w -> w- (= w - epsilon*grad_Loss_on_val(w', alpha))
            weight_modify(original_weights, delta_gradient_w,
                          model_params_dict, coeff)

            # compute Loss_on_train(w-, alpha)
            loss_train.forward()
            solver_archs.zero_grad()
            solver_model.zero_grad()
            loss_train.backward(clear_buffer=True)  # its gradient is stored

            # accumulate currently registered gradient again
            coeff = (+1.) * args.eta / 2. * epsilon
            accumulated_gradient = store_gradient(accumulated_gradient,
                                                  alphas_dict, coeff)

            # replace the weights
            for k, v in alphas_dict.items():
                nn.parameter.set_parameter(
                    k,
                    nn.Variable(v.shape).apply(data=v.data,
                                               grad=accumulated_gradient[k],
                                               need_grad=True))
            for k, v in model_params_dict.items():
                nn.parameter.set_parameter(
                    k,
                    nn.Variable(v.shape).apply(data=updated_weights[k].data,
                                               need_grad=True))

        solver_archs.weight_decay(args.weight_decay_archs)
        solver_archs.update()

        if i % 1000 == 0:
            for k, v in alphas_dict.items():
                keynames = k.split("_")
                print("\nParameters for {} cell, node {} to {};".format(
                    keynames[1], keynames[2], keynames[3]))
                show_ops_and_prob(v.d, ops)

    return alphas_dict
Esempio n. 11
0
def main(opt):
    '''
    NNabla configuration
    '''
    # os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    type_config = 'half' if opt.mixed_precision else 'float'
    comm = init_nnabla(ext_name=opt.extension_module,
                       device_id='0', type_config=type_config)
    nn.set_auto_forward(True)
    output_folder = os.path.join(
        opt.save_dir, "tmp.monitor.{}_{}".format(opt.arch, opt.num_layers))
    monitor = Monitor(output_folder)
    monitor_loss = None
    monitor_hm_loss = None
    monitor_wh_loss = None
    monitor_off_loss = None
    monitor_acc = None
    monitor_val_loss = None
    monitor_val_hm_loss = None
    monitor_val_wh_loss = None
    monitor_val_off_loss = None
    monitor_map = None
    monitor_time = None
    Detector = detector_factory[opt.task]
    detector = Detector(opt)

    interval = 1
    if comm.rank == 0:
        monitor_loss = MonitorSeries(
            "Training Loss", monitor, interval=interval, verbose=False)
        monitor_hm_loss = MonitorSeries(
            "hm_loss", monitor, interval=interval, verbose=False)
        monitor_wh_loss = MonitorSeries(
            "wh_loss", monitor, interval=interval, verbose=False)
        monitor_off_loss = MonitorSeries(
            "off_loss", monitor, interval=interval, verbose=False)
        monitor_val_loss = MonitorSeries(
            "Validation Loss", monitor, interval=interval, verbose=False)
        monitor_val_hm_loss = MonitorSeries(
            "val_hm_loss", monitor, interval=interval, verbose=False)
        monitor_val_wh_loss = MonitorSeries(
            "val_wh_loss", monitor, interval=interval, verbose=False)
        monitor_val_off_loss = MonitorSeries(
            "val_off_loss", monitor, interval=interval, verbose=False)
        monitor_map = MonitorSeries(
            "Val mAP", monitor, interval=interval, verbose=False)
        monitor_time = MonitorTimeElapsed(
            "time", monitor, interval=1, verbose=False)
    '''
    Data Iterators
    '''
    seed = opt.seed
    rng = np.random.RandomState(seed)
    source_factory = get_data_source(opt.dataset)
    train_source = source_factory(opt, 'train', shuffle=True, rng=rng,
                                  mixed_precision=opt.mixed_precision, channel_last=opt.channel_last)
    train_loader = data_iterator(train_source,
                                 opt.batch_size,
                                 with_memory_cache=False,
                                 with_file_cache=False
                                 )
    train_loader = train_loader.slice(rng, comm.n_procs, slice_pos=comm.rank)
    val_source = source_factory(opt, 'val', shuffle=False, rng=rng,
                                mixed_precision=opt.mixed_precision, channel_last=opt.channel_last)
    val_loader = data_iterator(val_source,
                               opt.batch_size,
                               with_memory_cache=False,
                               with_file_cache=False
                               )
    logger.info('Creating model...')
    logger.info(opt.heads)
    logger.info(f"batch size per gpu: {opt.batch_size}")
    model = create_model(opt.arch, opt.heads, opt.head_conv, opt.num_layers, training=True,
                         channel_last=opt.channel_last, pretrained_model_dir=opt.pretrained_model_dir)
    if opt.checkpoint != '':
        load_model(model, opt.checkpoint, clear=True)

    start_epoch = 0
    loss_func = CtdetLoss(opt)
    lr_sched = create_learning_rate_scheduler(
        opt.train_config.learning_rate_config)
    solver = S.Adam(alpha=lr_sched.get_lr())
    trainer = Trainer(
                model, loss_func, solver, train_loader, train_source, [
                    monitor_loss, monitor_hm_loss, monitor_wh_loss, monitor_off_loss, monitor_val_loss, monitor_val_hm_loss, monitor_val_wh_loss, monitor_val_off_loss], opt, comm)

    root_dir = opt.save_dir

    checkpoint_dir = os.path.join(root_dir, output_folder, 'checkpoints')
    start_epoch = 0
    if opt.resume_from is not None:
        start_epoch = trainer.load_checkpoint(checkpoint_dir, opt.resume_from)
        logger.info('resuming from the epoch {}'.format(start_epoch))

    for epoch in range(start_epoch, opt.num_epochs):
        lr_sched.set_epoch(epoch)
        trainer.solver.set_learning_rate(lr_sched.get_lr())
        iteration = trainer.update(epoch)
        if comm.rank == 0:
            if epoch % opt.save_intervals == 0 or epoch == (opt.num_epochs-1):
                monitor_time.add(epoch)
                trainer.save_checkpoint(checkpoint_dir, epoch)

        if epoch % opt.val_intervals == 0 or epoch == (opt.num_epochs-1):
            model.training = False
            trainer.evaluate(val_loader, epoch)
            if not opt.val_calc_map:
                num_iters = val_loader.size
                pbar = trange(num_iters, desc="[Test][exp_id:{} epoch:{}/{}]".format(
                    opt.exp_id, epoch, opt.num_epochs), disable=comm.rank > 0)
                if comm.rank == 0:
                    results = {}
                    for ind in pbar:
                        img_id = val_source.images[ind]
                        img_info = val_source.coco.loadImgs(ids=[img_id])[0]
                        img_path = os.path.join(
                            val_source.img_dir, img_info['file_name'])
                        with nn.context_scope(comm.ctx_float):
                            ret = detector.run(img_path)
                        results[img_id] = ret['results']
                    val_map = val_source.run_eval(
                        results, opt.save_dir, opt.data_dir)
                    monitor_map.add(epoch, val_map)
            model.training = True
            if comm.n_procs > 1:
                # Required to prevent timeout error of allreduce
                # at the first iteration of the next epoch.
                comm.comm.barrier()
Esempio n. 12
0
def main():

    random.seed(args.seed)
    np.random.seed(args.seed)

    # Prepare for CUDA.
    ctx = get_extension_context('cudnn', device_id=args.gpus)
    nn.set_default_context(ctx)

    start_full_time = time.time()
    from iterator import data_iterator

    # Data list for sceneflow data set
    train_list = "./dataset/sceneflow_train.csv"
    test_list = "./dataset/sceneflow_test.csv"
    train = True
    validation = True

    # Set monitor path.
    monitor_path = './nnmonitor' + str(datetime.now().strftime("%Y%m%d%H%M%S"))

    img_left, img_right, disp_img = read_csv(train_list)
    img_left_test, img_right_test, disp_img_test = read_csv(test_list)
    train_samples = len(img_left)
    test_samples = len(img_left_test)
    train_size = int(len(img_left) / args.batchsize_train)
    test_size = int(len(img_left_test) / args.batchsize_test)

    # Create data iterator.
    data_iterator_train = data_iterator(
        train_samples, args.batchsize_train, img_left, img_right, disp_img, train=True, shuffle=True, dataset=args.dataset)
    data_iterator_test = data_iterator(
        test_samples, args.batchsize_test, img_left_test, img_right_test, disp_img_test, train=False, shuffle=False, dataset=args.dataset)

    # Set data size

    print(train_size, test_size)

    # Define data shape for training.
    var_left = nn.Variable(
        (args.batchsize_train, 3, args.crop_height, args.crop_width))
    var_right = nn.Variable(
        (args.batchsize_train, 3, args.crop_height, args.crop_width))
    var_disp = nn.Variable(
        (args.batchsize_train, 1, args.crop_height, args.crop_width))
    # Define data shape for testing.
    var_left_test = nn.Variable(
        (args.batchsize_test, 3, args.im_height, args.im_width))
    var_right_test = nn.Variable(
        (args.batchsize_test, 3, args.im_height, args.im_width))
    var_disp_test = nn.Variable(
        (args.batchsize_test, 1, args.im_height, args.im_width))
    mask_test = nn.Variable(
        (args.batchsize_test, 1, args.im_height, args.im_width))

    if args.loadmodel is not None:
        # Loading CNN pretrained parameters.
        nn.load_parameters(args.loadmodel)

    # === for Training ===
    # Definition of pred
    pred1, pred2, pred3 = psm_net(var_left, var_right, args.maxdisp, True)
    mask_train = F.less_scalar(var_disp, args.maxdisp)
    sum_mask = F.maximum_scalar(F.sum(mask_train), 1)
    # Definition of loss
    loss = 0.5 * (0.5 * F.sum(F.huber_loss(pred1, var_disp)*mask_train)/(sum_mask) + 0.7 * F.sum(F.huber_loss(
        pred2, var_disp)*mask_train)/(sum_mask) + F.sum(F.huber_loss(pred3, var_disp)*mask_train)/(sum_mask))

    # === for Testing ===
    # Definition of pred
    mask_test = F.less_scalar(var_disp_test, args.maxdisp)
    sum_mask_test = F.maximum_scalar(F.sum(mask_test), 1)
    pred_test = psm_net(var_left_test, var_right_test, args.maxdisp, False)
    test_loss = F.sum(F.abs(pred_test - var_disp_test)*mask_test)/sum_mask_test

    # Prepare monitors.
    monitor = Monitor(monitor_path)
    monitor_train = MonitorSeries('Training loss', monitor, interval=1)
    monitor_test = MonitorSeries('Validation loss', monitor, interval=1)
    monitor_time_train = MonitorTimeElapsed(
        "Training time/epoch", monitor, interval=1)

    # Create a solver (parameter updater)
    solver = S.Adam(alpha=0.001, beta1=0.9, beta2=0.999)

    # Set Parameters
    params = nn.get_parameters()
    solver.set_parameters(params)
    params2 = nn.get_parameters(grad_only=False)
    solver.set_parameters(params2)

    for epoch in range(1, args.epochs+1):
        print('This is %d-th epoch' % (epoch))

        if validation:
            ## teting ##
            total_test_loss = 0

            index_test = 0
            while index_test < test_size:
                var_left_test.d, var_right_test.d, var_disp_test.d = data_iterator_test.next()
                test_loss.forward(clear_no_need_grad=True)
                total_test_loss += test_loss

                print('Iter %d test loss = %.3f' % (index_test, test_loss.d))
                index_test += 1
            test_error = total_test_loss/test_size
            print('epoch %d total 3-px error in val = %.3f' %
                  (epoch, test_error.d))
            # Pass validation loss to a monitor.
            monitor_test.add(epoch, test_error)

        if train:
            ## training ##
            total_train_loss = 0
            index = 0

            while index < train_size:

                # Get mini batch
                # Preprocess
                var_left.d, var_right.d, var_disp.d = data_iterator_train.next()
                loss.forward(clear_no_need_grad=True)
                # Initialize gradients
                solver.zero_grad()
                # Backward execution
                loss.backward(clear_buffer=True)
                # Update parameters by computed gradients
                solver.update()
                print('Iter %d training loss = %.3f' %
                      (index, loss.d))
                total_train_loss += loss.d
                index += 1
            train_error = total_train_loss/train_size
            monitor_time_train.add(epoch)
            print('epoch %d total training loss = %.3f' % (epoch, train_error))

            # Pass training loss to a monitor.
            monitor_train.add(epoch, train_error)
            print('full training time = %.2f HR' %
                  ((time.time() - start_full_time)/3600))

            # Save Parameter
            out_param_file = os.path.join(
                args.savemodel, 'psmnet_trained_param_' + str(epoch) + '.h5')
            nn.save_parameters(out_param_file)
Esempio n. 13
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
Esempio n. 14
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path,
        '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)
Esempio n. 15
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
Esempio n. 16
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Data Iterator
    tdata = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(args.max_iter):
        # Validation
        if i % int(n_train_samples / args.batch_size) == 0:
            ve = 0.
            for j in range(args.val_iter):
                image, label = vdata.next()
                input_image_valid["image"].d = image
                pred_valid.forward()
                ve += categorical_error(pred_valid.d, label)
            ve /= args.val_iter
            monitor_verr.add(i, ve)
        if int(i % args.model_save_interval) == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Solvers update
        solver.update()

        e = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_loss.add(i, loss_train.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter)))
def meta_train(args, shape_x, train_data, valid_data, test_data):

    # Build episode generators
    train_episode_generator = EpisodeGenerator(
        train_data[0], train_data[1], args.n_class_tr, args.n_shot_tr, args.n_query_tr)
    valid_episode_generator = EpisodeGenerator(
        valid_data[0], valid_data[1], args.n_class, args.n_shot, args.n_query)
    test_episode_generator = EpisodeGenerator(
        test_data[0], test_data[1], args.n_class, args.n_shot, args.n_query)

    # Build training model
    xs_t = nn.Variable((args.n_class_tr * args.n_shot_tr, ) + shape_x)
    xq_t = nn.Variable((args.n_class_tr * args.n_query_tr, ) + shape_x)
    hq_t = net(args.n_class_tr, xs_t, xq_t, args.embedding,
               args.net_type, args.metric, False)
    yq_t = nn.Variable((args.n_class_tr * args.n_query_tr, 1))
    loss_t = F.mean(F.softmax_cross_entropy(hq_t, yq_t))

    # Build evaluation model
    xs_v = nn.Variable((args.n_class * args.n_shot, ) + shape_x)
    xq_v = nn.Variable((args.n_class * args.n_query, ) + shape_x)
    hq_v = net(args.n_class, xs_v, xq_v, args.embedding,
               args.net_type, args.metric, True)
    yq_v = nn.Variable((args.n_class * args.n_query, 1))
    err_v = F.mean(F.top_n_error(hq_v, yq_v, n=1))

    # Setup solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor outputs
    monitor = Monitor(args.work_dir)
    monitor_loss = MonitorSeries(
        "Training loss", monitor, interval=args.iter_per_epoch)
    monitor_valid_err = MonitorSeries(
        "Validation error", monitor, interval=args.iter_per_valid)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)

    # Output files
    param_file = args.work_dir + "/params.h5"
    tsne_file = args.work_dir + "/tsne.png"

    # Save NNP
    batch_size = 1
    contents = save_nnp({'x0': xs_v, 'x1': xq_v}, {
                          'y': hq_v}, batch_size)
    save.save(os.path.join(args.work_dir,
                           'MetricMetaLearning_epoch0.nnp'), contents, variable_batch_size=False)

    # Training loop
    train_losses = []
    best_err = 1.0
    for i in range(args.max_iteration):

        # Decay learning rate
        if (i + 1) % args.lr_decay_interval == 0:
            solver.set_learning_rate(solver.learning_rate() * args.lr_decay)

        # Create an episode
        xs_t.d, xq_t.d, yq_t.d = train_episode_generator.next()

        # Training by the episode
        solver.zero_grad()
        loss_t.forward(clear_no_need_grad=True)
        loss_t.backward(clear_buffer=True)
        solver.update()
        train_losses.append(loss_t.d.copy())

        # Evaluation
        if (i + 1) % args.iter_per_valid == 0:
            train_loss = np.mean(train_losses)
            train_losses = []
            valid_errs = []
            for k in range(args.n_episode_for_valid):
                xs_v.d, xq_v.d, yq_v.d = valid_episode_generator.next()
                err_v.forward(clear_no_need_grad=True, clear_buffer=True)
                valid_errs.append(np.float(err_v.d.copy()))
            valid_err = np.mean(valid_errs)

            monitor_loss.add(i + 1, loss_t.d.copy())
            monitor_valid_err.add(i + 1, valid_err * 100)
            if valid_err < best_err:
                best_err = valid_err
                nn.save_parameters(param_file)

    # Final evaluation
    nn.load_parameters(param_file)
    v_errs = []
    for k in range(args.n_episode_for_test):
        xs_v.d, xq_v.d, yq_v.d = test_episode_generator.next()
        err_v.forward(clear_no_need_grad=True, clear_buffer=True)
        v_errs.append(np.float(err_v.d.copy()))
    v_err_mean = np.mean(v_errs)
    v_err_std = np.std(v_errs)
    v_err_conf = 1.96 * v_err_std / np.sqrt(args.n_episode_for_test)
    monitor_test_err.add(0, v_err_mean * 100)
    monitor_test_conf.add(0, v_err_conf * 100)

    # Visualization
    n_class = 50
    n_sample = 20
    visualize_episode_generator = EpisodeGenerator(
        train_data[0], train_data[1], n_class, 0, n_sample)
    _, samples, labels = visualize_episode_generator.next()
    u = get_embeddings(samples, conv4)
    v = get_tsne(u)
    plot_tsne(v[:, 0], v[:, 1], labels[:, 0], tsne_file)

    # Save NNP
    contents = save_nnp({'x0': xs_v, 'x1': xq_v}, {
                          'y': hq_v}, batch_size)
    save.save(os.path.join(args.work_dir,
                           'MetricMetaLearning.nnp'), contents, variable_batch_size=False)
Esempio n. 18
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct computation graphs for training and one for validation.
    * Initialize solvers and set parameter variables to those.
    * Instantiate a communicator and set parameter variables.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprops
      * Set parameter gradients zero
      * Execute backprop.
      * In-place allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Create contexts
    extension_module = args.context
    if extension_module != "cuda" and \
            extension_module != "cuda.cudnn":
        raise Exception("Use `cuda` or `cuda.cudnn` extension_module.")
    n_devices = args.n_devices
    ctxs = []
    for i in range(n_devices):
        ctx = extension_context(extension_module, device_id=i)
        ctxs.append(ctx)
    ctx = ctxs[-1]

    # Create training graphs
    input_image_train = []
    preds_train = []
    losses_train = []
    test = False
    for i in range(n_devices):
        image = nn.Variable((args.batch_size, 3, 32, 32))
        label = nn.Variable((args.batch_size, 1))
        device_scope_name = "device{}".format(i)

        pred = cifar10_resnet23_prediction(image, ctxs[i], device_scope_name,
                                           test)
        loss = cifar10_resnet32_loss(pred, label)

        input_image_train.append({"image": image, "label": label})
        preds_train.append(pred)
        losses_train.append(loss)

    # Create validation graph
    test = True
    device_scope_name = "device{}".format(0)
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar10_resnet23_prediction(image_valid, ctxs[i],
                                             device_scope_name, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solvers = []
    for i in range(n_devices):
        with nn.context_scope(ctxs[i]):
            solver = S.Adam()
            device_scope_name = "device{}".format(i)
            with nn.parameter_scope(device_scope_name):
                params = nn.get_parameters()
                solver.set_parameters(params)
            solvers.append(solver)

    # Communicator
    comm = C.DataParalellCommunicator(ctx)
    for i in range(n_devices):
        device_scope_name = "device{}".format(i)
        with nn.parameter_scope(device_scope_name):
            ctx = ctxs[i]
            params = nn.get_parameters()
            comm.add_context_and_parameters((ctx, params))
    comm.init()

    # Create threadpools with one thread
    pools = []
    for _ in range(n_devices):
        pool = ThreadPool(processes=1)
        pools.append(pool)

    # Once forward/backward to safely secure memory
    for device_id in range(n_devices):
        data, label = \
            (np.random.randn(*input_image_train[device_id]["image"].shape),
             (np.random.rand(*input_image_train[device_id]["label"].shape) * 10).astype(np.int32))

        ret = pools[device_id].apply_async(
            forward_backward, (input_image_train[device_id]["image"], data,
                               input_image_train[device_id]["label"], label,
                               losses_train[device_id], solvers[device_id]))
        ret.get()
        losses_train[device_id].d  # sync to host

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    tdata = data_iterator_cifar10(args.batch_size, True, rng)
    vdata = data_iterator_cifar10(args.batch_size, False)

    # Training-loop
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve = 0.
            for j in range(args.val_iter):
                image, label = vdata.next()
                input_image_valid["image"].d = image
                pred_valid.forward()
                ve += categorical_error(pred_valid.d, label)
            ve /= args.val_iter
            monitor_verr.add(i * n_devices, ve)
        if i % int(args.model_save_interval / n_devices) == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forwards/Zerograd/Backwards
        fb_results = []
        for device_id in range(n_devices):
            image, label = tdata.next()

            res = pools[device_id].apply_async(
                forward_backward,
                (input_image_train[device_id]["image"], image,
                 input_image_train[device_id]["label"], label,
                 losses_train[device_id], solvers[device_id]))
            fb_results.append(res)
        for device_id in range(n_devices):
            fb_results[device_id].get()

        # In-place allreduce
        comm.allreduce(division=True, inplace=False)

        # Solvers update
        for device_id in range(n_devices):
            solvers[device_id].update()

        e = categorical_error(preds_train[-1].d,
                              input_image_train[-1]["label"].d)
        monitor_loss.add(i * n_devices, losses_train[-1].d.copy())
        monitor_err.add(i * n_devices, e)
        monitor_time.add(i * n_devices)

    nn.save_parameters(
        os.path.join(args.model_save_path,
                     'params_%06d.h5' % (args.max_iter / n_devices)))
Esempio n. 20
0
class Train(BaseExecution):
    """
    Execution model for StyleGAN training and testing
    """
    def __init__(self, monitor, config, args, comm, few_shot_config):
        super(Train, self).__init__(monitor, config, args, comm,
                                    few_shot_config)

        # Initialize Monitor
        self.monitor_train_loss, self.monitor_train_gen = None, None
        self.monitor_val_loss, self.monitor_val_gen = None, None
        if comm is not None:
            if comm.rank == 0:
                self.monitor_train_gen_loss = MonitorSeries(
                    config['monitor']['train_loss'],
                    monitor,
                    interval=self.config['logger_step_interval'])
                self.monitor_train_gen = MonitorImageTile(
                    config['monitor']['train_gen'],
                    monitor,
                    interval=self.config['logger_step_interval'],
                    num_images=self.config['batch_size'])
                self.monitor_train_disc_loss = MonitorSeries(
                    config['monitor']['train_loss'],
                    monitor,
                    interval=self.config['logger_step_interval'])

        os.makedirs(self.config['saved_weights_dir'], exist_ok=True)
        self.results_dir = args.results_dir
        self.save_weights_dir = args.weights_path

        self.few_shot_config = few_shot_config

        # Initialize Discriminator
        self.discriminator = Discriminator(config['discriminator'],
                                           self.img_size)
        self.gen_exp_weight = 0.5**(32 / (10 * 1000))
        self.generator_ema = Generator(config['generator'],
                                       self.img_size,
                                       config['train']['mix_after'],
                                       global_scope='GeneratorEMA')

        # Initialize Solver
        if 'gen_solver' not in dir(self):
            if self.config['solver'] == 'Adam':
                self.gen_solver = S.Adam(beta1=0, beta2=0.99)
                self.disc_solver = S.Adam(beta1=0, beta2=0.99)
            else:
                self.gen_solver = eval('S.' + self.config['solver'])()
                self.disc_solver = eval('S.' + self.config['solver'])()

        self.gen_solver.set_learning_rate(self.config['learning_rate'])
        self.disc_solver.set_learning_rate(self.config['learning_rate'])

        self.gen_mean_path_length = 0.0

        self.dali = args.dali
        self.args = args
        # Initialize Dataloader
        if args.data == 'ffhq':
            if args.dali:
                self.train_loader = get_dali_iterator_ffhq(
                    args.dataset_path, config['data'], self.img_size,
                    self.batch_size, self.comm)
            else:
                self.train_loader = get_data_iterator_ffhq(
                    args.dataset_path, config['data'], self.batch_size,
                    self.img_size, self.comm)
        else:
            print('Dataset not recognized')
            exit(1)

        # Start training
        self.train()

    def build_static_graph(self):
        real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size,
                                      self.img_size))
        noises = [
            F.randn(shape=(self.batch_size, self.config['latent_dim']))
            for _ in range(2)
        ]

        if self.config['regularize_gen']:
            fake_img, dlatents = self.generator(self.batch_size,
                                                noises,
                                                return_latent=True)
        else:
            fake_img = self.generator(self.batch_size, noises)
        fake_img_test = self.generator_ema(self.batch_size, noises)

        fake_disc_out = self.discriminator(fake_img)
        real_disc_out = self.discriminator(real_img)
        disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out)

        gen_loss = gen_nonsaturating_loss(self.discriminator(fake_img))

        if self.few_shot_config['common']['type'] == 'ewc':
            EWC_loss_function = ElasticWeightConsolidation(
                _y=fake_img,
                _out_FI_path=self.few_shot_config['ewc']['weights'])
            gen_loss += self.few_shot_config['ewc']['lambda'] * \
                EWC_loss_function()

        var_name_list = [
            'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss',
            'fake_disc_out', 'real_disc_out', 'fake_img_test'
        ]
        var_list = [
            real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out,
            real_disc_out, fake_img_test
        ]

        if self.config['regularize_gen']:
            dlatents.need_grad = True
            mean_path_length = nn.Variable()
            pl_reg, path_mean, _ = gen_path_regularize(
                fake_img=fake_img,
                latents=dlatents,
                mean_path_length=mean_path_length)
            path_mean_update = F.assign(mean_path_length, path_mean)
            path_mean_update.name = 'path_mean_update'
            pl_reg += 0 * path_mean_update
            gen_loss_reg = gen_loss + pl_reg
            var_name_list.append('gen_loss_reg')
            var_list.append(gen_loss_reg)

        if self.config['regularize_disc']:
            real_img.need_grad = True
            real_disc_out = self.discriminator(real_img)
            disc_loss_reg = disc_loss + self.config[
                'r1_coeff'] * 0.5 * disc_r1_loss(
                    real_disc_out, real_img) * self.config['disc_reg_step']
            real_img.need_grad = False
            var_name_list.append('disc_loss_reg')
            var_list.append(disc_loss_reg)

        Parameters = namedtuple('Parameters', var_name_list)
        self.parameters = Parameters(*var_list)

    def forward_backward_pass(self, i, real_img=None, noises=None):
        """1 training step: forward and backward propagation

        Args:
                real_img (nn.Variable): [description]
                noises (list of nn.Variable): [description]

        Returns:
                [type]: [description]
        """ """
		Graph construction for forward pass for training
		"""

        # Update Discriminator
        self.disc_solver.zero_grad()

        self.parameters.fake_img.need_grad = False
        self.parameters.fake_img.forward()

        if self.config[
                'regularize_disc'] and i % self.config['disc_reg_step'] == 0:
            self.parameters.disc_loss_reg.forward(clear_no_need_grad=True)
            self.parameters.disc_loss_reg.backward(clear_buffer=True)
        else:
            self.parameters.disc_loss.forward(clear_no_need_grad=True)
            self.parameters.disc_loss.backward(clear_buffer=True)
        if self.comm is not None:
            params = [
                x.grad for x in self.disc_solver.get_parameters().values()
            ]
            self.comm.all_reduce(params, division=False, inplace=True)

        self.disc_solver.update()

        # Update Generator

        self.gen_solver.zero_grad()
        self.parameters.fake_img.need_grad = True
        if self.config[
                'regularize_gen'] and i % self.config['gen_reg_step'] == 0:
            self.parameters.gen_loss_reg.forward(clear_no_need_grad=True)
            self.parameters.gen_loss_reg.backward(clear_buffer=True)
        else:
            self.parameters.gen_loss.forward(clear_no_need_grad=True)
            self.parameters.gen_loss.backward(clear_buffer=True)
        if self.comm is not None:
            params = [
                x.grad for x in self.gen_solver.get_parameters().values()
            ]
            self.comm.all_reduce(params, division=False, inplace=True)
        self.gen_solver.update()

    def ema_update(self):
        with nn.parameter_scope('Generator'):
            g_params = nn.get_parameters(grad_only=False)
        with nn.parameter_scope('GeneratorEMA'):
            g_ema_params = nn.get_parameters(grad_only=False)
        update_ema_list = []
        for name in g_ema_params.keys():
            params_ema_updated = self.gen_exp_weight * \
                    g_ema_params[name] + \
                        (1.0 - self.gen_exp_weight) * g_params[name]
            update_ema_list.append(
                F.assign(g_ema_params[name], params_ema_updated))
        return F.sink(*update_ema_list)

    def copy_params(self, scope_from, scope_to):
        with nn.parameter_scope(scope_from):
            params_from = nn.get_parameters(grad_only=False)
        with nn.parameter_scope(scope_to):
            params_to = nn.get_parameters(grad_only=False)
        for name in params_to.keys():
            params_to[name].d = params_from[name].d

    def train(self):
        """
        Training loop: Runs forward_backward pass for the specified number of iterations and stores the generated images, model weigths and solver states
        """
        # n_procs = 1 if self.comm is None else self.comm.n_procs
        iterations_per_epoch = int(
            np.ceil(self.train_loader.size / self.train_loader.batch_size))

        self.build_static_graph()
        disc_params = {
            k: v
            for k, v in nn.get_parameters().items()
            if k.startswith('Discriminator')
        }
        self.disc_solver.set_parameters(disc_params)
        gen_params = {
            k: v
            for k, v in nn.get_parameters().items()
            if (k.startswith('Generator') and not k.startswith('GeneratorEMA'))
        }
        self.gen_solver.set_parameters(gen_params)
        if os.path.isfile(os.path.join(self.args.weights_path,
                                       'gen_solver.h5')):
            self.gen_solver.load_states(
                os.path.join(self.args.weights_path, 'gen_solver.h5'))
            self.disc_solver.load_states(
                os.path.join(self.args.weights_path, 'disc_solver.h5'))

        self.copy_params('Generator', 'GeneratorEMA')

        ema_updater = self.ema_update()

        for epoch in range(self.config['num_epochs']):
            pbar = trange(iterations_per_epoch,
                          desc='Epoch ' + str(epoch),
                          disable=self.comm.rank > 0)

            epoch_gen_loss, epoch_disc_loss = 0.0, 0.0
            print(
                f'Iterations per epoch: {iterations_per_epoch}, Number of processes: {self.comm.n_procs}, Data Loader size: {self.train_loader.size}'
            )

            for i in pbar:

                data = self.train_loader.next()

                if isinstance(data[0], nn.NdArray):
                    self.parameters.real_img.data = data[0]
                else:
                    self.parameters.real_img.d = data[0]

                self.forward_backward_pass(i)
                gen_loss = self.parameters.gen_loss
                disc_loss = self.parameters.disc_loss
                real_img = self.parameters.real_img
                fake_img = self.parameters.fake_img
                real_disc_out = self.parameters.real_disc_out
                fake_disc_out = self.parameters.fake_disc_out

                ema_updater.forward()

                epoch_gen_loss += gen_loss.d
                epoch_disc_loss += disc_loss.d

                pbar.set_description(
                    f'Gen Loss: {gen_loss.d}, Disc Loss: {disc_loss.d}')

                if np.isnan(gen_loss.d) or np.isnan(disc_loss.d):
                    for k, v in nn.get_parameters().items():
                        if v.d.max() < 1e-3 or np.any(np.isnan(v.d)):
                            print(k)

                if self.comm.rank == 0 and (
                        i == 30 and
                    (epoch % self.config['save_param_step_interval'] == 0
                     or epoch == self.config['num_epochs'] - 1)):
                    self.save_weights(self.save_weights_dir, epoch)
                    self.parameters.fake_img_test.forward(clear_buffer=True)
                    fake_img.forward(clear_buffer=True)
                    save_generations(
                        self.parameters.fake_img_test,
                        os.path.join(self.results_dir, f'fake_ema_{epoch}'))
                    save_generations(
                        real_img,
                        os.path.join(self.results_dir, f'real_{epoch}'))
                    save_generations(
                        fake_img,
                        os.path.join(self.results_dir, f'fake_{epoch}'))

            epoch_gen_loss /= iterations_per_epoch
            epoch_disc_loss /= iterations_per_epoch

            if self.comm is not None:
                if self.comm.rank == 0:
                    self.monitor_train_gen_loss.add(epoch, epoch_gen_loss)
                    self.monitor_train_gen_loss.add(epoch, epoch_disc_loss)
Esempio n. 21
0
def main():

    args = get_args()
    state_size = args.state_size
    batch_size = args.batch_size
    num_steps = args.num_steps
    num_layers = args.num_layers
    max_epoch = args.max_epoch
    max_norm = args.gradient_clipping_max_norm
    num_words = 10000
    lr = args.learning_rate

    train_data, val_data, test_data = get_data()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    from nnabla.monitor import Monitor, MonitorSeries
    monitor = Monitor(args.work_dir)
    monitor_perplexity = MonitorSeries(
        "Training perplexity", monitor, interval=10)
    monitor_vperplexity = MonitorSeries("Validation perplexity", monitor, interval=(
        len(val_data)//(num_steps*batch_size)))
    monitor_tperplexity = MonitorSeries(
        "Test perplexity", monitor, interval=(len(test_data)//(num_steps*1)))

    l1 = LSTMWrapper(batch_size, state_size)
    l2 = LSTMWrapper(batch_size, state_size)

    # train graph

    x = nn.Variable((batch_size, num_steps))
    t = nn.Variable((batch_size, num_steps))
    w = I.UniformInitializer((-0.1, 0.1))
    b = I.ConstantInitializer(1)
    loss = get_loss(l1, l2, x, t, w, b, num_words,
                    batch_size, state_size, True)
    l1.share_data()
    l2.share_data()

    # validation graph

    vx = nn.Variable((batch_size, num_steps))
    vt = nn.Variable((batch_size, num_steps))
    vloss = get_loss(l1, l2, vx, vt, w, b, num_words, batch_size, state_size)
    solver = S.Sgd(lr)
    solver.set_parameters(nn.get_parameters())

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    best_val = 10000
    for epoch in range(max_epoch):
        l1.reset_state()
        l2.reset_state()
        for i in range(len(train_data)//(num_steps*batch_size)):
            x.d, t.d = get_batch(train_data, i*num_steps,
                                 batch_size, num_steps)
            solver.zero_grad()
            loss.forward()
            loss.backward(clear_buffer=True)
            solver.weight_decay(1e-5)
            gradient_clipping(nn.get_parameters().values(), max_norm)
            solver.update()
            perp = perplexity(loss.d.copy())
            monitor_perplexity.add(
                (len(train_data)//(num_steps*batch_size))*(epoch)+i, perp)
        l1.reset_state()
        l2.reset_state()
        vloss_avg = 0
        for i in range(len(val_data)//(num_steps * batch_size)):
            vx.d, vt.d = get_batch(val_data, i*num_steps,
                                   batch_size, num_steps)
            vloss.forward()
            vloss_avg += vloss.d.copy()
        vloss_avg /= float((len(val_data)//(num_steps*batch_size)))
        vper = perplexity(vloss_avg)

        if vper < best_val:
            best_val = vper
            if vper < 200:
                save_name = "params_epoch_{:02d}.h5".format(epoch)
                nn.save_parameters(os.path.join(args.save_dir, save_name))
        else:
            solver.set_learning_rate(solver.learning_rate()*0.25)
            logger.info("Decreased learning rate to {:05f}".format(
                solver.learning_rate()))
        monitor_vperplexity.add(
            (len(val_data)//(num_steps*batch_size))*(epoch)+i, vper)

    # for final test split
    t_batch_size = 1
    tl1 = LSTMWrapper(t_batch_size, state_size)
    tl2 = LSTMWrapper(t_batch_size, state_size)
    tloss_avg = 0
    tx = nn.Variable((t_batch_size, num_steps))
    tt = nn.Variable((t_batch_size, num_steps))
    tloss = get_loss(tl1, tl2, tx, tt, w, b, num_words, 1, state_size)

    tl1.share_data()
    tl2.share_data()

    for i in range(len(test_data)//(num_steps * t_batch_size)):
        tx.d, tt.d = get_batch(test_data, i*num_steps, 1, num_steps)
        tloss.forward()
        tloss_avg += tloss.d.copy()
    tloss_avg /= float((len(test_data)//(num_steps*t_batch_size)))
    tper = perplexity(tloss_avg)
    monitor_tperplexity.add(
        (len(test_data)//(num_steps*t_batch_size))*(epoch)+i, tper)
Esempio n. 22
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size

    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=32,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((args.batch_size, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    _, tdata = data_iterator(args.batch_size, True, rng)
    vsource, vdata = data_iterator(args.batch_size, False)

    # loss_error_train.forward()

    # Training-loop
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                input_image_valid["image"].d = image
                input_image_valid["label"].d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            # Save model
            if device_id == 0:
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(
                        os.path.join(args.model_save_path,
                                     'params_%06d.h5' % i))

        # Forward/Zerograd
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:  # loss and error locally, and elapsed time
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

        # exit(0)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct computation graphs for training and one for validation.
    * Initialize solvers and set parameter variables to those.
    * Instantiate a communicator and set parameter variables.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprops
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Create contexts
    extension_module = args.context
    if extension_module != "cuda" and \
            extension_module != "cuda.cudnn":
        raise Exception("Use `cuda` or `cuda.cudnn` extension_module.")
    n_devices = args.n_devices
    ctxs = []
    for i in range(n_devices):
        ctx = extension_context(extension_module, device_id=i)
        ctxs.append(ctx)
    ctx = ctxs[-1]

    # Create training graphs
    input_image_train = []
    preds_train = []
    losses_train = []
    test = False
    for i in range(n_devices):
        image = nn.Variable((args.batch_size, 3, 32, 32))
        label = nn.Variable((args.batch_size, 1))
        device_scope_name = "device{}".format(i)

        pred = cifar100_resnet23_prediction(
            image, ctxs[i], device_scope_name, test)
        loss = cifar100_resnet32_loss(pred, label)

        input_image_train.append({"image": image, "label": label})
        preds_train.append(pred)
        losses_train.append(loss)

    # Create validation graph
    test = True
    device_scope_name = "device{}".format(0)
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar100_resnet23_prediction(
        image_valid, ctxs[i], device_scope_name, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solvers = []
    for i in range(n_devices):
        with nn.context_scope(ctxs[i]):
            solver = S.Adam()
            device_scope_name = "device{}".format(i)
            with nn.parameter_scope(device_scope_name):
                params = nn.get_parameters()
                solver.set_parameters(params)
            solvers.append(solver)

    # Communicator
    comm = C.DataParalellCommunicator(ctx)
    for i in range(n_devices):
        device_scope_name = "device{}".format(i)
        with nn.parameter_scope(device_scope_name):
            ctx = ctxs[i]
            params = nn.get_parameters()
            comm.add_context_and_parameters((ctx, params))
    comm.init()

    # Create threadpools with one thread
    pools = []
    for _ in range(n_devices):
        pool = ThreadPool(processes=1)
        pools.append(pool)

    # Once forward/backward to safely secure memory
    for device_id in range(n_devices):
        data, label = \
            (np.random.randn(*input_image_train[device_id]["image"].shape),
             (np.random.rand(*input_image_train[device_id]["label"].shape) * 10).astype(np.int32))

        ret = pools[device_id].apply_async(forward_backward,
                                           (input_image_train[device_id]["image"], data,
                                            input_image_train[device_id]["label"], label,
                                               losses_train[device_id], solvers[device_id]))
        ret.get()
        losses_train[device_id].d  # sync to host

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)
    with data_iterator_cifar100(args.batch_size, True) as tdata, \
            data_iterator_cifar100(bs_valid, False) as vdata:
        # Training-loop
        for i in range(int(args.max_iter / n_devices)):
            # Validation
            if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                ve = 0.
                for j in range(args.val_iter):
                    image, label = vdata.next()
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= args.val_iter
                monitor_verr.add(i * n_devices, ve)
            if i % int(args.model_save_interval / n_devices) == 0:
                nn.save_parameters(os.path.join(
                    args.model_save_path, 'params_%06d.h5' % i))

            # Forwards/Zerograd/Backwards
            fb_results = []
            for device_id in range(n_devices):
                image, label = tdata.next()

                res = pools[device_id].apply_async(forward_backward,
                                                   (input_image_train[device_id]["image"], image,
                                                    input_image_train[device_id]["label"], label,
                                                    losses_train[device_id], solvers[device_id]))
                fb_results.append(res)
            for device_id in range(n_devices):
                fb_results[device_id].get()

            # In-place Allreduce
            comm.allreduce()

            # Solvers update
            for device_id in range(n_devices):
                solvers[device_id].update()

            e = categorical_error(
                preds_train[-1].d, input_image_train[-1]["label"].d)
            monitor_loss.add(i * n_devices, losses_train[-1].d.copy())
            monitor_err.add(i * n_devices, e)
            monitor_time.add(i * n_devices)

    nn.save_parameters(os.path.join(
        args.model_save_path,
        'params_%06d.h5' % (args.max_iter / n_devices)))
Esempio n. 24
0
monitor_perplexity_valid = MonitorSeries('perplexity_valid',
                                         monitor,
                                         interval=1)

for epoch in range(max_epoch):
    train_loss_set = []
    for i in tqdm(range(num_train_batch)):
        x_batch, y_batch = train_data_iter.next()
        y_batch = y_batch.reshape(list(y_batch.shape) + [1])

        x.d, t.d = x_batch, y_batch

        loss.forward(clear_no_need_grad=True)
        train_loss_set.append(loss.d.copy())
        solver.zero_grad()
        loss.backward(clear_buffer=True)
        solver.update()

    valid_loss_set = []
    for i in range(num_valid_batch):
        x_batch, y_batch = valid_data_iter.next()
        y_batch = y_batch.reshape(list(y_batch.shape) + [1])

        x.d, t.d = x_batch, y_batch

        loss.forward(clear_no_need_grad=True)
        valid_loss_set.append(loss.d.copy())

    monitor_perplexity.add(epoch + 1, np.e**np.mean(train_loss_set))
    monitor_perplexity_valid.add(epoch + 1, np.e**np.mean(valid_loss_set))
Esempio n. 25
0
def train(args, train_dataset, tokenizer):
    """ Train the model """
    # Load the pretrianed model
    nn.load_parameters(args.pretrained_model)
    # Drop final layer for task-specific fine-tuning
    nn.parameter.pop_parameter('affine_seq_class/affine/W')
    nn.parameter.pop_parameter('affine_seq_class/affine/b')

    train_dataloader = data_iterator(
        train_dataset, batch_size=args.train_batch_size)

    global_step = 0
    train_loss = 0.0
    model = BertForSequenceClassification()

    input_ids = nn.Variable((args.train_batch_size, args.max_seq_length))
    attention_mask = nn.Variable((args.train_batch_size, args.max_seq_length))
    token_type_ids = nn.Variable((args.train_batch_size, args.max_seq_length))
    labels = nn.Variable((args.train_batch_size, ))

    input_ids_eval = nn.Variable((args.eval_batch_size, args.max_seq_length))
    attention_mask_eval = nn.Variable(
        (args.eval_batch_size, args.max_seq_length))
    token_type_ids_eval = nn.Variable(
        (args.eval_batch_size, args.max_seq_length))
    labels_eval = nn.Variable((args.eval_batch_size, ))

    activation = F.gelu
    if args.activation == 'relu':
        activation = F.relu
    loss, _, train_error = model(args, input_ids=input_ids, attention_mask=attention_mask,
                                 token_type_ids=token_type_ids, labels=labels,
                                 num_labels=args.num_labels, vocab_size=args.vocab_size,
                                 num_embed_dim=args.num_embed_dim,
                                 num_pos_ids=args.num_position_ids,
                                 num_attention_layers=args.num_attention_layers,
                                 num_attention_embed_dim=args.num_attention_embed_dim,
                                 num_attention_heads=args.num_attention_heads,
                                 num_attention_dim_feedforward=args.num_attention_dim_feedforward,
                                 attention_activation=activation, pool_outmap=args.num_pool_outmap,
                                 embed_dropout_prob=args.embed_dropout,
                                 attention_dropout_prob=args.attention_dropout,
                                 dropout_prob=args.last_dropout, test=False)

    loss.persistent = True
    if args.solver == 'Adam':
        solver = S.Adam(args.learning_rate, eps=args.adam_epsilon)
    else:
        solver = S.AdamW(args.learning_rate, eps=args.adam_epsilon)
    solver.set_parameters(nn.get_parameters())

    monitor = Monitor(args.output_dir)
    monitor_loss = MonitorSeries(
        "Training Loss", monitor, interval=10)
    monitor_eloss = MonitorSeries(
        "Evaluation Loss", monitor, interval=10)
    monitor_train_error = MonitorSeries(
        "Training Error Rate", monitor, interval=10)
    monitor_lr = MonitorSeries(
        "learning Rate", monitor, interval=10)

    total_steps = train_dataloader.size // args.train_batch_size
    var_linear = total_steps * args.num_train_epochs
    var_warmup = total_steps * (args.num_train_epochs - 1)
    for epoch in range(args.num_train_epochs):
        logger.info("Starting Epoch %d out of %d",
                    epoch+1, args.num_train_epochs)
        for it in range(total_steps):
            batch = train_dataloader.next()
            input_ids.d = batch[0]
            attention_mask.d = batch[1]
            token_type_ids.d = batch[2]
            labels.d = batch[3]

            learning_rate_linear = lr_linear(global_step, var_linear)
            learning_rate = args.learning_rate * learning_rate_linear

            if epoch == 0:
                learning_rate = args.learning_rate * (global_step/total_steps)
            if epoch > 0:
                learning_rate_linear = lr_linear(
                    (global_step-total_steps), var_warmup)
                learning_rate = args.learning_rate * learning_rate_linear

            solver.zero_grad()
            nn.forward_all([loss, train_error], clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.weight_decay(args.weight_decay)
            solver.clip_grad_by_norm(args.max_grad_norm)
            solver.set_learning_rate(learning_rate)
            solver.update()

            monitor_loss.add(
                (train_dataloader.size//args.train_batch_size)*epoch+it,
                loss.d.copy())
            monitor_train_error.add(
                (train_dataloader.size//args.train_batch_size)*epoch+it,
                train_error.d.copy())
            monitor_lr.add(global_step, learning_rate)
            global_step += 1
            train_loss += F.mean(loss.data)

        eval_task_names = (
            "mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
        eval_outputs_dirs = (args.output_dir, args.output_dir +
                             '-MM') if args.task_name == "mnli" else (args.output_dir,)

        results = {}
        for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
            print(eval_task)
            eval_dataset = BERTDataSource(
                args, tokenizer, evaluate=True, shuffle=False)
            if not os.path.exists(eval_output_dir):
                os.makedirs(eval_output_dir)

            eval_dataloader = data_iterator(
                eval_dataset, batch_size=args.eval_batch_size)
            total_eval_steps = eval_dataloader.size // args.eval_batch_size
            eval_loss = 0.0
            nb_eval_steps = 0
            preds = None
            out_label_ids = None
            tmp_eval_loss, logits, eval_error = model(args, input_ids=input_ids_eval,
                                                      attention_mask=attention_mask_eval,
                                                      token_type_ids=token_type_ids_eval, labels=labels_eval,
                                                      num_labels=args.num_labels, vocab_size=args.vocab_size,
                                                      num_embed_dim=args.num_embed_dim,
                                                      num_pos_ids=args.num_position_ids,
                                                      num_attention_layers=args.num_attention_layers,
                                                      num_attention_embed_dim=args.num_attention_embed_dim,
                                                      num_attention_heads=args.num_attention_heads,
                                                      num_attention_dim_feedforward=args.num_attention_dim_feedforward,
                                                      attention_activation=activation, pool_outmap=args.num_pool_outmap,
                                                      embed_dropout_prob=args.embed_dropout,
                                                      attention_dropout_prob=args.attention_dropout,
                                                      dropout_prob=args.last_dropout, test=True)

            tmp_eval_loss.persistent = True
            eval_loss += F.mean(tmp_eval_loss)
            for it in range(total_eval_steps):
                print(it, "  ", total_eval_steps)
                batch_eval = eval_dataloader.next()
                input_ids_eval.d = batch_eval[0]
                attention_mask_eval.d = batch_eval[1]
                token_type_ids_eval.d = batch_eval[2]
                labels_eval.d = batch_eval[3]
                nb_eval_steps += 1
                eval_loss.forward()
                monitor_eloss.add(it, eval_loss.d.copy())

                if preds is None:
                    preds = logits.d.copy()
                    out_label_ids = labels_eval.d.copy()
                else:
                    preds = np.append(preds, logits.d.copy(), axis=0)

                    out_label_ids = np.append(
                        out_label_ids, labels_eval.d.copy(), axis=0)
            eval_loss = eval_loss.d / nb_eval_steps
            if args.output_mode == "classification":
                preds = np.argmax(preds, axis=1)
            elif args.output_mode == "regression":
                preds = np.squeeze(preds)

            result = compute_metrics(eval_task, preds, out_label_ids)
            results.update(result)

            output_eval_file = os.path.join(
                eval_output_dir, "", "eval_results.txt")
            with open(output_eval_file, "a") as writer:
                logger.info("***** Evaluation results {} *****".format(""))
                for key in sorted(result.keys()):
                    logger.info("%d  %s = %s\n", epoch +
                                1, key, str(result[key]))
                    writer.write("%d %s = %s\n" %
                                 (epoch+1, key, str(result[key])))
                print("results", results)
    return results
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Communicator and Context
    extension_module = "cuda.cudnn"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx = extension_context(extension_module, device_id=device_id)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = cifar100_resnet23_prediction(
        image_train, ctx, test)
    loss_train = cifar100_resnet32_loss(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # add parameters to communicator
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar100_resnet23_prediction(
        image_valid, ctx, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(1. * n_train_samples /
                      args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = 1. * n_devices / warmup_iter

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)
    with data_iterator_cifar100(args.batch_size, True) as tdata, \
            data_iterator_cifar100(bs_valid, False) as vdata:
        # Training-loop
        for i in range(int(args.max_iter / n_devices)):
            # Validation
            if mpi_rank == 0:
                if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                    ve = 0.
                    for j in range(args.val_iter):
                        image, label = vdata.next()
                        input_image_valid["image"].d = image
                        pred_valid.forward()
                        ve += categorical_error(pred_valid.d, label)
                    ve /= args.val_iter
                    monitor_verr.add(i * n_devices, ve)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(os.path.join(
                        args.model_save_path, 'params_%06d.h5' % i))

            # Forward/Zerograd/Backward
            image, label = tdata.next()
            input_image_train["image"].d = image
            input_image_train["label"].d = label
            loss_train.forward()
            solver.zero_grad()
            loss_train.backward()

            # In-place Allreduce
            comm.allreduce(division=True)

            # Solvers update
            solver.update()

            # Linear Warmup
            if i < warmup_iter:
                lr = base_lr * n_devices * warmup_slope * i
                solver.set_learning_rate(lr)
            else:
                lr = base_lr * n_devices
                solver.set_learning_rate(lr)

            if mpi_rank == 0:
                e = categorical_error(
                    pred_train.d, input_image_train["label"].d)
                monitor_loss.add(i * n_devices, loss_train.d.copy())
                monitor_err.add(i * n_devices, e)
                monitor_time.add(i * n_devices)
    if mpi_rank == 0:
        nn.save_parameters(os.path.join(
            args.model_save_path,
            'params_%06d.h5' % (args.max_iter / n_devices)))
Esempio n. 27
0
def train(args):
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction
    elif args.net == 'cifar10_binary_connect_resnet23_prediction':
        model_prediction = cifar10_binary_connect_resnet23_prediction
    elif args.net == 'cifar10_binary_net_resnet23_prediction':
        model_prediction = cifar10_binary_net_resnet23_prediction
    elif args.net == 'cifar10_binary_weight_resnet23_prediction':
        model_prediction = cifar10_binary_weight_resnet23_prediction
    elif args.net == 'cifar10_fp_connect_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_fp_connect_resnet23_prediction,
            n=args.bit_width,
            delta=args.delta)
    elif args.net == 'cifar10_fp_net_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_fp_net_resnet23_prediction,
            n=args.bit_width,
            delta=args.delta)
    elif args.net == 'cifar10_pow2_connect_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_pow2_connect_resnet23_prediction,
            n=args.bit_width,
            m=args.upper_bound)
    elif args.net == 'cifar10_pow2_net_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_pow2_net_resnet23_prediction,
            n=args.bit_width,
            m=args.upper_bound)
    elif args.net == 'cifar10_inq_resnet23_prediction':
        model_prediction = functools.partial(cifar10_inq_resnet23_prediction,
                                             num_bits=args.bit_width)
    elif args.net == 'cifar10_min_max_resnet23_prediction':
        model_prediction = functools.partial(
            cifar10_min_max_resnet23_prediction,
            ql_min=args.ql_min,
            ql_max=args.ql_max,
            p_min_max=args.p_min_max,
            a_min_max=args.a_min_max,
            a_ema=args.a_ema,
            ste_fine_grained=args.ste_fine_grained)

    # TRAIN
    maps = 64
    data_iterator = data_iterator_cifar10
    c = 3
    h = w = 32
    n_train = 50000
    n_valid = 10000

    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    label = nn.Variable([args.batch_size, 1])
    # Create model_prediction graph.
    pred = model_prediction(image, maps=maps, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = model_prediction(vimage, maps=maps, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
            if ve < best_ve:
                nn.save_parameters(
                    os.path.join(args.model_save_path, 'params_%06d.h5' % i))
                best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Esempio n. 28
0
def train():
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction

    # TRAIN
    maps = 64
    data_iterator = data_iterator_cifar10
    c = 3
    h = w = 32
    n_train = 50000
    n_valid = 10000

    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    label = nn.Variable([args.batch_size, 1])
    # Create model_prediction graph.
    pred = model_prediction(image, maps=maps, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # SSL Regularization
    loss += ssl_regularization(nn.get_parameters(), args.filter_decay,
                               args.channel_decay)

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = model_prediction(vimage, maps=maps, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Esempio n. 29
0
def main(args):
    # Setting
    device_id = args.device_id
    conf = args.conf
    path = args.conf.data_path
    B = conf.batch_size
    R = conf.n_rays
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    # Dataset
    ds = DTUMVSDataSource(path, R, shuffle=True)
    di = data_iterator_dtumvs(ds, B)

    camloc = nn.Variable([B, 3])
    raydir = nn.Variable([B, R, 3])
    alpha = nn.Variable.from_numpy_array(conf.alpha)
    color_gt = nn.Variable([B, R, 3])
    mask_obj = nn.Variable([B, R, 1])

    # Monitor
    interval = di.size
    monitor_path = create_monitor_path(conf.data_path, args.monitor_path)
    monitor = Monitor(monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=interval)
    monitor_mhit = MonitorSeries("Hit count", monitor, interval=1)
    monitor_color_loss = MonitorSeries(
        "Training color loss", monitor, interval=interval)
    monitor_mask_loss = MonitorSeries(
        "Training mask loss", monitor, interval=interval)
    monitor_eikonal_loss = MonitorSeries(
        "Training eikonal loss", monitor, interval=interval)
    monitor_time = MonitorTimeElapsed(
        "Training time", monitor, interval=interval)
    monitor_image = MonitorImage("Rendered image", monitor, interval=1)

    # Solver
    solver = S.Adam(conf.learning_rate)
    loss, color_loss, mask_loss, eikonal_loss, mask_hit = \
        idr_loss(camloc, raydir, alpha, color_gt, mask_obj, conf)
    solver.set_parameters(nn.get_parameters())

    # Training loop
    for i in range(conf.train_epoch):

        ds.change_sampling_idx()

        # Validate
        if i % conf.valid_epoch_interval == 0 and not args.skip_val:
            def validate(i):
                pose_ = ds.poses[conf.valid_index:conf.valid_index+1, ...]
                intrinsic_ = ds.intrinsics[conf.valid_index:conf.valid_index+1, ...]
                mask_obj_ = ds.masks[conf.valid_index:conf.valid_index+1, ...]
                image = render(pose_, intrinsic_, mask_obj_, conf)
                monitor_image.add(i, image)
                nn.save_parameters(f"{monitor_path}/model_{i:05d}.h5")
            validate(i)

        # Train
        for j in range(di.size):
            # Feed data
            color_, mask_, intrinsic_, pose_, xy_ = di.next()
            color_gt.d = color_
            mask_obj.d = mask_
            raydir_, camloc_ = generate_raydir_camloc(pose_, intrinsic_, xy_)
            raydir.d = raydir_
            camloc.d = camloc_

            # Network
            loss.forward()
            solver.zero_grad()
            loss.backward(clear_buffer=True)
            solver.update()

            # Monitor
            t = i * di.size + j
            monitor_mhit.add(t, np.sum(mask_hit.d))
            monitor_loss.add(t, loss.d)
            monitor_color_loss.add(t, color_loss.d)
            monitor_mask_loss.add(t, mask_loss.d)
            monitor_eikonal_loss.add(t, eikonal_loss.d)
            monitor_time.add(t)

        # Decay
        if i in conf.alpha_decay:
            alpha.d = alpha.d * 2.0
        if i in conf.lr_decay:
            solver.set_learning_rate(solver.learning_rate() * 0.5)

    validate(i)
Esempio n. 30
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    lambda_ = args.lambda_

    # Model
    # generator loss
    z = nn.Variable([batch_size, latent])
    x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True)
    p_fake = discriminator(x_fake, maps=maps)
    loss_gen = gan_loss(p_fake).apply(persistent=True)
    # discriminator loss
    p_fake = discriminator(x_fake, maps=maps)
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, maps=maps)
    loss_dis = gan_loss(p_fake, p_real).apply(persistent=True)
    # gradient penalty
    eps = F.rand(shape=[batch_size, 1, 1, 1])
    x_rmix = eps * x_real + (1.0 - eps) * x_fake
    p_rmix = discriminator(x_rmix, maps=maps)
    x_rmix.need_grad = True  # Enabling gradient computation for double backward
    grads = nn.grad([p_rmix], [x_rmix])
    l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads]
    gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms])
    loss_dis += lambda_ * gp
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    x_test = generator(z_test, maps=maps, test=True,
                       up=args.up).apply(persistent=True)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)

    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_cri = MonitorSeries("Negative Critic Loss",
                                     monitor,
                                     interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=batch_size,
                                                interval=1,
                                                normalize_method=denormalize)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Data Iterator
    di = data_iterator_cifar10(batch_size, True)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need backward to generator
        for _ in range(args.n_critic):
            solver_dis.zero_grad()
            x_real.d = di.next()[0] / 127.5 - 1.0
            z.d = np.random.randn(batch_size, latent)
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(clear_buffer=True)
            solver_dis.update()

        # Train generator
        x_fake.need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        z.d = np.random.randn(batch_size, latent)
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()
        # Monitor
        monitor_loss_gen.add(i, loss_gen.d)
        monitor_loss_cri.add(i, -loss_dis.d)
        monitor_time.add(i)

        # Save
        if i % args.save_interval == 0:
            monitor_image_tile_train.add(i, x_fake)
            monitor_image_tile_test.add(i, x_test)
            nn.save_parameters(
                os.path.join(args.monitor_path, "params_{}.h5".format(i)))

    # Last
    x_test.forward(clear_buffer=True)
    nn.save_parameters(
        os.path.join(args.monitor_path, "params_{}.h5".format(i)))
    monitor_image_tile_train.add(i, x_fake)
    monitor_image_tile_test.add(i, x_test)
Esempio n. 31
0
def CNN_run(args, both_archs, data_dict, with_train=False, after_search=False):
    """
    """

    num_cells = args.num_cells
    num_nodes = args.num_nodes

    if after_search:
        assert with_train is True, "when you train the network after architecture search, set with_train=True"
    tdata, mean_val_train, std_val_train = data_dict["train_data"]
    vdata, mean_val_valid, std_val_valid = data_dict["valid_data"]
    channels, image_height, image_width, num_class = data_dict["basic_info"]
    batch_size = args.batch_size

    output_filter = args.output_filter

    if with_train:
        if after_search:
            num_epoch = args.epoch_on_retrain
            if args.additional_filters_on_retrain > 0:
                output_filter += args.additional_filters_on_retrain
        else:
            num_epoch = args.epoch_per_search

        one_epoch = tdata.size // batch_size
        max_iter = num_epoch * one_epoch

    val_iter = args.val_iter

    monitor_path = args.monitor_path
    model_save_path = args.monitor_path
    decay_rate = args.weight_decay
    initial_lr = args.child_lr

    model_save_interval = args.model_save_interval

    image_valid = nn.Variable(
        (batch_size, channels, image_height, image_width))
    input_image_valid = {"image": image_valid}

    vdata._reset()  # rewind data

    test = True
    pred_valid, _, _ = construct_architecture(image_valid, num_class, num_cells, num_nodes,
                                              both_archs, output_filter, test)

    if with_train:
        if after_search:
            # setting for training after architecture search
            with_grad_clip = args.with_grad_clip_on_retrain
            grad_clip = args.grad_clip_value
            lr_control = args.lr_control_on_retrain
        else:
            with_grad_clip = args.with_grad_clip_on_search
            grad_clip = args.grad_clip_value
            lr_control = args.lr_control_on_search

        # prepare variables used for training
        image_train = nn.Variable(
            (batch_size, channels, image_height, image_width))
        label_train = nn.Variable((batch_size, 1))
        input_image_train = {"image": image_train, "label": label_train}

        tdata._reset()  # rewind data

        test = False
        pred_train, aux_logits, used_weights = construct_architecture(image_train, num_class, num_cells, num_nodes,
                                                                      both_archs, output_filter, test)
        loss_train = loss_function(pred_train, aux_logits, label_train)

        used_weights_dict = {key_name: nn.get_parameters(
        )[key_name] for key_name in used_weights}

        # Create monitor.
        monitor = Monitor(monitor_path)
        monitor_loss = MonitorSeries("Training loss", monitor, interval=100)
        # modified to display accuracy.
        monitor_err = MonitorSeries("Training accuracy", monitor, interval=100)
        # modified to display accuracy.
        monitor_verr = MonitorSeries("Test accuracy", monitor, interval=1)

        # Solvers
        solver = S.Momentum(initial_lr)
        solver.set_parameters(
            used_weights_dict, reset=False, retain_state=True)

        # Training-loop
        for i in range(max_iter):
            if i > 0 and i % one_epoch == 0:
                # Validation during training.
                ve = 0.
                for j in range(val_iter):
                    image, label = vdata.next()
                    image = image / 255.0
                    image = (image - mean_val_valid) / std_val_valid
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= val_iter
                monitor_verr.add(i, 1.0 - ve)  # modified to display accuracy.

            if after_search and int(i % args.model_save_interval) == 0:
                nn.save_parameters(os.path.join(
                    args.model_save_path, 'params_%06d.h5' % i))

            # Forward/Zerograd/Backward
            image, label = tdata.next()
            image = image / 255.0
            image = (image - mean_val_train) / std_val_train
            input_image_train["image"].d = image
            input_image_train["label"].d = label
            loss_train.forward()

            if lr_control:
                new_lr = learning_rate_scheduler(i, max_iter, initial_lr, 0)
                solver.set_learning_rate(new_lr)

            solver.zero_grad()
            loss_train.backward()

            if with_grad_clip:
                for k, v in used_weights_dict.items():
                    if np.linalg.norm(v.g) > grad_clip:
                        v.grad.copy_from(F.clip_by_norm(v.grad, grad_clip))

            # Solvers update
            solver.weight_decay(decay_rate)
            solver.update()
            e = categorical_error(pred_train.d, input_image_train["label"].d)
            monitor_loss.add(i, loss_train.d.copy())
            monitor_err.add(i, 1.0 - e)  # modified to display accuracy.

    # Validation (After training or when called for evaluation only)
    ve = 0.
    for j in range(val_iter):
        image, label = vdata.next()
        image = image / 255.0
        image = (image - mean_val_valid) / std_val_valid
        input_image_valid["image"].d = image
        pred_valid.forward()
        ve += categorical_error(pred_valid.d, label)
    ve /= val_iter

    if with_train:
        print("Validation Accuracy on Trained CNN:",
              '{:.2f}'.format(100*(1.0 - ve)), "%\n")

    if after_search:
        nn.save_parameters(os.path.join(
            args.model_save_path, 'params_%06d.h5' % (max_iter)))

    return 1.0 - ve
def classification_svd():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction_slim

    # TRAIN
    reference = "reference"
    slim = "slim"
    rrate = 0.5  # reduction rate
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create `reference` and "slim" prediction graph.
    model_load_path = args.model_load_path
    pred = mnist_cnn_prediction(image, scope=slim, rrate=rrate, test=False)
    pred.persistent = True

    # Decompose and set parameters
    decompose_network_and_set_params(model_load_path, reference, slim, rrate)
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create reference prediction graph.
    vpred = mnist_cnn_prediction(vimage, scope=slim, rrate=rrate, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    with nn.parameter_scope(slim):
        solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    best_ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Esempio n. 33
0
def distil():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction
        data_iterator = data_iterator_cifar10
        c = 3
        h = w = 32
        n_train = 50000
        n_valid = 10000

    # TRAIN
    teacher = "teacher"
    student = "student"
    maps = args.maps
    rrate = args.reduction_rate
    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    image.persistent = True  # not clear the intermediate buffer re-used
    label = nn.Variable([args.batch_size, 1])
    label.persistent = True  # not clear the intermediate buffer re-used
    # Create `teacher` and "student" prediction graph.
    model_load_path = args.model_load_path
    nn.load_parameters(model_load_path)
    pred_label = model_prediction(image,
                                  net=teacher,
                                  maps=maps,
                                  test=not args.use_batch)
    pred_label.need_grad = False  # no need backward through teacher graph
    pred = model_prediction(image,
                            net=student,
                            maps=int(maps * (1. - rrate)),
                            test=False)
    pred.persistent = True  # not clear the intermediate buffer used
    loss_ce = F.mean(F.softmax_cross_entropy(pred, label))
    loss_ce_soft = ce_soft(pred, pred_label)
    loss = args.weight_ce * loss_ce + args.weight_ce_soft * loss_ce_soft

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create teacher prediction graph.
    vpred = model_prediction(vimage,
                             net=student,
                             maps=int(maps * (1. - rrate)),
                             test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    with nn.parameter_scope(student):
        solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator for MNIST.
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata[1].next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data[1].next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata[1].next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)