Esempio n. 1
0
def test_from_dlpack_new(ext_name, numpy_type, torch_type):
    ctx = get_extension_context(ext_name)
    device_name = ctx.backend[0].split(':')[0]
    if device_name == 'cudnn':
        device_name = 'cuda'  # for PyTorch
    nn.set_default_context(ctx)

    # Init PyTorch Tensor
    t = torch.ones((5, 5), dtype=torch_type, device=torch.device(device_name))

    # PyTorch to DLPack
    dlp = torch.utils.dlpack.to_dlpack(t)

    # DLPack to NNabla
    a = nn.utils.dlpack.from_dlpack(dlp)
    assert a.dtype == numpy_type

    # Check if the memory locations are still same,
    # which means DlpackArray is not copied to other arrays
    # in the same ArrayGroup.
    a += 1
    assert np.all(a.data == t.to('cpu').detach().numpy().copy())
Esempio n. 2
0
def create_communicator(ignore_error=False):
    global _current_communicator

    import nnabla_ext.cudnn
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    context = get_extension_context(extension_module)
    try:
        logger.log(99, 'Create communicator with contexts {}'.format(context))
        _current_communicator = C.MultiProcessDataParalellCommunicator(context)
        _current_communicator.init()
        context.device_id = str(_current_communicator.rank %
                                _current_communicator.size)
        if _current_communicator.size == 1:
            _current_communicator = None
    except:
        if not ignore_error:
            raise
        logger.warning("Failed to initialize nnabla.communicators.")
        _current_communicator = None

    return _current_communicator
Esempio n. 3
0
def main(args):

    if not args.train_csv:
        print(
            "No user-made data given. Use caltech101 dataset for finetuning.")
    else:
        # prepare dataset.
        assert os.path.isfile(
            args.train_csv
        ), "csv file for training not found, create dataset first."

    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = nn.ext_utils.import_extension_module(args.context)

    print("Use {} for fine-tuning".format(args.model))
    model_name = args.model
    if model_name == "ResNet":
        num_layers = args.res_layer
    elif model_name == "VGG":
        num_layers = args.vgg_layer
    elif model_name == "SqueezeNet":
        num_layers = args.squeeze_ver
    else:
        num_layers = ""

    model_module = importlib.import_module("nnabla.models.imagenet")
    MODEL = getattr(model_module, model_name)
    if model_name in ["ResNet", "VGG", "SqueezeNet"]:
        model = MODEL(num_layers)  # got model
    else:
        model = MODEL()

    CNN_run(args, model)

    return
    def create_static_mix(self, parts, input_path: str, output_path: Path):
        self.download_and_verify()
        ctx = get_extension_context(self.context)
        nn.set_default_context(ctx)
        nn.set_auto_forward(True)

        audio, _ = self.audio_adapter.load(input_path,
                                           sample_rate=self.sample_rate)

        if audio.shape[1] > 2:
            warnings.warn('Channel count > 2! '
                          'Only the first two channels will be processed!')
            audio = audio[:, :2]

        if audio.shape[1] == 1:
            # if we have mono, let's duplicate it
            # as the input of OpenUnmix is always stereo
            print('received mono file, so duplicate channels')
            audio = np.repeat(audio, 2, axis=1)

        print('Separating...')
        estimates = separate(audio,
                             model_path=str(self.model_file_path),
                             niter=self.iterations,
                             alpha=self.alpha,
                             softmask=self.softmask,
                             residual_model=self.residual_model)

        final_source = None

        for name, source in estimates.items():
            if not parts[name]:
                continue
            final_source = source if final_source is None else final_source + source

        print('Writing to MP3...')
        self.audio_adapter.save(output_path, final_source, self.sample_rate,
                                'mp3', self.bitrate)
Esempio n. 5
0
def main():

    args = get_args()

    nn.set_default_context(get_extension_context(
        args.extension, device_id=args.device_id))

    if args.nnp is None:
        local_nnp_dir = os.path.join("asset", args.gym_env)
        local_nnp_file = os.path.join(local_nnp_dir, "qnet.nnp")

        if not find_local_nnp(args.gym_env):
            logger.info("Downloading nnp data since you didn't specify...")
            nnp_uri = os.path.join("https://nnabla.org/pretrained-models/nnp_models/examples/dqn",
                                   args.gym_env,
                                   "qnet.nnp")
            if not os.path.exists(local_nnp_dir):
                os.mkdir(local_nnp_dir)
            download(nnp_uri, output_file=local_nnp_file, open_file=False)
            logger.info("Download done!")

        args.nnp = local_nnp_file

    from atari_utils import make_atari_deepmind
    env = make_atari_deepmind(args.gym_env, valid=False)
    print('Observation:', env.observation_space)
    print('Action:', env.action_space)
    obs_sampler = ObsSampler(args.num_frames)
    val_replay_memory = ReplayMemory(env.observation_space.shape,
                                     env.action_space.shape,
                                     max_memory=args.num_frames)
    # just play greedily
    explorer = GreedyExplorer(
        env.action_space.n, use_nnp=True, nnp_file=args.nnp, name='qnet')
    validator = Validator(env, val_replay_memory, explorer, obs_sampler,
                          num_episodes=1, render=not args.no_render)
    while True:
        validator.step()
Esempio n. 6
0
def main():
    args = get_micro_args()

    args.num_nodes = args.num_nodes - 2

    if args.recommended_arch:
        filename = args.recommended_arch

    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = nn.ext_utils.import_extension_module(args.context)

    data_iterator = data_iterator_cifar10
    tdata = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)

    mean_val_train, std_val_train, channel, img_height, img_width, num_class = get_data_stats(
        tdata)
    mean_val_valid, std_val_valid, _, _, _, _ = get_data_stats(vdata)

    data_dict = {
        "train_data": (tdata, mean_val_train, std_val_train),
        "valid_data": (vdata, mean_val_valid, std_val_valid),
        "basic_info": (channel, img_height, img_width, num_class)
    }

    check_arch = np.load(filename, allow_pickle=True)
    print("Train the model whose architecture is:")
    show_arch(check_arch)

    val_acc = CNN_run(args,
                      check_arch.tolist(),
                      data_dict,
                      with_train=True,
                      after_search=True)
Esempio n. 7
0
def main(args):
    # Context
    ctx = get_extension_context("cudnn", device_id=args.device_id)
    nn.set_default_context(ctx)

    # Dataset (input is normalized in [-1, 1])
    ds = point_cloud_data_source(args.fpath, knn=-1, test=True)
    pts_true = ds.points

    # Sample from mesh (unnormalized)
    mesh = utils.read_mesh(args.mesh_data_path)
    pcd = mesh.sample_points_poisson_disk(ds.size, seed=412)
    pts_pred = np.asarray(pcd.points)
    pts_pred = utils.normalize(pts_pred)

    # Pair-wise distance
    cd0, cd1, cd, hd0, hd1, hd = utils.chamfer_hausdorff_dists(
        pts_pred, pts_true)

    # Chamfer distance
    print("----- Chamfer distance -----")
    log = """
    One-sided Chamfer distance (Pred, True):   {}
    One-sided Chamfer distance (True, Pred):   {}
    Chamfer distance:                          {}
    """.format(cd0, cd1, cd)
    print(log)

    # Hausdorff distance
    print("----- Hausdorff distance -----")
    log = """
    One-sided Hausdorff distance (Pred, True): {}
    One-sided Hausdorff distance (True, Pred): {}
    Hausdorff distance:                        {}
    """.format(hd0, hd1, hd)
    print(log)
Esempio n. 8
0
def main():
    args = get_args()
    from nnabla.ext_utils import get_extension_context
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)

    nn.load_parameters(args.weights)
    x = nn.Variable((1, 3, args.size, args.size))
    y = darknet19.darknet19_classification(x / 255, test=True)

    label_names = np.loadtxt('imagenet.shortnames.list',
                             dtype=str,
                             delimiter=',')[:1000]

    img = imread(args.input)
    img = imresize(img, (args.size, args.size))

    x.d = img.transpose(2, 0, 1).reshape(1, 3, args.size, args.size)
    y.forward(clear_buffer=True)

    # softmax
    p = F.reshape(F.mul_scalar(F.softmax(y.data), 100), (y.size, ))

    # Show top-5 prediction
    inds = np.argsort(y.d.flatten())[::-1][:5]
    for i in inds:
        print('{}: {:.1f}%'.format(label_names[i], p.data[i]))

    s = time.time()
    n_time = 10
    for i in range(n_time):
        y.forward(clear_buffer=True)
    # Invoking device-to-host copy to synchronize the device (if CUDA).
    _ = y.d
    print("Processing time: {:.1f} [ms/image]".format(
        (time.time() - s) / n_time * 1000))
Esempio n. 9
0
def main():
    """
        Start architecture evaluation (retraining from scratch).
    """
    args = get_args()
    print(args)

    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = nn.ext_utils.import_extension_module(args.context)

    assert os.path.exists(
        args.model_arch_name), "architecture's params seem to be missing!"

    ops = {
        0: dil_conv_3x3,
        1: dil_conv_5x5,
        2: sep_conv_3x3,
        3: sep_conv_5x5,
        4: max_pool_3x3,
        5: avg_pool_3x3,
        6: identity,
        7: zero
    }

    with open(args.model_arch_name, 'r') as f:
        arch_dict = json.load(f)

    print("Train the model whose architecture is:")
    show_derived_cell(args, ops, arch_dict["arch_normal"], "normal")
    show_derived_cell(args, ops, arch_dict["arch_reduction"], "reduction")
    CNN_run(args, ops, arch_dict)

    return
Esempio n. 10
0
def main():
    """
        an example usage.
    """
    args = get_args()
    # Context
    ctx = get_extension_context(args.context, device_id=args.device_id)
    nn.set_default_context(ctx)
    params_dir = args.params_dir
    model = args.model
    outfile = args.outfile

    paths = args.path
    ext0, ext1 = [os.path.splitext(path)[-1] for path in paths]
    assert ext0 == ext1, "given inputs are not the same filetype."

    if ext0 == ".txt":
        # assume image lists are given
        handle_textfiles(paths[0], paths[1], outfile, params_dir, model)

    elif ext0 == "":
        # assume directoriess are given
        handle_directories(paths[0], paths[1], outfile, params_dir, model)

    elif ext0 in [".png", "jpg"]:
        assert os.path.isfile(
            paths[0]), f"specified file {paths[0]} is not found."
        assert os.path.isfile(
            paths[1]), f"specified file {paths[1]} is not found."
        lpips = LPIPS(model=model, params_dir=params_dir)
        lpips_val = compute_lpips_of_paired_images(
            lpips, paths[0], paths[1], params_dir, model)
        print(f"LPIPS: {lpips_val.d.sum():.3f}")

    else:
        raise RuntimeError(f"Invalid input file {ext0}.")
