Exemplo n.º 1
0
def main(args):
    # Setting
    device_id = args.device_id
    conf = args.conf
    path = conf.data_path
    B = conf.batch_size
    R = conf.n_rays
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    ctx = get_extension_context('cudnn', device_id=device_id)
    nn.set_default_context(ctx)

    # Dataset
    ds = DTUMVSDataSource(path, R, shuffle=True)

    # Monitor
    monitor_path = "/".join(args.model_load_path.split("/")[0:-1])
    monitor = Monitor(monitor_path)
    monitor_psnrs = MonitorSeries(f"PSNRs", monitor, interval=1)
    monitor_psnr = MonitorSeries(f"PSNR", monitor, interval=1)

    # Load model
    nn.load_parameters(args.model_load_path)

    # Evaluate
    image_list = []
    for pose, intrinsic, mask_obj in zip(ds.poses, ds.intrinsics, ds.masks):
        image = render(pose[np.newaxis, ...], intrinsic[np.newaxis, ...],
                       mask_obj[np.newaxis, ...], conf)
        image_list.append(image)

    metric = psnr(image_list, ds.images, ds.masks, monitor_psnrs)
    monitor_psnr.add(0, metric)
Exemplo n.º 2
0
def get_esrgan_monitors():
    """
    Create monitors for displaying and storing ESRGAN losses.
    """
    monitor_path = './nnmonitor' + str(datetime.now().strftime("%Y%m%d%H%M%S"))
    monitor = Monitor(monitor_path)
    monitor_feature_g = MonitorSeries('l_g_fea per iteration',
                                      monitor,
                                      interval=100)
    monitor_gan_g = MonitorSeries('l_g_gan per iteration',
                                  monitor,
                                  interval=100)
    monitor_gan_d = MonitorSeries('l_d_total per iteration',
                                  monitor,
                                  interval=100)
    monitor_d_real = MonitorSeries('D_real per iteration',
                                   monitor,
                                   interval=100)
    monitor_d_fake = MonitorSeries('D_fake per iteration',
                                   monitor,
                                   interval=100)
    Monitor_esrgan = namedtuple('Monitor_esrgan', [
        'monitor_feature_g', 'monitor_gan_g', 'monitor_gan_d',
        'monitor_d_real', 'monitor_d_fake'
    ])
    return Monitor_esrgan(monitor_feature_g, monitor_gan_g, monitor_gan_d,
                          monitor_d_real, monitor_d_fake)
Exemplo n.º 3
0
 def __init__(self,
              env,
              replay_memory,
              explorer,
              obs_sampler,
              num_episodes=10,
              clip_episode_step=False,
              num_ces=4500 * 4,
              num_eval_steps=125000 * 4,
              render=False,
              monitor=None,
              tbw=None):
     self.env = env
     self.replay_memory = replay_memory
     self.obs_sampler = obs_sampler
     self.explorer = explorer
     self.num_episodes = num_episodes
     self.clip_episode_step = clip_episode_step
     self.num_ces = num_ces
     self.num_eval_steps = num_eval_steps
     self.render = render
     self.steps = 0
     self.losses = []
     self.start_time = time.time()
     self.monitor = None
     if monitor:
         monitor_val_score = MonitorSeries("Validation Score", monitor)
         monitor_eval_score = MonitorSeries("Eval Score", monitor)
         self.monitor = {
             'val_score': monitor_val_score,
             'eval_score': monitor_eval_score
         }
     self.tbw = tbw
Exemplo n.º 4
0
def evaluate(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn

    # Model (Inception model) from nnp file
    nnp = NnpLoader(args.nnp_inception_model_load_path)
    x, y = get_input_and_output(nnp, args.batch_size, name=args.variable_name)

    if args.evaluation_metric == "IS":
        is_model = None
        compute_metric = compute_inception_score
    if args.evaluation_metric == "FID":
        di = data_iterator_imagenet(args.valid_dir,
                                    args.dirname_to_label_path,
                                    batch_size=args.batch_size,
                                    ih=args.image_size,
                                    iw=args.image_size,
                                    shuffle=True,
                                    train=False,
                                    noise=False)
        compute_metric = functools.partial(compute_frechet_inception_distance,
                                           di=di)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_metric = MonitorSeries("{}".format(args.evaluation_metric),
                                   monitor,
                                   interval=1)

    # Compute the evaluation metric for all models
    def cmp_func(path):
        return int(path.split("/")[-1].strip("params_").rstrip(".h5"))
    model_load_path = sorted(glob.glob("{}/*.h5".format(args.model_load_path)), key=cmp_func) \
        if os.path.isdir(args.model_load_path) else \
        [args.model_load_path]

    for path in model_load_path:
        # Model (SAGAN)
        nn.load_parameters(path)
        z = nn.Variable([batch_size, latent])
        y_fake = nn.Variable([batch_size])
        x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
            .apply(persistent=True)
        # Compute the evaluation metric
        score = compute_metric(z, y_fake, x_fake, x, y, args)
        itr = cmp_func(path)
        monitor_metric.add(itr, score)
Exemplo n.º 5
0
def train(env,
          model,
          buffer,
          exploration,
          monitor,
          update_fn,
          eval_fn,
          final_step,
          update_start,
          update_interval,
          save_interval,
          evaluate_interval,
          loss_labels=[]):
    reward_monitor = MonitorSeries('reward', monitor, interval=1)
    eval_reward_monitor = MonitorSeries('eval_reward', monitor, interval=1)
    time_monitor = MonitorTimeElapsed('time', monitor, interval=10000)
    loss_monitors = []
    for label in loss_labels:
        loss_monitors.append(MonitorSeries(label, monitor, interval=10000))

    step = 0
    while step <= final_step:
        obs_t = env.reset()
        ter_tp1 = False
        cumulative_reward = 0.0
        model.reset(step)
        while not ter_tp1:
            # select best action
            act_t = model.infer(obs_t)
            # add exploration noise
            act_t = exploration.get(step, act_t)
            # iterate environment
            obs_tp1, rew_tp1, ter_tp1, _ = env.step(act_t)
            # store transition
            buffer.append(obs_t, [act_t], rew_tp1, obs_tp1, ter_tp1)

            # update parameters
            if step > update_start and step % update_interval == 0:
                for i, loss in enumerate(update_fn(step)):
                    if loss is not None:
                        loss_monitors[i].add(step, loss)

            # save parameters
            if step % save_interval == 0:
                path = os.path.join(monitor.save_path, 'model_%d.h5' % step)
                nn.save_parameters(path)

            if step % evaluate_interval == 0:
                eval_reward_monitor.add(step, np.mean(eval_fn()))

            step += 1
            cumulative_reward += rew_tp1
            obs_t = obs_tp1
            time_monitor.add(step)

        # record metrics
        reward_monitor.add(step, cumulative_reward)
Exemplo n.º 6
0
def meta_test(args, shape_x, test_data):

    # Build episode generators
    test_episode_generator = EpisodeGenerator(
        test_data[0], test_data[1], args.n_class, args.n_shot, args.n_query)

    # Build prototypical network
    xs_v = nn.Variable((args.n_class * args.n_shot, ) + shape_x)
    xq_v = nn.Variable((args.n_class * args.n_query, ) + shape_x)
    hq_v = net(args.n_class, xs_v, xq_v, args.embedding,
               args.net_type, args.metric, True)
    yq_v = nn.Variable((args.n_class * args.n_query, 1))
    err_v = F.mean(F.top_n_error(hq_v, yq_v, n=1))

    # Load parameters
    nn.load_parameters(args.work_dir + "/params.h5")

    # Evaluate error rate
    v_errs = []
    for k in range(args.n_episode_for_test):
        xs_v.d, xq_v.d, yq_v.d = test_episode_generator.next()
        err_v.forward(clear_no_need_grad=True, clear_buffer=True)
        v_errs.append(np.float(err_v.d.copy()))
    v_err_mean = np.mean(v_errs)
    v_err_std = np.std(v_errs)
    v_err_conf = 1.96 * v_err_std / np.sqrt(args.n_episode_for_test)

    # Monitor error rate
    monitor = Monitor(args.work_dir)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)
    monitor_test_err.add(0, v_err_mean * 100)
    monitor_test_conf.add(0, v_err_conf * 100)

    return v_err_mean, v_err_conf
