コード例 #1
0
 def log_record(self, config):
     """
     To record the loss value of testing data during training
     :param config:
     :return:
     """
     log_dir = "log_{}".format('AI_GAN')
     tl.files.exists_or_mkdir(log_dir)
     self.log_all, self.log_all_filename = utils.logging_setup(log_dir)
     utils.log_config(self.log_all_filename, config)
コード例 #2
0
    def __init__(self, outtoken, hidden, enc_layers=1, dec_layers=1, nhead=1, dropout=0.1, pretrained=True):
        super(TransformerModel, self).__init__()

        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        
        self.backbone_name = 'resnet50'
        self.backbone = models.resnet50(pretrained=pretrained)
        self.backbone.fc = nn.Conv2d(2048, int(hidden/2), 1)

        self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.decoder = nn.Embedding(outtoken, hidden)
        self.pos_decoder = PositionalEncoding(hidden, dropout)
        self.transformer = nn.Transformer(d_model=hidden, nhead=nhead, num_encoder_layers=enc_layers,
                                          num_decoder_layers=dec_layers, dim_feedforward=hidden * 4, dropout=dropout,
                                          activation='relu')

        self.fc_out = nn.Linear(hidden, outtoken)
        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None

        log_config(self)
コード例 #3
0
ファイル: main.py プロジェクト: timole/sopernovus
def main():
    args = parse_args()
    inputFileNameUsage = args['input_file_usage']
    inputFileNameOperative = args['input_file_operative']
    outputApplicationsFileName = args['output_file_applications']
    outputUsersFileName = args['output_file_users']
    predictionOutputFileName = args['prediction_output_file']

    utils.log_config()    
    logger = logging.getLogger(__name__)

    logger.info("Data file: {}".format(inputFileNameUsage))
    if inputFileNameOperative:
        logger.info("Operative data file: {}".format(inputFileNameOperative))
    else:
        logger.info("No operative data available.")

    startTime = datetime.datetime.now()

    # exported to global scope for debugging purposes
    global df
    df = data_helper.import_data(inputFileNameUsage)

    global odf

    if inputFileNameOperative:
        odf = data_helper.import_operative_data(inputFileNameOperative)
    else:
        odf = None

    logger.info("N of events: {}, from {} to {} ".format(len(df), df['datetime'].min(), df['datetime'].max()))

    create_user_summary(outputUsersFileName)
    create_application_summary(outputApplicationsFileName)
#    create_prediction_summary(predictionOutputFileName)

    print_stats(startTime)
コード例 #4
0
                   save_dir + "/network.epoch{}".format(epoch + 1))
        torch.save(optimizer.state_dict(),
                   save_dir + "/network.optimizer.epoch{}".format(epoch + 1))
        adjust_learning_rate(optimizer, epoch + 1)
        print("EPOCH {} end".format(epoch + 1))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--train')
    parser.add_argument('--save_dir')
    args = parser.parse_args()
    train_script = args.train
    save_dir = args.save_dir

    log_config()
    if hp.baseline:
        model = Baseline()
    elif hp.text_based:
        model = text_based()
    elif hp.combined:
        model = Combined()
    elif hp.decoder_type == 'Attention':
        model = AttModel()

    model.apply(init_weight)

    if torch.cuda.device_count() > 1:
        # multi-gpu configuration
        ngpu = torch.cuda.device_count()
        device_ids = list(range(ngpu))
コード例 #5
0
 def log_record(self,config):
     log_dir = "log_{}".format('BrainQuantAI_Part_one')
     tl.files.exists_or_mkdir(log_dir)
     self.log_all, self.log_all_filename = utils.logging_setup(log_dir)
     utils.log_config(self.log_all_filename, config)