Esempio n. 11
0
    def __next__(self):
        if self._first_batch is not None:
            batch = self._first_batch
            self._first_batch = None
            return batch
        if self._counter >= self._size:
            if self._auto_reset:
                self.reset()
            # raise StopIteration
        # Gather outputs
        outputs = []
        for p in self._pipes:
            outputs.append(p._share_outputs())
        for i in range(self._num_gpus):
            device_id = self._pipes[i].device_id
            # initialize dict for all output categories
            category_outputs = dict()
            # segregate outputs into categories
            for j, out in enumerate(outputs[i]):
                category_outputs[self.output_map[j]] = out

            # Change DALI TensorLists into Tensors
            category_tensors = dict()
            category_shapes = dict()
            for category, out in category_outputs.items():
                category_tensors[category] = out.as_tensor()
                category_shapes[category] = category_tensors[category].shape()

            # If we did not yet allocate memory for that batch, do it now
            if self._data_batches[i][self._current_data_batch] is None:
                self._category_nnabla_type = dict()
                self._category_device = dict()
                nnabla_gpu_device = get_extension_context('cudnn',
                                                          device_id=device_id)
                nnabla_cpu_device = get_extension_context('cpu')
                # check category and device
                for category in self._output_categories:
                    self._category_nnabla_type[category] = np.dtype(
                        category_tensors[category].dtype())
                    if type(category_tensors[category]) is TensorGPU:
                        self._category_device[category] = nnabla_gpu_device
                    else:
                        self._category_device[category] = nnabla_cpu_device

                nnabla_tensors = dict()
                for category in self._output_categories:
                    nnabla_tensors[category] = nn.NdArray(
                        category_shapes[category])

                self._data_batches[i][
                    self._current_data_batch] = nnabla_tensors
            else:
                nnabla_tensors = self._data_batches[i][
                    self._current_data_batch]

            # Copy data from DALI Tensors to nnabla tensors
            for category, tensor in category_tensors.items():
                feed_ndarray(tensor,
                             nnabla_tensors[category],
                             dtype=self._category_nnabla_type[category],
                             ctx=self._category_device[category])

        for p in self._pipes:
            p._release_outputs()
            p._run()

        copy_db_index = self._current_data_batch
        # Change index for double buffering
        self._current_data_batch = (self._current_data_batch + 1) % 2
        self._counter += self._num_gpus * self.batch_size

        if (self._stop_at_epoch) and (self._counter > self._size):
            # First calculate how much data is required to return exactly self._size entries.
            diff = self._num_gpus * self.batch_size - \
                (self._counter - self._size)
            # Figure out how many GPUs to grab from.
            numGPUs_tograb = int(np.ceil(diff / self.batch_size))
            # Figure out how many results to grab from the last GPU (as a fractional GPU batch may be required to
            # bring us right up to self._size).
            mod_diff = diff % self.batch_size
            data_fromlastGPU = mod_diff if mod_diff else self.batch_size

            # Grab the relevant data.
            # 1) Grab everything from the relevant GPUs.
            # 2) Grab the right data from the last GPU.
            # 3) Append data together correctly and return.
            output = [
                db[copy_db_index]
                for db in self._data_batches[0:numGPUs_tograb]
            ]
            output[-1] = output[-1].copy()
            for category in self._output_categories:
                output[-1][category] = output[-1][category][0:data_fromlastGPU]

            return output

        return [db[copy_db_index] for db in self._data_batches]
Esempio n. 12
0
        if epoch % config['val']['interval'] == 0 and val_loader != None:
            trainer.validate(epoch)

        if comm.rank == 0:
            if epoch % config['train']['save_param_step_interval'] == 0 or epoch == config['train']['num_epochs']-1:
                trainer.save_checkpoint(
                    config['model']['saved_models_dir'], epoch, pixelcnn=args.pixelcnn_prior)


if __name__ == '__main__':

    parser = make_parser()
    args = parser.parse_args()
    config = read_yaml(os.path.join('configs', '{}.yaml'.format(args.data)))
    ctx = get_extension_context(
        config['extension_module'], device_id=config['device_id'])
    nn.set_auto_forward(True)

    if args.data == 'mnist':
        data_iterator = mnist_iterator
    elif args.data == 'imagenet':
        data_iterator = imagenet_iterator
    elif args.data == 'cifar10':
        data_iterator = cifar10_iterator
    else:
        print('Dataset not recognized')
        exit(1)

    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(ctx)
Esempio n. 13
0
def main():
    args = get_args()
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)
    train(args)
