def main(args): loader = Loader(args.run_name) logger = Logger(args.run_name, create_if_exists=False) option = loader.load_option() layers = [int(l) for l in option["layers"].split(",")] params = loader.load_params() Model = get_model_cls_by_type(option["type"]) model = Model(layers, option["nc"], option["omega"]) model.update_net_params(params) if args.size == 0: orig_img_fn = loader.get_image_filename("original") img = Image.open(orig_img_fn) width = img.width height = img.height else: width = args.size height = args.size estimate_and_save_image(model, width, height, logger) if option["nc"] == 1: estimate_and_save_gradient(model, width, height, logger) estimate_and_save_laplacian(model, width, height, logger) if option["size"] != 0: # PIL resize as reference orig_pil_img = loader.load_pil_image("original") resized_pil = orig_pil_img.resize((width, height)) pil_output_name = "pil_{}x{}".format(width, height) logger.save_image(pil_output_name, resized_pil)
def main(args): layers = [int(l) for l in args.layers.split(",")] Model = get_model_cls_by_type(args.type) DataLoader = get_data_loader_cls_by_type(args.type) data_loader = DataLoader(args.file, args.nc, args.size, args.batch_size) model = Model(layers, args.nc, args.omega) optimizer = JaxOptimizer("adam", model, args.lr) name = args.file.split(".")[0] logger = Logger(name) logger.save_option(vars(args)) gt_img = data_loader.get_ground_truth_image() logger.save_image("original", data_loader.original_pil_img) logger.save_image("gt", gt_img) iter_timer = Timer() iter_timer.start() def interm_callback(i, data, params): log = {} loss = model.loss_func(params, data) log["loss"] = float(loss) log["iter"] = i log["duration_per_iter"] = iter_timer.get_dt() / args.print_iter logger.save_log(log) print(log) print("Training Start") print(vars(args)) total_timer = Timer() total_timer.start() last_data = None for _ in range(args.epoch): data_loader = DataLoader(args.file, args.nc, args.size, args.batch_size) for data in data_loader: optimizer.step(data) last_data = data if optimizer.iter_cnt % args.print_iter == 0: interm_callback(optimizer.iter_cnt, data, optimizer.get_optimized_params()) if not optimizer.iter_cnt % args.print_iter == 0: interm_callback(optimizer.iter_cnt, data, optimizer.get_optimized_params()) train_duration = total_timer.get_dt() print("Training Duration: {} sec".format(train_duration)) logger.save_net_params(optimizer.get_optimized_params()) logger.save_losses_plot()