def data_iterator_tiny_digits(digits, batch_size=64, shuffle=False, rng=None): def load_func(index): """Loading an image and its label""" img = digits.images[index] label = digits.target[index] return img[None], np.array([label]).astype(np.int32) return data_iterator_simple(load_func, digits.target.shape[0], batch_size, shuffle, rng, with_file_cache=False)
def edges2shoes_data_iterator(img_path, batch_size=1, normalize_method=lambda x: (x - 127.5) / 127.5, num_samples=-1): imgs = glob.glob("{}/*.jpg".format(img_path)) if num_samples == -1: num_samples = len(imgs) else: logger.info( "Num. of data ({}) is used for debugging".format(num_samples)) def load_func(i): img = scipy.misc.imread(imgs[i], mode="RGB") img = normalize_method(img) h, w, c = img.shape img_A = img[:, 0:w // 2, :].transpose((2, 0, 1)) img_B = img[:, w // 2:, :].transpose((2, 0, 1)) return img_A, img_B return data_iterator_simple(load_func, num_samples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False)
def data_iterator_fewshot(img_path, batch_size, imsize=(256, 256), num_samples=1000, shuffle=True, rng=None): imgs = glob.glob("{}/**/*.jpg".format(img_path), recursive=True) if num_samples == -1: num_samples = len(imgs) else: logger.info( "Num. of data ({}) is used for debugging".format(num_samples)) def load_func(i): img = imread(imgs[i], num_channels=3) img = imresize(img, imsize).transpose(2, 0, 1) img = img / 255. * 2. - 1. return img, i return data_iterator_simple(load_func, num_samples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False, with_memory_cache=False)
def data_iterator_celeba(img_path, batch_size, imsize=(128, 128), num_samples=100, shuffle=True, rng=None): imgs = glob.glob("{}/*.png".format(img_path)) if num_samples == -1: num_samples = len(imgs) else: logger.info( "Num. of data ({}) is used for debugging".format(num_samples)) def load_func(i): cx = 89 cy = 121 img = imread(imgs[i]) img = img[cy - 64:cy + 64, cx - 64:cx + 64, :].transpose(2, 0, 1) / 255. img = img * 2. - 1. return img, None return data_iterator_simple(load_func, num_samples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False)
def test_sliced_data_iterator(test_data_csv_png_10, num_of_slices, size, batch_size, shuffle): def test_load_func(position): return np.full((1), position, dtype=np.float32) di = data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle) import six if six.PY2: from fractions import gcd else: from math import gcd def lcm(a, b): return abs(a * b) / gcd(a, b) if a and b else 0 max_epoch = lcm(batch_size, size) / size all_data = [] for slice_pos in range(num_of_slices): sliced_di = di.slice(rng=None, num_of_slices=num_of_slices, slice_pos=slice_pos) sliced_data = {} while True: current_epoch = sliced_di.epoch if current_epoch > max_epoch + 1: break data = sliced_di.next() if current_epoch not in sliced_data: sliced_data[current_epoch] = [] for dat in data: for d in dat: sliced_data[current_epoch].append(d) all_data.append(sliced_data) epochs = {} for slice_pos, sliced_data in enumerate(all_data): for epoch in sorted(sliced_data.keys()): if epoch not in epochs: epochs[epoch] = [] epochs[epoch].append(set(sliced_data[epoch])) for epoch in sorted(epochs.keys()): x0 = epochs[epoch][0] acceptable_size = batch_size amount = size // num_of_slices if acceptable_size < amount: acceptable_size = amount for dup in [x0 & x for x in epochs[epoch][1:]]: assert len(dup) < amount
def data_iterator_sr(num_examples, batch_size, gt_image, lq_image, train, shuffle, rng=None): from args import get_config conf = get_config() def dataset_load_func(i): # get images from the list scale = conf.train.scale gt_size = conf.train.gt_size gt_img = read_image(gt_image[i]) lq_img = read_image(lq_image[i]) if not train: gt_img = modcrop(gt_img, scale) gt_img = channel_convert(gt_img.shape[2], gt_img, color="RGB") if train: # randomly crop H, W, C = lq_img.shape lq_size = gt_size // scale rnd_h = random.randint(0, max(0, H - lq_size)) rnd_w = random.randint(0, max(0, W - lq_size)) lq_img = lq_img[rnd_h:rnd_h + lq_size, rnd_w:rnd_w + lq_size, :] rnd_h_gt, rnd_w_gt = int(rnd_h * scale), int(rnd_w * scale) gt_img = gt_img[rnd_h_gt:rnd_h_gt + gt_size, rnd_w_gt:rnd_w_gt + gt_size, :] # horizontal and vertical flipping and rotation hflip, rot = [True, True] hflip = hflip and random.random() < 0.5 vflip = rot and random.random() < 0.5 rot90 = rot and random.random() < 0.5 lq_img = augment(lq_img, hflip, rot90, vflip) gt_img = augment(gt_img, hflip, rot90, vflip) lq_img = channel_convert(C, [lq_img], color="RGB")[0] # BGR to RGB and HWC to CHW if gt_img.shape[2] == 3: gt_img = gt_img[:, :, [2, 1, 0]] lq_img = lq_img[:, :, [2, 1, 0]] gt_img = np.ascontiguousarray(np.transpose(gt_img, (2, 0, 1))) lq_img = np.ascontiguousarray(np.transpose(lq_img, (2, 0, 1))) return gt_img, lq_img return data_iterator_simple(dataset_load_func, num_examples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False, with_memory_cache=False)
def get_data_loader(attr_path, image_dir, batch_size, batch_size_valid, image_size, attribute='Bangs'): dataset, attr2idx, idx2attr = get_data_dict(attr_path, [attribute]) np.random.seed(313) np.random.shuffle(dataset) test_dataset = dataset[-4000:] training_dataset = dataset[:-4000] print("Use {} images for training.".format(len(training_dataset))) # create data iterators. load_func = functools.partial(stargan_load_func, dataset=training_dataset, image_dir=image_dir, image_size=image_size, crop_size=image_size) data_iterator = data_iterator_simple(load_func, len(training_dataset), batch_size, with_file_cache=False, with_memory_cache=False) load_func_test = functools.partial(stargan_load_func, dataset=test_dataset, image_dir=image_dir, image_size=image_size, crop_size=image_size) test_data_iterator = data_iterator_simple(load_func_test, len(test_dataset), batch_size_valid, with_file_cache=False, with_memory_cache=False) return data_iterator, test_data_iterator
def munit_data_iterator(img_path, batch_size=1, image_size=256, num_samples=-1, normalize_method=lambda x: (x - 127.5) / 127.5, shuffle=True, rng=None): imgs = [] if type(img_path) == list: for p in img_path: imgs.append(p) elif os.path.isdir(img_path): imgs += glob.glob("{}/*.jpg".format(img_path)) imgs += glob.glob("{}/*.JPG".format(img_path)) imgs += glob.glob("{}/*.jpeg".format(img_path)) imgs += glob.glob("{}/*.JPEG".format(img_path)) imgs += glob.glob("{}/*.png".format(img_path)) imgs += glob.glob("{}/*.PNG".format(img_path)) elif img_path.endswith(".jpg") or img_path.endswith(".JPG") \ or img_path.endswith(".jpeg") or img_path.endswith(".JPEG") \ or img_path.endswith(".png") or img_path.endswith(".PNG"): imgs.append(img_path) else: raise ValueError( "Path specified is not `directory path` or `list of files`.") if num_samples == -1: num_samples = len(imgs) else: logger.info( "Num. of data ({}) is used for debugging".format(num_samples)) def load_func(i): img = scipy.misc.imread(imgs[i], mode="RGB") img = scipy.misc.imresize(img, (image_size, image_size)) img = normalize_method(img) img = img.transpose((2, 0, 1)) return img, None return data_iterator_simple(load_func, num_samples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False)
def data_iterator_segmentation(num_examples, batch_size, image_path_file, label_path_file, rng=None, target_width=513, target_height=513, train=True): image_paths = load_paths(image_path_file) label_paths = load_paths(label_path_file) def image_label_load_func(i): ''' Returns: image: c x h x w array label: c x h x w array mask: c x h x w array ''' img = cv2.imread(image_paths[i]).astype('float32') b, g, r = cv2.split(img) img = cv2.merge([r, g, b]) if 'png' in label_paths[i]: lab = imageio.imread(label_paths[i], as_gray=False, pilmode="RGB").astype('int32') else: lab = np.load(label_paths[i], allow_pickle=True).astype('int32') if lab.ndim == 2: lab = lab[..., None] # Compute image preprocessing time #t = time.time() img, lab, mask = image_preprocess.preprocess_image_and_label( img, lab, target_width, target_height, train=train) #elapsed = time.time() - t return np.rollaxis(img, 2), np.rollaxis(lab, 2), np.rollaxis(mask, 2) return data_iterator_simple(image_label_load_func, num_examples, batch_size, shuffle=True, rng=rng, with_file_cache=False, with_memory_cache=False)
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle): src_data = [] with open(test_data_csv_png_10) as f: for l in f.readlines(): values = [x.strip() for x in l.split(',')] img_file_name = os.path.join( os.path.dirname(test_data_csv_png_10), values[0]) if os.path.exists(img_file_name): with open(img_file_name, 'rb') as img_file: d = load_image(img_file) src_data.append((d, [int(values[1])])) def test_load_func(position): return src_data[position] size = len(src_data) with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle) as di: check_data_iterator_result(di, batch_size, shuffle, False)
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle, stop_exhausted): src_data = [] with open(test_data_csv_png_10) as f: for l in f.readlines(): values = [x.strip() for x in l.split(',')] img_file_name = os.path.join(os.path.dirname(test_data_csv_png_10), values[0]) if os.path.exists(img_file_name): with open(img_file_name, 'rb') as img_file: d = load_image(img_file) src_data.append((d, [int(values[1])])) def test_load_func(position): return src_data[position] def end_epoch(epoch): print(f"{epoch} == {expect_epoch[0]}") assert epoch == expect_epoch[0], "Failed for end epoch check" assert threading.current_thread( ).ident == main_thread, "Failed for thread checking" def begin_epoch(epoch): print(f"{epoch} == {expect_epoch[0]}") assert epoch == expect_epoch[0], "Failed for begin epoch check" assert threading.current_thread( ).ident == main_thread, "Failed for thread checking" size = len(src_data) main_thread = threading.current_thread().ident expect_epoch = [0] with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle, stop_exhausted=stop_exhausted) as di: if batch_size // size == 0: di.register_epoch_end_callback(begin_epoch) di.register_epoch_end_callback(end_epoch) di.register_epoch_end_callback(begin_epoch) di.register_epoch_end_callback(end_epoch) check_data_iterator_result(di, batch_size, shuffle, False, stop_exhausted, expect_epoch)
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle, stop_exhausted): src_data = [] with open(test_data_csv_png_10) as f: for l in f.readlines(): values = [x.strip() for x in l.split(',')] img_file_name = os.path.join( os.path.dirname(test_data_csv_png_10), values[0]) if os.path.exists(img_file_name): with open(img_file_name, 'rb') as img_file: d = load_image(img_file) src_data.append((d, [int(values[1])])) def test_load_func(position): return src_data[position] size = len(src_data) with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle, stop_exhausted=stop_exhausted) as di: check_data_iterator_result( di, batch_size, shuffle, False, stop_exhausted)
def data_iterator_celeba(img_path, attributes, transform=None, batch_size=32, num_samples=-1, shuffle=True, rng=None): """ create celebA data iterator Args: img_path(list) : list of image paths attributes (dict) : attribute list transform : transform the image(data augmentation) batch_size (int) : number of samples contained in each generated batch num_samples (int) : number of samples taken in data loader (if num_samples=-1, it will take all the images in the dataset) shuffle (bool) : shuffle the data Returns: simple data iterator """ imgs = img_path attr = attributes if num_samples == -1: num_samples = len(imgs) else: logger.info( "Num. of data ({}) is used for debugging".format(num_samples)) def load_func(i): pillow_image = Image.open(imgs[i]) image = np.array(pillow_image) transformed_image = transform(image=image)['image'].transpose(2, 0, 1) return transformed_image, attr[imgs[i]] return data_iterator_simple(load_func, num_samples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False)
def data_iterator_segmentation(batch_size, image_paths, label_paths, rng=None, train=True): ''' Returns a data iterator object for semantic image segmentation dataset. Args: batch_size (int): Batch size image_paths (list of str): A list of image paths label_paths (list of str): A list of label image paths rng (None or numpy.random.RandomState): A random number generator used in shuffling dataset and data augmentation. train (bool): It performs random data augmentation as preprocessing if train is True. num_classs (int): Number of classes. Requierd if `label_mask_transformer` is not passed. ''' assert len(image_paths) == len(label_paths) num_examples = len(image_paths) def image_label_load_func(i): ''' Returns: image: c x h x w array label: c x h x w array ''' img = cv2.imread(image_paths[i], cv2.IMREAD_COLOR) lab = palette_png_reader(label_paths[i]) img, lab = image_preprocess.preprocess_image_and_label(img, lab, rng=rng) return img, lab return data_iterator_simple(image_label_load_func, num_examples, batch_size, shuffle=train, rng=rng)
num_train_batch = len(x_train) // batch_size num_dev_batch = len(x_test) // batch_size def load_train_func(index): return x_train[index], y_train[index] def load_dev_func(index): return x_test[index], y_test[index] train_data_iter = data_iterator_simple(load_train_func, len(x_train), batch_size, shuffle=True, with_file_cache=False) dev_data_iter = data_iterator_simple(load_dev_func, len(x_test), batch_size, shuffle=True, with_file_cache=False) def build_self_attention_model(train=True): x = nn.Variable((batch_size, max_len)) t = nn.Variable((batch_size, 1)) mask = get_mask(x) attention_mask = (F.constant(1, shape=mask.shape) - mask) * F.constant( np.finfo(np.float32).min, shape=mask.shape)
num_train_batch = len(x_train) // batch_size num_valid_batch = len(x_valid) // batch_size def load_train_func(index): return x_train[index], y_train[index] def load_valid_func(index): return x_valid[index], y_valid[index] train_data_iter = data_iterator_simple(load_train_func, len(x_train), batch_size, shuffle=True, with_file_cache=False) valid_data_iter = data_iterator_simple(load_valid_func, len(x_valid), batch_size, shuffle=True, with_file_cache=False) x = nn.Variable((batch_size, sentence_length)) t = nn.Variable((batch_size, sentence_length, 1)) h = PF.embed(x, vocab_size, embedding_size) h = LSTM(h, hidden, return_sequences=True) h = TimeDistributed(PF.affine)(h, hidden, name='hidden') y = TimeDistributed(PF.affine)(h, vocab_size, name='output')
def test_sliced_data_iterator_equivalence(test_data_csv_png_10, num_of_slices, size, batch_size, shuffle): def lcm(a, b): return abs(a * b) / math.gcd(a, b) if a and b else 0 max_epoch = lcm(batch_size, size) / size def test_load_func(position): return np.full((1), position, dtype=np.int) def simple_load_func(data_set, position): return data_set[position] def get_data(iter_list, iter_num): total = 0 for it in iter_list: for _ in range(iter_num): yield it.next() total += 1 yield total yield total iter_num = int((max_epoch * size) / (num_of_slices * batch_size) + 0.5) sliced_di_list = [] di = data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle) for slice_pos in range(num_of_slices): sliced_di = di.slice( rng=None, num_of_slices=num_of_slices, slice_pos=slice_pos) sliced_di_list.append(sliced_di) ref_di_list = [] all_data = [np.full((1), position, dtype=np.int) for position in range(size)] for slice_pos in range(num_of_slices): slice_sample_size = size / num_of_slices start_index = int(slice_sample_size * slice_pos + 0.5) end_index = int(slice_sample_size * (slice_pos + 1) + 0.5) slice_block_size = end_index - start_index sliced_data = all_data[start_index: end_index] di = data_iterator_simple( partial(simple_load_func, sliced_data), slice_block_size, batch_size, shuffle=shuffle) ref_di_list.append(di) set_a = set() set_b = set() for ref, t in zip(get_data(ref_di_list, iter_num), get_data(sliced_di_list, iter_num)): if isinstance(ref, tuple): ref, t = ref[0], t[0] if isinstance(ref, np.ndarray): # print(f"{ref} <--> {t}") set_a = set_a.union(set(ref)) set_b = set_b.union(set(t)) else: #print("-" * 30) assert ref == t # str_a = ','.join([str(f) for f in set_a]) # str_b = ','.join([str(f) for f in set_b]) # print(f"{str_a} <--> {str_b}") assert set_a == set_b di_all = ref_di_list + sliced_di_list for di in di_all: di.close()
def load_train_func(index): x, y = pdata[index] negative_sample_prob = np.ones(len(pdict)) negative_sample_prob[pdict[x]] = 0.0 negative_sample_prob[pdict[y]] = 0.0 negative_sample_prob /= len(pdict) - 2 negative_sample_indices = np.random.choice(range(len(pdict)), negative_sample_size, replace=False, p=negative_sample_prob) return pdict[x], pdict[y], negative_sample_indices train_data_iter = data_iterator_simple(load_train_func, len(pdata), batch_size, shuffle=True, with_file_cache=False) """ """ def distance(u, v, eps=1e-5): uu = F.sum(F.pow_scalar(u, 2), axis=1) vv = F.sum(F.pow_scalar(v, 2), axis=1) euclid_norm_pow2 = F.sum(F.pow_scalar(u - v, 2), axis=1) alpha = F.maximum2(F.constant(eps, shape=uu.shape), 1.0 - uu) beta = F.maximum2(F.constant(eps, shape=vv.shape), 1.0 - vv) return F.acosh(1 + 2 * euclid_norm_pow2 / (alpha * beta))
def data_iterator_imagenet(img_path, dirname_to_label_path, batch_size=16, ih=128, iw=128, n_classes=1000, class_id=-1, noise=True, normalize=lambda x: x / 128.0 - 1.0, train=True, shuffle=True, rng=None): # ------ # Valid # ------ if not train: # Classes (but this tmpdir in ImageNet case) dir_paths = glob.glob("{}/*".format(img_path)) dir_paths.sort() dir_paths = dir_paths[0:n_classes] # Images imgs = [] for dir_path in dir_paths: imgs += glob.glob("{}/*.JPEG".format(dir_path)) def load_func(i): # image img = Image.open(imgs[i]).resize((iw, ih), Image.BILINEAR).convert("RGB") img = np.asarray(img) img = img.transpose((2, 0, 1)) img = img / 128.0 - 1.0 return img, np.array([]) di = data_iterator_simple(load_func, len(imgs), batch_size, shuffle=shuffle, rng=rng, with_file_cache=False) return di # ------ # Train # ------ # Classes dir_paths = glob.glob("{}/*".format(img_path)) dir_paths.sort() dir_paths = dir_paths[0:n_classes] # Images imgs = [] for dir_path in dir_paths: imgs += glob.glob("{}/*.JPEG".format(dir_path)) # np.random.shuffle(imgs) # Dirname to Label map dirname_to_label, label_to_dirname = create_dirname_label_maps( dirname_to_label_path) # Filter by class_id if class_id != -1: dirname = label_to_dirname[class_id] imgs = list(filter(lambda img: dirname in img, imgs)) def load_func(i): # image img = Image.open(imgs[i]).resize((iw, ih), Image.BILINEAR).convert("RGB") img = np.asarray(img) img = img.transpose((2, 0, 1)) img = img / 128.0 - 1.0 if noise: img += np.random.uniform(size=img.shape, low=0.0, high=1.0 / 128) # label elms = imgs[i].rstrip().split("/") dname = elms[-2] label = dirname_to_label[dname] return img, np.array(label) di = data_iterator_simple(load_func, len(imgs), batch_size, shuffle=shuffle, rng=rng, with_file_cache=False) return di
def data_iterator(num_examples, batch_size, img_left, img_right, img_disp, train, shuffle, dataset, rng=None): def dataset_load_func(i): # get images from the list image_left = imread(img_left[i]).astype('float32') image_right = imread(img_right[i]).astype('float32') # print(img_disp) if dataset == "SceneFlow": from main import parser args = parser.parse_args() image_disp, scale = readPFM(img_disp[i]) image_disp = np.ascontiguousarray(image_disp, dtype=np.float32) elif dataset == "Kitti": from finetune import parser args = parser.parse_args() image_disp = imread(img_disp[i]).astype('float32') image_disp = image_disp.reshape( image_disp.shape[0], image_disp.shape[1], 1) mean_imagenet = np.asarray([0.485, 0.456, 0.406]).astype( np.float32).reshape(3, 1, 1) std_imagenet = np.asarray([0.229, 0.224, 0.225]).astype( np.float32).reshape(3, 1, 1) if train: w, h = image_left.shape[1], image_left.shape[0] th, tw = args.crop_height, args.crop_width x1 = random.randint(0, w - tw) y1 = random.randint(0, h - th) # crop image_left = image_left[y1:y1 + th, x1:x1 + tw] image_right = image_right[y1:y1 + th, x1:x1 + tw] if dataset == "Kitti": image_disp = np.ascontiguousarray( image_disp, dtype=np.float32)/256 image_disp = image_disp[y1:y1 + th, x1:x1 + tw] # normalize with mean and std image_left, image_right, image_disp = np.rollaxis( image_left, 2), np.rollaxis(image_right, 2), np.rollaxis(image_disp, 2) image_left = (image_left/255).astype(np.float32) image_right = (image_right/255).astype(np.float32) image_left -= mean_imagenet image_left /= std_imagenet image_right -= mean_imagenet image_right /= std_imagenet else: # crop if dataset == "SceneFlow": image_left = image_left[:args.im_height, :args.im_width, :] image_right = image_right[:args.im_height, :args.im_width, :] image_disp = image_disp[:args.im_height, :args.im_width, :] elif dataset == "Kitti": w, h = image_left.shape[1], image_left.shape[0] image_left = image_left[h - args.im_height:h, w-args.im_width:w, :] image_right = image_right[h - args.im_height:h, w-args.im_width:w, :] image_disp = image_disp[h - args.im_height:h, w-args.im_width:w, :] image_disp = np.ascontiguousarray( image_disp, dtype=np.float32)/256 # normalize image_left, image_right, image_disp = np.rollaxis( image_left, 2), np.rollaxis(image_right, 2), np.rollaxis(image_disp, 2) image_left = (image_left/255).astype(np.float32) image_right = (image_right/255).astype(np.float32) image_left -= mean_imagenet image_left /= std_imagenet image_right -= mean_imagenet image_right /= std_imagenet return image_left, image_right, image_disp return data_iterator_simple(dataset_load_func, num_examples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False, with_memory_cache=False)
def main(): """ Main script. Steps: * Get and set context. * Load Dataset * Initialize DataIterator. * Create Networks * Net for Labeled Data * Net for Unlabeled Data * Net for Test Data * Create Solver. * Training Loop. * Test * Training * by Labeled Data * Calculate Supervised Loss * by Unlabeled Data * Calculate Virtual Adversarial Noise * Calculate Unsupervised Loss """ args = get_args() # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) shape_x = (1, 28, 28) n_h = args.n_units n_y = args.n_class # Load MNIST Dataset from mnist_data import load_mnist, data_iterator_mnist images, labels = load_mnist(train=True) rng = np.random.RandomState(706) inds = rng.permutation(len(images)) def feed_labeled(i): j = inds[i] return images[j], labels[j] def feed_unlabeled(i): j = inds[i] return images[j], labels[j] di_l = data_iterator_simple(feed_labeled, args.n_labeled, args.batchsize_l, shuffle=True, rng=rng, with_file_cache=False) di_u = data_iterator_simple(feed_unlabeled, args.n_train, args.batchsize_u, shuffle=True, rng=rng, with_file_cache=False) di_v = data_iterator_mnist(args.batchsize_v, train=False) # Create networks # feed-forward-net building function def forward(x, test=False): return mlp_net(x, n_h, n_y, test) # Net for learning labeled data xl = nn.Variable((args.batchsize_l, ) + shape_x, need_grad=False) yl = forward(xl, test=False) tl = nn.Variable((args.batchsize_l, 1), need_grad=False) loss_l = F.mean(F.softmax_cross_entropy(yl, tl)) # Net for learning unlabeled data xu = nn.Variable((args.batchsize_u, ) + shape_x, need_grad=False) yu = forward(xu, test=False) y1 = yu.get_unlinked_variable() y1.need_grad = False noise = nn.Variable((args.batchsize_u, ) + shape_x, need_grad=True) r = noise / (F.sum(noise**2, [1, 2, 3], keepdims=True))**0.5 r.persistent = True y2 = forward(xu + args.xi_for_vat * r, test=False) y3 = forward(xu + args.eps_for_vat * r, test=False) loss_k = F.mean(distance(y1, y2)) loss_u = F.mean(distance(y1, y3)) # Net for evaluating validation data xv = nn.Variable((args.batchsize_v, ) + shape_x, need_grad=False) hv = forward(xv, test=True) tv = nn.Variable((args.batchsize_v, 1), need_grad=False) err = F.mean(F.top_n_error(hv, tv, n=1)) # Create solver solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Monitor training and validation stats. import nnabla.monitor as M monitor = M.Monitor(args.model_save_path) monitor_verr = M.MonitorSeries("Test error", monitor, interval=240) monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=240) # Training Loop. t0 = time.time() for i in range(args.max_iter): # Validation Test if i % args.val_interval == 0: valid_error = calc_validation_error(di_v, xv, tv, err, args.val_iter) monitor_verr.add(i, valid_error) ################################# ## Training by Labeled Data ##### ################################# # forward, backward and update xl.d, tl.d = di_l.next() xl.d = xl.d / 255 solver.zero_grad() loss_l.forward(clear_no_need_grad=True) loss_l.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() ################################# ## Training by Unlabeled Data ### ################################# # Calculate y without noise, only once. xu.d, _ = di_u.next() xu.d = xu.d / 255 yu.forward(clear_buffer=True) ##### Calculate Adversarial Noise ##### # Do power method iteration noise.d = np.random.normal(size=xu.shape).astype(np.float32) for k in range(args.n_iter_for_power_method): r.grad.zero() loss_k.forward(clear_no_need_grad=True) loss_k.backward(clear_buffer=True) noise.data.copy_from(r.grad) ##### Calculate loss for unlabeled data ##### # forward, backward and update solver.zero_grad() loss_u.forward(clear_no_need_grad=True) loss_u.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() ##### Learning rate update ##### if i % args.iter_per_epoch == 0: solver.set_learning_rate(solver.learning_rate() * args.learning_rate_decay) monitor_time.add(i) # Evaluate the final model by the error rate with validation dataset valid_error = calc_validation_error(di_v, xv, tv, err, args.val_iter) monitor_verr.add(i, valid_error) monitor_time.add(i) # Save the model. parameter_file = os.path.join(args.model_save_path, 'params_%06d.h5' % args.max_iter) nn.save_parameters(parameter_file)
def data_iterator_sr(conf, num_samples, sample_names, tar_size, shuffle, rng=None): """ Data iterator for TecoGAN training return: makes provision for low res & high res frames in RNN segments for specified batch_size """ def populate_hr_data(i): hr_data = [] # high res rgb, in range 0-1, shape any # moving first frame -> data augmentation # our data augmentation, moving first frame to mimic camera motion if conf.train.movingFirstFrame: lefttop_pos, range_pos, is_move = moving_decision(conf) for f_i in range(conf.train.rnn_n): img_name = sample_names[f_i][i] img_data = cv.imread(img_name, 3).astype(np.float32) img_data = img_data / 255 if conf.train.movingFirstFrame: if f_i == 0: img_data_0 = img_data target_size = img_data.shape # random data augmentation -> move first frame only with 30% probability if not is_move < 0.7: img_data = img_data_0[lefttop_pos[f_i][1]:target_size[0] - range_pos[1] + lefttop_pos[f_i][1], lefttop_pos[f_i][0]:target_size[1] - range_pos[0] + lefttop_pos[f_i][0], :] hr_data.append(img_data) return hr_data def dataset_load_func(i): hr_data = populate_hr_data(i) # random crop each batch entry separately # Check whether perform crop if conf.train.random_crop is True: cur_size = hr_data[0].shape offset_h = np.floor( np.random.uniform(0, cur_size[0] - tar_size, [])).astype(int) offset_w = np.floor( np.random.uniform(0, cur_size[1] - tar_size, [])).astype(int) for frame_t in range(conf.train.rnn_n): hr_data[frame_t] = hr_data[frame_t][offset_h:offset_h + tar_size, offset_w:offset_w + tar_size, :] # random flip: if conf.train.flip is True: # Produce the decision of random flip flip_decision = np.random.uniform(0, 1, []).astype(float) for frame_t in range(conf.train.rnn_n): if flip_decision < 0.5: np.fliplr(hr_data[frame_t]) hr_frames = hr_data target_frames = [] k_w_border = int(1.5 * 3.0) for rnn_inst in range(conf.train.rnn_n): # crop out desired data cropped_data = hr_data[rnn_inst][k_w_border:k_w_border + conf.train.crop_size * 4, k_w_border:k_w_border + conf.train.crop_size * 4, :] pre_processed_data = preprocess(cropped_data) target_frames.append(pre_processed_data) return hr_frames, target_frames return data_iterator_simple(dataset_load_func, num_samples, conf.train.batch_size, shuffle=shuffle, rng=rng, with_file_cache=False, with_memory_cache=False)
def train(args): if args.c_dim != len(args.selected_attrs): print("c_dim must be the same as the num of selected attributes. Modified c_dim.") args.c_dim = len(args.selected_attrs) # Dump the config information. config = dict() print("Used config:") for k in args.__dir__(): if not k.startswith("_"): config[k] = getattr(args, k) print("'{}' : {}".format(k, getattr(args, k))) # Prepare Generator and Discriminator based on user config. generator = functools.partial( model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num) discriminator = functools.partial(model.discriminator, image_size=args.image_size, conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num) x_real = nn.Variable( [args.batch_size, 3, args.image_size, args.image_size]) label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1]) label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1]) with nn.parameter_scope("dis"): dis_real_img, dis_real_cls = discriminator(x_real) with nn.parameter_scope("gen"): x_fake = generator(x_real, label_trg) x_fake.persistent = True # to retain its value during computation. # get an unlinked_variable of x_fake x_fake_unlinked = x_fake.get_unlinked_variable() with nn.parameter_scope("dis"): dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked) # ---------------- Define Loss for Discriminator ----------------- d_loss_real = (-1) * loss.gan_loss(dis_real_img) d_loss_fake = loss.gan_loss(dis_fake_img) d_loss_cls = loss.classification_loss(dis_real_cls, label_org) d_loss_cls.persistent = True # Gradient Penalty. alpha = F.rand(shape=(args.batch_size, 1, 1, 1)) x_hat = F.mul2(alpha, x_real) + \ F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked) with nn.parameter_scope("dis"): dis_for_gp, _ = discriminator(x_hat) grads = nn.grad([dis_for_gp], [x_hat]) l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5 d_loss_gp = F.mean((l2norm - 1.0) ** 2.0) # total discriminator loss. d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \ d_loss_cls + args.lambda_gp * d_loss_gp # ---------------- Define Loss for Generator ----------------- g_loss_fake = (-1) * loss.gan_loss(dis_fake_img) g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg) g_loss_cls.persistent = True # Reconstruct Images. with nn.parameter_scope("gen"): x_recon = generator(x_fake_unlinked, label_org) x_recon.persistent = True g_loss_rec = loss.recon_loss(x_real, x_recon) g_loss_rec.persistent = True # total generator loss. g_loss = g_loss_fake + args.lambda_rec * \ g_loss_rec + args.lambda_cls * g_loss_cls # -------------------- Solver Setup --------------------- d_lr = args.d_lr # initial learning rate for Discriminator g_lr = args.g_lr # initial learning rate for Generator solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2) solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2) # register parameters to each solver. with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) # -------------------- Create Monitors -------------------- monitor = Monitor(args.monitor_path) monitor_d_cls_loss = MonitorSeries( 'real_classification_loss', monitor, args.log_step) monitor_g_cls_loss = MonitorSeries( 'fake_classification_loss', monitor, args.log_step) monitor_loss_dis = MonitorSeries( 'discriminator_loss', monitor, args.log_step) monitor_recon_loss = MonitorSeries( 'reconstruction_loss', monitor, args.log_step) monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step) monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step) # -------------------- Prepare / Split Dataset -------------------- using_attr = args.selected_attrs dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr) random.seed(313) # use fixed seed. random.shuffle(dataset) # shuffle dataset. test_dataset = dataset[-2000:] # extract 2000 images for test if args.num_data: # Use training data partially. training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)] else: training_dataset = dataset[:-2000] print("Use {} images for training.".format(len(training_dataset))) # create data iterators. load_func = functools.partial(stargan_load_func, dataset=training_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) data_iterator = data_iterator_simple(load_func, len( training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) load_func_test = functools.partial(stargan_load_func, dataset=test_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) test_data_iterator = data_iterator_simple(load_func_test, len( test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) # Keep fixed test images for intermediate translation visualization. test_real_ndarray, test_label_ndarray = test_data_iterator.next() test_label_ndarray = test_label_ndarray.reshape( test_label_ndarray.shape + (1, 1)) # -------------------- Training Loop -------------------- one_epoch = data_iterator.size // args.batch_size num_max_iter = args.max_epoch * one_epoch for i in range(num_max_iter): # Get real images and labels. real_ndarray, label_ndarray = data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray # Generate target domain labels randomly. rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] # ---------------- Train Discriminator ----------------- # generate fake image. x_fake.forward(clear_no_need_grad=True) d_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() d_loss.backward(clear_buffer=True) solver_dis.update() monitor_loss_dis.add(i, d_loss.d.item()) monitor_d_cls_loss.add(i, d_loss_cls.d.item()) monitor_time.add(i) # -------------- Train Generator -------------- if (i + 1) % args.n_critic == 0: g_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() solver_gen.zero_grad() x_fake_unlinked.grad.zero() g_loss.backward(clear_buffer=True) x_fake.backward(grad=None) solver_gen.update() monitor_loss_gen.add(i, g_loss.d.item()) monitor_g_cls_loss.add(i, g_loss_cls.d.item()) monitor_recon_loss.add(i, g_loss_rec.d.item()) monitor_time.add(i) if (i + 1) % args.sample_step == 0: # save image. save_results(i, args, x_real, x_fake, label_org, label_trg, x_recon) if args.test_during_training: # translate images from test dataset. x_real.d, label_org.d = test_real_ndarray, test_label_ndarray label_trg.d = test_label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False) # Learning rates get decayed if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0: g_lr = max(0, g_lr - (args.lr_update_step * args.g_lr / float(0.5 * num_max_iter))) d_lr = max(0, d_lr - (args.lr_update_step * args.d_lr / float(0.5 * num_max_iter))) solver_gen.set_learning_rate(g_lr) solver_dis.set_learning_rate(d_lr) print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) # Save parameters and training config. param_name = 'trained_params_{}.h5'.format( datetime.datetime.today().strftime("%m%d%H%M")) param_path = os.path.join(args.model_save_path, param_name) nn.save_parameters(param_path) config["pretrained_params"] = param_name with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f: json.dump(config, f) # -------------------- Translation on test dataset -------------------- for i in range(args.num_test): real_ndarray, label_ndarray = test_data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False)
def data_iterator(conf, shuffle, rng=None): """ Data iterator for Zooming SloMo training return: """ assert conf.data.n_frames > 1, 'Error: Not enough LR frames to interpolate' half_n_frames = conf.data.n_frames // 2 # determine the LQ frame list # N | frames # 1 | error # 3 | 0,2 # 5 | 0,2,4 # 7 | 0,2,4,6 lr_index_list = [i * 2 for i in range(1 + half_n_frames)] paths_gt = pickle.load(open(conf.data.cache_keys, 'rb')) gt_lmdb = lmdb.open(conf.data.lmdb_data_gt, readonly=True, lock=False, readahead=False, meminit=False) lq_lmdb = lmdb.open(conf.data.lmdb_data_lq, readonly=True, lock=False, readahead=False, meminit=False) center_frame_idx = random.randint(2, 6) # 2<= index <=6 def determine_neighbor_list(central_frame_idx): """ given central frame index, determine neighborhood frames """ interval = random.choice(conf.data.interval_list) if conf.data.border_mode: direction = 1 # 1: forward; 0: backward if conf.random_reverse and random.random() < 0.5: direction = random.choice([0, 1]) if central_frame_idx + interval * (conf.data.n_frames - 1) > 7: direction = 0 elif central_frame_idx - interval * (conf.data.n_frames - 1) < 1: direction = 1 # get the neighbor list if direction == 1: neighbor_list = list( range(central_frame_idx, central_frame_idx + interval * conf.data.n_frames, interval)) else: neighbor_list = list( range(central_frame_idx, central_frame_idx - interval * conf.data.n_frames, -interval)) else: # ensure not exceeding the borders while (central_frame_idx + half_n_frames * interval > 7) or \ (central_frame_idx - half_n_frames * interval < 1): central_frame_idx = random.randint(2, 6) # get the neighbor list neighbor_list = list( range(central_frame_idx - half_n_frames * interval, central_frame_idx + half_n_frames * interval + 1, interval)) if conf.data.random_reverse and random.random() < 0.5: neighbor_list.reverse() return neighbor_list neighbors = determine_neighbor_list(center_frame_idx) lq_frames_list = [neighbors[i] for i in lr_index_list] assert len(neighbors) == conf.data.n_frames, \ 'Wrong length of neighbor list: {}'.format(len(neighbors)) # image read and augment functions def augment(img_list, flip=True, rot=True): # flip OR rotate def _augment(img): if flip and random.random() < 0.5: # horizontal flip img = img[:, ::-1, :] if rot and random.random() < 0.5: # vertical flip and 90 degree rotation img = img[::-1, :, :] img = img.transpose(1, 0, 2) return img return [_augment(img) for img in img_list] def _read_img_from_lmdb(env, key, size): """ read image from lmdb with key (w/ and w/o fixed size) size: (channels, height, width) tuple """ with env.begin(write=False) as txn: buf = txn.get(key.encode('ascii')) img_flat = np.frombuffer(buf, dtype=np.uint8) channels, height, width = size img = img_flat.reshape(height, width, channels) img = img.astype(np.float32) / 255. if img.ndim == 2: img = np.expand_dims(img, axis=2) # some images have 4 channels if img.shape[2] > 3: img = img[:, :, :3] return img def load_zoomingslomo_data(i): """ loads data, given the index -> primary function in data loader """ key = paths_gt[i] # get the GT image (as the center frame) img_gt_l = [ _read_img_from_lmdb(gt_lmdb, key + '_{}'.format(v), (3, 256, 448)) for v in neighbors ] # get Low Quality images lq_size_tuple = (3, 64, 112) img_lq_l = [ _read_img_from_lmdb(lq_lmdb, key + '_{}'.format(v), lq_size_tuple) for v in lq_frames_list ] _, height, width = lq_size_tuple # LQ size # randomly crop scale = 4 gt_size = conf.data.gt_size lr_size = gt_size // scale rnd_h = random.randint(0, max(0, height - lr_size)) rnd_w = random.randint(0, max(0, width - lr_size)) img_lq_l = [ v[rnd_h:rnd_h + lr_size, rnd_w:rnd_w + lr_size, :] for v in img_lq_l ] rnd_h_highres, rnd_w_highres = int(rnd_h * scale), int(rnd_w * scale) img_gt_l = [ v[rnd_h_highres:rnd_h_highres + gt_size, rnd_w_highres:rnd_w_highres + gt_size, :] for v in img_gt_l ] # augmentation - flip, rotate img_lq_l = img_lq_l + img_gt_l rlt = augment(img_lq_l, conf.data.use_flip, conf.data.use_rot) img_lq_l = rlt[0:-conf.data.n_frames] img_gt_l = rlt[-conf.data.n_frames:] # stack LQ and GT images in NHWC order, N is the frame number img_lq_stack = np.stack(img_lq_l, axis=0) img_gt_stack = np.stack(img_gt_l, axis=0) # numpy to tensor img_gt_stack = img_gt_stack[:, :, :, [2, 1, 0]] # BGR to RGB img_lq_stack = img_lq_stack[:, :, :, [2, 1, 0]] # BGR to RGB img_gt_stack = np.ascontiguousarray( np.transpose(img_gt_stack, (0, 3, 1, 2))) # HWC to CHW img_lq_stack = np.ascontiguousarray( np.transpose(img_lq_stack, (0, 3, 1, 2))) # HWC to CHW return img_lq_stack, img_gt_stack def load_slomo_data(i): """ loads data, given the index -> primary function in data loader """ key = paths_gt[i] gt_size_tuple = (3, 256, 448) # get the GT image (as the center frame) img_gt_l = [ _read_img_from_lmdb(gt_lmdb, key + '_{}'.format(v), gt_size_tuple) for v in neighbors ] _, height, width = gt_size_tuple # GT size # randomly crop gt_size = conf.data.gt_size rnd_h = random.randint(0, max(0, height - gt_size)) rnd_w = random.randint(0, max(0, width - gt_size)) img_gt_l = [ v[rnd_h:rnd_h + gt_size, rnd_w:rnd_w + gt_size, :] for v in img_gt_l ] # augmentation - flip, rotate img_gt_l = augment(img_gt_l, conf.data.use_flip, conf.data.use_rot) # stack LQ and GT images in NHWC order, N is the frame number img_gt_stack = np.stack(img_gt_l, axis=0) # numpy to tensor img_gt_stack = img_gt_stack[:, :, :, [2, 1, 0]] # BGR to RGB img_gt_stack = np.ascontiguousarray( np.transpose(img_gt_stack, (0, 3, 1, 2))) # HWC to CHW return _, img_gt_stack dataset_load_func = load_zoomingslomo_data if not conf.train.only_slomo else load_slomo_data return data_iterator_simple(dataset_load_func, len(paths_gt), conf.train.batch_size, shuffle=shuffle, rng=rng, with_file_cache=False, with_memory_cache=False)