コード例 #6
0
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=wilds.supported_datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )
    parser.add_argument('--pretrained_model_path',
                        default=None,
                        type=str,
                        help="Specify a path to a pretrained model's weights")

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to download the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.'
    )
    parser.add_argument('--version', default=None, type=str)

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--unlabeled_n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--unlabeled_batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        'Number of batches to process before stepping optimizer and/or schedulers. If > 1, we simulate having a larger effective batch size (though batchnorm behaves differently).'
    )

    # Active Learning
    parser.add_argument('--active_learning',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument(
        '--target_split',
        default="test",
        type=str,
        help=
        'Split from which to sample labeled examples and use as unlabeled data for self-training.'
    )
    parser.add_argument(
        '--use_target_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=True,
        help=
        "If false, we sample target labels and remove them from the eval set, but don't actually train on them."
    )
    parser.add_argument(
        '--use_source_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        "Train on labeled source examples (perhaps in addition to labeled target examples.)"
    )
    parser.add_argument(
        '--upsample_target_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        "If concatenating source labels, upsample target labels s.t. our labeled batches are 1/2 src and 1/2 tgt."
    )
    parser.add_argument('--selection_function',
                        choices=supported.selection_functions)
    parser.add_argument(
        '--selection_function_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        "keyword arguments for selection fn passed as key1=value1 key2=value2")
    parser.add_argument(
        '--selectby_fields',
        nargs='+',
        help=
        "If set, acts like a grouper and n_shots are acquired per selection group (e.g. y x hospital selects K examples per y x hospital)."
    )
    parser.add_argument('--n_shots',
                        type=int,
                        help="number of shots (labels) to actively acquire")

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )
    parser.add_argument('--freeze_featurizer',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        help="Only train classifier weights")
    parser.add_argument(
        '--teacher_model_path',
        type=str,
        help=
        'Path to teacher model weights. If this is defined, pseudolabels will first be computed for unlabeled data before anything else runs.'
    )
    parser.add_argument('--dropout_rate', type=float)

    # Transforms
    parser.add_argument('--transform', choices=supported.transforms)
    parser.add_argument('--additional_labeled_transform',
                        type=parse_none,
                        choices=supported.additional_transforms)
    parser.add_argument('--additional_unlabeled_transform',
                        type=parse_none,
                        nargs='+',
                        choices=supported.additional_transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)
    parser.add_argument(
        '--randaugment_n',
        type=int,
        help=
        'N parameter of RandAugment - the number of transformations to apply.')

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--maml_first_order',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--metalearning_k', type=int)
    parser.add_argument('--metalearning_adapt_lr', type=float)
    parser.add_argument('--metalearning_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--self_training_labeled_weight',
                        type=float,
                        help='Weight of labeled loss')
    parser.add_argument('--self_training_unlabeled_weight',
                        type=float,
                        help='Weight of unlabeled loss')
    parser.add_argument('--self_training_threshold', type=float)
    parser.add_argument(
        '--pseudolabel_T2',
        type=float,
        help=
        'Percentage of total iterations at which to end linear scheduling and hold unlabeled weight at the max value'
    )
    parser.add_argument('--soft_pseudolabels',
                        default=False,
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--algo_log_metric')

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--process_outputs_function',
                        choices=supported.process_outputs_functions)
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--eval_splits', nargs='+', default=['val', 'test'])
    parser.add_argument(
        '--save_splits',
        nargs='+',
        default=['test'],
        help=
        'If save_pred_step or save_pseudo_step are set, then this sets which splits to save pred / pseudos for. Must be a subset of eval_splits.'
    )
    parser.add_argument('--eval_additional_every',
                        default=1,
                        type=int,
                        help='Eval additional splits every _ training epochs.')
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--eval_epoch',
        default=None,
        type=int,
        help=
        'If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.'
    )

    # Misc
    parser.add_argument('--device', type=int, nargs='+', default=[0])
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_model_step', type=int)
    parser.add_argument('--save_pred_step', type=int)
    parser.add_argument('--save_pseudo_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--resume',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        'Whether to resume from the most recent saved model in the current log_dir.'
    )

    # Weights & Biases
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--wandb_api_key_path',
        type=str,
        help=
        "Path to Weights & Biases API Key. If use_wandb is set to True and this argument is not specified, user will be prompted to authenticate."
    )
    parser.add_argument('--wandb_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={},
                        help="Will be passed directly into wandb.init().")

    config = parser.parse_args()
    config = populate_defaults(config)

    # Set device
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        if len(config.device) > device_count:
            raise ValueError(
                f"Specified {len(config.device)} devices, but only {device_count} devices found."
            )
        config.use_data_parallel = len(config.device) > 1
        try:
            device_str = ",".join(map(str, config.device))
            config.device = torch.device(f"cuda:{device_str}")
        except RuntimeError as e:
            print(
                f"Failed to initialize CUDA. Using torch.device('cuda') instead. Error: {str(e)}"
            )
            config.device = torch.device("cuda")
    else:
        config.use_data_parallel = False
        config.device = torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        config.mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        config.mode = 'a'
    else:
        resume = False
        config.mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), config.mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Algorithms that use unlabeled data must be run in active learning mode,
    # because otherwise we have no unlabeled data.
    if config.algorithm in ["PseudoLabel", "FixMatch", "NoisyStudent"]:
        assert config.active_learning

    # Data
    full_dataset = wilds.get_dataset(dataset=config.dataset,
                                     version=config.version,
                                     root_dir=config.root_dir,
                                     download=config.download,
                                     split_scheme=config.split_scheme,
                                     **config.dataset_kwargs)

    # In this project, we sometimes train on batches of mixed splits, e.g. some train labeled examples and test labeled examples
    # Within each batch, we may want to sample uniformly across split, or log the train v. test label balance
    # To facilitate this, we'll hack the WILDS dataset to include each point's split in the metadata array
    add_split_to_wilds_dataset_metadata_array(full_dataset)

    # To modify data augmentation, modify the following code block.
    # If you want to use transforms that modify both `x` and `y`,
    # set `do_transform_y` to True when initializing the `WILDSSubset` below.
    train_transform = initialize_transform(
        transform_name=config.transform,
        config=config,
        dataset=full_dataset,
        additional_transform=config.additional_labeled_transform,
        is_training=True)
    eval_transform = initialize_transform(transform_name=config.transform,
                                          config=config,
                                          dataset=full_dataset,
                                          is_training=False)

    # Define any special transforms for the algorithms that use unlabeled data
    # if config.algorithm == "FixMatch":
    #     # For FixMatch, we need our loader to return batches in the form ((x_weak, x_strong), m)
    #     # We do this by initializing a special transform function
    #     unlabeled_train_transform = initialize_transform(
    #         config.transform, config, full_dataset, is_training=True, additional_transform="fixmatch"
    #     )
    # else:
    unlabeled_train_transform = initialize_transform(
        config.transform,
        config,
        full_dataset,
        is_training=True,
        additional_transform=config.additional_unlabeled_transform)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False

        data = full_dataset.get_subset(split,
                                       frac=config.frac,
                                       transform=transform)

        datasets[split] = configure_split_dict(
            data=data,
            split=split,
            split_name=full_dataset.split_names[split],
            get_train=(split == 'train'),
            get_eval=(split != 'train'),
            verbose=verbose,
            grouper=train_grouper,
            batch_size=config.batch_size,
            config=config)

        pseudolabels = None
        if config.algorithm == "NoisyStudent" and config.target_split == split:
            # Infer teacher outputs on unlabeled examples in sequential order
            # During forward pass, ensure we are not shuffling and not applying strong augs
            print(
                f"Inferring teacher pseudolabels on {config.target_split} for Noisy Student"
            )
            assert config.teacher_model_path is not None
            if not config.teacher_model_path.endswith(".pth"):
                # Use the best model
                config.teacher_model_path = os.path.join(
                    config.teacher_model_path,
                    f"{config.dataset}_seed:{config.seed}_epoch:best_model.pth"
                )
            teacher_model = initialize_model(
                config, infer_d_out(full_dataset)).to(config.device)
            load(teacher_model,
                 config.teacher_model_path,
                 device=config.device)
            # Infer teacher outputs on weakly augmented unlabeled examples in sequential order
            weak_transform = initialize_transform(
                transform_name=config.transform,
                config=config,
                dataset=full_dataset,
                is_training=True,
                additional_transform="weak")
            unlabeled_split_dataset = full_dataset.get_subset(
                split, transform=weak_transform, frac=config.frac)
            sequential_loader = get_eval_loader(
                loader=config.eval_loader,
                dataset=unlabeled_split_dataset,
                grouper=train_grouper,
                batch_size=config.unlabeled_batch_size,
                **config.loader_kwargs)
            pseudolabels = infer_predictions(teacher_model, sequential_loader,
                                             config)
            del teacher_model

        if config.active_learning and config.target_split == split:
            datasets[split]['label_manager'] = LabelManager(
                subset=data,
                train_transform=train_transform,
                eval_transform=eval_transform,
                unlabeled_train_transform=unlabeled_train_transform,
                pseudolabels=pseudolabels)

    if config.use_wandb:
        initialize_wandb(config)

    # Logging dataset info
    # Show class breakdown if feasible
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1 and full_dataset.n_classes <= 10:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    ## Schedulers are initialized as if we will iterate over "train" split batches.
    ## If we train on another split (e.g. labeled test), we'll re-initialize schedulers later using algorithm.change_n_train_steps()
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)
    if config.freeze_featurizer: freeze_features(algorithm)

    if config.active_learning:
        select_grouper = CombinatorialGrouper(
            dataset=full_dataset, groupby_fields=config.selectby_fields)
        selection_fn = initialize_selection_function(
            config, algorithm, select_grouper, algo_grouper=train_grouper)

    # Resume from most recent model in log_dir
    model_prefix = get_model_prefix(datasets['train'], config)
    if not config.eval_only:
        ## If doing active learning, expects to load a model trained on source
        resume_success = False
        if config.resume:
            save_path = model_prefix + 'epoch:last_model.pth'
            if not os.path.exists(save_path):
                epochs = [
                    int(file.split('epoch:')[1].split('_')[0])
                    for file in os.listdir(config.log_dir)
                    if file.endswith('.pth')
                ]
                if len(epochs) > 0:
                    latest_epoch = max(epochs)
                    save_path = model_prefix + f'epoch:{latest_epoch}_model.pth'
            try:
                prev_epoch, best_val_metric = load(algorithm, save_path,
                                                   config.device)
                # also load previous selections

                epoch_offset = prev_epoch + 1
                config.selection_function_kwargs[
                    'load_selection_path'] = config.log_dir
                logger.write(
                    f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}\n'
                )
                resume_success = True
            except FileNotFoundError:
                pass

        if resume_success == False:
            epoch_offset = 0
            best_val_metric = None

        # Log effective batch size
        logger.write((
            f'\nUsing gradient_accumulation_steps {config.gradient_accumulation_steps} means that'
        ) + (
            f' the effective labeled batch size is {config.batch_size * config.gradient_accumulation_steps}'
        ) + (
            f' and the effective unlabeled batch size is {config.unlabeled_batch_size * config.gradient_accumulation_steps}'
            if config.unlabeled_batch_size else '') + (
                '. Updates behave as if torch loaders have drop_last=False\n'))

        if config.active_learning:
            # create new labeled/unlabeled test splits
            train_split, unlabeled_split = run_active_learning(
                selection_fn=selection_fn,
                datasets=datasets,
                grouper=train_grouper,
                config=config,
                general_logger=logger,
                full_dataset=full_dataset)
            # reset schedulers, which were originally initialized to schedule based on the 'train' split
            # one epoch = one pass over labeled data
            algorithm.change_n_train_steps(
                new_n_train_steps=infer_n_train_steps(
                    datasets[train_split]['train_loader'], config),
                config=config)
        else:
            train_split = "train"
            unlabeled_split = None

        train(algorithm=algorithm,
              datasets=datasets,
              train_split=train_split,
              val_split="val",
              unlabeled_split=unlabeled_split,
              general_logger=logger,
              config=config,
              epoch_offset=epoch_offset,
              best_val_metric=best_val_metric)

    else:
        if config.eval_epoch is None:
            eval_model_path = model_prefix + 'epoch:best_model.pth'
        else:
            eval_model_path = model_prefix + f'epoch:{config.eval_epoch}_model.pth'
        best_epoch, best_val_metric = load(algorithm, eval_model_path,
                                           config.device)
        if config.eval_epoch is None:
            epoch = best_epoch
        else:
            epoch = config.eval_epoch

        if config.active_learning:
            # create new labeled/unlabeled test splits
            config.selection_function_kwargs[
                'load_selection_path'] = config.log_dir
            run_active_learning(selection_fn=selection_fn,
                                datasets=datasets,
                                grouper=train_grouper,
                                config=config,
                                general_logger=logger,
                                full_dataset=full_dataset)

        evaluate(algorithm=algorithm,
                 datasets=datasets,
                 epoch=epoch,
                 general_logger=logger,
                 config=config)

    if config.use_wandb:
        wandb.finish()
    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()
