コード例 #1
0
    def add_model_specific_args(parent_parser):  # pragma: no cover

        parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser], add_help=False)

        # param overwrites
        parser.set_defaults(gradient_clip_val=1.0,
                            model_save_monitor_value='val_acc',
                            model_save_monitor_mode='max',
                            early_stop_metric='val_loss',
                            early_stop_patience=10,
                            early_stop_mode='min',
                            val_check_interval=0.02,
                            max_nb_epochs=3
                            )

        running_group = parser.add_argument_group(title='Training/Evaluation options')
        model_group = parser.add_argument_group(title='Model options')
        tokenizer_group = parser.add_argument_group(title='Tokenizer options')
        task_group = parser.add_argument_group(title='Task options')

        # Add arguments to those groups

        model_group.add_argument('--model_type', type=str, required=True)
        model_group.add_argument('--model_weight', type=str, required=True)
        model_group.add_argument('--ci_alpha', type=float, default=0.95)

        tokenizer_group.add_argument('--tokenizer_type', type=str, default=None)
        tokenizer_group.add_argument('--tokenizer_weight', type=str, default=None)

        task_group.add_argument('--task_name',
                                choices=['qqp', 'alphanli', 'snli', 'hellaswag', 'physicaliqa', 'socialiqa', 'vcrqa', 'vcrqr'],
                                required=True)
        task_group.add_argument('--task_config_file', type=str, required=True)
        task_group.add_argument('--task_cache_dir', type=str, required=True)

        running_group.add_argument('--running_config_file', type=str, required=True)

        parser.add_argument('--test_input_dir', type=str, required=False, default=None)
        parser.add_argument('--output_dir', type=str, required=False, default=None)
        parser.add_argument('--weights_path', type=str, required=False, default=None)
        parser.add_argument('--tags_csv', type=str, required=False, default=None)

        return parser
