Ejemplo n.º 1
0
    type=str,
    required=True,
    help='path to your folder of images and text for learning the DALL-E')

parser.add_argument(
    '--truncate_captions',
    dest='truncate_captions',
    help=
    'Captions passed in which exceed the max token length will be truncated if this is set.'
)

parser.add_argument('--taming', dest='taming', action='store_true')

parser.add_argument('--fp16', action='store_true')

parser = deepspeed_utils.wrap_arg_parser(parser)

args = parser.parse_args()

# helpers


def exists(val):
    return val is not None


# constants

VAE_PATH = args.vae_path
DALLE_PATH = args.dalle_path
RESUME = exists(DALLE_PATH)
Ejemplo n.º 2
0
def get_parser():
    parser = argparse.ArgumentParser()

    # parser.add_argument('--image_folder', type = str, required = True,
    #                     help='path to your folder of images for learning the discrete VAE and its codebook')

    ## Data/Model/Output
    parser.add_argument('--image_folder',
                        type=str,
                        default='../../dataset/val2017')
    parser.add_argument('--model_dir', type=str, default='../model/vae')
    #     parser.add_argument('--output_dir', type=str, default='../output/vae')
    parser.add_argument('--image_size',
                        type=int,
                        required=False,
                        default=128,
                        help='image size')
    ## Hyperparameter
    parser.add_argument('--EPOCHS', type=int, default=20)
    parser.add_argument('--BATCH_SIZE', type=int, default=8)
    parser.add_argument('--LEARNING_RATE', type=float, default=1e-3)
    parser.add_argument('--LR_DECAY_RATE', type=float, default=0.98)

    parser.add_argument('--NUM_TOKENS', type=int, default=8192)
    parser.add_argument('--NUM_LAYERS', type=int, default=2)
    parser.add_argument('--NUM_RESNET_BLOCKS', type=int, default=2)
    parser.add_argument('--SMOOTH_L1_LOSS', type=bool, default=False)
    parser.add_argument('--EMB_DIM', type=int, default=512)
    parser.add_argument('--HID_DIM', type=int, default=256)
    parser.add_argument('--KL_LOSS_WEIGHT', type=int, default=0)

    parser.add_argument('--STARTING_TEMP', type=float, default=1.)
    parser.add_argument('--TEMP_MIN', type=float, default=0.5)
    parser.add_argument('--ANNEAL_RATE', type=float, default=1e-6)

    parser.add_argument('--NUM_IMAGES_SAVE', type=int, default=4)
    parser.add_argument('--model_parallel', type=bool, default=False)
    parser.add_argument('--num_worker', type=int, default=4)

    # Setting for Model Parallel
    parser.add_argument("--num_microbatches", type=int, default=4)
    parser.add_argument("--num_partitions", type=int, default=2)
    parser.add_argument("--horovod", type=bool, default=False)
    parser.add_argument("--ddp", type=bool, default=True)
    parser.add_argument("--amp", type=int, default=0)  ## if amp is 1, true
    parser.add_argument("--pipeline", type=str, default="interleaved")
    parser.add_argument("--optimize", type=str, default="speed")
    parser.add_argument("--placement_strategy", type=str, default="spread")
    parser.add_argument("--assert-losses", type=bool, default=False)

    parser.add_argument('--mp_parameters', type=str, default='')
    parser.add_argument("--partial-checkpoint",
                        type=str,
                        default="",
                        help="The checkpoint path to load")
    parser.add_argument("--full-checkpoint",
                        type=str,
                        default="",
                        help="The checkpoint path to load")
    parser.add_argument("--save-full-model",
                        action="store_true",
                        default=False,
                        help="For Saving the current Model")
    parser.add_argument(
        "--save-partial-model",
        action="store_true",
        default=False,
        help="For Saving the current Model",
    )
    parser.add_argument('--hosts', type=list, default=['algo-1'])
    parser.add_argument('--num-gpus', type=int, default=4)
    parser.add_argument('--channels-last', type=bool, default=True)
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=5,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser = deepspeed_utils.wrap_arg_parser(parser)

    return parser