예제 #1
0
def main(args):
    loader = Loader(args.run_name)
    logger = Logger(args.run_name, create_if_exists=False)
    option = loader.load_option()
    layers = [int(l) for l in option["layers"].split(",")]
    params = loader.load_params()

    Model = get_model_cls_by_type(option["type"])
    model = Model(layers, option["nc"], option["omega"])
    model.update_net_params(params)

    if args.size == 0:
        orig_img_fn = loader.get_image_filename("original")
        img = Image.open(orig_img_fn)
        width = img.width
        height = img.height
    else:
        width = args.size
        height = args.size

    estimate_and_save_image(model, width, height, logger)
    if option["nc"] == 1:
        estimate_and_save_gradient(model, width, height, logger)
        estimate_and_save_laplacian(model, width, height, logger)

    if option["size"] != 0:
        # PIL resize as reference
        orig_pil_img = loader.load_pil_image("original")
        resized_pil = orig_pil_img.resize((width, height))
        pil_output_name = "pil_{}x{}".format(width, height)
        logger.save_image(pil_output_name, resized_pil)
예제 #2
0
def main(args):
    layers = [int(l) for l in args.layers.split(",")]

    Model = get_model_cls_by_type(args.type)
    DataLoader = get_data_loader_cls_by_type(args.type)

    data_loader = DataLoader(args.file, args.nc, args.size, args.batch_size)
    model = Model(layers, args.nc, args.omega)
    optimizer = JaxOptimizer("adam", model, args.lr)

    name = args.file.split(".")[0]
    logger = Logger(name)
    logger.save_option(vars(args))

    gt_img = data_loader.get_ground_truth_image()
    logger.save_image("original", data_loader.original_pil_img)
    logger.save_image("gt", gt_img)

    iter_timer = Timer()
    iter_timer.start()

    def interm_callback(i, data, params):
        log = {}
        loss = model.loss_func(params, data)
        log["loss"] = float(loss)
        log["iter"] = i
        log["duration_per_iter"] = iter_timer.get_dt() / args.print_iter

        logger.save_log(log)
        print(log)

    print("Training Start")
    print(vars(args))

    total_timer = Timer()
    total_timer.start()
    last_data = None
    for _ in range(args.epoch):
        data_loader = DataLoader(args.file, args.nc, args.size,
                                 args.batch_size)
        for data in data_loader:
            optimizer.step(data)
            last_data = data
            if optimizer.iter_cnt % args.print_iter == 0:
                interm_callback(optimizer.iter_cnt, data,
                                optimizer.get_optimized_params())

    if not optimizer.iter_cnt % args.print_iter == 0:
        interm_callback(optimizer.iter_cnt, data,
                        optimizer.get_optimized_params())

    train_duration = total_timer.get_dt()
    print("Training Duration: {} sec".format(train_duration))
    logger.save_net_params(optimizer.get_optimized_params())
    logger.save_losses_plot()