Example #1
0
def get_args():
	parser = argparse.ArgumentParser(allow_abbrev=False)

	# Add data arguments
	parser.add_argument("--data-path", default="data", help="path to data directory")
	parser.add_argument("--dataset", default="bsd400", help="train dataset name")
	parser.add_argument("--batch-size", default=128, type=int, help="train batch size")

	# Add model arguments
	parser.add_argument("--model", default="dncnn", help="model architecture")
	
	# Add noise arguments
	parser.add_argument("--noise_mode", default="B", help="B - Blind S-one noise level")
	parser.add_argument('--noise_std', default = 25, type = float, 
				help = 'noise level when mode is S')
	parser.add_argument('--min_noise', default = 0, type = float, 
				help = 'minimum noise level when mode is B')
	parser.add_argument('--max_noise', default = 55, type = float, 
					help = 'maximum noise level when mode is B')

	# Add optimization arguments
	parser.add_argument("--lr", default=1e-3, type=float, help="learning rate")
	parser.add_argument("--num-epochs", default=100, type=int, help="force stop training at specified epoch")
	parser.add_argument("--valid-interval", default=1, type=int, help="evaluate every N epochs")
	parser.add_argument("--save-interval", default=1, type=int, help="save a checkpoint every N steps")

	# Parse twice as model arguments are not known the first time
	parser = utils.add_logging_arguments(parser)
	args, _ = parser.parse_known_args()
	models.MODEL_REGISTRY[args.model].add_args(parser)
	args = parser.parse_args()
	return args
Example #2
0
def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)

    # Add data arguments
    parser.add_argument("--data-path",
                        default="data",
                        help="path to data directory")
    parser.add_argument("--dataset",
                        default="split_cifar10",
                        help="train dataset name")
    parser.add_argument("--batch-size",
                        default=10,
                        type=int,
                        help="train batch size")

    # Add model arguments
    parser.add_argument("--model", default="resnet", help="model architecture")

    # Add optimization arguments
    parser.add_argument("--optimizer", default="adam", help="optimizer")
    parser.add_argument("--lr", default=2e-4, type=float, help="learning rate")
    parser.add_argument("--num-repeats-per-task",
                        default=1,
                        type=int,
                        help="number of repeats per task")
    parser.add_argument("--num-epochs",
                        default=1,
                        type=int,
                        help="force stop training at specified epoch")

    # Parse twice as model arguments are not known the first time
    parser = utils.add_logging_arguments(parser)
    args, _ = parser.parse_known_args()
    models.MODEL_REGISTRY[args.model].add_args(parser)
    optim.OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
    args = parser.parse_args()
    return args
Example #3
0
import torch.nn as nn
import torch.utils.data as td

import utils
import models
import models.jlu
import dataset
import inference
from train import embeds

MODES = ["word", "label", "intent"]

parser = argparse.ArgumentParser(fromfile_prefix_chars="@")

group = parser.add_argument_group("Logging Options")
utils.add_logging_arguments(group, "predict")
group.add_argument("--argparse-filename", type=str, default="predict.args")
group.add_argument("--show-progress", action="store_true", default=False)

group = parser.add_argument_group("Data Options")
group.add_argument("--word-path", type=str, required=True)
for mode in MODES:
    group.add_argument(f"--{mode}-vocab", type=str, required=True)
group.add_argument("--data-workers", type=int, default=8)
group.add_argument("--seed", type=int, default=None)
group.add_argument("--unk", type=str, default="<unk>")
group.add_argument("--eos", type=str, default="<eos>")
group.add_argument("--bos", type=str, default="<bos>")

group = parser.add_argument_group("Prediction Options")
group.add_argument("--ckpt-path", type=str, required=True)
Example #4
0
import numpy as np
import torch
import torch.nn as nn
import torch.optim as op
import torch.utils.data as td

import utils
import model
import dataset
from . import embeds

parser = argparse.ArgumentParser(fromfile_prefix_chars="@")

group = parser.add_argument_group("Logging Options")
utils.add_logging_arguments(group, "train")
group.add_argument("--argparse-filename",
                   type=str,
                   default="train-argparse.yml")
group.add_argument("--show-progress", action="store_true", default=False)

group = parser.add_argument_group("Model Parameters")
model.add_arguments(group)