Exemplo n.º 7
0
def main():
    # Training settings
    args = Yolov2OptionTraining().parse_args()

    nsamples = file_lines(args.train)

    set_default_context_by_args(args)

    # Training parameters
    max_epochs = args.max_batches * args.batch_size * args.accum_times / nsamples + 1

    if not os.path.exists(args.output):
        os.mkdir(args.output)

    ###############

    # Load parameters
    print("Load", args.weight, "...")
    if args.fine_tune:
        nn.load_parameters(args.weight)
        nn.parameter.pop_parameter("detection/conv/W")
        nn.parameter.pop_parameter("detection/conv/b")
    else:
        nn.load_parameters(args.weight)

    train_graph = TrainGraph(args, (args.size_aug[-1], args.size_aug[-1]))
    yolo_solver = YoloSolver(args)

    if args.on_memory_data:
        on_memory_data = dataset.load_on_memory_data(args.train)
    else:
        on_memory_data = None

    prefetch_iterator = PrefetchIterator()

    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.output)
    monitor_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_time = MonitorTimeElapsed('Time per epoch', monitor, interval=1)

    # Epoch loop
    for epoch in range(0, int(max_epochs)):
        loss = train(args, epoch, max_epochs, train_graph, yolo_solver,
                     prefetch_iterator, on_memory_data)
        monitor_loss.add(epoch, loss)
        monitor_time.add(epoch)

    # Save the final parameters
    logging('save weights to %s/%06d.h5' % (args.output, epoch + 1))
    nn.save_parameters('%s/%06d.h5' % (args.output, epoch + 1))
def get_common_monitors(monitor):
    """
    Create monitors for displaying and storing losses.
    """
    monitor_content_loss = MonitorSeries('content loss', monitor, interval=20)
    monitor_gen_loss = MonitorSeries('generator loss', monitor, interval=20)
    monitor_warp_loss = MonitorSeries('warp loss', monitor, interval=20)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=20)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=20)
    Monitor_common = collections.namedtuple('Monitor_common', [
        'monitor_content_loss', 'monitor_gen_loss', 'monitor_warp_loss',
        'monitor_lr', 'monitor_time'
    ])
    return Monitor_common(monitor_content_loss, monitor_gen_loss,
                          monitor_warp_loss, monitor_lr, monitor_time)
Exemplo n.º 9
0
def create_monitor(net_name: str, interval=100):
    value = namedtuple("value", ("loss", "error"))
    monitor = namedtuple("monitor", ("path", "time", "train", "val"))

    path = get_monitor_path(net_name)
    kwargs = dict(monitor=Monitor(path), interval=interval)
    time = MonitorTimeElapsed("time", **kwargs)

    kwargs["verbose"] = False
    loss = MonitorSeries("train_loss", **kwargs)
    error = MonitorSeries("train_error", **kwargs)
    value_train = value(loss, error)

    loss = MonitorSeries("val_loss", **kwargs)
    error = MonitorSeries("val_error", **kwargs)
    value_val = value(loss, error)

    return monitor(path, time, value_train, value_val)
Exemplo n.º 10
0
 def __init__(self,
              env,
              replay_memory,
              learner,
              sampler,
              explorer,
              obs_sampler,
              inter_eval_steps=-1,
              num_episodes=10,
              train_start=1000,
              train_freq=4,
              batch_size=32,
              render=False,
              validator=None,
              monitor=None,
              tbw=None):
     self.env = env
     self.replay_memory = replay_memory
     self.learner = learner
     self.sampler = sampler
     self.explorer = explorer
     self.obs_sampler = obs_sampler
     self.inter_eval_steps = inter_eval_steps
     self.num_episodes = num_episodes
     self.train_start = train_start
     self.train_started = False
     self.train_freq = train_freq
     self.batch_size = batch_size
     self.render = render
     self.validator = validator
     self.steps = 0
     self.losses = []
     self.monitor = None
     if monitor:
         monitor_loss = MonitorSeries("QNetwork Loss", monitor)
         monitor_epsilon = MonitorSeries("Epsilon", monitor)
         monitor_train_reward = MonitorSeries("Train Reward", monitor)
         self.monitor = {
             'loss': monitor_loss,
             'epsilon': monitor_epsilon,
             'reward': monitor_train_reward
         }
     self.tbw = tbw
Exemplo n.º 11
0
def setup_monitor(conf, monitor):
    """
    Setup monitor to keep track of losses and times to log them
    """
    jsi_monitor = {
        'rec_loss':
        MonitorSeries('rec_loss', monitor, interval=conf.monitor_interval),
        'psnr':
        MonitorSeries('psnr', monitor, interval=conf.monitor_interval),
        'lr':
        MonitorSeries('learning_rate', monitor,
                      interval=conf.monitor_interval),
        'g_final_loss':
        MonitorSeries('g_final_loss', monitor, interval=conf.monitor_interval),
        'd_final_fm_loss':
        MonitorSeries('d_final_fm_loss',
                      monitor,
                      interval=conf.monitor_interval),
        'd_final_detail_loss':
        MonitorSeries('d_final_detail_loss',
                      monitor,
                      interval=conf.monitor_interval),
        'g_adv_loss':
        MonitorSeries('g_adv_loss', monitor, interval=conf.monitor_interval),
        'g_detail_adv_loss':
        MonitorSeries('g_detail_adv_loss',
                      monitor,
                      interval=conf.monitor_interval),
        'fm_loss':
        MonitorSeries('fm_loss', monitor, interval=conf.monitor_interval),
        'fm_detail_loss':
        MonitorSeries('fm_detail_loss',
                      monitor,
                      interval=conf.monitor_interval),
        'time':
        MonitorTimeElapsed("Training time per epoch",
                           monitor,
                           interval=conf.monitor_interval)
    }
    return jsi_monitor
