示例#1
0
def arg_setting():    
    parser = argparse.ArgumentParser()

    group = parser.add_mutually_exclusive_group(required=False)

    group.add_argument('--vae_path',
                       type=str,
                       help='path to your trained discrete VAE')

    group.add_argument('--dalle_path',
                       type=str,
                       help='path to your partially trained DALL-E')

    parser.add_argument(
        '--image_text_folder',
        type=str,
        default='/opt/ml/input/data/training',
        help='path to your folder of images and text for learning the DALL-E')


    parser.add_argument(
        '--truncate_captions',
        dest='truncate_captions',
        action='store_true',
        help=
        'Captions passed in which exceed the max token length will be truncated if this is set.'
    )

    parser.add_argument('--random_resize_crop_lower_ratio',
                        dest='resize_ratio',
                        type=float,
                        default=0.75,
                        help='Random resized crop lower ratio')

    parser.add_argument('--chinese', dest='chinese', action='store_true')

    parser.add_argument('--taming', 
                        dest='taming',
                        type=lambda s: s.lower() in ['true', 't', 'yes', '1'],
                        default=False
#                         action='store_true'
                       )

    parser.add_argument('--hug', dest='hug', type=lambda s: s.lower() in ['true', 't', 'yes', '1'], default=False)

    parser.add_argument('--bpe_path',
                        type=str,
                        help='path to your huggingface BPE json file')

    parser.add_argument(
        '--fp16',
        type=lambda s: s.lower() in ['true', 't', 'yes', '1'],
        default=False,
#         action='store_true',
        help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.')

    parser.add_argument(
        '--wandb_name',
        default='dalle_train_transformer',
        help=
        'Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`'
    )

    parser = distributed_utils.wrap_arg_parser(parser)

    train_group = parser.add_argument_group('Training settings')

    train_group.add_argument('--epochs',
                             default=20,
                             type=int,
                             help='Number of epochs')

    train_group.add_argument('--batch_size',
                             default=4,
                             type=int,
                             help='Batch size')

    train_group.add_argument('--learning_rate',
                             default=3e-4,
                             type=float,
                             help='Learning rate')

    train_group.add_argument('--clip_grad_norm',
                             default=0.5,
                             type=float,
                             help='Clip gradient norm')

    train_group.add_argument('--lr_decay', dest='lr_decay', action='store_true')

    model_group = parser.add_argument_group('Model settings')

    model_group.add_argument('--dim', default = 512, type = int, help = 'Model dimension')
              
    model_group.add_argument('--heads',
                             default=8,
                             type=int,
                             help='Model number of heads')

    model_group.add_argument('--dim_head',
                             default=64,
                             type=int,
                             help='Model head dimension')

    model_group.add_argument('--reversible',
                             dest='reversible',
                             type=lambda s: s.lower() in ['true', 't', 'yes', '1'],
                             default=False
#                              action='store_true'
                            )

    model_group.add_argument('--loss_img_weight',
                             default=7,
                             type=int,
                             help='Image loss weight')

    model_group.add_argument('--text_seq_len',
                             default=256,
                             type=int,
                             help='Text sequence length')

    model_group.add_argument('--depth', default=2, type=int, help='Model depth')


    model_group.add_argument(
        '--attn_types',
        default='full',
        type=str,
        help=
        'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.'
    )

    
    parser.add_argument('--num_worker', type=int, default=4)
    parser.add_argument('--model_dir', type=str, default='model/dalle/')

    parser.add_argument('--num-gpus', type=int, default=8)

    parser.add_argument('--hosts', type=list, default=[])
    parser.add_argument('--current-host', type=str, default="")
    parser.add_argument(
        '--backend',
        type=str,
        default='nccl',
        help=
        'backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)'
    )

    
    
    # Setting for Model Parallel   
    parser.add_argument("--sagemakermp", type=lambda s:s.lower() in ['true','t','yes','1'], default=False)
    parser.add_argument("--num_microbatches", type=int, default=4)
    parser.add_argument("--num_partitions", type=int, default=2)
    parser.add_argument("--horovod", type=bool, default=False)
    parser.add_argument("--ddp", type=bool, default=True)
    parser.add_argument("--amp", type=int, default=0)  ## if amp is 1, true 
    parser.add_argument("--pipeline", type=str, default="interleaved")
    parser.add_argument("--optimize", type=str, default="speed")
    parser.add_argument("--placement_strategy", type=str, default="spread")
    parser.add_argument("--assert-losses", type=bool, default=False)
    
    parser.add_argument('--mp_parameters', type=str, default='')
    parser.add_argument("--partial-checkpoint",
                        type=str,
                        default="",
                        help="The checkpoint path to load")
    parser.add_argument("--full-checkpoint",
                        type=str,
                        default="",
                        help="The checkpoint path to load")
    parser.add_argument("--save-full-model",
                        action="store_true",
                        default=False,
                        help="For Saving the current Model")
    parser.add_argument(
        "--save-partial-model",
        action="store_true",
        default=False,
        help="For Saving the current Model",
    )

    args = parser.parse_args()
    return args
示例#2
0
                    dest='resize_ratio',
                    type=float,
                    default=0.75,
                    help='Random resized crop lower ratio')