group = parser.add_argument_group("Data Options")
group.add_argument("--data-path", type=str, required=True)
group.add_argument("--vocab", type=str, default=None)
group.add_argument("--vocab-limit", type=int, default=None)
group.add_argument("--data-workers", type=int, default=8)
group.add_argument("--pin-memory", action="store_true", default=False)
group.add_argument("--shuffle", action="store_true", default=False)
Example #5
0
def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)

    # Add data arguments
    parser.add_argument("--data-path",
                        default="../lidar_data/32_32/",
                        help="path to data directory")
    parser.add_argument("--dataset",
                        default="masked_pwc",
                        help="masked training data for generator")
    parser.add_argument("--batch-size",
                        default=32,
                        type=int,
                        help="train batch size")
    parser.add_argument("--num_scan_lines",
                        default=1000,
                        type=int,
                        help="number of scan lines used to generate data")
    parser.add_argument("--seq_len",
                        default=32,
                        type=int,
                        help="side length of the patches")
    parser.add_argument(
        "--scan_line_gap_break",
        default=7000,
        type=int,
        help="threshold over which scan_gap indicates a new scan line")
    parser.add_argument("--min_pt_count",
                        default=1700,
                        type=int,
                        help="in a scan line, otherwise line not used")
    parser.add_argument("--max_pt_count",
                        default=2000,
                        type=int,
                        help="in a scan line, otherwise line not used")
    parser.add_argument("--mask_pts_per_seq",
                        default=5,
                        type=int,
                        help="Sqrt(masked pts), side of the missing patch")
    parser.add_argument("--mask_consecutive",
                        default=True,
                        help="True if pts are in a consecutive patch")
    parser.add_argument(
        "--stride_inline",
        default=5,
        type=int,
        help="The number of pts skipped between patches within the scan line")
    parser.add_argument(
        "--stride_across_lines",
        default=3,
        type=int,
        help="The number of pts skipped between patches across the scan line")

    # parser.add_argument("--n-data", default=100000,type=int, help="number of samples")
    # parser.add_argument("--min_sep", default=5,type=int, help="minimum constant sample count for piecwewise function")

    # Add model arguments
    parser.add_argument("--model", default="lidar_unet2d", help="Model used")
    # parser.add_argument("--in_channels", default=7, type=int, help="Number of in channels")
    # parser.add_argument("--modelG", default="unet1d", help="Generator model architecture")
    # parser.add_argument("--modelD", default="gan_discriminator", help="Discriminator model architecture")
    parser.add_argument(
        "--wtd_loss",
        default=True,
        help="True if MSELoss should be weighted by xyz distances")
    # parser.add_argument("--g_d_update_ratio", default = 2, type=int, help="How many times to update G for each update of D")

    # Add optimization arguments
    parser.add_argument("--lr",
                        default=.005,
                        type=float,
                        help="learning rate for generator")
    parser.add_argument("--weight_decay",
                        default=0.,
                        type=float,
                        help="weight decay for optimizer")

    # Logistics arguments
    parser.add_argument("--num-epochs",
                        default=10,
                        type=int,
                        help="force stop training at specified epoch")
    parser.add_argument("--valid-interval",
                        default=1,
                        type=int,
                        help="evaluate every N epochs")
    parser.add_argument("--save-interval",
                        default=1,
                        type=int,
                        help="save a checkpoint every N steps")
    parser.add_argument("--output_dir",
                        default='../lidar_experiments/2d',
                        help="where the model and logs are saved.")
    parser.add_argument(
        "--MODEL_PATH_LOAD",
        default=
        '../lidar_experiments/2d/lidar_unet2d/lidar-unet2d-Nov-08-16:29:49/checkpoints/checkpoint_best.pt',
        help="where to load an existing model from")
    # Parse twice as model arguments are not known the first time
    parser = utils.add_logging_arguments(parser)
    args, _ = parser.parse_known_args()
    models.MODEL_REGISTRY[args.model].add_args(parser)
    # models.MODEL_REGISTRY[args.modelD].add_args(parser)
    args = parser.parse_args()
    print("vars(args)", vars())
    return args
Example #6
0
import argparse
import collections

import torch
import torch.nn as nn
import torch.utils.data as td

import utils
import model
import dataset


parser = argparse.ArgumentParser(fromfile_prefix_chars="@")

group = parser.add_argument_group("Logging Options")
utils.add_logging_arguments(group, "generate")
group.add_argument("--argparse-filename",
                   type=str, default="generate-argparse.yml")
group.add_argument("--samples-filename", type=str, default="samples.txt")
group.add_argument("--neighbors-filename", type=str, default="neighbors.txt")
group.add_argument("--show-progress", action="store_true", default=False)

group = parser.add_argument_group("Data Options")
group.add_argument("--data-path", type=str, default=None)
group.add_argument("--vocab", type=str, required=True)
group.add_argument("--data-workers", type=int, default=8)
group.add_argument("--seed", type=int, default=None)
group.add_argument("--unk", type=str, default="<unk>")
group.add_argument("--eos", type=str, default="<eos>")
group.add_argument("--bos", type=str, default="<bos>")