Exemplo n.º 12
0
def main():
    # Context
    ctx = get_extension_context("cudnn", device_id="0")
    nn.set_default_context(ctx)
    nn.auto_forward(False)
    # Inputs
    b, c, h, w = 64, 1, 28, 28
    x = nn.Variable([b, c, h, w])
    t = nn.Variable([b, 1])
    vx = nn.Variable([b, c, h, w])
    vt = nn.Variable([b, 1])
    # Model
    model = Model()
    pred = model(x)
    loss = F.softmax_cross_entropy(pred, t)
    vpred = model(vx, test=True)
    verror = F.top_n_error(vpred, vt)
    # Solver
    solver = S.Adam()
    solver.set_parameters(model.get_parameters(grad_only=True))
    # Data Iterator
    tdi = data_iterator_mnist(b, train=True)
    vdi = data_iterator_mnist(b, train=False)
    # Monitor
    monitor = Monitor("tmp.monitor")
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Training loop
    for e in range(1):
        for j in range(tdi.size // b):
            i = e * tdi.size // b + j
            x.d, t.d = tdi.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()
            monitor_loss.add(i, loss.d)
        error = 0.0
        for _ in range(vdi.size // b):
            vx.d, vt.d = vdi.next()
            verror.forward(clear_buffer=True)
            error += verror.d
        error /= vdi.size // b
        monitor_verr.add(i, error)
Exemplo n.º 13
0
def get_tecogan_monitors(monitor):
    """
    Create monitors for displaying and storing TECOGAN losses.
    """
    monitor_vgg_loss = MonitorSeries('vgg loss', monitor, interval=20)
    monitor_pp_loss = MonitorSeries('ping pong', monitor, interval=20)
    monitor_sum_layer_loss = MonitorSeries('d layer loss',
                                           monitor,
                                           interval=20)
    monitor_adv_loss = MonitorSeries('adversarial loss', monitor, interval=20)
    monitor_disc_loss = MonitorSeries('discriminator loss',
                                      monitor,
                                      interval=20)
    monitor_tb = MonitorSeries('tb', monitor, interval=20)
    Monitor_tecogan = collections.namedtuple('Monitor_tecogan', [
        'monitor_vgg_loss', 'monitor_pp_loss', 'monitor_sum_layer_loss',
        'monitor_adv_loss', 'monitor_disc_loss', 'monitor_tb'
    ])
    return Monitor_tecogan(monitor_vgg_loss, monitor_pp_loss,
                           monitor_sum_layer_loss, monitor_adv_loss,
                           monitor_disc_loss, monitor_tb)
Exemplo n.º 14
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
Exemplo n.º 15
0
 def make_monitor(name):
     return MonitorSeries(name, monitor, interval=1)
Exemplo n.º 16
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Data Iterator
    tdata = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(args.max_iter):
        # Validation
        if i % int(n_train_samples / args.batch_size) == 0:
            ve = 0.
            for j in range(args.val_iter):
                image, label = vdata.next()
                input_image_valid["image"].d = image
                pred_valid.forward()
                ve += categorical_error(pred_valid.d, label)
            ve /= args.val_iter
            monitor_verr.add(i, ve)
        if int(i % args.model_save_interval) == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Solvers update
        solver.update()

        e = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_loss.add(i, loss_train.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter)))
Exemplo n.º 17
0
h = TimeDistributed(PF.affine)(h, hidden, name='hidden')
y = TimeDistributed(PF.affine)(h, vocab_size, name='output')

mask = F.sum(F.sign(t), axis=2)  # do not predict 'pad'.
entropy = TimeDistributedSoftmaxCrossEntropy(y, t) * mask
count = F.sum(mask, axis=1)
loss = F.mean(F.div2(F.sum(entropy, axis=1), count))

# Create solver.
solver = S.Momentum(1e-2, momentum=0.9)
solver.set_parameters(nn.get_parameters())

# Create monitor.
from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
monitor = Monitor('./tmp-lstmlm')
monitor_perplexity = MonitorSeries('perplexity', monitor, interval=1)
monitor_perplexity_valid = MonitorSeries('perplexity_valid',
                                         monitor,
                                         interval=1)

for epoch in range(max_epoch):
    train_loss_set = []
    for i in tqdm(range(num_train_batch)):
        x_batch, y_batch = train_data_iter.next()
        y_batch = y_batch.reshape(list(y_batch.shape) + [1])

        x.d, t.d = x_batch, y_batch

        loss.forward(clear_no_need_grad=True)
        train_loss_set.append(loss.d.copy())
        solver.zero_grad()