parser.add_argument('--chinese', dest='chinese', action='store_true')

parser.add_argument('--taming', dest='taming', action='store_true')

parser.add_argument('--bpe_path',
                    type=str,
                    help='path to your huggingface BPE json file')

parser.add_argument('--fp16', action='store_true')

parser = distributed_utils.wrap_arg_parser(parser)

args = parser.parse_args()

# helpers


def exists(val):
    return val is not None


# constants

VAE_PATH = args.vae_path
DALLE_PATH = args.dalle_path
RESUME = exists(DALLE_PATH)
示例#3
0
def arg_setting():
    parser = argparse.ArgumentParser()

    group = parser.add_mutually_exclusive_group(required=False)

    group.add_argument('--vae_path',
                       type=str,
                       help='path to your trained discrete VAE')

    group.add_argument('--dalle_path',
                       type=str,
                       help='path to your partially trained DALL-E')

    parser.add_argument(
        '--image_text_folder',
        type=str,
        default='../../CUB_BIRD',
        help='path to your folder of images and text for learning the DALL-E')

    parser.add_argument(
        '--truncate_captions',
        dest='truncate_captions',
        type=lambda s: s.lower() in ['true', 't', 'yes', '1'],
        default=False,
        help=
        'Captions passed in which exceed the max token length will be truncated if this is set.'
    )

    #     parser.add_argument(
    #         '--truncate_captions',
    #         dest='truncate_captions',
    #         action='store_true',
    #         help=
    #         'Captions passed in which exceed the max token length will be truncated if this is set.'
    #     )

    parser.add_argument('--random_resize_crop_lower_ratio',
                        dest='resize_ratio',
                        type=float,
                        default=0.75,
                        help='Random resized crop lower ratio')

    parser.add_argument('--chinese', dest='chinese', action='store_true')

    parser.add_argument(
        '--taming',
        dest='taming',
        type=lambda s: s.lower() in ['true', 't', 'yes', '1'],
        #                         action='store_true'
    )

    parser.add_argument('--hug', dest='hug', action='store_true')

    parser.add_argument('--bpe_path',
                        type=str,
                        help='path to your huggingface BPE json file')

    parser.add_argument('--dalle_output_file_name',
                        type=str,
                        default="dalle.pt",
                        help='output_file_name')
    parser.add_argument(
        '--fp16',
        type=lambda s: s.lower() in ['true', 't', 'yes', '1'],
        default=False,
        #         action='store_true',
        help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.'
    )

    parser.add_argument(
        '--wandb_name',
        default='dalle_train_transformer',
        help=
        'Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`'
    )

    parser = distributed_utils.wrap_arg_parser(parser)

    train_group = parser.add_argument_group('Training settings')

    train_group.add_argument('--epochs',
                             default=20,
                             type=int,
                             help='Number of epochs')
    train_group.add_argument('--save_every_n_steps',
                             default=1000,
                             type=int,
                             help='Save a checkpoint every n steps')
    train_group.add_argument('--batch_size',
                             default=4,
                             type=int,
                             help='Batch size')

    train_group.add_argument('--learning_rate',
                             default=3e-4,
                             type=float,
                             help='Learning rate')

    train_group.add_argument('--clip_grad_norm',
                             default=0.5,
                             type=float,
                             help='Clip gradient norm')

    train_group.add_argument('--lr_decay',
                             dest='lr_decay',
                             action='store_true')

    model_group = parser.add_argument_group('Model settings')

    model_group.add_argument('--dim',
                             default=512,
                             type=int,
                             help='Model dimension')

    model_group.add_argument('--heads',
                             default=8,
                             type=int,
                             help='Model number of heads')

    model_group.add_argument('--dim_head',
                             default=64,
                             type=int,
                             help='Model head dimension')

    model_group.add_argument(
        '--reversible',
        dest='reversible',
        type=lambda s: s.lower() in ['true', 't', 'yes', '1'],
        default=False
        #                              action='store_true'
    )

    model_group.add_argument('--loss_img_weight',
                             default=7,
                             type=int,
                             help='Image loss weight')

    model_group.add_argument('--text_seq_len',
                             default=256,
                             type=int,
                             help='Text sequence length')

    model_group.add_argument('--depth',
                             default=2,
                             type=int,
                             help='Model depth')

    model_group.add_argument(
        '--attn_types',
        default='full',
        type=str,
        help=
        'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.'
    )

    parser.add_argument('--num_worker', type=int, default=4)
    parser.add_argument('--model_dir', type=str, default='model/dalle/')

    parser.add_argument('--num-gpus', type=int, default=8)

    parser.add_argument('--hosts', type=list, default=[])
    parser.add_argument('--current-host', type=str, default="")
    parser.add_argument(
        '--backend',
        type=str,
        default='nccl',
        help=
        'backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)'
    )

    parser.add_argument('--output_s3', type=str, default="s3://")

    args = parser.parse_args()
    return args