def _generate_sequential_enc_asset(file, model, image, precision=2): model.eval() input_image = image.clone() enc_keys = {} for layer, module in model.named_children(): image = module(image) enc_keys[layer] = pystiche.TensorKey(image, precision=precision) input = {"image": input_image} params = {"precision": precision} output = {"enc_keys": enc_keys} store_asset(input, params, output, file)
def _generate_default_transformer_epoch_optim_loop_asset( file, image_loader, transformer, criterion, criterion_update_fn, epochs, get_lr_scheduler, get_optimizer=None, ): input_transformer = deepcopy(transformer) if get_optimizer is None: get_optimizer = default_transformer_optimizer optimizer = get_optimizer(transformer) lr_scheduler = get_lr_scheduler(optimizer) for epoch in range(epochs): for target_image in image_loader: criterion_update_fn(target_image, criterion) input_image = transformer(target_image) def closure(): optimizer.zero_grad() loss = criterion(input_image) loss.backward() return loss optimizer.step(closure) lr_scheduler.step() input = { "image_loader": image_loader, "transformer": input_transformer, "criterion": criterion, "criterion_update_fn": criterion_update_fn, "epochs": epochs, } params = { "get_optimizer": get_optimizer, "get_lr_scheduler": get_lr_scheduler, } output = {"transformer": transformer} store_asset(input, params, output, file)
def _generate_default_image_pyramid_optim_loop_asset( file, input_image, criterion, pyramid, get_optimizer=None, preprocessor=None, postprocessor=None, ): if get_optimizer is None: get_optimizer = default_image_optimizer aspect_ratio = extract_aspect_ratio(input_image) output_image = input_image.clone() for level in pyramid: with torch.no_grad(): output_image = level.resize_image(output_image, aspect_ratio=aspect_ratio) if preprocessor is not None: output_image = preprocessor(output_image) optimizer = get_optimizer(output_image) for _ in level: def closure(): optimizer.zero_grad() loss = criterion(output_image) loss.backward() return loss optimizer.step(closure) output_image = output_image.detach() if postprocessor is not None: output_image = postprocessor(output_image) input = {"image": input_image, "criterion": criterion, "pyramid": pyramid} params = { "get_optimizer": get_optimizer, "preprocessor": preprocessor, "postprocessor": postprocessor, } output = {"image": output_image} store_asset(input, params, output, file)
def _generate_default_transformer_optim_loop_asset( file, image_loader, device, transformer, criterion, criterion_update_fn, get_optimizer=None, ): input_transformer = deepcopy(transformer) if get_optimizer is None: get_optimizer = default_transformer_optimizer optimizer = get_optimizer(transformer) for target_image in image_loader: target_image = target_image.to(device) criterion_update_fn(target_image, criterion) input_image = transformer(target_image) def closure(): optimizer.zero_grad() loss = criterion(input_image) loss.backward() return loss optimizer.step(closure) input = { "image_loader": image_loader, "device": device, "transformer": input_transformer, "criterion": criterion, "criterion_update_fn": criterion_update_fn, } params = { "get_optimizer": get_optimizer, } output = {"transformer": transformer} store_asset(input, params, output, file)
def _generate_default_image_optim_loop_asset( file, input_image, criterion, get_optimizer=None, num_steps=500, preprocessor=None, postprocessor=None, ): if get_optimizer is None: get_optimizer = default_image_optimizer output_image = input_image.clone() if preprocessor is not None: output_image = preprocessor(output_image) optimizer = get_optimizer(output_image) for step in range(num_steps): def closure(): optimizer.zero_grad() loss = criterion(output_image) loss.backward() return loss optimizer.step(closure) output_image = output_image.detach() if postprocessor is not None: output_image = postprocessor(output_image) input = {"image": input_image, "criterion": criterion} params = { "get_optimizer": get_optimizer, "num_steps": num_steps, "preprocessor": preprocessor, "postprocessor": postprocessor, } output = {"image": output_image} store_asset(input, params, output, file)