コード例 #7
0
    def __init__(self,
                 outtoken,
                 hidden,
                 enc_layers=1,
                 dec_layers=1,
                 nhead=1,
                 dropout=0.1):
        super(TransformerModel, self).__init__()

        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        self.backbone_name = 'conv(64)->conv(64)->conv(128)->conv(256)->conv(256)->conv(512)->conv(512)'

        self.conv0 = Conv2d(1,
                            64,
                            kernel_size=(3, 3),
                            stride=(1, 1),
                            padding=(1, 1))
        self.conv1 = Conv2d(64,
                            128,
                            kernel_size=(3, 3),
                            stride=(1, 1),
                            padding=(1, 1))
        self.conv2 = Conv2d(128,
                            256,
                            kernel_size=(3, 3),
                            stride=(2, 1),
                            padding=(1, 1))
        self.conv3 = Conv2d(256,
                            256,
                            kernel_size=(3, 3),
                            stride=(1, 1),
                            padding=(1, 1))
        self.conv4 = Conv2d(256,
                            512,
                            kernel_size=(3, 3),
                            stride=(2, 1),
                            padding=(1, 1))
        self.conv5 = Conv2d(512,
                            512,
                            kernel_size=(3, 3),
                            stride=(1, 1),
                            padding=(1, 1))
        self.conv6 = Conv2d(512, 512, kernel_size=(2, 1), stride=(1, 1))

        self.pool1 = MaxPool2d(kernel_size=2,
                               stride=2,
                               padding=0,
                               dilation=1,
                               ceil_mode=False)
        self.pool3 = MaxPool2d(kernel_size=2,
                               stride=2,
                               padding=0,
                               dilation=1,
                               ceil_mode=False)
        self.pool5 = MaxPool2d(kernel_size=(2, 2),
                               stride=(2, 1),
                               padding=(0, 1),
                               dilation=1,
                               ceil_mode=False)

        self.bn0 = BatchNorm2d(64,
                               eps=1e-05,
                               momentum=0.1,
                               affine=True,
                               track_running_stats=True)
        self.bn1 = BatchNorm2d(128,
                               eps=1e-05,
                               momentum=0.1,
                               affine=True,
                               track_running_stats=True)
        self.bn2 = BatchNorm2d(256,
                               eps=1e-05,
                               momentum=0.1,
                               affine=True,
                               track_running_stats=True)
        self.bn3 = BatchNorm2d(256,
                               eps=1e-05,
                               momentum=0.1,
                               affine=True,
                               track_running_stats=True)
        self.bn4 = BatchNorm2d(512,
                               eps=1e-05,
                               momentum=0.1,
                               affine=True,
                               track_running_stats=True)
        self.bn5 = BatchNorm2d(512,
                               eps=1e-05,
                               momentum=0.1,
                               affine=True,
                               track_running_stats=True)
        self.bn6 = BatchNorm2d(512,
                               eps=1e-05,
                               momentum=0.1,
                               affine=True,
                               track_running_stats=True)

        self.activ = LeakyReLU()

        self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.decoder = nn.Embedding(outtoken, hidden)
        self.pos_decoder = PositionalEncoding(hidden, dropout)
        self.transformer = nn.Transformer(d_model=hidden,
                                          nhead=nhead,
                                          num_encoder_layers=enc_layers,
                                          num_decoder_layers=dec_layers,
                                          dim_feedforward=hidden * 4,
                                          dropout=dropout)

        self.fc_out = nn.Linear(hidden, outtoken)
        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None

        log_config(self)
