Esempio n. 1
0
 def split_dataset(dataset):
     gpu_datasets = split_dataset_random(dataset, len(dataset) // 2)
     if not len(gpu_datasets[0]) == len(gpu_datasets[1]):
         adapted_second_split = split_dataset(gpu_datasets[1],
                                              len(gpu_datasets[0]))[0]
         gpu_datasets = (gpu_datasets[0], adapted_second_split)
     return gpu_datasets
    base_optimizer = chainer.optimizers.Adam(alpha=args.learning_rate)
    optimizer = base_optimizer
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))
    optimizer.add_hook(chainer.optimizer.GradientClipping(2))

    # freeze localization net
    if args.freeze_localization:
        localization_net.disable_update()

    # if we are using more than one GPU, we need to evenly split the datasets
    if len(args.gpus) > 1:
        gpu_datasets = split_dataset_n_random(train_dataset, len(args.gpus))
        if not len(gpu_datasets[0]) == len(gpu_datasets[-1]):
            adapted_second_split = split_dataset(gpu_datasets[-1], len(gpu_datasets[0]))[0]
            gpu_datasets[-1] = adapted_second_split
    else:
        gpu_datasets = [train_dataset]

    train_iterators = [chainer.iterators.MultiprocessIterator(dataset, args.batch_size) for dataset in gpu_datasets]
    validation_iterator = chainer.iterators.MultiprocessIterator(validation_dataset, args.batch_size, repeat=False)

    # use the MultiProcessParallelUpdater in order to harness the full power of data parallel computation
    updater = MultiprocessParallelUpdater(train_iterators, optimizer, devices=args.gpus)

    log_dir = os.path.join(args.log_dir, "{}_{}".format(datetime.datetime.now().isoformat(), args.log_name))
    args.log_dir = log_dir

    # backup current file
    if not os.path.exists(log_dir):
 def split_dataset(self, dataset):
     gpu_datasets = split_dataset_n_random(dataset, len(self.gpus))
     if not len(gpu_datasets[0]) == len(gpu_datasets[-1]):
         adapted_second_split = split_dataset(gpu_datasets[-1], len(gpu_datasets[0]))[0]
         gpu_datasets[-1] = adapted_second_split
     return gpu_datasets