コード例 #2
0
    # APPNP params 
    parser.add_argument('--edge-drop', default=0.5, type=float)
    parser.add_argument('--alpha', default=0.1, type=float, help='teleporting probability')
    parser.add_argument('--k', type=int,default=10, help='number of propagation steps.') 
    
    #GTN params 
    parser.add_argument('--sample-number', type=int, default=32, help='characteristic function sample number. Will generate feats_t.npy.')
    
    parser.add_argument('--concat', action='store_true', default=True, help='concat neighbors with itself.')
    
    # GResNet params 
    parser.add_argument('--residual-type', type=str, default='graph_raw', choices=['naive','raw','graph_naive','graph_raw'])
    
    parser.add_argument("--self-loop", action='store_true',
            help="graph self-loop (default=False)")
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    if not args.ckpt_name:
        d = datetime.now() 
        time_str = d.strftime('%m-%dT%H%M')
        args.ckpt_name = '{}_{}_l{}_h{}_{}'.format(args.model,
                args.dataset, args.n_layers, args.n_hidden, time_str)

    checkpoint_callback = ModelCheckpoint(
        filepath=f'./checkpoints/{args.ckpt_name}',
        save_best_only=True,
コード例 #3
0
    def add_model_specific_args(parent_parser, root_dir):  # pragma: no cover
        parser = HyperOptArgumentParser(strategy=parent_parser.strategy,
                                        parents=[parent_parser])

        parser.set_defaults(device=torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu'))

        # network params
        parser.opt_list('--gcn_mid_dim',
                        default=256,
                        type=int,
                        options=[128, 256, 512, 1024],
                        tunable=True)
        parser.opt_list('--gcn_output_dim',
                        default=256,
                        type=int,
                        options=[128, 256, 512, 1024],
                        tunable=True)
        parser.opt_list('--txtcnn_drop_prob',
                        default=0.0,
                        options=[0.0, 0.1, 0.2],
                        type=float,
                        tunable=True)
        parser.opt_list('--gcn_drop_prob',
                        default=0.5,
                        options=[0.2, 0.5],
                        type=float,
                        tunable=True)
        parser.opt_list('--warploss_margin',
                        default=0.4,
                        type=float,
                        tunable=True)
        parser.opt_list('--freeze_embeddings',
                        default=True,
                        options=[True, False],
                        type=lambda x: (str(x).lower() == 'true'),
                        tunable=True)

        parser.opt_list('--txtcnn_pfilter_num1',
                        default=64,
                        options=[16, 32, 64, 128],
                        type=int,
                        tunable=True)
        parser.opt_list('--txtcnn_pfilter_num2',
                        default=64,
                        options=[16, 32, 64, 128],
                        type=int,
                        tunable=True)
        parser.opt_list('--txtcnn_pfilter_num3',
                        default=64,
                        options=[16, 32, 64, 128],
                        type=int,
                        tunable=True)
        parser.opt_list('--txtcnn_pfilter_num4',
                        default=64,
                        options=[16, 32, 64, 128],
                        type=int,
                        tunable=True)
        parser.opt_list('--txtcnn_rfilter_num1',
                        default=64,
                        options=[16, 32, 64, 128],
                        type=int,
                        tunable=True)
        parser.opt_list('--txtcnn_rfilter_num2',
                        default=32,
                        options=[16, 32, 64, 128],
                        type=int,
                        tunable=True)

        # data
        parser.add_argument('--data_root',
                            default=os.path.join(root_dir, 'data'),
                            type=str)
        parser.add_argument('--top_t', default=6, type=int)
        parser.add_argument('--total_onehop', default=20, type=int)
        parser.add_argument('--total', default=50, type=int)
        parser.add_argument('--shuffle',
                            default=True,
                            type=lambda x: (str(x).lower() == 'true'))
        parser.add_argument('--train_div', default=1.0, type=float)

        # training params (opt)
        parser.opt_list('--batch_size',
                        default=64,
                        options=[32, 64, 128, 256],
                        type=int,
                        tunable=False)
        parser.opt_list('--max_nb_epochs',
                        default=8,
                        options=[256, 512, 1024],
                        type=int,
                        tunable=False)
        parser.opt_list('--learning_rate',
                        default=0.0005,
                        options=[0.0001, 0.0005, 0.001],
                        type=float,
                        tunable=True)
        parser.opt_list('--weight_decay',
                        default=0.001,
                        options=[0.0001, 0.0005, 0.001],
                        type=float,
                        tunable=True)
        parser.add_argument('--model_save_path',
                            default=os.path.join(root_dir, 'experiment'),
                            type=str)
        return parser
コード例 #4
0
ファイル: cpc_model.py プロジェクト: williamFalcon/WpJNw9n6sC
    def add_model_specific_args(parent_parser, root_dir):
        parser = HyperOptArgumentParser(strategy=parent_parser.strategy,
                                        parents=[parent_parser])

        parser.set_defaults(nb_hopt_trials=1000)
        parser.set_defaults(min_nb_epochs=1000)
        parser.set_defaults(max_nb_epochs=1100)
        parser.set_defaults(early_stop_metric='val_nce')
        parser.set_defaults(model_save_monitor_value='val_nce')
        parser.set_defaults(model_save_monitor_mode='min')
        parser.set_defaults(early_stop_mode='min')

        # CIFAR 10
        dataset_name = 'CIFAR10'
        image_height = 32
        nb_classes = 10
        patch_size = 8
        patch_overlap = 4

        # dataset options
        parser.opt_list('--nb_classes',
                        default=nb_classes,
                        type=int,
                        options=[10],
                        tunable=False)
        parser.opt_list('--patch_size',
                        default=patch_size,
                        type=int,
                        options=[10],
                        tunable=False)
        parser.opt_list('--patch_overlap',
                        default=patch_overlap,
                        type=int,
                        options=[10],
                        tunable=False)

        # network params
        parser.add_argument('--image_height', type=int, default=image_height)

        # trainin params
        parser.add_argument('--dataset_name', type=str, default=dataset_name)
        parser.add_argument('--batch_size',
                            type=int,
                            default=200,
                            help='input batch size (default: 200)')
        parser.opt_list(
            '--learning_rate',
            type=float,
            default=0.0002,
            options=[
                2e-4 * (1 / 64),
                2e-4 * (1 / 32),
                2e-4 * (1 / 16),
                2e-4 * (1 / 8),
                2e-4 * (1 / 4),
                2e-4 * (1 / 2),
                2e-4 * (1 / 4),
                2e-4,
                2e-4 * 4,  #2e-4*4,
                2e-4 * 8,
            ],
            tunable=False)

        # data
        parser.opt_list('--cifar10_root',
                        default=f'{root_dir}/fisherman/datasets',
                        type=str,
                        tunable=False)
        return parser