コード例 #8
0
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=wilds.supported_datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to downloads the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.'
    )
    parser.add_argument('--version', default=None, type=str)

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )

    # Transforms
    parser.add_argument('--train_transform', choices=supported.transforms)
    parser.add_argument('--eval_transform', choices=supported.transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--dann_lambda', type=float)
    parser.add_argument('--dann_domain_layers', type=int,
                        default=1)  # hidden layers
    parser.add_argument('--dann_label_layers', type=int,
                        default=1)  # hidden layers
    parser.add_argument('--domain_loss_function', choices=supported.losses)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--algo_log_metric')

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--process_outputs_function',
                        choices=supported.process_outputs_functions)
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--eval_splits', nargs='+', default=[])
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--eval_epoch',
        default=None,
        type=int,
        help=
        'If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.'
    )

    # Misc
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_pred',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--resume',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    config = parser.parse_args()
    config = populate_defaults(config)

    # set device
    config.device = torch.device("cuda:" + str(
        config.device)) if torch.cuda.is_available() else torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        mode = 'a'
    else:
        resume = False
        mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Data
    full_dataset = wilds.get_dataset(dataset=config.dataset,
                                     version=config.version,
                                     root_dir=config.root_dir,
                                     download=config.download,
                                     split_scheme=config.split_scheme,
                                     **config.dataset_kwargs)

    # To implement data augmentation (i.e., have different transforms
    # at training time vs. test time), modify these two lines:
    train_transform = initialize_transform(
        transform_name=config.train_transform,
        config=config,
        dataset=full_dataset)
    eval_transform = initialize_transform(transform_name=config.eval_transform,
                                          config=config,
                                          dataset=full_dataset)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split, frac=config.frac, transform=transform)

        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=config.n_groups_per_batch,
                **config.loader_kwargs)
        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose

        # Loggers
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_eval.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))
        datasets[split]['algo_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_algo.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))

        if config.use_wandb:
            initialize_wandb(config)

    # Logging dataset info
    # Show class breakdown if feasible
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1 and full_dataset.n_classes <= 10:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)

    model_prefix = get_model_prefix(datasets['train'], config)
    if not config.eval_only:
        ## Load saved results if resuming
        resume_success = False
        if resume:
            save_path = model_prefix + 'epoch:last_model.pth'
            if not os.path.exists(save_path):
                epochs = [
                    int(file.split('epoch:')[1].split('_')[0])
                    for file in os.listdir(config.log_dir)
                    if file.endswith('.pth')
                ]
                if len(epochs) > 0:
                    latest_epoch = max(epochs)
                    save_path = model_prefix + f'epoch:{latest_epoch}_model.pth'
            try:
                prev_epoch, best_val_metric = load(algorithm, save_path)
                epoch_offset = prev_epoch + 1
                logger.write(
                    f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}'
                )
                resume_success = True
            except FileNotFoundError:
                pass

        if resume_success == False:
            epoch_offset = 0
            best_val_metric = None

        train(algorithm=algorithm,
              datasets=datasets,
              general_logger=logger,
              config=config,
              epoch_offset=epoch_offset,
              best_val_metric=best_val_metric)
    else:
        if config.eval_epoch is None:
            eval_model_path = model_prefix + 'epoch:best_model.pth'
        else:
            eval_model_path = model_prefix + f'epoch:{config.eval_epoch}_model.pth'
        best_epoch, best_val_metric = load(algorithm, eval_model_path)
        if config.eval_epoch is None:
            epoch = best_epoch
        else:
            epoch = config.eval_epoch
        evaluate(algorithm=algorithm,
                 datasets=datasets,
                 epoch=epoch,
                 general_logger=logger,
                 config=config)

    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()
コード例 #9
0
def run(config, device, epochs, replications, seed, num_data_workers):
    """
    Run an experiment of the given config.

    A MLFlow experiment will be set according to
    the name in the config. A BaseTask will be build
    and the train function called. Each call of the run function
    with the same config will be a run of this experiment.
    If replications is set to a number bigger than one, a nested
    run is created and the task executed this number of times.

    When debugging, nothing is written to disk to avoid
    cluttering the results directory.

    :param config: path to the config JSON file or config dict
    :param device: device to train on
    :param epochs: epochs to train for
    :param replications: number of times to replicate this run
    :param seed: random seed to use
    :param num_data_workers: number of worker threads for data loading
    """
    # Set seed for randomization
    if seed is not None:
        # Make PyTorch and numpy deterministic
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        print('Fixed randomization. Seed %d' % seed)
        print('#' * 40)
    else:
        # Retrieve default seed, as it is not set
        seed = np.random.randint(np.iinfo(np.int32).max)
        torch.manual_seed(seed)
        np.random.seed(seed)

    # Load config JSON
    if isinstance(config, str):
        print('Run experiment from %s' % config)
        print('#' * 40)
        config = utils.read_config(config)
    elif isinstance(config, dict):
        print('Run experiment with dict named %s' % config['name'])
        print('#' * 40)
    else:
        raise ValueError(
            'Config has to be either a string path or a dict, but is %s.' %
            str(type(dict)))

    # Extract config dicts for components
    name = config['name']
    dataset = config['dataset']
    model = config['model']
    trainer = config['trainer']
    metrics = config['metrics']

    # Setup mlflow experiment
    if utils.is_debugging():
        # Reroute mlflow to tmp file on debugging
        warnings.warn(
            'Debugging mode: MLFlow stuff will be saved to temporary dir.',
            UserWarning)
        mlflow.set_tracking_uri('file:' + utils.build_tmp_dir())
    else:
        script_path = os.path.dirname(__file__)
        root_path = os.path.dirname(script_path)
        mlflow.set_tracking_uri('file:' + root_path)
    mlflow.set_experiment(name)

    # Start the top level run
    nest_runs = True if replications > 0 else False
    with mlflow.start_run(nested=nest_runs):
        # Log parameters to run
        utils.log_config(config)
        mlflow.log_param('max_epochs', epochs)
        mlflow.log_param('seed', seed)
        mlflow.set_tag('device', device)

        if nest_runs:
            # Open child runs for each replication
            mlflow.log_param('replications', replications)
            seeds = np.random.randint(np.iinfo(np.int32).max,
                                      size=replications)
            for i, s in enumerate(seeds):
                print('Run replication %d/%d...' % (i, replications))
                with mlflow.start_run(nested=True):
                    # Log params to child runs
                    utils.log_config(config)
                    mlflow.set_tag('replication', i)

                    # Set derived seed for child runs to make each reproducible
                    mlflow.log_param('seed', s)
                    torch.manual_seed(s)
                    np.random.seed(s)

                    # Execute run
                    task = BaseTask(name, device, dataset, model, trainer,
                                    metrics)
                    task.train(epochs, num_data_workers)
        else:
            # Simply execute top level run, when replications are zero
            task = BaseTask(name, device, dataset, model, trainer, metrics)
            task.train(epochs, num_data_workers)