Esempio n. 14
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for CIFAR10.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """

    # define training parameters
    augmented_shift = True
    augmented_flip = True
    batch_size = 128
    vbatch_size = 100
    num_classes = 10
    weight_decay = 0.0002
    momentum = 0.9
    learning_rates = (cfg.initial_learning_rate,)*80 + \
        (cfg.initial_learning_rate / 10.,)*40 + \
        (cfg.initial_learning_rate / 100.,)*40
    print('lr={}'.format(learning_rates))
    print('weight_decay={}'.format(weight_decay))
    print('momentum={}'.format(momentum))

    # create nabla context
    from nnabla.ext_utils import get_extension_context
    ctx = get_extension_context('cudnn', device_id=args.gpu)
    nn.set_default_context(ctx)

    # Initialize DataIterator for CIFAR10.
    logger.info("Get CIFAR10 Data ...")
    data = cifar_data.DataIterator(batch_size,
                                   augmented_shift=augmented_shift,
                                   augmented_flip=augmented_flip)
    vdata = cifar_data.DataIterator(vbatch_size, val=True)

    if cfg.weightfile is not None:
        logger.info(f"Loading weights from {cfg.weightfile}")
        nn.load_parameters(cfg.weightfile)

    # TRAIN
    # Create input variables.
    image = nn.Variable([batch_size, 3, 32, 32])
    label = nn.Variable([batch_size, 1])

    # Create prediction graph.
    pred, hidden = resnet_cifar10(image,
                                  num_classes=num_classes,
                                  cfg=cfg,
                                  test=False)
    pred.persistent = True

    # Compute initial network size
    num_weights, kbytes_weights = network_size_weights()
    kbytes_weights.forward()
    print(f"Initial network size (weights) is {float(kbytes_weights.d):.3f}KB "
          f"(total number of weights: {int(num_weights):d}).")

    num_activations, kbytes_activations = network_size_activations()
    kbytes_activations.forward()
    print(
        f"Initial network size (activations) is {float(kbytes_activations.d):.3f}KB "
        f"(total number of activations: {int(num_activations):d}).")

    # Create loss function.
    cost_lambda2 = nn.Variable(())
    cost_lambda2.d = cfg.initial_cost_lambda2
    cost_lambda2.persistent = True
    cost_lambda3 = nn.Variable(())
    cost_lambda3.d = cfg.initial_cost_lambda3
    cost_lambda3.persistent = True

    loss1 = F.mean(F.softmax_cross_entropy(pred, label))
    loss1.persistent = True

    if cfg.target_weight_kbytes > 0:
        loss2 = F.relu(kbytes_weights - cfg.target_weight_kbytes)**2
        loss2.persistent = True
    else:
        loss2 = nn.Variable(())
        loss2.d = 0
        loss2.persistent = True
    if cfg.target_activation_kbytes > 0:
        loss3 = F.relu(kbytes_activations - cfg.target_activation_kbytes)**2
        loss3.persistent = True
    else:
        loss3 = nn.Variable(())
        loss3.d = 0
        loss3.persistent = True

    loss = loss1 + cost_lambda2 * loss2 + cost_lambda3 * loss3

    # VALID
    # Create input variables.
    vimage = nn.Variable([vbatch_size, 3, 32, 32])
    vlabel = nn.Variable([vbatch_size, 1])
    # Create predition graph.
    vpred, vhidden = resnet_cifar10(vimage,
                                    num_classes=num_classes,
                                    cfg=cfg,
                                    test=True)
    vpred.persistent = True

    # Create Solver.
    if cfg.optimizer == "adam":
        solver = S.Adam(alpha=learning_rates[0])
    else:
        solver = S.Momentum(learning_rates[0], momentum)

    solver.set_parameters(nn.get_parameters())

    # Training loop (epochs)
    logger.info("Start Training ...")
    i = 0
    best_v_err = 1.0

    # logs of the results
    iters = []
    res_train_err = []
    res_train_loss = []
    res_val_err = []

    # print all variables that exist
    for k in nn.get_parameters():
        print(k)

    res_n_b = collections.OrderedDict()
    res_n_w = collections.OrderedDict()
    res_n_a = collections.OrderedDict()
    res_d_b = collections.OrderedDict()
    res_d_w = collections.OrderedDict()
    res_d_a = collections.OrderedDict()
    res_xmin_b = collections.OrderedDict()
    res_xmin_w = collections.OrderedDict()
    res_xmin_a = collections.OrderedDict()
    res_xmax_b = collections.OrderedDict()
    res_xmax_w = collections.OrderedDict()
    res_xmax_a = collections.OrderedDict()

    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'n') and (k.split('/')[-3] == 'bquant'):
            res_n_b[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'n') and (k.split('/')[-3] == 'Wquant'):
            res_n_w[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'n') and (k.split('/')[-3] == 'Aquant'):
            res_n_a[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'd') and (k.split('/')[-3] == 'bquant'):
            res_d_b[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'd') and (k.split('/')[-3] == 'Wquant'):
            res_d_w[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'd') and (k.split('/')[-3] == 'Aquant'):
            res_d_a[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmin') and (k.split('/')[-3] == 'bquant'):
            res_xmin_b[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmin') and (k.split('/')[-3] == 'Wquant'):
            res_xmin_w[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmin') and (k.split('/')[-3] == 'Aquant'):
            res_xmin_a[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmax') and (k.split('/')[-3] == 'bquant'):
            res_xmax_b[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmax') and (k.split('/')[-3] == 'Wquant'):
            res_xmax_w[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmax') and (k.split('/')[-3] == 'Aquant'):
            res_xmax_a[k] = []

    for epoch in range(len(learning_rates)):
        train_loss = list()
        train_loss1 = list()
        train_loss2 = list()
        train_loss3 = list()
        train_err = list()

        # check whether we need to adapt the learning rate
        if epoch > 0 and learning_rates[epoch - 1] != learning_rates[epoch]:
            solver.set_learning_rate(learning_rates[epoch])

        # Training loop (iterations)
        start_epoch = True
        while data.current != 0 or start_epoch:
            start_epoch = False
            # Next batch
            image.d, label.d = data.next()

            # Training forward/backward
            solver.zero_grad()

            loss.forward()
            loss.backward()

            if weight_decay is not None:
                solver.weight_decay(weight_decay)

            # scale gradients
            if cfg.target_weight_kbytes > 0 or cfg.target_activation_kbytes > 0:
                clip_quant_grads()

            solver.update()
            e = categorical_error(pred.d, label.d)
            train_loss += [loss.d]
            train_loss1 += [loss1.d]
            train_loss2 += [loss2.d]
            train_loss3 += [loss3.d]
            train_err += [e]

            # make sure that parametric values are clipped to correct values (if outside)
            clip_quant_vals()

            # Intermediate Validation (when constraint is set and fulfilled)
            kbytes_weights.forward()
            kbytes_activations.forward()
            if ((cfg.target_weight_kbytes > 0
                 and (cfg.target_weight_kbytes <= 0
                      or float(kbytes_weights.d) <= cfg.target_weight_kbytes)
                 and (cfg.target_activation_kbytes <= 0 or float(
                     kbytes_activations.d) <= cfg.target_activation_kbytes))):

                ve = list()
                start_epoch_ = True
                while vdata.current != 0 or start_epoch_:
                    start_epoch_ = False
                    vimage.d, vlabel.d = vdata.next()
                    vpred.forward()
                    ve += [categorical_error(vpred.d, vlabel.d)]

                v_err = np.array(ve).mean()
                if v_err < best_v_err:
                    best_v_err = v_err
                    nn.save_parameters(
                        os.path.join(cfg.params_dir, 'params_best.h5'))
                    print(
                        f'Best validation error (fulfilling constraints: {best_v_err}'
                    )
                    sys.stdout.flush()
                    sys.stderr.flush()

            i += 1

        # Validation
        ve = list()
        start_epoch = True
        while vdata.current != 0 or start_epoch:
            start_epoch = False
            vimage.d, vlabel.d = vdata.next()
            vpred.forward()
            ve += [categorical_error(vpred.d, vlabel.d)]

        v_err = np.array(ve).mean()
        kbytes_weights.forward()
        kbytes_activations.forward()
        if ((v_err < best_v_err
             and (cfg.target_weight_kbytes <= 0
                  or float(kbytes_weights.d) <= cfg.target_weight_kbytes) and
             (cfg.target_activation_kbytes <= 0 or
              float(kbytes_activations.d) <= cfg.target_activation_kbytes))):
            best_v_err = v_err
            nn.save_parameters(os.path.join(cfg.params_dir, 'params_best.h5'))
            sys.stdout.flush()
            sys.stderr.flush()

        if cfg.target_weight_kbytes > 0:
            print(
                f"Current network size (weights) is {float(kbytes_weights.d):.3f}KB "
                f"(#params: {int(num_weights)}, "
                f"avg. bitwidth: {8. * 1024. * kbytes_weights.d / num_weights})"
            )
            sys.stdout.flush()
            sys.stderr.flush()
        if cfg.target_activation_kbytes > 0:
            print(
                f"Current network size (activations) is {float(kbytes_activations.d):.3f}KB"
            )
            sys.stdout.flush()
            sys.stderr.flush()

        for k in nn.get_parameters():
            if k.split('/')[-1] == 'n':
                print(f'{k}', f'{nn.get_parameters()[k].d}',
                      f'{nn.get_parameters()[k].g}')
                sys.stdout.flush()
                sys.stderr.flush()
                if k.split('/')[-3] == 'bquant':
                    res_n_b[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Wquant':
                    res_n_w[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Aquant':
                    res_n_a[k].append(np.asscalar(nn.get_parameters()[k].d))

            elif k.split('/')[-1] == 'd':
                print(f'{k}', f'{nn.get_parameters()[k].d}',
                      f'{nn.get_parameters()[k].g}')
                sys.stdout.flush()
                sys.stderr.flush()
                if k.split('/')[-3] == 'bquant':
                    res_d_b[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Wquant':
                    res_d_w[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Aquant':
                    res_d_a[k].append(np.asscalar(nn.get_parameters()[k].d))

            elif k.split('/')[-1] == 'xmin':
                print(f'{k}', f'{nn.get_parameters()[k].d}',
                      f'{nn.get_parameters()[k].g}')
                sys.stdout.flush()
                sys.stderr.flush()
                if k.split('/')[-3] == 'bquant':
                    res_xmin_b[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Wquant':
                    res_xmin_w[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Aquant':
                    res_xmin_a[k].append(np.asscalar(nn.get_parameters()[k].d))

            elif k.split('/')[-1] == 'xmax':
                print(f'{k}', f'{nn.get_parameters()[k].d}',
                      f'{nn.get_parameters()[k].g}')
                sys.stdout.flush()
                sys.stderr.flush()
                if k.split('/')[-3] == 'bquant':
                    res_xmax_b[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Wquant':
                    res_xmax_w[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Aquant':
                    res_xmax_a[k].append(np.asscalar(nn.get_parameters()[k].d))

        # Print
        logger.info(f'epoch={epoch}(iter={i}); '
                    f'overall cost={np.array(train_loss).mean()}; '
                    f'cross-entropy cost={np.array(train_loss1).mean()}; '
                    f'weight-size cost={np.array(train_loss2).mean()}; '
                    f'activations-size cost={np.array(train_loss3).mean()}; '
                    f'TrainErr={np.array(train_err).mean()}; '
                    f'ValidErr={v_err}; BestValidErr={best_v_err}')
        sys.stdout.flush()
        sys.stderr.flush()

        # update the logs
        iters.append(i)
        res_train_err.append(np.array(train_err).mean())
        res_train_loss.append([
            np.array(train_loss).mean(),
            np.array(train_loss1).mean(),
            np.array(train_loss2).mean(),
            np.array(train_loss3).mean()
        ])
        res_val_err.append(np.array(v_err).mean())
        res_ges = np.concatenate([
            np.array(iters)[:, np.newaxis],
            np.array(res_train_err)[:, np.newaxis],
            np.array(res_val_err)[:, np.newaxis],
            np.array(res_train_loss)
        ],
                                 axis=-1)

        # save the results
        np.savetxt(cfg.params_dir + '/results.csv',
                   np.array(res_ges),
                   fmt='%10.8f',
                   header='iter,train_err,val_err,loss,loss1,loss2,loss3',
                   comments='',
                   delimiter=',')

        for rs, res in zip([
                'res_n_b.csv', 'res_n_w.csv', 'res_n_a.csv', 'res_d_b.csv',
                'res_d_w.csv', 'res_d_a.csv', 'res_min_b.csv', 'res_min_w.csv',
                'res_min_a.csv', 'res_max_b.csv', 'res_max_w.csv',
                'res_max_a.csv'
        ], [
                res_n_b, res_n_w, res_n_a, res_d_b, res_d_w, res_d_a,
                res_xmin_b, res_xmin_w, res_xmin_a, res_xmax_b, res_xmax_w,
                res_xmax_a
        ]):
            res_mat = np.array([res[i] for i in res])
            if res_mat.shape[0] > 1 and res_mat.shape[1] > 1:
                np.savetxt(
                    cfg.params_dir + '/' + rs,
                    np.array([[i, j, res_mat[i, j]] for i, j in product(
                        range(res_mat.shape[0]), range(res_mat.shape[1]))]),
                    fmt='%10.8f',
                    comments='',
                    delimiter=',')
Esempio n. 15
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
Esempio n. 16
0
def generate(args):
    # Load model
    nn.load_parameters(args.model_load_path)

    # Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = 1, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])
    one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5)

    # Model
    maps = args.maps
    # content/style (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    # content/style (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(
        shape=x_style_a.shape) if not args.example_guided else x_style_a
    z_style_a = z_style_a.apply(persistent=True)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(
        shape=x_style_b.shape) if not args.example_guided else x_style_b
    z_style_b = z_style_b.apply(persistent=True)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")

    # Monitor
    suffix = "Stochastic" if not args.example_guided else "Example-guided"
    monitor = Monitor(args.monitor_path)
    monitor_image_a = MonitorImage("Fake Image B to A {} Valid".format(suffix),
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B {} Valid".format(suffix),
                                   monitor,
                                   interval=1)

    # DataIterator
    di_a = munit_data_iterator(args.img_path_a, args.batch_size)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size)

    # Generate all
    # generate (A -> B)
    if args.example_guided:
        x_real_b.d = di_b.next()[0]
    for i in range(di_a.size):
        x_real_a.d = di_a.next()[0]
        images = []
        images.append(x_real_a.d.copy())
        for _ in range(args.num_repeats):
            x_fake_b.forward(clear_buffer=True)
            images.append(x_fake_b.d.copy())
        monitor_image_b.add(i, np.concatenate(images, axis=3))

    # generate (B -> A)
    if args.example_guided:
        x_real_a.d = di_a.next()[0]
    for i in range(di_b.size):
        x_real_b.d = di_b.next()[0]
        images = []
        images.append(x_real_b.d.copy())
        for _ in range(args.num_repeats):
            x_fake_a.forward(clear_buffer=True)
            images.append(x_fake_a.d.copy())
        monitor_image_a.add(i, np.concatenate(images, axis=3))
Esempio n. 17
0
def calc_latency_and_onnx(exp_nr, calc_latency=False, ext_name='cpu',
                          device_id=0, onnx=False):

    nn.clear_parameters()
    N = 10  # number of random networks to sample
    estim_net = None
    estim_accum_by_graph = None
    estim_accum_by_mod = None

    #  10 **************************
    if exp_nr == 10:
        from nnabla_nas.contrib import zoph

        ctx = get_extension_context(ext_name=ext_name, device_id=device_id)
        nn.set_default_context(ctx)

        OUTPUT_DIR = './logs/zoph/one_net/'

        # Sample one ZOPH network from the search space
        shape = (1, 3, 32, 32)
        input = nn.Variable(shape)
        zn = zoph.SearchNet()
        zn.apply(training=False)
        output = zn(input)

        estim_net, estim_accum_by_graph, estim_accum_by_mod = \
            init_calc_latency(output, ext_name=ext_name, device_id=device_id)

        # zn_unique_active_modules = get_active_and_profiled_modules(zn)

        # SAVE GRAPH in PDF
        zn.save_graph(OUTPUT_DIR + 'zn')
        # SAVE WHOLE NET. Calc whole net (real) latency using [Profiler]
        # Calculate also layer-based latency
        # The modules are discovered using the nnabla graph of the whole net
        # The latency is then calculated based on each individual module's
        # nnabla graph [LatencyGraphEstimator]
        net_lat, acc_lat = zn.save_net_nnp(
                            OUTPUT_DIR + 'zn', input, output,
                            calc_latency=calc_latency,
                            func_real_latency=estim_net,
                            func_accum_latency=estim_accum_by_graph
                            )

        # reset estim_accum_by_graph
        estim_net, estim_accum_by_graph, estim_accum_by_mod = \
            init_calc_latency(output, ext_name=ext_name, device_id=device_id)

        # SAVE ALL MODULES. Calc layer-based latency.
        # The modules are discovered by going over the module list.
        # The latency is then calculated based on each individual module's
        # nnabla graph [LatencyGraphEstimator]
        acc_lat_g = zn.save_modules_nnp(
                            OUTPUT_DIR + 'zn_by_graph',
                            active_only=True,
                            calc_latency=calc_latency,
                            func_latency=estim_accum_by_graph
                            )

        # SAVE ALL MODULES. Calc layer-based latency.
        # The modules are discovered by going over the module list.
        # The latency is then calculated using the module [LatencyEstimator]
        # **** This function is deprecated ****
        acc_lat_m = zn.save_modules_nnp_by_mod(
                            OUTPUT_DIR + 'zn_by_module',
                            active_only=True,
                            calc_latency=calc_latency,
                            func_latency=estim_accum_by_mod
                            )

        # reset estim_accum_by_graph, estim_accum_by_mod
        estim_net, estim_accum_by_graph, estim_accum_by_mod = \
            init_calc_latency(output, ext_name=ext_name, device_id=device_id)

        # Just obtain the latencies of all modules without saving files
        # By graph
        latencies_1, acc_lat_1 = zn.get_latency(estim_accum_by_graph)
        # By module (deprecated)
        latencies_2, acc_lat_2 = zn.get_latency_by_mod(estim_accum_by_mod)

        print(net_lat, acc_lat, acc_lat_g, acc_lat_m, acc_lat_1, acc_lat_2)

        # CONVERT ALL TO ONNX
        if onnx:
            zn.convert_npp_to_onnx(OUTPUT_DIR)

        # VERBOSITY - INFO OF NETWORK CONTENT
        # with open(OUTPUT_DIR + 'zn.txt', 'w') as f:
        #    print_me(zn, f)

    #  11 **************************
    if exp_nr == 11:
        from nnabla_nas.contrib import zoph
        ctx = get_extension_context(ext_name=ext_name, device_id=device_id)
        nn.set_default_context(ctx)

        OUTPUT_DIR = './logs/zoph/many_different_nets/'

        shape = (1, 3, 32, 32)
        input = nn.Variable(shape)

        # Sample N zoph networks from the search space
        for i in range(0, N):
            nn.clear_parameters()
            zn = zoph.SearchNet()
            zn.apply(training=False)
            output = zn(input)
            if calc_latency:
                estim_net, estim_accum_by_graph, estim_accum_by_mod = \
                    init_calc_latency(output,
                                      ext_name=ext_name,
                                      device_id=device_id
                                      )

            zn.save_graph(OUTPUT_DIR + 'zn' + str(i))
            zn.save_net_nnp(OUTPUT_DIR + 'zn' + str(i), input, output,
                            calc_latency=calc_latency,
                            func_real_latency=estim_net,
                            func_accum_latency=estim_accum_by_graph
                            )
            zn.save_modules_nnp(OUTPUT_DIR + 'zn' + str(i), active_only=True,
                                calc_latency=calc_latency,
                                func_latency=estim_accum_by_graph
                                )
        if onnx:
            zn = zoph.SearchNet()
            zn.convert_npp_to_onnx(OUTPUT_DIR, opset='opset_snpe')

    #  12 **************************
    if exp_nr == 12:
        from nnabla_nas.contrib import zoph
        import time
        ctx = get_extension_context(ext_name=ext_name, device_id=device_id)
        nn.set_default_context(ctx)

        OUTPUT_DIR = './logs/zoph/same_net_many_times/'

        shape = (1, 3, 32, 32)
        input = nn.Variable(shape)
        zn = zoph.SearchNet()
        zn.apply(training=False)
        output = zn(input)

        # Measure add-hoc latency of zoph network
        for i in range(0, N):
            n_run = 100
            # warm up
            output.forward()

            result = 0.0
            for i in range(n_run):
                start = time.time()
                output.forward()
                stop = time.time()
                result += stop - start

            mean_time = result / n_run
            print(mean_time*1000)

        # Measure latency on same zoph network N times
        for i in range(0, N):
            if calc_latency:
                estim_net, estim_accum_by_graph, estim_accum_by_mod = \
                    init_calc_latency(output,
                                      ext_name=ext_name,
                                      device_id=device_id
                                      )
            zn.save_net_nnp(OUTPUT_DIR + 'zn' + str(i), input, output,
                            calc_latency=calc_latency,
                            func_real_latency=estim_net,
                            func_accum_latency=estim_accum_by_graph
                            )
            zn.save_modules_nnp(OUTPUT_DIR + 'zn' + str(i), active_only=True,
                                calc_latency=calc_latency,
                                func_latency=estim_accum_by_graph
                                )
        zn.save_graph(OUTPUT_DIR)
        if onnx:
            zn.convert_npp_to_onnx(OUTPUT_DIR)

    #  20 **************************
    if exp_nr == 20:
        from nnabla_nas.contrib import random_wired
        ctx = get_extension_context(ext_name=ext_name, device_id=device_id)
        nn.set_default_context(ctx)

        OUTPUT_DIR = './logs/rdn/one_net/'

        # Sample one random wired network from the search space
        shape = (1, 3, 32, 32)
        input = nn.Variable(shape)
        rw = random_wired.TrainNet()
        rw.apply(training=False)
        output = rw(input)

        if calc_latency:
            estim_net, estim_accum_by_graph, estim_accum_by_mod = \
                init_calc_latency(output,
                                  ext_name=ext_name,
                                  device_id=device_id
                                  )

        rw.save_graph(OUTPUT_DIR + 'rw')
        rw.save_net_nnp(OUTPUT_DIR + 'rw', input, output,
                        calc_latency=calc_latency,
                        func_real_latency=estim_net,
                        func_accum_latency=estim_accum_by_graph
                        )
        rw.save_modules_nnp(OUTPUT_DIR + 'rw', active_only=True,
                            calc_latency=calc_latency,
                            func_latency=estim_accum_by_graph
                            )

        if onnx:
            rw.convert_npp_to_onnx(OUTPUT_DIR)

    #  21 **************************
    if exp_nr == 21:
        from nnabla_nas.contrib import random_wired
        ctx = get_extension_context(ext_name=ext_name, device_id=device_id)
        nn.set_default_context(ctx)

        OUTPUT_DIR = './logs/rdn/many_different_nets/'

        shape = (1, 3, 32, 32)
        input = nn.Variable(shape)

        # Measure latency on same rdn network N times
        for i in range(0, N):
            nn.clear_parameters()
            rw = random_wired.TrainNet()
            rw.apply(training=False)
            output = rw(input)

            if calc_latency:
                estim_net, estim_accum_by_graph, estim_accum_by_mod = \
                    init_calc_latency(output,
                                      ext_name=ext_name,
                                      device_id=device_id
                                      )
            rw.save_graph(OUTPUT_DIR + 'rw' + str(i))
            rw.save_net_nnp(OUTPUT_DIR + 'rw' + str(i), input, output,
                            calc_latency=calc_latency,
                            func_real_latency=estim_net,
                            func_accum_latency=estim_accum_by_graph
                            )
            rw.save_modules_nnp(OUTPUT_DIR + 'rw' + str(i), active_only=True,
                                calc_latency=calc_latency,
                                func_latency=estim_accum_by_graph
                                )

        if onnx:
            rw = random_wired.TrainNet()
            rw.convert_npp_to_onnx(OUTPUT_DIR, opset='opset_snpe')

    #  22 **************************
    if exp_nr == 22:
        from nnabla_nas.contrib import random_wired
        import time

        OUTPUT_DIR = './logs/rdn/same_net_many_times/'
        ctx = get_extension_context(ext_name=ext_name, device_id=device_id)
        nn.set_default_context(ctx)

        shape = (1, 3, 32, 32)
        input = nn.Variable(shape)
        rw = random_wired.TrainNet()
        rw.apply(training=False)
        output = rw(input)

        for i in range(0, N):
            n_run = 10
            # warm up
            output.forward()

            result = 0.0
            for i in range(n_run):
                start = time.time()
                output.forward()
                stop = time.time()
                result += stop - start
            mean_time = result / n_run
            print(mean_time*1000)

        # Measure latency on same rdn network N times
        for i in range(0, N):
            if calc_latency:
                estim_net, estim_accum_by_graph, estim_accum_by_mod = \
                    init_calc_latency(output,
                                      ext_name=ext_name,
                                      device_id=device_id
                                      )
            rw.save_net_nnp(OUTPUT_DIR + 'rw' + str(i), input, output,
                            calc_latency=calc_latency,
                            func_real_latency=estim_net,
                            func_accum_latency=estim_accum_by_graph
                            )
            rw.save_modules_nnp(OUTPUT_DIR + 'rw' + str(i), active_only=True,
                                calc_latency=calc_latency,
                                func_latency=estim_accum_by_graph
                                )
        rw.save_graph(OUTPUT_DIR + 'rw' + str(i))

        if onnx:
            rw.convert_npp_to_onnx(OUTPUT_DIR)

    #  31 **************************
    if exp_nr == 31:
        from nnabla_nas.contrib.classification.mobilenet import SearchNet
        ctx = get_extension_context(ext_name=ext_name, device_id=device_id)
        nn.set_default_context(ctx)

        OUTPUT_DIR = './logs/mobilenet/many_different_nets/'

        input = nn.Variable((1, 3, 224, 224))

        # number of random networks to sample
        # Sample N networks from the search space
        for i in range(0, N):
            nn.clear_parameters()
            mobile_net = SearchNet(num_classes=1000)
            mobile_net.apply(training=False)
            output = mobile_net(input)

            if calc_latency:
                estim_net, estim_accum_by_graph, estim_accum_by_mod = \
                    init_calc_latency(output,
                                      ext_name=ext_name,
                                      device_id=device_id
                                      )
            # This calculates the actual latency of the network
            # and the accum. latency by adding the latencies of
            # each module, following the nnabla graph
            mobile_net.save_net_nnp(OUTPUT_DIR + 'mn' + str(i), input, output,
                                    calc_latency=calc_latency,
                                    func_real_latency=estim_net,
                                    func_accum_latency=estim_accum_by_graph
                                    )

            """
            # This calculates the latency of each module going
            # over the nnabla graph (deprecated)
            mobile_net.calc_latency_all_modules(
                    OUTPUT_DIR + 'graph_mn' + str(i), output,
                    func_latency=estim_accum_by_graph)
            """
            # This calculates the latency of each module going
            # the tree of the modules
            mobile_net.save_modules_nnp(
                            OUTPUT_DIR + 'mn' + str(i),  active_only=True,
                            calc_latency=calc_latency,
                            func_latency=estim_accum_by_graph
                            )
        if onnx:
            mobile_net = SearchNet()
            mobile_net.convert_npp_to_onnx(OUTPUT_DIR, opset='opset_snpe')

    #  32 **************************
    if exp_nr == 32:
        from nnabla_nas.contrib.classification.mobilenet import SearchNet
        import time

        OUTPUT_DIR = './logs/mobilenet/same_net_many_times/'
        ctx = get_extension_context(ext_name=ext_name, device_id=device_id)
        nn.set_default_context(ctx)

        input = nn.Variable((1, 3, 224, 224))
        mobile_net = SearchNet(num_classes=1000)
        mobile_net.apply(training=False)
        output = mobile_net(input)

        for i in range(0, N):
            n_run = 100
            # warm up
            output.forward()

            result = 0.0
            for i in range(n_run):
                start = time.time()
                output.forward()
                stop = time.time()
                result += stop - start

            mean_time = result / n_run
            print(mean_time*1000)

        # Measure latency on same network N times
        for i in range(0, N):
            if calc_latency:
                estim_net, estim_accum_by_graph, estim_accum_by_mod = \
                    init_calc_latency(output,
                                      ext_name=ext_name,
                                      device_id=device_id
                                      )

            mobile_net.save_net_nnp(OUTPUT_DIR + 'mn' + str(i), input, output,
                                    calc_latency=calc_latency,
                                    func_real_latency=estim_net,
                                    func_accum_latency=estim_accum_by_graph
                                    )

            mobile_net.save_modules_nnp(
                            OUTPUT_DIR + 'mn' + str(i), active_only=True,
                            calc_latency=calc_latency,
                            func_latency=estim_accum_by_graph
                            )
        if onnx:
            mobile_net.convert_npp_to_onnx(OUTPUT_DIR)

    #  4 **************************
    if exp_nr == 4:
        from nnabla_nas.module import static as smo

        input1 = nn.Variable((1, 256, 32, 32))
        input2 = nn.Variable((1, 384, 32, 32))
        input3 = nn.Variable((1, 128, 32, 32))
        input4 = nn.Variable((1, 768, 32, 32))
        input5 = nn.Variable((1, 1280, 32, 32))
        input6 = nn.Variable((1, 2048, 32, 32))
        input7 = nn.Variable((1, 512, 32, 32))
        input8 = nn.Variable((1, 192, 32, 32))
        input9 = nn.Variable((1, 224, 32, 32))

        static_input1 = smo.Input(value=input1)
        static_input2 = smo.Input(value=input2)
        static_input3 = smo.Input(value=input3)
        static_input4 = smo.Input(value=input4)
        static_input5 = smo.Input(value=input5)
        static_input6 = smo.Input(value=input6)
        static_input7 = smo.Input(value=input7)
        static_input8 = smo.Input(value=input8)
        static_input9 = smo.Input(value=input9)

        myconv1 = smo.Conv(parents=[static_input1], in_channels=256,
                           out_channels=128, kernel=(1, 1), pad=None, group=1)
        myconv2 = smo.Conv(parents=[static_input2], in_channels=384,
                           out_channels=128, kernel=(1, 1), pad=None, group=1)
        myconv3 = smo.Conv(parents=[static_input3], in_channels=128,
                           out_channels=256, kernel=(1, 1), pad=None, group=1)
        myconv4 = smo.Conv(parents=[static_input4], in_channels=768,
                           out_channels=256, kernel=(1, 1))
        myconv5 = smo.Conv(parents=[static_input5], in_channels=1280,
                           out_channels=256, kernel=(1, 1), pad=None, group=1)
        myconv6 = smo.Conv(parents=[static_input6], in_channels=2048,
                           out_channels=256, kernel=(1, 1), pad=None, group=1)
        myconv7 = smo.Conv(parents=[static_input7], in_channels=512,
                           out_channels=512, kernel=(3, 3), pad=(1, 1), group=1
                           )
        myconv8 = smo.Conv(parents=[static_input8], in_channels=192,
                           out_channels=512, kernel=(7, 7), pad=(3, 3), group=1
                           )
        myconv9 = smo.Conv(parents=[static_input9], in_channels=224,
                           out_channels=128, kernel=(5, 5), pad=(2, 2), group=1
                           )

        output1 = myconv1()
        output2 = myconv2()
        output3 = myconv3()
        output4 = myconv4()
        output5 = myconv5()
        output6 = myconv6()
        output7 = myconv7()
        output8 = myconv8()
        output9 = myconv9()

        N = 10
        for i in range(0, N):
            mean_time = estim_fwd(output1)
            print("1, ", mean_time)
            mean_time = estim_fwd(output2)
            print("2, ", mean_time)
            mean_time = estim_fwd(output3)
            print("3, ", mean_time)
            mean_time = estim_fwd(output4)
            print("4, ", mean_time)
            mean_time = estim_fwd(output5)
            print("5, ", mean_time)
            mean_time = estim_fwd(output6)
            print("6, ", mean_time)
            mean_time = estim_fwd(output7)
            print("7, ", mean_time)
            mean_time = estim_fwd(output8)
            print("8, ", mean_time)
            mean_time = estim_fwd(output9)
            print("9, ", mean_time)

        N = 100
        from nnabla_nas.utils.estimator.latency import LatencyGraphEstimator
        for i in range(0, N):
            estimation = LatencyGraphEstimator(n_run=100, ext_name='cpu')
            latency = estimation.get_estimation(myconv1)
            latency = estimation.get_estimation(myconv2)
            latency = estimation.get_estimation(myconv3)
            latency = estimation.get_estimation(myconv4)
            latency = estimation.get_estimation(myconv5)
            latency = estimation.get_estimation(myconv6)
            latency = estimation.get_estimation(myconv7)
            latency = estimation.get_estimation(myconv8)
            latency = estimation.get_estimation(myconv9)

            estimation = LatencyGraphEstimator(n_run=100, ext_name='cpu')
            latency = estimation.get_estimation(myconv9)
            latency = estimation.get_estimation(myconv8)
            latency = estimation.get_estimation(myconv7)
            latency = estimation.get_estimation(myconv6)
            latency = estimation.get_estimation(myconv5)
            latency = estimation.get_estimation(myconv4)
            latency = estimation.get_estimation(myconv3)
            latency = estimation.get_estimation(myconv2)
            latency = estimation.get_estimation(myconv1)

            estimation = LatencyGraphEstimator(n_run=100, ext_name='cpu')
            latency = estimation.get_estimation(myconv6)
            latency = estimation.get_estimation(myconv9)
            latency = estimation.get_estimation(myconv1)
            latency = estimation.get_estimation(myconv4)
            latency = estimation.get_estimation(myconv8)
            latency = estimation.get_estimation(myconv3)
            latency = estimation.get_estimation(myconv5)
            latency = estimation.get_estimation(myconv7)
            latency = estimation.get_estimation(myconv2)

            latency += 0  # to avoid lint/flake8 error

    #  5 **************************
    if exp_nr == 5:
        from nnabla_nas.module import static as smo
        from nnabla_nas.utils.estimator.latency import LatencyGraphEstimator
        from numpy.random import permutation
        import numpy as np

        run_also_ours_at_the_end = True

        N_conv = 50  # number of different convolutions tried
        in_sizes = np.random.randint(low=1, high=1000, size=N_conv)
        out_sizes = np.random.randint(low=1, high=600, size=N_conv)
        kernel_sizes = np.random.randint(low=1, high=7, size=N_conv)
        feat_sizes = np.random.randint(low=16, high=48, size=N_conv)

        N = 100
        for j in range(N):
            estimation = LatencyGraphEstimator(n_run=100, ext_name='cpu')
            print('****************** RUN ********************')
            for i in permutation(N_conv):
                input = nn.Variable((1, in_sizes[i],
                                     feat_sizes[i], feat_sizes[i]))
                static_input = smo.Input(value=input)
                myconv = smo.Conv(parents=[static_input],
                                  in_channels=in_sizes[i],
                                  out_channels=out_sizes[i],
                                  kernel=(kernel_sizes[i], kernel_sizes[i]),
                                  pad=None, group=1
                                  )
                output = myconv()
                latency = estimation.get_estimation(myconv)

        latency += 0  # to avoid lint/flake8 error

        if run_also_ours_at_the_end is True:
            print('*********** NOW IT IS OUR TURN ***********')
            for i in range(N_conv):
                input = nn.Variable((1, in_sizes[i],
                                    feat_sizes[i], feat_sizes[i]))
                static_input = smo.Input(value=input)
                myconv = smo.Conv(parents=[static_input],
                                  in_channels=in_sizes[i],
                                  out_channels=out_sizes[i],
                                  kernel=(kernel_sizes[i], kernel_sizes[i]),
                                  pad=None, group=1
                                  )
                output = myconv()
                mean_time = estim_fwd(output, n_run=100) * 1000  # in ms
                print('Our_Conv : 100 :', mean_time, ':',
                      '[(1, ' + str(in_sizes[i]) + ', ' + str(feat_sizes[i]) +
                      ', ' + str(feat_sizes[i]) + ')]',
                      ':', out_sizes[i], ':', kernel_sizes[i]
                      )

    #  6 **************************
    if exp_nr == 6:
        import onnx
        load_onnx = False

        if len(sys.argv) > 2:
            INPUT_DIR = sys.argv[2]
        else:
            INPUT_DIR = './logs/zoph/one_net/'

        if len(sys.argv) > 3:
            load_onnx = True

        existing_networks = glob.glob(INPUT_DIR + '/*' + os.path.sep)
        all_nets_latencies = dict.fromkeys([])
        all_nets = dict.fromkeys([])
        net_idx = 0
        for network in existing_networks:
            all_blocks = glob.glob(network + '**/*.acclat', recursive=True)
            blocks = dict.fromkeys([])
            block_idx = 0

            this_net_accumulated_latency = 0.0
            this_net_accumulated_latency_of_convs = 0.0
            this_net_accumulated_latency_of_relus = 0.0
            this_net_accumulated_latency_of_bns = 0.0
            this_net_accumulated_latency_of_merges = 0.0
            this_net_accumulated_latency_of_pools = 0.0
            this_net_accumulated_latency_of_reshapes = 0.0
            this_net_accumulated_latency_of_affines = 0.0
            this_net_accumulated_latency_of_add2s = 0.0

            for block_lat in all_blocks:
                block = block_lat[:-7] + '.onnx'
                print('.... READING .... -->  ' + block)

                # Reading latency for each of the blocks of layers
                with open(block_lat, 'r') as f:
                    block_latency = float(f.read())

                this_net_accumulated_latency += block_latency

                # Layer-type-wise latencies tested
                # for Zoph
                # for Random Wired networks
                # for mobilenet

                layer_name = block.split('/')[-1].split('.')[-2]
                if layer_name.find('bn') != -1:
                    this_net_accumulated_latency_of_bns += block_latency
                elif layer_name.find('batchnorm') != -1:
                    this_net_accumulated_latency_of_bns += block_latency
                elif layer_name.find('relu') != -1:
                    this_net_accumulated_latency_of_relus += block_latency
                elif layer_name.find('conv') != -1:
                    this_net_accumulated_latency_of_convs += block_latency
                elif layer_name.find('merg') != -1:
                    this_net_accumulated_latency_of_merges += block_latency
                elif layer_name.find('pool') != -1:
                    this_net_accumulated_latency_of_pools += block_latency
                elif layer_name.find('con') != -1:  # from concat
                    this_net_accumulated_latency_of_merges += block_latency
                elif layer_name.find('reshape') != -1:
                    this_net_accumulated_latency_of_reshapes += block_latency
                elif layer_name.find('linear') != -1:
                    this_net_accumulated_latency_of_affines += block_latency
                elif layer_name.find('add2') != -1:
                    this_net_accumulated_latency_of_add2s += block_latency

                this_block = dict.fromkeys([])
                this_block['latency'] = block_latency

                if load_onnx:
                    # Interesting FIELDS in params.graph:
                    # 'input', 'name', 'node', 'output'
                    params = onnx.load(block)
                    this_block['name'] = params.graph.name
                    this_block['input'] = params.graph.input
                    this_block['output'] = params.graph.output
                    this_block['nodes'] = params.graph.node

                blocks[block_idx] = this_block
                block_idx += 1

            net_realat_file = network[:-1] + '.realat'
            with open(net_realat_file, 'r') as f:
                this_net_real_latency = float(f.read())

            net_acclat_file = network[:-1] + '.acclat'
            with open(net_acclat_file, 'r') as f:
                this_net_acc_latency = float(f.read())

            this_net = dict.fromkeys([])
            this_net['real_latency'] = this_net_real_latency
            this_net['accum_latency_graph'] = this_net_acc_latency
            this_net['accum_latency_module'] = this_net_accumulated_latency

            if load_onnx:
                net_file = network[:-1] + '.onnx'
                print('xxxx READING xxxx -->  ' + net_file)
                params = onnx.load(net_file)
                this_net['name'] = params.graph.name
                this_net['input'] = params.graph.input
                this_net['output'] = params.graph.output
                this_net['nodes'] = params.graph.node

            all_nets_latencies[net_idx] = [
                this_net_real_latency,
                this_net_acc_latency,
                this_net_accumulated_latency,
                this_net_accumulated_latency_of_convs,
                this_net_accumulated_latency_of_bns,
                this_net_accumulated_latency_of_relus,
                this_net_accumulated_latency_of_pools,
                this_net_accumulated_latency_of_merges,
                this_net_accumulated_latency_of_reshapes,
                this_net_accumulated_latency_of_affines,
                this_net_accumulated_latency_of_add2s,
                ]

            all_nets[net_idx] = this_net

            net_idx += 1

        # Compare accumulated latency to net latencies, do a plot:
        print('LATENCY Results from ' + INPUT_DIR)
        print('NETWORK, LAYER-WISE (by graph), ',
              'LAYER-WISE (by module), of CONVs, of BNs, of RELUs, ',
              'of POOLs, of MERGEs/CONCATs, of RESHAPEs, of AFFINEs, ',
              'of ADD2 layers'
              )
        for i in range(len(all_nets_latencies)):
            # print(all_nets_latencies[i])
            print(['%7.3f' % val for val in all_nets_latencies[i]])
Esempio n. 18
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Data Iterator
    tdata = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(args.max_iter):
        # Validation
        if i % int(n_train_samples / args.batch_size) == 0:
            ve = 0.
            for j in range(args.val_iter):
                image, label = vdata.next()
                input_image_valid["image"].d = image
                pred_valid.forward()
                ve += categorical_error(pred_valid.d, label)
            ve /= args.val_iter
            monitor_verr.add(i, ve)
        if int(i % args.model_save_interval) == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Solvers update
        solver.update()

        e = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_loss.add(i, loss_train.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter)))
Esempio n. 19
0
def half_test(rng,
              func,
              finputs,
              hinputs,
              func_args,
              func_kwargs,
              backward,
              ctx,
              func_name,
              atol=1e-1):

    # 0. Define utility functions
    def _filter_inputs(vinputs):
        return [v for v in vinputs if v is not None]

    def _zero_grad(vs):
        for v in vs:
            if v is None:
                continue
            v.grad.zero()

    def _get_grad_copy(vinputs, backward):
        return [
            i.g.copy() for i, b in zip(vinputs, backward)
            if b and i is not None
        ]

    def _set_output_grad_and_copy(os, grads=None):
        if grads is None:
            grads = [randn(rng, *o.shape) for o in os]
        for o, g in zip(os, grads):
            o.g = g
        return grads

    # 1. Create a float32 function.
    with nn.context_scope(ctx):
        o_f = force_tuple(func(*(finputs + func_args), **func_kwargs))
    if True in backward:
        grad_copy = _set_output_grad_and_copy(o_f)

    # 2. Get outputs of forward and backward of the float32 function.
    o_f[0].parent.forward(_filter_inputs(finputs), o_f)
    y_f = [o.d.copy() for o in o_f]
    if True in backward:
        _zero_grad(finputs)
        o_f[0].parent.backward(_filter_inputs(finputs), o_f)
        g_f = _get_grad_copy(finputs, backward)

    # 3. Create a float16 (half) function.
    ext, dtype = ctx.backend[0].split(':')
    assert dtype == 'float'
    ctx_h = ext_utils.get_extension_context(ext, type_config='half')
    ctx_h.device_id = ctx.device_id
    with nn.context_scope(ctx_h):
        o_h = force_tuple(func(*(hinputs + func_args), **func_kwargs))
    if True in backward:
        _set_output_grad_and_copy(o_h, grad_copy)

    # 4. Get outputs of forward and backward of the float16 function.
    o_h[0].parent.forward(_filter_inputs(hinputs), o_h)
    y_h = [o.d.copy() for o in o_h]
    if True in backward:
        _zero_grad(hinputs)
        o_h[0].parent.backward(_filter_inputs(hinputs), o_h)
        g_h = _get_grad_copy(hinputs, backward)

    # 5. Check if output values are close between function data types.
    for ff, hh in zip(y_f, y_h):
        # TODO: set tol param
        assert_allclose(
            ff,
            hh,
            atol=atol,
            err_msg="{} half forward test fails.".format(func_name))
    if True not in backward:
        return
    for ff, hh in zip(g_f, g_h):
        # TODO: set tol param
        assert_allclose(
            ff,
            hh,
            atol=atol,
            err_msg="{} half backward test fails.".format(func_name))
Esempio n. 20
0
if __name__ == '__main__':

    parser = make_parser()
    config = read_yaml(os.path.join('configs', 'gender.yaml'))
    args = parser.parse_args()
    config.nnabla_context.device_id = args.device_id
    config.gender_faces.data_dir = args.data_root
    config.train.save_path = args.save_path
    config.train.batch_size = args.batch_size
    config.model.g_n_scales = args.g_n_scales
    config.model.d_n_scales = args.d_n_scales

    # nn.set_auto_forward(True)

    ctx = get_extension_context(config.nnabla_context.ext_name)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(ctx)

    image_shape = tuple(x * config.model.g_n_scales
                        for x in config.model.base_image_shape)

    if args.face_morph:
        di = get_data_iterator_mix(args.data_root, comm,
                                   config.train.batch_size, image_shape)
    else:
        di = get_data_iterator_attribute(args.data_root, comm,
                                         config.train.batch_size, image_shape)

    if args.load_path:
        nn.load_parameters(args.load_path)
Esempio n. 21
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    lambda_ = args.lambda_

    # Model
    # generator loss
    z = nn.Variable([batch_size, latent])
    x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True)
    p_fake = discriminator(x_fake, maps=maps)
    loss_gen = gan_loss(p_fake).apply(persistent=True)
    # discriminator loss
    p_fake = discriminator(x_fake, maps=maps)
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, maps=maps)
    loss_dis = gan_loss(p_fake, p_real).apply(persistent=True)
    # gradient penalty
    eps = F.rand(shape=[batch_size, 1, 1, 1])
    x_rmix = eps * x_real + (1.0 - eps) * x_fake
    p_rmix = discriminator(x_rmix, maps=maps)
    x_rmix.need_grad = True  # Enabling gradient computation for double backward
    grads = nn.grad([p_rmix], [x_rmix])
    l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads]
    gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms])
    loss_dis += lambda_ * gp
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    x_test = generator(z_test, maps=maps, test=True,
                       up=args.up).apply(persistent=True)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)

    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_cri = MonitorSeries("Negative Critic Loss",
                                     monitor,
                                     interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=batch_size,
                                                interval=1,
                                                normalize_method=denormalize)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Data Iterator
    di = data_iterator_cifar10(batch_size, True)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need backward to generator
        for _ in range(args.n_critic):
            solver_dis.zero_grad()
            x_real.d = di.next()[0] / 127.5 - 1.0
            z.d = np.random.randn(batch_size, latent)
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(clear_buffer=True)
            solver_dis.update()

        # Train generator
        x_fake.need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        z.d = np.random.randn(batch_size, latent)
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()
        # Monitor
        monitor_loss_gen.add(i, loss_gen.d)
        monitor_loss_cri.add(i, -loss_dis.d)
        monitor_time.add(i)

        # Save
        if i % args.save_interval == 0:
            monitor_image_tile_train.add(i, x_fake)
            monitor_image_tile_test.add(i, x_test)
            nn.save_parameters(
                os.path.join(args.monitor_path, "params_{}.h5".format(i)))

    # Last
    x_test.forward(clear_buffer=True)
    nn.save_parameters(
        os.path.join(args.monitor_path, "params_{}.h5".format(i)))
    monitor_image_tile_train.add(i, x_fake)
    monitor_image_tile_test.add(i, x_test)
Esempio n. 22
0
I = import_module("nnabla-examples/mnist-collection/dcgan")
I.__file__

# これで、 mnist-collection/dcgan.pyの内部にアクセスできるようになった。 今回の例ではハイ
# パーパラメータが設定されているのでそれに倣う。

source = inspect.getsource(I)
print(source[source.index("if __name__"):])

max_iter = 20000
learning_rate = 0.0002
batch_size = 64
weight_decay = 0.0001

# コンテキストを設定する。
context = get_extension_context("cudnn", device_id=0, type_config="float")
nn.set_default_context(context)
nn.get_current_context()

# Fakeパスの設定
z = nn.Variable([batch_size, 100, 1, 1])
fake = I.generator(z)
fake.persistent = True  # Not to clear at backward
pred_fake = I.discriminator(fake)
loss_gen = F.mean(
    F.sigmoid_cross_entropy(pred_fake, F.constant(1, pred_fake.shape)))
fake_dis = fake.get_unlinked_variable(need_grad=True)
fake_dis.need_grad = True  # TODO: Workaround until v1.0.2
pred_fake_dis = I.discriminator(fake_dis)
loss_dis = F.mean(
    F.sigmoid_cross_entropy(pred_fake_dis, F.constant(0, pred_fake_dis.shape)))
Esempio n. 23
0
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context, import_extension_module
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = import_extension_module(args.context)

    # read label file
    f = open(args.label_file_path, "r")
    labels_dict = f.readlines()

    # Load parameters
    _ = nn.load_parameters(args.model_load_path)

    # Build a Deeplab v3+ network
    x = nn.Variable((1, 3, args.image_height, args.image_width),
                    need_grad=False)
    y = net.deeplabv3plus_model(x,
                                args.output_stride,
                                args.num_class,
                                test=True)

    # preprocess image
    image = imageio.imread(args.test_image_file, as_gray=False, pilmode="RGB")
    #image = imread(args.test_image_file).astype('float32')
    orig_h, orig_w, orig_c = image.shape
    old_size = (orig_h, orig_w)

    input_array = image_preprocess.preprocess_image_and_label(
        image,
        label=None,
        target_width=args.image_width,
        target_height=args.image_height,
        train=False)
    print('Input', input_array.shape)
    input_array = np.transpose(input_array, (2, 0, 1))
    input_array = np.reshape(
        input_array,
        (1, input_array.shape[0], input_array.shape[1], input_array.shape[2]))

    # Compute inference and inference time
    t = time.time()

    x.d = input_array
    y.forward(clear_buffer=True)
    print("done")
    available_devices = ext.get_devices()
    ext.device_synchronize(available_devices[0])
    ext.clear_memory_cache()

    elapsed = time.time() - t
    print('Inference time : %s seconds' % (elapsed))

    output = np.argmax(y.d, axis=1)  # (batch,h,w)

    # Apply post processing
    post_processed = post_process(output[0], old_size,
                                  (args.image_height, args.image_width))

    # Get the classes predicted
    predicted_classes = np.unique(post_processed)
    for i in range(predicted_classes.shape[0]):
        print('Classes Segmented: ', labels_dict[predicted_classes[i]])

    # Visualize inference result
    visualize(post_processed)
Esempio n. 24
0
def valid():
    """
    Main script for validation.

    """

    args = get_args()
    n_valid_samples = 50000
    num_classes = 1000
    assert n_valid_samples % args.batch_size == 0, \
        "Set batch_size such that n_valid_samples (50000) can be devided by batch_size. \Batch size is now set as {}".format(
            args.batch_size)

    # Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Pipelines and Iterators for validation
    device_id = int(args.device_id)
    val_pipes = [
        ValPipeline(args.batch_size,
                    args.num_threads,
                    device_id,
                    args.val_cachefile_dir,
                    args.val_list,
                    seed=device_id,
                    num_gpu=1)
    ]
    val_pipes[0].build()
    vdata = DALIClassificationIterator(val_pipes,
                                       val_pipes[0].epoch_size("Reader"),
                                       auto_reset=True,
                                       stop_at_epoch=False)

    # Network for validation
    nn.load_parameters(args.model_load_path)
    v_model = get_model(args, num_classes, 1, args.accum_grad, test=True)
    v_e = F.mean(F.top_n_error(v_model.pred, v_model.label, n=args.top_n))

    # Monitors
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=1)

    # Validation
    ve_local = 0.
    val_iter_local = n_valid_samples // args.batch_size
    for i in range(val_iter_local):
        nextImage, nextLabel = vdata.next()
        v_model.image.data.copy_from(nextImage)
        v_model.label.data.copy_from(nextLabel)
        v_model.image.data.cast(np.float, ctx)
        v_model.label.data.cast(np.int32, ctx)
        v_e.forward(clear_buffer=True)
        nn.logger.info("validation error is {} at {}-th batch".format(
            v_e.d, i))
        ve_local += v_e.d.copy()
    ve_local /= val_iter_local

    monitor_verr.add(0, ve_local)
    monitor_vtime.add(0)
Esempio n. 25
0
import pytest
import numpy as np
import model
import test
import nnabla.functions as F
from nnabla.ext_utils import get_extension_context
import nnabla as nn

ctx = get_extension_context('cpu')
nn.set_default_context(ctx)


@pytest.fixture(params=[4096, 4096 * 10])
def nb_timesteps(request):
    return int(request.param)


@pytest.fixture(params=[1, 2, 3])
def nb_channels(request):
    return request.param


@pytest.fixture(params=[1])
def nb_samples(request):
    return request.param


@pytest.fixture(params=[1024, 2048, 4096])
def nfft(request):
    return int(request.param)
Esempio n. 26
0
parser = argparse.ArgumentParser(description='Encoder-decoder model training.')
parser.add_argument('--context',
                    '-c',
                    type=str,
                    default='cpu',
                    help='You can choose cpu or cudnn.')
parser.add_argument('--device',
                    '-d',
                    type=int,
                    default=0,
                    help='You can choose the device id when you use cudnn.')
args = parser.parse_args()

if args.context == 'cudnn':
    from nnabla.ext_utils import get_extension_context
    ctx = get_extension_context('cudnn', device_id=args.device)
    nn.set_default_context(ctx)

max_len: int = 400
batch_size: int = 64
embedding_size: int = 300
hidden_size: int = 300
da: int = 350
r: int = 30
output_mlp_size: int = 3000
max_epoch: int = 20
vocab_size: int = 20000
dropout_ratio: float = 0.3
attention_penalty_coef: float = 0.03
l2_penalty_coef: float = 1e-4
Esempio n. 27
0
                        type=str,
                        default=None,
                        help='Path to the reference audio.')
    parser.add_argument("--output",
                        "-o",
                        type=str,
                        default=None,
                        help="Path to the converted audio file.")
    args = parser.parse_args()

    # setup context for nnabla
    if args.device_id != '-1':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id

    # setup the context
    ctx = get_extension_context(args.context, device_id='0')
    nn.set_default_context(ctx)

    hp.batch_size = 1
    model = NVCNet(hp)
    model.training = False
    model.load_parameters(args.model)

    x_audio = lr.load(args.input, sr=hp.sr)[0]  # read input utterance
    y_audio = lr.load(args.reference, sr=hp.sr)[0]  # read reference utterance

    x = nn.Variable.from_numpy_array(x_audio[None, None, ...])
    y = nn.Variable.from_numpy_array(y_audio[None, None, ...])
    out = model(x, y)

    out.forward(clear_buffer=True)
Esempio n. 28
0
def train(args):
    # Settings
    b, c, h, w = 1, 3, 256, 256
    beta1 = 0.5
    beta2 = 0.999
    pool_size = 50
    lambda_recon = args.lambda_recon
    lambda_idt = args.lambda_idt
    base_lr = args.learning_rate
    init_method = args.init_method

    # Context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Inputs
    x_raw = nn.Variable([b, c, h, w], need_grad=False)
    y_raw = nn.Variable([b, c, h, w], need_grad=False)
    x_real = image_augmentation(x_raw)
    y_real = image_augmentation(y_raw)
    x_history = nn.Variable([b, c, h, w])
    y_history = nn.Variable([b, c, h, w])
    x_real_test = nn.Variable([b, c, h, w], need_grad=False)
    y_real_test = nn.Variable([b, c, h, w], need_grad=False)

    # Models for training
    # Generate
    y_fake = models.g(x_real, unpool=args.unpool, init_method=init_method)
    x_fake = models.f(y_real, unpool=args.unpool, init_method=init_method)
    y_fake.persistent, x_fake.persistent = True, True
    # Reconstruct
    x_recon = models.f(y_fake, unpool=args.unpool, init_method=init_method)
    y_recon = models.g(x_fake, unpool=args.unpool, init_method=init_method)
    # Discriminate
    d_y_fake = models.d_y(y_fake, init_method=init_method)
    d_x_fake = models.d_x(x_fake, init_method=init_method)
    d_y_real = models.d_y(y_real, init_method=init_method)
    d_x_real = models.d_x(x_real, init_method=init_method)
    d_y_history = models.d_y(y_history, init_method=init_method)
    d_x_history = models.d_x(x_history, init_method=init_method)

    # Models for test
    y_fake_test = models.g(
        x_real_test, unpool=args.unpool, init_method=init_method)
    x_fake_test = models.f(
        y_real_test, unpool=args.unpool, init_method=init_method)
    y_fake_test.persistent, x_fake_test.persistent = True, True
    # Reconstruct
    x_recon_test = models.f(
        y_fake_test, unpool=args.unpool, init_method=init_method)
    y_recon_test = models.g(
        x_fake_test, unpool=args.unpool, init_method=init_method)

    # Losses
    # Reconstruction Loss
    loss_recon = models.recon_loss(x_recon, x_real) \
        + models.recon_loss(y_recon, y_real)
    # Generator loss
    loss_gen = models.lsgan_loss(d_y_fake) \
        + models.lsgan_loss(d_x_fake) \
        + lambda_recon * loss_recon
    # Identity loss
    if lambda_idt != 0:
        logger.info("Identity loss was added.")
        # Identity
        y_idt = models.g(y_real, unpool=args.unpool, init_method=init_method)
        x_idt = models.f(x_real, unpool=args.unpool, init_method=init_method)
        loss_idt = models.recon_loss(x_idt, x_real) \
            + models.recon_loss(y_idt, y_real)
        loss_gen += lambda_recon * lambda_idt * loss_idt
    # Discriminator losses
    loss_dis_y = models.lsgan_loss(d_y_history, d_y_real)
    loss_dis_x = models.lsgan_loss(d_x_history, d_x_real)

    # Solvers
    solver_gen = S.Adam(base_lr, beta1, beta2)
    solver_dis_x = S.Adam(base_lr, beta1, beta2)
    solver_dis_y = S.Adam(base_lr, beta1, beta2)
    with nn.parameter_scope('generator'):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope('discriminator'):
        with nn.parameter_scope("x"):
            solver_dis_x.set_parameters(nn.get_parameters())
        with nn.parameter_scope("y"):
            solver_dis_y.set_parameters(nn.get_parameters())

    # Datasets
    rng = np.random.RandomState(313)
    ds_train_B = cycle_gan_data_source(
        args.dataset, train=True, domain="B", shuffle=True, rng=rng)
    ds_train_A = cycle_gan_data_source(
        args.dataset, train=True, domain="A", shuffle=True, rng=rng)
    ds_test_B = cycle_gan_data_source(
        args.dataset, train=False, domain="B", shuffle=False, rng=rng)
    ds_test_A = cycle_gan_data_source(
        args.dataset, train=False, domain="A", shuffle=False, rng=rng)
    di_train_B = cycle_gan_data_iterator(ds_train_B, args.batch_size)
    di_train_A = cycle_gan_data_iterator(ds_train_A, args.batch_size)
    di_test_B = cycle_gan_data_iterator(ds_test_B, args.batch_size)
    di_test_A = cycle_gan_data_iterator(ds_test_A, args.batch_size)

    # Monitors
    monitor = Monitor(args.monitor_path)

    def make_monitor(name):
        return MonitorSeries(name, monitor, interval=1)
    monitor_loss_gen = make_monitor('generator_loss')
    monitor_loss_dis_x = make_monitor('discriminator_B_domain_loss')
    monitor_loss_dis_y = make_monitor('discriminator_A_domain_loss')

    def make_monitor_image(name):
        return MonitorImage(name, monitor, interval=1,
                            normalize_method=lambda x: (x + 1.0) * 127.5)
    monitor_train_gx = make_monitor_image('fake_images_train_A')
    monitor_train_fy = make_monitor_image('fake_images_train_B')
    monitor_train_x_recon = make_monitor_image('fake_images_B_recon_train')
    monitor_train_y_recon = make_monitor_image('fake_images_A_recon_train')
    monitor_test_gx = make_monitor_image('fake_images_test_A')
    monitor_test_fy = make_monitor_image('fake_images_test_B')
    monitor_test_x_recon = make_monitor_image('fake_images_recon_test_B')
    monitor_test_y_recon = make_monitor_image('fake_images_recon_test_A')
    monitor_train_list = [
        (monitor_train_gx, y_fake),
        (monitor_train_fy, x_fake),
        (monitor_train_x_recon, x_recon),
        (monitor_train_y_recon, y_recon),
        (monitor_loss_gen, loss_gen),
        (monitor_loss_dis_x, loss_dis_x),
        (monitor_loss_dis_y, loss_dis_y),
    ]
    monitor_test_list = [
        (monitor_test_gx, y_fake_test),
        (monitor_test_fy, x_fake_test),
        (monitor_test_x_recon, x_recon_test),
        (monitor_test_y_recon, y_recon_test)]

    # ImagePool
    pool_x = ImagePool(pool_size)
    pool_y = ImagePool(pool_size)

    # Training loop
    epoch = 0
    n_images = np.max([ds_train_B.size, ds_train_A.size]
                      )  # num. images for each domain
    max_iter = args.max_epoch * n_images // args.batch_size
    for i in range(max_iter):
        # Validation
        if int((i+1) % (n_images // args.batch_size)) == 0:
            logger.info("Mode:Test,Epoch:{}".format(epoch))
            # Monitor for train
            for monitor, v in monitor_train_list:
                monitor.add(i, v.d)
            # Use training graph since there are no test mode
            x_data, _ = di_test_B.next()
            y_data, _ = di_test_A.next()
            x_real_test.d = x_data
            y_real_test.d = y_data
            x_recon_test.forward()
            y_recon_test.forward()
            # Monitor for test
            for monitor, v in monitor_test_list:
                monitor.add(i, v.d)
            # Save model
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
            # Learning rate decay
            for solver in [solver_gen, solver_dis_x, solver_dis_y]:
                linear_decay(solver, base_lr, epoch, args.max_epoch)
            epoch += 1

        # Get data
        x_data, _ = di_train_B.next()
        y_data, _ = di_train_A.next()
        x_raw.d = x_data
        y_raw.d = y_data

        # Train Generators
        loss_gen.forward(clear_no_need_grad=False)
        solver_gen.zero_grad()
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()

        # Insert and Get to/from pool
        x_history.d = pool_x.insert_then_get(x_fake.d)
        y_history.d = pool_y.insert_then_get(y_fake.d)

        # Train Discriminator Y
        loss_dis_y.forward(clear_no_need_grad=False)
        solver_dis_y.zero_grad()
        loss_dis_y.backward(clear_buffer=True)
        solver_dis_y.update()

        # Train Discriminator X
        loss_dis_x.forward(clear_no_need_grad=False)
        solver_dis_x.zero_grad()
        loss_dis_x.backward(clear_buffer=True)
        solver_dis_x.update()
def classification_svd():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction_slim

    # TRAIN
    reference = "reference"
    slim = "slim"
    rrate = 0.5  # reduction rate
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create `reference` and "slim" prediction graph.
    model_load_path = args.model_load_path
    pred = mnist_cnn_prediction(image, scope=slim, rrate=rrate, test=False)
    pred.persistent = True

    # Decompose and set parameters
    decompose_network_and_set_params(model_load_path, reference, slim, rrate)
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create reference prediction graph.
    vpred = mnist_cnn_prediction(vimage, scope=slim, rrate=rrate, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    with nn.parameter_scope(slim):
        solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    best_ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Esempio n. 30
0
def generate(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn
    threshold = args.truncation_threshold

    # Model
    nn.load_parameters(args.model_load_path)
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
        .apply(persistent=True)

    # Generate All
    if args.generate_all:
        # Monitor
        monitor = Monitor(args.monitor_path)
        name = "Generated Image Tile All"
        monitor_image = MonitorImageTile(name,
                                         monitor,
                                         interval=1,
                                         num_images=args.batch_size,
                                         normalize_method=normalize_method)

        # Generate images for all classes
        for class_id in range(args.n_classes):
            # Generate
            z_data = resample(batch_size, latent, threshold)
            y_data = generate_one_class(class_id, batch_size)

            z.d = z_data
            y_fake.d = y_data
            x_fake.forward(clear_buffer=True)
            monitor_image.add(class_id, x_fake.d)
        return

    # Generate Indivisually
    monitor = Monitor(args.monitor_path)
    name = "Generated Image Tile {}".format(
        args.class_id) if args.class_id != -1 else "Generated Image Tile"
    monitor_image_tile = MonitorImageTile(name,
                                          monitor,
                                          interval=1,
                                          num_images=args.batch_size,
                                          normalize_method=normalize_method)
    name = "Generated Image {}".format(
        args.class_id) if args.class_id != -1 else "Generated Image"
    monitor_image = MonitorImage(name,
                                 monitor,
                                 interval=1,
                                 num_images=args.batch_size,
                                 normalize_method=normalize_method)
    z_data = resample(batch_size, latent, threshold)
    y_data = generate_random_class(n_classes, batch_size) if args.class_id == -1 else \
        generate_one_class(args.class_id, batch_size)
    z.d = z_data
    y_fake.d = y_data
    x_fake.forward(clear_buffer=True)
    monitor_image.add(0, x_fake.d)
    monitor_image_tile.add(0, x_fake.d)