Exemplo n.º 18
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size

    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=32,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((args.batch_size, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    _, tdata = data_iterator(args.batch_size, True, rng)
    vsource, vdata = data_iterator(args.batch_size, False)

    # loss_error_train.forward()

    # Training-loop
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                input_image_valid["image"].d = image
                input_image_valid["label"].d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            # Save model
            if device_id == 0:
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(
                        os.path.join(args.model_save_path,
                                     'params_%06d.h5' % i))

        # Forward/Zerograd
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:  # loss and error locally, and elapsed time
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

        # exit(0)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
Exemplo n.º 19
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)
Exemplo n.º 20
0
    def __init__(self,
                 solver,
                 tinput=None,
                 tlabel=None,
                 tpred=None,
                 tdata=None,
                 vinput=None,
                 vlabel=None,
                 vpred=None,
                 vdata=None,
                 monitor_path=None,
                 model_save_path=None,
                 max_epoch=1,
                 iter_per_epoch=None,
                 val_iter=None):
        # Monitors
        monitor = Monitor(monitor_path)
        monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
        monitor_vloss = MonitorSeries("Valid loss", monitor, interval=1)
        monitor_time = MonitorTimeElapsed("Training time",
                                          monitor,
                                          interval=10)

        # Loss
        tpred = tpred.apply(persistent=True)
        tloss = F.mean(F.squared_error(tpred, tlabel))
        vpred = vpred.apply(persistent=True)
        vloss = F.mean(F.squared_error(vpred, vlabel))

        # Updater
        def tdata_feeder():
            tinput.d, tlabel.d = tdata.next()

        def update_callback_on_finish(i):
            monitor_loss.add(i, tloss.d)
            monitor_time.add(i)

        updater = Updater(
            solver,
            tloss,
            data_feeder=tdata_feeder,
            forward_callback_on_finish=forward_callback_on_finish,
            update_callback_on_finish=update_callback_on_finish)

        # Evaluator
        def vdata_feeder():
            vinput.d, vlabel.d = vdata.next()

        def vloss_callback_on_finish(i, v):
            monitor_vloss.add(i, v)

        val_iter = val_iter if val_iter is not None else vdata.size // vdata.batch_size
        evaluator = Evaluator(vloss,
                              data_feeder=vdata_feeder,
                              val_iter=val_iter,
                              callback_on_finish=vloss_callback_on_finish)

        # Trainer
        iter_per_epoch = iter_per_epoch if iter_per_epoch is not None \
            else tdata.size // tdata.batch_size
        self.trainer = Trainer(updater,
                               evaluator,
                               model_save_path,
                               max_epoch=max_epoch,
                               iter_per_epoch=iter_per_epoch)
Exemplo n.º 21
0
def main():
    # Args
    args = get_args()
    save_args(args)

    # Context
    ctx = extension_context(args.context,
                            device_id=args.device_id,
                            type_config=args.type_config)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Data Itrator
    di = data_iterator(args.img_path,
                       args.batch_size,
                       imsize=(args.imsize, args.imsize),
                       num_samples=args.train_samples,
                       dataset_name=args.dataset_name)
    # Model
    generator = Generator(use_bn=args.use_bn,
                          last_act=args.last_act,
                          use_wscale=args.not_use_wscale,
                          use_he_backward=args.use_he_backward)
    discriminator = Discriminator(use_ln=args.use_ln,
                                  alpha=args.leaky_alpha,
                                  use_wscale=args.not_use_wscale,
                                  use_he_backward=args.use_he_backward)

    # Solver
    solver_gen = S.Adam(alpha=args.learning_rate,
                        beta1=args.beta1,
                        beta2=args.beta2)
    solver_dis = S.Adam(alpha=args.learning_rate,
                        beta1=args.beta1,
                        beta2=args.beta2)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_dis = MonitorSeries("Discriminator Loss",
                                     monitor,
                                     interval=10)
    monitor_p_fake = MonitorSeries("Fake Probability", monitor, interval=10)
    monitor_p_real = MonitorSeries("Real Probability", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training Time per Resolution",
                                      monitor,
                                      interval=1)
    monitor_image_tile = MonitorImageTileWithName("Image Tile",
                                                  monitor,
                                                  num_images=4,
                                                  normalize_method=lambda x:
                                                  (x + 1.) / 2.)

    # TODO: use argment
    resolution_list = [4, 8, 16, 32, 64, 128]
    channel_list = [512, 512, 256, 128, 64, 32]

    trainer = Trainer(di,
                      generator,
                      discriminator,
                      solver_gen,
                      solver_dis,
                      args.monitor_path,
                      monitor_loss_gen,
                      monitor_loss_dis,
                      monitor_p_fake,
                      monitor_p_real,
                      monitor_time,
                      monitor_image_tile,
                      resolution_list,
                      channel_list,
                      n_latent=args.latent,
                      n_critic=args.critic,
                      save_image_interval=args.save_image_interval,
                      hyper_sphere=args.hyper_sphere,
                      l2_fake_weight=args.l2_fake_weight)

    # TODO: use images per resolution?
    trainer.train(args.epoch_per_resolution)
Exemplo n.º 22
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    rng = np.random.RandomState(313)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Communicator and Context
    extension_module = "cuda.cudnn"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx = extension_context(extension_module, device_id=device_id)
    nn.set_default_context(ctx)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # add parameters to communicator
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    tdata = data_iterator(args.batch_size, True, rng)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if device_id == 0:
            if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                ve = 0.
                for j in range(args.val_iter):
                    image, label = vdata.next()
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= args.val_iter
                monitor_verr.add(i * n_devices, ve)
            if i % int(args.model_save_interval / n_devices) == 0:
                nn.save_parameters(
                    os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Allreduce
        comm.allreduce(division=False, inplace=False)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:
            e = categorical_error(pred_train.d, input_image_train["label"].d)
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, e)
            monitor_time.add(i * n_devices)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct computation graphs for training and one for validation.
    * Initialize solvers and set parameter variables to those.
    * Instantiate a communicator and set parameter variables.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprops
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Create contexts
    extension_module = args.context
    if extension_module != "cuda" and \
            extension_module != "cuda.cudnn":
        raise Exception("Use `cuda` or `cuda.cudnn` extension_module.")
    n_devices = args.n_devices
    ctxs = []
    for i in range(n_devices):
        ctx = extension_context(extension_module, device_id=i)
        ctxs.append(ctx)
    ctx = ctxs[-1]

    # Create training graphs
    input_image_train = []
    preds_train = []
    losses_train = []
    test = False
    for i in range(n_devices):
        image = nn.Variable((args.batch_size, 3, 32, 32))
        label = nn.Variable((args.batch_size, 1))
        device_scope_name = "device{}".format(i)

        pred = cifar100_resnet23_prediction(
            image, ctxs[i], device_scope_name, test)
        loss = cifar100_resnet32_loss(pred, label)

        input_image_train.append({"image": image, "label": label})
        preds_train.append(pred)
        losses_train.append(loss)

    # Create validation graph
    test = True
    device_scope_name = "device{}".format(0)
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar100_resnet23_prediction(
        image_valid, ctxs[i], device_scope_name, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solvers = []
    for i in range(n_devices):
        with nn.context_scope(ctxs[i]):
            solver = S.Adam()
            device_scope_name = "device{}".format(i)
            with nn.parameter_scope(device_scope_name):
                params = nn.get_parameters()
                solver.set_parameters(params)
            solvers.append(solver)

    # Communicator
    comm = C.DataParalellCommunicator(ctx)
    for i in range(n_devices):
        device_scope_name = "device{}".format(i)
        with nn.parameter_scope(device_scope_name):
            ctx = ctxs[i]
            params = nn.get_parameters()
            comm.add_context_and_parameters((ctx, params))
    comm.init()

    # Create threadpools with one thread
    pools = []
    for _ in range(n_devices):
        pool = ThreadPool(processes=1)
        pools.append(pool)

    # Once forward/backward to safely secure memory
    for device_id in range(n_devices):
        data, label = \
            (np.random.randn(*input_image_train[device_id]["image"].shape),
             (np.random.rand(*input_image_train[device_id]["label"].shape) * 10).astype(np.int32))

        ret = pools[device_id].apply_async(forward_backward,
                                           (input_image_train[device_id]["image"], data,
                                            input_image_train[device_id]["label"], label,
                                               losses_train[device_id], solvers[device_id]))
        ret.get()
        losses_train[device_id].d  # sync to host

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)
    with data_iterator_cifar100(args.batch_size, True) as tdata, \
            data_iterator_cifar100(bs_valid, False) as vdata:
        # Training-loop
        for i in range(int(args.max_iter / n_devices)):
            # Validation
            if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                ve = 0.
                for j in range(args.val_iter):
                    image, label = vdata.next()
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= args.val_iter
                monitor_verr.add(i * n_devices, ve)
            if i % int(args.model_save_interval / n_devices) == 0:
                nn.save_parameters(os.path.join(
                    args.model_save_path, 'params_%06d.h5' % i))

            # Forwards/Zerograd/Backwards
            fb_results = []
            for device_id in range(n_devices):
                image, label = tdata.next()

                res = pools[device_id].apply_async(forward_backward,
                                                   (input_image_train[device_id]["image"], image,
                                                    input_image_train[device_id]["label"], label,
                                                    losses_train[device_id], solvers[device_id]))
                fb_results.append(res)
            for device_id in range(n_devices):
                fb_results[device_id].get()

            # In-place Allreduce
            comm.allreduce()

            # Solvers update
            for device_id in range(n_devices):
                solvers[device_id].update()

            e = categorical_error(
                preds_train[-1].d, input_image_train[-1]["label"].d)
            monitor_loss.add(i * n_devices, losses_train[-1].d.copy())
            monitor_err.add(i * n_devices, e)
            monitor_time.add(i * n_devices)

    nn.save_parameters(os.path.join(
        args.model_save_path,
        'params_%06d.h5' % (args.max_iter / n_devices)))
Exemplo n.º 24
0
def train(data_iterator, monitor, config, comm, args):
    monitor_train_loss, monitor_train_recon = None, None
    monitor_val_loss, monitor_val_recon = None, None
    if comm.rank == 0:
        monitor_train_loss = MonitorSeries(
            config['monitor']['train_loss'], monitor, interval=config['train']['logger_step_interval'])
        monitor_train_recon = MonitorImageTile(config['monitor']['train_recon'], monitor, interval=config['train']['logger_step_interval'],
                                               num_images=config['train']['batch_size'])

        monitor_val_loss = MonitorSeries(
            config['monitor']['val_loss'], monitor, interval=config['train']['logger_step_interval'])
        monitor_val_recon = MonitorImageTile(config['monitor']['val_recon'], monitor, interval=config['train']['logger_step_interval'],
                                             num_images=config['train']['batch_size'])

    model = VQVAE(config)

    if not args.sample_from_pixelcnn:
        if config['train']['solver'] == 'adam':
            solver = S.Adam()
        else:
            solver = S.momentum()
        solver.set_learning_rate(config['train']['learning_rate'])

        train_loader = data_iterator(config, comm, train=True)
        if config['dataset']['name'] != 'imagenet':
            val_loader = data_iterator(config, comm, train=False)
        else:
            val_loader = None
    else:
        solver, train_loader, val_loader = None, None, None

    if not args.pixelcnn_prior:
        trainer = VQVAEtrainer(model, solver, train_loader, val_loader, monitor_train_loss,
                               monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm)
        num_epochs = config['train']['num_epochs']
    else:
        pixelcnn_model = GatedPixelCNN(config['prior'])
        trainer = TrainerPrior(model, pixelcnn_model, solver, train_loader, val_loader, monitor_train_loss,
                               monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm, eval=args.sample_from_pixelcnn)
        num_epochs = config['prior']['train']['num_epochs']

    if os.path.exists(config['model']['checkpoint']) and (args.load_checkpoint or args.sample_from_pixelcnn):
        checkpoint_path = config['model']['checkpoint'] if not args.pixelcnn_prior else config['prior']['checkpoint']
        trainer.load_checkpoint(checkpoint_path, msg='Parameters loaded from {}'.format(
            checkpoint_path), pixelcnn=args.pixelcnn_prior, load_solver=not args.sample_from_pixelcnn)

    if args.sample_from_pixelcnn:
        trainer.random_generate(
            args.sample_from_pixelcnn, args.sample_save_path)
        return

    for epoch in range(num_epochs):

        trainer.train(epoch)

        if epoch % config['val']['interval'] == 0 and val_loader != None:
            trainer.validate(epoch)

        if comm.rank == 0:
            if epoch % config['train']['save_param_step_interval'] == 0 or epoch == config['train']['num_epochs']-1:
                trainer.save_checkpoint(
                    config['model']['saved_models_dir'], epoch, pixelcnn=args.pixelcnn_prior)
Exemplo n.º 25
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    lambda_ = args.lambda_

    # Model
    # generator loss
    z = nn.Variable([batch_size, latent])
    x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True)
    p_fake = discriminator(x_fake, maps=maps)
    loss_gen = gan_loss(p_fake).apply(persistent=True)
    # discriminator loss
    p_fake = discriminator(x_fake, maps=maps)
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, maps=maps)
    loss_dis = gan_loss(p_fake, p_real).apply(persistent=True)
    # gradient penalty
    eps = F.rand(shape=[batch_size, 1, 1, 1])
    x_rmix = eps * x_real + (1.0 - eps) * x_fake
    p_rmix = discriminator(x_rmix, maps=maps)
    x_rmix.need_grad = True  # Enabling gradient computation for double backward
    grads = nn.grad([p_rmix], [x_rmix])
    l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads]
    gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms])
    loss_dis += lambda_ * gp
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    x_test = generator(z_test, maps=maps, test=True,
                       up=args.up).apply(persistent=True)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)

    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_cri = MonitorSeries("Negative Critic Loss",
                                     monitor,
                                     interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=batch_size,
                                                interval=1,
                                                normalize_method=denormalize)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Data Iterator
    di = data_iterator_cifar10(batch_size, True)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need backward to generator
        for _ in range(args.n_critic):
            solver_dis.zero_grad()
            x_real.d = di.next()[0] / 127.5 - 1.0
            z.d = np.random.randn(batch_size, latent)
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(clear_buffer=True)
            solver_dis.update()

        # Train generator
        x_fake.need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        z.d = np.random.randn(batch_size, latent)
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()
        # Monitor
        monitor_loss_gen.add(i, loss_gen.d)
        monitor_loss_cri.add(i, -loss_dis.d)
        monitor_time.add(i)

        # Save
        if i % args.save_interval == 0:
            monitor_image_tile_train.add(i, x_fake)
            monitor_image_tile_test.add(i, x_test)
            nn.save_parameters(
                os.path.join(args.monitor_path, "params_{}.h5".format(i)))

    # Last
    x_test.forward(clear_buffer=True)
    nn.save_parameters(
        os.path.join(args.monitor_path, "params_{}.h5".format(i)))
    monitor_image_tile_train.add(i, x_fake)
    monitor_image_tile_test.add(i, x_test)
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Communicator and Context
    extension_module = "cuda.cudnn"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx = extension_context(extension_module, device_id=device_id)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = cifar100_resnet23_prediction(
        image_train, ctx, test)
    loss_train = cifar100_resnet32_loss(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # add parameters to communicator
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar100_resnet23_prediction(
        image_valid, ctx, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(1. * n_train_samples /
                      args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = 1. * n_devices / warmup_iter

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)
    with data_iterator_cifar100(args.batch_size, True) as tdata, \
            data_iterator_cifar100(bs_valid, False) as vdata:
        # Training-loop
        for i in range(int(args.max_iter / n_devices)):
            # Validation
            if mpi_rank == 0:
                if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                    ve = 0.
                    for j in range(args.val_iter):
                        image, label = vdata.next()
                        input_image_valid["image"].d = image
                        pred_valid.forward()
                        ve += categorical_error(pred_valid.d, label)
                    ve /= args.val_iter
                    monitor_verr.add(i * n_devices, ve)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(os.path.join(
                        args.model_save_path, 'params_%06d.h5' % i))

            # Forward/Zerograd/Backward
            image, label = tdata.next()
            input_image_train["image"].d = image
            input_image_train["label"].d = label
            loss_train.forward()
            solver.zero_grad()
            loss_train.backward()

            # In-place Allreduce
            comm.allreduce(division=True)

            # Solvers update
            solver.update()

            # Linear Warmup
            if i < warmup_iter:
                lr = base_lr * n_devices * warmup_slope * i
                solver.set_learning_rate(lr)
            else:
                lr = base_lr * n_devices
                solver.set_learning_rate(lr)

            if mpi_rank == 0:
                e = categorical_error(
                    pred_train.d, input_image_train["label"].d)
                monitor_loss.add(i * n_devices, loss_train.d.copy())
                monitor_err.add(i * n_devices, e)
                monitor_time.add(i * n_devices)
    if mpi_rank == 0:
        nn.save_parameters(os.path.join(
            args.model_save_path,
            'params_%06d.h5' % (args.max_iter / n_devices)))
Exemplo n.º 27
0
def train():
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction

    # TRAIN
    maps = 64
    data_iterator = data_iterator_cifar10
    c = 3
    h = w = 32
    n_train = 50000
    n_valid = 10000

    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    label = nn.Variable([args.batch_size, 1])
    # Create model_prediction graph.
    pred = model_prediction(image, maps=maps, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # SSL Regularization
    loss += ssl_regularization(nn.get_parameters(), args.filter_decay,
                               args.channel_decay)

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = model_prediction(vimage, maps=maps, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Exemplo n.º 28
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path,
        '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)
Exemplo n.º 29
0
def CNN_run(args, ops, alphas_dict):
    """
        Based on the given model architecture,
        construct CNN and execute training.
        input:
            args: arguments set by user.
            ops: operations used in the network.
            arch_dict: a dictionary containing architecture information.
    """

    data_iterator = data_iterator_cifar10
    all_data = data_iterator(args.batch_size, True)
    tdata = all_data.slice(rng=None, slice_start=0, slice_end=25000)
    vdata = all_data.slice(rng=None, slice_start=25000, slice_end=50000)

    # CIFAR10 statistics, mean and variance
    CIFAR_MEAN = np.reshape([0.49139968, 0.48215827, 0.44653124], (1, 3, 1, 1))
    CIFAR_STD = np.reshape([0.24703233, 0.24348505, 0.26158768], (1, 3, 1, 1))

    channels, image_height, image_width = 3, 32, 32
    batch_size = args.batch_size
    initial_model_lr = args.model_lr

    one_epoch = tdata.size // batch_size
    max_iter = args.epoch * one_epoch

    # Create monitor.
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=100)
    monitor_err = MonitorSeries("Training error", monitor, interval=100)
    monitor_vloss = MonitorSeries("Validation loss", monitor, interval=100)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=100)

    # prepare variables and graph used for training
    image_train = nn.Variable(
        (batch_size, channels, image_height, image_width))
    label_train = nn.Variable((batch_size, 1))
    input_image_train = {"image": image_train, "label": label_train}
    pred_train = construct_networks(args, ops, image_train, test=False)
    loss_train = loss_function(pred_train, label_train)

    # prepare solvers for model parameters
    model_params_dict = \
        {k: v for k, v in nn.get_parameters().items() if "alpha_" not in k}
    solver_model = S.Momentum(initial_model_lr)
    solver_model.set_parameters(
        {
            k: v
            for k, v in nn.get_parameters().items()
            if k in model_params_dict.keys()
        },
        reset=False,
        retain_state=True)

    # prepare solvers for architecture parameters
    solver_archs = S.Adam(alpha=args.arch_lr, beta1=0.5, beta2=0.999)
    solver_archs.set_parameters(
        {
            k: v
            for k, v in nn.get_parameters().items() if k in alphas_dict.keys()
        },
        reset=False,
        retain_state=True)

    # Training-loop
    for i in range(max_iter):

        # Update Model Parameters.

        if args.second_order:
            # store the weights before update.
            original_weights = {
                k: nn.Variable(v.shape, need_grad=True).apply(
                    data=nn.NdArray(v.shape).copy_from(v.data))
                for k, v in nn.get_parameters().items() if "alpha_" not in k
            }

            # gradients refuge
            accumulated_gradient = \
                {k: nn.Variable(v.shape).apply(d=0)
                 for k, v in alphas_dict.items()}

        image, label = tdata.next()
        image = image / 255.0
        image = (image - CIFAR_MEAN) / CIFAR_STD
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()

        e = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_loss.add(i, loss_train.d.copy())
        monitor_err.add(i, e)

        if args.lr_control_model:
            new_lr = learning_rate_scheduler(i, max_iter, initial_model_lr, 0)
            solver_model.set_learning_rate(new_lr)

        solver_model.zero_grad()
        loss_train.backward(clear_buffer=True)

        if args.with_grad_clip_model:
            for k, v in model_params_dict.items():
                v.grad.copy_from(
                    F.clip_by_norm(v.grad, args.grad_clip_value_model))

        solver_model.weight_decay(args.weight_decay_model)
        solver_model.update()  # weights update ( w -> w')

        if args.second_order:
            updated_weights = {
                k: nn.Variable(v.shape, need_grad=True).apply(
                    data=nn.NdArray(v.shape).copy_from(v.data))
                for k, v in nn.get_parameters().items() if "alpha_" not in k
            }

        # Update Architecture Parameters.

        ve, vloss = 0., 0.
        v_image, v_label = vdata.next()
        v_image = v_image / 255.0
        v_image = (v_image - CIFAR_MEAN) / CIFAR_STD
        input_image_train["image"].d = v_image
        input_image_train["label"].d = v_label
        # compute Loss_on_valid(w', alpha)
        loss_train.forward(clear_no_need_grad=True)

        ve = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_vloss.add(i, loss_train.d.copy())
        monitor_verr.add(i, ve)

        solver_archs.zero_grad()
        solver_model.zero_grad()
        loss_train.backward(clear_buffer=True)  # its gradient is stored

        if args.second_order:
            accumulated_gradient = store_gradient(accumulated_gradient,
                                                  alphas_dict,
                                                  coeff=1.)

            # grad_alpha_L_val(w', alpha).  Note that gradient stored into .data
            delta_gradient_w = {
                k: nn.Variable(v.shape).apply(data=nn.NdArray(
                    v.shape).copy_from(v.grad),
                                              need_grad=True)
                for k, v in nn.get_parameters().items() if "alpha_" not in k
            }

            epsilon = 0.01 / np.sum(
                [np.linalg.norm(v.d) for v in delta_gradient_w.values()])

            coeff = 1.0 * epsilon
            # w -> w+ (= w + epsilon*grad_Loss_on_val(w', alpha))
            weight_modify(original_weights, delta_gradient_w,
                          model_params_dict, coeff)

            input_image_train["image"].d = image  # reuse the same data
            input_image_train["label"].d = label

            # compute Loss_on_train(w+, alpha)
            loss_train.forward()
            solver_archs.zero_grad()
            solver_model.zero_grad()
            loss_train.backward(clear_buffer=True)  # its gradient is stored

            # accumulate currently registered gradient
            coeff = (-1.) * args.eta / 2. * epsilon
            accumulated_gradient = store_gradient(accumulated_gradient,
                                                  alphas_dict, coeff)

            coeff = -1.0 * epsilon
            # w -> w- (= w - epsilon*grad_Loss_on_val(w', alpha))
            weight_modify(original_weights, delta_gradient_w,
                          model_params_dict, coeff)

            # compute Loss_on_train(w-, alpha)
            loss_train.forward()
            solver_archs.zero_grad()
            solver_model.zero_grad()
            loss_train.backward(clear_buffer=True)  # its gradient is stored

            # accumulate currently registered gradient again
            coeff = (+1.) * args.eta / 2. * epsilon
            accumulated_gradient = store_gradient(accumulated_gradient,
                                                  alphas_dict, coeff)

            # replace the weights
            for k, v in alphas_dict.items():
                nn.parameter.set_parameter(
                    k,
                    nn.Variable(v.shape).apply(data=v.data,
                                               grad=accumulated_gradient[k],
                                               need_grad=True))
            for k, v in model_params_dict.items():
                nn.parameter.set_parameter(
                    k,
                    nn.Variable(v.shape).apply(data=updated_weights[k].data,
                                               need_grad=True))

        solver_archs.weight_decay(args.weight_decay_archs)
        solver_archs.update()

        if i % 1000 == 0:
            for k, v in alphas_dict.items():
                keynames = k.split("_")
                print("\nParameters for {} cell, node {} to {};".format(
                    keynames[1], keynames[2], keynames[3]))
                show_ops_and_prob(v.d, ops)

    return alphas_dict
Exemplo n.º 30
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
Exemplo n.º 31
0
def main(args):
    # Setting
    device_id = args.device_id
    conf = args.conf
    path = args.conf.data_path
    B = conf.batch_size
    R = conf.n_rays
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    # Dataset
    ds = DTUMVSDataSource(path, R, shuffle=True)
    di = data_iterator_dtumvs(ds, B)

    camloc = nn.Variable([B, 3])
    raydir = nn.Variable([B, R, 3])
    alpha = nn.Variable.from_numpy_array(conf.alpha)
    color_gt = nn.Variable([B, R, 3])
    mask_obj = nn.Variable([B, R, 1])

    # Monitor
    interval = di.size
    monitor_path = create_monitor_path(conf.data_path, args.monitor_path)
    monitor = Monitor(monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=interval)
    monitor_mhit = MonitorSeries("Hit count", monitor, interval=1)
    monitor_color_loss = MonitorSeries(
        "Training color loss", monitor, interval=interval)
    monitor_mask_loss = MonitorSeries(
        "Training mask loss", monitor, interval=interval)
    monitor_eikonal_loss = MonitorSeries(
        "Training eikonal loss", monitor, interval=interval)
    monitor_time = MonitorTimeElapsed(
        "Training time", monitor, interval=interval)
    monitor_image = MonitorImage("Rendered image", monitor, interval=1)

    # Solver
    solver = S.Adam(conf.learning_rate)
    loss, color_loss, mask_loss, eikonal_loss, mask_hit = \
        idr_loss(camloc, raydir, alpha, color_gt, mask_obj, conf)
    solver.set_parameters(nn.get_parameters())

    # Training loop
    for i in range(conf.train_epoch):

        ds.change_sampling_idx()

        # Validate
        if i % conf.valid_epoch_interval == 0 and not args.skip_val:
            def validate(i):
                pose_ = ds.poses[conf.valid_index:conf.valid_index+1, ...]
                intrinsic_ = ds.intrinsics[conf.valid_index:conf.valid_index+1, ...]
                mask_obj_ = ds.masks[conf.valid_index:conf.valid_index+1, ...]
                image = render(pose_, intrinsic_, mask_obj_, conf)
                monitor_image.add(i, image)
                nn.save_parameters(f"{monitor_path}/model_{i:05d}.h5")
            validate(i)

        # Train
        for j in range(di.size):
            # Feed data
            color_, mask_, intrinsic_, pose_, xy_ = di.next()
            color_gt.d = color_
            mask_obj.d = mask_
            raydir_, camloc_ = generate_raydir_camloc(pose_, intrinsic_, xy_)
            raydir.d = raydir_
            camloc.d = camloc_

            # Network
            loss.forward()
            solver.zero_grad()
            loss.backward(clear_buffer=True)
            solver.update()

            # Monitor
            t = i * di.size + j
            monitor_mhit.add(t, np.sum(mask_hit.d))
            monitor_loss.add(t, loss.d)
            monitor_color_loss.add(t, color_loss.d)
            monitor_mask_loss.add(t, mask_loss.d)
            monitor_eikonal_loss.add(t, eikonal_loss.d)
            monitor_time.add(t)

        # Decay
        if i in conf.alpha_decay:
            alpha.d = alpha.d * 2.0
        if i in conf.lr_decay:
            solver.set_learning_rate(solver.learning_rate() * 0.5)

    validate(i)
Exemplo n.º 32
0
def classification_svd():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction_slim

    # TRAIN
    reference = "reference"
    slim = "slim"
    rrate = 0.5  # reduction rate
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create `reference` and "slim" prediction graph.
    model_load_path = args.model_load_path
    pred = mnist_cnn_prediction(image, scope=slim, rrate=rrate, test=False)
    pred.persistent = True

    # Decompose and set parameters
    decompose_network_and_set_params(model_load_path, reference, slim, rrate)
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create reference prediction graph.
    vpred = mnist_cnn_prediction(vimage, scope=slim, rrate=rrate, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    with nn.parameter_scope(slim):
        solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    best_ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Exemplo n.º 33
0
def main():

    random.seed(args.seed)
    np.random.seed(args.seed)

    # Prepare for CUDA.
    ctx = get_extension_context('cudnn', device_id=args.gpus)
    nn.set_default_context(ctx)

    start_full_time = time.time()
    from iterator import data_iterator

    # Data list for sceneflow data set
    train_list = "./dataset/sceneflow_train.csv"
    test_list = "./dataset/sceneflow_test.csv"
    train = True
    validation = True

    # Set monitor path.
    monitor_path = './nnmonitor' + str(datetime.now().strftime("%Y%m%d%H%M%S"))

    img_left, img_right, disp_img = read_csv(train_list)
    img_left_test, img_right_test, disp_img_test = read_csv(test_list)
    train_samples = len(img_left)
    test_samples = len(img_left_test)
    train_size = int(len(img_left) / args.batchsize_train)
    test_size = int(len(img_left_test) / args.batchsize_test)

    # Create data iterator.
    data_iterator_train = data_iterator(
        train_samples, args.batchsize_train, img_left, img_right, disp_img, train=True, shuffle=True, dataset=args.dataset)
    data_iterator_test = data_iterator(
        test_samples, args.batchsize_test, img_left_test, img_right_test, disp_img_test, train=False, shuffle=False, dataset=args.dataset)

    # Set data size

    print(train_size, test_size)

    # Define data shape for training.
    var_left = nn.Variable(
        (args.batchsize_train, 3, args.crop_height, args.crop_width))
    var_right = nn.Variable(
        (args.batchsize_train, 3, args.crop_height, args.crop_width))
    var_disp = nn.Variable(
        (args.batchsize_train, 1, args.crop_height, args.crop_width))
    # Define data shape for testing.
    var_left_test = nn.Variable(
        (args.batchsize_test, 3, args.im_height, args.im_width))
    var_right_test = nn.Variable(
        (args.batchsize_test, 3, args.im_height, args.im_width))
    var_disp_test = nn.Variable(
        (args.batchsize_test, 1, args.im_height, args.im_width))
    mask_test = nn.Variable(
        (args.batchsize_test, 1, args.im_height, args.im_width))

    if args.loadmodel is not None:
        # Loading CNN pretrained parameters.
        nn.load_parameters(args.loadmodel)

    # === for Training ===
    # Definition of pred
    pred1, pred2, pred3 = psm_net(var_left, var_right, args.maxdisp, True)
    mask_train = F.less_scalar(var_disp, args.maxdisp)
    sum_mask = F.maximum_scalar(F.sum(mask_train), 1)
    # Definition of loss
    loss = 0.5 * (0.5 * F.sum(F.huber_loss(pred1, var_disp)*mask_train)/(sum_mask) + 0.7 * F.sum(F.huber_loss(
        pred2, var_disp)*mask_train)/(sum_mask) + F.sum(F.huber_loss(pred3, var_disp)*mask_train)/(sum_mask))

    # === for Testing ===
    # Definition of pred
    mask_test = F.less_scalar(var_disp_test, args.maxdisp)
    sum_mask_test = F.maximum_scalar(F.sum(mask_test), 1)
    pred_test = psm_net(var_left_test, var_right_test, args.maxdisp, False)
    test_loss = F.sum(F.abs(pred_test - var_disp_test)*mask_test)/sum_mask_test

    # Prepare monitors.
    monitor = Monitor(monitor_path)
    monitor_train = MonitorSeries('Training loss', monitor, interval=1)
    monitor_test = MonitorSeries('Validation loss', monitor, interval=1)
    monitor_time_train = MonitorTimeElapsed(
        "Training time/epoch", monitor, interval=1)

    # Create a solver (parameter updater)
    solver = S.Adam(alpha=0.001, beta1=0.9, beta2=0.999)

    # Set Parameters
    params = nn.get_parameters()
    solver.set_parameters(params)
    params2 = nn.get_parameters(grad_only=False)
    solver.set_parameters(params2)

    for epoch in range(1, args.epochs+1):
        print('This is %d-th epoch' % (epoch))

        if validation:
            ## teting ##
            total_test_loss = 0

            index_test = 0
            while index_test < test_size:
                var_left_test.d, var_right_test.d, var_disp_test.d = data_iterator_test.next()
                test_loss.forward(clear_no_need_grad=True)
                total_test_loss += test_loss

                print('Iter %d test loss = %.3f' % (index_test, test_loss.d))
                index_test += 1
            test_error = total_test_loss/test_size
            print('epoch %d total 3-px error in val = %.3f' %
                  (epoch, test_error.d))
            # Pass validation loss to a monitor.
            monitor_test.add(epoch, test_error)

        if train:
            ## training ##
            total_train_loss = 0
            index = 0

            while index < train_size:

                # Get mini batch
                # Preprocess
                var_left.d, var_right.d, var_disp.d = data_iterator_train.next()
                loss.forward(clear_no_need_grad=True)
                # Initialize gradients
                solver.zero_grad()
                # Backward execution
                loss.backward(clear_buffer=True)
                # Update parameters by computed gradients
                solver.update()
                print('Iter %d training loss = %.3f' %
                      (index, loss.d))
                total_train_loss += loss.d
                index += 1
            train_error = total_train_loss/train_size
            monitor_time_train.add(epoch)
            print('epoch %d total training loss = %.3f' % (epoch, train_error))

            # Pass training loss to a monitor.
            monitor_train.add(epoch, train_error)
            print('full training time = %.2f HR' %
                  ((time.time() - start_full_time)/3600))

            # Save Parameter
            out_param_file = os.path.join(
                args.savemodel, 'psmnet_trained_param_' + str(epoch) + '.h5')
            nn.save_parameters(out_param_file)