コード例 #10
0
ファイル: __init__.py プロジェクト: carry156176/auto_nexttao
"""
初始化日志文件
"""
# 导包
import logging
import utils

# 初始化
utils.log_config()

logging.info("Test日志会不会打印!!!")
コード例 #11
0
ファイル: analysis_test.py プロジェクト: timole/sopernovus
        self.assertEqual(apps[apps['applicationId'] == 'LP-100']['authorityId'].item(), 102986)
        self.assertEqual(apps[apps['applicationId'] == 'LP-100']['authorities'].item(), 1)

class TestUsersSummary(unittest.TestCase):

    @classmethod
    def setUpClass(self):
        self.df = data_helper.import_data("data/" + _TEST_DATA_FILE)
        self.users = user_analysis.summarize_users(self.df)

        # print apps once for debugging
        print "Summary of users based on test data:"
        print self.users

    def test_number_of_users(self):
        self.assertEqual(len(self.users), 6)

    def test_user_applications(self):
        users = self.users

        self.assertEqual(users[users['userId'] == 101302]['applicantRoles'].item(), 3)
        self.assertEqual(users[users['userId'] == 101346]['authorityRoles'].item(), 1)


if __name__ == '__main__':
    utils.log_config()
    logger = logging.getLogger(__name__)
    logger.info("Run unit tests")
    unittest.main()

コード例 #12
0
ファイル: analyze_expt.py プロジェクト: hrayrhar/wilds
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=supported.datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )
    parser.add_argument('--analyze_sample', default=1)

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to downloads the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.'
    )

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )

    # Transforms
    parser.add_argument('--train_transform', choices=supported.transforms)
    parser.add_argument('--eval_transform', choices=supported.transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'target resolution. for example --target_resolution 224 224 for standard resnet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--algo_log_metric')
    parser.add_argument('--hsic_beta', type=float)
    parser.add_argument('--grad_penalty_lamb', type=float)
    parser.add_argument(
        '--params_regex',
        type=str,
        help='Regular expression specifying which gradients to penalize.')
    parser.add_argument('--label_cond',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--dann_lamb', type=float)
    parser.add_argument('--dann_dc_name', type=str)

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--eval_splits', nargs='+', default=[])
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--eval_epoch', default=None, type=int)
    parser.add_argument('--save_z',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    # Misc
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--resume',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    config = parser.parse_args()
    config = populate_defaults(config)

    # set device
    config.device = torch.device("cuda:" + str(
        config.device)) if torch.cuda.is_available() else torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        mode = 'a'
    else:
        resume = False
        mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Data
    full_dataset = supported.datasets[config.dataset](
        root_dir=config.root_dir,
        download=config.download,
        split_scheme=config.split_scheme,
        **config.dataset_kwargs)

    # To implement data augmentation (i.e., have different transforms
    # at training time vs. test time), modify these two lines:
    train_transform = initialize_transform(
        transform_name=config.train_transform,
        config=config,
        dataset=full_dataset)
    eval_transform = initialize_transform(transform_name=config.eval_transform,
                                          config=config,
                                          dataset=full_dataset)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split, frac=config.frac, transform=transform)

        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=config.n_groups_per_batch,
                **config.loader_kwargs)
        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose

        # Loggers
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_eval.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))
        datasets[split]['algo_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_algo.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))

        if config.use_wandb:
            initialize_wandb(config)

    # Logging dataset info
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)

    if config.eval_epoch is None:
        eval_model_path = os.path.join(config.log_dir, 'best_model.pth')
    else:
        eval_model_path = os.path.join(config.log_dir,
                                       f'{config.eval_epoch}_model.pth')
    best_epoch, best_val_metric = load(algorithm, eval_model_path)
    if config.eval_epoch is None:
        epoch = best_epoch
    else:
        epoch = config.eval_epoch

    results, z_splits, y_splits, c_splits = evaluate(algorithm=algorithm,
                                                     datasets=datasets,
                                                     epoch=epoch,
                                                     general_logger=logger,
                                                     config=config)

    include_test = config.evaluate_all_splits or 'test' in config.eval_splits

    logistics = all_logistics(z_splits,
                              c_splits,
                              y_splits,
                              epoch=epoch,
                              sample=int(config.analyze_sample),
                              include_test=include_test)

    logistics['G0'] = results['id_val']['acc_avg']
    logistics['G1'] = logistics['val_on_val']
    logistics['G2'] = logistics['trainval_on_val']
    logistics['G3'] = results['val']['acc_avg']

    logistics['I0'] = logistics['c_train']
    logistics['I1'] = logistics['c_val']
    per_class = torch.tensor(list(logistics['c_perclass'].values()))
    logistics['I2'] = torch.mean(per_class).item()

    if include_test:
        logistics['G1_test'] = logistics['test_on_test']
        logistics['G2_test'] = logistics['traintest_on_test']
        logistics['G3_test'] = results['test']['acc_avg']

        logistics['I1_test'] = logistics['c_test']
        per_class = torch.tensor(list(logistics['c_perclass_test'].values()))
        logistics['I2_test'] = torch.mean(per_class).item()

    with (open(os.path.join(config.log_dir, f'tests_epoch_{epoch}.pkl'),
               "wb")) as f:
        pickle.dump(logistics, f)

    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()
コード例 #13
0
def main(e2e_start_time):
    # Parse essential argumentss
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", required=True)
    parser.add_argument("--model_size",
                        default="base",
                        type=str,
                        help="base or large")
    parser.add_argument("--pretrain_tfrecords", type=str)
    parser.add_argument("--phase2", action='store_true')
    parser.add_argument("--fp16_compression", action='store_true')
    parser.add_argument("--amp",
                        action='store_true',
                        help="Whether to use fp16.")
    parser.add_argument("--xla",
                        action='store_true',
                        help="Whether to use xla.")
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--num_train_steps", type=int)
    parser.add_argument("--num_warmup_steps", type=int)
    parser.add_argument("--learning_rate", type=float)
    parser.add_argument("--train_batch_size", type=int)
    parser.add_argument("--max_seq_length", type=int)

    parser.add_argument("--mask_prob", type=float)
    parser.add_argument("--disc_weight", type=float)
    parser.add_argument("--generator_hidden_size", type=float)

    parser.add_argument("--log_freq",
                        type=int,
                        default=10,
                        help="Training metrics logging frequency")
    parser.add_argument("--save_checkpoints_steps", type=int)
    parser.add_argument("--keep_checkpoint_max", type=int)
    parser.add_argument("--restore_checkpoint", default=None, type=str)
    parser.add_argument("--load_weights", action='store_true')
    parser.add_argument("--weights_dir")

    parser.add_argument("--optimizer",
                        default="adam",
                        type=str,
                        help="adam or lamb")
    parser.add_argument(
        "--skip_adaptive",
        action='store_true',
        help="Whether to apply adaptive LR on LayerNorm and biases")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=1,
                        help="Number of Gradient Accumulation steps")
    parser.add_argument("--lr_decay_power",
                        type=float,
                        default=0.5,
                        help="LR decay power")
    parser.add_argument("--opt_beta_1",
                        type=float,
                        default=0.878,
                        help="Optimizer beta1")
    parser.add_argument("--opt_beta_2",
                        type=float,
                        default=0.974,
                        help="Optimizer beta2")
    parser.add_argument("--end_lr", type=float, default=0.0, help="Ending LR")
    parser.add_argument("--log_dir",
                        type=str,
                        default=None,
                        help="Path to store logs")
    parser.add_argument("--results_dir",
                        type=str,
                        default=None,
                        help="Path to store all model results")
    parser.add_argument("--skip_checkpoint",
                        action='store_true',
                        default=False,
                        help="Path to store logs")
    parser.add_argument(
        '--json-summary',
        type=str,
        default=None,
        help=
        'If provided, the json summary will be written to the specified file.')
    args = parser.parse_args()
    config = PretrainingConfig(**args.__dict__)
    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    # Set up tensorflow
    hvd.init()

    args.log_dir = config.log_dir
    # DLLogger
    setup_logger(args)

    set_affinity(hvd.local_rank())
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                   'GPU')
    tf.config.optimizer.set_jit(config.xla)
    #tf.config.optimizer.set_experimental_options({"auto_mixed_precision": config.amp})

    if config.amp:
        policy = tf.keras.mixed_precision.experimental.Policy(
            "mixed_float16", loss_scale="dynamic")
        tf.keras.mixed_precision.experimental.set_policy(policy)
        print('Compute dtype: %s' %
              policy.compute_dtype)  # Compute dtype: float16
        print('Variable dtype: %s' %
              policy.variable_dtype)  # Variable dtype: float32

    #tf.random.set_seed(config.seed)

    # Set up config cont'
    if config.load_weights and config.restore_checkpoint:
        raise ValueError(
            "`load_weights` and `restore_checkpoint` should not be on at the same time."
        )
    if config.phase2 and not config.restore_checkpoint:
        raise ValueError(
            "`phase2` cannot be used without `restore_checkpoint`.")
    utils.heading("Config:")
    log_config(config)

    # Save pretrain configs
    pretrain_config_json = os.path.join(config.checkpoints_dir,
                                        'pretrain_config.json')
    if is_main_process():
        utils.write_json(config.__dict__, pretrain_config_json)
        log("Configuration saved in {}".format(pretrain_config_json))

    # Set up model
    model = PretrainingModel(config)

    # Set up metrics
    metrics = dict()
    metrics["train_perf"] = tf.keras.metrics.Mean(name="train_perf")
    metrics["total_loss"] = tf.keras.metrics.Mean(name="total_loss")
    metrics["masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
        name="masked_lm_accuracy")
    metrics["masked_lm_loss"] = tf.keras.metrics.Mean(name="masked_lm_loss")
    if config.electra_objective:
        metrics["sampled_masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
            name="sampled_masked_lm_accuracy")
        if config.disc_weight > 0:
            metrics["disc_loss"] = tf.keras.metrics.Mean(name="disc_loss")
            metrics["disc_auc"] = tf.keras.metrics.AUC(name="disc_auc")
            metrics["disc_accuracy"] = tf.keras.metrics.Accuracy(
                name="disc_accuracy")
            metrics["disc_precision"] = tf.keras.metrics.Accuracy(
                name="disc_precision")
            metrics["disc_recall"] = tf.keras.metrics.Accuracy(
                name="disc_recall")

    # Set up tensorboard
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(
        config.log_dir, current_time,
        'train_' + str(get_rank()) + '_of_' + str(get_world_size()))
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    # Set up dataset
    dataset = pretrain_utils.get_dataset(config,
                                         config.train_batch_size,
                                         world_size=get_world_size(),
                                         rank=get_rank())
    train_iterator = iter(dataset)

    # Set up optimizer
    optimizer = create_optimizer(init_lr=config.learning_rate,
                                 num_train_steps=config.num_train_steps,
                                 num_warmup_steps=config.num_warmup_steps,
                                 weight_decay_rate=config.weight_decay_rate,
                                 optimizer=config.optimizer,
                                 skip_adaptive=config.skip_adaptive,
                                 power=config.lr_decay_power,
                                 beta_1=config.opt_beta_1,
                                 beta_2=config.opt_beta_2,
                                 end_lr=config.end_lr)

    accumulator = GradientAccumulator()
    if config.amp:
        optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
            optimizer, "dynamic")

    # Set up model checkpoint
    checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     phase2=tf.Variable(False),
                                     optimizer=optimizer,
                                     model=model)
    manager = tf.train.CheckpointManager(
        checkpoint,
        config.checkpoints_dir,
        max_to_keep=config.keep_checkpoint_max)
    if config.restore_checkpoint and config.restore_checkpoint != "latest":
        checkpoint.restore(config.restore_checkpoint)
        log(" ** Restored model checkpoint from {}".format(
            config.restore_checkpoint))
    elif config.restore_checkpoint and config.restore_checkpoint == "latest" and manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        log(" ** Restored model checkpoint from {}".format(
            manager.latest_checkpoint))
    elif config.load_weights:
        model.generator(model.generator.dummy_inputs)
        model.discriminator(model.discriminator.dummy_inputs)
        model.generator.load_weights(
            os.path.join(config.weights_dir, 'generator', 'tf_model.h5'))
        model.discriminator.load_weights(
            os.path.join(config.weights_dir, 'discriminator', 'tf_model.h5'))
    else:
        log(" ** Initializing from scratch.")

    restore_iterator = bool(
        config.restore_checkpoint) and config.restore_checkpoint == "latest"
    # Initialize global step for phase2
    if config.phase2 and not bool(checkpoint.phase2):
        optimizer.iterations.assign(0)
        checkpoint.step.assign(0)
        checkpoint.phase2.assign(True)
        restore_iterator = False
    if bool(checkpoint.phase2):
        manager = tf.train.CheckpointManager(
            checkpoint,
            config.checkpoints_dir,
            checkpoint_name='ckpt-p2',
            max_to_keep=config.keep_checkpoint_max)

    # Set up iterator checkpoint
    iter_checkpoint = tf.train.Checkpoint(train_iterator=train_iterator,
                                          world_size=tf.Variable(
                                              get_world_size()),
                                          rank=tf.Variable(get_rank()))
    iter_manager = tf.train.CheckpointManager(
        iter_checkpoint,
        os.path.join(config.checkpoints_dir,
                     'iter_ckpt_rank_' + '{:02}'.format(get_rank())),
        checkpoint_name='iter_ckpt_rank_' + '{:02}'.format(get_rank()),
        max_to_keep=config.keep_checkpoint_max)
    if restore_iterator and iter_manager.latest_checkpoint:
        ckpt_world_size = tf.train.load_variable(
            iter_manager.latest_checkpoint,
            'world_size/.ATTRIBUTES/VARIABLE_VALUE')
        if ckpt_world_size == get_world_size():
            iter_checkpoint.restore(iter_manager.latest_checkpoint)
            log(" ** Restored iterator checkpoint from {}".format(
                iter_manager.latest_checkpoint),
                all_rank=True)

    utils.heading("Running training")
    accumulator.reset()
    train_start, start_step = time.time(), int(checkpoint.step) - 1
    local_step = 0
    saved_ckpt = False
    while int(checkpoint.step) <= config.num_train_steps:
        saved_ckpt = False
        step = int(checkpoint.step)
        features = next(train_iterator)
        iter_start = time.time()

        # if step == 200: tf.profiler.experimental.start(logdir=train_log_dir)
        total_loss, eval_fn_inputs = train_one_step(
            config,
            model,
            optimizer,
            features,
            accumulator,
            local_step == 1,
            take_step=local_step % args.gradient_accumulation_steps == 0)
        # if step == 300: tf.profiler.experimental.stop()

        metrics["train_perf"].update_state(config.train_batch_size *
                                           get_world_size() /
                                           (time.time() - iter_start))
        metrics["total_loss"].update_state(values=total_loss)
        metric_fn(config, metrics, eval_fn_inputs)

        if (step % args.log_freq
                == 0) and (local_step % args.gradient_accumulation_steps == 0):
            log_info_dict = {
                k: float(v.result().numpy() *
                         100) if "accuracy" in k else float(v.result().numpy())
                for k, v in metrics.items()
            }
            dllogger.log(step=(step, ), data=log_info_dict, verbosity=0)
            log('Step:{step:6d}, Loss:{total_loss:10.6f}, Gen_loss:{masked_lm_loss:10.6f}, Disc_loss:{disc_loss:10.6f}, Gen_acc:{masked_lm_accuracy:6.2f}, '
                'Disc_acc:{disc_accuracy:6.2f}, Perf:{train_perf:4.0f}, Loss Scaler: {loss_scale}, Elapsed: {elapsed}, ETA: {eta}, '
                .format(step=step,
                        **log_info_dict,
                        loss_scale=optimizer.loss_scale if config.amp else 1,
                        elapsed=utils.get_readable_time(time.time() -
                                                        train_start),
                        eta=utils.get_readable_time(
                            (time.time() - train_start) / (step - start_step) *
                            (config.num_train_steps - step))),
                all_rank=True)

            with train_summary_writer.as_default():
                for key, m in metrics.items():
                    tf.summary.scalar(key, m.result(), step=step)

            if int(checkpoint.step) < config.num_train_steps:
                for m in metrics.values():
                    m.reset_states()

        #Print allreduced metrics on the last step
        if int(checkpoint.step) == config.num_train_steps and (
                local_step % args.gradient_accumulation_steps == 0):
            log_info_dict = {
                k: float(hvd.allreduce(v.result()).numpy() * 100) if "accuracy"
                in k else float(hvd.allreduce(v.result()).numpy())
                for k, v in metrics.items()
            }
            log_info_dict["training_sequences_per_second"] = log_info_dict[
                "train_perf"]
            log_info_dict["final_loss"] = log_info_dict["total_loss"]
            log_info_dict["e2e_train_time"] = time.time() - e2e_start_time
            dllogger.log(step=(), data=log_info_dict, verbosity=0)
            log('<FINAL STEP METRICS> Step:{step:6d}, Loss:{total_loss:10.6f}, Gen_loss:{masked_lm_loss:10.6f}, Disc_loss:{disc_loss:10.6f}, Gen_acc:{masked_lm_accuracy:6.2f}, '
                'Disc_acc:{disc_accuracy:6.2f}, Perf:{train_perf:4.0f},'.
                format(step=step, **log_info_dict),
                all_rank=False)

        if local_step % args.gradient_accumulation_steps == 0:
            checkpoint.step.assign(int(optimizer.iterations))

        local_step += 1
        if not config.skip_checkpoint and (
                local_step %
            (config.save_checkpoints_steps * args.gradient_accumulation_steps)
                == 0):
            saved_ckpt = True
            if is_main_process():
                save_path = manager.save(checkpoint_number=step)
                log(" ** Saved model checkpoint for step {}: {}".format(
                    step, save_path))
            iter_save_path = iter_manager.save(checkpoint_number=step)
            log(" ** Saved iterator checkpoint for step {}: {}".format(
                step, iter_save_path),
                all_rank=True)

    step = (int(checkpoint.step) - 1)
    dllogger.flush()
    if not config.skip_checkpoint and not saved_ckpt:
        if is_main_process():
            save_path = manager.save(checkpoint_number=step)
            log(" ** Saved model checkpoint for step {}: {}".format(
                step, save_path))
        iter_save_path = iter_manager.save(checkpoint_number=step)
        log(" ** Saved iterator checkpoint for step {}: {}".format(
            step, iter_save_path),
            all_rank=True)

    return args
コード例 #14
0
ファイル: run_pretrain.py プロジェクト: sharathts/electra
def main():
    # Parse essential args
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        required=True,
                        help="Location of data files (model weights, etc).")
    parser.add_argument("--model_name",
                        required=True,
                        help="The name of the model being fine-tuned.")
    parser.add_argument("--pretrain_tfrecords", type=str)

    parser.add_argument("--seed", type=int)
    parser.add_argument("--num_train_steps", type=int)
    parser.add_argument("--num_warmup_steps", type=int)
    parser.add_argument("--learning_rate", type=float)
    parser.add_argument("--train_batch_size", type=int)
    parser.add_argument("--max_seq_length", type=int)

    parser.add_argument("--mask_prob", type=float)
    parser.add_argument("--disc_weight", type=float)
    parser.add_argument("--generator_hidden_size", type=float)

    parser.add_argument("--save_checkpoints_steps", type=int)
    parser.add_argument("--keep_checkpoint_max", type=int)
    parser.add_argument("--restore_checkpoint", action='store_true')

    parser.add_argument("--optimizer",
                        default="adam",
                        type=str,
                        help="adam or lamb")

    args = parser.parse_args()
    config = PretrainingConfig(**args.__dict__)

    # Set up tensorflow
    hvd.init()
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                   'GPU')
    tf.config.optimizer.set_jit(config.xla)
    tf.config.optimizer.set_experimental_options(
        {"auto_mixed_precision": config.amp})
    tf.random.set_seed(config.seed)

    # Set up config
    if config.do_train == config.do_eval:
        raise ValueError(
            "Exactly one of `do_train` or `do_eval` must be True.")
    if config.debug and config.do_train:
        utils.rmkdir(config.model_dir)
    utils.heading("Config:")
    log_config(config)

    # Save pretrain configs
    pretrain_config_json = os.path.join(config.checkpoints_dir,
                                        'pretrain_config.json')
    if is_main_process():
        utils.write_json(config.__dict__, pretrain_config_json)
        log("Configuration saved in {}".format(pretrain_config_json))

    # Set up model
    model = PretrainingModel(config)

    # Set up metrics
    perf_metrics = dict()
    perf_metrics["train_perf"] = tf.keras.metrics.Mean(name="train_perf")

    eval_metrics = dict()
    eval_metrics["total_loss"] = tf.keras.metrics.Mean(name="total_loss")
    eval_metrics["masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
        name="masked_lm_accuracy")
    eval_metrics["masked_lm_loss"] = tf.keras.metrics.Mean(
        name="masked_lm_loss")
    if config.electra_objective:
        eval_metrics["sampled_masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
            name="sampled_masked_lm_accuracy")
        if config.disc_weight > 0:
            eval_metrics["disc_loss"] = tf.keras.metrics.Mean(name="disc_loss")
            eval_metrics["disc_auc"] = tf.keras.metrics.AUC(name="disc_auc")
            eval_metrics["disc_accuracy"] = tf.keras.metrics.Accuracy(
                name="disc_accuracy")
            eval_metrics["disc_precision"] = tf.keras.metrics.Accuracy(
                name="disc_precision")
            eval_metrics["disc_recall"] = tf.keras.metrics.Accuracy(
                name="disc_recall")

    # Set up tensorboard
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(
        config.log_dir, current_time,
        'train_' + str(get_rank()) + '_of_' + str(get_world_size()))
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    # Set up dataset
    dataset = pretrain_utils.get_dataset(config,
                                         config.train_batch_size,
                                         world_size=get_world_size(),
                                         rank=get_rank())
    train_iterator = iter(dataset)

    # Set up optimizer
    optimizer = create_optimizer(init_lr=config.learning_rate,
                                 num_train_steps=config.num_train_steps,
                                 num_warmup_steps=config.num_warmup_steps,
                                 weight_decay_rate=config.weight_decay_rate,
                                 optimizer=config.optimizer)
    if config.amp:
        optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
            optimizer, "dynamic")

    if config.do_train:
        # Set up checkpoint manager
        checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                         optimizer=optimizer,
                                         model=model)
        manager = tf.train.CheckpointManager(
            checkpoint,
            config.checkpoints_dir,
            max_to_keep=config.keep_checkpoint_max)
        iter_checkpoint = tf.train.Checkpoint(train_iterator=train_iterator)
        iter_manager = tf.train.CheckpointManager(
            iter_checkpoint,
            os.path.join(config.checkpoints_dir,
                         'iter_ckpt_rank_' + '{:02}'.format(get_rank())),
            checkpoint_name='iter_ckpt_rank_' + '{:02}'.format(get_rank()),
            max_to_keep=config.keep_checkpoint_max)
        if config.restore_checkpoint and manager.latest_checkpoint:
            checkpoint.restore(manager.latest_checkpoint)
            log(" ** Restored model checkpoint from {}".format(
                manager.latest_checkpoint))
            if iter_manager.latest_checkpoint:
                iter_checkpoint.restore(iter_manager.latest_checkpoint)
                log(" ** Restored iterator checkpoint from {}".format(
                    iter_manager.latest_checkpoint),
                    all_rank=True)
        else:
            log(" ** Initializing from scratch.")

        utils.heading("Running training")
        train_start, start_step = time.time(), int(checkpoint.step) - 1
        while int(checkpoint.step) <= config.num_train_steps:
            step = int(checkpoint.step)
            features = next(train_iterator)
            iter_start = time.time()

            # if step == 200: tf.profiler.experimental.start(logdir=train_log_dir)
            total_loss, eval_fn_inputs = train_one_step(
                config, model, optimizer, features, step <= 1)
            # if step == 300: tf.profiler.experimental.stop()

            perf_metrics["train_perf"].update_state(config.train_batch_size *
                                                    get_world_size() /
                                                    (time.time() - iter_start))
            eval_metrics["total_loss"].update_state(values=total_loss)
            metric_fn(config, eval_metrics, eval_fn_inputs)

            if step % 100 == 0:
                log('Step:{:6d}, Loss:{:10.6f}, Gen_loss:{:10.6f}, Disc_loss:{:10.6f}, Gen_acc:{:6.2f}, '
                    'Disc_acc:{:6.2f}, Perf:{:4.0f}, Elapsed: {}, ETA: {}, '.
                    format(
                        step, total_loss,
                        eval_metrics["masked_lm_loss"].result().numpy(),
                        eval_metrics["disc_loss"].result().numpy(),
                        eval_metrics["masked_lm_accuracy"].result().numpy() *
                        100,
                        eval_metrics["disc_accuracy"].result().numpy() * 100,
                        perf_metrics["train_perf"].result().numpy(),
                        utils.get_readable_time(time.time() - train_start),
                        utils.get_readable_time(
                            (time.time() - train_start) / (step - start_step) *
                            (config.num_train_steps - step))),
                    all_rank=True)

                with train_summary_writer.as_default():
                    for key, m in eval_metrics.items():
                        tf.summary.scalar(key, m.result(), step=step)

                for m in eval_metrics.values():
                    m.reset_states()

            checkpoint.step.assign_add(1)
            if step % config.save_checkpoints_steps == 0:
                if is_main_process():
                    save_path = manager.save()
                    log(" ** Saved model checkpoint for step {}: {}".format(
                        step, save_path))
                iter_save_path = iter_manager.save()
                log(" ** Saved iterator checkpoint for step {}: {}".format(
                    step, iter_save_path),
                    all_rank=True)

    if config.do_eval:
        pass